ei_sankey_four.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. import pandas as pd
  3. import plotly.graph_objects as go
  4. import plotly.io as pio
  5. pio.kaleido.scope.mathjax = None
  6. import seaborn as sns
  7. import argparse
  8. from os.path import join as opj
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--compact", action="store_true", default=False)
  11. args = parser.parse_args()
  12. mcmc = [
  13. "output/etm_20_r/ei_mle_control_nu_aggregate_0_1.csv.npz",
  14. "output/etm_20_r/ei_mle_control_nu_aggregate_1_2.csv.npz",
  15. "output/etm_20_r/ei_mle_control_nu_aggregate_2_3.csv.npz",
  16. ]
  17. portfolios = [
  18. "output/etm_20_r/aggregate_0_1.csv",
  19. "output/etm_20_r/aggregate_1_2.csv",
  20. "output/etm_20_r/aggregate_2_3.csv",
  21. ]
  22. samples = [np.load(_mcmc) for _mcmc in mcmc]
  23. portfolios = [pd.read_csv(portfolio) for portfolio in portfolios]
  24. n_bins = len(mcmc)
  25. bai_wl = set(portfolios[0]["bai"])
  26. for k in np.arange(n_bins):
  27. bai_wl = bai_wl&set(portfolios[k]["bai"])
  28. topics = pd.read_csv(opj("output/etm_20_r", "topics.csv"))
  29. n_topics = len(topics)
  30. junk = topics["label"].str.contains("Junk")
  31. topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
  32. topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
  33. counts = []
  34. for k in range(len(portfolios)):
  35. df = portfolios[k]
  36. keep = df["bai"].isin(bai_wl)
  37. NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
  38. NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
  39. N = NR.shape[0]
  40. NR = NR[:,~junk]
  41. NC = NC[:,~junk]
  42. x = NR/NR.sum(axis=1)[:,np.newaxis]
  43. print(len(df))
  44. theta = samples[k]["beta"]
  45. print(theta.shape)
  46. counts.append(np.einsum("ai,aij->ij", x[keep,:], theta[keep,:,:]))
  47. n_topics = counts[0].shape[0]
  48. source = []
  49. target = []
  50. value = []
  51. color = []
  52. colors = sns.color_palette("hls", n_topics).as_hex()
  53. print(colors)
  54. total_incoming = samples[0]["counts"].sum(axis=1)
  55. total_outcoming = samples[0]["counts"].sum(axis=0)
  56. incoming_labels = [
  57. f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)
  58. ]
  59. outcoming_labels = [
  60. f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)
  61. ]
  62. for i in range(n_topics):
  63. for j in range(n_topics):
  64. for k in range(n_bins):
  65. v = counts[k][i, j]/counts[k].sum()
  66. x = counts[k][i, :].sum()/counts[k].sum()
  67. y = counts[k][:, j].sum()/counts[k].sum()
  68. value.append(v)
  69. source.append(i + k*n_topics)
  70. target.append(j + (k+1)*n_topics)
  71. highlight = (v >= x * y)
  72. # print(topics[i], topics[j])
  73. # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in))
  74. # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out))
  75. # print((len(args.highlight_in + args.highlight_out) == 0))
  76. color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)")
  77. fig = go.Figure(
  78. data=[
  79. go.Sankey(
  80. node=dict(
  81. pad=7,
  82. thickness=20,
  83. line=dict(color="black", width=0.5),
  84. label=topics+[""]*(n_topics*(n_bins-1))+topics,
  85. color=colors*(n_bins+1),
  86. ),
  87. link=dict(
  88. source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
  89. target=target,
  90. value=value,
  91. color=color,
  92. ),
  93. )
  94. ]
  95. )
  96. for x_coordinate, column_name in enumerate(["2000-2004","2005-2009","2010-2014","2015-2019"]):
  97. fig.add_annotation(
  98. x=x_coordinate / n_bins,
  99. y=1.05,
  100. xref="paper",
  101. yref="paper",
  102. text=column_name,
  103. showarrow=False,
  104. font=dict(
  105. size=13,
  106. ),
  107. align="center",
  108. )
  109. fig.update_layout(font_size=11)
  110. fig.update_layout(
  111. autosize=False,
  112. width=650 if args.compact else 800,
  113. height=600,
  114. )
  115. fig.write_image(
  116. opj(
  117. "output/etm_20_r",
  118. f"sankey_four.pdf",
  119. ),
  120. width=650 if args.compact else 800,
  121. height=600,
  122. )