import pandas as pd import pandas as pd import pickle import numpy as np from scipy.stats import entropy from os.path import join as opj, exists import argparse import textwrap from sklearn.preprocessing import MultiLabelBinarizer parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--dataset", default="inspire-harvest/database") args = parser.parse_args() ngrams = pd.read_csv(opj(args.input, "ngrams.csv")) articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "pacs_codes", "curated"]] articles = articles[articles["date_created"].str.len() >= 4] articles["year"] = articles["date_created"].str[:4].astype(int)-2000 articles["article_id"] = articles.article_id.astype(int) _articles = pd.read_csv(opj(args.input, "articles.csv")) articles = _articles.merge(articles, how="inner") years = articles["year"].values topics = pd.read_csv(opj(args.input, "topics.csv")) is_junk_topic = np.array(topics["label"].str.contains("Junk")) pacs_description = pd.read_csv(opj(args.dataset, "pacs_codes.csv")).set_index("code")["description"].to_dict() codes = set(pacs_description.keys()) has_pacs_code = articles["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0 articles = articles[has_pacs_code] binarizer = MultiLabelBinarizer() pacs = binarizer.fit_transform(articles["pacs_codes"]) n_categories = pacs.shape[1] print((pacs-pacs.mean(axis=0)).shape) cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)]) cat_labels = binarizer.inverse_transform(cat_classes) with open(opj(args.input, "etm_instance.pickle"), "rb") as handle: etm_instance = pickle.load(handle) if not exists(opj(args.input, "theta.npy")): theta = etm_instance.get_document_topic_dist() np.save(opj(args.input, "theta.npy"), theta) else: theta = np.load(opj(args.input, "theta.npy")) print(theta[:,~is_junk_topic].mean(axis=0).sum()) topic_matrix = np.load(opj(args.input, "topics_counts.npy")) print("Theta average entropy:", np.exp(entropy(theta[:,~is_junk_topic], axis=1)).mean()) print("Topic matrix average entropy:", np.nanmean(np.exp(np.nan_to_num(entropy(topic_matrix, axis=1))))) theta = theta[has_pacs_code,:] n_articles = theta.shape[0] R = (theta-theta.mean(axis=0)).T@(pacs-pacs.mean(axis=0))/n_articles R /= np.outer(theta.std(axis=0), pacs.std(axis=0)) print(R.shape) print(R) topics["top_pacs"] = "" topics["relevant"] = 1 topics.loc[is_junk_topic, "label"] = "Uninterpretable" topics.loc[is_junk_topic, "relevant"] = 0 for i in range(R.shape[0]): ind = np.argsort(R[i,:]) top = np.array(ind[-5:])[::-1] topics.loc[i, "top_pacs"] = "\\\\ ".join([f"{textwrap.shorten(pacs_description[cat_labels[j][0]], width=40)} ({R[i,j]:.2f})" for j in top if cat_labels[j][0] in pacs_description]) topics["top_words"] = topics["top_words"].str.replace(",", ", ") topics["top_words"] = topics["top_words"].str.replace("_", "\\_") # topics["top_words"] = topics["top_words"].str.replace(r"(^(.*)-ray)|((.*)-ray)", "$\\gamma$-ray", regex=True) # topics["top_words"] = topics["top_words"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=45, break_long_words=False))) # topics["top_words"] = topics["top_words"].apply(lambda s: '\\begin{tabular}{l}' + s +'\\end{tabular}') topics["top_pacs"] = topics["top_pacs"].apply(lambda s: '\\shortstack[l]{' + s +'}') topics.sort_values(["relevant", "label"], ascending=[False, True], inplace=True) pd.set_option('display.max_colwidth', None) latex = topics.to_latex( columns=["label", "top_words", "top_pacs"], header = ["Research area", "Top words", "Most correlated PACS categories"], index=False, multirow=True, multicolumn=True, longtable=True, column_format='p{0.15\\textwidth}|b{0.425\\textwidth}|b{0.425\\textwidth}', escape=False, caption="Research areas, their top-words, and their correlation with a standard classification (PACS).", label="table:research_areas" ) latex = latex.replace(')} \\\\\n', ')}\\\\ \\hline\n') with open(opj(args.input, "topics.tex"), "w+") as fp: fp.write(latex) keep_pacs = pacs.sum(axis=0)>=100 R = R[:,keep_pacs] R = R[~is_junk_topic] labels = [cat_labels[i][0] for i in range(len(keep_pacs)) if keep_pacs[i]] breaks_1 = np.array([True if (i==0 or labels[i-1][:5]!=labels[i][:5]) else False for i in range(len(labels))]) breaks_2 = np.array([True if (i==0 or labels[i-1][:2]!=labels[i][:2]) else False for i in range(len(labels))]) order = np.load(opj(args.input, "topics_order.npy")) print(order) R = R[order,:] import seaborn as sns import matplotlib from matplotlib import pyplot as plt matplotlib.use("pgf") matplotlib.rcParams.update( { "pgf.texsystem": "xelatex", "font.family": "serif", "font.serif": "Times New Roman", "text.usetex": True, "pgf.rcfonts": False, } ) plt.rcParams["text.latex.preamble"].join([ r"\usepackage{amsmath}", r"\setmainfont{amssymb}", ]) from matplotlib.gridspec import GridSpec plt.clf() fig = plt.figure(figsize=(6.4*2, 6.4*2.5)) gs = GridSpec(9,8,hspace=0,wspace=0) ax_heatmap = fig.add_subplot(gs[0:8,2:8]) sns.heatmap( R.T, cmap="RdBu", vmin=-0.5, vmax=+0.5, square=False, ax=ax_heatmap, cbar_kws={"shrink": 0.25} ) ax_heatmap.invert_yaxis() ax_heatmap.yaxis.set_visible(False) topics = pd.read_csv(opj(args.input, "topics.csv")) topics = topics[~topics["label"].str.contains("Junk")]["label"].tolist() ordered_topics = [topics[i] for i in order] ax_heatmap.set_xticklabels(ordered_topics, rotation = 90) xmin, xmax = ax_heatmap.get_xlim() ymin, ymax = ax_heatmap.get_ylim() ax_breaks = fig.add_subplot(gs[0:8,0:2]) ax_breaks.set_xlim(0,8) ax_breaks.set_ylim(ymin,ymax) ax_breaks.axis("off") import re html = open(opj(args.dataset, "pacs.html"), "r").read() # Regex pattern to extract the code and label pattern = r'(\d{2}\.\d{2}\.\d{2})\s*(.*?)' matches = re.findall(pattern, html) matches = { matches[i][0][:2]: matches[i][1] for i in range(len(matches)) } import textwrap prev_break = 0 for i in range(len(breaks_1)): if breaks_1[i] == True: ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125/2) if breaks_2[i] == True or i == len(breaks_1)-1: ax_breaks.hlines(y=i-0.5, xmin=2, xmax=10, color="lightgray", lw=0.125) ax_heatmap.hlines(y=i-0.5, xmin=xmin, xmax=xmax, color="lightgray", lw=0.125) print(prev_break, i, labels[i], prev_break+(i-prev_break)/2.0) if prev_break != i: text = textwrap.shorten(matches[labels[i-1][:2]], width=35*3) lines = textwrap.wrap(text, width=35) too_small = len(lines) >= 0.35*(i-prev_break-2) if too_small: text = textwrap.shorten(matches[labels[i-1][:2]], width=35*2) lines = textwrap.wrap(text, width=35) too_small = len(lines) >= 0.35*(i-prev_break-2) if too_small: text = textwrap.shorten(matches[labels[i-1][:2]], width=35) lines = textwrap.wrap(text, width=35) too_small = len(lines) >= 0.5*(i-prev_break-2) if (not too_small) or i == len(breaks_1)-1: ax_breaks.text( 7.5, prev_break+(i-prev_break)/2.0, "\n".join(lines), ha="right", va="center", fontsize=8 ) prev_break = i plt.savefig(opj(args.input, "pacs_clustermap.eps"), bbox_inches="tight")