123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- #!/usr/bin/env python3
- import pandas as pd
- import pickle
- import numpy as np
- from scipy.special import logit, expit
- import argparse
- import matplotlib
- import matplotlib.pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "pdflatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- def set_size(width, fraction=1, ratio=None):
- fig_width_pt = width * fraction
- inches_per_pt = 1 / 72.27
- if ratio is None:
- ratio = (5**0.5 - 1) / 2
- fig_width_in = fig_width_pt * inches_per_pt
- fig_height_in = fig_width_in * ratio
- return fig_width_in, fig_height_in
- parser = argparse.ArgumentParser(description="plot_pred")
- parser.add_argument("data")
- parser.add_argument("fit")
- parser.add_argument("output")
- parser.add_argument("--selected-corpus", type=int, default=-1)
- args = parser.parse_args()
- with open(args.data, "rb") as fp:
- data = pickle.load(fp)
- fit = pd.read_parquet(args.fit)
- fig = plt.figure(figsize=set_size(450, 1, 1))
- axes = [fig.add_subplot(4, 4, i + 1) for i in range(4 * 4)]
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- n_groups = data["n_groups"]
- n_corpora = data["n_corpora"]
- for i in range(4 * 4):
- ax = axes[i]
- row = i // 4 + 1
- col = i % 4 + 1
- label = f"{col}.{row}"
- corpora = np.zeros((n_corpora, len(fit)))
- for c in range(n_corpora):
- corpora[c, :] = fit[f"corpus_bias.{c+1}.{label}"].values
- ax.set_xticks([])
- ax.set_xticklabels([])
- ax.set_yticks([])
- ax.set_yticklabels([])
- ax.set_ylim(0, 3)
- ax.set_xlim(-2, 2)
- if row == 1:
- ax.xaxis.tick_top()
- ax.set_xticks([0])
- ax.set_xticklabels([speakers[col - 1]])
- if row == 4:
- ax.set_xticks(np.linspace(-1, 1, 3, endpoint=True))
- ax.set_xticklabels(np.linspace(-1, 1, 3, endpoint=True))
- if col == 1:
- ax.set_yticks([1.5])
- ax.set_yticklabels([speakers[row - 1]])
- for c in range(n_corpora):
- if c != args.selected_corpus:
- ax.hist(
- corpora[c, :],
- bins=np.linspace(-2, 2, 40),
- density=True,
- histtype="step",
- color="gray",
- lw=0.5,
- alpha=0.5,
- )
- if args.selected_corpus >= 0:
- ax.hist(
- corpora[args.selected_corpus, :],
- bins=np.linspace(-2, 2, 40),
- density=True,
- histtype="step",
- color="red",
- lw=1,
- )
- # ax.axvline(np.mean(data), linestyle="--", linewidth=0.5, color="#333", alpha=1)
- # ax.text(0.5, 4.5, f"{low:.2f} - {high:.2f}", ha="center", va="center")
- fig.suptitle("Corpora biases $b_{ij}$ distributions")
- fig.subplots_adjust(wspace=0, hspace=0)
- plt.savefig(args.output)
- plt.show()
|