features_corr.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import pandas as pd
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. import seaborn as sns
  5. from sklearn.preprocessing import MultiLabelBinarizer
  6. import argparse
  7. parser = argparse.ArgumentParser("extracting correlations")
  8. parser.add_argument('type', choices=["categories", "pacs_codes", "susy"])
  9. parser.add_argument("cond", choices=["cat_topic", "topic_cat", "pearson", "pmi", "npmi"])
  10. parser.add_argument("articles")
  11. parser.add_argument("destination")
  12. parser.add_argument("--descriptions", required=False)
  13. parser.add_argument("--filter", nargs='+', default=[])
  14. args = parser.parse_args([
  15. "pacs_codes",
  16. "pmi",
  17. "output/hep-ct-75-0.1-0.001-130000-20/topics_0.parquet",
  18. "output/pmi.csv",
  19. "--descriptions", "output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv"
  20. ])
  21. def is_susy(s: str):
  22. return "supersymmetr" in s or "susy" in s
  23. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["article_id", "pacs_codes", "categories"] + (["abstract", "title"] if args.type == "susy" else [])]
  24. articles["article_id"] = articles["article_id"].astype(int)
  25. if args.type == "susy":
  26. articles["susy"] = articles["title"].str.lower().map(is_susy) | articles["abstract"].str.lower().map(is_susy)
  27. articles["susy"] = articles["susy"].map(lambda x: ["susy"] if x else ["not_susy"])
  28. topics = pd.read_parquet(args.articles)
  29. topics["article_id"] = topics["article_id"].astype(int)
  30. topics["topics"] = topics["probs"]
  31. if 'categories' in topics.columns:
  32. topics.drop(columns = ['categories'], inplace = True)
  33. topics = topics.merge(articles, how="inner", left_on = "article_id", right_on = "article_id")
  34. topics = topics[topics[args.type].map(len) > 0]
  35. if args.type == "pacs_codes":
  36. codes = set(pd.read_csv("inspire-harvest/database/pacs_codes.csv")["code"])
  37. topics = topics[topics["pacs_codes"].map(lambda l: set(l)&codes).map(len) > 0]
  38. X = np.stack(topics["topics"].values)
  39. binarizer = MultiLabelBinarizer()
  40. Y = binarizer.fit_transform(topics[args.type])
  41. n_articles = len(X)
  42. n_topics = X.shape[1]
  43. n_categories = Y.shape[1]
  44. sums = np.zeros((n_topics, n_categories))
  45. topic_probs = np.zeros(n_topics)
  46. p_topic_cat = np.zeros((n_topics, n_categories))
  47. p_cat_topic = np.zeros((n_topics, n_categories))
  48. pearson = np.zeros((n_topics, n_categories))
  49. pmi = np.zeros((n_topics, n_categories))
  50. npmi = np.zeros((n_topics, n_categories))
  51. if args.cond == "pearson":
  52. for k in range(n_topics):
  53. for c in range(n_categories):
  54. pearson[k,c] = np.corrcoef(X[:,k],Y[:,c])[0,1]
  55. for i in range(n_articles):
  56. for k in range(n_topics):
  57. sums[k,:] += Y[i,:]*X[i,k]
  58. topic_probs = np.mean(X,axis=0)
  59. cat_probs = np.mean(Y,axis=0)
  60. cat_counts = np.sum(Y,axis=0)
  61. significant_cats = cat_counts>=100
  62. for k in range(n_topics):
  63. p_cat_topic[k,:] = sums[k,:]/(topic_probs[k]*n_articles)
  64. for c in range(n_categories):
  65. p_topic_cat[:,c] = sums[:,c]/(cat_probs[c]*n_articles)
  66. for k in range(n_topics):
  67. pmi[k,:] = np.log(sums[k,:]/(topic_probs[k]*np.sum(Y,axis=0)))
  68. for k in range(n_topics):
  69. npmi[k,:] = -np.log(sums[k,:]/(topic_probs[k]*np.sum(Y,axis=0)))/np.log(sums[k,:]/n_articles)
  70. cat_classes = np.array([np.identity(n_categories)[cl] for cl in range(n_categories)])
  71. cat_labels = binarizer.inverse_transform(cat_classes)
  72. data = dict()
  73. for c in range(n_categories):
  74. data[cat_labels[c][0]] = p_topic_cat[:,c] if args.cond == "topic_cat" else (p_cat_topic[:,c] if args.cond == "cat_topic" else (pearson[:,c] if args.cond == "pearson" else (pmi[:,c] if args.cond == "pmi" else npmi[:,c])))
  75. data = pd.DataFrame(data)
  76. if len(args.filter):
  77. data = data[args.filter]
  78. else:
  79. cats = map(lambda c: cat_labels[c][0], np.arange(n_categories)[significant_cats])
  80. data = data[cats]
  81. data["topic"] = data.index
  82. if args.descriptions:
  83. descriptions = pd.read_csv(args.descriptions)[["topic", "description"]].rename(columns={'description_fr': 'description'})
  84. data = data.merge(descriptions, how='left', left_index=True,right_on="topic")
  85. data.to_csv(args.destination)
  86. if len(args.filter):
  87. sns.heatmap(data[args.filter], annot=True, fmt=".2f")
  88. plt.show()