ei_sankey.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import numpy as np
  2. import pandas as pd
  3. import plotly.graph_objects as go
  4. import plotly.io as pio
  5. pio.kaleido.scope.mathjax = None
  6. import seaborn as sns
  7. import argparse
  8. from os.path import join as opj
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input")
  11. parser.add_argument("--suffix")
  12. parser.add_argument("--compact", action="store_true", default=False)
  13. parser.add_argument("--transparent", action="store_true", default=False)
  14. parser.add_argument("--highlight-in", default=list(), nargs="+")
  15. parser.add_argument("--highlight-out", default=list(), nargs="+")
  16. parser.add_argument("--mle", action="store_true", default=False)
  17. args = parser.parse_args()
  18. format = "mle" if args.mle else "samples"
  19. samples = np.load(opj(args.input, f"ei_{format}_{args.suffix}.npz"))
  20. topics = pd.read_csv(opj(args.input, "topics.csv"))
  21. topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
  22. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  23. if args.mle:
  24. n_topics = samples["counts"].shape[0]
  25. else:
  26. n_topics = samples["counts"].shape[1]
  27. source = []
  28. target = []
  29. value = []
  30. color = []
  31. colors = sns.color_palette("hls", n_topics).as_hex()
  32. print(colors)
  33. if args.mle:
  34. total_incoming = samples["counts"].sum(axis=1)
  35. total_outcoming = samples["counts"].sum(axis=0)
  36. else:
  37. total_incoming = samples["counts"].mean(axis=0).sum(axis=1)
  38. total_outcoming = samples["counts"].mean(axis=0).sum(axis=0)
  39. incoming_labels = [
  40. f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)
  41. ]
  42. outcoming_labels = [
  43. f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)
  44. ]
  45. for i in range(n_topics):
  46. for j in range(n_topics):
  47. if args.mle:
  48. v = samples["counts"][i, j]
  49. x = samples["counts"][i, :].sum()
  50. y = samples["counts"][:, j].sum()
  51. else:
  52. v = samples["counts"][:, i, j].mean()
  53. x = samples["counts"].mean(axis=0)[i, :].sum()
  54. y = samples["counts"].mean(axis=0)[:, j].sum()
  55. value.append(v)
  56. source.append(i)
  57. target.append(j + n_topics)
  58. highlight = (
  59. (len(args.highlight_in) > 0 and topics[i] in args.highlight_in)
  60. or (len(args.highlight_out) > 0 and topics[j] in args.highlight_out)
  61. or (
  62. (not args.transparent)
  63. and (len(args.highlight_in + args.highlight_out) == 0)
  64. )
  65. ) and (v >= x * y)
  66. # print(topics[i], topics[j])
  67. # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in))
  68. # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out))
  69. # print((len(args.highlight_in + args.highlight_out) == 0))
  70. color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)")
  71. fig = go.Figure(
  72. data=[
  73. go.Sankey(
  74. node=dict(
  75. pad=7,
  76. thickness=20,
  77. line=dict(color="black", width=0.5),
  78. label=topics + topics,
  79. color=colors + colors,
  80. ),
  81. link=dict(
  82. source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
  83. target=target,
  84. value=value,
  85. color=color,
  86. ),
  87. )
  88. ]
  89. )
  90. fig.update_layout(font_size=13)
  91. fig.update_layout(
  92. autosize=False,
  93. width=650 if args.compact else 800,
  94. height=600,
  95. )
  96. # fig.show()
  97. highlights = f"{'_'.join(args.highlight_in)}{'_'.join(args.highlight_out)}".replace(
  98. "/", "_"
  99. ).replace("&", "_").replace(" ", "").lower()
  100. if len(highlights):
  101. highlights = f"_{highlights}"
  102. fig.write_image(
  103. opj(
  104. args.input,
  105. f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}{'_transparent' if args.transparent else ''}{highlights}.pdf",
  106. ),
  107. width=650 if args.compact else 800,
  108. height=600,
  109. )
  110. transfers = pd.DataFrame(
  111. [
  112. {
  113. "from": topics[i],
  114. "to": topics[j],
  115. "magnitude": 100 * samples["counts"].mean(axis=0)[i, j],
  116. "ratio": samples["counts"].mean(axis=0)[i, j]
  117. / samples["counts"].mean(axis=0)[i, :].sum(),
  118. }
  119. for i in range(n_topics)
  120. for j in range(n_topics)
  121. ]
  122. )
  123. latex = (
  124. transfers[transfers["from"] != transfers["to"]]
  125. .sort_values("magnitude", ascending=False)
  126. .head(10)
  127. .to_latex(
  128. columns=["from", "to", "magnitude"],
  129. header=["Origin research area", "Target research area", "Magnitude"],
  130. index=False,
  131. multirow=True,
  132. multicolumn=True,
  133. column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
  134. escape=False,
  135. float_format=lambda x: f"{x:.2f}",
  136. caption="Largest transfers across research areas.",
  137. label="table:largest_transfers",
  138. )
  139. )
  140. latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
  141. with open(opj(args.input, "largest_transfers.tex"), "w+") as fp:
  142. fp.write(latex)
  143. latex = (
  144. transfers[transfers["from"] == transfers["to"]]
  145. .sort_values("ratio", ascending=False)
  146. .head(10)
  147. .to_latex(
  148. columns=["from", "ratio"],
  149. header=["Research area", "Conservatism"],
  150. index=False,
  151. multirow=True,
  152. multicolumn=True,
  153. column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
  154. escape=False,
  155. float_format=lambda x: f"{x:.2f}",
  156. caption="Most conservative research areas.",
  157. label="table:most_conservative",
  158. )
  159. )
  160. latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
  161. with open(opj(args.input, "most_conservative.tex"), "w+") as fp:
  162. fp.write(latex)