import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.io as pio pio.kaleido.scope.mathjax = None import seaborn as sns import argparse from os.path import join as opj parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--suffix") parser.add_argument("--compact", action="store_true", default=False) args = parser.parse_args() samples = np.load(opj(args.input, f"ei_samples_{args.suffix}.npz")) topics = pd.read_csv(opj(args.input, "topics.csv")) topics["label"] = topics["label"].str.replace("\\&", "&", regex=False) topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist() n_topics = samples["counts"].shape[1] source = [] target = [] value = [] color = [] colors = sns.color_palette("hls", n_topics).as_hex() print(colors) total_incoming = samples["counts"].mean(axis=0).sum(axis=1) total_outcoming = samples["counts"].mean(axis=0).sum(axis=0) incoming_labels = [f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)] outcoming_labels = [f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)] for i in range(n_topics): for j in range(n_topics): v = samples["counts"][:, i, j].mean() x = samples["counts"].mean(axis=0)[i,:].sum() y = samples["counts"].mean(axis=0)[:,j].sum() value.append(v) source.append(i) target.append(j + n_topics) color.append("rgba(100,100,100,0.2)" if v >= x*y else "rgba(200,200,200,0.02)") fig = go.Figure( data=[ go.Sankey( node=dict( pad=7, thickness=20, line=dict(color="black", width=0.5), label=topics + topics, color=colors + colors, ), link=dict( source=source, # indices correspond to labels, eg A1, A2, A1, B1, ... target=target, value=value, color=color, ), ) ] ) fig.update_layout(font_size=13) fig.update_layout( autosize=False, width=650 if args.compact else 800, height=600, ) #fig.show() fig.write_image(opj(args.input, f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}.pdf"), width=650 if args.compact else 800, height=600) transfers = pd.DataFrame([ { "from": topics[i], "to": topics[j], "magnitude": 100*samples["counts"].mean(axis=0)[i,j], "ratio": samples["counts"].mean(axis=0)[i,j]/samples["counts"].mean(axis=0)[i,:].sum() } for i in range(n_topics) for j in range(n_topics) ]) latex = transfers[transfers["from"]!=transfers["to"]].sort_values("magnitude", ascending=False).head(10).to_latex( columns=["from", "to", "magnitude"], header=["Origin research area", "Target research area", "Magnitude"], index=False, multirow=True, multicolumn=True, column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c', escape=False, float_format=lambda x: f"{x:.2f}", caption="Largest transfers across research areas.", label="table:largest_transfers" ) latex = latex.replace('\\\\\n', '\\\\ \\hline\n') with open(opj(args.input, "largest_transfers.tex"), "w+") as fp: fp.write(latex) latex = transfers[transfers["from"]==transfers["to"]].sort_values("ratio", ascending=False).head(10).to_latex( columns=["from", "ratio"], header=["Research area", "Conservatism"], index=False, multirow=True, multicolumn=True, column_format='b{0.4\\textwidth}|b{0.4\\textwidth}|c', escape=False, float_format=lambda x: f"{x:.2f}", caption="Most conservative research areas.", label="table:most_conservative" ) latex = latex.replace('\\\\\n', '\\\\ \\hline\n') with open(opj(args.input, "most_conservative.tex"), "w+") as fp: fp.write(latex)