topic_popularity.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import pandas as pd
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. import matplotlib
  7. from matplotlib import pyplot as plt
  8. matplotlib.use("pgf")
  9. matplotlib.rcParams.update(
  10. {
  11. "pgf.texsystem": "xelatex",
  12. "font.family": "serif",
  13. "font.serif": "Times New Roman",
  14. "text.usetex": True,
  15. "pgf.rcfonts": False,
  16. }
  17. )
  18. plt.rcParams["text.latex.preamble"].join([
  19. r"\usepackage{amsmath}",
  20. r"\setmainfont{amssymb}",
  21. ])
  22. import seaborn as sns
  23. import pickle
  24. from os.path import join as opj
  25. import argparse
  26. parser = argparse.ArgumentParser()
  27. parser.add_argument("--input")
  28. parser.add_argument("--dataset", default="inspire-harvest/database")
  29. parser.add_argument("--keywords-threshold", type=int, default=200)
  30. parser.add_argument("--articles-threshold", type=int, default=5)
  31. parser.add_argument("--early-periods", nargs="+", type=int, default=[0,1]) # [2,3] for ACL, [3] for HEP
  32. parser.add_argument("--late-periods", nargs="+", type=int, default=[3]) # [2,3] for ACL, [3] for HEP
  33. parser.add_argument("--fla", action="store_true", help="first or last author")
  34. args = parser.parse_args()
  35. custom_range = "_" + "-".join(map(str, args.early_periods)) + "_" + "-".join(map(str, args.late_periods)) if (args.early_periods!=[0,1] or args.late_periods!=[3]) else ""
  36. print(custom_range)
  37. references = pd.read_parquet(opj(args.dataset, "articles_references.parquet"))
  38. references["cites"] = references.cites.astype(int)
  39. references["cited"] = references.cited.astype(int)
  40. topics = pd.read_csv(opj(args.input, "topics.csv"))["label"].tolist()
  41. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  42. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title", "accelerators"]]
  43. articles["article_id"] = articles.article_id.astype(int)
  44. experimental = articles[articles["accelerators"].map(len)>=1]
  45. experimental = experimental.explode("accelerators")
  46. experimental["accelerators"] = experimental["accelerators"].str.replace(
  47. "(.*)-(.*)-(.*)$", r"\1-\2", regex=True
  48. )
  49. types = {
  50. "FNAL-E": "colliders",
  51. "CERN-LEP": "colliders",
  52. "DESY-HERA": "colliders",
  53. "SUPER-KAMIOKANDE": "astro. neutrinos",
  54. "CERN-NA": "colliders",
  55. "CESR-CLEO": "colliders",
  56. "CERN-WA": "colliders",
  57. "BNL-E": "colliders",
  58. "KAMIOKANDE": "astro. neutrinos",
  59. "SLAC-E": "colliders",
  60. "SLAC-PEP2": "colliders",
  61. "KEK-BF": "colliders",
  62. "SNO": "neutrinos",
  63. "BNL-RHIC": "colliders",
  64. "WMAP": "cosmic $\\mu$wave background",
  65. "CERN-LHC": "colliders",
  66. "PLANCK": "cosmic $\\mu$wave background",
  67. "BEPC-BES": "colliders",
  68. "LIGO": "gravitational waves",
  69. "VIRGO": "gravitational waves",
  70. "CERN-PS": "colliders",
  71. "FERMI-LAT": "other cosmic sources",
  72. "XENON100": "dark matter (direct)",
  73. "ICECUBE": "astro. neutrinos",
  74. "LUX": "dark matter (direct)",
  75. "T2K": "neutrinos",
  76. "BICEP2": "cosmic $\\mu$wave background",
  77. "CDMS": "dark matter (direct)",
  78. "LAMPF-1173": "neutrinos",
  79. "FRASCATI-DAFNE": "colliders",
  80. "KamLAND": "neutrinos",
  81. "SDSS": "other cosmic sources",
  82. "JLAB-E-89": "colliders",
  83. "CHOOZ": "neutrinos",
  84. "XENON1T": "dark matter (direct)",
  85. "SCP": "supernovae",
  86. "DAYA-BAY": "neutrinos",
  87. "HOMESTAKE-CHLORINE": "neutrinos",
  88. "HIGH-Z": "supernovae",
  89. "K2K": "neutrinos",
  90. "MACRO": "other cosmic sources",
  91. "GALLEX": "neutrinos",
  92. "SAGE": "neutrinos",
  93. "PAMELA": "other cosmic sources",
  94. "CERN-UA": "colliders",
  95. "CERN SPS": "colliders",
  96. "DESY-PETRA": "colliders",
  97. "SLAC-SLC": "colliders",
  98. "LEPS": "colliders",
  99. "DOUBLECHOOZ": "neutrinos",
  100. "AUGER": "other cosmic sources",
  101. "AMS": "other cosmic sources",
  102. "DAMA": "dark matter (direct)",
  103. "DESY-DORIS": "colliders",
  104. "NOVOSIBIRSK-CMD": "colliders",
  105. "IMB": "neutrinos",
  106. "RENO": "neutrinos",
  107. "SLAC-SP": "colliders"
  108. }
  109. experimental = experimental[experimental["accelerators"].isin(types.keys())]
  110. experimental["type"] = experimental["accelerators"].map(types)
  111. articles = articles[articles["date_created"].str.len() >= 4]
  112. articles["year"] = articles["date_created"].str[:4].astype(int)
  113. _articles = pd.read_csv(opj(args.input, "articles.csv"))
  114. articles = _articles.merge(articles, how="left")
  115. print(topic_matrix.shape)
  116. topic_matrix = topic_matrix/np.where(topic_matrix.sum(axis=1)>0, topic_matrix.sum(axis=1), 1)[:,np.newaxis]
  117. articles["topics"] = list(topic_matrix)
  118. articles["main_topic"] = topic_matrix.argmax(axis=1)
  119. articles["main_topic"] = articles["main_topic"].map(lambda x: topics[x])
  120. citing_experiments = articles.merge(references, how="inner", left_on="article_id", right_on="cites")
  121. citing_experiments = citing_experiments.merge(experimental, how="inner", left_on="cited", right_on="article_id")
  122. counts = citing_experiments.groupby("type")["main_topic"].value_counts(normalize=True).reset_index()
  123. counts = counts.pivot(index="type", columns="main_topic")
  124. print(counts)
  125. fig, ax = plt.subplots()
  126. sns.heatmap(counts, ax=ax)
  127. fig.savefig(opj(args.input, "topic_experiments.eps"), bbox_inches="tight")
  128. articles = articles[articles["year"]<2020]
  129. popularity = articles.groupby("year").agg(
  130. topics=("topics", lambda x: list(np.mean(x, axis=0)))
  131. ).sort_index()
  132. fig, ax = plt.subplots()
  133. colors = [
  134. '#377eb8', '#ff7f00', '#4daf4a',
  135. '#f781bf', '#a65628', '#984ea3',
  136. '#999999', '#e41a1c', '#dede00'
  137. ]
  138. top = np.stack(popularity["topics"]).max(axis=0)
  139. top = np.argpartition(top, -12)[-12:]
  140. i = 0
  141. for t in top:
  142. print(t)
  143. ls = "dashed" if i//len(colors)>=1 else None
  144. ax.plot(popularity.index, popularity["topics"].map(lambda x: x[t]), label=topics[t], color=colors[i%9], ls=ls)
  145. ax.scatter(popularity.index, popularity["topics"].map(lambda x: x[t]), color=colors[i%9])
  146. i += 1
  147. ax.set_xticks(np.arange(2002, 2020, 2), np.arange(2002, 2020, 2))
  148. fig.legend(ncols=3, bbox_to_anchor=(0,1.1), loc="upper left", bbox_transform=fig.transFigure)
  149. fig.savefig(opj(args.input, "topic_popularity.pdf"), bbox_inches="tight")