123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.stats import entropy
- from os.path import join as opj
- import argparse
- from tqdm import trange
- parser = argparse.ArgumentParser()
- parser.add_argument("--input")
- parser.add_argument("--dataset", default="inspire-harvest/database")
- parser.add_argument("--write-topics", action="store_true", default=False)
- parser.add_argument("--debug", action="store_true", default=False)
- args = parser.parse_args()
- ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
- with open(opj(args.input, "dataset.pickle"), "rb") as handle:
- data = pickle.load(handle)
- with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
- etm_instance = pickle.load(handle)
- articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created"]]
- articles = articles[articles["date_created"].str.len() >= 4]
- if "year" not in articles.columns:
- articles["year"] = articles["date_created"].str[:4].astype(int)-2000
- articles["article_id"] = articles.article_id.astype(int)
- _articles = pd.read_csv(opj(args.input, "articles.csv"))
- articles = _articles.merge(articles, how="inner")
- years = articles["year"].values
- if args.write_topics:
- top_words = [",".join(l) for l in etm_instance.get_topics(20)]
- topics = pd.DataFrame({"top_words": top_words})
- topics["label"] = ""
- topics.to_csv(opj(args.input, "topics.csv"))
- topics = pd.read_csv(opj(args.input, "topics.csv"))
- theta = etm_instance.get_document_topic_dist()
- is_junk_topic = np.array(topics["label"].str.contains("Junk"))
- if args.debug:
- import seaborn as sns
- sns.heatmap(
- np.corrcoef(theta, theta, rowvar=False), vmin=-0.5, vmax=0.5, cmap="RdBu"
- )
- plt.show()
- topic_counts = np.zeros((theta.shape[0], theta.shape[1]))
- p_w_z = etm_instance.get_topic_word_dist()
- print("Computing P(w|d) matrix...")
- p_w_d = theta @ p_w_z
- keywords = np.zeros((articles.year.max()+1, p_w_z.shape[1], p_w_z.shape[0]))
- for d in trange(theta.shape[0]):
- for i, w in enumerate(data["tokens"][d]):
- p = p_w_z[:, w] * theta[d, :] / p_w_d[d, w]
- S = np.exp(entropy(p))
- if S >= 2:
- if args.debug:
- word = ngrams.iloc[w]["ngram"]
- print(f"{word} is ambiguous, entropy={S:.2f}")
- continue
- else:
- k = np.argmax(p)
- if is_junk_topic[k]:
- continue
- topic_counts[d, k] += data["counts"][d][i]
- if args.debug:
- word = ngrams.iloc[w]["ngram"]
- print(word, topics.iloc[k]["label"], data["counts"][d][i])
- n_words = topic_counts[d,:].sum()
- if args.debug:
- print(n_words)
-
- if n_words == 0:
- continue
- for i, w in enumerate(data["tokens"][d]):
- keywords[years[d],w,:] += topic_counts[d,:]
-
- print(topic_counts)
- print(topic_counts.mean(axis=0))
- print(topic_counts.sum(axis=0))
- np.save(opj(args.input, "keywords.npy"), keywords)
- np.save(opj(args.input, "topics_counts.npy"), topic_counts)
|