123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- #!/usr/bin/env python3
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.special import logit, expit
- from scipy.stats import gamma, beta
- import argparse
- import matplotlib
- import matplotlib.pyplot as plt
- # matplotlib.use("pgf")
- # matplotlib.rcParams.update(
- # {
- # "pgf.texsystem": "pdflatex",
- # "font.family": "serif",
- # "font.serif": "Times New Roman",
- # "text.usetex": True,
- # "pgf.rcfonts": False,
- # }
- # )
- import seaborn as sns
- from os.path import join as opj
- def set_size(width, fraction=1, ratio=None):
- fig_width_pt = width * fraction
- inches_per_pt = 1 / 72.27
- if ratio is None:
- ratio = (5**0.5 - 1) / 2
- fig_width_in = fig_width_pt * inches_per_pt
- fig_height_in = fig_width_in * ratio
- return fig_width_in, fig_height_in
- parser = argparse.ArgumentParser()
- parser.add_argument("--samples")
- parser.add_argument("--output")
- args = parser.parse_args()
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- n_classes = len(speakers)
- samples = np.load(opj("output", f"{args.samples}.npz"))
- def pair_plot(data, output):
- plt.clf()
- plt.cla()
- plt.rcParams['figure.figsize']=set_size(450, 1, 1)
- sns.pairplot(data, kind="kde")
- plt.savefig(opj("output", f"pair_plot_{output}.eps"), bbox_inches="tight")
- plt.savefig(opj("output", f"pair_plot_{output}.png"), bbox_inches="tight", dpi=720)
- data = {}
- for i in range(n_classes):
- data[f"alpha_child_level.{i}"] = samples["alpha_child_level"][:,i]
- data[f"mu_pop_level.{i}"] = samples["mu_pop_level"][:,i]
-
- if i < n_classes-1:
- data[f"alpha_corpus_level.{i+1}"] = samples["alpha_corpus_level"][:,0,i]
- data[f"mu_corpus_level.{i+1}"] = samples["mu_corpus_level"][:,i,0]
- data[f"truth_vocs.{i}"] = samples["truth_vocs"][:,0,i]
- data["alpha_dev"] = samples["alpha_dev"]
- data["sigma_dev"] = samples["sigma_dev"]
- data["beta_dev"] = samples["beta_dev"]
- data["child_dev_age"] = samples["child_dev_age"][:,0]
- if "mus" in samples:
- for j in range(n_classes):
- data[f"mus.{i}.{j}"] = samples["mus"][:,i,j]
- data = pd.DataFrame(data)
- pair_plot(data[["alpha_dev", "sigma_dev", "beta_dev", "child_dev_age"]], f"dev")
- for i in range (1,n_classes):
- cols = [
- f"alpha_corpus_level.{i}", f"alpha_child_level.{i}",
- f"mu_pop_level.{i}", f"mu_corpus_level.{i}"
- ]
- pair_plot(data[cols], f"hierarchical_{speakers[i]}")
- cols = [f"truth_vocs.{i}" for i in range(n_classes)]
- pair_plot(data[cols], "truth_vocs")
- cols = [f"mus.{i}.{j}" for i in range(n_classes-1) for j in range(n_classes-1)]
- pair_plot(data[cols], "mus")
|