trading_zone.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib
  4. from matplotlib import pyplot as plt
  5. matplotlib.use("pgf")
  6. matplotlib.rcParams.update(
  7. {
  8. "pgf.texsystem": "xelatex",
  9. "font.family": "serif",
  10. "font.serif": "Times New Roman",
  11. "text.usetex": True,
  12. "pgf.rcfonts": False,
  13. }
  14. )
  15. plt.rcParams["text.latex.preamble"].join([
  16. r"\usepackage{amsmath}",
  17. r"\setmainfont{amssymb}",
  18. ])
  19. import argparse
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--cited", choices=["0","1","2"], help="Cited category (0: theory, 1: phenomenology, 2: experiment)", required=True)
  22. parser.add_argument("--cites", choices=["0","1","2"], help="Citing category (0: theory, 1: phenomenology, 2: experiment)", required=True)
  23. parser.add_argument("--include-crosslists", default=False, action="store_true")
  24. args = parser.parse_args()
  25. crosslists = "_crosslists" if args.include_crosslists else ""
  26. boundary = f"{args.cited}_{args.cites}"
  27. ngrams = pd.read_csv(f"output/trading_zone{crosslists}_{boundary}/selected_ngrams.csv")
  28. ngrams["keyword"] = ngrams.index+1
  29. supersymmetry_keywords = ngrams[ngrams["ngram"].str.contains("super")]["keyword"].tolist()
  30. n_ngrams = len(ngrams)
  31. citations = pd.read_csv(f"output/trading_zone{crosslists}_{boundary}/citations.csv")
  32. p_t = citations.groupby("year_cites")["trades"].sum().values
  33. bow = np.load(f"output/trading_zone{crosslists}_{boundary}/full_bow.npy")
  34. p_w_t = np.zeros((len(p_t), bow.shape[1]))
  35. for citation in citations.to_dict(orient="records"):
  36. i = citation["id"]-1
  37. p_w_t[citation["year_cites"],:] += bow[i,:]*citation["trades"]
  38. p_w_t = (p_w_t.T/p_t).T
  39. ngrams["p_w_t_max"] = p_w_t.max(axis=0)
  40. ngrams["drop"] = False
  41. n_items = bow.shape[0]
  42. n_words = bow.shape[1]
  43. num = np.outer(bow[:2000].sum(axis=0),bow[:2000].sum(axis=0))/(2000**2)
  44. den = np.tensordot(bow[:2000,:], bow[:2000,:], axes=([0],[0]))/2000
  45. npmi = np.log(num)/np.log(den)-1
  46. x, y = np.where(npmi-np.identity(n_words)>0.9)
  47. for k,_ in enumerate(x):
  48. i = x[k]
  49. j = y[k]
  50. a = ngrams.at[i,"p_w_t_max"]
  51. b = ngrams.at[j,"p_w_t_max"]
  52. if a < b:
  53. ngrams.at[i,"drop"] = True
  54. else:
  55. ngrams.at[j,"drop"] = True
  56. ngrams.sort_values("p_w_t_max", inplace=True)
  57. ngrams = ngrams[ngrams["drop"]==False]
  58. years = 2001+np.arange(len(p_t))
  59. colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']
  60. colors += colors
  61. fig, axes = plt.subplots(1,2,sharey=True)
  62. ax = axes[0]
  63. n = 0
  64. for i, ngram in ngrams.tail(5).to_dict(orient="index").items():
  65. p = p_w_t[:,i]
  66. ax.scatter(years, p, color=colors[n], s=10, label=ngram["ngram"])
  67. ax.plot(years, p, color=colors[n], lw=0.5)
  68. ax.set_ylabel(f"$P(b_k=1|$trade$)$")
  69. n += 1
  70. ax = axes[1]
  71. n = 0
  72. for i, ngram in ngrams[ngrams["ngram"].str.contains("super")].tail(5).to_dict(orient="index").items():
  73. if n >= len(colors):
  74. break
  75. p = p_w_t[:,i]
  76. ax.scatter(years, p, color=colors[n], s=10, label=ngram["ngram"])
  77. ax.plot(years, p, color=colors[n], lw=0.5)
  78. ax.set_ylabel(f"$P(b_k=1|$trade$)$")
  79. n += 1
  80. for i in range(2):
  81. axes[i].set_xlim(2001,2019)
  82. axes[i].set_ylim(0.003,0.3)
  83. axes[i].set_yscale("log")
  84. axes[i].legend(loc=("best" if i < 2 else "lower right"), prop={'size': 6})
  85. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.pdf", bbox_inches="tight")
  86. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.pgf", bbox_inches="tight")
  87. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.eps", bbox_inches="tight")