word2vec_validation.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from AbstractSemantics.terms import TermExtractor
  2. from AbstractSemantics.embeddings import GensimWord2Vec
  3. from gensim.models import KeyedVectors
  4. import nltk
  5. import pandas as pd
  6. import numpy as np
  7. from scipy.sparse import csr_matrix
  8. from os.path import join as opj
  9. from os.path import exists
  10. import itertools
  11. from functools import partial
  12. from collections import defaultdict
  13. import re
  14. import multiprocessing as mp
  15. # from matplotlib import pyplot as plt
  16. import argparse
  17. import yaml
  18. import sys
  19. import pickle
  20. from gensim.models.callbacks import CallbackAny2Vec
  21. model_loss = 0
  22. class MonitorCallback(CallbackAny2Vec):
  23. def __init__(self, test_words):
  24. self._test_words = test_words
  25. self.epoch = 0
  26. def on_epoch_end(self, model):
  27. loss = model.get_latest_training_loss()
  28. model.running_training_loss = 0.0
  29. print("Loss after epoch {}: {}".format(self.epoch, loss))
  30. global model_loss
  31. model_loss = loss
  32. self.epoch += 1
  33. for word in self._test_words: # show wv logic changes
  34. print(f"{word}: {model.wv.most_similar(word)}")
  35. def filter_ngrams(l, wl):
  36. return [ngram for ngram in l if ngram in wl]
  37. def construct_bow(l, n):
  38. items = list(set(l))
  39. return np.array(items), np.array([l.count(i) for i in items])
  40. def ngram_inclusion(i, js):
  41. return [
  42. j.find(i) >= 0 # matching
  43. and bool(re.search(f"(^|\_){re.escape(i)}($|\_)", j))
  44. and (j.count("_") == i.count("_") + 1)
  45. and (
  46. (i.count("_") >= 1)
  47. or bool(re.search(f"(^|\_){re.escape(i)}$", j))
  48. or bool(re.search(f"^{re.escape(i)}($|\_)", j))
  49. )
  50. for j in js
  51. ]
  52. if __name__ == "__main__":
  53. parser = argparse.ArgumentParser("CT Model")
  54. parser.add_argument("location", help="model directory")
  55. parser.add_argument(
  56. "filter", choices=["categories", "keywords", "no-filter"], help="filter type"
  57. )
  58. parser.add_argument("--values", nargs="+", default=[], help="filter allowed values")
  59. parser.add_argument("--dataset", default="inspire-harvest/database")
  60. # sample size
  61. parser.add_argument("--samples", type=int, default=50000)
  62. parser.add_argument("--constant-sampling", type=int, default=0)
  63. # text pre-processing
  64. parser.add_argument(
  65. "--add-title", default=False, action="store_true", help="include title"
  66. )
  67. parser.add_argument(
  68. "--remove-latex", default=False, action="store_true", help="remove latex"
  69. )
  70. parser.add_argument(
  71. "--lemmatize", default=False, action="store_true", help="lemmatize"
  72. )
  73. parser.add_argument(
  74. "--limit-redundancy",
  75. default=False,
  76. action="store_true",
  77. help="limit redundancy",
  78. )
  79. parser.add_argument("--blacklist", default=None, help="blacklist")
  80. # embeddings
  81. parser.add_argument("--dimensions", type=int, default=50)
  82. parser.add_argument("--pre-trained-embeddings", default=False, action="store_true")
  83. parser.add_argument("--use-saved-embeddings", default=False, action="store_true")
  84. # topic model parameters
  85. parser.add_argument("--topics", type=int, default=25)
  86. parser.add_argument("--min-df", type=float, default=0.001)
  87. parser.add_argument("--max-df", type=float, default=0.15)
  88. parser.add_argument("--threads", type=int, default=4)
  89. args = parser.parse_args(
  90. [
  91. "output/etm_20_r",
  92. "categories",
  93. "--values",
  94. "Theory-HEP",
  95. "Phenomenology-HEP",
  96. "--dataset",
  97. "../inspire-harvest/database",
  98. "--constant-sampling",
  99. "30000",
  100. "--samples",
  101. "300000",
  102. "--threads",
  103. "24",
  104. "--add-title",
  105. "--remove-latex",
  106. "--dimensions",
  107. "50",
  108. "--topics",
  109. "25",
  110. "--min-df",
  111. "0.00075",
  112. "--lemmatize",
  113. "--pre-trained-embeddings",
  114. # "--limit-redundancy"
  115. "--use-saved-embeddings"
  116. # "--blacklist",
  117. # "output/medialab/blacklist",
  118. ]
  119. )
  120. # with open(opj(args.location, "params.yml"), "w+") as fp:
  121. # yaml.dump(args, fp)
  122. articles = pd.read_parquet(
  123. opj(args.dataset, "articles.parquet")
  124. )[["title", "abstract", "article_id", "date_created", "categories"]]
  125. if args.add_title:
  126. articles["abstract"] = articles["abstract"].str.cat(articles["title"], sep=". ")
  127. articles.drop(columns=["title"], inplace=True)
  128. if args.remove_latex:
  129. articles["abstract"] = articles["abstract"].apply(
  130. lambda s: re.sub("$[^>]+$", "", s)
  131. )
  132. articles["abstract"] = articles["abstract"].apply(
  133. lambda s: re.sub(r"\b\\\w+", "", s)
  134. )
  135. articles["abstract"] = articles["abstract"].apply(
  136. lambda s: re.sub("[^0-9a-zA-Z--- -\.]+", "", s)
  137. )
  138. # articles["abstract"] = articles["abstract"].str.replace("-", " ") # NEW
  139. articles = articles[articles["abstract"].map(len) >= 100]
  140. articles["abstract"] = articles["abstract"].str.lower()
  141. articles = articles[articles["date_created"].str.len() >= 4]
  142. if "year" not in articles.columns:
  143. articles["year"] = articles["date_created"].str[:4].astype(int) - 2000
  144. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 20)]
  145. else:
  146. articles["year"] = articles["year"].astype(int)
  147. articles = articles[articles["year"]>=2000]
  148. articles["year_group"] = articles["year"] // 5
  149. keep = pd.Series([False] * len(articles), index=articles.index)
  150. print("Applying filter...")
  151. if args.filter == "keywords":
  152. for value in args.values:
  153. keep |= articles["abstract"].str.contains(value)
  154. elif args.filter == "categories":
  155. for value in args.values:
  156. keep |= articles["categories"].apply(lambda l: value in l)
  157. elif args.filter == "no-filter":
  158. keep |= True
  159. articles = articles[keep == True].sample(frac=1)
  160. if args.constant_sampling > 0:
  161. articles = articles.groupby("year").head(args.constant_sampling)
  162. articles = articles.sample(frac=1).head(args.samples)
  163. articles.reset_index(inplace=True)
  164. print(articles)
  165. print("Extracting n-grams...")
  166. extractor = TermExtractor(
  167. articles["abstract"].tolist(),
  168. # limit_redundancy=args.limit_redundancy,
  169. patterns=[
  170. ["JJ.*"],
  171. ["NN.*"],
  172. ["JJ.*", "NN.*"],
  173. ["JJ.*", "NN.*", "NN.*"],
  174. # ["JJ.*", "NN", "CC", "NN.*"],
  175. # ["JJ.*", "NN.*", "JJ.*", "NN.*"],
  176. # ["RB.*", "JJ.*", "NN.*", "NN.*"],
  177. ],
  178. )
  179. ngrams = extractor.ngrams(
  180. threads=args.threads,
  181. lemmatize=args.lemmatize,
  182. lemmatize_ngrams=args.lemmatize,
  183. split_sentences=True,
  184. )
  185. del extractor
  186. del articles["abstract"]
  187. ngrams = map(
  188. lambda l: [
  189. [
  190. ("_".join(n)).strip()
  191. .replace("-", "_")
  192. .replace(".._", "")
  193. .replace("_..", "")
  194. for n in sent
  195. ]
  196. for sent in l
  197. ],
  198. ngrams,
  199. )
  200. ngrams = list(ngrams)
  201. print("Pre-training embeddings...")
  202. emb = GensimWord2Vec(
  203. [sentence for sentences in ngrams for sentence in sentences]
  204. )
  205. losses = []
  206. for dim in [10,25,50,75,100,150,200]:
  207. for attempt in np.arange(5):
  208. model = emb.create_model(
  209. vector_size=dim,
  210. window=5,
  211. workers=args.threads,
  212. compute_loss=True,
  213. epochs=20,
  214. min_count=30,
  215. sg=1,
  216. callbacks=[
  217. MonitorCallback(
  218. [
  219. "black_hole",
  220. "supersymmetry",
  221. ]
  222. )
  223. ],
  224. )
  225. print(model_loss)
  226. losses.append({
  227. 'dim': dim,
  228. 'loss': model_loss
  229. })
  230. print(losses)
  231. losses = pd.DataFrame(losses)
  232. losses.to_csv(opj(args.location, "word2vec_losses.csv"))