trading_zone.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib
  4. from matplotlib import pyplot as plt
  5. matplotlib.use("pgf")
  6. matplotlib.rcParams.update(
  7. {
  8. "pgf.texsystem": "xelatex",
  9. "font.family": "serif",
  10. "font.serif": "Times New Roman",
  11. "text.usetex": True,
  12. "pgf.rcfonts": False,
  13. }
  14. )
  15. plt.rcParams["text.latex.preamble"].join([
  16. r"\usepackage{amsmath}",
  17. r"\setmainfont{amssymb}",
  18. ])
  19. from scipy.sparse import csr_matrix
  20. import argparse
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument("--cited", choices=["0","1","2"], help="Cited category (0: theory, 1: phenomenology, 2: experiment)", required=True)
  23. parser.add_argument("--cites", choices=["0","1","2"], help="Citing category (0: theory, 1: phenomenology, 2: experiment)", required=True)
  24. parser.add_argument("--include-crosslists", default=False, action="store_true")
  25. args = parser.parse_args()
  26. crosslists = "_crosslists" if args.include_crosslists else ""
  27. boundary = f"{args.cited}_{args.cites}"
  28. ngrams = pd.read_csv(f"output/trading_zone{crosslists}_{boundary}/selected_ngrams.csv")
  29. ngrams["keyword"] = ngrams.index+1
  30. supersymmetry_keywords = ngrams[ngrams["ngram"].str.contains("super")]["keyword"].tolist()
  31. n_ngrams = len(ngrams)
  32. citations = pd.read_csv(f"output/trading_zone{crosslists}_{boundary}/citations.csv")
  33. p_t = citations.groupby("year_cites")["trades"].sum().values
  34. bow = np.load(f"output/trading_zone{crosslists}_{boundary}/full_bow.npy")
  35. p_w_t = np.zeros((len(p_t), bow.shape[1]))
  36. for citation in citations.to_dict(orient="records"):
  37. i = citation["id"]-1
  38. p_w_t[citation["year_cites"],:] += bow[i,:]*citation["trades"]
  39. p_w_t = (p_w_t.T/p_t).T
  40. ngrams["p_w_t_max"] = p_w_t.max(axis=0)
  41. ngrams["drop"] = False
  42. n_items = bow.shape[0]
  43. n_words = bow.shape[1]
  44. num = np.outer(bow.sum(axis=0)/n_items,bow.sum(axis=0)/n_items)
  45. den = (csr_matrix(bow.T).dot(csr_matrix(bow))/n_items).todense()
  46. npmi = np.log(num)/np.log(den)-1
  47. x, y = np.where(npmi-np.identity(n_words)>0.5)
  48. for k,_ in enumerate(x):
  49. i = x[k]
  50. j = y[k]
  51. a = ngrams.at[i,"p_w_t_max"]
  52. b = ngrams.at[j,"p_w_t_max"]
  53. if a < b:
  54. ngrams.at[i,"drop"] = True
  55. else:
  56. ngrams.at[j,"drop"] = True
  57. ngrams.sort_values("p_w_t_max", inplace=True)
  58. ngrams = ngrams[ngrams["drop"]==False]
  59. years = 2001+np.arange(len(p_t))
  60. colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']
  61. colors += colors
  62. fig, axes = plt.subplots(1,2,sharey=True)
  63. ax = axes[0]
  64. n = 0
  65. for i, ngram in ngrams.tail(5).to_dict(orient="index").items():
  66. p = p_w_t[:,i]
  67. ax.scatter(years, p, color=colors[n], s=10, label=ngram["ngram"])
  68. ax.plot(years, p, color=colors[n], lw=0.5)
  69. ax.set_ylabel(f"$P(b_k=1|$trade$)$")
  70. n += 1
  71. ax = axes[1]
  72. n = 0
  73. for i, ngram in ngrams[ngrams["ngram"].str.contains("super")].tail(5).to_dict(orient="index").items():
  74. if n >= len(colors):
  75. break
  76. p = p_w_t[:,i]
  77. ax.scatter(years, p, color=colors[n], s=10, label=ngram["ngram"])
  78. ax.plot(years, p, color=colors[n], lw=0.5)
  79. ax.set_ylabel(f"$P(b_k=1|$trade$)$")
  80. n += 1
  81. for i in range(2):
  82. axes[i].set_xlim(2001,2019)
  83. # axes[i].set_ylim(0.003)
  84. axes[i].set_yscale("log")
  85. axes[i].legend(loc=("best" if i < 2 else "lower right"), prop={'size': 6})
  86. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.pdf", bbox_inches="tight")
  87. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.pgf", bbox_inches="tight")
  88. plt.savefig(f"plots/trading_zone{crosslists}_{boundary}.eps", bbox_inches="tight")