123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- #!/usr/bin/env python3
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.special import logit, expit
- from scipy.stats import gamma
- import argparse
- import matplotlib
- import matplotlib.pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "pdflatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- def set_size(width, fraction=1, ratio=None):
- fig_width_pt = width * fraction
- inches_per_pt = 1 / 72.27
- if ratio is None:
- ratio = (5**0.5 - 1) / 2
- fig_width_in = fig_width_pt * inches_per_pt
- fig_height_in = fig_width_in * ratio
- return fig_width_in, fig_height_in
- # parser = argparse.ArgumentParser(description="plot_pred")
- # parser.add_argument("data")
- # parser.add_argument("fit")
- # parser.add_argument("output")
- # args = parser.parse_args()
- fits = {
- # "vtc": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
- "vtc": np.load("output/aggregates_vtc_all_confusion_covariates.npz"),
- "lena": np.load("output/aggregates_lena_all_15_confusion_covariates.npz"),
- # "lena": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
- # "lena": np.load("output/aggregates_lena_all_confusion_covariates.npz")
- }
- labels = {
- "vtc": "VTC",
- "lena": "LENA"
- }
- colors = {
- "vtc": "#377eb8",
- "lena": "#ff7f00",
- }
- fig, axes = plt.subplots(nrows=1, ncols=4, figsize=set_size(450, 1, 1*0.25), sharex=True, sharey=True)
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- for i in range(4):
- ax = axes[i]
- row = i // 4
- col = i % 4
- ax.axhline(y=1, color="black", lw=0.5)
- ax.scatter([0], [1], color="black")
- split = np.linspace(-0.125,0.125,len(fits))
- for k, algo in enumerate(fits):
- fit = fits[algo]
-
- beta = np.exp(fit["beta_age_bin"]/10)
- low = np.quantile(beta, q=0.05/2, axis=0)
- high = np.quantile(beta, q=1-0.05/2, axis=0)
- mean = np.mean(beta, axis=0)
- age_bin = np.arange(low.shape[0])+1
- ax.scatter(age_bin+split[k], mean[:,i], label=labels[algo] if i==0 else None, color=colors[algo])
- ax.errorbar(age_bin+split[k], mean[:,i], (mean[:,i]-low[:,i],high[:,i]-mean[:,i]), color=colors[algo], ls="none")
- ax.set_xticks([])
- ax.set_xticklabels([])
- ax.set_xlim(-0.5, 3.5)
- ax.set_ylim(0.75, 1.25)
- if col == 0:
- ax.set_ylabel("Ratio")
- ax.set_xlabel("Age group (in years)")
- ax.set_yticks([0.8, 1, 1.2])
- ax.set_yticklabels([0.8, 1, 1.2])
- ax.tick_params(labelleft=True)
- else:
- ax.tick_params(labelleft=False)
- ax.set_title(speakers[col])
- ax.set_xticks(np.arange(0,4))
- ax.set_xticklabels([f"$[{i}-{i+1})$" if i<3 else f"$[{i}-\\infty($" for i in np.arange(0,4)], rotation=90)
- ax.text(
- 0.5, 0.9, f"{speakers[row]}$\\to${speakers[col]}",
- ha="center", transform = ax.transAxes
- )
- fig.subplots_adjust(wspace=0, hspace=0)
- fig.legend(bbox_to_anchor=(1, 0.1))
- plt.savefig("output/confusion_age.png", bbox_inches="tight", dpi=720)
- plt.savefig("output/confusion_age.pdf", bbox_inches="tight")
- plt.show()
|