Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

pacs.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import pandas as pd
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.stats import entropy
  6. from os.path import join as opj, exists
  7. import argparse
  8. import textwrap
  9. from sklearn.preprocessing import MultiLabelBinarizer
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--input")
  12. args = parser.parse_args()
  13. ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
  14. articles = pd.read_parquet("../semantics/inspire-harvest/database/articles.parquet")[["article_id", "date_created", "pacs_codes", "curated"]]
  15. articles = articles[articles["date_created"].str.len() >= 4]
  16. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  17. articles["article_id"] = articles.article_id.astype(int)
  18. _articles = pd.read_csv(opj(args.input, "articles.csv"))
  19. articles = _articles.merge(articles, how="inner")
  20. years = articles["year"].values
  21. topics = pd.read_csv(opj(args.input, "topics.csv"))
  22. is_junk_topic = np.array(topics["label"].str.contains("Junk"))
  23. pacs_description = pd.read_csv("inspire-harvest/database/pacs_codes.csv").set_index("code")["description"].to_dict()
  24. codes = set(pacs_description.keys())
  25. has_pacs_code = articles["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0
  26. articles = articles[has_pacs_code]
  27. binarizer = MultiLabelBinarizer()
  28. pacs = binarizer.fit_transform(articles["pacs_codes"])
  29. n_categories = pacs.shape[1]
  30. print((pacs-pacs.mean(axis=0)).shape)
  31. cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)])
  32. cat_labels = binarizer.inverse_transform(cat_classes)
  33. with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
  34. etm_instance = pickle.load(handle)
  35. if not exists(opj(args.input, "theta.npy")):
  36. theta = etm_instance.get_document_topic_dist()
  37. np.save(opj(args.input, "theta.npy"), theta)
  38. else:
  39. theta = np.load(opj(args.input, "theta.npy"))
  40. print(theta[:,~is_junk_topic].mean(axis=0).sum())
  41. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  42. print("Theta average entropy:", np.exp(entropy(theta[:,~is_junk_topic], axis=1)).mean())
  43. print("Topic matrix average entropy:", np.nanmean(np.exp(np.nan_to_num(entropy(topic_matrix, axis=1)))))
  44. theta = theta[has_pacs_code,:]
  45. n_articles = theta.shape[0]
  46. R = (theta-theta.mean(axis=0)).T@(pacs-pacs.mean(axis=0))/n_articles
  47. R /= np.outer(theta.std(axis=0), pacs.std(axis=0))
  48. print(R.shape)
  49. print(R)
  50. topics["top_pacs"] = ""
  51. topics["relevant"] = 1
  52. topics.loc[is_junk_topic, "label"] = "Uninterpretable"
  53. topics.loc[is_junk_topic, "relevant"] = 0
  54. for i in range(R.shape[0]):
  55. ind = np.argsort(R[i,:])
  56. top = ind[-5:][::-1]
  57. topics.loc[i, "top_pacs"] = "\\\\ ".join([f"{textwrap.shorten(pacs_description[cat_labels[j][0]], width=40)} ({R[i,j]:.2f})" for j in top if cat_labels[j][0] in pacs_description])
  58. topics["top_words"] = topics["top_words"].str.replace(",", ", ")
  59. topics["top_words"] = topics["top_words"].str.replace("_", "\\_")
  60. # topics["top_words"] = topics["top_words"].str.replace(r"(^(.*)-ray)|((.*)-ray)", "$\\gamma$-ray", regex=True)
  61. # topics["top_words"] = topics["top_words"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=45, break_long_words=False)))
  62. # topics["top_words"] = topics["top_words"].apply(lambda s: '\\begin{tabular}{l}' + s +'\\end{tabular}')
  63. topics["top_pacs"] = topics["top_pacs"].apply(lambda s: '\\shortstack[l]{' + s +'}')
  64. topics.sort_values(["relevant", "label"], ascending=[False, True], inplace=True)
  65. pd.set_option('display.max_colwidth', None)
  66. latex = topics.to_latex(
  67. columns=["label", "top_words", "top_pacs"],
  68. header = ["Research area", "Top words", "Most correlated PACS categories"],
  69. index=False,
  70. multirow=True,
  71. multicolumn=True,
  72. longtable=True,
  73. column_format='p{0.15\\textwidth}|b{0.425\\textwidth}|b{0.425\\textwidth}',
  74. escape=False,
  75. caption="Research areas, their top-words, and their correlation with a standard classification (PACS).",
  76. label="table:research_areas"
  77. )
  78. latex = latex.replace(')} \\\\\n', ')}\\\\ \\hline\n')
  79. with open(opj(args.input, "topics.tex"), "w+") as fp:
  80. fp.write(latex)
  81. keep_pacs = pacs.sum(axis=0)>=100
  82. R = R[:,keep_pacs]
  83. R = R[~is_junk_topic]
  84. labels = [cat_labels[i][0] for i in range(len(keep_pacs)) if keep_pacs[i]]
  85. breaks_1 = np.array([True if (i==0 or labels[i-1][:5]!=labels[i][:5]) else False for i in range(len(labels))])
  86. breaks_2 = np.array([True if (i==0 or labels[i-1][:2]!=labels[i][:2]) else False for i in range(len(labels))])
  87. order = np.load(opj(args.input, "topics_order.npy"))
  88. print(order)
  89. R = R[order,:]
  90. import seaborn as sns
  91. import matplotlib
  92. from matplotlib import pyplot as plt
  93. matplotlib.use("pgf")
  94. matplotlib.rcParams.update(
  95. {
  96. "pgf.texsystem": "xelatex",
  97. "font.family": "serif",
  98. "font.serif": "Times New Roman",
  99. "text.usetex": True,
  100. "pgf.rcfonts": False,
  101. }
  102. )
  103. plt.rcParams["text.latex.preamble"].join([
  104. r"\usepackage{amsmath}",
  105. r"\setmainfont{amssymb}",
  106. ])
  107. from matplotlib.gridspec import GridSpec
  108. plt.clf()
  109. fig = plt.figure(figsize=(6.4*2, 6.4*2.5))
  110. gs = GridSpec(9,8,hspace=0,wspace=0)
  111. ax_heatmap = fig.add_subplot(gs[0:8,2:8])
  112. sns.heatmap(
  113. R.T,
  114. cmap="RdBu",
  115. vmin=-0.5, vmax=+0.5,
  116. square=False,
  117. ax=ax_heatmap,
  118. cbar_kws={"shrink": 0.25}
  119. )
  120. ax_heatmap.invert_yaxis()
  121. ax_heatmap.yaxis.set_visible(False)
  122. topics = pd.read_csv(opj(args.input, "topics.csv"))
  123. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  124. ordered_topics = [topics[i] for i in order]
  125. ax_heatmap.set_xticklabels(ordered_topics, rotation = 90)
  126. xmin, xmax = ax_heatmap.get_xlim()
  127. ymin, ymax = ax_heatmap.get_ylim()
  128. ax_breaks = fig.add_subplot(gs[0:8,0:2])
  129. ax_breaks.set_xlim(0,8)
  130. ax_breaks.set_ylim(ymin,ymax)
  131. ax_breaks.axis("off")
  132. import re
  133. html = open("../semantics/inspire-harvest/database/pacs.html", "r").read()
  134. # Regex pattern to extract the code and label
  135. pattern = r'<span class="pacs_num">(\d{2}\.\d{2}\.\d{2})</span>\s*(.*?)</h2>'
  136. matches = re.findall(pattern, html)
  137. matches = {
  138. matches[i][0][:2]: matches[i][1]
  139. for i in range(len(matches))
  140. }
  141. import textwrap
  142. prev_break = 0
  143. for i in range(len(breaks_1)):
  144. if breaks_1[i] == True:
  145. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125/2)
  146. if breaks_2[i] == True or i == len(breaks_1)-1:
  147. ax_breaks.hlines(y=i-0.5, xmin=2, xmax=10, color="lightgray", lw=0.125)
  148. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125)
  149. print(prev_break, i, labels[i], prev_break+(i-prev_break)/2.0)
  150. if prev_break != i:
  151. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*3)
  152. lines = textwrap.wrap(text, width=35)
  153. too_small = len(lines) >= 0.35*(i-prev_break-2)
  154. if too_small:
  155. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*2)
  156. lines = textwrap.wrap(text, width=35)
  157. too_small = len(lines) >= 0.35*(i-prev_break-2)
  158. if too_small:
  159. text = textwrap.shorten(matches[labels[i-1][:2]], width=35)
  160. lines = textwrap.wrap(text, width=35)
  161. too_small = len(lines) >= 0.5*(i-prev_break-2)
  162. if (not too_small) or i == len(breaks_1)-1:
  163. ax_breaks.text(
  164. 7.5,
  165. prev_break+(i-prev_break)/2.0,
  166. "\n".join(lines),
  167. ha="right",
  168. va="center",
  169. fontsize=8
  170. )
  171. prev_break = i
  172. plt.savefig(opj(args.input, "pacs_clustermap.eps"), bbox_inches="tight")