etm_compile.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import pandas as pd
  2. import pickle
  3. import numpy as np
  4. from scipy.stats import entropy
  5. from os.path import join as opj
  6. import argparse
  7. from tqdm import trange
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument("--input")
  10. parser.add_argument("--dataset", default="inspire-harvest/database")
  11. parser.add_argument("--write-topics", action="store_true", default=False)
  12. parser.add_argument("--debug", action="store_true", default=False)
  13. args = parser.parse_args()
  14. ngrams = pd.read_csv(opj(args.input, "ngrams.csv"))
  15. with open(opj(args.input, "dataset.pickle"), "rb") as handle:
  16. data = pickle.load(handle)
  17. with open(opj(args.input, "etm_instance.pickle"), "rb") as handle:
  18. etm_instance = pickle.load(handle)
  19. articles = pd.read_parquet(opj(args.dataset, "articles.parquet"))[["article_id", "date_created"]]
  20. articles = articles[articles["date_created"].str.len() >= 4]
  21. if "year" not in articles.columns:
  22. articles["year"] = articles["date_created"].str[:4].astype(int)-2000
  23. articles["article_id"] = articles.article_id.astype(int)
  24. _articles = pd.read_csv(opj(args.input, "articles.csv"))
  25. articles = _articles.merge(articles, how="inner")
  26. years = articles["year"].values
  27. if args.write_topics:
  28. top_words = [",".join(l) for l in etm_instance.get_topics(20)]
  29. topics = pd.DataFrame({"top_words": top_words})
  30. topics["label"] = ""
  31. topics.to_csv(opj(args.input, "topics.csv"))
  32. topics = pd.read_csv(opj(args.input, "topics.csv"))
  33. theta = etm_instance.get_document_topic_dist()
  34. is_junk_topic = np.array(topics["label"].str.contains("Junk"))
  35. if args.debug:
  36. import seaborn as sns
  37. sns.heatmap(
  38. np.corrcoef(theta, theta, rowvar=False), vmin=-0.5, vmax=0.5, cmap="RdBu"
  39. )
  40. plt.show()
  41. topic_counts = np.zeros((theta.shape[0], theta.shape[1]))
  42. p_w_z = etm_instance.get_topic_word_dist()
  43. print("Computing P(w|d) matrix...")
  44. p_w_d = theta @ p_w_z
  45. keywords = np.zeros((articles.year.max()+1, p_w_z.shape[1], p_w_z.shape[0]))
  46. for d in trange(theta.shape[0]):
  47. for i, w in enumerate(data["tokens"][d]):
  48. p = p_w_z[:, w] * theta[d, :] / p_w_d[d, w]
  49. S = np.exp(entropy(p))
  50. if S >= 2:
  51. if args.debug:
  52. word = ngrams.iloc[w]["ngram"]
  53. print(f"{word} is ambiguous, entropy={S:.2f}")
  54. continue
  55. else:
  56. k = np.argmax(p)
  57. if is_junk_topic[k]:
  58. continue
  59. topic_counts[d, k] += data["counts"][d][i]
  60. if args.debug:
  61. word = ngrams.iloc[w]["ngram"]
  62. print(word, topics.iloc[k]["label"], data["counts"][d][i])
  63. n_words = topic_counts[d,:].sum()
  64. if args.debug:
  65. print(n_words)
  66. if n_words == 0:
  67. continue
  68. for i, w in enumerate(data["tokens"][d]):
  69. keywords[years[d],w,:] += topic_counts[d,:]
  70. print(topic_counts)
  71. print(topic_counts.mean(axis=0))
  72. print(topic_counts.sum(axis=0))
  73. np.save(opj(args.input, "keywords.npy"), keywords)
  74. np.save(opj(args.input, "topics_counts.npy"), topic_counts)