123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from os.path import join as opj
- from re import A
- import pandas as pd
- import numpy as np
- from matplotlib import pyplot as plt
- import matplotlib
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "pdflatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- import pickle
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--vtc")
- parser.add_argument("--output")
- args = parser.parse_args()
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- def validation(data, samples, given_lambda=False):
- fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
- # FEM,MAL
- # OCH,FEM OCH,MAL
-
- truth = np.zeros((data["n_groups"], 4))
- algo = np.zeros((data["n_groups"], 4))
- for c in range(data["n_clips"]):
- truth[data["group"][c]-1] += data["truth_total"][c]
- algo[data["group"][c]-1] += data["algo_total"][c]
- variable = "sim_vocs_given_lambda" if given_lambda else "sim_vocs"
-
- for row in range(2):
- for col in range(2):
- ax = axes[row, col]
- i = row*2+col
- ax.scatter(algo[:,i], truth[:,i], color="#f781bf", s=3.5, marker="s", facecolors='none')
- mu = np.mean(samples[variable][:,:,i], axis=0)
- low = np.quantile(samples[variable][:,:,i], axis=0, q=(1-0.68)/2)
- up = np.quantile(samples[variable][:,:,i], axis=0, q=1-(1-0.68)/2)
-
- ax.plot([0, 200], [0,200], color="black", lw=0.5, ls="dashed")
- ax.scatter(algo[:,i], mu, s=3, facecolors='none', edgecolors="#377eb8")
- ax.errorbar(algo[:,i], (low+up)/2, ((low+up)/2-low, up-(low+up)/2), ls="none", lw=0.5)
- if row == 1 and col == 0:
- ax.set_xlabel("algo (obs.)")
- ax.set_ylabel("algo (pred)")
- # ax.set_xscale("log")
- # ax.set_yscale("log")
- ax.set_ylim(-5,200)
- ax.set_xlim(-5,200)
- ax.set_title(speakers[i])
- x1, x2, y1, y2 = -2.5, 25, -2.5, 25
- axins = ax.inset_axes(
- [0.57, 0.06, 0.4, 0.4],
- xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
- axins.set_xlim(x1, x2)
- axins.set_ylim(y1, y2)
- axins.scatter(algo[:,i], truth[:,i], color="#f781bf", s=3.5, marker="s", facecolors='none')
- axins.plot([0, 200], [0,200], color="black", lw=0.5, ls="dashed")
- axins.scatter(algo[:,i], mu, s=3, facecolors='none', edgecolors="#377eb8")
- axins.errorbar(algo[:,i], (low+up)/2, ((low+up)/2-low, up-(low+up)/2), ls="none", lw=0.5)
- ax.indicate_inset_zoom(axins, edgecolor="black")
-
- fig.savefig(f"output/validation_{variable}_{args.output}.eps", bbox_inches="tight")
-
- data = {
- "vtc": f"output/{args.vtc}.pickle",
- }
- for key in data:
- with open(data[key], "rb") as f:
- data[key] = pickle.load(f)
- samples = {
- "vtc": np.load(f"output/{args.vtc}.npz"),
- }
- validation(data["vtc"], samples["vtc"], True)
- validation(data["vtc"], samples["vtc"], False)
|