123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- import pandas as pd
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.stats import entropy
- from os.path import join as opj, exists
- import argparse
- import textwrap
- from sklearn.preprocessing import MultiLabelBinarizer
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--dataset", default="inspire-harvest/database")
- args = parser.parse_args()
- ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
- articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "pacs_codes", "curated"]]
- articles = articles[articles["date_created"].str.len() >= 4]
- articles["year"] = articles["date_created"].str[:4].astype(int)-2000
- articles["article_id"] = articles.article_id.astype(int)
- _articles = pd.read_csv(opj(args.input, "articles.csv"))
- articles = _articles.merge(articles, how="inner")
- years = articles["year"].values
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- is_junk_topic = np.array(topics["label"].str.contains("Junk"))
- pacs_description = pd.read_csv(opj(args.dataset, "pacs_codes.csv")).set_index("code")["description"].to_dict()
- codes = set(pacs_description.keys())
- has_pacs_code = articles["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0
- articles = articles[has_pacs_code]
- binarizer = MultiLabelBinarizer()
- pacs = binarizer.fit_transform(articles["pacs_codes"])
- n_categories = pacs.shape[1]
- print((pacs-pacs.mean(axis=0)).shape)
- cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)])
- cat_labels = binarizer.inverse_transform(cat_classes)
- with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
- etm_instance = pickle.load(handle)
- if not exists(opj(args.input, "theta.npy")):
- theta = etm_instance.get_document_topic_dist()
- np.save(opj(args.input, "theta.npy"), theta)
- else:
- theta = np.load(opj(args.input, "theta.npy"))
- print(theta[:,~is_junk_topic].mean(axis=0).sum())
- topic_matrix = np.load(opj(args.input, "topics_counts.npy"))
- print("Theta average entropy:", np.exp(entropy(theta[:,~is_junk_topic], axis=1)).mean())
- print("Topic matrix average entropy:", np.nanmean(np.exp(np.nan_to_num(entropy(topic_matrix, axis=1)))))
- theta = theta[has_pacs_code,:]
- n_articles = theta.shape[0]
- R = (theta-theta.mean(axis=0)).T@(pacs-pacs.mean(axis=0))/n_articles
- R /= np.outer(theta.std(axis=0), pacs.std(axis=0))
- print(R.shape)
- print(R)
- topics["top_pacs"] = ""
- topics["relevant"] = 1
- topics.loc[is_junk_topic, "label"] = "Uninterpretable"
- topics.loc[is_junk_topic, "relevant"] = 0
- for i in range(R.shape[0]):
- ind = np.argsort(R[i,:])
- top = np.array(ind[-5:])[::-1]
- topics.loc[i, "top_pacs"] = "\\\\ ".join([f"{textwrap.shorten(pacs_description[cat_labels[j][0]], width=40)} ({R[i,j]:.2f})" for j in top if cat_labels[j][0] in pacs_description])
- topics["top_words"] = topics["top_words"].str.replace(",", ", ")
- topics["top_words"] = topics["top_words"].str.replace("_", "\\_")
- # topics["top_words"] = topics["top_words"].str.replace(r"(^(.*)-ray)|((.*)-ray)", "$\\gamma$-ray", regex=True)
- # topics["top_words"] = topics["top_words"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=45, break_long_words=False)))
- # topics["top_words"] = topics["top_words"].apply(lambda s: '\\begin{tabular}{l}' + s +'\\end{tabular}')
- topics["top_pacs"] = topics["top_pacs"].apply(lambda s: '\\shortstack[l]{' + s +'}')
- topics.sort_values(["relevant", "label"], ascending=[False, True], inplace=True)
- pd.set_option('display.max_colwidth', None)
- latex = topics.to_latex(
- columns=["label", "top_words", "top_pacs"],
- header = ["Research area", "Top words", "Most correlated PACS categories"],
- index=False,
- multirow=True,
- multicolumn=True,
- longtable=True,
- column_format='p{0.15\\textwidth}|b{0.425\\textwidth}|b{0.425\\textwidth}',
- escape=False,
- caption="Research areas, their top-words, and their correlation with a standard classification (PACS).",
- label="table:research_areas"
- )
- latex = latex.replace(')} \\\\\n', ')}\\\\ \\hline\n')
- with open(opj(args.input, "topics.tex"), "w+") as fp:
- fp.write(latex)
- keep_pacs = pacs.sum(axis=0)>=100
- R = R[:,keep_pacs]
- R = R[~is_junk_topic]
- labels = [cat_labels[i][0] for i in range(len(keep_pacs)) if keep_pacs[i]]
- breaks_1 = np.array([True if (i==0 or labels[i-1][:5]!=labels[i][:5]) else False for i in range(len(labels))])
- breaks_2 = np.array([True if (i==0 or labels[i-1][:2]!=labels[i][:2]) else False for i in range(len(labels))])
- order = np.load(opj(args.input, "topics_order.npy"))
- print(order)
- R = R[order,:]
- import seaborn as sns
- import matplotlib
- from matplotlib import pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "xelatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- plt.rcParams["text.latex.preamble"].join([
- r"\usepackage{amsmath}",
- r"\setmainfont{amssymb}",
- ])
- from matplotlib.gridspec import GridSpec
- plt.clf()
- fig = plt.figure(figsize=(6.4*2, 6.4*2.5))
- gs = GridSpec(9,8,hspace=0,wspace=0)
- ax_heatmap = fig.add_subplot(gs[0:8,2:8])
- sns.heatmap(
- R.T,
- cmap="RdBu",
- vmin=-0.5, vmax=+0.5,
- square=False,
- ax=ax_heatmap,
- cbar_kws={"shrink": 0.25}
- )
- ax_heatmap.invert_yaxis()
- ax_heatmap.yaxis.set_visible(False)
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist()
- ordered_topics = [topics[i] for i in order]
- ax_heatmap.set_xticklabels(ordered_topics, rotation = 90)
- xmin, xmax = ax_heatmap.get_xlim()
- ymin, ymax = ax_heatmap.get_ylim()
- ax_breaks = fig.add_subplot(gs[0:8,0:2])
- ax_breaks.set_xlim(0,8)
- ax_breaks.set_ylim(ymin,ymax)
- ax_breaks.axis("off")
- import re
- html = open(opj(args.dataset, "pacs.html"), "r").read()
- # Regex pattern to extract the code and label
- pattern = r'<span class="pacs_num">(\d{2}\.\d{2}\.\d{2})</span>\s*(.*?)</h2>'
- matches = re.findall(pattern, html)
- matches = {
- matches[i][0][:2]: matches[i][1]
- for i in range(len(matches))
- }
- import textwrap
- prev_break = 0
- for i in range(len(breaks_1)):
- if breaks_1[i] == True:
- ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125/2)
- if breaks_2[i] == True or i == len(breaks_1)-1:
- ax_breaks.hlines(y=i-0.5, xmin=2, xmax=10, color="lightgray", lw=0.125)
- ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125)
- print(prev_break, i, labels[i], prev_break+(i-prev_break)/2.0)
- if prev_break != i:
- text = textwrap.shorten(matches[labels[i-1][:2]], width=35*3)
- lines = textwrap.wrap(text, width=35)
- too_small = len(lines) >= 0.35*(i-prev_break-2)
- if too_small:
- text = textwrap.shorten(matches[labels[i-1][:2]], width=35*2)
- lines = textwrap.wrap(text, width=35)
- too_small = len(lines) >= 0.35*(i-prev_break-2)
- if too_small:
- text = textwrap.shorten(matches[labels[i-1][:2]], width=35)
- lines = textwrap.wrap(text, width=35)
- too_small = len(lines) >= 0.5*(i-prev_break-2)
- if (not too_small) or i == len(breaks_1)-1:
- ax_breaks.text(
- 7.5,
- prev_break+(i-prev_break)/2.0,
- "\n".join(lines),
- ha="right",
- va="center",
- fontsize=8
- )
- prev_break = i
- plt.savefig(opj(args.input, "pacs_clustermap.eps"), bbox_inches="tight")
|