tsne.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import pandas as pd
  2. import numpy as np
  3. import tomotopy as tp
  4. import matplotlib
  5. from matplotlib import pyplot as plt
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update(
  8. {
  9. "pgf.texsystem": "xelatex",
  10. "font.family": "serif",
  11. "font.serif": "Times New Roman",
  12. "text.usetex": True,
  13. "pgf.rcfonts": False,
  14. }
  15. )
  16. from adjustText import adjust_text
  17. import textwrap
  18. from sklearn.manifold import TSNE
  19. articles = pd.read_parquet("inspire-harvest/database/articles.parquet")[["article_id", "pacs_codes", "categories"]]
  20. articles["article_id"] = articles["article_id"].astype(int)
  21. topics = pd.read_parquet("output/hep-ct-75-0.1-0.001-130000-20/topics_0.parquet")
  22. topics["article_id"] = topics["article_id"].astype(int)
  23. topics["topics"] = topics["probs"]
  24. topics = topics.merge(articles, how="inner", left_on = "article_id", right_on = "article_id")
  25. topics["categories"] = topics["categories"].map(
  26. lambda l: (
  27. [x in l for x in ["Theory-HEP", "Phenomenology-HEP", "Experiment-HEP"]]
  28. )
  29. )
  30. X = np.stack(topics["topics"].values)
  31. Y = np.stack(topics["categories"].values).astype(int)
  32. num = np.outer(X.sum(axis=0),Y.sum(axis=0))/(X.shape[0]**2)
  33. den = np.tensordot(X, Y, axes=([0],[0]))/X.shape[0]
  34. npmi = np.log(num)/np.log(den)-1
  35. topic_main_category = npmi.argmax(axis=1).astype(int)
  36. usages = pd.read_csv('output/supersymmetry_usages.csv')
  37. usages = usages.groupby("term").agg(topic=("topic", lambda x: x.tolist()))
  38. susy_topics = usages.loc["supersymmetry"]["topic"] + [t for t in usages.loc["susy"]["topic"] if t not in usages.loc["supersymmetry"]["topic"]]
  39. descriptions = pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")
  40. labels = descriptions.loc[susy_topics]["description"].tolist()
  41. edges = np.array([False]*len(topic_main_category))
  42. edges[susy_topics]=True
  43. edges = ["black" if edge else "none" for edge in edges]
  44. mdl = tp.CTModel.load("output/hep-ct-75-0.1-0.001-130000-20/model")
  45. correlations = mdl.get_correlations()
  46. colors=['#377eb8', '#ff7f00', '#4daf4a']
  47. cats=["Theory", "Phenomenology", "Experiment"]
  48. # perplexity = 40 better identifies the different lumps
  49. tsne = TSNE(n_components=2, metric="precomputed", random_state=714, perplexity=40)
  50. points = tsne.fit_transform(1-correlations)
  51. from sklearn.linear_model import LinearRegression
  52. reg = LinearRegression()
  53. reg.fit(points, topic_main_category)
  54. angle = np.arctan(reg.coef_[0]/reg.coef_[1])-np.pi/2
  55. m = np.array([[np.cos(angle), np.sin(angle)], [-np.sin(angle),np.cos(angle)]])
  56. points=points@m
  57. fig, axes = plt.subplots(nrows=2,ncols=1,sharex=True,gridspec_kw={"height_ratios": [5, 1]})
  58. for i, cat in enumerate(cats):
  59. axes[0].scatter(
  60. points[topic_main_category==i,0],
  61. points[topic_main_category==i,1],
  62. color=colors[i],
  63. label=cat,
  64. edgecolors=[edges[i[0]] for i in np.argwhere(topic_main_category==i) if i!=np.nan]
  65. )
  66. texts = []
  67. for i,topic in enumerate(susy_topics):
  68. texts.append(
  69. axes[0].annotate(
  70. labels[i],
  71. xy=(points[topic,0],points[topic,1]),
  72. xytext=(points[topic,0],points[topic,1]+0.25),
  73. size="small"
  74. )
  75. )
  76. adjust_text(texts,ax=axes[0])
  77. import seaborn as sns
  78. sns.kdeplot(data=[points[topic_main_category==i,0] for i in range(3)],ax=axes[1],legend=False)
  79. plt.subplots_adjust(wspace=0, hspace=0)
  80. for i in range(2):
  81. axes[i].set_xticklabels([])
  82. axes[i].set_yticklabels([])
  83. axes[0].set_ylabel("y-axis")
  84. axes[1].set_xlabel("x-axis")
  85. axes[1].set_ylabel("Density")
  86. axes[0].legend()
  87. fig.savefig(f"plots/topics_tsne.eps", bbox_inches="tight")
  88. fig.savefig(f"plots/topics_tsne.png", bbox_inches="tight")