author.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. import matplotlib
  4. matplotlib.use("pgf")
  5. matplotlib.rcParams.update(
  6. {
  7. "pgf.texsystem": "xelatex",
  8. "font.family": "serif",
  9. "font.serif": "Times New Roman",
  10. "text.usetex": True,
  11. "pgf.rcfonts": False,
  12. }
  13. )
  14. plt.rcParams["text.latex.preamble"].join([
  15. r"\usepackage{amsmath}",
  16. r"\setmainfont{amssymb}",
  17. ])
  18. from matplotlib.gridspec import GridSpec
  19. import pandas as pd
  20. from os.path import join as opj
  21. from scipy.stats import entropy
  22. import argparse
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument("--input")
  25. parser.add_argument("--suffix")
  26. parser.add_argument("--bai")
  27. parser.add_argument("--reorder-topics", action="store_true", default=False)
  28. args = parser.parse_args()
  29. inp = args.input
  30. bai = args.bai
  31. topics = pd.read_csv(opj(inp, "topics.csv"))
  32. n_topics = len(topics)
  33. junk = topics["label"].str.contains("Junk")
  34. topics = topics[~junk]["label"].tolist()
  35. df = pd.read_csv(opj(inp, "aggregate.csv"))
  36. df = df.merge(pd.read_parquet(opj(inp, "pooled_resources.parquet")), left_on="bai", right_on="bai")
  37. X = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  38. Y = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  39. S = np.stack(df["pooled_resources"])
  40. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values)
  41. # stability = pd.read_csv(opj(inp, "institutional_stability.csv"), index_col="bai")
  42. # df = df.merge(stability, left_on="bai", right_index=True)
  43. X = X[:,~junk]
  44. Y = Y[:,~junk]
  45. S = S[:,~junk]
  46. expertise = expertise[:,~junk]
  47. df["social_diversity"] = np.exp(entropy(S,axis=1))
  48. df["intellectual_diversity"] = np.exp(entropy(expertise,axis=1))
  49. df["social_cap"] = np.log(1+np.stack(S).sum(axis=1))
  50. x = X/X.sum(axis=1)[:,np.newaxis]
  51. y = Y/Y.sum(axis=1)[:,np.newaxis]
  52. samples = np.load(opj(inp, f"ei_samples_{args.suffix}.npz"))
  53. m = np.einsum("skii,ki->sk", 1-samples["beta"], x)
  54. m = m.mean(axis=0)
  55. theta = samples["beta"].mean(axis=0)
  56. authors = df["bai"].tolist()
  57. n_topics = len(topics)
  58. a = authors.index(bai)
  59. beta = np.einsum("ij,i->ij", theta[a], x[a,:])
  60. if args.reorder_topics:
  61. from scipy.spatial.distance import pdist, squareform
  62. from fastcluster import linkage
  63. def seriation(Z,N,cur_index):
  64. if cur_index < N:
  65. return [cur_index]
  66. else:
  67. left = int(Z[cur_index-N,0])
  68. right = int(Z[cur_index-N,1])
  69. return (seriation(Z,N,left) + seriation(Z,N,right))
  70. def compute_serial_matrix(dist_mat,method="ward"):
  71. N = len(dist_mat)
  72. flat_dist_mat = squareform(dist_mat)
  73. res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
  74. res_order = seriation(res_linkage, N, N + N-2)
  75. seriated_dist = np.zeros((N,N))
  76. a,b = np.triu_indices(N,k=1)
  77. seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
  78. seriated_dist[b,a] = seriated_dist[a,b]
  79. return seriated_dist, res_order, res_linkage
  80. dist = 1-np.array([
  81. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))]
  82. for i in range(len(topics))
  83. ])
  84. m, order, dendo = compute_serial_matrix(dist)
  85. order = np.array(order)[::-1]
  86. else:
  87. order = np.arange(n_topics)
  88. topics = [topics[i] for i in order]
  89. fig = plt.figure(figsize=(6.4,6.4))
  90. gs = GridSpec(4,4,hspace=0.1,wspace=0.1)
  91. ax_joint = fig.add_subplot(gs[1:4,1:4])
  92. ax_marg_x = fig.add_subplot(gs[0,1:4])
  93. ax_marg_y = fig.add_subplot(gs[1:4,0])
  94. ax_joint.set_xlim(-0.5,n_topics-0.5)
  95. ax_marg_x.set_xlim(-0.5,n_topics-0.5)
  96. ax_marg_y.set_ylim(-0.5,n_topics-0.5)
  97. ax_joint.imshow(beta[:, order][order], cmap="Greys", aspect='auto', vmin=0, vmax=0.5)
  98. ax_marg_x.bar(np.arange(n_topics), height=y[a,order], width=1, color="red")
  99. ax_marg_y.barh(n_topics-np.arange(n_topics)-1, width=x[a,order], height=1, orientation="horizontal")
  100. common_scale = np.maximum(np.max(x[a,order]),np.max(y[a,order]))
  101. ax_marg_x.set_ylim(0,common_scale)
  102. ax_marg_y.set_xlim(0,common_scale)
  103. ax_marg_y.invert_xaxis()
  104. # Turn off tick labels on marginals
  105. plt.setp(ax_marg_x.get_xticklabels(), visible=False)
  106. plt.setp(ax_marg_x.get_yticklabels(), visible=False)
  107. plt.setp(ax_marg_y.get_yticklabels(), visible=False)
  108. plt.setp(ax_marg_y.get_xticklabels(), visible=False)
  109. ax_joint.yaxis.tick_right()
  110. ax_joint.set_xticks(np.arange(n_topics), np.arange(n_topics))
  111. ax_joint.set_xticklabels(topics, rotation = 90)
  112. ax_joint.set_yticks(np.arange(n_topics), np.arange(n_topics))
  113. ax_joint.set_yticklabels(topics)
  114. plt.show()
  115. fig.savefig(opj(inp, f"trajectory_example_{bai}.eps"), bbox_inches="tight")