optimal_transport.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. from calendar import c
  2. import numpy as np
  3. import pandas as pd
  4. from scipy.stats import norm
  5. from scipy.special import softmax
  6. import cvxpy as cp
  7. import ot
  8. from sklearn.linear_model import LinearRegression
  9. from scipy.linalg import logm
  10. from matplotlib import pyplot as plt
  11. import matplotlib
  12. matplotlib.use("pgf")
  13. matplotlib.rcParams.update(
  14. {
  15. "pgf.texsystem": "xelatex",
  16. "font.family": "serif",
  17. "font.serif": "Times New Roman",
  18. "text.usetex": True,
  19. "pgf.rcfonts": False,
  20. 'mathtext.default': 'regular',
  21. }
  22. )
  23. plt.rcParams["text.latex.preamble"].join([
  24. r"\usepackage{amsmath}",
  25. r"\usepackage{bm}",
  26. r"\setmainfont{amssymb}",
  27. ])
  28. import seaborn as sns
  29. import argparse
  30. from os.path import join as opj, exists
  31. import pickle
  32. from cmdstanpy import CmdStanModel
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument("--input")
  35. parser.add_argument("--suffix", default=None)
  36. parser.add_argument("--model", default="knowledge", choices=["knowledge", "identity", "random", "etm", "linguistic", "linguistic_symmetric"])
  37. parser.add_argument("--prior", default="bounded", choices=["bounded"])
  38. parser.add_argument("--steps", default=1000000, type=int)
  39. parser.add_argument("--burnin", default=50000, type=int)
  40. parser.add_argument("--thin", default=100, type=int)
  41. parser.add_argument("--alpha-prior", default=5, type=float)
  42. parser.add_argument("--mle", action="store_true", default=False)
  43. args = parser.parse_args()
  44. suffix = f"_{args.suffix}" if args.suffix is not None else ""
  45. format = "mle" if args.mle else "samples"
  46. samples = np.load(opj(args.input, f"ei_{format}{suffix}.npz"))
  47. topics = pd.read_csv(opj(args.input, "topics.csv"))
  48. junk = topics["label"].str.contains("Junk")
  49. topics = topics[~junk]["label"].tolist()
  50. fig, ax = plt.subplots()
  51. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  52. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  53. resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
  54. df = df.merge(resources, left_on="bai", right_on="bai")
  55. NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  56. NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  57. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values)
  58. S = np.stack(df["pooled_resources"])
  59. # junk = np.sum(NR + NC, axis=0) == 0
  60. N = NR.shape[0]
  61. NR = NR[:,~junk]
  62. NC = NC[:,~junk]
  63. expertise = expertise[:,~junk]
  64. S = S[:,~junk]
  65. x = NR/NR.sum(axis=1)[:,np.newaxis]
  66. y = NC/NC.sum(axis=1)[:,np.newaxis]
  67. S_distrib = S/S.sum(axis=1)[:,np.newaxis]
  68. print(S_distrib)
  69. R = np.array([
  70. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
  71. for i in range(len(topics))
  72. ])
  73. K = expertise.shape[1]
  74. # observed couplings
  75. if args.mle:
  76. theta = samples["beta"]
  77. else:
  78. theta = samples["beta"].mean(axis=0)
  79. theta = np.einsum("ai,aij->ij", x, theta)
  80. order = np.load(opj(args.input, "topics_order.npy"))
  81. def mcmc_bounded(T, x, alpha_prior, sigma, steps=1000):
  82. # x = x/x.std()
  83. T = T/T.sum()
  84. m = T.shape[0]
  85. n = T.shape[1]
  86. K = np.zeros((steps+1, m, n))
  87. lambd = m*n*6
  88. # Transform K in a way that sends C to the prior support
  89. # while preserving cross-ratios
  90. Dc = cp.Variable(m)
  91. prob = cp.Problem(
  92. cp.Minimize(cp.sum(cp.abs(Dc))),
  93. [
  94. m*cp.sum(Dc)==-np.sum(np.log(T))-lambd, # C sums to m*n*lambda
  95. Dc <= -np.log(np.max(T, axis=0)) # C is positive
  96. ]
  97. )
  98. prob.solve(verbose=True)
  99. K[0] = T@np.diag(np.exp(Dc.value))
  100. beta = np.random.randn(steps+1)
  101. delta = np.random.randn(steps+1)
  102. C = np.zeros((steps+1, m, n))
  103. C[0] = -np.log(K[0])/lambd
  104. accepted = np.array([False]*(steps+1))
  105. accepted[0] = True
  106. oob = np.array([False]*(steps+1))
  107. beta_prior = norm(loc=0,scale=1)
  108. delta_prior = norm(loc=0,scale=1)
  109. for i in range(steps):
  110. Dr = np.random.randn(m)*sigma
  111. Dr = np.exp(Dr)
  112. Dc = 1/Dr # preserve the sum of C
  113. beta[i+1] = np.random.randn()*sigma+beta[i]
  114. delta[i+1] = np.random.randn()*sigma*0.1+delta[i]
  115. K[i+1] = np.diag(Dr)@K[i]@np.diag(Dc)
  116. C[i+1] = -np.log(K[i+1])/lambd
  117. distrib_prop = softmax(np.power(x.flatten(), np.exp(delta[i+1]))*beta[i+1])
  118. distrib_prev = softmax(np.power(x.flatten(), np.exp(delta[i]))*beta[i])
  119. oob[i+1] = np.abs(C[i+1].sum()-1)>1e-6 or np.any(C[i+1]<0)
  120. if not oob[i+1]:
  121. p_prop = beta_prior.logpdf(beta[i+1]) + delta_prior.logpdf(delta[i+1]*10)
  122. p_prev = beta_prior.logpdf(beta[i]) + delta_prior.logpdf(delta[i]*10)
  123. p_prop += -alpha_prior*(C[i+1].flatten()*np.log(C[i+1].flatten()/distrib_prop)).sum() - 0.5*np.log(C[i+1].flatten()).sum()
  124. p_prev += -alpha_prior*(C[i].flatten()*np.log(C[i].flatten()/distrib_prev)).sum() - 0.5*np.log(C[i].flatten()).sum()
  125. a = p_prop-p_prev
  126. u = np.random.uniform(0, 1)
  127. if oob[i+1] or a <= np.log(u):
  128. C[i+1] = C[i]
  129. K[i+1] = K[i]
  130. beta[i+1] = beta[i]
  131. delta[i+1] = delta[i]
  132. accepted[i+1] = False
  133. else:
  134. accepted[i+1] = True
  135. if i % 1000 == 0:
  136. print(f"step {i}/{steps}, rate={accepted[:i].mean():.3f}, oob={oob[:i].mean():.3f}, acc={accepted[:i].sum():.0f}")
  137. print(f"beta: {beta[:i].mean():.2f}, beta batch: {beta[i-1000:i].mean():.2f}, std batch: {beta[i-1000:i].std():.2f}")
  138. print(f"delta: {delta[:i].mean():.2f}, delta batch: {delta[i-1000:i].mean():.2f}, std batch: {delta[i-1000:i].std():.2f}")
  139. return C, beta, delta, accepted
  140. output = opj(args.input, f"cost_{args.model}_{args.prior}.npz")
  141. if args.model == "knowledge":
  142. matrix = 1-np.load(opj(args.input, "nu_expertise.npy"))
  143. elif args.model == "etm":
  144. matrix = 1-np.load(opj(args.input, "nu_etm.npy"))
  145. elif args.model == "identity":
  146. matrix = 1-np.eye(K)
  147. elif args.model == "random":
  148. matrix = np.random.uniform(0, 1, size=(K,K))
  149. elif args.model == "linguistic":
  150. matrix = np.load(opj(args.input, "nu_ling.npy"))
  151. elif args.model == "linguistic_symmetric":
  152. matrix = np.load(opj(args.input, "nu_ling_symmetric.npy"))
  153. fig, ax = plt.subplots()
  154. sns.heatmap(
  155. matrix[:, order][order],
  156. cmap="Reds",
  157. vmin=0,
  158. vmax=+np.max(np.abs(matrix)),
  159. xticklabels=[topics[i] for i in order],
  160. yticklabels=[topics[i] for i in order],
  161. ax=ax,
  162. )
  163. fig.savefig(opj(args.input, f"linguistic_gap_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  164. matrix_sd = 1
  165. if not exists(output):
  166. if args.model in ["knowledge", "etm"]:
  167. C, beta, delta, accepted = mcmc_bounded(theta, matrix, args.alpha_prior*K*K, 0.1, steps=args.steps)
  168. else:
  169. C, beta, delta, accepted = mcmc_bounded(theta, matrix/matrix.std(), args.alpha_prior*K*K, 0.1, steps=args.steps)
  170. C = C[args.burnin::args.thin]
  171. beta = beta[args.burnin::args.thin]
  172. delta = delta[args.burnin::args.thin]
  173. accepted = accepted[args.burnin::args.thin]
  174. np.savez_compressed(output, C=C, beta=beta, delta=delta)
  175. else:
  176. samples = np.load(output)
  177. C = samples["C"]
  178. beta = samples["beta"]
  179. delta = samples["delta"]
  180. print(beta.mean())
  181. print(beta.std())
  182. m = matrix/matrix_sd
  183. m = np.repeat(m[np.newaxis,:,:], len(delta), axis=0)
  184. delta = delta[:,np.newaxis,np.newaxis]
  185. delta = np.repeat(delta, m.shape[1], axis=1)
  186. delta = np.repeat(delta, m.shape[2], axis=2)
  187. res = C.mean(axis=0).flatten()-softmax(np.einsum("s,sij->sij", beta, np.power(m, np.exp(delta))).mean(axis=0).flatten())
  188. delta = res.mean(axis=0)
  189. res = (res**2).mean()
  190. var = C.flatten().var()
  191. res = res/var
  192. res = 1-res
  193. print(res)
  194. fig, ax = plt.subplots()
  195. sns.heatmap(
  196. C.mean(axis=0)[:, order][order],
  197. xticklabels=[topics[i] for i in order],
  198. yticklabels=[topics[i] for i in order],
  199. cmap="Reds",
  200. vmin=+np.min(C.mean(axis=0)),
  201. vmax=+np.max(C.mean(axis=0)),
  202. ax=ax,
  203. )
  204. fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  205. fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.pdf"), bbox_inches="tight")
  206. fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.png"), bbox_inches="tight", dpi=300)
  207. pearson = np.corrcoef(C.mean(axis=0).flatten(), matrix.flatten())[0,1]
  208. print("R:", pearson)
  209. print("R^2:", pearson**2)
  210. reg = LinearRegression()
  211. fit = reg.fit(matrix.flatten().reshape(-1, 1),C.mean(axis=0).flatten())
  212. if args.model == "knowledge":
  213. fig, ax = plt.subplots(figsize=(6.4,4.8))
  214. xs = np.linspace(0, 1, 4)
  215. ax.plot(1-xs, fit.predict(xs.reshape(-1, 1)), color="black")
  216. ax.scatter(1-matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  217. # error bars are boring as they only reflect the degeneracy of the cost matrix
  218. # low = np.quantile(C, q=0.05/2, axis=0)
  219. # up = np.quantile(C, q=1-0.05/2, axis=0)
  220. # mean = C.mean(axis=0)
  221. # ax.errorbar(
  222. # 1-matrix.flatten(),
  223. # mean.flatten(),
  224. # (np.maximum(mean.flatten()-low.flatten(), 0), np.maximum(up.flatten()-mean.flatten(), 0)),
  225. # ls="none",
  226. # lw=0.5
  227. # )
  228. ax.set_xlabel("Fraction of physicists with expertise in $k'$\namong those with expertise in $k$ ($\\nu_{k,k'}$)")
  229. # pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, (1-matrix.flatten())/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
  230. ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
  231. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  232. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  233. elif args.model == "identity":
  234. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  235. ax.axline((0,0), slope=-beta.mean(axis=0)/matrix_sd, color="black")
  236. ax.scatter((1-matrix).flatten(), C.mean(axis=0).flatten(), s=4)
  237. ax.set_xlabel("1 if $k=k'$, 0 otherwise")
  238. ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
  239. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  240. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  241. elif args.model == "linguistic":
  242. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  243. ax.axline((0,0), slope=beta.mean(axis=0)/matrix_sd, color="black")
  244. ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  245. ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
  246. ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
  247. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  248. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  249. elif args.model == "linguistic_symmetric":
  250. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  251. ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  252. ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
  253. pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, matrix.flatten()/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
  254. ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
  255. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  256. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  257. # predicted transfers
  258. origin = x.mean(axis=0)
  259. target = y.mean(axis=0)
  260. fig, ax = plt.subplots()
  261. shifts = theta[:, order][order]/theta.sum()
  262. sig = shifts>origin[order]*target[order]
  263. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  264. sns.heatmap(
  265. shifts,
  266. xticklabels=[topics[i] for i in order],
  267. yticklabels=[topics[i] for i in order],
  268. cmap="Blues",
  269. vmin=0,
  270. ax=ax,
  271. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  272. fmt="",
  273. annot_kws={"fontsize": 6},
  274. )
  275. fig.savefig(opj(args.input, f"cost_matrix_true_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  276. T = ot.sinkhorn(
  277. origin,
  278. target,
  279. softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model in ["knowledge", "etm"] else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
  280. 1/(3*K*K)
  281. )
  282. shifts = T[:, order][order]
  283. sig = shifts>origin[order]*target[order]
  284. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  285. fig, ax = plt.subplots()
  286. sns.heatmap(
  287. shifts,
  288. xticklabels=[topics[i] for i in order],
  289. yticklabels=[topics[i] for i in order],
  290. cmap="Blues",
  291. vmin=0,
  292. ax=ax,
  293. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  294. fmt="",
  295. annot_kws={"fontsize": 6},
  296. )
  297. fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  298. T_baseline = ot.sinkhorn(
  299. origin,
  300. target,
  301. (1-np.identity(K))/K,
  302. 50/(10*K*K)
  303. )
  304. fig, ax = plt.subplots()
  305. shifts = T_baseline[:, order][order]
  306. sig = shifts>origin[order]*target[order]
  307. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  308. fig, ax = plt.subplots()
  309. sns.heatmap(
  310. shifts,
  311. xticklabels=[topics[i] for i in order],
  312. yticklabels=[topics[i] for i in order],
  313. cmap="Blues",
  314. vmin=0,
  315. ax=ax,
  316. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  317. fmt="",
  318. annot_kws={"fontsize": 6},
  319. )
  320. fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_identity.eps"), bbox_inches="tight")
  321. def tv_dist(x, y):
  322. return np.abs(y/y.sum()-x/x.sum()).sum()/2
  323. lambdas = np.logspace(np.log10(1/(5*10*K*K)), np.log10(100/(K*K)), 200)
  324. perf = []
  325. baseline = []
  326. for l in lambdas:
  327. T = ot.sinkhorn(
  328. origin,
  329. target,
  330. softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model == "knowledge" else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
  331. l
  332. )
  333. T_baseline = ot.sinkhorn(
  334. origin,
  335. target,
  336. (1-np.identity(K))/K,
  337. l
  338. )
  339. perf.append(tv_dist(T.flatten(), theta.flatten()))
  340. baseline.append(tv_dist(T_baseline.flatten(), theta.flatten()))
  341. fig, ax = plt.subplots()
  342. ax.plot(lambdas, perf, label=f"{args.model} ({np.min(perf):.3f})")
  343. ax.plot(lambdas, baseline, label=f"baseline ({np.min(baseline):.3f})")
  344. ax.set_xscale("log")
  345. fig.legend()
  346. fig.savefig(opj(args.input, f"performance_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  347. # counterfactual
  348. T = ot.sinkhorn(
  349. origin,
  350. target,
  351. C.mean(axis=0)/C.mean(axis=0).sum(),
  352. 1/(3*K*K)
  353. )
  354. shifts = T[:, order][order]
  355. sig = shifts>origin[order]*target[order]
  356. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  357. fig, ax = plt.subplots()
  358. sns.heatmap(
  359. shifts,
  360. xticklabels=[topics[i] for i in order],
  361. yticklabels=[topics[i] for i in order],
  362. cmap="Blues",
  363. vmin=0,
  364. ax=ax,
  365. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  366. fmt="",
  367. annot_kws={"fontsize": 6},
  368. )
  369. fig.savefig(opj(args.input, f"cost_matrix_counterfactual_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")