etm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. from AbstractSemantics.terms import TermExtractor
  2. from AbstractSemantics.embeddings import GensimWord2Vec
  3. from gensim.models import KeyedVectors
  4. import nltk
  5. import pandas as pd
  6. import numpy as np
  7. from scipy.sparse import csr_matrix
  8. from os.path import join as opj
  9. from os.path import exists
  10. import itertools
  11. from functools import partial
  12. from collections import defaultdict
  13. import re
  14. import multiprocessing as mp
  15. # from matplotlib import pyplot as plt
  16. import argparse
  17. import yaml
  18. import sys
  19. import pickle
  20. from gensim.models.callbacks import CallbackAny2Vec
  21. class MonitorCallback(CallbackAny2Vec):
  22. def __init__(self, test_words):
  23. self._test_words = test_words
  24. self.epoch = 0
  25. def on_epoch_end(self, model):
  26. loss = model.get_latest_training_loss()
  27. if self.epoch == 0:
  28. print("Loss after epoch {}: {}".format(self.epoch, loss))
  29. else:
  30. print(
  31. "Loss after epoch {}: {}".format(
  32. self.epoch, loss - self.loss_previous_step
  33. )
  34. )
  35. self.epoch += 1
  36. self.loss_previous_step = loss
  37. for word in self._test_words: # show wv logic changes
  38. print(f"{word}: {model.wv.most_similar(word)}")
  39. def filter_ngrams(l, wl):
  40. return [ngram for ngram in l if ngram in wl]
  41. def construct_bow(l, n):
  42. items = list(set(l))
  43. return np.array(items), np.array([l.count(i) for i in items])
  44. def ngram_inclusion(i, js):
  45. return [
  46. j.find(i) >= 0 # matching
  47. and bool(re.search(f"(^|\_){re.escape(i)}($|\_)", j))
  48. and (j.count("_") == i.count("_") + 1)
  49. and (
  50. (i.count("_") >= 1)
  51. or bool(re.search(f"(^|\_){re.escape(i)}$", j))
  52. or bool(re.search(f"^{re.escape(i)}($|\_)", j))
  53. )
  54. for j in js
  55. ]
  56. if __name__ == "__main__":
  57. parser = argparse.ArgumentParser("CT Model")
  58. parser.add_argument("location", help="model directory")
  59. parser.add_argument(
  60. "filter", choices=["categories", "keywords", "no-filter"], help="filter type"
  61. )
  62. parser.add_argument("--values", nargs="+", default=[], help="filter allowed values")
  63. parser.add_argument("--dataset", default="inspire-harvest/database")
  64. # sample size
  65. parser.add_argument("--samples", type=int, default=50000)
  66. parser.add_argument("--constant-sampling", type=int, default=0)
  67. # text pre-processing
  68. parser.add_argument(
  69. "--add-title", default=False, action="store_true", help="include title"
  70. )
  71. parser.add_argument(
  72. "--remove-latex", default=False, action="store_true", help="remove latex"
  73. )
  74. parser.add_argument(
  75. "--lemmatize", default=False, action="store_true", help="lemmatize"
  76. )
  77. parser.add_argument(
  78. "--limit-redundancy",
  79. default=False,
  80. action="store_true",
  81. help="limit redundancy",
  82. )
  83. parser.add_argument("--blacklist", default=None, help="blacklist")
  84. # embeddings
  85. parser.add_argument("--dimensions", type=int, default=50)
  86. parser.add_argument("--pre-trained-embeddings", default=False, action="store_true")
  87. parser.add_argument("--use-saved-embeddings", default=False, action="store_true")
  88. # topic model parameters
  89. parser.add_argument("--topics", type=int, default=25)
  90. parser.add_argument("--min-df", type=float, default=0.001)
  91. parser.add_argument("--max-df", type=float, default=0.15)
  92. parser.add_argument("--threads", type=int, default=4)
  93. args = parser.parse_args(
  94. [
  95. "output/acl_2002_2022",
  96. "no-filter",
  97. "--dataset",
  98. "../acl",
  99. "--constant-sampling",
  100. "12000",
  101. "--samples",
  102. "300000",
  103. "--threads",
  104. "30",
  105. "--add-title",
  106. "--remove-latex",
  107. "--dimensions",
  108. "50",
  109. "--topics",
  110. "20",
  111. "--min-df",
  112. "0.00075",
  113. "--lemmatize",
  114. "--pre-trained-embeddings",
  115. # "--limit-redundancy"
  116. "--use-saved-embeddings"
  117. # "--blacklist",
  118. # "output/medialab/blacklist",
  119. ]
  120. )
  121. with open(opj(args.location, "params.yml"), "w+") as fp:
  122. yaml.dump(args, fp)
  123. articles = pd.read_parquet(
  124. opj(args.dataset, "articles.parquet")
  125. )[["title", "abstract", "article_id", "date_created", "categories"]]
  126. if args.add_title:
  127. articles["abstract"] = articles["abstract"].str.cat(articles["title"], sep=". ")
  128. articles.drop(columns=["title"], inplace=True)
  129. if args.remove_latex:
  130. articles["abstract"] = articles["abstract"].apply(
  131. lambda s: re.sub("$[^>]+$", "", s)
  132. )
  133. articles["abstract"] = articles["abstract"].apply(
  134. lambda s: re.sub(r"\b\\\w+", "", s)
  135. )
  136. articles["abstract"] = articles["abstract"].apply(
  137. lambda s: re.sub("[^0-9a-zA-Z--- -\.]+", "", s)
  138. )
  139. # articles["abstract"] = articles["abstract"].str.replace("-", " ") # NEW
  140. articles = articles[articles["abstract"].map(len) >= 100]
  141. articles["abstract"] = articles["abstract"].str.lower()
  142. articles = articles[articles["date_created"].str.len() >= 4]
  143. if "year" not in articles.columns:
  144. articles["year"] = articles["date_created"].str[:4].astype(int) - 2000
  145. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  146. else:
  147. articles["year"] = articles["year"].astype(int)
  148. articles = articles[articles["year"]>=2002]
  149. articles["year_group"] = articles["year"] // 5
  150. keep = pd.Series([False] * len(articles), index=articles.index)
  151. print("Applying filter...")
  152. if args.filter == "keywords":
  153. for value in args.values:
  154. keep |= articles["abstract"].str.contains(value)
  155. elif args.filter == "categories":
  156. for value in args.values:
  157. keep |= articles["categories"].apply(lambda l: value in l)
  158. elif args.filter == "no-filter":
  159. keep |= True
  160. articles = articles[keep == True].sample(frac=1)
  161. if args.constant_sampling > 0:
  162. articles = articles.groupby("year").head(args.constant_sampling)
  163. articles = articles.sample(frac=1).head(args.samples)
  164. articles.reset_index(inplace=True)
  165. articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
  166. print(articles)
  167. print("Extracting n-grams...")
  168. extractor = TermExtractor(
  169. articles["abstract"].tolist(),
  170. limit_redundancy=args.limit_redundancy,
  171. patterns=[
  172. ["JJ.*"],
  173. ["NN.*"],
  174. ["JJ.*", "NN.*"],
  175. ["JJ.*", "NN.*", "NN.*"],
  176. # ["JJ.*", "NN", "CC", "NN.*"],
  177. # ["JJ.*", "NN.*", "JJ.*", "NN.*"],
  178. # ["RB.*", "JJ.*", "NN.*", "NN.*"],
  179. ],
  180. )
  181. ngrams = extractor.ngrams(
  182. threads=args.threads,
  183. lemmatize=args.lemmatize,
  184. lemmatize_ngrams=args.lemmatize,
  185. split_sentences=args.pre_trained_embeddings and not args.use_saved_embeddings,
  186. )
  187. del extractor
  188. del articles["abstract"]
  189. if args.pre_trained_embeddings and not args.use_saved_embeddings:
  190. ngrams = map(
  191. lambda l: [
  192. [
  193. ("_".join(n))
  194. .strip()
  195. .replace("-", "_")
  196. .replace(".._", "")
  197. .replace("_..", "")
  198. for n in sent
  199. ]
  200. for sent in l
  201. ],
  202. ngrams,
  203. )
  204. ngrams = list(ngrams)
  205. print("Pre-training embeddings...")
  206. emb = GensimWord2Vec(
  207. [sentence for sentences in ngrams for sentence in sentences]
  208. )
  209. model = emb.create_model(
  210. vector_size=args.dimensions,
  211. window=5,
  212. workers=args.threads,
  213. compute_loss=True,
  214. # epochs=90,
  215. # min_count=30,
  216. epochs=80,
  217. min_count=15,
  218. sg=1,
  219. callbacks=[
  220. MonitorCallback(
  221. [
  222. "transformer",
  223. "embedding",
  224. "syntax",
  225. "grammar"
  226. ]
  227. )
  228. ],
  229. )
  230. model.wv.save_word2vec_format(opj(args.location, "embeddings.bin"), binary=True)
  231. del model
  232. ngrams = [
  233. list(itertools.chain.from_iterable(article_sentences))
  234. for article_sentences in ngrams
  235. ]
  236. else:
  237. ngrams = map(
  238. lambda l: [
  239. "_".join(n)
  240. .strip()
  241. .replace("-", "_")
  242. .replace(".._", "")
  243. .replace("_..", "")
  244. for n in l
  245. ],
  246. ngrams,
  247. )
  248. ngrams = list(ngrams)
  249. print("Deriving vocabulary...")
  250. voc = defaultdict(int)
  251. for article_ngrams in ngrams:
  252. _ngrams = set(article_ngrams)
  253. for ngram in _ngrams:
  254. voc[ngram] += 1
  255. voc = pd.DataFrame({"ngram": voc.keys(), "count": voc.values()})
  256. voc["df"] = voc["count"] / len(articles)
  257. voc.set_index("ngram", inplace=True)
  258. if args.min_df < 1:
  259. voc = voc[voc["df"] >= args.min_df]
  260. else:
  261. voc = voc[voc["count"] >= args.min_df]
  262. if args.max_df < 1:
  263. voc = voc[voc["df"] <= args.max_df]
  264. else:
  265. voc = voc[voc["count"] <= args.max_df]
  266. voc["len"] = voc.index.map(len)
  267. voc = voc[voc["len"] >= 2]
  268. stop_words = nltk.corpus.stopwords.words("english")
  269. voc = voc[~voc.index.isin(stop_words)]
  270. if args.blacklist is not None:
  271. print("Filtering black-listed keywords...")
  272. blacklist = pd.read_csv(args.blacklist)["ngram"].tolist()
  273. voc = voc[
  274. voc.index.map(lambda s: not any([ngram in s for ngram in blacklist]))
  275. == True
  276. ]
  277. print("Filtering completed.")
  278. voc = voc.sort_values("df", ascending=False)
  279. voc.to_csv(opj(args.location, "ngrams.csv"))
  280. voc = pd.read_csv(opj(args.location, "ngrams.csv"), keep_default_na=False)[
  281. "ngram"
  282. ].tolist()
  283. vocabulary = {n: i for i, n in enumerate(voc)}
  284. print("Filtering n-grams...")
  285. with mp.Pool(processes=args.threads) as pool:
  286. ngrams = pool.map(partial(filter_ngrams, wl=voc), ngrams)
  287. print("Constructing bag-of-words...")
  288. bow = [[vocabulary[ngram] for ngram in _ngrams] for _ngrams in ngrams]
  289. # if args.limit_redundancy:
  290. # print("Building 'within' matrix...")
  291. # with mp.Pool(processes=args.threads) as pool:
  292. # within = pool.map(partial(ngram_inclusion, js=voc), voc)
  293. # within = np.array(within).astype(int)
  294. # print("Removing double-counting...")
  295. # bow = csr_matrix(bow)
  296. # double_counting = bow.dot(csr_matrix(within.T))
  297. # bow = bow - double_counting
  298. # print(double_counting.sum(), "redundant keywords removed")
  299. # del double_counting
  300. # bow = bow.todense()
  301. # print((bow <= -1).sum(), "keywords had negative counts after removal")
  302. del ngrams
  303. with mp.Pool(processes=args.threads) as pool:
  304. bow = pool.map(partial(construct_bow, n=len(voc)), bow)
  305. keep = [i for i in range(len(bow)) if len(bow[i][0]) > 0]
  306. articles = articles.iloc[keep]
  307. articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
  308. bow = [bow[i] for i in keep]
  309. dataset = {
  310. "tokens": [bow[i][0] for i in range(len(bow))],
  311. "counts": [bow[i][1] for i in range(len(bow))],
  312. "article_id": articles["article_id"],
  313. }
  314. del bow
  315. with open(opj(args.location, "dataset.pickle"), "wb") as handle:
  316. pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
  317. print("Training...")
  318. from embedded_topic_model.models.etm import ETM
  319. etm_instance = ETM(
  320. voc,
  321. num_topics=args.topics,
  322. rho_size=args.dimensions,
  323. emb_size=args.dimensions,
  324. epochs=25,
  325. debug_mode=True,
  326. train_embeddings=not args.pre_trained_embeddings,
  327. model_path=opj(args.location, "model"),
  328. embeddings=opj(args.location, "embeddings.bin")
  329. if args.pre_trained_embeddings
  330. else None,
  331. use_c_format_w2vec=True,
  332. )
  333. etm_instance.fit(dataset)
  334. with open(opj(args.location, "etm_instance.pickle"), "wb") as handle:
  335. pickle.dump(etm_instance, handle, protocol=pickle.HIGHEST_PROTOCOL)
  336. topics = etm_instance.get_topics(20)
  337. print(topics)
  338. topic_coherence = etm_instance.get_topic_coherence()
  339. print(topic_coherence)
  340. topic_diversity = etm_instance.get_topic_diversity()
  341. print(topic_diversity)