123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import numpy as np
- import pandas as pd
- import plotly.graph_objects as go
- import plotly.io as pio
- pio.kaleido.scope.mathjax = None
- import seaborn as sns
- import argparse
- from os.path import join as opj
- parser = argparse.ArgumentParser()
- parser.add_argument("--compact", action="store_true", default=False)
- args = parser.parse_args()
- mcmc = [
- "output/etm_20_r/ei_mle_control_nu_aggregate_0_1.csv.npz",
- "output/etm_20_r/ei_mle_control_nu_aggregate_1_2.csv.npz",
- "output/etm_20_r/ei_mle_control_nu_aggregate_2_3.csv.npz",
- ]
- portfolios = [
- "output/etm_20_r/aggregate_0_1.csv",
- "output/etm_20_r/aggregate_1_2.csv",
- "output/etm_20_r/aggregate_2_3.csv",
- ]
- samples = [np.load(_mcmc) for _mcmc in mcmc]
- portfolios = [pd.read_csv(portfolio) for portfolio in portfolios]
- n_bins = len(mcmc)
- bai_wl = set(portfolios[0]["bai"])
- for k in np.arange(n_bins):
- bai_wl = bai_wl&set(portfolios[k]["bai"])
- topics = pd.read_csv(opj("output/etm_20_r", "topics.csv"))
- n_topics = len(topics)
- junk = topics["label"].str.contains("Junk")
- topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
- topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
- counts = []
- for k in range(len(portfolios)):
- df = portfolios[k]
- keep = df["bai"].isin(bai_wl)
- NR = np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int)
- NC = np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int)
- N = NR.shape[0]
- NR = NR[:,~junk]
- NC = NC[:,~junk]
- x = NR/NR.sum(axis=1)[:,np.newaxis]
- print(len(df))
- theta = samples[k]["beta"]
- print(theta.shape)
- counts.append(np.einsum("ai,aij->ij", x[keep,:], theta[keep,:,:]))
- n_topics = counts[0].shape[0]
- source = []
- target = []
- value = []
- color = []
- colors = sns.color_palette("hls", n_topics).as_hex()
- print(colors)
- total_incoming = samples[0]["counts"].sum(axis=1)
- total_outcoming = samples[0]["counts"].sum(axis=0)
- incoming_labels = [
- f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)
- ]
- outcoming_labels = [
- f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)
- ]
- for i in range(n_topics):
- for j in range(n_topics):
- for k in range(n_bins):
- v = counts[k][i, j]/counts[k].sum()
- x = counts[k][i, :].sum()/counts[k].sum()
- y = counts[k][:, j].sum()/counts[k].sum()
- value.append(v)
- source.append(i + k*n_topics)
- target.append(j + (k+1)*n_topics)
- highlight = (v >= x * y)
- # print(topics[i], topics[j])
- # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in))
- # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out))
- # print((len(args.highlight_in + args.highlight_out) == 0))
- color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)")
- fig = go.Figure(
- data=[
- go.Sankey(
- node=dict(
- pad=7,
- thickness=20,
- line=dict(color="black", width=0.5),
- label=topics+[""]*(n_topics*(n_bins-1))+topics,
- color=colors*(n_bins+1),
- ),
- link=dict(
- source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
- target=target,
- value=value,
- color=color,
- ),
- )
- ]
- )
- for x_coordinate, column_name in enumerate(["2000-2004","2005-2009","2010-2014","2015-2019"]):
- fig.add_annotation(
- x=x_coordinate / n_bins,
- y=1.05,
- xref="paper",
- yref="paper",
- text=column_name,
- showarrow=False,
- font=dict(
- size=13,
- ),
- align="center",
- )
- fig.update_layout(font_size=11)
- fig.update_layout(
- autosize=False,
- width=650 if args.compact else 800,
- height=600,
- )
- fig.write_image(
- opj(
- "output/etm_20_r",
- f"sankey_four.pdf",
- ),
- width=650 if args.compact else 800,
- height=600,
- )
|