123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import numpy as np
- import pandas as pd
- import plotly.graph_objects as go
- import plotly.io as pio
- pio.kaleido.scope.mathjax = None
- import seaborn as sns
- import argparse
- from os.path import join as opj
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--suffix")
- parser.add_argument("--compact", action="store_true", default=False)
- args = parser.parse_args()
- samples = np.load(opj(args.input, f"ei_samples_{args.suffix}.npz"))
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
- topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
- n_topics = samples["counts"].shape[1]
- source = []
- target = []
- value = []
- color = []
- colors = sns.color_palette("hls", n_topics).as_hex()
- print(colors)
- total_incoming = samples["counts"].mean(axis=0).sum(axis=1)
- total_outcoming = samples["counts"].mean(axis=0).sum(axis=0)
- incoming_labels = [f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)]
- outcoming_labels = [f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)]
- for i in range(n_topics):
- for j in range(n_topics):
- v = samples["counts"][:, i, j].mean()
- x = samples["counts"].mean(axis=0)[i,:].sum()
- y = samples["counts"].mean(axis=0)[:,j].sum()
- value.append(v)
- source.append(i)
- target.append(j + n_topics)
- color.append("rgba(100,100,100,0.2)" if v >= x*y else "rgba(200,200,200,0.02)")
- fig = go.Figure(
- data=[
- go.Sankey(
- node=dict(
- pad=7,
- thickness=20,
- line=dict(color="black", width=0.5),
- label=topics + topics,
- color=colors + colors,
- ),
- link=dict(
- source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
- target=target,
- value=value,
- color=color,
- ),
- )
- ]
- )
- fig.update_layout(font_size=13)
- fig.update_layout(
- autosize=False,
- width=650 if args.compact else 800,
- height=600,
- )
- #fig.show()
- fig.write_image(opj(args.input, f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}.pdf"), width=650 if args.compact else 800, height=600)
- transfers = pd.DataFrame([
- {
- "from": topics[i],
- "to": topics[j],
- "magnitude": 100*samples["counts"].mean(axis=0)[i,j],
- "ratio": samples["counts"].mean(axis=0)[i,j]/samples["counts"].mean(axis=0)[i,:].sum()
- }
- for i in range(n_topics)
- for j in range(n_topics)
- ])
- latex = transfers[transfers["from"]!=transfers["to"]].sort_values("magnitude", ascending=False).head(10).to_latex(
- columns=["from", "to", "magnitude"],
- header=["Origin research area", "Target research area", "Magnitude"],
- index=False,
- multirow=True,
- multicolumn=True,
- column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c',
- escape=False,
- float_format=lambda x: f"{x:.2f}",
- caption="Largest transfers across research areas.",
- label="table:largest_transfers"
- )
- latex = latex.replace('\\\\\n', '\\\\ \\hline\n')
- with open(opj(args.input, "largest_transfers.tex"), "w+") as fp:
- fp.write(latex)
- latex = transfers[transfers["from"]==transfers["to"]].sort_values("ratio", ascending=False).head(10).to_latex(
- columns=["from", "ratio"],
- header=["Research area", "Conservatism"],
- index=False,
- multirow=True,
- multicolumn=True,
- column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c',
- escape=False,
- float_format=lambda x: f"{x:.2f}",
- caption="Most conservative research areas.",
- label="table:most_conservative"
- )
- latex = latex.replace('\\\\\n', '\\\\ \\hline\n')
- with open(opj(args.input, "most_conservative.tex"), "w+") as fp:
- fp.write(latex)
|