Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

confusion_probs.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. import argparse
  7. import matplotlib
  8. import matplotlib.pyplot as plt
  9. matplotlib.use("pgf")
  10. matplotlib.rcParams.update(
  11. {
  12. "pgf.texsystem": "pdflatex",
  13. "font.family": "serif",
  14. "font.serif": "Times New Roman",
  15. "text.usetex": True,
  16. "pgf.rcfonts": False,
  17. }
  18. )
  19. def set_size(width, fraction=1, ratio=None):
  20. fig_width_pt = width * fraction
  21. inches_per_pt = 1 / 72.27
  22. if ratio is None:
  23. ratio = (5**0.5 - 1) / 2
  24. fig_width_in = fig_width_pt * inches_per_pt
  25. fig_height_in = fig_width_in * ratio
  26. return fig_width_in, fig_height_in
  27. parser = argparse.ArgumentParser(description="plot_pred")
  28. parser.add_argument("data")
  29. parser.add_argument("fit")
  30. parser.add_argument("output")
  31. args = parser.parse_args()
  32. with open(args.data, "rb") as fp:
  33. data = pickle.load(fp)
  34. fit = pd.read_parquet(args.fit)
  35. fig = plt.figure(figsize=set_size(450, 1, 1))
  36. axes = [fig.add_subplot(4, 4, i + 1) for i in range(4 * 4)]
  37. speakers = ["CHI", "OCH", "FEM", "MAL"]
  38. n_groups = data["n_groups"]
  39. for i in range(4 * 4):
  40. ax = axes[i]
  41. row = i // 4 + 1
  42. col = i % 4 + 1
  43. label = f"{col}.{row}"
  44. # if args.group is None:
  45. # data = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  46. # else:
  47. # data = fit[f'alphas.{args.group}.{label}']/(fit[f'alphas.{args.group}.{label}']+fit[f'betas.{args.group}.{label}']).values
  48. # data = np.hstack([(fit[f'group_mus.{k}.{label}']).values for k in range(1,59)])
  49. # data = fit[f'mus.{label}'].values
  50. if "fixed_bias.1.1" in fit.columns:
  51. data = expit(
  52. np.hstack(
  53. [
  54. logit(fit[f"probs.{k+1}.{label}"].values)
  55. + fit[f"fixed_bias.{label}"].values
  56. for k in range(n_groups)
  57. ]
  58. )
  59. )
  60. else:
  61. data = np.hstack([fit[f"probs.{k+1}.{label}"].values for k in range(n_groups)])
  62. ax.set_xticks([])
  63. ax.set_xticklabels([])
  64. ax.set_yticks([])
  65. ax.set_yticklabels([])
  66. ax.set_ylim(0, 5)
  67. ax.set_xlim(0, 1)
  68. low = np.quantile(data, 0.0275)
  69. high = np.quantile(data, 0.975)
  70. if row == 1:
  71. ax.xaxis.tick_top()
  72. ax.set_xticks([0.5])
  73. ax.set_xticklabels([speakers[col - 1]])
  74. if row == 4:
  75. ax.set_xticks(np.linspace(0.25, 1, 3, endpoint=False))
  76. ax.set_xticklabels(np.linspace(0.25, 1, 3, endpoint=False))
  77. if col == 1:
  78. ax.set_yticks([2.5])
  79. ax.set_yticklabels([speakers[row - 1]])
  80. ax.hist(data, bins=np.linspace(0, 1, 40), density=True, histtype="step")
  81. ax.axvline(np.mean(data), linestyle="--", linewidth=0.5, color="#333", alpha=1)
  82. ax.text(0.5, 4.5, f"{low:.2f} - {high:.2f}", ha="center", va="center")
  83. fig.suptitle("$p_{ij}$ distribution")
  84. fig.subplots_adjust(wspace=0, hspace=0)
  85. plt.savefig(args.output)
  86. plt.show()