corpora_biases.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. import argparse
  7. import matplotlib
  8. import matplotlib.pyplot as plt
  9. matplotlib.use("pgf")
  10. matplotlib.rcParams.update(
  11. {
  12. "pgf.texsystem": "pdflatex",
  13. "font.family": "serif",
  14. "font.serif": "Times New Roman",
  15. "text.usetex": True,
  16. "pgf.rcfonts": False,
  17. }
  18. )
  19. def set_size(width, fraction=1, ratio=None):
  20. fig_width_pt = width * fraction
  21. inches_per_pt = 1 / 72.27
  22. if ratio is None:
  23. ratio = (5**0.5 - 1) / 2
  24. fig_width_in = fig_width_pt * inches_per_pt
  25. fig_height_in = fig_width_in * ratio
  26. return fig_width_in, fig_height_in
  27. parser = argparse.ArgumentParser(description="plot_pred")
  28. parser.add_argument("data")
  29. parser.add_argument("fit")
  30. parser.add_argument("output")
  31. parser.add_argument("--selected-corpus", type=int, default=-1)
  32. args = parser.parse_args()
  33. with open(args.data, "rb") as fp:
  34. data = pickle.load(fp)
  35. fit = pd.read_parquet(args.fit)
  36. fig = plt.figure(figsize=set_size(450, 1, 1))
  37. axes = [fig.add_subplot(4, 4, i + 1) for i in range(4 * 4)]
  38. speakers = ["CHI", "OCH", "FEM", "MAL"]
  39. n_groups = data["n_groups"]
  40. n_corpora = data["n_corpora"]
  41. for i in range(4 * 4):
  42. ax = axes[i]
  43. row = i // 4 + 1
  44. col = i % 4 + 1
  45. label = f"{col}.{row}"
  46. corpora = np.zeros((n_corpora, len(fit)))
  47. for c in range(n_corpora):
  48. corpora[c, :] = fit[f"corpus_bias.{c+1}.{label}"].values
  49. ax.set_xticks([])
  50. ax.set_xticklabels([])
  51. ax.set_yticks([])
  52. ax.set_yticklabels([])
  53. ax.set_ylim(0, 3)
  54. ax.set_xlim(-2, 2)
  55. if row == 1:
  56. ax.xaxis.tick_top()
  57. ax.set_xticks([0])
  58. ax.set_xticklabels([speakers[col - 1]])
  59. if row == 4:
  60. ax.set_xticks(np.linspace(-1, 1, 3, endpoint=True))
  61. ax.set_xticklabels(np.linspace(-1, 1, 3, endpoint=True))
  62. if col == 1:
  63. ax.set_yticks([1.5])
  64. ax.set_yticklabels([speakers[row - 1]])
  65. for c in range(n_corpora):
  66. if c != args.selected_corpus:
  67. ax.hist(
  68. corpora[c, :],
  69. bins=np.linspace(-2, 2, 40),
  70. density=True,
  71. histtype="step",
  72. color="gray",
  73. lw=0.5,
  74. alpha=0.5,
  75. )
  76. if args.selected_corpus >= 0:
  77. ax.hist(
  78. corpora[args.selected_corpus, :],
  79. bins=np.linspace(-2, 2, 40),
  80. density=True,
  81. histtype="step",
  82. color="red",
  83. lw=1,
  84. )
  85. # ax.axvline(np.mean(data), linestyle="--", linewidth=0.5, color="#333", alpha=1)
  86. # ax.text(0.5, 4.5, f"{low:.2f} - {high:.2f}", ha="center", va="center")
  87. fig.suptitle("Corpora biases $b_{ij}$ distributions")
  88. fig.subplots_adjust(wspace=0, hspace=0)
  89. plt.savefig(args.output)
  90. plt.show()