tsne.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import pandas as pd
  2. import numpy as np
  3. import tomotopy as tp
  4. import matplotlib
  5. from matplotlib import pyplot as plt
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update(
  8. {
  9. "pgf.texsystem": "xelatex",
  10. "font.family": "serif",
  11. "font.serif": "Times New Roman",
  12. "text.usetex": True,
  13. "pgf.rcfonts": False,
  14. }
  15. )
  16. from adjustText import adjust_text
  17. import textwrap
  18. from sklearn.manifold import TSNE
  19. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["article_id", "pacs_codes", "categories"]]
  20. articles["article_id"] = articles["article_id"].astype(int)
  21. topics = pd.read_parquet("output/hep-ct-75-0.1-0.001-130000-20/topics_0.parquet")
  22. topics["article_id"] = topics["article_id"].astype(int)
  23. topics["topics"] = topics["probs"]
  24. topics = topics.merge(articles, how="inner", left_on = "article_id", right_on = "article_id")
  25. topics["categories"] = topics["categories"].map(
  26. lambda l: (
  27. [x in l for x in ["Theory-HEP", "Phenomenology-HEP", "Experiment-HEP"]]
  28. )
  29. )
  30. X = np.stack(topics["topics"].values)
  31. Y = np.stack(topics["categories"].values).astype(int)
  32. cat_topic_mean = np.zeros((Y.shape[1], X.shape[1]))
  33. for i in range(3):
  34. cat_topic_mean[i] = X[Y[:,i]==1,:].mean(axis=0)
  35. topic_main_category = cat_topic_mean.argmax(axis=0).astype(int)
  36. usages = pd.read_csv('output/supersymmetry_usages.csv')
  37. usages = usages.groupby("term").agg(topic=("topic", lambda x: x.tolist()))
  38. susy_topics = usages.loc["supersymmetry"]["topic"] + [t for t in usages.loc["susy"]["topic"] if t not in usages.loc["supersymmetry"]["topic"]]
  39. descriptions = pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")
  40. labels = descriptions.loc[susy_topics]["description"].tolist()
  41. edges = np.array([False]*len(topic_main_category))
  42. edges[susy_topics]=True
  43. edges = ["black" if edge else "none" for edge in edges]
  44. mdl = tp.CTModel.load("output/hep-ct-75-0.1-0.001-130000-20/model")
  45. correlations = mdl.get_correlations()
  46. colors=['#377eb8', '#ff7f00', '#4daf4a']
  47. cats=["Theory", "Phenomenology", "Experiment"]
  48. tsne = TSNE(n_components=2, metric="precomputed", random_state=714, perplexity=40)
  49. points = tsne.fit_transform(1-correlations)
  50. from sklearn.linear_model import LinearRegression
  51. reg = LinearRegression()
  52. reg.fit(points, topic_main_category)
  53. angle = np.arctan(reg.coef_[0]/reg.coef_[1])-np.pi/2
  54. m = np.array([[np.cos(angle), np.sin(angle)], [-np.sin(angle),np.cos(angle)]])
  55. points=points@m
  56. fig, axes = plt.subplots(nrows=2,ncols=1,sharex=True,gridspec_kw={"height_ratios": [5, 1]},figsize=[6.4,5])
  57. for i, cat in enumerate(cats):
  58. axes[0].scatter(
  59. points[topic_main_category==i,0],
  60. points[topic_main_category==i,1],
  61. color=colors[i],
  62. label=cat,
  63. edgecolors=[edges[i[0]] for i in np.argwhere(topic_main_category==i) if i!=np.nan]
  64. )
  65. texts = []
  66. for i,topic in enumerate(susy_topics):
  67. texts.append(
  68. axes[0].annotate(
  69. labels[i],
  70. xy=(points[topic,0],points[topic,1]),
  71. xytext=(points[topic,0],points[topic,1]+0.25),
  72. size="small"
  73. )
  74. )
  75. adjust_text(texts,ax=axes[0])
  76. import seaborn as sns
  77. sns.kdeplot(data=[points[topic_main_category==i,0] for i in range(3)],ax=axes[1],legend=False)
  78. plt.subplots_adjust(wspace=0, hspace=0)
  79. for i in range(2):
  80. axes[i].set_xticklabels([])
  81. axes[i].set_yticklabels([])
  82. fig.legend()
  83. plt.savefig(f"plots/topics_tsne.eps", bboxes_inches="tight")
  84. plt.savefig(f"plots/topics_tsne.png", bboxes_inches="tight")