confusion_probs.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. import argparse
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. matplotlib.use("pgf")
  9. matplotlib.rcParams.update({
  10. "pgf.texsystem": "pdflatex",
  11. 'font.family': 'serif',
  12. "font.serif" : "Times New Roman",
  13. 'text.usetex': True,
  14. 'pgf.rcfonts': False,
  15. })
  16. def set_size(width, fraction=1, ratio = None):
  17. fig_width_pt = width * fraction
  18. inches_per_pt = 1 / 72.27
  19. if ratio is None:
  20. ratio = (5 ** 0.5 - 1) / 2
  21. fig_width_in = fig_width_pt * inches_per_pt
  22. fig_height_in = fig_width_in * ratio
  23. return fig_width_in, fig_height_in
  24. parser = argparse.ArgumentParser(description = 'plot_pred')
  25. parser.add_argument('data')
  26. parser.add_argument('fit')
  27. parser.add_argument('output')
  28. args = parser.parse_args()
  29. with open(args.data, 'rb') as fp:
  30. data = pickle.load(fp)
  31. fit = pd.read_parquet(args.fit)
  32. fig = plt.figure(figsize=set_size(450, 1, 1))
  33. axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
  34. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  35. n_groups = data['n_groups']
  36. for i in range(4*4):
  37. ax = axes[i]
  38. row = i//4+1
  39. col = i%4+1
  40. label = f'{col}.{row}'
  41. #if args.group is None:
  42. # data = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  43. #else:
  44. # data = fit[f'alphas.{args.group}.{label}']/(fit[f'alphas.{args.group}.{label}']+fit[f'betas.{args.group}.{label}']).values
  45. #data = np.hstack([(fit[f'group_mus.{k}.{label}']).values for k in range(1,59)])
  46. #data = fit[f'mus.{label}'].values
  47. data = np.hstack([fit[f'probs.{k+1}.{label}'].values for k in range(n_groups)])
  48. ax.set_xticks([])
  49. ax.set_xticklabels([])
  50. ax.set_yticks([])
  51. ax.set_yticklabels([])
  52. ax.set_ylim(0,5)
  53. ax.set_xlim(0,1)
  54. low = np.quantile(data, 0.0275)
  55. high = np.quantile(data, 0.975)
  56. if row == 1:
  57. ax.xaxis.tick_top()
  58. ax.set_xticks([0.5])
  59. ax.set_xticklabels([speakers[col-1]])
  60. if row == 4:
  61. ax.set_xticks(np.linspace(0.25,1,3, endpoint = False))
  62. ax.set_xticklabels(np.linspace(0.25,1,3, endpoint = False))
  63. if col == 1:
  64. ax.set_yticks([2.5])
  65. ax.set_yticklabels([speakers[row-1]])
  66. ax.hist(data, bins = np.linspace(0,1,40), density = True, histtype = 'step')
  67. ax.axvline(np.mean(data), linestyle = '--', linewidth = 0.5, color = '#333', alpha = 1)
  68. ax.text(0.5, 4.5, f'{low:.2f} - {high:.2f}', ha = 'center', va = 'center')
  69. fig.suptitle("$p_{ij}$ distribution")
  70. fig.subplots_adjust(wspace = 0, hspace = 0)
  71. plt.savefig(args.output)
  72. plt.show()