123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- from calendar import c
- import numpy as np
- import pandas as pd
- from scipy.stats import norm
- from scipy.special import softmax
- import cvxpy as cp
- import ot
- from sklearn.linear_model import LinearRegression
- from scipy.linalg import logm
- from matplotlib import pyplot as plt
- import matplotlib
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "xelatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- 'mathtext.default': 'regular',
- }
- )
- plt.rcParams["text.latex.preamble"].join([
- r"\usepackage{amsmath}",
- r"\usepackage{bm}",
- r"\setmainfont{amssymb}",
- ])
- import seaborn as sns
- import argparse
- from os.path import join as opj, exists
- import pickle
- from cmdstanpy import CmdStanModel
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--suffix", default=None)
- parser.add_argument("--model", default="knowledge", choices=["knowledge", "identity", "random", "etm", "linguistic", "linguistic_symmetric"])
- parser.add_argument("--prior", default="bounded", choices=["bounded"])
- parser.add_argument("--steps", default=1000000, type=int)
- parser.add_argument("--burnin", default=50000, type=int)
- parser.add_argument("--thin", default=100, type=int)
- parser.add_argument("--alpha-prior", default=5, type=float)
- args = parser.parse_args()
- suffix = f"_{args.suffix}" if args.suffix is not None else ""
- samples = np.load(opj(args.input, f"ei_samples{suffix}.npz"))
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- junk = topics["label"].str.contains("Junk")
- topics = topics[~junk]["label"].tolist()
- fig, ax = plt.subplots()
- n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
- df = pd.read_csv(opj(args.input, "aggregate.csv"))
- resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
- df = df.merge(resources, left_on="bai", right_on="bai")
- NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
- NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
- expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values)
- S = np.stack(df["pooled_resources"])
- # junk = np.sum(NR + NC, axis=0) == 0
- N = NR.shape[0]
- NR = NR[:,~junk]
- NC = NC[:,~junk]
- expertise = expertise[:,~junk]
- S = S[:,~junk]
- x = NR/NR.sum(axis=1)[:,np.newaxis]
- y = NC/NC.sum(axis=1)[:,np.newaxis]
- S_distrib = S/S.sum(axis=1)[:,np.newaxis]
- print(S_distrib)
- R = np.array([
- [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
- for i in range(len(topics))
- ])
- K = expertise.shape[1]
- # observed couplings
- theta = samples["beta"].mean(axis=0)
- theta = np.einsum("ai,aij->ij", x, theta)
- order = np.load(opj(args.input, "topics_order.npy"))
- def mcmc_bounded(T, x, alpha_prior, sigma, steps=1000):
- # x = x/x.std()
- T = T/T.sum()
- m = T.shape[0]
- n = T.shape[1]
- K = np.zeros((steps+1, m, n))
- lambd = m*n*3
- # Transform K in a way that sends C to the prior support
- # while preserving cross-ratios
- Dc = cp.Variable(m)
- prob = cp.Problem(
- cp.Minimize(cp.sum(cp.abs(Dc))),
- [
- m*cp.sum(Dc)==-np.sum(np.log(T))-lambd, # C sums to m*n*lambda
- Dc <= -np.log(np.max(T, axis=0)) # C is positive
- ]
- )
- prob.solve(verbose=True)
- K[0] = T@np.diag(np.exp(Dc.value))
- beta = np.random.randn(steps+1)
- C = np.zeros((steps+1, m, n))
- C[0] = -np.log(K[0])/lambd
-
- accepted = np.array([False]*(steps+1))
- accepted[0] = True
- oob = np.array([False]*(steps+1))
- beta_prior = norm(loc=0,scale=1)
-
- for i in range(steps):
- Dr = np.random.randn(m)*sigma
- Dr = np.exp(Dr)
- Dc = 1/Dr # preserve the sum of C
- beta[i+1] = np.random.randn()*sigma+beta[i]
- K[i+1] = np.diag(Dr)@K[i]@np.diag(Dc)
- C[i+1] = -np.log(K[i+1])/lambd
-
- distrib_prop = softmax(x.flatten()*beta[i+1])
- distrib_prev = softmax(x.flatten()*beta[i])
- oob[i+1] = np.abs(C[i+1].sum()-1)>1e-6 or np.any(C[i+1]<0)
- if not oob[i+1]:
- p_prop = beta_prior.logpdf(beta[i+1])
- p_prev = beta_prior.logpdf(beta[i])
- p_prop += -alpha_prior*(C[i+1].flatten()*np.log(C[i+1].flatten()/distrib_prop)).sum() - 0.5*np.log(C[i+1].flatten()).sum()
- p_prev += -alpha_prior*(C[i].flatten()*np.log(C[i].flatten()/distrib_prev)).sum() - 0.5*np.log(C[i].flatten()).sum()
- a = p_prop-p_prev
- u = np.random.uniform(0, 1)
- if oob[i+1] or a <= np.log(u):
- C[i+1] = C[i]
- K[i+1] = K[i]
- beta[i+1] = beta[i]
- accepted[i+1] = False
- else:
- accepted[i+1] = True
- if i % 1000 == 0:
- print(f"step {i}/{steps}, rate={accepted[:i].mean():.3f}, oob={oob[:i].mean():.3f}, acc={accepted[:i].sum():.0f}")
- print(f"beta: {beta[:i].mean():.2f}, beta batch: {beta[i-1000:i].mean():.2f}, std batch: {beta[i-1000:i].std():.2f}")
- return C, beta, accepted
- output = opj(args.input, f"cost_{args.model}_{args.prior}.npz")
- if args.model == "knowledge":
- matrix = 1-np.load(opj(args.input, "nu_expertise.npy"))
- elif args.model == "etm":
- matrix = 1-np.load(opj(args.input, "nu_etm.npy"))
- elif args.model == "identity":
- matrix = 1-np.eye(K)
- elif args.model == "random":
- matrix = np.random.uniform(0, 1, size=(K,K))
- elif args.model == "linguistic":
- matrix = np.load(opj(args.input, "nu_ling.npy"))
- elif args.model == "linguistic_symmetric":
- matrix = np.load(opj(args.input, "nu_ling_symmetric.npy"))
- fig, ax = plt.subplots()
- sns.heatmap(
- matrix[:, order][order],
- cmap="Reds",
- vmin=0,
- vmax=+np.max(np.abs(matrix)),
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- ax=ax,
- )
- fig.savefig(opj(args.input, f"linguistic_gap_{args.model}_{args.prior}.eps"), bbox_inches="tight")
- matrix_sd = 1
- if not exists(output):
- if args.model in ["knowledge", "etm"]:
- C, beta, accepted = mcmc_bounded(theta, matrix, args.alpha_prior*K*K, 0.1, steps=args.steps)
- else:
- C, beta, accepted = mcmc_bounded(theta, matrix/matrix.std(), args.alpha_prior*K*K, 0.1, steps=args.steps)
- C = C[args.burnin::args.thin]
- beta = beta[args.burnin::args.thin]
- accepted = accepted[args.burnin::args.thin]
- np.savez_compressed(output, C=C, beta=beta)
- else:
- samples = np.load(output)
- C = samples["C"]
- beta = samples["beta"]
- print(beta.mean())
- print(beta.std())
- res = C-np.einsum("s,ij->sij", beta, matrix/matrix_sd)
- delta = res.mean(axis=0)
- res = (res**2).mean(axis=(1,2))
- var = np.array([C[s].flatten().var() for s in range(C.shape[0])])
- res = res/var
- res = 1-res
- print(res.mean())
- fig, ax = plt.subplots()
- sns.heatmap(
- C.mean(axis=0)[:, order][order],
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- cmap="Reds",
- vmin=+np.min(C.mean(axis=0)),
- vmax=+np.max(C.mean(axis=0)),
- ax=ax,
- )
- fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.eps"), bbox_inches="tight")
- pearson = np.corrcoef(C.mean(axis=0).flatten(), matrix.flatten())[0,1]
- print("R:", pearson)
- print("R^2:", pearson**2)
- reg = LinearRegression()
- fit = reg.fit(matrix.flatten().reshape(-1, 1),C.mean(axis=0).flatten())
- if args.model == "knowledge":
- fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
- xs = np.linspace(0, 1, 4)
- ax.plot(1-xs, fit.predict(xs.reshape(-1, 1)), color="black")
- ax.scatter(1-matrix.flatten(), C.mean(axis=0).flatten(), s=4)
-
- # error bars are boring as they only reflect the degeneracy of the cost matrix
- # low = np.quantile(C, q=0.05/2, axis=0)
- # up = np.quantile(C, q=1-0.05/2, axis=0)
- # mean = C.mean(axis=0)
- # ax.errorbar(
- # 1-matrix.flatten(),
- # mean.flatten(),
- # (np.maximum(mean.flatten()-low.flatten(), 0), np.maximum(up.flatten()-mean.flatten(), 0)),
- # ls="none",
- # lw=0.5
- # )
- ax.set_xlabel("Fraction of physicists with expertise in $k'$\namong those with expertise in $k$ ($\\nu_{k,k'}$)")
- # pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, (1-matrix.flatten())/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
- ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
- ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
- fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
- elif args.model == "identity":
- fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
- ax.axline((0,0), slope=-beta.mean(axis=0)/matrix_sd, color="black")
- ax.scatter((1-matrix).flatten(), C.mean(axis=0).flatten(), s=4)
- ax.set_xlabel("1 if $k=k'$, 0 otherwise")
- ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
- ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
- fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
- elif args.model == "linguistic":
- fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
- ax.axline((0,0), slope=beta.mean(axis=0)/matrix_sd, color="black")
- ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
- ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
- ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
- ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
- fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
- elif args.model == "linguistic_symmetric":
- fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
- ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
- ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
- pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, matrix.flatten()/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
- ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
- ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
- fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
- # predicted transfers
- origin = x.mean(axis=0)
- target = y.mean(axis=0)
- fig, ax = plt.subplots()
- shifts = theta[:, order][order]/theta.sum()
- sig = shifts>origin[order]*target[order]
- shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
- sns.heatmap(
- shifts,
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- cmap="Blues",
- vmin=0,
- ax=ax,
- annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
- fmt="",
- annot_kws={"fontsize": 6},
- )
- fig.savefig(opj(args.input, f"cost_matrix_true_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
- T = ot.sinkhorn(
- origin,
- target,
- softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model in ["knowledge", "etm"] else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
- 1/(3*K*K)
- )
- shifts = T[:, order][order]
- sig = shifts>origin[order]*target[order]
- shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
- fig, ax = plt.subplots()
- sns.heatmap(
- shifts,
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- cmap="Blues",
- vmin=0,
- ax=ax,
- annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
- fmt="",
- annot_kws={"fontsize": 6},
- )
- fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
- T_baseline = ot.sinkhorn(
- origin,
- target,
- (1-np.identity(K))/K,
- 50/(10*K*K)
- )
- fig, ax = plt.subplots()
- shifts = T_baseline[:, order][order]
- sig = shifts>origin[order]*target[order]
- shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
- fig, ax = plt.subplots()
- sns.heatmap(
- shifts,
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- cmap="Blues",
- vmin=0,
- ax=ax,
- annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
- fmt="",
- annot_kws={"fontsize": 6},
- )
- fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_identity.eps"), bbox_inches="tight")
- def tv_dist(x, y):
- return np.abs(y/y.sum()-x/x.sum()).sum()/2
- lambdas = np.logspace(np.log10(1/(5*10*K*K)), np.log10(100/(K*K)), 200)
- perf = []
- baseline = []
- for l in lambdas:
- T = ot.sinkhorn(
- origin,
- target,
- softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model == "knowledge" else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
- l
- )
-
- T_baseline = ot.sinkhorn(
- origin,
- target,
- (1-np.identity(K))/K,
- l
- )
- perf.append(tv_dist(T.flatten(), theta.flatten()))
- baseline.append(tv_dist(T_baseline.flatten(), theta.flatten()))
- fig, ax = plt.subplots()
- ax.plot(lambdas, perf, label=f"{args.model} ({np.min(perf):.3f})")
- ax.plot(lambdas, baseline, label=f"baseline ({np.min(baseline):.3f})")
- ax.set_xscale("log")
- fig.legend()
- fig.savefig(opj(args.input, f"performance_{args.model}_{args.prior}.eps"), bbox_inches="tight")
- # counterfactual
- T = ot.sinkhorn(
- origin,
- target,
- C.mean(axis=0)/C.mean(axis=0).sum(),
- 1/(3*K*K)
- )
- shifts = T[:, order][order]
- sig = shifts>origin[order]*target[order]
- shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
- fig, ax = plt.subplots()
- sns.heatmap(
- shifts,
- xticklabels=[topics[i] for i in order],
- yticklabels=[topics[i] for i in order],
- cmap="Blues",
- vmin=0,
- ax=ax,
- annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
- fmt="",
- annot_kws={"fontsize": 6},
- )
- fig.savefig(opj(args.input, f"cost_matrix_counterfactual_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
|