susy_correlations.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import pandas as pd
  2. import numpy as np
  3. import tomotopy as tp
  4. import matplotlib
  5. from matplotlib import pyplot as plt
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update(
  8. {
  9. "pgf.texsystem": "xelatex",
  10. "font.family": "serif",
  11. "font.serif": "Times New Roman",
  12. "text.usetex": True,
  13. "pgf.rcfonts": False,
  14. }
  15. )
  16. import seaborn as sns
  17. import textwrap
  18. def reverse_colourmap(cmap, name):
  19. """
  20. In:
  21. cmap, name
  22. Out:
  23. my_cmap_r
  24. Explanation:
  25. t[0] goes from 0 to 1
  26. row i: x y0 y1 -> t[0] t[1] t[2]
  27. /
  28. /
  29. row i+1: x y0 y1 -> t[n] t[1] t[2]
  30. so the inverse should do the same:
  31. row i+1: x y1 y0 -> 1-t[0] t[2] t[1]
  32. /
  33. /
  34. row i: x y1 y0 -> 1-t[n] t[2] t[1]
  35. """
  36. reverse = []
  37. k = []
  38. for key in cmap._segmentdata:
  39. k.append(key)
  40. channel = cmap._segmentdata[key]
  41. data = []
  42. for t in channel:
  43. data.append((1-t[0],t[2],t[1]))
  44. reverse.append(sorted(data))
  45. LinearL = dict(zip(k,reverse))
  46. my_cmap_r = matplotlib.colors.LinearSegmentedColormap(name, LinearL)
  47. return my_cmap_r
  48. orig_cmap = matplotlib.cm.RdBu
  49. mdl = tp.CTModel.load("output/hep-ct-75-0.1-0.001-130000-20/model")
  50. correlations = mdl.get_correlations()
  51. usages = pd.read_csv('output/supersymmetry_usages.csv')
  52. usages = usages.groupby("term").agg(topic=("topic", lambda x: x.tolist()))
  53. topics = usages.loc["supersymmetry"]["topic"] + [t for t in usages.loc["susy"]["topic"] if t not in usages.loc["supersymmetry"]["topic"]]
  54. descriptions = pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")
  55. labels = descriptions.loc[topics]["description"].tolist()
  56. submatrix = correlations[topics,:][:,topics]
  57. for i in range(submatrix.shape[0]):
  58. submatrix[i,i] = np.nan
  59. w = textwrap.TextWrapper(width=20,break_long_words=False,replace_whitespace=False)
  60. wlabels = ["\n".join(words) for words in map(w.wrap, labels)]
  61. # shrunk_cmap = shiftedColorMap(orig_cmap, start=-submatrix.max(), midpoint=0, stop=+submatrix.max(), name='shrunk')
  62. reverse_cmap = reverse_colourmap(orig_cmap, "reverse")
  63. sns.heatmap(submatrix, xticklabels = wlabels, yticklabels = wlabels, annot=True, fmt=".2f", cmap=reverse_cmap, center=0, vmin=-1, vmax=1)
  64. plt.savefig("plots/susy_correlations.pdf", bbox_inches="tight")
  65. plt.savefig("plots/susy_correlations.pgf", bbox_inches="tight")
  66. plt.savefig("plots/susy_correlations.eps", bbox_inches="tight")