ei_sankey.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import numpy as np
  2. import pandas as pd
  3. import plotly.graph_objects as go
  4. import plotly.io as pio
  5. pio.kaleido.scope.mathjax = None
  6. import seaborn as sns
  7. import argparse
  8. from os.path import join as opj
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input")
  11. parser.add_argument("--suffix")
  12. parser.add_argument("--compact", action="store_true", default=False)
  13. args = parser.parse_args()
  14. samples = np.load(opj(args.input, f"ei_samples_{args.suffix}.npz"))
  15. topics = pd.read_csv(opj(args.input, "topics.csv"))
  16. topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
  17. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  18. n_topics = samples["counts"].shape[1]
  19. source = []
  20. target = []
  21. value = []
  22. color = []
  23. colors = sns.color_palette("hls", n_topics).as_hex()
  24. print(colors)
  25. total_incoming = samples["counts"].mean(axis=0).sum(axis=1)
  26. total_outcoming = samples["counts"].mean(axis=0).sum(axis=0)
  27. incoming_labels = [f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)]
  28. outcoming_labels = [f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)]
  29. for i in range(n_topics):
  30. for j in range(n_topics):
  31. v = samples["counts"][:, i, j].mean()
  32. x = samples["counts"].mean(axis=0)[i,:].sum()
  33. y = samples["counts"].mean(axis=0)[:,j].sum()
  34. value.append(v)
  35. source.append(i)
  36. target.append(j + n_topics)
  37. color.append("rgba(100,100,100,0.2)" if v >= x*y else "rgba(200,200,200,0.02)")
  38. fig = go.Figure(
  39. data=[
  40. go.Sankey(
  41. node=dict(
  42. pad=7,
  43. thickness=20,
  44. line=dict(color="black", width=0.5),
  45. label=topics + topics,
  46. color=colors + colors,
  47. ),
  48. link=dict(
  49. source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
  50. target=target,
  51. value=value,
  52. color=color,
  53. ),
  54. )
  55. ]
  56. )
  57. fig.update_layout(font_size=13)
  58. fig.update_layout(
  59. autosize=False,
  60. width=650 if args.compact else 800,
  61. height=600,
  62. )
  63. #fig.show()
  64. fig.write_image(opj(args.input, f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}.pdf"), width=650 if args.compact else 800, height=600)
  65. transfers = pd.DataFrame([
  66. {
  67. "from": topics[i],
  68. "to": topics[j],
  69. "magnitude": 100*samples["counts"].mean(axis=0)[i,j],
  70. "ratio": samples["counts"].mean(axis=0)[i,j]/samples["counts"].mean(axis=0)[i,:].sum()
  71. }
  72. for i in range(n_topics)
  73. for j in range(n_topics)
  74. ])
  75. latex = transfers[transfers["from"]!=transfers["to"]].sort_values("magnitude", ascending=False).head(10).to_latex(
  76. columns=["from", "to", "magnitude"],
  77. header=["Origin research area", "Target research area", "Magnitude"],
  78. index=False,
  79. multirow=True,
  80. multicolumn=True,
  81. column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c',
  82. escape=False,
  83. float_format=lambda x: f"{x:.2f}",
  84. caption="Largest transfers across research areas.",
  85. label="table:largest_transfers"
  86. )
  87. latex = latex.replace('\\\\\n', '\\\\ \\hline\n')
  88. with open(opj(args.input, "largest_transfers.tex"), "w+") as fp:
  89. fp.write(latex)
  90. latex = transfers[transfers["from"]==transfers["to"]].sort_values("ratio", ascending=False).head(10).to_latex(
  91. columns=["from", "ratio"],
  92. header=["Research area", "Conservatism"],
  93. index=False,
  94. multirow=True,
  95. multicolumn=True,
  96. column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c',
  97. escape=False,
  98. float_format=lambda x: f"{x:.2f}",
  99. caption="Most conservative research areas.",
  100. label="table:most_conservative"
  101. )
  102. latex = latex.replace('\\\\\n', '\\\\ \\hline\n')
  103. with open(opj(args.input, "most_conservative.tex"), "w+") as fp:
  104. fp.write(latex)