ei_sankey.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import numpy as np
  2. import pandas as pd
  3. import plotly.graph_objects as go
  4. import plotly.io as pio
  5. pio.kaleido.scope.mathjax = None
  6. import seaborn as sns
  7. import argparse
  8. from os.path import join as opj
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input")
  11. parser.add_argument("--suffix")
  12. parser.add_argument("--compact", action="store_true", default=False)
  13. parser.add_argument("--transparent", action="store_true", default=False)
  14. parser.add_argument("--highlight-in", default=list(), nargs="+")
  15. parser.add_argument("--highlight-out", default=list(), nargs="+")
  16. args = parser.parse_args()
  17. samples = np.load(opj(args.input, f"ei_samples_{args.suffix}.npz"))
  18. topics = pd.read_csv(opj(args.input, "topics.csv"))
  19. topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
  20. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  21. n_topics = samples["counts"].shape[1]
  22. source = []
  23. target = []
  24. value = []
  25. color = []
  26. colors = sns.color_palette("hls", n_topics).as_hex()
  27. print(colors)
  28. total_incoming = samples["counts"].mean(axis=0).sum(axis=1)
  29. total_outcoming = samples["counts"].mean(axis=0).sum(axis=0)
  30. incoming_labels = [
  31. f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)
  32. ]
  33. outcoming_labels = [
  34. f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)
  35. ]
  36. for i in range(n_topics):
  37. for j in range(n_topics):
  38. v = samples["counts"][:, i, j].mean()
  39. x = samples["counts"].mean(axis=0)[i, :].sum()
  40. y = samples["counts"].mean(axis=0)[:, j].sum()
  41. value.append(v)
  42. source.append(i)
  43. target.append(j + n_topics)
  44. highlight = (
  45. (len(args.highlight_in) > 0 and topics[i] in args.highlight_in)
  46. or (len(args.highlight_out) > 0 and topics[j] in args.highlight_out)
  47. or (
  48. (not args.transparent)
  49. and (len(args.highlight_in + args.highlight_out) == 0)
  50. )
  51. ) and (v >= x * y)
  52. # print(topics[i], topics[j])
  53. # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in))
  54. # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out))
  55. # print((len(args.highlight_in + args.highlight_out) == 0))
  56. color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)")
  57. fig = go.Figure(
  58. data=[
  59. go.Sankey(
  60. node=dict(
  61. pad=7,
  62. thickness=20,
  63. line=dict(color="black", width=0.5),
  64. label=topics + topics,
  65. color=colors + colors,
  66. ),
  67. link=dict(
  68. source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
  69. target=target,
  70. value=value,
  71. color=color,
  72. ),
  73. )
  74. ]
  75. )
  76. fig.update_layout(font_size=13)
  77. fig.update_layout(
  78. autosize=False,
  79. width=650 if args.compact else 800,
  80. height=600,
  81. )
  82. # fig.show()
  83. highlights = f"{'_'.join(args.highlight_in)}{'_'.join(args.highlight_out)}".replace(
  84. "/", "_"
  85. ).replace("&", "_").replace(" ", "").lower()
  86. if len(highlights):
  87. highlights = f"_{highlights}"
  88. fig.write_image(
  89. opj(
  90. args.input,
  91. f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}{'_transparent' if args.transparent else ''}{highlights}.pdf",
  92. ),
  93. width=650 if args.compact else 800,
  94. height=600,
  95. )
  96. transfers = pd.DataFrame(
  97. [
  98. {
  99. "from": topics[i],
  100. "to": topics[j],
  101. "magnitude": 100 * samples["counts"].mean(axis=0)[i, j],
  102. "ratio": samples["counts"].mean(axis=0)[i, j]
  103. / samples["counts"].mean(axis=0)[i, :].sum(),
  104. }
  105. for i in range(n_topics)
  106. for j in range(n_topics)
  107. ]
  108. )
  109. latex = (
  110. transfers[transfers["from"] != transfers["to"]]
  111. .sort_values("magnitude", ascending=False)
  112. .head(10)
  113. .to_latex(
  114. columns=["from", "to", "magnitude"],
  115. header=["Origin research area", "Target research area", "Magnitude"],
  116. index=False,
  117. multirow=True,
  118. multicolumn=True,
  119. column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
  120. escape=False,
  121. float_format=lambda x: f"{x:.2f}",
  122. caption="Largest transfers across research areas.",
  123. label="table:largest_transfers",
  124. )
  125. )
  126. latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
  127. with open(opj(args.input, "largest_transfers.tex"), "w+") as fp:
  128. fp.write(latex)
  129. latex = (
  130. transfers[transfers["from"] == transfers["to"]]
  131. .sort_values("ratio", ascending=False)
  132. .head(10)
  133. .to_latex(
  134. columns=["from", "ratio"],
  135. header=["Research area", "Conservatism"],
  136. index=False,
  137. multirow=True,
  138. multicolumn=True,
  139. column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
  140. escape=False,
  141. float_format=lambda x: f"{x:.2f}",
  142. caption="Most conservative research areas.",
  143. label="table:most_conservative",
  144. )
  145. )
  146. latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
  147. with open(opj(args.input, "most_conservative.tex"), "w+") as fp:
  148. fp.write(latex)