author.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. import matplotlib
  4. matplotlib.use("pgf")
  5. matplotlib.rcParams.update(
  6. {
  7. "pgf.texsystem": "xelatex",
  8. "font.family": "serif",
  9. "font.serif": "Times New Roman",
  10. "text.usetex": True,
  11. "pgf.rcfonts": False,
  12. }
  13. )
  14. plt.rcParams["text.latex.preamble"].join([
  15. r"\usepackage{amsmath}",
  16. r"\setmainfont{amssymb}",
  17. ])
  18. from matplotlib.gridspec import GridSpec
  19. import pandas as pd
  20. from os.path import join as opj
  21. from scipy.stats import entropy
  22. import argparse
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument("--input")
  25. parser.add_argument("--suffix")
  26. parser.add_argument("--bai")
  27. parser.add_argument("--reorder-topics", action="store_true", default=False)
  28. args = parser.parse_args()
  29. inp = args.input
  30. bai = args.bai
  31. topics = pd.read_csv(opj(inp, "topics.csv"))
  32. n_topics = len(topics)
  33. junk = topics["label"].str.contains("Junk")
  34. topics = topics[~junk]["label"].tolist()
  35. df = pd.read_csv(opj(inp, "aggregate.csv"))
  36. df = df.merge(pd.read_parquet(opj(inp, "pooled_resources.parquet")), left_on="bai", right_on="bai")
  37. X = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  38. Y = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  39. S = np.stack(df["pooled_resources"])
  40. expertise = np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values)
  41. # stability = pd.read_csv(opj(inp, "institutional_stability.csv"), index_col="bai")
  42. # df = df.merge(stability, left_on="bai", right_index=True)
  43. X = X[:,~junk]
  44. Y = Y[:,~junk]
  45. S = S[:,~junk]
  46. expertise = expertise[:,~junk]
  47. df["social_diversity"] = np.exp(entropy(S,axis=1))
  48. df["intellectual_diversity"] = np.exp(entropy(expertise,axis=1))
  49. df["social_cap"] = np.log(1+np.stack(S).sum(axis=1))
  50. x = X/X.sum(axis=1)[:,np.newaxis]
  51. y = Y/Y.sum(axis=1)[:,np.newaxis]
  52. authors = df["bai"].tolist()
  53. n_topics = len(topics)
  54. a = authors.index(bai)
  55. if args.reorder_topics:
  56. from scipy.spatial.distance import pdist, squareform
  57. from fastcluster import linkage
  58. def seriation(Z,N,cur_index):
  59. if cur_index < N:
  60. return [cur_index]
  61. else:
  62. left = int(Z[cur_index-N,0])
  63. right = int(Z[cur_index-N,1])
  64. return (seriation(Z,N,left) + seriation(Z,N,right))
  65. def compute_serial_matrix(dist_mat,method="ward"):
  66. N = len(dist_mat)
  67. flat_dist_mat = squareform(dist_mat)
  68. res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
  69. res_order = seriation(res_linkage, N, N + N-2)
  70. seriated_dist = np.zeros((N,N))
  71. a,b = np.triu_indices(N,k=1)
  72. seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
  73. seriated_dist[b,a] = seriated_dist[a,b]
  74. return seriated_dist, res_order, res_linkage
  75. dist = 1-np.array([
  76. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/((expertise[:,i]>expertise[:,i].mean())|(expertise[:,j]>expertise[:,j].mean())).mean() for j in range(len(topics))]
  77. for i in range(len(topics))
  78. ])
  79. m, order, dendo = compute_serial_matrix(dist)
  80. order = np.array(order)[::-1]
  81. else:
  82. order = np.arange(n_topics)
  83. topics = [topics[i] for i in order]
  84. fig, ax = plt.subplots(figsize=(7.5/1.25,5/1.25))
  85. fig.tight_layout()
  86. ax.barh(np.arange(n_topics), -x[a,order], align='center', zorder=10, height=1)
  87. ax.barh(np.arange(n_topics), y[a,order], align='center', color="red", zorder=10, height=1)
  88. ax.axvline(0, color="black")
  89. # plt.gca().invert_yaxis()
  90. ax.set(yticks=np.arange(n_topics), yticklabels=topics)
  91. ax.yaxis.tick_right()
  92. ax.yaxis.set_label_position("right")
  93. # low, high = ax.get_xlim()
  94. # bound = max(abs(low), abs(high))
  95. # ax.set_xlim(-bound, bound)
  96. plt.setp(ax.get_xticklabels(), visible=False)
  97. fig.savefig(opj(inp, f"portfolios_{bai}.eps"), bbox_inches="tight")
  98. fig.savefig(opj(inp, f"portfolios_{bai}.pdf"), bbox_inches="tight")
  99. samples = np.load(opj(inp, f"ei_samples_{args.suffix}.npz"))
  100. beta = samples["beta"].mean(axis=0)
  101. # m = np.einsum("kii,ki->k", 1-beta, x)
  102. # m = m.mean(axis=0)
  103. # theta = samples["beta"].mean(axis=0)
  104. beta = np.einsum("ij,i->ij", beta[a], x[a,:])
  105. fig = plt.figure(figsize=(6.4,6.4))
  106. gs = GridSpec(4,4,hspace=0.1,wspace=0.1)
  107. ax_joint = fig.add_subplot(gs[1:4,1:4])
  108. ax_marg_x = fig.add_subplot(gs[0,1:4])
  109. ax_marg_y = fig.add_subplot(gs[1:4,0])
  110. ax_joint.set_xlim(-0.5,n_topics-0.5)
  111. ax_marg_x.set_xlim(-0.5,n_topics-0.5)
  112. ax_marg_y.set_ylim(-0.5,n_topics-0.5)
  113. ax_joint.imshow(beta[:, order][order], cmap="Greys", aspect='auto', vmin=0, vmax=0.5)
  114. ax_marg_x.bar(np.arange(n_topics), height=y[a,order], width=1, color="red")
  115. ax_marg_y.barh(n_topics-np.arange(n_topics)-1, width=x[a,order], height=1, orientation="horizontal")
  116. common_scale = np.maximum(np.max(x[a,order]),np.max(y[a,order]))
  117. ax_marg_x.set_ylim(0,common_scale)
  118. ax_marg_y.set_xlim(0,common_scale)
  119. ax_marg_y.invert_xaxis()
  120. # Turn off tick labels on marginals
  121. plt.setp(ax_marg_x.get_xticklabels(), visible=False)
  122. plt.setp(ax_marg_x.get_yticklabels(), visible=False)
  123. plt.setp(ax_marg_y.get_yticklabels(), visible=False)
  124. plt.setp(ax_marg_y.get_xticklabels(), visible=False)
  125. ax_joint.yaxis.tick_right()
  126. ax_joint.set_xticks(np.arange(n_topics), np.arange(n_topics))
  127. ax_joint.set_xticklabels(topics, rotation = 90)
  128. ax_joint.set_yticks(np.arange(n_topics), np.arange(n_topics))
  129. ax_joint.set_yticklabels(topics)
  130. plt.show()
  131. fig.savefig(opj(inp, f"trajectory_example_{bai}.eps"), bbox_inches="tight")