trading_zone.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #!/usr/bin/env python
  2. from AbstractSemantics.terms import TermExtractor
  3. import pandas as pd
  4. import numpy as np
  5. from os.path import join as opj
  6. from collections import defaultdict
  7. import re
  8. import argparse
  9. import yaml
  10. import sys
  11. def is_hep(categories: str):
  12. return any(["-HEP" in x for x in categories])
  13. if __name__ == '__main__':
  14. parser = argparse.ArgumentParser('trading zone')
  15. parser.add_argument('--location', help='model directory', default="output/trading_zone")
  16. parser.add_argument('--filter', choices=['categories', 'keywords', 'no-filter'], help='filter type', default="categories")
  17. parser.add_argument('--values', nargs='+', default=["Theory-HEP", "Phenomenology-HEP", "Experiment-HEP"], help='filter allowed values')
  18. parser.add_argument('--exclude', nargs='+', default=[], help='exclude values')
  19. parser.add_argument('--samples', type=int, default=10000000)
  20. parser.add_argument('--constant-sampling', type=int, default=0)
  21. parser.add_argument('--reuse-articles', default=False, action="store_true", help="reuse article selection")
  22. parser.add_argument('--nouns', default=False, action="store_true", help="include nouns")
  23. parser.add_argument('--adjectives', default=False, action="store_true", help="include adjectives")
  24. parser.add_argument('--lemmatize', default=True, action="store_true", help="stemmer")
  25. parser.add_argument('--lemmatize-ngrams', default=True, action="store_true", help="stemmer")
  26. parser.add_argument('--remove-latex', default=True, action="store_true", help="remove latex")
  27. parser.add_argument('--limit-redundancy', default=False, action="store_true", help="limit redundancy")
  28. parser.add_argument('--add-title', default=True, action="store_true", help="include title")
  29. parser.add_argument('--top-unithood', type=int, default=2000, help='top unithood filter')
  30. parser.add_argument('--threads', type=int, default=16)
  31. parser.add_argument('--category-cited', type=int, help="filter cited category (0=theory,1=phenomenology,2=experiment)")
  32. parser.add_argument('--category-cites', type=int, help="filter citing category (0=theory,1=phenomenology,2=experiment)")
  33. parser.add_argument('--include-crosslists', default=False, action="store_true", help="include crosslists papers")
  34. args = parser.parse_args()
  35. location = f"{args.location}_{args.category_cited}_{args.category_cites}"
  36. with open(opj(location, "params.yml"), "w+") as fp:
  37. yaml.dump(args, fp)
  38. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["title", "abstract", "article_id", "date_created", "categories"]]
  39. articles = articles[articles.categories.map(is_hep)]
  40. articles = articles[articles["date_created"].str.len() >= 4]
  41. articles["year"] = articles["date_created"].str[:4].astype(int)
  42. articles = articles[articles["year"]<=2019]
  43. if args.add_title:
  44. articles["abstract"] = articles["abstract"].str.cat(articles["title"])
  45. articles.drop(columns = ["title"], inplace=True)
  46. if args.remove_latex:
  47. articles['abstract'] = articles['abstract'].apply(lambda s: re.sub('$[^>]+$', '', s))
  48. articles = articles[articles["abstract"].map(len)>=100]
  49. articles["abstract"] = articles["abstract"].str.lower()
  50. articles = articles[articles["date_created"].str.len() >= 4]
  51. articles = articles[~articles["abstract"].isnull()]
  52. if args.constant_sampling > 0:
  53. articles = articles.groupby("year").head(args.constant_sampling)
  54. keep = pd.Series([False]*len(articles), index=articles.index)
  55. print("Applying filter...")
  56. if args.filter == 'keywords':
  57. for value in args.values:
  58. keep |= articles["abstract"].str.contains(value)
  59. for value in args.exclude:
  60. keep &= ~articles["abstract"].str.contains(value)
  61. elif args.filter == 'categories':
  62. for value in args.values:
  63. keep |= articles["categories"].apply(lambda l: value in l)
  64. for value in args.exclude:
  65. keep &= ~articles["categories"].apply(lambda l: value in l)
  66. articles = articles[keep==True]
  67. citations = articles[[x for x in articles.columns if x != "abstract"]].merge(pd.read_parquet("output/cross_citations{}.parquet".format("_crosslists" if args.include_crosslists else ""))[["article_id_cited", "article_id_cites", "category_cites", "category_cited", "year_cites", "year_cited"]], how="inner", left_on="article_id",right_on="article_id_cited")
  68. citations = citations[
  69. (citations["category_cited"] == args.category_cited) & (citations["category_cites"].isin([args.category_cites,args.category_cited]))
  70. ]
  71. citations["trade"] = (citations["category_cited"] != citations["category_cites"])
  72. citations = citations[citations["year_cites"]>=2001]
  73. citations = citations[citations["year_cites"]<=2019]
  74. citations["year_cites"] = ((citations["year_cites"]-citations["year_cites"].min())).astype(int)
  75. citations.drop_duplicates(["article_id_cited", "article_id_cites"], inplace=True)
  76. citations = citations.sample(args.samples if args.samples < len(citations) else len(citations))
  77. citations = citations.groupby(["article_id_cited", "year_cites"]).agg(
  78. trades = ("trade", "sum"),
  79. total = ("article_id_cites", "count"),
  80. category_cited = ("category_cited", "first")
  81. )
  82. citations.reset_index(inplace=True)
  83. articles_to_keep = list(citations["article_id_cited"].unique())
  84. citations = citations[citations["article_id_cited"].isin(articles_to_keep)]
  85. articles = articles[articles["article_id"].isin(articles_to_keep)]
  86. articles = articles.merge(
  87. citations[["article_id_cited", "category_cited"]].drop_duplicates(),
  88. left_on="article_id",
  89. right_on="article_id_cited"
  90. )
  91. print("Extracting n-grams...")
  92. extractor = TermExtractor(
  93. articles["abstract"].tolist(),
  94. limit_redundancy=args.limit_redundancy
  95. )
  96. if args.nouns:
  97. extractor.add_patterns([["NN.*"]])
  98. if args.adjectives:
  99. extractor.add_patterns([["^JJ$"]])
  100. ngrams = extractor.ngrams(
  101. threads=args.threads,
  102. lemmatize=args.lemmatize,
  103. lemmatize_ngrams=args.lemmatize_ngrams
  104. )
  105. ngrams = map(lambda l: [" ".join(n) for n in l], ngrams)
  106. ngrams = list(ngrams)
  107. articles["ngrams"] = ngrams
  108. print("Deriving vocabulary...")
  109. ngrams_occurrences = defaultdict(int)
  110. categories = articles["category_cited"].tolist()
  111. for ngrams in articles["ngrams"].tolist():
  112. _ngrams = set(ngrams)
  113. for ngram in _ngrams:
  114. ngrams_occurrences[ngram] += 1
  115. ngrams_occurrences = {
  116. "ngram": ngrams_occurrences.keys(),
  117. "count": ngrams_occurrences.values()
  118. }
  119. ngrams_occurrences = pd.DataFrame(ngrams_occurrences)
  120. ngrams_occurrences["unithood"] = (
  121. np.log(2 + ngrams_occurrences["ngram"].str.count(" "))
  122. * ngrams_occurrences["count"] / len(articles)
  123. )
  124. ngrams_occurrences.set_index("ngram", inplace=True)
  125. top = ngrams_occurrences.sort_values("unithood", ascending=False).head(
  126. args.top_unithood
  127. )
  128. top.to_csv(opj(location, "ngrams.csv"))
  129. articles = articles.sample(frac=1)
  130. ngrams = articles["ngrams"].tolist()
  131. selected_ngrams = pd.read_csv(opj(location, 'ngrams.csv'))['ngram'].tolist()
  132. vocabulary = {
  133. n: i
  134. for i, n in enumerate(selected_ngrams)
  135. }
  136. ngrams = [[ngram for ngram in _ngrams if ngram in selected_ngrams] for _ngrams in ngrams]
  137. ngrams_bow = [[vocabulary[ngram] for ngram in _ngrams] for _ngrams in ngrams]
  138. ngrams_bow = [[_ngrams.count(i) for i in range(len(selected_ngrams))] for _ngrams in ngrams_bow]
  139. n = []
  140. frac = []
  141. ngrams_bow = np.array(ngrams_bow)
  142. for i in range(0, args.top_unithood, 10):
  143. sum_bow = ngrams_bow[:,:i]
  144. sum_bow = sum_bow.sum(axis=1)
  145. sum_bow = (sum_bow==0).mean()
  146. n.append(i)
  147. frac.append(100*sum_bow)
  148. frac = np.array(frac)
  149. n_words = n[np.argmin(frac[frac>5])]
  150. print(f"preserving {n_words} words out of {len(selected_ngrams)}.")
  151. selected_ngrams = selected_ngrams[:n_words]
  152. vocabulary = {
  153. n: i
  154. for i, n in enumerate(selected_ngrams)
  155. }
  156. ngrams = [[ngram for ngram in _ngrams if ngram in selected_ngrams] for _ngrams in ngrams]
  157. ngrams_bow = [[vocabulary[ngram] for ngram in _ngrams] for _ngrams in ngrams]
  158. ngrams_bow = [[_ngrams.count(i) for i in range(len(selected_ngrams))] for _ngrams in ngrams_bow]
  159. full_bow = (np.array(ngrams_bow)>=1)*1
  160. np.save(opj(location, 'full_bow.npy'), full_bow)
  161. pd.DataFrame({'ngram': selected_ngrams}).to_csv(opj(location, "selected_ngrams.csv"), index=False)
  162. articles.reset_index(inplace=True)
  163. articles["id"] = articles.index+1
  164. citations = citations.merge(articles[["article_id", "id"]], left_on="article_id_cited",right_on="article_id")
  165. citations.to_csv(opj(location, 'citations.csv'))