import numpy as np import pandas as pd from scipy.stats import entropy from scipy.special import softmax from sklearn.linear_model import LinearRegression from sklearn.metrics import r2_score from matplotlib import pyplot as plt import matplotlib from matplotlib import pyplot as plt matplotlib.use("pgf") matplotlib.rcParams.update( { "pgf.texsystem": "xelatex", "font.family": "serif", "font.serif": "Times New Roman", "text.usetex": True, "pgf.rcfonts": False, } ) plt.rcParams["text.latex.preamble"].join([ r"\usepackage{amsmath}", r"\setmainfont{amssymb}", ]) import seaborn as sns import argparse from os.path import join as opj import pickle parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--dataset", default="inspire-harvest/database") parser.add_argument("--suffix", default=None) parser.add_argument("--mle", action="store_true", default=False) args = parser.parse_args() suffix = f"_{args.suffix}" if args.suffix is not None else "" format = "mle" if args.mle else "samples" samples = np.load(opj(args.input, f"ei_{format}{suffix}.npz")) topics = pd.read_csv(opj(args.input, "topics.csv")) junk = topics["label"].str.contains("Junk") topics = topics[~junk]["label"].tolist() fig, ax = plt.subplots() n_topics = len(pd.read_csv(opj(args.input, "topics.csv"))) df = pd.read_csv(opj(args.input, "aggregate.csv")) NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int) NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int) expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].fillna(0).values) # junk = np.sum(NR + NC, axis=0) == 0 NR = NR[:,~junk] NC = NC[:,~junk] expertise = expertise[:,~junk] from scipy.spatial.distance import pdist, squareform from fastcluster import linkage def seriation(Z,N,cur_index): if cur_index < N: return [cur_index] else: left = int(Z[cur_index-N,0]) right = int(Z[cur_index-N,1]) return (seriation(Z,N,left) + seriation(Z,N,right)) def compute_serial_matrix(dist_mat,method="ward"): N = len(dist_mat) flat_dist_mat = squareform(dist_mat) res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True) res_order = seriation(res_linkage, N, N + N-2) seriated_dist = np.zeros((N,N)) a,b = np.triu_indices(N,k=1) seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]] seriated_dist[b,a] = seriated_dist[a,b] return seriated_dist, res_order, res_linkage dist = 1-np.array([ [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))] for i in range(len(topics)) ]) dist = np.nan_to_num(dist) m, order, dendo = compute_serial_matrix(dist) order = np.array(order)[::-1] print(order) ordered_topics = [topics[i] for i in order] np.save(opj(args.input, "topics_order.npy"), order) x = NR/NR.sum(axis=1)[:,np.newaxis] y = NR/NR.sum(axis=1)[:,np.newaxis] expertise = expertise/expertise.sum(axis=1)[:,np.newaxis] R = np.array([ [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))] for i in range(len(topics)) ]) print(expertise) if not args.mle: fig, ax = plt.subplots() x = np.linspace(np.min(R), np.max(R), 100) y = samples["delta_0"][:,np.newaxis]+np.einsum("s,i->si", samples["delta_nu"], x) ax.fill_between(x, np.quantile(y,axis=0,q=0.05/2), np.quantile(y,axis=0,q=1-0.05/2), color="lightgray") ax.plot(x, y.mean(axis=0), color="black") delta = samples["delta"].mean(axis=0) sig = (np.quantile(samples["delta"], q=0.05/2, axis=0)*np.quantile(samples["delta"], q=1-0.05/2, axis=0)).flatten()>0 ax.scatter(R.flatten()[sig], delta.flatten()[sig], s=2) ax.errorbar( R.flatten()[sig], delta.flatten()[sig], (delta.flatten()[sig]-np.quantile(samples["delta"], q=0.05/2, axis=0).flatten()[sig], np.quantile(samples["delta"], q=1-0.05/2, axis=0).flatten()[sig]-delta.flatten()[sig]), ls="none", lw=0.5 ) ax.set_xlabel("$\\nu_{kk'}$") ax.set_ylabel("$\\delta_{kk'}$") fig.savefig(opj(args.input, f"delta_vs_nu{suffix}.eps"), bbox_inches="tight") plt.clf() counts = samples["counts"].mean(axis=(0)) counts = counts/counts.sum(axis=0)[:,np.newaxis] sns.heatmap( counts[:, order][order], vmin=0, cmap="Blues", xticklabels=ordered_topics, yticklabels=ordered_topics, ) plt.xticks(rotation=90) plt.yticks(rotation=0) plt.savefig( opj(args.input, f"ei_counts{suffix}.eps"), bbox_inches="tight" ) plt.clf() mu = samples["mu"][:,:, order][:,order] gamma = samples["gamma"][:,:, order][:,order] for s in range(mu.shape[0]): mu[s,:,:] += np.diag(np.diag(gamma[s])) mu = np.mean(mu, axis=0) mu = softmax(mu,axis=1) sns.heatmap( mu, vmin=0, vmax=np.max(mu), cmap="Blues", xticklabels=ordered_topics, yticklabels=ordered_topics, annot=[ [ f"{mu[i,j]:.2f}" if i == j else "" for j in range(len(topics)) ] for i in range(len(topics)) ], fmt="", annot_kws={"fontsize": 5}, ) plt.xticks(rotation=90) plt.yticks(rotation=0) plt.savefig( opj(args.input, f"ei_mu{suffix}.eps"), bbox_inches="tight" ) fig, ax = plt.subplots() beta = mu ax.scatter(R.flatten(), beta.flatten()) fig.savefig( opj(args.input, f"ei_mu_kl_dist.eps"), bbox_inches="tight", ) def plot_matrix(param, hide_insignificant: bool=True): fig, ax = plt.subplots() x = samples[param][:,:, order][:,order] delta = x.mean(axis=(0)) up = np.quantile(x, axis=0, q=1 - 0.05 / 2) low = np.quantile(x, axis=0, q=0.05 / 2) up_3s = np.quantile(x, axis=0, q=1 - 0.003 / 2) low_3s = np.quantile(x, axis=0, q=0.003 / 2) if len(delta.shape) == 1: ax.errorbar( np.arange(len(delta)), delta, yerr=(delta - low, up - delta), ls="none" ) ax.scatter(np.arange(len(delta)), delta) ax.set_xticks(np.arange(len(delta))) ax.set_xticklabels(topics) ax.xaxis.set_tick_params(rotation=90) fig.savefig( opj(args.input, f"ei_{param}{suffix}.eps"), bbox_inches="tight", ) else: significant_2s = up * low > 0 significant_3s = up_3s * low_3s > 0 if hide_insignificant: sns.heatmap( np.where(significant_2s, delta, np.nan), cmap="RdBu", vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))), vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))), xticklabels=ordered_topics, yticklabels=ordered_topics, ax=ax, fmt="", annot_kws={"fontsize": 6}, ) else: sns.heatmap( delta, cmap="RdBu", vmin=-np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))), vmax=+np.maximum(np.abs(np.min(delta)), np.abs(np.max(delta))), xticklabels=ordered_topics, yticklabels=ordered_topics, ax=ax, annot=[ [ "$\\ast\\ast$" if significant_3s[i, j] else ("$\\ast$" if significant_2s[i, j] else "") for j in range(len(topics)) ] for i in range(len(topics)) ], fmt="", annot_kws={"fontsize": 6}, ) ax.xaxis.set_tick_params(rotation=90) ax.yaxis.set_tick_params(rotation=0) fig.savefig( opj(args.input, f"ei_{param}{suffix}.eps"), bbox_inches="tight", ) plt.clf() plt.clf() if not args.mle: plot_matrix("delta") plot_matrix("gamma") fig, ax = plt.subplots() RC = R[:, order][order] # np.fill_diagonal(RC, np.nan) sns.heatmap( 1-RC, cmap="Blues", vmin=0, vmax=+np.maximum(np.abs(np.min(RC[~np.isnan(RC)])), np.abs(np.max(RC[~np.isnan(RC)]))), xticklabels=ordered_topics, yticklabels=[""]*len(ordered_topics), ax=ax, fmt="", annot_kws={"fontsize": 6}, ) ax.xaxis.set_tick_params(rotation=90) ax.yaxis.set_tick_params(rotation=0) fig.savefig( opj(args.input, f"ei_R{suffix}.eps"), bbox_inches="tight", ) fig.savefig( opj(args.input, f"ei_R{suffix}.pdf"), bbox_inches="tight", ) fig.savefig( opj(args.input, f"ei_R{suffix}.png"), bbox_inches="tight", dpi=300 ) topic_matrix = np.load(opj(args.input, "topics_counts.npy")) topic_matrix = topic_matrix[:,~junk] articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]] articles = articles[articles["date_created"].str.len() >= 4] articles["year"] = articles["date_created"].str[:4].astype(int)-2000 articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)] articles["year_group"] = articles["year"]//5 _articles = pd.read_csv(opj(args.input,"articles.csv")) articles["article_id"] = articles.article_id.astype(int) articles = _articles.merge(articles, how="left") articles["main_topic"] = topic_matrix.argmax(axis=1) references = pd.read_parquet(opj(args.dataset, "articles_references.parquet")) references["cites"] = references.cites.astype(int) references["cited"] = references.cited.astype(int) references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cites", right_on="article_id") references = references.merge(articles[["article_id", "main_topic"]], how="inner", left_on="cited", right_on="article_id", suffixes=("_cites", "_cited")) citation_matrix = np.zeros((n_topics, n_topics)) for i in range(n_topics): for j in range(n_topics): citation_matrix[i,j] = len(references[(references["main_topic_cites"]==i)&(references["main_topic_cited"]==j)]) fig, ax = plt.subplots() m = np.log(citation_matrix.sum()*citation_matrix/np.outer(citation_matrix.sum(axis=1),citation_matrix.sum(axis=0))) m = m[:, order][order] sns.heatmap( m, ax=ax, vmin=-np.max(np.abs(m)), vmax=np.max(np.abs(m)), cmap="RdBu", xticklabels=ordered_topics, yticklabels=ordered_topics, annot=[ [ f"$\\times$\\textbf{{{np.exp(m[i,j]):.1f}}}" if np.exp(m[i,j])>=1.05 else "" for j in range(len(topics)) ] for i in range(len(topics)) ], fmt="", annot_kws={"fontsize": 5}, ) ax.xaxis.set_tick_params(rotation=90) ax.yaxis.set_tick_params(rotation=0) ax.set_xlabel("\\textbf{Cited category (references)}") ax.set_ylabel("\\textbf{Citing category}") fig.savefig(opj(args.input, "topic_citation_matrix.eps"), bbox_inches="tight")