ctm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. from AbstractSemantics.terms import TermExtractor
  2. import pandas as pd
  3. import numpy as np
  4. from os.path import join as opj
  5. from os.path import exists
  6. import itertools
  7. from functools import partial
  8. from collections import defaultdict
  9. import re
  10. import tomotopy as tp
  11. from sklearn.model_selection import train_test_split
  12. import tqdm
  13. import multiprocessing as mp
  14. from matplotlib import pyplot as plt
  15. import argparse
  16. import yaml
  17. import sys
  18. parser = argparse.ArgumentParser('CT Model')
  19. parser.add_argument('location', help='model directory')
  20. parser.add_argument('filter', choices=['categories', 'keywords', 'no-filter'], help='filter type')
  21. parser.add_argument('--values', nargs='+', default=[], help='filter allowed values')
  22. parser.add_argument('--samples', type=int, default=100000)
  23. parser.add_argument('--constant-sampling', type=int, default=0)
  24. parser.add_argument('--reuse-articles', default=False, action="store_true", help="reuse article selection")
  25. parser.add_argument('--nouns', default=False, action="store_true", help="include nouns")
  26. parser.add_argument('--adjectives', default=False, action="store_true", help="include adjectives")
  27. parser.add_argument('--lemmatize', default=False, action="store_true", help="stemmer")
  28. parser.add_argument('--remove-latex', default=False, action="store_true", help="remove latex")
  29. parser.add_argument('--limit-redundancy', default=False, action="store_true", help="limit redundancy")
  30. parser.add_argument('--add-title', default=False, action="store_true", help="include title")
  31. parser.add_argument('--top-unithood', type=int, default=20000, help='top unithood filter')
  32. parser.add_argument('--min-token-length', type=int, default=0, help='minimum token length')
  33. parser.add_argument('--min-df', type=int, default=0, help='min_df')
  34. # parser.add_argument('--top-termhood', type=int, default=15000, help='top termhood filter')
  35. parser.add_argument('--reload-model', default=False, action="store_true", help="reload saved model")
  36. parser.add_argument('--reuse-stored-vocabulary', default=False, action='store_true')
  37. parser.add_argument('--compute-best-params', action='store_true', help='optimize hyperparameters (maximzing C_v)', required=False)
  38. parser.add_argument('--reuse-best-params', action='store_true', help='re-use optimal hyperparameters', required=False)
  39. parser.add_argument('--topics', type=int, default=8, help='topics')
  40. parser.add_argument('--alpha', default=0.1, type=float, help='LDA alpha prior')
  41. parser.add_argument('--eta', default=0.01, type=float, help='LDA beta(eta) prior')
  42. parser.add_argument('--threads', type=int, default=4)
  43. args = parser.parse_args()
  44. if __name__ == "__main__":
  45. with open(opj(args.location, "params.yml"), "w+") as fp:
  46. yaml.dump(args, fp)
  47. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["title", "abstract", "article_id", "date_created", "categories"]]
  48. if args.add_title:
  49. articles["abstract"] = articles["abstract"].str.cat(articles["title"])
  50. articles.drop(columns = ["title"], inplace=True)
  51. if args.remove_latex:
  52. articles['abstract'] = articles['abstract'].apply(lambda s: re.sub('$[^>]+$', '', s))
  53. articles = articles[articles["abstract"].map(len)>=100]
  54. articles["abstract"] = articles["abstract"].str.lower()
  55. articles = articles[articles["date_created"].str.len() >= 4]
  56. articles["year"] = articles["date_created"].str[:4].astype(int)-1980
  57. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  58. if args.reuse_articles:
  59. used = pd.read_csv(opj(args.location, 'articles.csv'))
  60. articles = articles[articles["article_id"].isin(used["article_id"])]
  61. else:
  62. articles = articles[~articles["abstract"].isnull()]
  63. if args.constant_sampling > 0:
  64. articles = articles.groupby("year").head(args.constant_sampling)
  65. keep = pd.Series([False]*len(articles), index=articles.index)
  66. print("Applying filter...")
  67. if args.filter == 'keywords':
  68. for value in args.values:
  69. keep |= articles["abstract"].str.contains(value)
  70. elif args.filter == 'categories':
  71. for value in args.values:
  72. keep |= articles["categories"].apply(lambda l: value in l)
  73. articles = articles[keep==True]
  74. articles = articles.sample(frac=1).head(args.samples)
  75. articles[["article_id"]].to_csv(opj(args.location, 'articles.csv'))
  76. articles.reset_index(inplace = True)
  77. print("Extracting n-grams...")
  78. extractor = TermExtractor(articles["abstract"].tolist(), limit_redundancy=args.limit_redundancy)
  79. if args.nouns:
  80. extractor.add_patterns([["NN.*"]])
  81. if args.adjectives:
  82. extractor.add_patterns([["^JJ$"]])
  83. ngrams = extractor.ngrams(threads=args.threads,lemmatize=args.lemmatize)
  84. ngrams = map(lambda l: [" ".join(n) for n in l], ngrams)
  85. ngrams = list(ngrams)
  86. articles["ngrams"] = ngrams
  87. print("Deriving vocabulary...")
  88. if not args.reuse_stored_vocabulary:
  89. ngrams_occurrences = defaultdict(int)
  90. ngrams_cooccurrences = defaultdict(int)
  91. termhood = defaultdict(int)
  92. for ngrams in articles["ngrams"].tolist():
  93. _ngrams = set(ngrams)
  94. for ngram in _ngrams:
  95. ngrams_occurrences[ngram] += 1
  96. ngrams_occurrences = pd.DataFrame(
  97. {"ngram": ngrams_occurrences.keys(), "count": ngrams_occurrences.values()}
  98. )
  99. ngrams_occurrences["unithood"] = (
  100. np.log(2 + ngrams_occurrences["ngram"].str.count(" "))
  101. * ngrams_occurrences["count"]
  102. )
  103. ngrams_occurrences["unithood"] /= len(articles)
  104. ngrams_occurrences.set_index("ngram", inplace=True)
  105. ngrams_occurrences["len"] = ngrams_occurrences.index.map(len)
  106. ngrams_occurrences = ngrams_occurrences[ngrams_occurrences["len"] > 1]
  107. top_unithood = ngrams_occurrences.sort_values("unithood", ascending=False).head(
  108. args.top_unithood
  109. )
  110. top = top_unithood
  111. top.to_csv(opj(args.location, "ngrams.csv"))
  112. selected_ngrams = set(pd.read_csv(opj(args.location, 'ngrams.csv'))['ngram'].tolist())
  113. ngrams = articles["ngrams"].tolist()
  114. ngrams = [[ngram for ngram in _ngrams if ngram in selected_ngrams] for _ngrams in ngrams]
  115. training_ngrams, validation_ngrams = train_test_split(ngrams, train_size=0.9)
  116. print("Creating tomotopy copora...")
  117. training_corpus = tp.utils.Corpus()
  118. for doc in training_ngrams:
  119. training_corpus.add_doc(words=doc)
  120. validation_corpus = tp.utils.Corpus()
  121. for doc in validation_ngrams:
  122. validation_corpus.add_doc(words=doc)
  123. if args.compute_best_params:
  124. topics = list(range(25, 100, 25)) + list(range(100, 200, 50))
  125. alphas = np.logspace(-2, 0, 3, True)
  126. etas = np.logspace(-3, -1, 3, True)
  127. model_results = {
  128. 'topics': [],
  129. 'alphas': [],
  130. 'etas': [],
  131. 'u_mass': [],
  132. 'c_uci': [],
  133. 'c_npmi': [],
  134. 'c_v': [],
  135. 'train_ll_per_word': [],
  136. 'validation_ll': [],
  137. 'documents': [],
  138. 'words': [],
  139. 'perplexity': [],
  140. 'train_perplexity': []
  141. }
  142. try:
  143. done = pd.read_csv(opj(args.location, 'lda_tuning_results.csv'))
  144. model_results = done.to_dict(orient="list")
  145. print(model_results)
  146. except Exception as e:
  147. print(e)
  148. done = None
  149. with tqdm.tqdm(total=len(topics)*len(alphas)*len(etas)) as pbar:
  150. for k in topics:
  151. for alpha in alphas:
  152. # alpha = alpha*10/k
  153. for eta in etas:
  154. print(k, alpha, eta)
  155. is_done = done is not None and len(done[(done["topics"] == k) & (done["alphas"] == alpha) & (done["etas"] == eta)]) > 0
  156. if is_done:
  157. print("already done")
  158. continue
  159. try:
  160. mdl = tp.CTModel(
  161. tw=tp.TermWeight.ONE,
  162. corpus=training_corpus,
  163. k=k,
  164. min_df=3,
  165. smoothing_alpha=alpha,
  166. eta=eta
  167. )
  168. mdl.train(0)
  169. prev_ll_per_word = None
  170. for _ in range(0, 100, 10):
  171. mdl.train(10)
  172. print('Iteration: {:05}\tll per word: {:.5f}'.format(mdl.global_step, mdl.ll_per_word))
  173. if prev_ll_per_word is not None and prev_ll_per_word > mdl.ll_per_word:
  174. print("stopping here")
  175. break
  176. else:
  177. prev_ll_per_word = mdl.ll_per_word
  178. except:
  179. print("failed")
  180. pbar.update(1)
  181. continue
  182. for preset in ('u_mass', 'c_uci', 'c_npmi', 'c_v'):
  183. coh = tp.coherence.Coherence(mdl, coherence=preset)
  184. average_coherence = coh.get_score()
  185. model_results[preset].append(average_coherence)
  186. res, total_ll = mdl.infer(validation_corpus, together=True)
  187. _ll = np.array([doc.get_ll() for doc in res])
  188. words = np.array([len(doc.words) for doc in res])
  189. perplexity = np.exp(-np.sum(total_ll)/np.sum(words))
  190. print(perplexity, mdl.perplexity)
  191. print(-np.sum(total_ll)/np.sum(words), np.log(mdl.perplexity), -np.sum(total_ll)/np.sum(words)/np.log(mdl.perplexity))
  192. #print(total_ll, _ll)
  193. print(f"Topics: {k}, Perplexity: {perplexity}")
  194. print(mdl.ll_per_word)
  195. print(mdl.perplexity)
  196. print(mdl.num_words)
  197. model_results['train_ll_per_word'].append(mdl.ll_per_word)
  198. model_results['validation_ll'].append(np.sum(total_ll))
  199. model_results['documents'].append(len(res))
  200. model_results['words'].append(np.sum(words))
  201. model_results['perplexity'].append(perplexity)
  202. model_results['train_perplexity'].append(mdl.perplexity)
  203. model_results['topics'].append(k)
  204. model_results['alphas'].append(alpha)
  205. model_results['etas'].append(eta)
  206. pd.DataFrame(model_results).to_csv(opj(args.location, 'lda_tuning_results.csv'), index=False)
  207. pbar.update(1)
  208. params = {'topics': args.topics}
  209. if not args.reload_model:
  210. print("Training LDA...")
  211. min_df = args.min_df
  212. print(min_df)
  213. mdl = tp.CTModel(
  214. tw=tp.TermWeight.ONE,
  215. corpus=training_corpus,
  216. k=params['topics'],
  217. min_df=min_df,
  218. smoothing_alpha=args.alpha,
  219. eta=args.eta
  220. )
  221. mdl.train(0)
  222. print('Num docs:', len(mdl.docs), ', Vocab size:', len(mdl.used_vocabs), ', Num words:', mdl.num_words)
  223. print('Removed top words:', mdl.removed_top_words)
  224. print('Training...', file=sys.stderr, flush=True)
  225. for _ in range(0, 250, 10):
  226. mdl.train(10)
  227. print('Iteration: {:05}\tll per word: {:.5f}'.format(mdl.global_step, mdl.ll_per_word))
  228. import pyLDAvis
  229. topic_term_dists = np.stack([mdl.get_topic_word_dist(k) for k in range(mdl.k)])
  230. doc_topic_dists = np.stack([doc.get_topic_dist() for doc in mdl.docs])
  231. doc_topic_dists /= doc_topic_dists.sum(axis=1, keepdims=True)
  232. doc_lengths = np.array([len(doc.words) for doc in mdl.docs])
  233. vocab = list(mdl.used_vocabs)
  234. term_frequency = mdl.used_vocab_freq
  235. prepared_data = pyLDAvis.prepare(
  236. topic_term_dists,
  237. doc_topic_dists,
  238. doc_lengths,
  239. vocab,
  240. term_frequency,
  241. start_index=0, # tomotopy starts topic ids with 0, pyLDAvis with 1
  242. sort_topics=False # IMPORTANT: otherwise the topic_ids between pyLDAvis and tomotopy are not matching!
  243. )
  244. pyLDAvis.save_html(prepared_data, opj(args.location, 'ldavis.html'))
  245. print('Saving...', file=sys.stderr, flush=True)
  246. mdl.save(opj(args.location, "model"), True)
  247. else:
  248. print("Loading pre-trained model...")
  249. mdl = tp.CTModel.load(opj(args.location, "model"))
  250. mdl.summary()
  251. # extract candidates for auto topic labeling
  252. extractor = tp.label.PMIExtractor(min_cf=10, min_df=5, max_len=5, max_cand=10000)
  253. cands = extractor.extract(mdl)
  254. labeler = tp.label.FoRelevance(mdl, cands, min_df=5, smoothing=1e-2, mu=0.25)
  255. for k in range(mdl.k):
  256. print("== Topic #{} ==".format(k))
  257. print("Labels:", ', '.join(label for label, score in labeler.get_topic_labels(k, top_n=5)))
  258. for word, prob in mdl.get_topic_words(k, top_n=10):
  259. print(word, prob, sep='\t')
  260. print()
  261. for preset in ('u_mass', 'c_uci', 'c_npmi', 'c_v'):
  262. coh = tp.coherence.Coherence(mdl, coherence=preset)
  263. average_coherence = coh.get_score()
  264. coherence_per_topic = [coh.get_score(topic_id=k) for k in range(mdl.k)]
  265. print('==== Coherence: {} ===='.format(preset))
  266. print('Average:', average_coherence, '\nPer Topic:', coherence_per_topic)
  267. print()
  268. print("Applying model...")
  269. used_vocab = set(mdl.used_vocabs)
  270. articles["ngrams"] = ngrams
  271. articles = articles[articles["ngrams"].map(len) > 0]
  272. articles = articles[articles["ngrams"].map(lambda l: len(set(l)&used_vocab) > 0) == True]
  273. ngrams = articles["ngrams"].tolist()
  274. corpus = tp.utils.Corpus()
  275. for doc in ngrams:
  276. corpus.add_doc(words=doc)
  277. test_result_cps, ll = mdl.infer(corpus)
  278. topic_dist = []
  279. for i, doc in enumerate(test_result_cps):
  280. print(i, doc)
  281. dist = doc.get_topic_dist()
  282. topic_dist.append(dist)
  283. n = 0
  284. while exists(opj(args.location, f"topics_{n}.parquet")):
  285. n +=1
  286. path = opj(args.location, f"topics_{n}.parquet")
  287. articles["probs"] = topic_dist
  288. articles["topics"] = articles["probs"].map(lambda l: ",".join(list(map('{:.6f}'.format, l))))
  289. articles[["year", "article_id", "topics", "probs"]].to_parquet(path, index=False)
  290. try:
  291. descriptions = pd.read_csv(opj(args.location, "descriptions.csv")).set_index("topic")
  292. except:
  293. descriptions = None
  294. cumprobs = np.zeros((42, mdl.k))
  295. counts = np.zeros(42)
  296. for year, _articles in articles.groupby("year"):
  297. print(year)
  298. for article in _articles.to_dict(orient = 'records'):
  299. for topic, prob in enumerate(article['probs']):
  300. cumprobs[year,topic] += prob
  301. counts[year] = len(_articles)
  302. cumprobs.dump(opj(args.location, 'cumsprobs.npy'))
  303. counts.dump(opj(args.location, 'counts.npy'))
  304. lines = ['-', '--', '-.', ':', 'dotted', (0, (1, 10)), (0, (3, 10, 1, 10)), (0, (5, 10)), (0, (3, 1, 1, 1, 1, 1)), '-', '--']
  305. for topic in range(mdl.k):
  306. plt.plot(
  307. 1980+np.arange(42),
  308. cumprobs[:,topic],
  309. linestyle=lines[topic//7],
  310. label=topic if descriptions is None else descriptions.loc[topic,"description"]
  311. )
  312. plt.title("Absolute magnitude of supersymmetry research topics")
  313. plt.ylabel("Estimated amount of articles\n($\\sum_{d_i \\in \\mathrm{year}} p(t|d_i)$)")
  314. plt.xlim(1980, 2018)
  315. plt.legend(fontsize='x-small')
  316. plt.savefig(opj(args.location, "topics_count.pdf"))
  317. plt.clf()
  318. for topic in range(mdl.k):
  319. plt.plot(
  320. 1980+np.arange(42),
  321. cumprobs[:,topic]/counts,
  322. linestyle=lines[topic//7],
  323. label=topic if descriptions is None else descriptions.loc[topic,"description"]
  324. )
  325. plt.title("Relative magnitude of supersymmetry research topics")
  326. plt.ylabel("Probability of each topic throughout years\n($p(t|\\mathrm{year}$)")
  327. plt.xlim(1980, 2018)
  328. plt.legend(fontsize='x-small')
  329. plt.savefig(opj(args.location, "topics_probs.pdf"))
  330. plt.clf()