pair_plots.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. from scipy.stats import gamma, beta
  7. import argparse
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. # matplotlib.use("pgf")
  11. # matplotlib.rcParams.update(
  12. # {
  13. # "pgf.texsystem": "pdflatex",
  14. # "font.family": "serif",
  15. # "font.serif": "Times New Roman",
  16. # "text.usetex": True,
  17. # "pgf.rcfonts": False,
  18. # }
  19. # )
  20. import seaborn as sns
  21. from os.path import join as opj
  22. def set_size(width, fraction=1, ratio=None):
  23. fig_width_pt = width * fraction
  24. inches_per_pt = 1 / 72.27
  25. if ratio is None:
  26. ratio = (5**0.5 - 1) / 2
  27. fig_width_in = fig_width_pt * inches_per_pt
  28. fig_height_in = fig_width_in * ratio
  29. return fig_width_in, fig_height_in
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument("--samples")
  32. parser.add_argument("--output")
  33. args = parser.parse_args()
  34. speakers = ["CHI", "OCH", "FEM", "MAL"]
  35. n_classes = len(speakers)
  36. samples = np.load(opj("output", f"{args.samples}.npz"))
  37. def pair_plot(data, output):
  38. plt.clf()
  39. plt.cla()
  40. plt.rcParams['figure.figsize']=set_size(450, 1, 1)
  41. sns.pairplot(data, kind="kde")
  42. plt.savefig(opj("output", f"pair_plot_{output}.eps"), bbox_inches="tight")
  43. plt.savefig(opj("output", f"pair_plot_{output}.png"), bbox_inches="tight", dpi=720)
  44. data = {}
  45. for i in range(n_classes):
  46. data[f"alpha_child_level.{i}"] = samples["alpha_child_level"][:,i]
  47. data[f"mu_pop_level.{i}"] = samples["mu_pop_level"][:,i]
  48. if i < n_classes-1:
  49. data[f"alpha_corpus_level.{i+1}"] = samples["alpha_corpus_level"][:,0,i]
  50. data[f"mu_corpus_level.{i+1}"] = samples["mu_corpus_level"][:,i,0]
  51. data[f"truth_vocs.{i}"] = samples["truth_vocs"][:,0,i]
  52. data["alpha_dev"] = samples["alpha_dev"]
  53. data["sigma_dev"] = samples["sigma_dev"]
  54. data["beta_dev"] = samples["beta_dev"]
  55. data["child_dev_age"] = samples["child_dev_age"][:,0]
  56. if "mus" in samples:
  57. for j in range(n_classes):
  58. data[f"mus.{i}.{j}"] = samples["mus"][:,i,j]
  59. data = pd.DataFrame(data)
  60. pair_plot(data[["alpha_dev", "sigma_dev", "beta_dev", "child_dev_age"]], f"dev")
  61. for i in range (1,n_classes):
  62. cols = [
  63. f"alpha_corpus_level.{i}", f"alpha_child_level.{i}",
  64. f"mu_pop_level.{i}", f"mu_corpus_level.{i}"
  65. ]
  66. pair_plot(data[cols], f"hierarchical_{speakers[i]}")
  67. cols = [f"truth_vocs.{i}" for i in range(n_classes)]
  68. pair_plot(data[cols], "truth_vocs")
  69. cols = [f"mus.{i}.{j}" for i in range(n_classes-1) for j in range(n_classes-1)]
  70. pair_plot(data[cols], "mus")