pacs.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import pandas as pd
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.stats import entropy
  6. from os.path import join as opj, exists
  7. import argparse
  8. import textwrap
  9. from sklearn.preprocessing import MultiLabelBinarizer
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--input")
  12. parser.add_argument("--dataset", default="inspire-harvest/database")
  13. args = parser.parse_args()
  14. ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
  15. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "pacs_codes", "curated"]]
  16. articles = articles[articles["date_created"].str.len() >= 4]
  17. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  18. articles["article_id"] = articles.article_id.astype(int)
  19. _articles = pd.read_csv(opj(args.input, "articles.csv"))
  20. articles = _articles.merge(articles, how="inner")
  21. years = articles["year"].values
  22. topics = pd.read_csv(opj(args.input, "topics.csv"))
  23. is_junk_topic = np.array(topics["label"].str.contains("Junk"))
  24. pacs_description = pd.read_csv(opj(args.dataset, "pacs_codes.csv")).set_index("code")["description"].to_dict()
  25. codes = set(pacs_description.keys())
  26. has_pacs_code = articles["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0
  27. articles = articles[has_pacs_code]
  28. binarizer = MultiLabelBinarizer()
  29. pacs = binarizer.fit_transform(articles["pacs_codes"])
  30. n_categories = pacs.shape[1]
  31. print((pacs-pacs.mean(axis=0)).shape)
  32. cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)])
  33. cat_labels = binarizer.inverse_transform(cat_classes)
  34. with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
  35. etm_instance = pickle.load(handle)
  36. if not exists(opj(args.input, "theta.npy")):
  37. theta = etm_instance.get_document_topic_dist()
  38. np.save(opj(args.input, "theta.npy"), theta)
  39. else:
  40. theta = np.load(opj(args.input, "theta.npy"))
  41. print(theta[:,~is_junk_topic].mean(axis=0).sum())
  42. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  43. print("Theta average entropy:", np.exp(entropy(theta[:,~is_junk_topic], axis=1)).mean())
  44. print("Topic matrix average entropy:", np.nanmean(np.exp(np.nan_to_num(entropy(topic_matrix, axis=1)))))
  45. theta = theta[has_pacs_code,:]
  46. n_articles = theta.shape[0]
  47. R = (theta-theta.mean(axis=0)).T@(pacs-pacs.mean(axis=0))/n_articles
  48. R /= np.outer(theta.std(axis=0), pacs.std(axis=0))
  49. print(R.shape)
  50. print(R)
  51. topics["top_pacs"] = ""
  52. topics["relevant"] = 1
  53. topics.loc[is_junk_topic, "label"] = "Uninterpretable"
  54. topics.loc[is_junk_topic, "relevant"] = 0
  55. for i in range(R.shape[0]):
  56. ind = np.argsort(R[i,:])
  57. top = np.array(ind[-5:])[::-1]
  58. topics.loc[i, "top_pacs"] = "\\\\ ".join([f"{textwrap.shorten(pacs_description[cat_labels[j][0]], width=40)} ({R[i,j]:.2f})" for j in top if cat_labels[j][0] in pacs_description])
  59. topics["top_words"] = topics["top_words"].str.replace(",", ", ")
  60. topics["top_words"] = topics["top_words"].str.replace("_", "\\_")
  61. # topics["top_words"] = topics["top_words"].str.replace(r"(^(.*)-ray)|((.*)-ray)", "$\\gamma$-ray", regex=True)
  62. # topics["top_words"] = topics["top_words"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=45, break_long_words=False)))
  63. # topics["top_words"] = topics["top_words"].apply(lambda s: '\\begin{tabular}{l}' + s +'\\end{tabular}')
  64. topics["top_pacs"] = topics["top_pacs"].apply(lambda s: '\\shortstack[l]{' + s +'}')
  65. topics.sort_values(["relevant", "label"], ascending=[False, True], inplace=True)
  66. pd.set_option('display.max_colwidth', None)
  67. latex = topics.to_latex(
  68. columns=["label", "top_words", "top_pacs"],
  69. header = ["Research area", "Top words", "Most correlated PACS categories"],
  70. index=False,
  71. multirow=True,
  72. multicolumn=True,
  73. longtable=True,
  74. column_format='p{0.15\\textwidth}|b{0.425\\textwidth}|b{0.425\\textwidth}',
  75. escape=False,
  76. caption="Research areas, their top-words, and their correlation with a standard classification (PACS).",
  77. label="table:research_areas"
  78. )
  79. latex = latex.replace(')} \\\\\n', ')}\\\\ \\hline\n')
  80. with open(opj(args.input, "topics.tex"), "w+") as fp:
  81. fp.write(latex)
  82. keep_pacs = pacs.sum(axis=0)>=100
  83. R = R[:,keep_pacs]
  84. R = R[~is_junk_topic]
  85. labels = [cat_labels[i][0] for i in range(len(keep_pacs)) if keep_pacs[i]]
  86. breaks_1 = np.array([True if (i==0 or labels[i-1][:5]!=labels[i][:5]) else False for i in range(len(labels))])
  87. breaks_2 = np.array([True if (i==0 or labels[i-1][:2]!=labels[i][:2]) else False for i in range(len(labels))])
  88. order = np.load(opj(args.input, "topics_order.npy"))
  89. print(order)
  90. R = R[order,:]
  91. import seaborn as sns
  92. import matplotlib
  93. from matplotlib import pyplot as plt
  94. matplotlib.use("pgf")
  95. matplotlib.rcParams.update(
  96. {
  97. "pgf.texsystem": "xelatex",
  98. "font.family": "serif",
  99. "font.serif": "Times New Roman",
  100. "text.usetex": True,
  101. "pgf.rcfonts": False,
  102. }
  103. )
  104. plt.rcParams["text.latex.preamble"].join([
  105. r"\usepackage{amsmath}",
  106. r"\setmainfont{amssymb}",
  107. ])
  108. from matplotlib.gridspec import GridSpec
  109. plt.clf()
  110. fig = plt.figure(figsize=(6.4*2, 6.4*2.5))
  111. gs = GridSpec(9,8,hspace=0,wspace=0)
  112. ax_heatmap = fig.add_subplot(gs[0:8,2:8])
  113. sns.heatmap(
  114. R.T,
  115. cmap="RdBu",
  116. vmin=-0.5, vmax=+0.5,
  117. square=False,
  118. ax=ax_heatmap,
  119. cbar_kws={"shrink": 0.25}
  120. )
  121. ax_heatmap.invert_yaxis()
  122. ax_heatmap.yaxis.set_visible(False)
  123. topics = pd.read_csv(opj(args.input, "topics.csv"))
  124. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  125. ordered_topics = [topics[i] for i in order]
  126. ax_heatmap.set_xticklabels(ordered_topics, rotation = 90)
  127. xmin, xmax = ax_heatmap.get_xlim()
  128. ymin, ymax = ax_heatmap.get_ylim()
  129. ax_breaks = fig.add_subplot(gs[0:8,0:2])
  130. ax_breaks.set_xlim(0,8)
  131. ax_breaks.set_ylim(ymin,ymax)
  132. ax_breaks.axis("off")
  133. import re
  134. html = open(opj(args.dataset, "pacs.html"), "r").read()
  135. # Regex pattern to extract the code and label
  136. pattern = r'<span class="pacs_num">(\d{2}\.\d{2}\.\d{2})</span>\s*(.*?)</h2>'
  137. matches = re.findall(pattern, html)
  138. matches = {
  139. matches[i][0][:2]: matches[i][1]
  140. for i in range(len(matches))
  141. }
  142. import textwrap
  143. prev_break = 0
  144. for i in range(len(breaks_1)):
  145. if breaks_1[i] == True:
  146. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125/2)
  147. if breaks_2[i] == True or i == len(breaks_1)-1:
  148. ax_breaks.hlines(y=i-0.5, xmin=2, xmax=10, color="lightgray", lw=0.125)
  149. ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125)
  150. print(prev_break, i, labels[i], prev_break+(i-prev_break)/2.0)
  151. if prev_break != i:
  152. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*3)
  153. lines = textwrap.wrap(text, width=35)
  154. too_small = len(lines) >= 0.35*(i-prev_break-2)
  155. if too_small:
  156. text = textwrap.shorten(matches[labels[i-1][:2]], width=35*2)
  157. lines = textwrap.wrap(text, width=35)
  158. too_small = len(lines) >= 0.35*(i-prev_break-2)
  159. if too_small:
  160. text = textwrap.shorten(matches[labels[i-1][:2]], width=35)
  161. lines = textwrap.wrap(text, width=35)
  162. too_small = len(lines) >= 0.5*(i-prev_break-2)
  163. if (not too_small) or i == len(breaks_1)-1:
  164. ax_breaks.text(
  165. 7.5,
  166. prev_break+(i-prev_break)/2.0,
  167. "\n".join(lines),
  168. ha="right",
  169. va="center",
  170. fontsize=8
  171. )
  172. prev_break = i
  173. plt.savefig(opj(args.input, "pacs_clustermap.eps"), bbox_inches="tight")