train_embeddings.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from AbstractSemantics.terms import TermExtractor
  2. from AbstractSemantics.embeddings import GensimWord2Vec
  3. import pandas as pd
  4. import numpy as np
  5. from os.path import join as opj
  6. from os.path import exists
  7. import itertools
  8. from functools import partial
  9. from collections import defaultdict
  10. import re
  11. from sklearn.preprocessing import MultiLabelBinarizer
  12. from sklearn.feature_extraction.text import TfidfTransformer
  13. from sklearn.model_selection import train_test_split
  14. import multiprocessing as mp
  15. from matplotlib import pyplot as plt
  16. import argparse
  17. import yaml
  18. import sys
  19. from gensim.models.callbacks import CallbackAny2Vec
  20. class MonitorCallback(CallbackAny2Vec):
  21. def __init__(self, test_words):
  22. self._test_words = test_words
  23. self.epoch = 0
  24. def on_epoch_end(self, model):
  25. loss = model.get_latest_training_loss()
  26. if self.epoch == 0:
  27. print('Loss after epoch {}: {}'.format(self.epoch, loss))
  28. else:
  29. print('Loss after epoch {}: {}'.format(self.epoch, loss- self.loss_previous_step))
  30. self.epoch += 1
  31. self.loss_previous_step = loss
  32. for word in self._test_words: # show wv logic changes
  33. print(f"{word}: {model.wv.most_similar(word)}")
  34. if __name__ == '__main__':
  35. parser = argparse.ArgumentParser('CT Model')
  36. parser.add_argument('location', help='model directory')
  37. parser.add_argument('filter', choices=['categories', 'keywords', 'no-filter'], help='filter type')
  38. parser.add_argument('--values', nargs='+', default=[], help='filter allowed values')
  39. parser.add_argument('--samples', type=int, default=50000)
  40. parser.add_argument('--dimensions', type=int, default=64)
  41. parser.add_argument('--constant-sampling', type=int, default=0)
  42. parser.add_argument('--reuse-articles', default=False, action="store_true", help="reuse article selection")
  43. parser.add_argument('--nouns', default=False, action="store_true", help="include nouns")
  44. parser.add_argument('--adjectives', default=False, action="store_true", help="include adjectives")
  45. parser.add_argument('--lemmatize', default=False, action="store_true", help="stemmer")
  46. parser.add_argument('--remove-latex', default=False, action="store_true", help="remove latex")
  47. parser.add_argument('--add-title', default=False, action="store_true", help="include title")
  48. parser.add_argument('--top-unithood', type=int, default=20000, help='top unithood filter')
  49. parser.add_argument('--min-token-length', type=int, default=0, help='minimum token length')
  50. parser.add_argument('--min-df', type=int, default=0, help='min_df')
  51. parser.add_argument('--reuse-stored-vocabulary', default=False, action='store_true')
  52. parser.add_argument('--threads', type=int, default=4)
  53. args = parser.parse_args(["output/embeddings", "categories", "--values", "Phenomenology-HEP", "Theory-HEP", "--samples", "150000", "--threads", "4"])
  54. with open(opj(args.location, "params.yml"), "w+") as fp:
  55. yaml.dump(args, fp)
  56. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["title", "abstract", "article_id", "date_created", "categories"]]
  57. if args.add_title:
  58. articles["abstract"] = articles["abstract"].str.cat(articles["title"])
  59. articles.drop(columns = ["title"], inplace=True)
  60. if args.remove_latex:
  61. articles['abstract'] = articles['abstract'].apply(lambda s: re.sub('$[^>]+$', '', s))
  62. articles = articles[articles["abstract"].map(len)>=100]
  63. articles["abstract"] = articles["abstract"].str.lower()
  64. articles = articles[articles["date_created"].str.len() >= 4]
  65. articles["year"] = articles["date_created"].str[:4].astype(int)-1980
  66. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  67. articles["year_group"] = articles["year"]//5
  68. if args.reuse_articles:
  69. used = pd.read_csv(opj(args.location, 'articles.csv'))
  70. articles = articles[articles["article_id"].isin(used["article_id"])]
  71. else:
  72. articles = articles[~articles["abstract"].isnull()]
  73. if args.constant_sampling > 0:
  74. articles = articles.groupby("year").head(args.constant_sampling)
  75. keep = pd.Series([False]*len(articles), index=articles.index)
  76. print("Applying filter...")
  77. if args.filter == 'keywords':
  78. for value in args.values:
  79. keep |= articles["abstract"].str.contains(value)
  80. elif args.filter == 'categories':
  81. for value in args.values:
  82. keep |= articles["categories"].apply(lambda l: value in l)
  83. articles = articles[keep==True]
  84. articles = articles.sample(frac=1).head(args.samples)
  85. articles[["article_id"]].to_csv(opj(args.location, 'articles.csv'))
  86. articles.reset_index(inplace = True)
  87. print("Extracting n-grams...")
  88. extractor = TermExtractor(articles["abstract"].tolist())
  89. sentences = extractor.tokens(threads=args.threads, lemmatize=True, split_sentences=True)
  90. print(len(sentences))
  91. print(sentences[0])
  92. print(sentences[0][0])
  93. articles["sentences"] = sentences
  94. for category in args.values:
  95. _articles = articles[articles.categories.map(lambda l: category in l)]
  96. corpus = [sentence for sentences in _articles["sentences"].tolist() for sentence in sentences]
  97. print(category, len(corpus))
  98. emb = GensimWord2Vec(corpus)
  99. model = emb.model(
  100. vector_size=args.dimensions,
  101. window=10,
  102. workers=args.threads,
  103. compute_loss=True,
  104. epochs=50,
  105. callbacks=[MonitorCallback(["quark", "gluino", "renormalization"])]
  106. )
  107. # model.build_vocab(corpus)
  108. model.train(corpus, epochs=10, total_examples=model.corpus_count)
  109. model.train(corpus, epochs=10, total_examples=model.corpus_count)
  110. model.train(corpus, epochs=10, total_examples=model.corpus_count)
  111. model.save(opj(args.location, f"{category}.mdl"))