pacs.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import pandas as pd
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.stats import entropy
  6. from os.path import join as opj, exists
  7. import argparse
  8. import textwrap
  9. from sklearn.preprocessing import MultiLabelBinarizer
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--input")
  12. args = parser.parse_args()
  13. ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
  14. articles = pd.read_parquet("../semantics/inspire-harvest/database/articles.parquet")[["article_id", "date_created", "pacs_codes", "curated"]]
  15. articles = articles[articles["date_created"].str.len() >= 4]
  16. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  17. articles["article_id"] = articles.article_id.astype(int)
  18. _articles = pd.read_csv(opj(args.input, "articles.csv"))
  19. articles = _articles.merge(articles, how="inner")
  20. years = articles["year"].values
  21. topics = pd.read_csv(opj(args.input, "topics.csv"))
  22. is_junk_topic = np.array(topics["label"].str.contains("Junk"))
  23. pacs_description = pd.read_csv("inspire-harvest/database/pacs_codes.csv").set_index("code")["description"].to_dict()
  24. codes = set(pacs_description.keys())
  25. has_pacs_code = articles["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0
  26. articles = articles[has_pacs_code]
  27. binarizer = MultiLabelBinarizer()
  28. pacs = binarizer.fit_transform(articles["pacs_codes"])
  29. n_categories = pacs.shape[1]
  30. print((pacs-pacs.mean(axis=0)).shape)
  31. cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)])
  32. cat_labels = binarizer.inverse_transform(cat_classes)
  33. with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
  34. etm_instance = pickle.load(handle)
  35. if not exists(opj(args.input, "theta.npy")):
  36. theta = etm_instance.get_document_topic_dist()
  37. np.save(opj(args.input, "theta.npy"), theta)
  38. else:
  39. theta = np.load(opj(args.input, "theta.npy"))
  40. print(theta[:,~is_junk_topic].mean(axis=0).sum())
  41. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  42. print("Theta average entropy:", np.exp(entropy(theta[:,~is_junk_topic], axis=1)).mean())
  43. print("Topic matrix average entropy:", np.nanmean(np.exp(np.nan_to_num(entropy(topic_matrix, axis=1)))))
  44. theta = theta[has_pacs_code,:]
  45. n_articles = theta.shape[0]
  46. R = (theta-theta.mean(axis=0)).T@(pacs-pacs.mean(axis=0))/n_articles
  47. R /= np.outer(theta.std(axis=0), pacs.std(axis=0))
  48. print(R.shape)
  49. print(R)
  50. topics["top_pacs"] = ""
  51. topics["relevant"] = 1
  52. topics.loc[is_junk_topic, "label"] = "Uninterpretable"
  53. topics.loc[is_junk_topic, "relevant"] = 0
  54. for i in range(R.shape[0]):
  55. ind = np.argsort(R[i,:])
  56. top = ind[-5:][::-1]
  57. topics.loc[i, "top_pacs"] = "\\\\ ".join([f"{textwrap.shorten(pacs_description[cat_labels[j][0]], width=40)} ({R[i,j]:.2f})" for j in top if cat_labels[j][0] in pacs_description])
  58. topics["top_words"] = topics["top_words"].str.replace(",", ", ")
  59. topics["top_words"] = topics["top_words"].str.replace("_", "\\_")
  60. # topics["top_words"] = topics["top_words"].str.replace(r"(^(.*)-ray)|((.*)-ray)", "$\\gamma$-ray", regex=True)
  61. # topics["top_words"] = topics["top_words"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=45, break_long_words=False)))
  62. # topics["top_words"] = topics["top_words"].apply(lambda s: '\\begin{tabular}{l}' + s +'\\end{tabular}')
  63. topics["top_pacs"] = topics["top_pacs"].apply(lambda s: '\\shortstack[l]{' + s +'}')
  64. topics.sort_values(["relevant", "label"], ascending=[False, True], inplace=True)
  65. pd.set_option('display.max_colwidth', None)
  66. latex = topics.to_latex(
  67. columns=["label", "top_words", "top_pacs"],
  68. header = ["Research area", "Top words", "Most correlated PACS categories"],
  69. index=False,
  70. multirow=True,
  71. multicolumn=True,
  72. longtable=True,
  73. column_format='p{0.15\\textwidth}|b{0.425\\textwidth}|b{0.425\\textwidth}',
  74. escape=False,
  75. caption="Research areas, their top-words, and their correlation with a standard classification (PACS).",
  76. label="table:research_areas"
  77. )
  78. latex = latex.replace(')} \\\\\n', ')}\\\\ \\hline\n')
  79. with open(opj(args.input, "topics.tex"), "w+") as fp:
  80. fp.write(latex)
  81. keep_pacs = pacs.sum(axis=0)>=100
  82. R = R[:,keep_pacs]
  83. R = R[~is_junk_topic]
  84. labels = [cat_labels[i][0] for i in range(len(keep_pacs)) if keep_pacs[i]]
  85. breaks_1 = np.array([True if (i==0 or labels[i-1][:5]!=labels[i][:5]) else False for i in range(len(labels))])
  86. breaks_2 = np.array([True if (i==0 or labels[i-1][:2]!=labels[i][:2]) else False for i in range(len(labels))])
  87. order = np.load(opj(args.input, "topics_order.npy"))
  88. print(order)
  89. R = R[order,:]
  90. import seaborn as sns
  91. import matplotlib
  92. from matplotlib import pyplot as plt
  93. matplotlib.use("pgf")
  94. matplotlib.rcParams.update(
  95. {
  96. "pgf.texsystem": "xelatex",
  97. "font.family": "serif",
  98. "font.serif": "Times New Roman",
  99. "text.usetex": True,
  100. "pgf.rcfonts": False,
  101. }
  102. )
  103. plt.rcParams["text.latex.preamble"].join([
  104. r"\usepackage{amsmath}",
  105. r"\setmainfont{amssymb}",
  106. ])
  107. from matplotlib.gridspec import GridSpec
  108. plt.clf()
  109. fig = plt.figure(figsize=(6.4*2, 6.4*2.5))
  110. gs = GridSpec(9,8,hspace=0,wspace=0)
  111. ax_heatmap = fig.add_subplot(gs[0:8,2:8])
  112. sns.heatmap(
  113. R.T,
  114. cmap="RdBu",
  115. vmin=-0.5, vmax=+0.5,
  116. square=False,
  117. ax=ax_heatmap,
  118. cbar_kws={"shrink": 0.25}
  119. )
  120. ax_heatmap.invert_yaxis()
  121. ax_heatmap.yaxis.set_visible(False)
  122. topics = pd.read_csv(opj(args.input, "topics.csv"))
  123. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  124. ordered_topics = [topics[i] for i in order]
  125. ax_heatmap.set_xticklabels(ordered_topics, rotation = 90)
  126. xmin, xmax = ax_heatmap.get_xlim()
  127. ymin, ymax = ax_heatmap.get_ylim()
  128. ax_breaks = fig.add_subplot(gs[0:8,0:2])
  129. ax_breaks.set_xlim(0,8)
  130. ax_breaks.set_ylim(ymin,ymax)
  131. ax_breaks.axis("off")
  132. import re
  133. html = open("../semantics/inspire-harvest/database/pacs.html", "r").read()
  134. # Regex pattern to extract the code and label
  135. pattern = r'<span class="pacs_num">(\d{2}\.\d{2}\.\d{2})</span>\s*(.*?)</h2>'
  136. matches = re.findall(pattern, html)
  137. matches = {
  138. matches[i][0][:2]: matches[i][1]
  139. for i in range(len(matches))
  140. }
  141. import textwrap
  142. prev_break = 0
  143. for i in range(len(breaks_1)):
  144. if breaks_1[i] == True:
  145. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125/2)
  146. if breaks_2[i] == True or i == len(breaks_1)-1:
  147. ax_breaks.hlines(y=i-0.5, xmin=2, xmax=10, color="lightgray", lw=0.125)
  148. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125)
  149. print(prev_break, i, labels[i], prev_break+(i-prev_break)/2.0)
  150. if prev_break != i:
  151. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*3)
  152. lines = textwrap.wrap(text, width=35)
  153. too_small = len(lines) >= 0.35*(i-prev_break-2)
  154. if too_small:
  155. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*2)
  156. lines = textwrap.wrap(text, width=35)
  157. too_small = len(lines) >= 0.35*(i-prev_break-2)
  158. if too_small:
  159. text = textwrap.shorten(matches[labels[i-1][:2]], width=35)
  160. lines = textwrap.wrap(text, width=35)
  161. too_small = len(lines) >= 0.5*(i-prev_break-2)
  162. if (not too_small) or i == len(breaks_1)-1:
  163. ax_breaks.text(
  164. 7.5,
  165. prev_break+(i-prev_break)/2.0,
  166. "\n".join(lines),
  167. ha="right",
  168. va="center",
  169. fontsize=8
  170. )
  171. prev_break = i
  172. plt.savefig(opj(args.input, "pacs_clustermap.eps"), bbox_inches="tight")