confusion_probs.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. import argparse
  7. import matplotlib
  8. import matplotlib.pyplot as plt
  9. matplotlib.use("pgf")
  10. matplotlib.rcParams.update(
  11. {
  12. "pgf.texsystem": "pdflatex",
  13. "font.family": "serif",
  14. "font.serif": "Times New Roman",
  15. "text.usetex": True,
  16. "pgf.rcfonts": False,
  17. }
  18. )
  19. def set_size(width, fraction=1, ratio=None):
  20. fig_width_pt = width * fraction
  21. inches_per_pt = 1 / 72.27
  22. if ratio is None:
  23. ratio = (5**0.5 - 1) / 2
  24. fig_width_in = fig_width_pt * inches_per_pt
  25. fig_height_in = fig_width_in * ratio
  26. return fig_width_in, fig_height_in
  27. parser = argparse.ArgumentParser(description="plot_pred")
  28. parser.add_argument("data")
  29. parser.add_argument("fit")
  30. parser.add_argument("output")
  31. args = parser.parse_args()
  32. with open(args.data, "rb") as fp:
  33. data = pickle.load(fp)
  34. fit = pd.read_parquet(args.fit)
  35. fig = plt.figure(figsize=set_size(450, 1, 1))
  36. axes = [fig.add_subplot(4, 4, i + 1) for i in range(4 * 4)]
  37. speakers = ["CHI", "OCH", "FEM", "MAL"]
  38. n_groups = data["n_groups"]
  39. for i in range(4 * 4):
  40. ax = axes[i]
  41. row = i // 4 + 1
  42. col = i % 4 + 1
  43. label = f"{col}.{row}"
  44. # if args.group is None:
  45. # data = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  46. # else:
  47. # data = fit[f'alphas.{args.group}.{label}']/(fit[f'alphas.{args.group}.{label}']+fit[f'betas.{args.group}.{label}']).values
  48. # data = np.hstack([(fit[f'group_mus.{k}.{label}']).values for k in range(1,59)])
  49. # data = fit[f'mus.{label}'].values
  50. if "fixed_bias.1.1" in fit.columns:
  51. data = expit(
  52. np.hstack(
  53. [
  54. logit(fit[f"probs.{k+1}.{label}"].values)
  55. + fit[f"fixed_bias.{label}"].values
  56. for k in range(n_groups)
  57. ]
  58. )
  59. )
  60. else:
  61. data = np.hstack([fit[f"probs.{k+1}.{label}"].values for k in range(n_groups)])
  62. ax.set_xticks([])
  63. ax.set_xticklabels([])
  64. ax.set_yticks([])
  65. ax.set_yticklabels([])
  66. ax.set_ylim(0, 5)
  67. ax.set_xlim(0, 1)
  68. low = np.quantile(data, 0.0275)
  69. high = np.quantile(data, 0.975)
  70. if row == 1:
  71. ax.xaxis.tick_top()
  72. ax.set_xticks([0.5])
  73. ax.set_xticklabels([speakers[col - 1]])
  74. if row == 4:
  75. ax.set_xticks(np.linspace(0.25, 1, 3, endpoint=False))
  76. ax.set_xticklabels(np.linspace(0.25, 1, 3, endpoint=False))
  77. if col == 1:
  78. ax.set_yticks([2.5])
  79. ax.set_yticklabels([speakers[row - 1]])
  80. ax.hist(data, bins=np.linspace(0, 1, 40), density=True, histtype="step")
  81. ax.axvline(np.mean(data), linestyle="--", linewidth=0.5, color="#333", alpha=1)
  82. ax.text(0.5, 4.5, f"{low:.2f} - {high:.2f}", ha="center", va="center")
  83. fig.suptitle("$p_{ij}$ distribution")
  84. fig.subplots_adjust(wspace=0, hspace=0)
  85. plt.savefig(args.output)
  86. plt.show()