etm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. from AbstractSemantics.terms import TermExtractor
  2. from AbstractSemantics.embeddings import GensimWord2Vec
  3. from gensim.models import KeyedVectors
  4. import nltk
  5. import pandas as pd
  6. import numpy as np
  7. from scipy.sparse import csr_matrix
  8. from os.path import join as opj
  9. from os.path import exists
  10. import itertools
  11. from functools import partial
  12. from collections import defaultdict
  13. import re
  14. import multiprocessing as mp
  15. # from matplotlib import pyplot as plt
  16. import argparse
  17. import yaml
  18. import sys
  19. import pickle
  20. from gensim.models.callbacks import CallbackAny2Vec
  21. class MonitorCallback(CallbackAny2Vec):
  22. def __init__(self, test_words):
  23. self._test_words = test_words
  24. self.epoch = 0
  25. def on_epoch_end(self, model):
  26. loss = model.get_latest_training_loss()
  27. if self.epoch == 0:
  28. print("Loss after epoch {}: {}".format(self.epoch, loss))
  29. else:
  30. print(
  31. "Loss after epoch {}: {}".format(
  32. self.epoch, loss - self.loss_previous_step
  33. )
  34. )
  35. self.epoch += 1
  36. self.loss_previous_step = loss
  37. for word in self._test_words: # show wv logic changes
  38. print(f"{word}: {model.wv.most_similar(word)}")
  39. def filter_ngrams(l, wl):
  40. return [ngram for ngram in l if ngram in wl]
  41. def construct_bow(l, n):
  42. items = list(set(l))
  43. return np.array(items), np.array([l.count(i) for i in items])
  44. def ngram_inclusion(i, js):
  45. return [
  46. j.find(i) >= 0 # matching
  47. and bool(re.search(f"(^|\_){re.escape(i)}($|\_)", j))
  48. and (j.count("_") == i.count("_") + 1)
  49. and (
  50. (i.count("_") >= 1)
  51. or bool(re.search(f"(^|\_){re.escape(i)}$", j))
  52. or bool(re.search(f"^{re.escape(i)}($|\_)", j))
  53. )
  54. for j in js
  55. ]
  56. if __name__ == "__main__":
  57. parser = argparse.ArgumentParser("CT Model")
  58. parser.add_argument("location", help="model directory")
  59. parser.add_argument(
  60. "filter", choices=["categories", "keywords", "no-filter"], help="filter type"
  61. )
  62. parser.add_argument("--values", nargs="+", default=[], help="filter allowed values")
  63. parser.add_argument("--dataset", default="inspire-harvest/database")
  64. # sample size
  65. parser.add_argument("--samples", type=int, default=50000)
  66. parser.add_argument("--constant-sampling", type=int, default=0)
  67. # text pre-processing
  68. parser.add_argument(
  69. "--add-title", default=False, action="store_true", help="include title"
  70. )
  71. parser.add_argument(
  72. "--remove-latex", default=False, action="store_true", help="remove latex"
  73. )
  74. parser.add_argument(
  75. "--lemmatize", default=False, action="store_true", help="lemmatize"
  76. )
  77. parser.add_argument(
  78. "--limit-redundancy",
  79. default=False,
  80. action="store_true",
  81. help="limit redundancy",
  82. )
  83. parser.add_argument("--blacklist", default=None, help="blacklist")
  84. # embeddings
  85. parser.add_argument("--dimensions", type=int, default=50)
  86. parser.add_argument("--pre-trained-embeddings", default=False, action="store_true")
  87. parser.add_argument("--use-saved-embeddings", default=False, action="store_true")
  88. # topic model parameters
  89. parser.add_argument("--topics", type=int, default=25)
  90. parser.add_argument("--min-df", type=float, default=0.001)
  91. parser.add_argument("--max-df", type=float, default=0.15)
  92. parser.add_argument("--threads", type=int, default=4)
  93. args = parser.parse_args(
  94. [
  95. "output/etm_25_pretrained",
  96. "categories",
  97. "--values",
  98. "Theory-HEP",
  99. "Phenomenology-HEP",
  100. "--dataset",
  101. "../inspire-harvest/database",
  102. "--constant-sampling",
  103. "30000",
  104. "--samples",
  105. "300000",
  106. "--threads",
  107. "24",
  108. "--add-title",
  109. "--remove-latex",
  110. "--dimensions",
  111. "50",
  112. "--topics",
  113. "25",
  114. "--min-df",
  115. "0.00075",
  116. "--lemmatize",
  117. "--pre-trained-embeddings",
  118. # "--limit-redundancy"
  119. "--use-saved-embeddings"
  120. # "--blacklist",
  121. # "output/medialab/blacklist",
  122. ]
  123. )
  124. with open(opj(args.location, "params.yml"), "w+") as fp:
  125. yaml.dump(args, fp)
  126. articles = pd.read_parquet(
  127. opj(args.dataset, "articles.parquet")
  128. )[["title", "abstract", "article_id", "date_created", "categories"]]
  129. if args.add_title:
  130. articles["abstract"] = articles["abstract"].str.cat(articles["title"], sep=". ")
  131. articles.drop(columns=["title"], inplace=True)
  132. if args.remove_latex:
  133. articles["abstract"] = articles["abstract"].apply(
  134. lambda s: re.sub("$[^>]+$", "", s)
  135. )
  136. articles["abstract"] = articles["abstract"].apply(
  137. lambda s: re.sub(r"\b\\\w+", "", s)
  138. )
  139. articles["abstract"] = articles["abstract"].apply(
  140. lambda s: re.sub("[^0-9a-zA-Z--- -\.]+", "", s)
  141. )
  142. # articles["abstract"] = articles["abstract"].str.replace("-", " ") # NEW
  143. articles = articles[articles["abstract"].map(len) >= 100]
  144. articles["abstract"] = articles["abstract"].str.lower()
  145. articles = articles[articles["date_created"].str.len() >= 4]
  146. if "year" not in articles.columns:
  147. articles["year"] = articles["date_created"].str[:4].astype(int) - 2000
  148. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 20)]
  149. else:
  150. articles["year"] = articles["year"].astype(int)
  151. articles = articles[articles["year"]>=2000]
  152. articles["year_group"] = articles["year"] // 5
  153. keep = pd.Series([False] * len(articles), index=articles.index)
  154. print("Applying filter...")
  155. if args.filter == "keywords":
  156. for value in args.values:
  157. keep |= articles["abstract"].str.contains(value)
  158. elif args.filter == "categories":
  159. for value in args.values:
  160. keep |= articles["categories"].apply(lambda l: value in l)
  161. elif args.filter == "no-filter":
  162. keep |= True
  163. articles = articles[keep == True].sample(frac=1)
  164. if args.constant_sampling > 0:
  165. articles = articles.groupby("year").head(args.constant_sampling)
  166. articles = articles.sample(frac=1).head(args.samples)
  167. articles.reset_index(inplace=True)
  168. articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
  169. print(articles)
  170. print("Extracting n-grams...")
  171. extractor = TermExtractor(
  172. articles["abstract"].tolist(),
  173. limit_redundancy=args.limit_redundancy,
  174. patterns=[
  175. ["JJ.*"],
  176. ["NN.*"],
  177. ["JJ.*", "NN.*"],
  178. ["JJ.*", "NN.*", "NN.*"],
  179. # ["JJ.*", "NN", "CC", "NN.*"],
  180. # ["JJ.*", "NN.*", "JJ.*", "NN.*"],
  181. # ["RB.*", "JJ.*", "NN.*", "NN.*"],
  182. ],
  183. )
  184. ngrams = extractor.ngrams(
  185. threads=args.threads,
  186. lemmatize=args.lemmatize,
  187. lemmatize_ngrams=args.lemmatize,
  188. split_sentences=args.pre_trained_embeddings and not args.use_saved_embeddings,
  189. )
  190. del extractor
  191. del articles["abstract"]
  192. if args.pre_trained_embeddings and not args.use_saved_embeddings:
  193. ngrams = map(
  194. lambda l: [
  195. [
  196. ("_".join(n))
  197. .strip()
  198. .replace("-", "_")
  199. .replace(".._", "")
  200. .replace("_..", "")
  201. for n in sent
  202. ]
  203. for sent in l
  204. ],
  205. ngrams,
  206. )
  207. ngrams = list(ngrams)
  208. print("Pre-training embeddings...")
  209. emb = GensimWord2Vec(
  210. [sentence for sentences in ngrams for sentence in sentences]
  211. )
  212. model = emb.create_model(
  213. vector_size=args.dimensions,
  214. window=5,
  215. workers=args.threads,
  216. compute_loss=True,
  217. # epochs=90,
  218. # min_count=30,
  219. epochs=80,
  220. min_count=15,
  221. sg=1,
  222. callbacks=[
  223. MonitorCallback(
  224. [
  225. "transformer",
  226. "embedding",
  227. "syntax",
  228. "grammar"
  229. ]
  230. )
  231. ],
  232. )
  233. model.wv.save_word2vec_format(opj(args.location, "embeddings.bin"), binary=True)
  234. del model
  235. ngrams = [
  236. list(itertools.chain.from_iterable(article_sentences))
  237. for article_sentences in ngrams
  238. ]
  239. else:
  240. ngrams = map(
  241. lambda l: [
  242. "_".join(n)
  243. .strip()
  244. .replace("-", "_")
  245. .replace(".._", "")
  246. .replace("_..", "")
  247. for n in l
  248. ],
  249. ngrams,
  250. )
  251. ngrams = list(ngrams)
  252. print("Deriving vocabulary...")
  253. voc = defaultdict(int)
  254. for article_ngrams in ngrams:
  255. _ngrams = set(article_ngrams)
  256. for ngram in _ngrams:
  257. voc[ngram] += 1
  258. voc = pd.DataFrame({"ngram": voc.keys(), "count": voc.values()})
  259. voc["df"] = voc["count"] / len(articles)
  260. voc.set_index("ngram", inplace=True)
  261. if args.min_df < 1:
  262. voc = voc[voc["df"] >= args.min_df]
  263. else:
  264. voc = voc[voc["count"] >= args.min_df]
  265. if args.max_df < 1:
  266. voc = voc[voc["df"] <= args.max_df]
  267. else:
  268. voc = voc[voc["count"] <= args.max_df]
  269. voc["len"] = voc.index.map(len)
  270. voc = voc[voc["len"] >= 2]
  271. stop_words = nltk.corpus.stopwords.words("english")
  272. voc = voc[~voc.index.isin(stop_words)]
  273. if args.blacklist is not None:
  274. print("Filtering black-listed keywords...")
  275. blacklist = pd.read_csv(args.blacklist)["ngram"].tolist()
  276. voc = voc[
  277. voc.index.map(lambda s: not any([ngram in s for ngram in blacklist]))
  278. == True
  279. ]
  280. print("Filtering completed.")
  281. voc = voc.sort_values("df", ascending=False)
  282. voc.to_csv(opj(args.location, "ngrams.csv"))
  283. voc = pd.read_csv(opj(args.location, "ngrams.csv"), keep_default_na=False)[
  284. "ngram"
  285. ].tolist()
  286. vocabulary = {n: i for i, n in enumerate(voc)}
  287. print("Filtering n-grams...")
  288. with mp.Pool(processes=args.threads) as pool:
  289. ngrams = pool.map(partial(filter_ngrams, wl=voc), ngrams)
  290. print("Constructing bag-of-words...")
  291. bow = [[vocabulary[ngram] for ngram in _ngrams] for _ngrams in ngrams]
  292. # if args.limit_redundancy:
  293. # print("Building 'within' matrix...")
  294. # with mp.Pool(processes=args.threads) as pool:
  295. # within = pool.map(partial(ngram_inclusion, js=voc), voc)
  296. # within = np.array(within).astype(int)
  297. # print("Removing double-counting...")
  298. # bow = csr_matrix(bow)
  299. # double_counting = bow.dot(csr_matrix(within.T))
  300. # bow = bow - double_counting
  301. # print(double_counting.sum(), "redundant keywords removed")
  302. # del double_counting
  303. # bow = bow.todense()
  304. # print((bow <= -1).sum(), "keywords had negative counts after removal")
  305. del ngrams
  306. with mp.Pool(processes=args.threads) as pool:
  307. bow = pool.map(partial(construct_bow, n=len(voc)), bow)
  308. keep = [i for i in range(len(bow)) if len(bow[i][0]) > 0]
  309. articles = articles.iloc[keep]
  310. articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
  311. bow = [bow[i] for i in keep]
  312. dataset = {
  313. "tokens": [bow[i][0] for i in range(len(bow))],
  314. "counts": [bow[i][1] for i in range(len(bow))],
  315. "article_id": articles["article_id"],
  316. }
  317. del bow
  318. with open(opj(args.location, "dataset.pickle"), "wb") as handle:
  319. pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
  320. print("Training...")
  321. from embedded_topic_model.models.etm import ETM
  322. etm_instance = ETM(
  323. voc,
  324. num_topics=args.topics,
  325. rho_size=args.dimensions,
  326. emb_size=args.dimensions,
  327. epochs=25,
  328. debug_mode=True,
  329. train_embeddings=not args.pre_trained_embeddings,
  330. model_path=opj(args.location, "model"),
  331. embeddings=opj(args.location, "embeddings.bin")
  332. if args.pre_trained_embeddings
  333. else None,
  334. use_c_format_w2vec=True,
  335. )
  336. etm_instance.fit(dataset)
  337. with open(opj(args.location, "etm_instance.pickle"), "wb") as handle:
  338. pickle.dump(etm_instance, handle, protocol=pickle.HIGHEST_PROTOCOL)
  339. topics = etm_instance.get_topics(20)
  340. print(topics)
  341. topic_coherence = etm_instance.get_topic_coherence()
  342. print(topic_coherence)
  343. topic_diversity = etm_instance.get_topic_diversity()
  344. print(topic_diversity)