ei.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import numpy as np
  2. import pandas as pd
  3. from scipy.stats import entropy
  4. from scipy.special import softmax
  5. from sklearn.linear_model import LinearRegression
  6. from sklearn.metrics import r2_score
  7. from matplotlib import pyplot as plt
  8. import matplotlib
  9. from matplotlib import pyplot as plt
  10. matplotlib.use("pgf")
  11. matplotlib.rcParams.update(
  12. {
  13. "pgf.texsystem": "xelatex",
  14. "font.family": "serif",
  15. "font.serif": "Times New Roman",
  16. "text.usetex": True,
  17. "pgf.rcfonts": False,
  18. }
  19. )
  20. plt.rcParams["text.latex.preamble"].join([
  21. r"\usepackage{amsmath}",
  22. r"\setmainfont{amssymb}",
  23. ])
  24. import seaborn as sns
  25. import argparse
  26. from os.path import join as opj
  27. import pickle
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument("--input")
  30. parser.add_argument("--dataset", default="inspire-harvest/database")
  31. parser.add_argument("--suffix", default=None)
  32. args = parser.parse_args()
  33. suffix = f"_{args.suffix}" if args.suffix is not None else ""
  34. samples = np.load(opj(args.input, f"ei_samples{suffix}.npz"))
  35. topics = pd.read_csv(opj(args.input, "topics.csv"))
  36. junk = topics["label"].str.contains("Junk")
  37. topics = topics[~junk]["label"].tolist()
  38. fig, ax = plt.subplots()
  39. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  40. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  41. NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  42. NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  43. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].fillna(0).values)
  44. # junk = np.sum(NR + NC, axis=0) == 0
  45. NR = NR[:,~junk]
  46. NC = NC[:,~junk]
  47. expertise = expertise[:,~junk]
  48. from scipy.spatial.distance import pdist, squareform
  49. from fastcluster import linkage
  50. def seriation(Z,N,cur_index):
  51. if cur_index < N:
  52. return [cur_index]
  53. else:
  54. left = int(Z[cur_index-N,0])
  55. right = int(Z[cur_index-N,1])
  56. return (seriation(Z,N,left) + seriation(Z,N,right))
  57. def compute_serial_matrix(dist_mat,method="ward"):
  58. N = len(dist_mat)
  59. flat_dist_mat = squareform(dist_mat)
  60. res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
  61. res_order = seriation(res_linkage, N, N + N-2)
  62. seriated_dist = np.zeros((N,N))
  63. a,b = np.triu_indices(N,k=1)
  64. seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
  65. seriated_dist[b,a] = seriated_dist[a,b]
  66. return seriated_dist, res_order, res_linkage
  67. dist = 1-np.array([
  68. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))]
  69. for i in range(len(topics))
  70. ])
  71. dist = np.nan_to_num(dist)
  72. m, order, dendo = compute_serial_matrix(dist)
  73. order = np.array(order)[::-1]
  74. print(order)
  75. ordered_topics = [topics[i] for i in order]
  76. np.save(opj(args.input, "topics_order.npy"), order)
  77. x = NR/NR.sum(axis=1)[:,np.newaxis]
  78. y = NR/NR.sum(axis=1)[:,np.newaxis]
  79. expertise = expertise/expertise.sum(axis=1)[:,np.newaxis]
  80. R = np.array([
  81. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
  82. for i in range(len(topics))
  83. ])
  84. print(expertise)
  85. fig, ax = plt.subplots()
  86. x = np.linspace(np.min(R), np.max(R), 100)
  87. y = samples["delta_0"][:,np.newaxis]+np.einsum("s,i->si", samples["delta_nu"], x)
  88. ax.fill_between(x, np.quantile(y,axis=0,q=0.05/2), np.quantile(y,axis=0,q=1-0.05/2), color="lightgray")
  89. ax.plot(x, y.mean(axis=0), color="black")
  90. delta = samples["delta"].mean(axis=0)
  91. sig = (np.quantile(samples["delta"], q=0.05/2, axis=0)*np.quantile(samples["delta"], q=1-0.05/2, axis=0)).flatten()>0
  92. ax.scatter(R.flatten()[sig], delta.flatten()[sig], s=2)
  93. ax.errorbar(
  94. R.flatten()[sig],
  95. delta.flatten()[sig],
  96. (delta.flatten()[sig]-np.quantile(samples["delta"], q=0.05/2, axis=0).flatten()[sig], np.quantile(samples["delta"], q=1-0.05/2, axis=0).flatten()[sig]-delta.flatten()[sig]),
  97. ls="none",
  98. lw=0.5
  99. )
  100. ax.set_xlabel("$\\nu_{kk'}$")
  101. ax.set_ylabel("$\\delta_{kk'}$")
  102. fig.savefig(opj(args.input, f"delta_vs_nu{suffix}.eps"), bbox_inches="tight")
  103. plt.clf()
  104. counts = samples["counts"].mean(axis=(0))
  105. counts = counts/counts.sum(axis=0)[:,np.newaxis]
  106. sns.heatmap(
  107. counts[:, order][order],
  108. vmin=0,
  109. cmap="Blues",
  110. xticklabels=ordered_topics,
  111. yticklabels=ordered_topics,
  112. )
  113. plt.xticks(rotation=90)
  114. plt.yticks(rotation=0)
  115. plt.savefig(
  116. opj(args.input, f"ei_counts{suffix}.eps"), bbox_inches="tight"
  117. )
  118. plt.clf()
  119. mu = samples["mu"][:,:, order][:,order]
  120. gamma = samples["gamma"][:,:, order][:,order]
  121. for s in range(mu.shape[0]):
  122. mu[s,:,:] += np.diag(np.diag(gamma[s]))
  123. mu = np.mean(mu, axis=0)
  124. mu = softmax(mu,axis=1)
  125. sns.heatmap(
  126. mu,
  127. vmin=0,
  128. vmax=np.max(mu),
  129. cmap="Blues",
  130. xticklabels=ordered_topics,
  131. yticklabels=ordered_topics,
  132. annot=[
  133. [
  134. f"{mu[i,j]:.2f}"
  135. if i == j
  136. else ""
  137. for j in range(len(topics))
  138. ]
  139. for i in range(len(topics))
  140. ],
  141. fmt="",
  142. annot_kws={"fontsize": 5},
  143. )
  144. plt.xticks(rotation=90)
  145. plt.yticks(rotation=0)
  146. plt.savefig(
  147. opj(args.input, f"ei_mu{suffix}.eps"), bbox_inches="tight"
  148. )
  149. fig, ax = plt.subplots()
  150. beta = mu
  151. ax.scatter(R.flatten(), beta.flatten())
  152. fig.savefig(
  153. opj(args.input, f"ei_mu_kl_dist.eps"),
  154. bbox_inches="tight",
  155. )
  156. def plot_matrix(param):
  157. fig, ax = plt.subplots()
  158. x = samples[param][:,:, order][:,order]
  159. delta = x.mean(axis=(0))
  160. up = np.quantile(x, axis=0, q=1 - 0.05 / 2)
  161. low = np.quantile(x, axis=0, q=0.05 / 2)
  162. up_3s = np.quantile(x, axis=0, q=1 - 0.003 / 2)
  163. low_3s = np.quantile(x, axis=0, q=0.003 / 2)
  164. if len(delta.shape) == 1:
  165. ax.errorbar(
  166. np.arange(len(delta)), delta, yerr=(delta - low, up - delta), ls="none"
  167. )
  168. ax.scatter(np.arange(len(delta)), delta)
  169. ax.set_xticks(np.arange(len(delta)))
  170. ax.set_xticklabels(topics)
  171. ax.xaxis.set_tick_params(rotation=90)
  172. fig.savefig(
  173. opj(args.input, f"ei_{param}{suffix}.eps"),
  174. bbox_inches="tight",
  175. )
  176. else:
  177. significant_2s = up * low > 0
  178. significant_3s = up_3s * low_3s > 0
  179. sns.heatmap(
  180. delta,
  181. cmap="RdBu",
  182. vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  183. vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  184. xticklabels=ordered_topics,
  185. yticklabels=ordered_topics,
  186. ax=ax,
  187. annot=[
  188. [
  189. "$\\ast\\ast$"
  190. if significant_3s[i, j]
  191. else ("$\\ast$" if significant_2s[i, j] else "")
  192. for j in range(len(topics))
  193. ]
  194. for i in range(len(topics))
  195. ],
  196. fmt="",
  197. annot_kws={"fontsize": 6},
  198. )
  199. ax.xaxis.set_tick_params(rotation=90)
  200. ax.yaxis.set_tick_params(rotation=0)
  201. fig.savefig(
  202. opj(args.input, f"ei_{param}{suffix}.eps"),
  203. bbox_inches="tight",
  204. )
  205. plt.clf()
  206. plt.clf()
  207. plot_matrix("delta")
  208. plot_matrix("gamma")
  209. fig, ax = plt.subplots()
  210. RC = R[:, order][order]
  211. # np.fill_diagonal(RC, np.nan)
  212. sns.heatmap(
  213. RC,
  214. cmap="Blues",
  215. vmin=0,
  216. vmax=+np.maximum(np.abs(np.min(RC[~np.isnan(RC)])), np.abs(np.max(RC[~np.isnan(RC)]))),
  217. xticklabels=ordered_topics,
  218. yticklabels=ordered_topics,
  219. ax=ax,
  220. fmt="",
  221. annot_kws={"fontsize": 6},
  222. )
  223. ax.xaxis.set_tick_params(rotation=90)
  224. ax.yaxis.set_tick_params(rotation=0)
  225. fig.savefig(
  226. opj(args.input, f"ei_R{suffix}.eps"),
  227. bbox_inches="tight",
  228. )
  229. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  230. topic_matrix = topic_matrix[:,~junk]
  231. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]]
  232. articles = articles[articles["date_created"].str.len() >= 4]
  233. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  234. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  235. articles["year_group"] = articles["year"]//5
  236. _articles = pd.read_csv(opj(args.input,"articles.csv"))
  237. articles["article_id"] = articles.article_id.astype(int)
  238. articles = _articles.merge(articles, how="left")
  239. articles["main_topic"] = topic_matrix.argmax(axis=1)
  240. references = pd.read_parquet(opj(args.dataset, "articles_references.parquet"))
  241. references["cites"] = references.cites.astype(int)
  242. references["cited"] = references.cited.astype(int)
  243. references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cites", right_on="article_id")
  244. references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cited", right_on="article_id", suffixes=("_cites", "_cited"))
  245. citation_matrix = np.zeros((n_topics, n_topics))
  246. for i in range(n_topics):
  247. for j in range(n_topics):
  248. citation_matrix[i,j] = len(references[(references["main_topic_cites"]==i)&(references["main_topic_cited"]==j)])
  249. fig, ax = plt.subplots()
  250. m = np.log(citation_matrix.sum()*citation_matrix/np.outer(citation_matrix.sum(axis=1),citation_matrix.sum(axis=0)))
  251. m = m[:, order][order]
  252. sns.heatmap(
  253. m,
  254. ax=ax,
  255. vmin=-np.max(np.abs(m)),
  256. vmax=np.max(np.abs(m)),
  257. cmap="RdBu",
  258. xticklabels=ordered_topics,
  259. yticklabels=ordered_topics,
  260. annot=[
  261. [
  262. f"$\\times$\\textbf{{{np.exp(m[i,j]):.1f}}}"
  263. if np.exp(m[i,j])>=1.05
  264. else ""
  265. for j in range(len(topics))
  266. ]
  267. for i in range(len(topics))
  268. ],
  269. fmt="",
  270. annot_kws={"fontsize": 5},
  271. )
  272. ax.xaxis.set_tick_params(rotation=90)
  273. ax.yaxis.set_tick_params(rotation=0)
  274. ax.set_xlabel("\\textbf{Cited category (references)}")
  275. ax.set_ylabel("\\textbf{Citing category}")
  276. fig.savefig(opj(args.input, "topic_citation_matrix.eps"), bbox_inches="tight")