#!/usr/bin/env python # coding: utf-8 import pandas as pd import numpy as np from matplotlib import pyplot as plt import matplotlib from matplotlib import pyplot as plt matplotlib.use("pgf") matplotlib.rcParams.update( { "pgf.texsystem": "xelatex", "font.family": "serif", "font.serif": "Times New Roman", "text.usetex": True, "pgf.rcfonts": False, } ) plt.rcParams["text.latex.preamble"].join([ r"\usepackage{amsmath}", r"\setmainfont{amssymb}", ]) import seaborn as sns import pickle from os.path import join as opj import argparse parser = argparse.ArgumentParser() parser.add_argument("--inputs", nargs="+") parser.add_argument("--dataset", default="inspire-harvest/database") parser.add_argument("--keywords-threshold", type=int, default=200) parser.add_argument("--articles-threshold", type=int, default=5) parser.add_argument("--early-periods", nargs="+", type=int, default=[0,1]) # [2,3] for ACL, [3] for HEP parser.add_argument("--late-periods", nargs="+", type=int, default=[3]) # [2,3] for ACL, [3] for HEP parser.add_argument("--fla", action="store_true", help="first or last author") args = parser.parse_args() custom_range = "_" + "-".join(map(str, args.early_periods)) + "_" + "-".join(map(str, args.late_periods)) if (args.early_periods!=[0,1] or args.late_periods!=[3]) else "" print(custom_range) references = pd.read_parquet(opj(args.dataset, "articles_references.parquet")) references["cites"] = references.cites.astype(int) references["cited"] = references.cited.astype(int) articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title", "accelerators"]] articles["article_id"] = articles.article_id.astype(int) articles = articles[articles["date_created"].str.len() >= 4] articles["year"] = articles["date_created"].str[:4].astype(int) experimental = articles[articles["accelerators"].map(len)>=1] experimental = experimental.explode("accelerators") experimental["accelerators"] = experimental["accelerators"].str.replace( "(.*)-(.*)-(.*)$", r"\1-\2", regex=True ) types = { "FNAL-E": "colliders", "CERN-LEP": "colliders", "DESY-HERA": "colliders", "SUPER-KAMIOKANDE": "astro. neutrinos", "CERN-NA": "colliders", "CESR-CLEO": "colliders", "CERN-WA": "colliders", "BNL-E": "colliders", "KAMIOKANDE": "astro. neutrinos", "SLAC-E": "colliders", "SLAC-PEP2": "colliders", "KEK-BF": "colliders", "SNO": "neutrinos", "BNL-RHIC": "colliders", "WMAP": "cosmic $\\mu$wave background", "CERN-LHC": "colliders", "PLANCK": "cosmic $\\mu$wave background", "BEPC-BES": "colliders", "LIGO": "gravitational waves", "VIRGO": "gravitational waves", "CERN-PS": "colliders", "FERMI-LAT": "other cosmic sources", "XENON100": "dark matter (direct)", "ICECUBE": "astro. neutrinos", "LUX": "dark matter (direct)", "T2K": "neutrinos", "BICEP2": "cosmic $\\mu$wave background", "CDMS": "dark matter (direct)", "LAMPF-1173": "neutrinos", "FRASCATI-DAFNE": "colliders", "KamLAND": "neutrinos", "SDSS": "other cosmic sources", "JLAB-E-89": "colliders", "CHOOZ": "neutrinos", "XENON1T": "dark matter (direct)", "SCP": "supernovae", "DAYA-BAY": "neutrinos", "HOMESTAKE-CHLORINE": "neutrinos", "HIGH-Z": "supernovae", "K2K": "neutrinos", "MACRO": "other cosmic sources", "GALLEX": "neutrinos", "SAGE": "neutrinos", "PAMELA": "other cosmic sources", "CERN-UA": "colliders", "CERN SPS": "colliders", "DESY-PETRA": "colliders", "SLAC-SLC": "colliders", "LEPS": "colliders", "DOUBLECHOOZ": "neutrinos", "AUGER": "other cosmic sources", "AMS": "other cosmic sources", "DAMA": "dark matter (direct)", "DESY-DORIS": "colliders", "NOVOSIBIRSK-CMD": "colliders", "IMB": "neutrinos", "RENO": "neutrinos", "SLAC-SP": "colliders" } ordered_types = [ "colliders", "neutrinos", "astro. neutrinos", "dark matter (direct)", "cosmic $\\mu$wave background", "supernovae", "other cosmic sources", "gravitational waves" ] experimental = experimental[experimental["accelerators"].isin(types.keys())] experimental["type"] = experimental["accelerators"].map(types) def compute_counts(model, articles): _articles = pd.read_csv(opj(model, "articles.csv")) articles = _articles.merge(articles, how="left") topics = pd.read_csv(opj(model, "topics.csv"))["label"].tolist() topic_matrix = np.load(opj(model, "topics_counts.npy")) topic_matrix = topic_matrix/np.where(topic_matrix.sum(axis=1)>0, topic_matrix.sum(axis=1), 1)[:,np.newaxis] articles["topics"] = list(topic_matrix) articles["main_topic"] = topic_matrix.argmax(axis=1) articles["main_topic"] = articles["main_topic"].map(lambda x: topics[x]) articles = articles[~articles["main_topic"].str.contains("Junk")] citing_experiments = articles.merge(references, how="inner", left_on="article_id", right_on="cites") citing_experiments = citing_experiments.merge(experimental, how="inner", left_on="cited", right_on="article_id") counts = citing_experiments.groupby("type")["main_topic"].value_counts(normalize=True).reset_index() counts = counts.pivot(index="type", columns="main_topic") counts.sort_index(key=lambda idx: idx.map(lambda x: ordered_types.index(x)), inplace=True) return counts model_names = ["Main model", "$K=15$", "$K=25$"] fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=[4.8*1.5, 3.2*1.5], layout='constrained') for i, model in enumerate(args.inputs): counts = compute_counts(model, articles.copy()) print(counts.index) topics = counts["proportion"].columns values = np.stack(counts["proportion"].values) relevant_topics = np.arange(len(topics))[values.max(axis=0)>0.05] topics = topics[relevant_topics] # sns.heatmap(counts, ax=axes[i], cmap="Reds") im = axes[i].matshow(counts["proportion"][topics], cmap="Reds", vmin=0, vmax=1, aspect=len(topics)/15) axes[i].set_xlabel(model_names[i]) topics = [topic.replace(" and ", " \\& ").lower().capitalize() for topic in topics] topics = [f"\\small{{{topic}}}" for topic in topics] axes[i].set_xticks(np.arange(len(topics)), topics, rotation="vertical") axes[i].set_xticks(np.arange(len(topics)+1)-0.5, minor=True) # axes[i].axvline(len(topics)-0.5, color="black", lw=2) axes[i].set_xlim(-0.5, len(topics)-0.5) fig.colorbar(im, location='bottom') axes[0].set_yticks(np.arange(len(ordered_types)), list(map(lambda x: f"\\small{{{x.capitalize()}}}", ordered_types)), va="center") axes[i].set_yticks(np.arange(len(ordered_types)+1)-0.5, minor=True) # p0 = axes[0].get_position() # p2 = axes[2].get_position() # ax_cbar = fig.add_axes([p0.x0, 0.08, p2.x1, 0.05]) # plt.colorbar(im, cax=ax_cbar, ticks=np.linspace(0, 1, 3, True), orientation='horizontal') plt.subplots_adjust(wspace=0, hspace=0) fig.savefig(opj(args.inputs[0], "topic_experiments.eps"), bbox_inches="tight")