confusion_age.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. from scipy.stats import gamma
  7. import argparse
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. matplotlib.use("pgf")
  11. matplotlib.rcParams.update(
  12. {
  13. "pgf.texsystem": "pdflatex",
  14. "font.family": "serif",
  15. "font.serif": "Times New Roman",
  16. "text.usetex": True,
  17. "pgf.rcfonts": False,
  18. }
  19. )
  20. def set_size(width, fraction=1, ratio=None):
  21. fig_width_pt = width * fraction
  22. inches_per_pt = 1 / 72.27
  23. if ratio is None:
  24. ratio = (5**0.5 - 1) / 2
  25. fig_width_in = fig_width_pt * inches_per_pt
  26. fig_height_in = fig_width_in * ratio
  27. return fig_width_in, fig_height_in
  28. # parser = argparse.ArgumentParser(description="plot_pred")
  29. # parser.add_argument("data")
  30. # parser.add_argument("fit")
  31. # parser.add_argument("output")
  32. # args = parser.parse_args()
  33. fits = {
  34. # "vtc": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
  35. "vtc": np.load("output/aggregates_vtc_all_confusion_covariates.npz"),
  36. "lena": np.load("output/aggregates_lena_all_15_confusion_covariates.npz"),
  37. # "lena": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
  38. # "lena": np.load("output/aggregates_lena_all_confusion_covariates.npz")
  39. }
  40. labels = {
  41. "vtc": "VTC",
  42. "lena": "LENA"
  43. }
  44. colors = {
  45. "vtc": "#377eb8",
  46. "lena": "#ff7f00",
  47. }
  48. fig, axes = plt.subplots(nrows=1, ncols=4, figsize=set_size(450, 1, 1*0.25), sharex=True, sharey=True)
  49. speakers = ["CHI", "OCH", "FEM", "MAL"]
  50. for i in range(4):
  51. ax = axes[i]
  52. row = i // 4
  53. col = i % 4
  54. ax.axhline(y=1, color="black", lw=0.5)
  55. ax.scatter([0], [1], color="black")
  56. split = np.linspace(-0.125,0.125,len(fits))
  57. for k, algo in enumerate(fits):
  58. fit = fits[algo]
  59. beta = np.exp(fit["beta_age_bin"]/10)
  60. low = np.quantile(beta, q=0.05/2, axis=0)
  61. high = np.quantile(beta, q=1-0.05/2, axis=0)
  62. mean = np.mean(beta, axis=0)
  63. age_bin = np.arange(low.shape[0])+1
  64. ax.scatter(age_bin+split[k], mean[:,i], label=labels[algo] if i==0 else None, color=colors[algo])
  65. ax.errorbar(age_bin+split[k], mean[:,i], (mean[:,i]-low[:,i],high[:,i]-mean[:,i]), color=colors[algo], ls="none")
  66. ax.set_xticks([])
  67. ax.set_xticklabels([])
  68. ax.set_xlim(-0.5, 3.5)
  69. ax.set_ylim(0.75, 1.25)
  70. if col == 0:
  71. ax.set_ylabel("Ratio")
  72. ax.set_xlabel("Age group (in years)")
  73. ax.set_yticks([0.8, 1, 1.2])
  74. ax.set_yticklabels([0.8, 1, 1.2])
  75. ax.tick_params(labelleft=True)
  76. else:
  77. ax.tick_params(labelleft=False)
  78. ax.set_title(speakers[col])
  79. ax.set_xticks(np.arange(0,4))
  80. ax.set_xticklabels([f"$[{i}-{i+1})$" if i<3 else f"$[{i}-\\infty($" for i in np.arange(0,4)], rotation=90)
  81. ax.text(
  82. 0.5, 0.9, f"{speakers[row]}$\\to${speakers[col]}",
  83. ha="center", transform = ax.transAxes
  84. )
  85. fig.subplots_adjust(wspace=0, hspace=0)
  86. fig.legend(bbox_to_anchor=(1, 0.1))
  87. plt.savefig("output/confusion_age.png", bbox_inches="tight", dpi=720)
  88. plt.savefig("output/confusion_age.pdf", bbox_inches="tight")
  89. plt.show()