etm_transfers.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import pandas as pd
  4. import numpy as np
  5. import networkx as nx
  6. from ipysigma import Sigma
  7. from matplotlib import pyplot as plt
  8. import seaborn as sns
  9. import pickle
  10. from os.path import join as opj
  11. import argparse
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--location")
  14. parser.add_argument("--dataset", default="inspire-harvest/database")
  15. parser.add_argument("--keywords-threshold", type=int, default=200)
  16. parser.add_argument("--articles-threshold", type=int, default=5)
  17. parser.add_argument("--early-periods", nargs="+", type=int, default=[0,1]) # [2,3] for ACL, [3] for HEP
  18. parser.add_argument("--late-periods", nargs="+", type=int, default=[3]) # [2,3] for ACL, [3] for HEP
  19. parser.add_argument("--fla", action="store_true", help="first or last author")
  20. args = parser.parse_args()
  21. custom_range = "_" + "-".join(map(str, args.early_periods)) + "_" + "-".join(map(str, args.late_periods)) if (args.early_periods!=[0,1] or args.late_periods!=[3]) else ""
  22. print(custom_range)
  23. topics = pd.read_csv(opj(args.location, "topics.csv"))["label"].tolist()
  24. topic_matrix = np.load(opj(args.location, "topics_counts.npy"))
  25. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]]
  26. articles = articles[articles["date_created"].str.len() >= 4]
  27. if "years" not in articles.columns:
  28. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  29. else:
  30. articles["year"] = articles["year"].astype(int)-2002
  31. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  32. articles["year_group"] = articles["year"]//5
  33. _articles = pd.read_csv(opj(args.location,"articles.csv"))
  34. articles["article_id"] = articles.article_id.astype(int)
  35. articles = _articles.merge(articles, how="left")
  36. print(len(_articles))
  37. print(len(articles))
  38. articles["main_topic"] = topic_matrix.argmax(axis=1)
  39. articles["main_topic"] = articles["main_topic"].map(lambda k: topics[k])
  40. print(articles[["title", "main_topic"]].sample(frac=1).head(10))
  41. print(articles[["title", "main_topic"]].sample(frac=1).head(10))
  42. all_authors = pd.read_parquet(opj(args.dataset, "articles_authors.parquet"))
  43. all_authors["article_id"] = all_authors.article_id.astype(int)
  44. n_authors = all_authors.groupby("article_id").agg(
  45. n_authors=("bai", lambda x: x.nunique()),
  46. first_author=("bai", "first"),
  47. last_author=("bai", "last")
  48. ).reset_index()
  49. n_articles = len(articles)
  50. articles = articles.merge(n_authors, how="left", left_on="article_id", right_on="article_id")
  51. assert len(articles)==n_articles, "# of articles does not match! cannot continue"
  52. all_authors = all_authors.merge(articles, how="inner", left_on="article_id", right_on="article_id")
  53. all_authors["year_range"] = all_authors["year"]//5
  54. n_papers = all_authors.groupby(["bai", "year_range"]).agg(n=("article_id", "count")).reset_index()
  55. filtered_authors = []
  56. for author, n in n_papers.groupby("bai"):
  57. start = n[n["year_range"].isin(args.early_periods)]
  58. # end = n[n["year_range"]==3]
  59. end = n[n["year_range"].isin(args.late_periods)]
  60. if len(start) and len(end):
  61. filtered_authors.append({
  62. "author": author,
  63. "n_start": start.iloc[0]["n"],
  64. "n_end": end.iloc[0]["n"],
  65. })
  66. filtered_authors = pd.DataFrame(filtered_authors)
  67. filtered_authors = filtered_authors[(filtered_authors["n_start"] >= args.articles_threshold) & (filtered_authors["n_end"] >= args.articles_threshold)]
  68. authors=all_authors[all_authors["bai"].isin(filtered_authors["author"])]
  69. start_authors = authors[authors["year_range"].isin(args.early_periods)]
  70. # end_authors = authors[authors["year_range"]==3]
  71. end_authors = authors[authors["year_range"].isin(args.late_periods)]
  72. authorlist = list(authors["bai"].unique())
  73. inv_articles = {n: i for i,n in enumerate(articles["article_id"].values)}
  74. inv_authorlist = {author: i for i, author in enumerate(authorlist)}
  75. n_authors = len(authorlist)
  76. n_clusters = topic_matrix.shape[1]
  77. n_years = articles["year"].max()+1
  78. start = np.zeros((n_authors, n_clusters))
  79. end = np.zeros((n_authors, n_clusters))
  80. expertise = np.zeros((n_authors, n_clusters))
  81. start_count = np.zeros(n_authors)
  82. end_count = np.zeros(n_authors)
  83. expertise_norm = np.zeros(n_authors)
  84. for author, _articles in start_authors.groupby("bai"):
  85. for article in _articles.to_dict(orient="records"):
  86. article_id = article["article_id"]
  87. n = articles.iloc[inv_articles[article_id]]["n_authors"]
  88. expertise[inv_authorlist[author]] += (1/n)*topic_matrix[inv_articles[article_id],:].flat
  89. expertise_norm[inv_authorlist[author]] += (1/n)*topic_matrix[inv_articles[article_id],:].sum()
  90. if args.fla and author not in [article["first_author"], article["last_author"]]:
  91. continue
  92. start[inv_authorlist[author],:] += topic_matrix[inv_articles[article_id],:].flat
  93. start_count[inv_authorlist[author]] += topic_matrix[inv_articles[article_id],:].sum()
  94. for author, _articles in end_authors.groupby("bai"):
  95. for article in _articles.to_dict(orient="records"):
  96. article_id = article["article_id"]
  97. if args.fla and author not in [article["first_author"], article["last_author"]]:
  98. continue
  99. end[inv_authorlist[author],:] += topic_matrix[inv_articles[article_id],:].flat
  100. end_count[inv_authorlist[author]] += topic_matrix[inv_articles[article_id],:].sum()
  101. authors_records = {}
  102. for author, _articles in all_authors.groupby("bai"):
  103. record = np.zeros((n_years, n_clusters))
  104. record_count = np.zeros((n_years, n_clusters))
  105. for article in _articles.to_dict(orient="records"):
  106. year = article["year"]
  107. article_id = article["article_id"]
  108. record[year,:] += topic_matrix[inv_articles[article_id],:].flat
  109. record_count[year] += topic_matrix[inv_articles[article_id],:].sum()
  110. authors_records[author] = {
  111. "record": record,
  112. "record_count": record_count
  113. }
  114. if args.fla:
  115. with open(opj(args.location, f"authors_full_records_fla{custom_range}.pickle"), "wb") as handle:
  116. pickle.dump(authors_records, handle, protocol=pickle.HIGHEST_PROTOCOL)
  117. else:
  118. with open(opj(args.location, f"authors_full_records{custom_range}.pickle"), "wb") as handle:
  119. pickle.dump(authors_records, handle, protocol=pickle.HIGHEST_PROTOCOL)
  120. ok = (start_count>=args.keywords_threshold)&(end_count>=args.keywords_threshold)
  121. cluster_names_start = [f"start_{n+1}" for n in range(n_clusters)]
  122. cluster_names_end = [f"end_{n+1}" for n in range(n_clusters)]
  123. cluster_names_expertise = [f"expertise_{n+1}" for n in range(n_clusters)]
  124. start = start[ok]
  125. end = end[ok]
  126. start_count = start_count[ok]
  127. end_count = end_count[ok]
  128. expertise = expertise[ok]/expertise_norm[ok][:,np.newaxis]
  129. start_norm = (start/start_count[:,np.newaxis])
  130. end_norm = (end/end_count[:,np.newaxis])
  131. print(start_norm.shape)
  132. print(end_norm.shape)
  133. print(start_norm.mean(axis=0))
  134. print(end_norm.mean(axis=0))
  135. aggregate = {}
  136. for i in range(n_clusters):
  137. aggregate[cluster_names_start[i]] = start[:,i]
  138. aggregate[cluster_names_end[i]] = end[:,i]
  139. aggregate[cluster_names_expertise[i]] = expertise[:,i]
  140. aggregate = pd.DataFrame(aggregate)
  141. aggregate["bai"] = [bai for i, bai in enumerate(authorlist) if ok[i]]
  142. if args.fla:
  143. aggregate.to_csv(opj(args.location, f"aggregate_fla{custom_range}.csv"))
  144. else:
  145. aggregate.to_csv(opj(args.location, f"aggregate{custom_range}.csv"))
  146. sns.heatmap(np.corrcoef(start_norm.T, end_norm.T), vmin=-0.5, vmax=0.5, cmap="RdBu")
  147. plt.show()