optimal_transport.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from calendar import c
  2. import numpy as np
  3. import pandas as pd
  4. from scipy.stats import norm
  5. from scipy.special import softmax
  6. import cvxpy as cp
  7. import ot
  8. from sklearn.linear_model import LinearRegression
  9. from scipy.linalg import logm
  10. from matplotlib import pyplot as plt
  11. import matplotlib
  12. matplotlib.use("pgf")
  13. matplotlib.rcParams.update(
  14. {
  15. "pgf.texsystem": "xelatex",
  16. "font.family": "serif",
  17. "font.serif": "Times New Roman",
  18. "text.usetex": True,
  19. "pgf.rcfonts": False,
  20. 'mathtext.default': 'regular',
  21. }
  22. )
  23. plt.rcParams["text.latex.preamble"].join([
  24. r"\usepackage{amsmath}",
  25. r"\usepackage{bm}",
  26. r"\setmainfont{amssymb}",
  27. ])
  28. import seaborn as sns
  29. import argparse
  30. from os.path import join as opj, exists
  31. import pickle
  32. from cmdstanpy import CmdStanModel
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument("--input")
  35. parser.add_argument("--suffix", default=None)
  36. parser.add_argument("--model", default="knowledge", choices=["knowledge", "identity", "random", "etm", "linguistic", "linguistic_symmetric"])
  37. parser.add_argument("--prior", default="bounded", choices=["bounded"])
  38. parser.add_argument("--steps", default=1000000, type=int)
  39. parser.add_argument("--burnin", default=50000, type=int)
  40. parser.add_argument("--thin", default=100, type=int)
  41. parser.add_argument("--alpha-prior", default=5, type=float)
  42. args = parser.parse_args()
  43. suffix = f"_{args.suffix}" if args.suffix is not None else ""
  44. samples = np.load(opj(args.input, f"ei_samples{suffix}.npz"))
  45. topics = pd.read_csv(opj(args.input, "topics.csv"))
  46. junk = topics["label"].str.contains("Junk")
  47. topics = topics[~junk]["label"].tolist()
  48. fig, ax = plt.subplots()
  49. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  50. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  51. resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
  52. df = df.merge(resources, left_on="bai", right_on="bai")
  53. NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  54. NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  55. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values)
  56. S = np.stack(df["pooled_resources"])
  57. # junk = np.sum(NR + NC, axis=0) == 0
  58. N = NR.shape[0]
  59. NR = NR[:,~junk]
  60. NC = NC[:,~junk]
  61. expertise = expertise[:,~junk]
  62. S = S[:,~junk]
  63. x = NR/NR.sum(axis=1)[:,np.newaxis]
  64. y = NC/NC.sum(axis=1)[:,np.newaxis]
  65. S_distrib = S/S.sum(axis=1)[:,np.newaxis]
  66. print(S_distrib)
  67. R = np.array([
  68. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(len(topics))]
  69. for i in range(len(topics))
  70. ])
  71. K = expertise.shape[1]
  72. # observed couplings
  73. theta = samples["beta"].mean(axis=0)
  74. theta = np.einsum("ai,aij->ij", x, theta)
  75. order = np.load(opj(args.input, "topics_order.npy"))
  76. def mcmc_bounded(T, x, alpha_prior, sigma, steps=1000):
  77. # x = x/x.std()
  78. T = T/T.sum()
  79. m = T.shape[0]
  80. n = T.shape[1]
  81. K = np.zeros((steps+1, m, n))
  82. lambd = m*n*3
  83. # Transform K in a way that sends C to the prior support
  84. # while preserving cross-ratios
  85. Dc = cp.Variable(m)
  86. prob = cp.Problem(
  87. cp.Minimize(cp.sum(cp.abs(Dc))),
  88. [
  89. m*cp.sum(Dc)==-np.sum(np.log(T))-lambd, # C sums to m*n*lambda
  90. Dc <= -np.log(np.max(T, axis=0)) # C is positive
  91. ]
  92. )
  93. prob.solve(verbose=True)
  94. K[0] = T@np.diag(np.exp(Dc.value))
  95. beta = np.random.randn(steps+1)
  96. C = np.zeros((steps+1, m, n))
  97. C[0] = -np.log(K[0])/lambd
  98. accepted = np.array([False]*(steps+1))
  99. accepted[0] = True
  100. oob = np.array([False]*(steps+1))
  101. beta_prior = norm(loc=0,scale=1)
  102. for i in range(steps):
  103. Dr = np.random.randn(m)*sigma
  104. Dr = np.exp(Dr)
  105. Dc = 1/Dr # preserve the sum of C
  106. beta[i+1] = np.random.randn()*sigma+beta[i]
  107. K[i+1] = np.diag(Dr)@K[i]@np.diag(Dc)
  108. C[i+1] = -np.log(K[i+1])/lambd
  109. distrib_prop = softmax(x.flatten()*beta[i+1])
  110. distrib_prev = softmax(x.flatten()*beta[i])
  111. oob[i+1] = np.abs(C[i+1].sum()-1)>1e-6 or np.any(C[i+1]<0)
  112. if not oob[i+1]:
  113. p_prop = beta_prior.logpdf(beta[i+1])
  114. p_prev = beta_prior.logpdf(beta[i])
  115. p_prop += -alpha_prior*(C[i+1].flatten()*np.log(C[i+1].flatten()/distrib_prop)).sum() - 0.5*np.log(C[i+1].flatten()).sum()
  116. p_prev += -alpha_prior*(C[i].flatten()*np.log(C[i].flatten()/distrib_prev)).sum() - 0.5*np.log(C[i].flatten()).sum()
  117. a = p_prop-p_prev
  118. u = np.random.uniform(0, 1)
  119. if oob[i+1] or a <= np.log(u):
  120. C[i+1] = C[i]
  121. K[i+1] = K[i]
  122. beta[i+1] = beta[i]
  123. accepted[i+1] = False
  124. else:
  125. accepted[i+1] = True
  126. if i % 1000 == 0:
  127. print(f"step {i}/{steps}, rate={accepted[:i].mean():.3f}, oob={oob[:i].mean():.3f}, acc={accepted[:i].sum():.0f}")
  128. print(f"beta: {beta[:i].mean():.2f}, beta batch: {beta[i-1000:i].mean():.2f}, std batch: {beta[i-1000:i].std():.2f}")
  129. return C, beta, accepted
  130. output = opj(args.input, f"cost_{args.model}_{args.prior}.npz")
  131. if args.model == "knowledge":
  132. matrix = 1-np.load(opj(args.input, "nu_expertise.npy"))
  133. elif args.model == "etm":
  134. matrix = 1-np.load(opj(args.input, "nu_etm.npy"))
  135. elif args.model == "identity":
  136. matrix = 1-np.eye(K)
  137. elif args.model == "random":
  138. matrix = np.random.uniform(0, 1, size=(K,K))
  139. elif args.model == "linguistic":
  140. matrix = np.load(opj(args.input, "nu_ling.npy"))
  141. elif args.model == "linguistic_symmetric":
  142. matrix = np.load(opj(args.input, "nu_ling_symmetric.npy"))
  143. fig, ax = plt.subplots()
  144. sns.heatmap(
  145. matrix[:, order][order],
  146. cmap="Reds",
  147. vmin=0,
  148. vmax=+np.max(np.abs(matrix)),
  149. xticklabels=[topics[i] for i in order],
  150. yticklabels=[topics[i] for i in order],
  151. ax=ax,
  152. )
  153. fig.savefig(opj(args.input, f"linguistic_gap_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  154. matrix_sd = 1
  155. if not exists(output):
  156. if args.model in ["knowledge", "etm"]:
  157. C, beta, accepted = mcmc_bounded(theta, matrix, args.alpha_prior*K*K, 0.1, steps=args.steps)
  158. else:
  159. C, beta, accepted = mcmc_bounded(theta, matrix/matrix.std(), args.alpha_prior*K*K, 0.1, steps=args.steps)
  160. C = C[args.burnin::args.thin]
  161. beta = beta[args.burnin::args.thin]
  162. accepted = accepted[args.burnin::args.thin]
  163. np.savez_compressed(output, C=C, beta=beta)
  164. else:
  165. samples = np.load(output)
  166. C = samples["C"]
  167. beta = samples["beta"]
  168. print(beta.mean())
  169. print(beta.std())
  170. res = C-np.einsum("s,ij->sij", beta, matrix/matrix_sd)
  171. delta = res.mean(axis=0)
  172. res = (res**2).mean(axis=(1,2))
  173. var = np.array([C[s].flatten().var() for s in range(C.shape[0])])
  174. res = res/var
  175. res = 1-res
  176. print(res.mean())
  177. fig, ax = plt.subplots()
  178. sns.heatmap(
  179. C.mean(axis=0)[:, order][order],
  180. xticklabels=[topics[i] for i in order],
  181. yticklabels=[topics[i] for i in order],
  182. cmap="Reds",
  183. vmin=+np.min(C.mean(axis=0)),
  184. vmax=+np.max(C.mean(axis=0)),
  185. ax=ax,
  186. )
  187. fig.savefig(opj(args.input, f"cost_matrix_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  188. pearson = np.corrcoef(C.mean(axis=0).flatten(), matrix.flatten())[0,1]
  189. print("R:", pearson)
  190. print("R^2:", pearson**2)
  191. reg = LinearRegression()
  192. fit = reg.fit(matrix.flatten().reshape(-1, 1),C.mean(axis=0).flatten())
  193. if args.model == "knowledge":
  194. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  195. xs = np.linspace(0, 1, 4)
  196. ax.plot(1-xs, fit.predict(xs.reshape(-1, 1)), color="black")
  197. ax.scatter(1-matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  198. # error bars are boring as they only reflect the degeneracy of the cost matrix
  199. # low = np.quantile(C, q=0.05/2, axis=0)
  200. # up = np.quantile(C, q=1-0.05/2, axis=0)
  201. # mean = C.mean(axis=0)
  202. # ax.errorbar(
  203. # 1-matrix.flatten(),
  204. # mean.flatten(),
  205. # (np.maximum(mean.flatten()-low.flatten(), 0), np.maximum(up.flatten()-mean.flatten(), 0)),
  206. # ls="none",
  207. # lw=0.5
  208. # )
  209. ax.set_xlabel("Fraction of physicists with expertise in $k'$\namong those with expertise in $k$ ($\\nu_{k,k'}$)")
  210. # pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, (1-matrix.flatten())/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
  211. ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
  212. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  213. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  214. elif args.model == "identity":
  215. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  216. ax.axline((0,0), slope=-beta.mean(axis=0)/matrix_sd, color="black")
  217. ax.scatter((1-matrix).flatten(), C.mean(axis=0).flatten(), s=4)
  218. ax.set_xlabel("1 if $k=k'$, 0 otherwise")
  219. ax.text(0.95, 0.95, f"$R={-pearson:.2f}$", ha="right", va="top", transform=ax.transAxes)
  220. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  221. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  222. elif args.model == "linguistic":
  223. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  224. ax.axline((0,0), slope=beta.mean(axis=0)/matrix_sd, color="black")
  225. ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  226. ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
  227. ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
  228. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  229. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  230. elif args.model == "linguistic_symmetric":
  231. fig, ax = plt.subplots(figsize=(0.75*4.8,0.75*3.2))
  232. ax.scatter(matrix.flatten(), C.mean(axis=0).flatten(), s=4)
  233. ax.set_xlabel("Linguistic gap from $k$ to $k'$\n$\\Delta_{k,k'}=H(\\varphi_{k'}+\\varphi_k)-H(\\varphi_k)$")
  234. pearson = np.corrcoef(softmax(np.einsum("s,i->si", beta, matrix.flatten()/matrix.std()), axis=1).mean(axis=0), C.mean(axis=0).flatten())[0,1]
  235. ax.text(0.05, 0.95, f"$R={pearson:.2f}$", ha="left", va="top", transform=ax.transAxes)
  236. ax.set_ylabel("Cost of shifting attention\nfrom $k$ to $k'$ ($C_{k,k'}$)")
  237. fig.savefig(opj(args.input, f"cost_vs_nu_{args.model}.eps"), bbox_inches="tight")
  238. # predicted transfers
  239. origin = x.mean(axis=0)
  240. target = y.mean(axis=0)
  241. fig, ax = plt.subplots()
  242. shifts = theta[:, order][order]/theta.sum()
  243. sig = shifts>origin[order]*target[order]
  244. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  245. sns.heatmap(
  246. shifts,
  247. xticklabels=[topics[i] for i in order],
  248. yticklabels=[topics[i] for i in order],
  249. cmap="Blues",
  250. vmin=0,
  251. ax=ax,
  252. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  253. fmt="",
  254. annot_kws={"fontsize": 6},
  255. )
  256. fig.savefig(opj(args.input, f"cost_matrix_true_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  257. T = ot.sinkhorn(
  258. origin,
  259. target,
  260. softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model in ["knowledge", "etm"] else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
  261. 1/(3*K*K)
  262. )
  263. shifts = T[:, order][order]
  264. sig = shifts>origin[order]*target[order]
  265. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  266. fig, ax = plt.subplots()
  267. sns.heatmap(
  268. shifts,
  269. xticklabels=[topics[i] for i in order],
  270. yticklabels=[topics[i] for i in order],
  271. cmap="Blues",
  272. vmin=0,
  273. ax=ax,
  274. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  275. fmt="",
  276. annot_kws={"fontsize": 6},
  277. )
  278. fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  279. T_baseline = ot.sinkhorn(
  280. origin,
  281. target,
  282. (1-np.identity(K))/K,
  283. 50/(10*K*K)
  284. )
  285. fig, ax = plt.subplots()
  286. shifts = T_baseline[:, order][order]
  287. sig = shifts>origin[order]*target[order]
  288. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  289. fig, ax = plt.subplots()
  290. sns.heatmap(
  291. shifts,
  292. xticklabels=[topics[i] for i in order],
  293. yticklabels=[topics[i] for i in order],
  294. cmap="Blues",
  295. vmin=0,
  296. ax=ax,
  297. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  298. fmt="",
  299. annot_kws={"fontsize": 6},
  300. )
  301. fig.savefig(opj(args.input, f"cost_matrix_predicted_couplings_identity.eps"), bbox_inches="tight")
  302. def tv_dist(x, y):
  303. return np.abs(y/y.sum()-x/x.sum()).sum()/2
  304. lambdas = np.logspace(np.log10(1/(5*10*K*K)), np.log10(100/(K*K)), 200)
  305. perf = []
  306. baseline = []
  307. for l in lambdas:
  308. T = ot.sinkhorn(
  309. origin,
  310. target,
  311. softmax(np.einsum("s,i->si", beta, matrix.flatten() if args.model == "knowledge" else matrix.flatten()/matrix.std()), axis=1).reshape((len(beta), K, K)).mean(axis=0),
  312. l
  313. )
  314. T_baseline = ot.sinkhorn(
  315. origin,
  316. target,
  317. (1-np.identity(K))/K,
  318. l
  319. )
  320. perf.append(tv_dist(T.flatten(), theta.flatten()))
  321. baseline.append(tv_dist(T_baseline.flatten(), theta.flatten()))
  322. fig, ax = plt.subplots()
  323. ax.plot(lambdas, perf, label=f"{args.model} ({np.min(perf):.3f})")
  324. ax.plot(lambdas, baseline, label=f"baseline ({np.min(baseline):.3f})")
  325. ax.set_xscale("log")
  326. fig.legend()
  327. fig.savefig(opj(args.input, f"performance_{args.model}_{args.prior}.eps"), bbox_inches="tight")
  328. # counterfactual
  329. T = ot.sinkhorn(
  330. origin,
  331. target,
  332. C.mean(axis=0)/C.mean(axis=0).sum(),
  333. 1/(3*K*K)
  334. )
  335. shifts = T[:, order][order]
  336. sig = shifts>origin[order]*target[order]
  337. shifts = shifts/shifts.sum(axis=1)[:,np.newaxis]
  338. fig, ax = plt.subplots()
  339. sns.heatmap(
  340. shifts,
  341. xticklabels=[topics[i] for i in order],
  342. yticklabels=[topics[i] for i in order],
  343. cmap="Blues",
  344. vmin=0,
  345. ax=ax,
  346. annot=[[f"\\textbf{{{shifts[i,j]:.2f}}}" if sig[i,j] else "" for j in range(len(topics))] for i in range(len(topics))],
  347. fmt="",
  348. annot_kws={"fontsize": 6},
  349. )
  350. fig.savefig(opj(args.input, f"cost_matrix_counterfactual_couplings_{args.model}_{args.prior}.eps"), bbox_inches="tight")