import numpy as np from matplotlib import pyplot as plt import matplotlib matplotlib.use("pgf") matplotlib.rcParams.update( { "pgf.texsystem": "xelatex", "font.family": "serif", "font.serif": "Times New Roman", "text.usetex": True, "pgf.rcfonts": False, } ) plt.rcParams["text.latex.preamble"].join([ r"\usepackage{amsmath}", r"\setmainfont{amssymb}", ]) from matplotlib.gridspec import GridSpec import pandas as pd from os.path import join as opj from scipy.stats import entropy import argparse parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--suffix") parser.add_argument("--bai") parser.add_argument("--reorder-topics", action="store_true", default=False) args = parser.parse_args() inp = args.input bai = args.bai topics = pd.read_csv(opj(inp, "topics.csv")) n_topics = len(topics) junk = topics["label"].str.contains("Junk") topics = topics[~junk]["label"].tolist() df = pd.read_csv(opj(inp, "aggregate.csv")) df = df.merge(pd.read_parquet(opj(inp, "pooled_resources.parquet")), left_on="bai", right_on="bai") X = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int) Y = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int) S = np.stack(df["pooled_resources"]) expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values) # stability = pd.read_csv(opj(inp, "institutional_stability.csv"), index_col="bai") # df = df.merge(stability, left_on="bai", right_index=True) X = X[:,~junk] Y = Y[:,~junk] S = S[:,~junk] expertise = expertise[:,~junk] df["social_diversity"] = np.exp(entropy(S,axis=1)) df["intellectual_diversity"] = np.exp(entropy(expertise,axis=1)) df["social_cap"] = np.log(1+np.stack(S).sum(axis=1)) x = X/X.sum(axis=1)[:,np.newaxis] y = Y/Y.sum(axis=1)[:,np.newaxis] samples = np.load(opj(inp, f"ei_samples_{args.suffix}.npz")) m = np.einsum("skii,ki->sk", 1-samples["beta"], x) m = m.mean(axis=0) theta = samples["beta"].mean(axis=0) authors = df["bai"].tolist() n_topics = len(topics) a = authors.index(bai) beta = np.einsum("ij,i->ij", theta[a], x[a,:]) if args.reorder_topics: from scipy.spatial.distance import pdist, squareform from fastcluster import linkage def seriation(Z,N,cur_index): if cur_index < N: return [cur_index] else: left = int(Z[cur_index-N,0]) right = int(Z[cur_index-N,1]) return (seriation(Z,N,left) + seriation(Z,N,right)) def compute_serial_matrix(dist_mat,method="ward"): N = len(dist_mat) flat_dist_mat = squareform(dist_mat) res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True) res_order = seriation(res_linkage, N, N + N-2) seriated_dist = np.zeros((N,N)) a,b = np.triu_indices(N,k=1) seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]] seriated_dist[b,a] = seriated_dist[a,b] return seriated_dist, res_order, res_linkage dist = 1-np.array([ [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))] for i in range(len(topics)) ]) m, order, dendo = compute_serial_matrix(dist) order = np.array(order)[::-1] else: order = np.arange(n_topics) topics = [topics[i] for i in order] fig = plt.figure(figsize=(6.4,6.4)) gs = GridSpec(4,4,hspace=0.1,wspace=0.1) ax_joint = fig.add_subplot(gs[1:4,1:4]) ax_marg_x = fig.add_subplot(gs[0,1:4]) ax_marg_y = fig.add_subplot(gs[1:4,0]) ax_joint.set_xlim(-0.5,n_topics-0.5) ax_marg_x.set_xlim(-0.5,n_topics-0.5) ax_marg_y.set_ylim(-0.5,n_topics-0.5) ax_joint.imshow(beta[:, order][order], cmap="Greys", aspect='auto', vmin=0, vmax=0.5) ax_marg_x.bar(np.arange(n_topics), height=y[a,order], width=1, color="red") ax_marg_y.barh(n_topics-np.arange(n_topics)-1, width=x[a,order], height=1, orientation="horizontal") common_scale = np.maximum(np.max(x[a,order]),np.max(y[a,order])) ax_marg_x.set_ylim(0,common_scale) ax_marg_y.set_xlim(0,common_scale) ax_marg_y.invert_xaxis() # Turn off tick labels on marginals plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_x.get_yticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) plt.setp(ax_marg_y.get_xticklabels(), visible=False) ax_joint.yaxis.tick_right() ax_joint.set_xticks(np.arange(n_topics), np.arange(n_topics)) ax_joint.set_xticklabels(topics, rotation = 90) ax_joint.set_yticks(np.arange(n_topics), np.arange(n_topics)) ax_joint.set_yticklabels(topics) plt.show() fig.savefig(opj(inp, f"trajectory_example_{bai}.eps"), bbox_inches="tight")