123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- #!/usr/bin/env python3
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.special import logit, expit
- from scipy.stats import gamma
- import argparse
- import matplotlib
- import matplotlib.pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "pdflatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- def set_size(width, fraction=1, ratio=None):
- fig_width_pt = width * fraction
- inches_per_pt = 1 / 72.27
- if ratio is None:
- ratio = (5**0.5 - 1) / 2
- fig_width_in = fig_width_pt * inches_per_pt
- fig_height_in = fig_width_in * ratio
- return fig_width_in, fig_height_in
- # parser = argparse.ArgumentParser(description="plot_pred")
- # parser.add_argument("data")
- # parser.add_argument("fit")
- # parser.add_argument("output")
- # args = parser.parse_args()
- fits = {
- # "vtc": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
- "vtc": np.load("output/aggregates_vtc_all_confusion_covariates.npz"),
- "lena": np.load("output/aggregates_lena_all_15_confusion_covariates.npz"),
- # "lena": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
- # "lena": np.load("output/aggregates_lena_all_confusion_covariates.npz")
- }
- labels = {
- "vtc": "VTC",
- "lena": "LENA"
- }
- colors = {
- "vtc": "#377eb8",
- "lena": "#ff7f00",
- }
- fig = plt.figure(figsize=set_size(450, 1, 1))
- ax = fig.add_subplot(4, 4, 1)
- axes = [ax] + [fig.add_subplot(4, 4, i + 1, sharex=ax, sharey=ax) for i in range(1, 4 * 4)]
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- for i in range(4 * 4):
- ax = axes[i]
- row = i // 4
- col = i % 4
- # if args.group is None:
- # data = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
- # else:
- # data = fit[f'alphas.{args.group}.{label}']/(fit[f'alphas.{args.group}.{label}']+fit[f'betas.{args.group}.{label}']).values
- # data = np.hstack([(fit[f'group_mus.{k}.{label}']).values for k in range(1,59)])
- # data = fit[f'mus.{label}'].values
-
- ax.axhline(y=0, color="black", lw=0.5)
- for algo in fits:
- fit = fits[algo]
-
- alphas_urban = fit["alphas"][:,0,row,col]
- mus_urban = fit["mus"][:,0,row,col]
- alphas_rural = fit["alphas"][:,1,row,col]
- mus_rural = fit["mus"][:,1,row,col]
- scale_urban = mus_urban/alphas_urban
- scale_rural = mus_rural/alphas_rural
- x = np.linspace(0.01,0.99,200,True)
- pdf_urban = np.zeros((len(x), len(alphas_urban)))
- pdf_rural = np.zeros((len(x), len(alphas_rural)))
- for k in range(len(x)):
- # pdf_urban[k,:] = gamma.logpdf(x[k], alphas_urban, np.zeros(len(alphas_urban)), scale_urban)
- # pdf_rural[k,:] = gamma.logpdf(x[k], alphas_rural, np.zeros(len(alphas_rural)), scale_rural)
- pdf_urban[k,:] = gamma.pdf(x[k], alphas_urban, np.zeros(len(alphas_urban)), scale_urban)
- pdf_rural[k,:] = gamma.pdf(x[k], alphas_rural, np.zeros(len(alphas_rural)), scale_rural)
- # log_ratio = pdf_rural - pdf_urban
- # low = np.quantile(log_ratio, q=0.0275, axis=1)
- # high = np.quantile(log_ratio, q=0.975, axis=1)
- # mean = np.mean(log_ratio, axis=1)
- # ax.plot(x, mean, label=labels[algo] if i==0 else None, color=colors[algo])
- # ax.fill_between(x, low, high, alpha=0.2, color=colors[algo])
- low_urban = np.quantile(pdf_urban, q=0.05/2, axis=1)
- high_urban = np.quantile(pdf_urban, q=1-0.05/2, axis=1)
- mean_urban = np.mean(pdf_urban, axis=1)
- low_rural = np.quantile(pdf_rural, q=0.05/2, axis=1)
- high_rural = np.quantile(pdf_rural, q=1-0.05/2, axis=1)
- mean_rural = np.mean(pdf_rural, axis=1)
- ax.plot(x, mean_urban, label=labels[algo] if i==0 else None, color=colors[algo], ls="dashed")
- ax.fill_between(x, low_urban, high_urban, alpha=0.05, color=colors[algo])
- ax.plot(x, mean_rural, color=colors[algo])
- ax.fill_between(x, low_rural, high_rural, alpha=0.2, color=colors[algo])
- ax.set_xticks([])
- ax.set_xticklabels([])
- ax.set_yticks([])
- ax.set_yticklabels([])
- ax.set_xlim(0, 1)
- ax.set_ylim(0, 10)
- if col == 0:
- ax.set_ylabel(speakers[row])
- if row == 0:
- ax.set_title(speakers[col])
- if row == 3:
- ax.set_xticks(np.linspace(0.25, 1, 3, endpoint=False))
- ax.set_xticklabels(np.linspace(0.25, 1, 3, endpoint=False))
- ax.text(
- 0.5, 0.9, f"{speakers[row]}$\\to${speakers[col]}",
- ha="center", transform = ax.transAxes
- )
- fig.subplots_adjust(wspace=0, hspace=0)
- fig.legend(bbox_to_anchor=(1, 0.1))
- plt.savefig("output/confusion_ratio.png", bbox_inches="tight", dpi=720)
- plt.savefig("output/confusion_ratio.pdf", bbox_inches="tight")
- plt.show()
|