import pandas as pd import pickle import numpy as np from scipy.stats import entropy from os.path import join as opj import argparse from tqdm import trange parser = argparse.ArgumentParser() parser.add_argument("--input") parser.add_argument("--dataset", default="inspire-harvest/database") parser.add_argument("--write-topics", action="store_true", default=False) parser.add_argument("--debug", action="store_true", default=False) args = parser.parse_args() ngrams = pd.read_csv(opj(args.input, "ngrams.csv")) with open(opj(args.input, "dataset.pickle"), "rb") as handle: data = pickle.load(handle) with open(opj(args.input, "etm_instance.pickle"), "rb") as handle: etm_instance = pickle.load(handle) articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created"]] articles = articles[articles["date_created"].str.len() >= 4] if "year" not in articles.columns: articles["year"] = articles["date_created"].str[:4].astype(int)-2000 articles["article_id"] = articles.article_id.astype(int) _articles = pd.read_csv(opj(args.input, "articles.csv")) articles = _articles.merge(articles, how="inner") years = articles["year"].values if args.write_topics: top_words = [",".join(l) for l in etm_instance.get_topics(20)] topics = pd.DataFrame({"top_words": top_words}) topics["label"] = "" topics.to_csv(opj(args.input, "topics.csv")) topics = pd.read_csv(opj(args.input, "topics.csv")) theta = etm_instance.get_document_topic_dist() is_junk_topic = np.array(topics["label"].str.contains("Junk")) if args.debug: import seaborn as sns sns.heatmap( np.corrcoef(theta, theta, rowvar=False), vmin=-0.5, vmax=0.5, cmap="RdBu" ) plt.show() topic_counts = np.zeros((theta.shape[0], theta.shape[1])) p_w_z = etm_instance.get_topic_word_dist() print("Computing P(w|d) matrix...") p_w_d = theta @ p_w_z keywords = np.zeros((articles.year.max()+1, p_w_z.shape[1], p_w_z.shape[0])) for d in trange(theta.shape[0]): for i, w in enumerate(data["tokens"][d]): p = p_w_z[:, w] * theta[d, :] / p_w_d[d, w] S = np.exp(entropy(p)) if S >= 2: if args.debug: word = ngrams.iloc[w]["ngram"] print(f"{word} is ambiguous, entropy={S:.2f}") continue else: k = np.argmax(p) if is_junk_topic[k]: continue topic_counts[d, k] += data["counts"][d][i] if args.debug: word = ngrams.iloc[w]["ngram"] print(word, topics.iloc[k]["label"], data["counts"][d][i]) n_words = topic_counts[d,:].sum() if args.debug: print(n_words) if n_words == 0: continue for i, w in enumerate(data["tokens"][d]): keywords[years[d],w,:] += topic_counts[d,:] print(topic_counts) print(topic_counts.mean(axis=0)) print(topic_counts.sum(axis=0)) np.save(opj(args.input, "keywords.npy"), keywords) np.save(opj(args.input, "topics_counts.npy"), topic_counts)