import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.io as pio pio.kaleido.scope.mathjax = None import seaborn as sns import argparse from os.path import join as opj parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--suffix") parser.add_argument("--compact", action="store_true", default=False) parser.add_argument("--transparent", action="store_true", default=False) parser.add_argument("--highlight-in", default=list(), nargs="+") parser.add_argument("--highlight-out", default=list(), nargs="+") parser.add_argument("--mle", action="store_true", default=False) args = parser.parse_args() format = "mle" if args.mle else "samples" samples = np.load(opj(args.input, f"ei_{format}_{args.suffix}.npz")) topics = pd.read_csv(opj(args.input, "topics.csv")) topics["label"] = topics["label"].str.replace("\\&", "&", regex=False) topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist() if args.mle: n_topics = samples["counts"].shape[0] else: n_topics = samples["counts"].shape[1] source = [] target = [] value = [] color = [] colors = sns.color_palette("hls", n_topics).as_hex() print(colors) if args.mle: total_incoming = samples["counts"].sum(axis=1) total_outcoming = samples["counts"].sum(axis=0) else: total_incoming = samples["counts"].mean(axis=0).sum(axis=1) total_outcoming = samples["counts"].mean(axis=0).sum(axis=0) incoming_labels = [ f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics) ] outcoming_labels = [ f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics) ] for i in range(n_topics): for j in range(n_topics): if args.mle: v = samples["counts"][i, j] x = samples["counts"][i, :].sum() y = samples["counts"][:, j].sum() else: v = samples["counts"][:, i, j].mean() x = samples["counts"].mean(axis=0)[i, :].sum() y = samples["counts"].mean(axis=0)[:, j].sum() value.append(v) source.append(i) target.append(j + n_topics) highlight = ( (len(args.highlight_in) > 0 and topics[i] in args.highlight_in) or (len(args.highlight_out) > 0 and topics[j] in args.highlight_out) or ( (not args.transparent) and (len(args.highlight_in + args.highlight_out) == 0) ) ) and (v >= x * y) # print(topics[i], topics[j]) # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in)) # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out)) # print((len(args.highlight_in + args.highlight_out) == 0)) color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)") fig = go.Figure( data=[ go.Sankey( node=dict( pad=7, thickness=20, line=dict(color="black", width=0.5), label=topics + topics, color=colors + colors, ), link=dict( source=source, # indices correspond to labels, eg A1, A2, A1, B1, ... target=target, value=value, color=color, ), ) ] ) fig.update_layout(font_size=13) fig.update_layout( autosize=False, width=650 if args.compact else 800, height=600, ) # fig.show() highlights = f"{'_'.join(args.highlight_in)}{'_'.join(args.highlight_out)}".replace( "/", "_" ).replace("&", "_").replace(" ", "").lower() if len(highlights): highlights = f"_{highlights}" fig.write_image( opj( args.input, f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}{'_transparent' if args.transparent else ''}{highlights}.pdf", ), width=650 if args.compact else 800, height=600, ) transfers = pd.DataFrame( [ { "from": topics[i], "to": topics[j], "magnitude": 100 * samples["counts"].mean(axis=0)[i, j], "ratio": samples["counts"].mean(axis=0)[i, j] / samples["counts"].mean(axis=0)[i, :].sum(), } for i in range(n_topics) for j in range(n_topics) ] ) latex = ( transfers[transfers["from"] != transfers["to"]] .sort_values("magnitude", ascending=False) .head(10) .to_latex( columns=["from", "to", "magnitude"], header=["Origin research area", "Target research area", "Magnitude"], index=False, multirow=True, multicolumn=True, column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c", escape=False, float_format=lambda x: f"{x:.2f}", caption="Largest transfers across research areas.", label="table:largest_transfers", ) ) latex = latex.replace("\\\\\n", "\\\\ \\hline\n") with open(opj(args.input, "largest_transfers.tex"), "w+") as fp: fp.write(latex) latex = ( transfers[transfers["from"] == transfers["to"]] .sort_values("ratio", ascending=False) .head(10) .to_latex( columns=["from", "ratio"], header=["Research area", "Conservatism"], index=False, multirow=True, multicolumn=True, column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c", escape=False, float_format=lambda x: f"{x:.2f}", caption="Most conservative research areas.", label="table:most_conservative", ) ) latex = latex.replace("\\\\\n", "\\\\ \\hline\n") with open(opj(args.input, "most_conservative.tex"), "w+") as fp: fp.write(latex)