from calendar import c import numpy as np import pandas as pd from scipy.stats import norm from scipy.special import softmax import cvxpy as cp import ot from sklearn.linear_model import LinearRegression from scipy.linalg import logm from matplotlib import pyplot as plt import matplotlib matplotlib.use("pgf") matplotlib.rcParams.update( { "pgf.texsystem": "xelatex", "font.family": "serif", "font.serif": "Times New Roman", "text.usetex": True, "pgf.rcfonts": False, 'mathtext.default': 'regular', } ) plt.rcParams["text.latex.preamble"].join([ r"\usepackage{amsmath}", r"\usepackage{bm}", r"\setmainfont{amssymb}", ]) import seaborn as sns import argparse from os.path import join as opj, exists import pickle from cmdstanpy import CmdStanModel parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--suffix", default=None) parser.add_argument("--model", default="knowledge", choices=["knowledge", "identity", "random", "etm", "linguistic", "linguistic_symmetric"]) parser.add_argument("--prior", default="bounded", choices=["bounded"]) parser.add_argument("--steps", default=1000000, type=int) parser.add_argument("--burnin", default=50000, type=int) parser.add_argument("--thin", default=100, type=int) parser.add_argument("--alpha-prior", default=5, type=float) args = parser.parse_args() suffix = f"_{args.suffix}" if args.suffix is not None else "" samples = np.load(opj(args.input, f"ei_samples{suffix}.npz")) topics = pd.read_csv(opj(args.input, "topics.csv")) junk = topics["label"].str.contains("Junk") topics = topics[~junk]["label"].tolist() fig, ax = plt.subplots() n_topics = len(pd.read_csv(opj(args.input, "topics.csv"))) df = pd.read_csv(opj(args.input, "aggregate.csv")) resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet")) df = df.merge(resources, left_on="bai", right_on="bai") NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int) NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int) expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values) S = np.stack(df["pooled_resources"]) # junk = np.sum(NR + NC, axis=0) == 0 N = NR.shape[0] NR = NR[:,~junk] NC = NC[:,~junk] expertise = expertise[:,~junk] S = S[:,~junk] x = NR/NR.sum(axis=1)[:,np.newaxis] y = NC/NC.sum(axis=1)[:,np.newaxis] S_distrib = S/S.sum(axis=1)[:,np.newaxis] print(S_distrib) R = np.array([ [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))] for i in range(len(topics)) ]) K = expertise.shape[1] # observed couplings theta = samples["beta"].mean(axis=0) theta = np.einsum("ai,aij->ij", x, theta) order = np.load(opj(args.input, "topics_order.npy")) def mcmc_bounded(T, x, alpha_prior, sigma, steps=1000): # x = x/x.std() T = T/T.sum() m = T.shape[0] n = T.shape[1] K = np.zeros((steps+1, m, n)) lambd = m*n*3 # Transform K in a way that sends C to the prior support # while preserving cross-ratios Dc = cp.Variable(m) prob = cp.Problem( cp.Minimize(cp.sum(cp.abs(Dc))), [ m*cp.sum(Dc)==-np.sum(np.log(T))-lambd, # C sums to m*n*lambda Dc <= -np.log(np.max(T, axis=0)) # C is positive ] ) prob.solve(verbose=True) K[0] = T@np.diag(np.exp(Dc.value)) beta = np.random.randn(steps+1) C = np.zeros((steps+1, m, n)) C[0] = -np.log(K[0])/lambd accepted = np.array([False]*(steps+1)) accepted[0] = True oob = np.array([False]*(steps+1)) beta_prior = norm(loc=0,scale=1) for i in range(steps): Dr = np.random.randn(m)*sigma Dr = np.exp(Dr) Dc = 1/Dr # preserve the sum of C beta[i+1] = np.random.randn()*sigma+beta[i] K[i+1] = np.diag(Dr)@K[i]@np.diag(Dc) C[i+1] = -np.log(K[i+1])/lambd distrib_prop = softmax(x.flatten()*beta[i+1]) distrib_prev = softmax(x.flatten()*beta[i]) oob[i+1] = np.abs(C[i+1].sum()-1)>1e-6 or np.any(C[i+1]<0) if not oob[i+1]: p_prop = beta_prior.logpdf(beta[i+1]) p_prev = beta_prior.logpdf(beta[i]) p_prop += -alpha_prior*(C[i+1].flatten()*np.log(C[i+1].flatten()/distrib_prop)).sum() - 0.5*np.log(C[i+1].flatten()).sum() p_prev += -alpha_prior*(C[i].flatten()*np.log(C[i].flatten()/distrib_prev)).sum() - 0.5*np.log(C[i].flatten()).sum() a = p_prop-p_prev u = np.random.uniform(0, 1) if oob[i+1] or a <= np.log(u): C[i+1] = C[i] K[i+1] = K[i] beta[i+1] = beta[i] accepted[i+1] = False else: accepted[i+1] = True if i % 1000 == 0: print(f"step {i}/{steps}, rate={accepted[:i].mean():.3f}, oob={oob[:i].mean():.3f}, acc={accepted[:i].sum():.0f}") print(f"beta: {beta[:i].mean():.2f}, beta batch: {beta[i-1000:i].mean():.2f}, std batch: {beta[i-1000:i].std():.2f}") return C, beta, accepted output = opj(args.input, f"cost_{args.model}_{args.prior}.npz") if args.model == "knowledge": matrix = 1-np.load(opj(args.input, "nu_expertise.npy")) elif args.model == "etm": matrix = 1-np.load(opj(args.input, "nu_etm.npy")) elif args.model == "identity": matrix = 1-np.eye(K) elif args.model == "random": matrix = np.random.uniform(0, 1, size=(K,K)) elif args.model == "linguistic": matrix = np.load(opj(args.input, "nu_ling.npy")) elif args.model == "linguistic_symmetric": matrix = np.load(opj(args.input, "nu_ling_symmetric.npy")) fig, ax = plt.subplots() sns.heatmap( matrix[:, order][order], cmap="Reds", vmin=0, vmax=+np.max(np.abs(matrix)), xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], ax=ax, ) fig.savefig(opj(args.input, f"linguistic_gap_{args.model}_{args.prior}.eps"), bbox_inches="tight") matrix_sd = 1 if not exists(output): if args.model in ["knowledge", "etm"]: C, beta, accepted = mcmc_bounded(theta, matrix, args.alpha_prior*K*K, 0.1, steps=args.steps) else: C, beta, accepted = mcmc_bounded(theta, matrix/matrix.std(), args.alpha_prior*K*K, 0.1, steps=args.steps) C = C[args.burnin::args.thin] beta = beta[args.burnin::args.thin] accepted = accepted[args.burnin::args.thin] np.savez_compressed(output, C=C, beta=beta) else: samples = np.load(output) C = samples["C"] beta = samples["beta"] print(beta.mean()) print(beta.std()) res = C-np.einsum("s,ij->sij", beta, matrix/matrix_sd) delta = res.mean(axis=0) res = (res**2).mean(axis=(1,2)) var = np.array([C[s].flatten().var() for s in range(C.shape[0])]) res = res/var res = 1-res print(res.mean()) fig, ax = plt.subplots() sns.heatmap( C.mean(axis=0)[:, order][order], xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], cmap="Reds", vmin=+np.min(C.mean(axis=0)), vmax=+np.max(C.mean(axis=0)), ax=ax, ) fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.eps"), bbox_inches="tight") pearson = np.corrcoef(C.mean(axis=0).flatten(), matrix.flatten())[0,1] print("R:", pearson) print("R^2:", pearson**2) reg = LinearRegression() fit = reg.fit(matrix.flatten().reshape(-1, 1),C.mean(axis=0).flatten()) if args.model == "knowledge": fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2)) xs = np.linspace(0, 1, 4) ax.plot(1-xs, fit.predict(xs.reshape(-1, 1)), color="black") ax.scatter(1-matrix.flatten(), C.mean(axis=0).flatten(), s=4) # error bars are boring as they only reflect the degeneracy of the cost matrix # low = np.quantile(C, q=0.05/2, axis=0) # up = np.quantile(C, q=1-0.05/2, axis=0) # mean = C.mean(axis=0) # ax.errorbar( # 1-matrix.flatten(), # mean.flatten(), # (np.maximum(mean.flatten()-low.flatten(), 0), np.maximum(up.flatten()-mean.flatten(), 0)), # ls="none", # lw=0.5 # ) ax.set_xlabel("Fraction of physicists with expertise in $k'$\namong those with expertise in $k$ ($\\nu_{k,k'}$)") # pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, (1-matrix.flatten())/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1] ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes) ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)") fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight") elif args.model == "identity": fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2)) ax.axline((0,0), slope=-beta.mean(axis=0)/matrix_sd, color="black") ax.scatter((1-matrix).flatten(), C.mean(axis=0).flatten(), s=4) ax.set_xlabel("1 if $k=k'$, 0 otherwise") ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes) ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)") fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight") elif args.model == "linguistic": fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2)) ax.axline((0,0), slope=beta.mean(axis=0)/matrix_sd, color="black") ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4) ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$") ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes) ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)") fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight") elif args.model == "linguistic_symmetric": fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2)) ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4) ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$") pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, matrix.flatten()/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1] ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes) ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)") fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight") # predicted transfers origin = x.mean(axis=0) target = y.mean(axis=0) fig, ax = plt.subplots() shifts = theta[:, order][order]/theta.sum() sig = shifts>origin[order]*target[order] shifts = shifts/shifts.sum(axis=1)[:,np.newaxis] sns.heatmap( shifts, xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], cmap="Blues", vmin=0, ax=ax, annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))], fmt="", annot_kws={"fontsize": 6}, ) fig.savefig(opj(args.input, f"cost_matrix_true_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight") T = ot.sinkhorn( origin, target, softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model in ["knowledge", "etm"] else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0), 1/(3*K*K) ) shifts = T[:, order][order] sig = shifts>origin[order]*target[order] shifts = shifts/shifts.sum(axis=1)[:,np.newaxis] fig, ax = plt.subplots() sns.heatmap( shifts, xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], cmap="Blues", vmin=0, ax=ax, annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))], fmt="", annot_kws={"fontsize": 6}, ) fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight") T_baseline = ot.sinkhorn( origin, target, (1-np.identity(K))/K, 50/(10*K*K) ) fig, ax = plt.subplots() shifts = T_baseline[:, order][order] sig = shifts>origin[order]*target[order] shifts = shifts/shifts.sum(axis=1)[:,np.newaxis] fig, ax = plt.subplots() sns.heatmap( shifts, xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], cmap="Blues", vmin=0, ax=ax, annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))], fmt="", annot_kws={"fontsize": 6}, ) fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_identity.eps"), bbox_inches="tight") def tv_dist(x, y): return np.abs(y/y.sum()-x/x.sum()).sum()/2 lambdas = np.logspace(np.log10(1/(5*10*K*K)), np.log10(100/(K*K)), 200) perf = [] baseline = [] for l in lambdas: T = ot.sinkhorn( origin, target, softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model == "knowledge" else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0), l ) T_baseline = ot.sinkhorn( origin, target, (1-np.identity(K))/K, l ) perf.append(tv_dist(T.flatten(), theta.flatten())) baseline.append(tv_dist(T_baseline.flatten(), theta.flatten())) fig, ax = plt.subplots() ax.plot(lambdas, perf, label=f"{args.model} ({np.min(perf):.3f})") ax.plot(lambdas, baseline, label=f"baseline ({np.min(baseline):.3f})") ax.set_xscale("log") fig.legend() fig.savefig(opj(args.input, f"performance_{args.model}_{args.prior}.eps"), bbox_inches="tight") # counterfactual T = ot.sinkhorn( origin, target, C.mean(axis=0)/C.mean(axis=0).sum(), 1/(3*K*K) ) shifts = T[:, order][order] sig = shifts>origin[order]*target[order] shifts = shifts/shifts.sum(axis=1)[:,np.newaxis] fig, ax = plt.subplots() sns.heatmap( shifts, xticklabels=[topics[i] for i in order], yticklabels=[topics[i] for i in order], cmap="Blues", vmin=0, ax=ax, annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))], fmt="", annot_kws={"fontsize": 6}, ) fig.savefig(opj(args.input, f"cost_matrix_counterfactual_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")