etm_transfers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import pandas as pd
  4. import numpy as np
  5. import networkx as nx
  6. from ipysigma import Sigma
  7. from matplotlib import pyplot as plt
  8. import seaborn as sns
  9. import pickle
  10. from os.path import join as opj
  11. import argparse
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--location")
  14. parser.add_argument("--dataset", default="inspire-harvest/database")
  15. parser.add_argument("--keywords-threshold", type=int, default=200)
  16. parser.add_argument("--articles-threshold", type=int, default=5)
  17. parser.add_argument("--late-periods", nargs="+", type=int, default=[3]) # [2,3] for ACL, [3] for HEP
  18. args = parser.parse_args()
  19. topics = pd.read_csv(opj(args.location, "topics.csv"))["label"].tolist()
  20. topic_matrix = np.load(opj(args.location, "topics_counts.npy"))
  21. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created", "title"]]
  22. articles = articles[articles["date_created"].str.len() >= 4]
  23. if "years" not in articles.columns:
  24. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  25. else:
  26. articles["year"] = articles["year"].astype(int)-2002
  27. articles = articles[(articles["year"] >= 0) & (articles["year"] <= 40)]
  28. articles["year_group"] = articles["year"]//5
  29. _articles = pd.read_csv(opj(args.location,"articles.csv"))
  30. articles["article_id"] = articles.article_id.astype(int)
  31. articles = _articles.merge(articles, how="left")
  32. print(len(_articles))
  33. print(len(articles))
  34. articles["main_topic"] = topic_matrix.argmax(axis=1)
  35. articles["main_topic"] = articles["main_topic"].map(lambda k: topics[k])
  36. print(articles[["title", "main_topic"]].sample(frac=1).head(10))
  37. print(articles[["title", "main_topic"]].sample(frac=1).head(10))
  38. all_authors = pd.read_parquet(opj(args.dataset, "articles_authors.parquet"))
  39. all_authors["article_id"] = all_authors.article_id.astype(int)
  40. n_authors = all_authors.groupby("article_id").agg(n_authors=("bai", lambda x: x.nunique())).reset_index()
  41. n_articles = len(articles)
  42. articles = articles.merge(n_authors, how="left", left_on="article_id", right_on="article_id")
  43. assert len(articles)==n_articles, "# of articles does not match! cannot continue"
  44. all_authors = all_authors.merge(articles, how="inner", left_on="article_id", right_on="article_id")
  45. all_authors["year_range"] = all_authors["year"]//5
  46. n_papers = all_authors.groupby(["bai", "year_range"]).agg(n=("article_id", "count")).reset_index()
  47. filtered_authors = []
  48. for author, n in n_papers.groupby("bai"):
  49. start = n[n["year_range"]<=1]
  50. # end = n[n["year_range"]==3]
  51. end = n[n["year_range"].isin(args.late_periods)]
  52. if len(start) and len(end):
  53. filtered_authors.append({
  54. "author": author,
  55. "n_start": start.iloc[0]["n"],
  56. "n_end": end.iloc[0]["n"],
  57. })
  58. filtered_authors = pd.DataFrame(filtered_authors)
  59. filtered_authors = filtered_authors[(filtered_authors["n_start"] >= args.articles_threshold) & (filtered_authors["n_end"] >= args.articles_threshold)]
  60. authors=all_authors[all_authors["bai"].isin(filtered_authors["author"])]
  61. start_authors = authors[authors["year_range"]<=1]
  62. # end_authors = authors[authors["year_range"]==3]
  63. end_authors = authors[authors["year_range"].isin(args.late_periods)]
  64. authorlist = list(authors["bai"].unique())
  65. inv_articles = {n: i for i,n in enumerate(articles["article_id"].values)}
  66. inv_authorlist = {author: i for i, author in enumerate(authorlist)}
  67. n_authors = len(authorlist)
  68. n_clusters = topic_matrix.shape[1]
  69. n_years = articles["year"].max()+1
  70. start = np.zeros((n_authors, n_clusters))
  71. end = np.zeros((n_authors, n_clusters))
  72. expertise = np.zeros((n_authors, n_clusters))
  73. start_count = np.zeros(n_authors)
  74. end_count = np.zeros(n_authors)
  75. expertise_norm = np.zeros(n_authors)
  76. for author, _articles in start_authors.groupby("bai"):
  77. for article_id in _articles["article_id"].tolist():
  78. start[inv_authorlist[author],:] += topic_matrix[inv_articles[article_id],:].flat
  79. start_count[inv_authorlist[author]] += topic_matrix[inv_articles[article_id],:].sum()
  80. n = articles.iloc[inv_articles[article_id]]["n_authors"]
  81. expertise[inv_authorlist[author]] += (1/n)*topic_matrix[inv_articles[article_id],:].flat
  82. expertise_norm[inv_authorlist[author]] += (1/n)*topic_matrix[inv_articles[article_id],:].sum()
  83. for author, _articles in end_authors.groupby("bai"):
  84. for article_id in _articles["article_id"].tolist():
  85. end[inv_authorlist[author],:] += topic_matrix[inv_articles[article_id],:].flat
  86. end_count[inv_authorlist[author]] += topic_matrix[inv_articles[article_id],:].sum()
  87. authors_records = {}
  88. for author, _articles in all_authors.groupby("bai"):
  89. record = np.zeros((n_years, n_clusters))
  90. record_count = np.zeros((n_years, n_clusters))
  91. for article in _articles.to_dict(orient="records"):
  92. year = article["year"]
  93. article_id = article["article_id"]
  94. record[year,:] += topic_matrix[inv_articles[article_id],:].flat
  95. record_count[year] += topic_matrix[inv_articles[article_id],:].sum()
  96. authors_records[author] = {
  97. "record": record,
  98. "record_count": record_count
  99. }
  100. with open(opj(args.location, "authors_full_records.pickle"), "wb") as handle:
  101. pickle.dump(authors_records, handle, protocol=pickle.HIGHEST_PROTOCOL)
  102. ok = (start_count>=args.keywords_threshold)&(end_count>=args.keywords_threshold)
  103. cluster_names_start = [f"start_{n+1}" for n in range(n_clusters)]
  104. cluster_names_end = [f"end_{n+1}" for n in range(n_clusters)]
  105. cluster_names_expertise = [f"expertise_{n+1}" for n in range(n_clusters)]
  106. start = start[ok]
  107. end = end[ok]
  108. start_count = start_count[ok]
  109. end_count = end_count[ok]
  110. expertise = expertise[ok]/expertise_norm[ok][:,np.newaxis]
  111. start_norm = (start/start_count[:,np.newaxis])
  112. end_norm = (end/end_count[:,np.newaxis])
  113. print(start_norm.shape)
  114. print(end_norm.shape)
  115. print(start_norm.mean(axis=0))
  116. print(end_norm.mean(axis=0))
  117. aggregate = {}
  118. for i in range(n_clusters):
  119. aggregate[cluster_names_start[i]] = start[:,i]
  120. aggregate[cluster_names_end[i]] = end[:,i]
  121. aggregate[cluster_names_expertise[i]] = expertise[:,i]
  122. aggregate = pd.DataFrame(aggregate)
  123. aggregate["bai"] = [bai for i, bai in enumerate(authorlist) if ok[i]]
  124. aggregate.to_csv(opj(args.location, "aggregate.csv"))
  125. sns.heatmap(np.corrcoef(start_norm.T, end_norm.T), vmin=-0.5, vmax=0.5, cmap="RdBu")
  126. plt.show()