123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- import numpy as np
- import pandas as pd
- from scipy.stats import entropy
- from scipy.special import softmax
- from sklearn.linear_model import LinearRegression
- from sklearn.metrics import r2_score
- from matplotlib import pyplot as plt
- import matplotlib
- from matplotlib import pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "xelatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- plt.rcParams["text.latex.preamble"].join([
- r"\usepackage{amsmath}",
- r"\setmainfont{amssymb}",
- ])
- import seaborn as sns
- import argparse
- from os.path import join as opj
- import pickle
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--dataset", default="inspire-harvest/database")
- parser.add_argument("--suffix", default=None)
- parser.add_argument("--mle", action="store_true", default=False)
- args = parser.parse_args()
- suffix = f"_{args.suffix}" if args.suffix is not None else ""
- format = "mle" if args.mle else "samples"
- samples = np.load(opj(args.input, f"ei_{format}{suffix}.npz"))
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- junk = topics["label"].str.contains("Junk")
- topics = topics[~junk]["label"].tolist()
- fig, ax = plt.subplots()
- n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
- df = pd.read_csv(opj(args.input, "aggregate.csv"))
- NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
- NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
- expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].fillna(0).values)
- # junk = np.sum(NR + NC, axis=0) == 0
- NR = NR[:,~junk]
- NC = NC[:,~junk]
- expertise = expertise[:,~junk]
- from scipy.spatial.distance import pdist, squareform
- from fastcluster import linkage
- def seriation(Z,N,cur_index):
- if cur_index < N:
- return [cur_index]
- else:
- left = int(Z[cur_index-N,0])
- right = int(Z[cur_index-N,1])
- return (seriation(Z,N,left) + seriation(Z,N,right))
-
- def compute_serial_matrix(dist_mat,method="ward"):
- N = len(dist_mat)
- flat_dist_mat = squareform(dist_mat)
- res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
- res_order = seriation(res_linkage, N, N + N-2)
- seriated_dist = np.zeros((N,N))
- a,b = np.triu_indices(N,k=1)
- seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
- seriated_dist[b,a] = seriated_dist[a,b]
-
- return seriated_dist, res_order, res_linkage
- dist = 1-np.array([
- [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))]
- for i in range(len(topics))
- ])
- dist = np.nan_to_num(dist)
- m, order, dendo = compute_serial_matrix(dist)
- order = np.array(order)[::-1]
- print(order)
- ordered_topics = [topics[i] for i in order]
- np.save(opj(args.input, "topics_order.npy"), order)
- x = NR/NR.sum(axis=1)[:,np.newaxis]
- y = NR/NR.sum(axis=1)[:,np.newaxis]
- expertise = expertise/expertise.sum(axis=1)[:,np.newaxis]
- R = np.array([
- [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
- for i in range(len(topics))
- ])
- print(expertise)
- if not args.mle:
- fig, ax = plt.subplots()
- x = np.linspace(np.min(R), np.max(R), 100)
- y = samples["delta_0"][:,np.newaxis]+np.einsum("s,i->si", samples["delta_nu"], x)
- ax.fill_between(x, np.quantile(y,axis=0,q=0.05/2), np.quantile(y,axis=0,q=1-0.05/2), color="lightgray")
- ax.plot(x, y.mean(axis=0), color="black")
- delta = samples["delta"].mean(axis=0)
- sig = (np.quantile(samples["delta"], q=0.05/2, axis=0)*np.quantile(samples["delta"], q=1-0.05/2, axis=0)).flatten()>0
- ax.scatter(R.flatten()[sig], delta.flatten()[sig], s=2)
- ax.errorbar(
- R.flatten()[sig],
- delta.flatten()[sig],
- (delta.flatten()[sig]-np.quantile(samples["delta"], q=0.05/2, axis=0).flatten()[sig], np.quantile(samples["delta"], q=1-0.05/2, axis=0).flatten()[sig]-delta.flatten()[sig]),
- ls="none",
- lw=0.5
- )
- ax.set_xlabel("$\\nu_{kk'}$")
- ax.set_ylabel("$\\delta_{kk'}$")
- fig.savefig(opj(args.input, f"delta_vs_nu{suffix}.eps"), bbox_inches="tight")
- plt.clf()
- counts = samples["counts"].mean(axis=(0))
- counts = counts/counts.sum(axis=0)[:,np.newaxis]
- sns.heatmap(
- counts[:, order][order],
- vmin=0,
- cmap="Blues",
- xticklabels=ordered_topics,
- yticklabels=ordered_topics,
- )
- plt.xticks(rotation=90)
- plt.yticks(rotation=0)
- plt.savefig(
- opj(args.input, f"ei_counts{suffix}.eps"), bbox_inches="tight"
- )
- plt.clf()
- mu = samples["mu"][:,:, order][:,order]
- gamma = samples["gamma"][:,:, order][:,order]
- for s in range(mu.shape[0]):
- mu[s,:,:] += np.diag(np.diag(gamma[s]))
- mu = np.mean(mu, axis=0)
- mu = softmax(mu,axis=1)
- sns.heatmap(
- mu,
- vmin=0,
- vmax=np.max(mu),
- cmap="Blues",
- xticklabels=ordered_topics,
- yticklabels=ordered_topics,
- annot=[
- [
- f"{mu[i,j]:.2f}"
- if i == j
- else ""
- for j in range(len(topics))
-
- ]
- for i in range(len(topics))
- ],
- fmt="",
- annot_kws={"fontsize": 5},
- )
- plt.xticks(rotation=90)
- plt.yticks(rotation=0)
- plt.savefig(
- opj(args.input, f"ei_mu{suffix}.eps"), bbox_inches="tight"
- )
- fig, ax = plt.subplots()
- beta = mu
- ax.scatter(R.flatten(), beta.flatten())
- fig.savefig(
- opj(args.input, f"ei_mu_kl_dist.eps"),
- bbox_inches="tight",
- )
- def plot_matrix(param, hide_insignificant: bool=True):
- fig, ax = plt.subplots()
- x = samples[param][:,:, order][:,order]
- delta = x.mean(axis=(0))
-
- up = np.quantile(x, axis=0, q=1 - 0.05 / 2)
- low = np.quantile(x, axis=0, q=0.05 / 2)
- up_3s = np.quantile(x, axis=0, q=1 - 0.003 / 2)
- low_3s = np.quantile(x, axis=0, q=0.003 / 2)
- if len(delta.shape) == 1:
- ax.errorbar(
- np.arange(len(delta)), delta, yerr=(delta - low, up - delta), ls="none"
- )
- ax.scatter(np.arange(len(delta)), delta)
- ax.set_xticks(np.arange(len(delta)))
- ax.set_xticklabels(topics)
- ax.xaxis.set_tick_params(rotation=90)
- fig.savefig(
- opj(args.input, f"ei_{param}{suffix}.eps"),
- bbox_inches="tight",
- )
- else:
- significant_2s = up * low > 0
- significant_3s = up_3s * low_3s > 0
- if hide_insignificant:
- sns.heatmap(
- np.where(significant_2s, delta, np.nan),
- cmap="RdBu",
- vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
- vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
- xticklabels=ordered_topics,
- yticklabels=ordered_topics,
- ax=ax,
- fmt="",
- annot_kws={"fontsize": 6},
- )
- else:
- sns.heatmap(
- delta,
- cmap="RdBu",
- vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
- vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))),
- xticklabels=ordered_topics,
- yticklabels=ordered_topics,
- ax=ax,
- annot=[
- [
- "$\\ast\\ast$"
- if significant_3s[i, j]
- else ("$\\ast$" if significant_2s[i, j] else "")
- for j in range(len(topics))
- ]
- for i in range(len(topics))
- ],
- fmt="",
- annot_kws={"fontsize": 6},
- )
- ax.xaxis.set_tick_params(rotation=90)
- ax.yaxis.set_tick_params(rotation=0)
- fig.savefig(
- opj(args.input, f"ei_{param}{suffix}.eps"),
- bbox_inches="tight",
- )
- plt.clf()
- plt.clf()
- if not args.mle:
- plot_matrix("delta")
- plot_matrix("gamma")
- fig, ax = plt.subplots()
- RC = R[:, order][order]
- # np.fill_diagonal(RC, np.nan)
- sns.heatmap(
- 1-RC,
- cmap="Blues",
- vmin=0,
- vmax=+np.maximum(np.abs(np.min(RC[~np.isnan(RC)])), np.abs(np.max(RC[~np.isnan(RC)]))),
- xticklabels=ordered_topics,
- yticklabels=[""]*len(ordered_topics),
- ax=ax,
- fmt="",
- annot_kws={"fontsize": 6},
- )
- ax.xaxis.set_tick_params(rotation=90)
- ax.yaxis.set_tick_params(rotation=0)
- fig.savefig(
- opj(args.input, f"ei_R{suffix}.eps"),
- bbox_inches="tight",
- )
- fig.savefig(
- opj(args.input, f"ei_R{suffix}.pdf"),
- bbox_inches="tight",
- )
- fig.savefig(
- opj(args.input, f"ei_R{suffix}.png"),
- bbox_inches="tight",
- dpi=300
- )
- topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
- topic_matrix = topic_matrix[:,~junk]
- articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]]
- articles = articles[articles["date_created"].str.len() >= 4]
- articles["year"] = articles["date_created"].str[:4].astype(int)-2000
- articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
- articles["year_group"] = articles["year"]//5
- _articles = pd.read_csv(opj(args.input,"articles.csv"))
- articles["article_id"] = articles.article_id.astype(int)
- articles = _articles.merge(articles, how="left")
- articles["main_topic"] = topic_matrix.argmax(axis=1)
- references = pd.read_parquet(opj(args.dataset, "articles_references.parquet"))
- references["cites"] = references.cites.astype(int)
- references["cited"] = references.cited.astype(int)
- references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cites", right_on="article_id")
- references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cited", right_on="article_id", suffixes=("_cites", "_cited"))
- citation_matrix = np.zeros((n_topics, n_topics))
- for i in range(n_topics):
- for j in range(n_topics):
- citation_matrix[i,j] = len(references[(references["main_topic_cites"]==i)&(references["main_topic_cited"]==j)])
- fig, ax = plt.subplots()
- m = np.log(citation_matrix.sum()*citation_matrix/np.outer(citation_matrix.sum(axis=1),citation_matrix.sum(axis=0)))
- m = m[:, order][order]
- sns.heatmap(
- m,
- ax=ax,
- vmin=-np.max(np.abs(m)),
- vmax=np.max(np.abs(m)),
- cmap="RdBu",
- xticklabels=ordered_topics,
- yticklabels=ordered_topics,
- annot=[
- [
- f"$\\times$\\textbf{{{np.exp(m[i,j]):.1f}}}"
- if np.exp(m[i,j])>=1.05
- else ""
- for j in range(len(topics))
- ]
- for i in range(len(topics))
- ],
- fmt="",
- annot_kws={"fontsize": 5},
- )
- ax.xaxis.set_tick_params(rotation=90)
- ax.yaxis.set_tick_params(rotation=0)
- ax.set_xlabel("\\textbf{Cited category (references)}")
- ax.set_ylabel("\\textbf{Citing category}")
- fig.savefig(opj(args.input, "topic_citation_matrix.eps"), bbox_inches="tight")
|