topic_vs_experiments.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import pandas as pd
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. import matplotlib
  7. from matplotlib import pyplot as plt
  8. matplotlib.use("pgf")
  9. matplotlib.rcParams.update(
  10. {
  11. "pgf.texsystem": "xelatex",
  12. "font.family": "serif",
  13. "font.serif": "Times New Roman",
  14. "text.usetex": True,
  15. "pgf.rcfonts": False,
  16. }
  17. )
  18. plt.rcParams["text.latex.preamble"].join([
  19. r"\usepackage{amsmath}",
  20. r"\setmainfont{amssymb}",
  21. ])
  22. import seaborn as sns
  23. import pickle
  24. from os.path import join as opj
  25. import argparse
  26. parser = argparse.ArgumentParser()
  27. parser.add_argument("--inputs", nargs="+")
  28. parser.add_argument("--dataset", default="inspire-harvest/database")
  29. parser.add_argument("--keywords-threshold", type=int, default=200)
  30. parser.add_argument("--articles-threshold", type=int, default=5)
  31. parser.add_argument("--early-periods", nargs="+", type=int, default=[0,1]) # [2,3] for ACL, [3] for HEP
  32. parser.add_argument("--late-periods", nargs="+", type=int, default=[3]) # [2,3] for ACL, [3] for HEP
  33. parser.add_argument("--fla", action="store_true", help="first or last author")
  34. args = parser.parse_args()
  35. custom_range = "_" + "-".join(map(str, args.early_periods)) + "_" + "-".join(map(str, args.late_periods)) if (args.early_periods!=[0,1] or args.late_periods!=[3]) else ""
  36. print(custom_range)
  37. references = pd.read_parquet(opj(args.dataset, "articles_references.parquet"))
  38. references["cites"] = references.cites.astype(int)
  39. references["cited"] = references.cited.astype(int)
  40. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title", "accelerators"]]
  41. articles["article_id"] = articles.article_id.astype(int)
  42. articles = articles[articles["date_created"].str.len() >= 4]
  43. articles["year"] = articles["date_created"].str[:4].astype(int)
  44. experimental = articles[articles["accelerators"].map(len)>=1]
  45. experimental = experimental.explode("accelerators")
  46. experimental["accelerators"] = experimental["accelerators"].str.replace(
  47. "(.*)-(.*)-(.*)$", r"\1-\2", regex=True
  48. )
  49. types = {
  50. "FNAL-E": "colliders",
  51. "CERN-LEP": "colliders",
  52. "DESY-HERA": "colliders",
  53. "SUPER-KAMIOKANDE": "astro. neutrinos",
  54. "CERN-NA": "colliders",
  55. "CESR-CLEO": "colliders",
  56. "CERN-WA": "colliders",
  57. "BNL-E": "colliders",
  58. "KAMIOKANDE": "astro. neutrinos",
  59. "SLAC-E": "colliders",
  60. "SLAC-PEP2": "colliders",
  61. "KEK-BF": "colliders",
  62. "SNO": "neutrinos",
  63. "BNL-RHIC": "colliders",
  64. "WMAP": "cosmic $\\mu$wave background",
  65. "CERN-LHC": "colliders",
  66. "PLANCK": "cosmic $\\mu$wave background",
  67. "BEPC-BES": "colliders",
  68. "LIGO": "gravitational waves",
  69. "VIRGO": "gravitational waves",
  70. "CERN-PS": "colliders",
  71. "FERMI-LAT": "other cosmic sources",
  72. "XENON100": "dark matter (direct)",
  73. "ICECUBE": "astro. neutrinos",
  74. "LUX": "dark matter (direct)",
  75. "T2K": "neutrinos",
  76. "BICEP2": "cosmic $\\mu$wave background",
  77. "CDMS": "dark matter (direct)",
  78. "LAMPF-1173": "neutrinos",
  79. "FRASCATI-DAFNE": "colliders",
  80. "KamLAND": "neutrinos",
  81. "SDSS": "other cosmic sources",
  82. "JLAB-E-89": "colliders",
  83. "CHOOZ": "neutrinos",
  84. "XENON1T": "dark matter (direct)",
  85. "SCP": "supernovae",
  86. "DAYA-BAY": "neutrinos",
  87. "HOMESTAKE-CHLORINE": "neutrinos",
  88. "HIGH-Z": "supernovae",
  89. "K2K": "neutrinos",
  90. "MACRO": "other cosmic sources",
  91. "GALLEX": "neutrinos",
  92. "SAGE": "neutrinos",
  93. "PAMELA": "other cosmic sources",
  94. "CERN-UA": "colliders",
  95. "CERN SPS": "colliders",
  96. "DESY-PETRA": "colliders",
  97. "SLAC-SLC": "colliders",
  98. "LEPS": "colliders",
  99. "DOUBLECHOOZ": "neutrinos",
  100. "AUGER": "other cosmic sources",
  101. "AMS": "other cosmic sources",
  102. "DAMA": "dark matter (direct)",
  103. "DESY-DORIS": "colliders",
  104. "NOVOSIBIRSK-CMD": "colliders",
  105. "IMB": "neutrinos",
  106. "RENO": "neutrinos",
  107. "SLAC-SP": "colliders"
  108. }
  109. ordered_types = [
  110. "colliders",
  111. "neutrinos",
  112. "astro. neutrinos",
  113. "dark matter (direct)",
  114. "cosmic $\\mu$wave background",
  115. "supernovae",
  116. "other cosmic sources",
  117. "gravitational waves"
  118. ]
  119. experimental = experimental[experimental["accelerators"].isin(types.keys())]
  120. experimental["type"] = experimental["accelerators"].map(types)
  121. def compute_counts(model, articles):
  122. _articles = pd.read_csv(opj(model, "articles.csv"))
  123. articles = _articles.merge(articles, how="left")
  124. topics = pd.read_csv(opj(model, "topics.csv"))["label"].tolist()
  125. topic_matrix = np.load(opj(model, "topics_counts.npy"))
  126. topic_matrix = topic_matrix/np.where(topic_matrix.sum(axis=1)>0, topic_matrix.sum(axis=1), 1)[:,np.newaxis]
  127. articles["topics"] = list(topic_matrix)
  128. articles["main_topic"] = topic_matrix.argmax(axis=1)
  129. articles["main_topic"] = articles["main_topic"].map(lambda x: topics[x])
  130. articles = articles[~articles["main_topic"].str.contains("Junk")]
  131. citing_experiments = articles.merge(references, how="inner", left_on="article_id", right_on="cites")
  132. citing_experiments = citing_experiments.merge(experimental, how="inner", left_on="cited", right_on="article_id")
  133. counts = citing_experiments.groupby("type")["main_topic"].value_counts(normalize=True).reset_index()
  134. counts = counts.pivot(index="type", columns="main_topic")
  135. counts.sort_index(key=lambda idx: idx.map(lambda x: ordered_types.index(x)), inplace=True)
  136. return counts
  137. model_names = ["Main model", "$K=15$", "$K=25$"]
  138. fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=[4.8*1.5, 3.2*1.5], layout='constrained')
  139. for i, model in enumerate(args.inputs):
  140. counts = compute_counts(model, articles.copy())
  141. print(counts.index)
  142. topics = counts["proportion"].columns
  143. values = np.stack(counts["proportion"].values)
  144. relevant_topics = np.arange(len(topics))[values.max(axis=0)>0.05]
  145. topics = topics[relevant_topics]
  146. # sns.heatmap(counts, ax=axes[i], cmap="Reds")
  147. im = axes[i].matshow(counts["proportion"][topics], cmap="Reds", vmin=0, vmax=1, aspect=len(topics)/15)
  148. axes[i].set_xlabel(model_names[i])
  149. topics = [topic.replace(" and ", " \\& ").lower().capitalize() for topic in topics]
  150. topics = [f"\\small{{{topic}}}" for topic in topics]
  151. axes[i].set_xticks(np.arange(len(topics)), topics, rotation="vertical")
  152. axes[i].set_xticks(np.arange(len(topics)+1)-0.5, minor=True)
  153. # axes[i].axvline(len(topics)-0.5, color="black", lw=2)
  154. axes[i].set_xlim(-0.5, len(topics)-0.5)
  155. fig.colorbar(im, location='bottom')
  156. axes[0].set_yticks(np.arange(len(ordered_types)), list(map(lambda x: f"\\small{{{x.capitalize()}}}", ordered_types)), va="center")
  157. axes[i].set_yticks(np.arange(len(ordered_types)+1)-0.5, minor=True)
  158. # p0 = axes[0].get_position()
  159. # p2 = axes[2].get_position()
  160. # ax_cbar = fig.add_axes([p0.x0, 0.08, p2.x1, 0.05])
  161. # plt.colorbar(im, cax=ax_cbar, ticks=np.linspace(0, 1, 3, True), orientation='horizontal')
  162. plt.subplots_adjust(wspace=0, hspace=0)
  163. fig.savefig(opj(args.inputs[0], "topic_experiments.eps"), bbox_inches="tight")