fit_trends.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib
  4. from matplotlib import pyplot as plt
  5. import seaborn as sns
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update(
  8. {
  9. "pgf.texsystem": "xelatex",
  10. "font.family": "serif",
  11. "font.serif": "Times New Roman",
  12. "text.usetex": True,
  13. "pgf.rcfonts": False,
  14. }
  15. )
  16. from sklearn.preprocessing import MultiLabelBinarizer
  17. from sklearn.linear_model import LinearRegression
  18. from scipy import stats
  19. import argparse
  20. def pearsonr_ci(x,y,alpha=0.05):
  21. ''' calculate Pearson correlation along with the confidence interval using scipy and numpy
  22. Source: https://zhiyzuo.github.io/Pearson-Correlation-CI-in-Python/
  23. Parameters
  24. ----------
  25. x, y : iterable object such as a list or np.array
  26. Input for correlation calculation
  27. alpha : float
  28. Significance level. 0.05 by default
  29. Returns
  30. -------
  31. r : float
  32. Pearson's correlation coefficient
  33. pval : float
  34. The corresponding p value
  35. lo, hi : float
  36. The lower and upper bound of confidence intervals
  37. '''
  38. r, p = stats.pearsonr(x,y)
  39. r_z = np.arctanh(r)
  40. se = 1/np.sqrt(x.size-3)
  41. z = stats.norm.ppf(1-alpha/2)
  42. lo_z, hi_z = r_z-z*se, r_z+z*se
  43. lo, hi = np.tanh((lo_z, hi_z))
  44. return r, p, lo, hi
  45. def is_susy(s: str):
  46. return "supersymmetr" in s or "susy" in s
  47. parser = argparse.ArgumentParser("extracting correlations")
  48. parser.add_argument("--articles", default="output/hep-ct-75-0.1-0.001-130000-20/topics_0.parquet")
  49. parser.add_argument("--since", help="since year", type=int, default=2011)
  50. parser.add_argument("--until", help="until year", type=int, default=2019)
  51. parser.add_argument("--domain", choices=["hep", "susy"], default="susy")
  52. parser.add_argument("--descriptions", default="output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")
  53. args = parser.parse_args()
  54. years = np.arange(args.since, args.until+1)
  55. n_years = len(years)
  56. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["article_id", "date_created", "pacs_codes", "categories", "abstract", "title"]]
  57. if args.domain == "susy":
  58. articles = articles[(articles["abstract"].str.lower().map(is_susy) == True) | (articles["title"].str.lower().map(is_susy) == True)]
  59. articles["article_id"] = articles["article_id"].astype(int)
  60. articles["year"] = articles["date_created"].str[:4].replace('', 0).astype(int)
  61. articles = articles[(articles["year"] >= years.min()) & (articles["year"] <= years.max())]
  62. topics = pd.read_parquet(args.articles)
  63. topics["article_id"] = topics["article_id"].astype(int)
  64. topics["topics"] = topics["probs"]
  65. topics.drop(columns = ["year"], inplace = True)
  66. topics = topics.merge(articles, how="inner", left_on = "article_id", right_on = "article_id")
  67. n_topics = len(topics["topics"].iloc[0])
  68. cumprobs = np.zeros((n_years, n_topics))
  69. counts = np.zeros(n_years)
  70. for year, _articles in topics.groupby("year"):
  71. for article in _articles.to_dict(orient = 'records'):
  72. for topic, prob in enumerate(article['probs']):
  73. cumprobs[year-years.min(),topic] += prob
  74. counts[year-years.min()] = len(_articles)
  75. fits = []
  76. for topic in range(n_topics):
  77. y = cumprobs[:,topic]/counts
  78. reg = LinearRegression().fit(years.reshape(-1, 1), y)
  79. r, p, lo_95, hi_95 = pearsonr_ci(years, y)
  80. r, p, lo_99, hi_99 = pearsonr_ci(years, y, alpha=0.01)
  81. fits.append({
  82. 'topic': topic,
  83. 'r2': reg.score(years.reshape(-1, 1), y)**2,
  84. 'slope': reg.coef_[0],
  85. 'lower_95': lo_95,
  86. 'high_95': hi_95,
  87. 'lower_99': lo_99,
  88. 'high_99': hi_99
  89. })
  90. fits = pd.DataFrame(fits)
  91. fits = fits.merge(pd.read_csv(args.descriptions))
  92. fits['95_significant'] = fits['lower_95']*fits['high_95'] > 0
  93. fits['99_significant'] = fits['lower_99']*fits['high_99'] > 0
  94. fits.sort_values("slope", ascending=True, inplace=True)
  95. fits.to_csv('output/fits.csv')
  96. significant_dec = fits[fits["99_significant"]==True].head(3)
  97. significant_inc = fits[fits["99_significant"]==True].tail(3)
  98. fig, axes = plt.subplots(1,2,sharey=True)
  99. colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3']
  100. ax = axes[0]
  101. n = 0
  102. for topic in significant_dec.to_dict(orient="records"):
  103. ax.plot(
  104. years,
  105. cumprobs[:,topic['topic']]/counts,
  106. color = colors[n]
  107. )
  108. ax.scatter(
  109. years,
  110. cumprobs[:,topic['topic']]/counts,
  111. label=topic['description'],
  112. color = colors[n]
  113. )
  114. n +=1
  115. ax.set_ylabel("Average relative contribution of each topic per year ($\\bar{\\theta_z}$)")
  116. ax.set_xlim(years.min(), years.max())
  117. ax.legend(fontsize='x-small', loc="upper right")
  118. ax = axes[1]
  119. n = 0
  120. for topic in significant_inc.to_dict(orient="records"):
  121. ax.plot(
  122. years,
  123. cumprobs[:,topic['topic']]/counts,
  124. color = colors[n]
  125. )
  126. ax.scatter(
  127. years,
  128. cumprobs[:,topic['topic']]/counts,
  129. label=topic['description'],
  130. color = colors[n]
  131. )
  132. n +=1
  133. ax.set_xlim(years.min(), years.max())
  134. ax.legend(fontsize='x-small', loc="upper right")
  135. fig.suptitle(
  136. "Coldest topics (left) and hottest topics (right) – {}, {}-{}".format(
  137. "high-energy physics" if args.domain == "hep" else "supersymmetry",
  138. args.since,
  139. args.until
  140. )
  141. )
  142. plt.savefig(f"plots/hot_cold_topics_hep_{args.since}_{args.until}_{args.domain}.pgf", bbox_inches="tight")
  143. plt.savefig(f"plots/hot_cold_topics_hep_{args.since}_{args.until}_{args.domain}.pdf", bbox_inches="tight")
  144. plt.savefig(f"plots/hot_cold_topics_hep_{args.since}_{args.until}_{args.domain}.eps", bbox_inches="tight")