ei.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import numpy as np
  2. import pandas as pd
  3. from scipy.stats import entropy
  4. from scipy.special import softmax
  5. from sklearn.linear_model import LinearRegression
  6. from sklearn.metrics import r2_score
  7. from matplotlib import pyplot as plt
  8. import matplotlib
  9. from matplotlib import pyplot as plt
  10. matplotlib.use("pgf")
  11. matplotlib.rcParams.update(
  12. {
  13. "pgf.texsystem": "xelatex",
  14. "font.family": "serif",
  15. "font.serif": "Times New Roman",
  16. "text.usetex": True,
  17. "pgf.rcfonts": False,
  18. }
  19. )
  20. plt.rcParams["text.latex.preamble"].join([
  21. r"\usepackage{amsmath}",
  22. r"\setmainfont{amssymb}",
  23. ])
  24. import seaborn as sns
  25. import argparse
  26. from os.path import join as opj
  27. import pickle
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument("--input")
  30. parser.add_argument("--dataset", default="inspire-harvest/database")
  31. parser.add_argument("--suffix", default=None)
  32. parser.add_argument("--mle", action="store_true", default=False)
  33. args = parser.parse_args()
  34. suffix = f"_{args.suffix}" if args.suffix is not None else ""
  35. format = "mle" if args.mle else "samples"
  36. samples = np.load(opj(args.input, f"ei_{format}{suffix}.npz"))
  37. topics = pd.read_csv(opj(args.input, "topics.csv"))
  38. junk = topics["label"].str.contains("Junk")
  39. topics = topics[~junk]["label"].tolist()
  40. fig, ax = plt.subplots()
  41. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  42. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  43. NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  44. NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  45. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].fillna(0).values)
  46. # junk = np.sum(NR + NC, axis=0) == 0
  47. NR = NR[:,~junk]
  48. NC = NC[:,~junk]
  49. expertise = expertise[:,~junk]
  50. from scipy.spatial.distance import pdist, squareform
  51. from fastcluster import linkage
  52. def seriation(Z,N,cur_index):
  53. if cur_index < N:
  54. return [cur_index]
  55. else:
  56. left = int(Z[cur_index-N,0])
  57. right = int(Z[cur_index-N,1])
  58. return (seriation(Z,N,left) + seriation(Z,N,right))
  59. def compute_serial_matrix(dist_mat,method="ward"):
  60. N = len(dist_mat)
  61. flat_dist_mat = squareform(dist_mat)
  62. res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
  63. res_order = seriation(res_linkage, N, N + N-2)
  64. seriated_dist = np.zeros((N,N))
  65. a,b = np.triu_indices(N,k=1)
  66. seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
  67. seriated_dist[b,a] = seriated_dist[a,b]
  68. return seriated_dist, res_order, res_linkage
  69. dist = 1-np.array([
  70. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))]
  71. for i in range(len(topics))
  72. ])
  73. dist = np.nan_to_num(dist)
  74. m, order, dendo = compute_serial_matrix(dist)
  75. order = np.array(order)[::-1]
  76. print(order)
  77. ordered_topics = [topics[i] for i in order]
  78. np.save(opj(args.input, "topics_order.npy"), order)
  79. x = NR/NR.sum(axis=1)[:,np.newaxis]
  80. y = NR/NR.sum(axis=1)[:,np.newaxis]
  81. expertise = expertise/expertise.sum(axis=1)[:,np.newaxis]
  82. R = np.array([
  83. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
  84. for i in range(len(topics))
  85. ])
  86. print(expertise)
  87. if not args.mle:
  88. fig, ax = plt.subplots()
  89. x = np.linspace(np.min(R), np.max(R), 100)
  90. y = samples["delta_0"][:,np.newaxis]+np.einsum("s,i->si", samples["delta_nu"], x)
  91. ax.fill_between(x, np.quantile(y,axis=0,q=0.05/2), np.quantile(y,axis=0,q=1-0.05/2), color="lightgray")
  92. ax.plot(x, y.mean(axis=0), color="black")
  93. delta = samples["delta"].mean(axis=0)
  94. sig = (np.quantile(samples["delta"], q=0.05/2, axis=0)*np.quantile(samples["delta"], q=1-0.05/2, axis=0)).flatten()>0
  95. ax.scatter(R.flatten()[sig], delta.flatten()[sig], s=2)
  96. ax.errorbar(
  97. R.flatten()[sig],
  98. delta.flatten()[sig],
  99. (delta.flatten()[sig]-np.quantile(samples["delta"], q=0.05/2, axis=0).flatten()[sig], np.quantile(samples["delta"], q=1-0.05/2, axis=0).flatten()[sig]-delta.flatten()[sig]),
  100. ls="none",
  101. lw=0.5
  102. )
  103. ax.set_xlabel("$\\nu_{kk'}$")
  104. ax.set_ylabel("$\\delta_{kk'}$")
  105. fig.savefig(opj(args.input, f"delta_vs_nu{suffix}.eps"), bbox_inches="tight")
  106. plt.clf()
  107. counts = samples["counts"].mean(axis=(0))
  108. counts = counts/counts.sum(axis=0)[:,np.newaxis]
  109. sns.heatmap(
  110. counts[:, order][order],
  111. vmin=0,
  112. cmap="Blues",
  113. xticklabels=ordered_topics,
  114. yticklabels=ordered_topics,
  115. )
  116. plt.xticks(rotation=90)
  117. plt.yticks(rotation=0)
  118. plt.savefig(
  119. opj(args.input, f"ei_counts{suffix}.eps"), bbox_inches="tight"
  120. )
  121. plt.clf()
  122. mu = samples["mu"][:,:, order][:,order]
  123. gamma = samples["gamma"][:,:, order][:,order]
  124. for s in range(mu.shape[0]):
  125. mu[s,:,:] += np.diag(np.diag(gamma[s]))
  126. mu = np.mean(mu, axis=0)
  127. mu = softmax(mu,axis=1)
  128. sns.heatmap(
  129. mu,
  130. vmin=0,
  131. vmax=np.max(mu),
  132. cmap="Blues",
  133. xticklabels=ordered_topics,
  134. yticklabels=ordered_topics,
  135. annot=[
  136. [
  137. f"{mu[i,j]:.2f}"
  138. if i == j
  139. else ""
  140. for j in range(len(topics))
  141. ]
  142. for i in range(len(topics))
  143. ],
  144. fmt="",
  145. annot_kws={"fontsize": 5},
  146. )
  147. plt.xticks(rotation=90)
  148. plt.yticks(rotation=0)
  149. plt.savefig(
  150. opj(args.input, f"ei_mu{suffix}.eps"), bbox_inches="tight"
  151. )
  152. fig, ax = plt.subplots()
  153. beta = mu
  154. ax.scatter(R.flatten(), beta.flatten())
  155. fig.savefig(
  156. opj(args.input, f"ei_mu_kl_dist.eps"),
  157. bbox_inches="tight",
  158. )
  159. def plot_matrix(param, hide_insignificant: bool=True):
  160. fig, ax = plt.subplots()
  161. x = samples[param][:,:, order][:,order]
  162. delta = x.mean(axis=(0))
  163. up = np.quantile(x, axis=0, q=1 - 0.05 / 2)
  164. low = np.quantile(x, axis=0, q=0.05 / 2)
  165. up_3s = np.quantile(x, axis=0, q=1 - 0.003 / 2)
  166. low_3s = np.quantile(x, axis=0, q=0.003 / 2)
  167. if len(delta.shape) == 1:
  168. ax.errorbar(
  169. np.arange(len(delta)), delta, yerr=(delta - low, up - delta), ls="none"
  170. )
  171. ax.scatter(np.arange(len(delta)), delta)
  172. ax.set_xticks(np.arange(len(delta)))
  173. ax.set_xticklabels(topics)
  174. ax.xaxis.set_tick_params(rotation=90)
  175. fig.savefig(
  176. opj(args.input, f"ei_{param}{suffix}.eps"),
  177. bbox_inches="tight",
  178. )
  179. else:
  180. significant_2s = up * low > 0
  181. significant_3s = up_3s * low_3s > 0
  182. if hide_insignificant:
  183. sns.heatmap(
  184. np.where(significant_2s, delta, np.nan),
  185. cmap="RdBu",
  186. vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  187. vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  188. xticklabels=ordered_topics,
  189. yticklabels=ordered_topics,
  190. ax=ax,
  191. fmt="",
  192. annot_kws={"fontsize": 6},
  193. )
  194. else:
  195. sns.heatmap(
  196. delta,
  197. cmap="RdBu",
  198. vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  199. vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
  200. xticklabels=ordered_topics,
  201. yticklabels=ordered_topics,
  202. ax=ax,
  203. annot=[
  204. [
  205. "$\\ast\\ast$"
  206. if significant_3s[i, j]
  207. else ("$\\ast$" if significant_2s[i, j] else "")
  208. for j in range(len(topics))
  209. ]
  210. for i in range(len(topics))
  211. ],
  212. fmt="",
  213. annot_kws={"fontsize": 6},
  214. )
  215. ax.xaxis.set_tick_params(rotation=90)
  216. ax.yaxis.set_tick_params(rotation=0)
  217. fig.savefig(
  218. opj(args.input, f"ei_{param}{suffix}.eps"),
  219. bbox_inches="tight",
  220. )
  221. plt.clf()
  222. plt.clf()
  223. if not args.mle:
  224. plot_matrix("delta")
  225. plot_matrix("gamma")
  226. fig, ax = plt.subplots()
  227. RC = R[:, order][order]
  228. # np.fill_diagonal(RC, np.nan)
  229. sns.heatmap(
  230. 1-RC,
  231. cmap="Blues",
  232. vmin=0,
  233. vmax=+np.maximum(np.abs(np.min(RC[~np.isnan(RC)])), np.abs(np.max(RC[~np.isnan(RC)]))),
  234. xticklabels=ordered_topics,
  235. yticklabels=[""]*len(ordered_topics),
  236. ax=ax,
  237. fmt="",
  238. annot_kws={"fontsize": 6},
  239. )
  240. ax.xaxis.set_tick_params(rotation=90)
  241. ax.yaxis.set_tick_params(rotation=0)
  242. fig.savefig(
  243. opj(args.input, f"ei_R{suffix}.eps"),
  244. bbox_inches="tight",
  245. )
  246. fig.savefig(
  247. opj(args.input, f"ei_R{suffix}.pdf"),
  248. bbox_inches="tight",
  249. )
  250. fig.savefig(
  251. opj(args.input, f"ei_R{suffix}.png"),
  252. bbox_inches="tight",
  253. dpi=300
  254. )
  255. topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
  256. topic_matrix = topic_matrix[:,~junk]
  257. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]]
  258. articles = articles[articles["date_created"].str.len() >= 4]
  259. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  260. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  261. articles["year_group"] = articles["year"]//5
  262. _articles = pd.read_csv(opj(args.input,"articles.csv"))
  263. articles["article_id"] = articles.article_id.astype(int)
  264. articles = _articles.merge(articles, how="left")
  265. articles["main_topic"] = topic_matrix.argmax(axis=1)
  266. references = pd.read_parquet(opj(args.dataset, "articles_references.parquet"))
  267. references["cites"] = references.cites.astype(int)
  268. references["cited"] = references.cited.astype(int)
  269. references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cites", right_on="article_id")
  270. references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cited", right_on="article_id", suffixes=("_cites", "_cited"))
  271. citation_matrix = np.zeros((n_topics, n_topics))
  272. for i in range(n_topics):
  273. for j in range(n_topics):
  274. citation_matrix[i,j] = len(references[(references["main_topic_cites"]==i)&(references["main_topic_cited"]==j)])
  275. fig, ax = plt.subplots()
  276. m = np.log(citation_matrix.sum()*citation_matrix/np.outer(citation_matrix.sum(axis=1),citation_matrix.sum(axis=0)))
  277. m = m[:, order][order]
  278. sns.heatmap(
  279. m,
  280. ax=ax,
  281. vmin=-np.max(np.abs(m)),
  282. vmax=np.max(np.abs(m)),
  283. cmap="RdBu",
  284. xticklabels=ordered_topics,
  285. yticklabels=ordered_topics,
  286. annot=[
  287. [
  288. f"$\\times$\\textbf{{{np.exp(m[i,j]):.1f}}}"
  289. if np.exp(m[i,j])>=1.05
  290. else ""
  291. for j in range(len(topics))
  292. ]
  293. for i in range(len(topics))
  294. ],
  295. fmt="",
  296. annot_kws={"fontsize": 5},
  297. )
  298. ax.xaxis.set_tick_params(rotation=90)
  299. ax.yaxis.set_tick_params(rotation=0)
  300. ax.set_xlabel("\\textbf{Cited category (references)}")
  301. ax.set_ylabel("\\textbf{Citing category}")
  302. fig.savefig(opj(args.input, "topic_citation_matrix.eps"), bbox_inches="tight")