123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import numpy as np
- import pandas as pd
- import plotly.graph_objects as go
- import plotly.io as pio
- pio.kaleido.scope.mathjax = None
- import seaborn as sns
- import argparse
- from os.path import join as opj
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--suffix")
- parser.add_argument("--compact", action="store_true", default=False)
- parser.add_argument("--transparent", action="store_true", default=False)
- parser.add_argument("--highlight-in", default=list(), nargs="+")
- parser.add_argument("--highlight-out", default=list(), nargs="+")
- parser.add_argument("--mle", action="store_true", default=False)
- args = parser.parse_args()
- format = "mle" if args.mle else "samples"
- samples = np.load(opj(args.input, f"ei_{format}_{args.suffix}.npz"))
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- topics["label"] = topics["label"].str.replace("\\&", "&", regex=False)
- topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
- if args.mle:
- n_topics = samples["counts"].shape[0]
- else:
- n_topics = samples["counts"].shape[1]
- source = []
- target = []
- value = []
- color = []
- colors = sns.color_palette("hls", n_topics).as_hex()
- print(colors)
- if args.mle:
- total_incoming = samples["counts"].sum(axis=1)
- total_outcoming = samples["counts"].sum(axis=0)
- else:
- total_incoming = samples["counts"].mean(axis=0).sum(axis=1)
- total_outcoming = samples["counts"].mean(axis=0).sum(axis=0)
- incoming_labels = [
- f"{topics[i]} ({100*total_incoming[i]:.0f}%)" for i in range(n_topics)
- ]
- outcoming_labels = [
- f"{topics[i]} ({100*total_outcoming[i]:.0f}%)" for i in range(n_topics)
- ]
- for i in range(n_topics):
- for j in range(n_topics):
- if args.mle:
- v = samples["counts"][i, j]
- x = samples["counts"][i, :].sum()
- y = samples["counts"][:, j].sum()
- else:
- v = samples["counts"][:, i, j].mean()
- x = samples["counts"].mean(axis=0)[i, :].sum()
- y = samples["counts"].mean(axis=0)[:, j].sum()
- value.append(v)
- source.append(i)
- target.append(j + n_topics)
- highlight = (
- (len(args.highlight_in) > 0 and topics[i] in args.highlight_in)
- or (len(args.highlight_out) > 0 and topics[j] in args.highlight_out)
- or (
- (not args.transparent)
- and (len(args.highlight_in + args.highlight_out) == 0)
- )
- ) and (v >= x * y)
- # print(topics[i], topics[j])
- # print((len(args.highlight_in) > 0 and topics[i] in args.highlight_in))
- # print((len(args.highlight_out) > 0 and topics[j] in args.highlight_out))
- # print((len(args.highlight_in + args.highlight_out) == 0))
- color.append("rgba(100,100,100,0.3)" if highlight else "rgba(200,200,200,0.02)")
- fig = go.Figure(
- data=[
- go.Sankey(
- node=dict(
- pad=7,
- thickness=20,
- line=dict(color="black", width=0.5),
- label=topics + topics,
- color=colors + colors,
- ),
- link=dict(
- source=source, # indices correspond to labels, eg A1, A2, A1, B1, ...
- target=target,
- value=value,
- color=color,
- ),
- )
- ]
- )
- fig.update_layout(font_size=13)
- fig.update_layout(
- autosize=False,
- width=650 if args.compact else 800,
- height=600,
- )
- # fig.show()
- highlights = f"{'_'.join(args.highlight_in)}{'_'.join(args.highlight_out)}".replace(
- "/", "_"
- ).replace("&", "_").replace(" ", "").lower()
- if len(highlights):
- highlights = f"_{highlights}"
- fig.write_image(
- opj(
- args.input,
- f"sankey_control_{args.suffix}{'_compact' if args.compact else ''}{'_transparent' if args.transparent else ''}{highlights}.pdf",
- ),
- width=650 if args.compact else 800,
- height=600,
- )
- transfers = pd.DataFrame(
- [
- {
- "from": topics[i],
- "to": topics[j],
- "magnitude": 100 * samples["counts"].mean(axis=0)[i, j],
- "ratio": samples["counts"].mean(axis=0)[i, j]
- / samples["counts"].mean(axis=0)[i, :].sum(),
- }
- for i in range(n_topics)
- for j in range(n_topics)
- ]
- )
- latex = (
- transfers[transfers["from"] != transfers["to"]]
- .sort_values("magnitude", ascending=False)
- .head(10)
- .to_latex(
- columns=["from", "to", "magnitude"],
- header=["Origin research area", "Target research area", "Magnitude"],
- index=False,
- multirow=True,
- multicolumn=True,
- column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
- escape=False,
- float_format=lambda x: f"{x:.2f}",
- caption="Largest transfers across research areas.",
- label="table:largest_transfers",
- )
- )
- latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
- with open(opj(args.input, "largest_transfers.tex"), "w+") as fp:
- fp.write(latex)
- latex = (
- transfers[transfers["from"] == transfers["to"]]
- .sort_values("ratio", ascending=False)
- .head(10)
- .to_latex(
- columns=["from", "ratio"],
- header=["Research area", "Conservatism"],
- index=False,
- multirow=True,
- multicolumn=True,
- column_format="b{0.4\\textwidth}|b{0.4\\textwidth}|c",
- escape=False,
- float_format=lambda x: f"{x:.2f}",
- caption="Most conservative research areas.",
- label="table:most_conservative",
- )
- )
- latex = latex.replace("\\\\\n", "\\\\ \\hline\n")
- with open(opj(args.input, "most_conservative.tex"), "w+") as fp:
- fp.write(latex)
|