validation.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. import argparse
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. matplotlib.use("pgf")
  9. matplotlib.rcParams.update({
  10. "pgf.texsystem": "pdflatex",
  11. 'font.family': 'serif',
  12. "font.serif" : "Times New Roman",
  13. 'text.usetex': True,
  14. 'pgf.rcfonts': False,
  15. })
  16. from sklearn.linear_model import LinearRegression
  17. def set_size(width, fraction=1, ratio = None):
  18. fig_width_pt = width * fraction
  19. inches_per_pt = 1 / 72.27
  20. if ratio is None:
  21. ratio = (5 ** 0.5 - 1) / 2
  22. fig_width_in = fig_width_pt * inches_per_pt
  23. fig_height_in = fig_width_in * ratio
  24. return fig_width_in, fig_height_in
  25. parser = argparse.ArgumentParser(description = 'plot_pred')
  26. parser.add_argument('data')
  27. parser.add_argument('fit')
  28. parser.add_argument('output')
  29. args = parser.parse_args()
  30. with open(args.data, 'rb') as fp:
  31. data = pickle.load(fp)
  32. fit = pd.read_parquet(args.fit)
  33. fig = plt.figure(figsize=set_size(450, 1, 1))
  34. axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
  35. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  36. colors = ['red', 'orange', 'green', 'blue']
  37. n_values = data['n_validation']
  38. for i in range(4*4):
  39. ax = axes[i]
  40. row = i//4+1
  41. col = i%4+1
  42. truth = data['truth'][:n_values,row-1]
  43. #vtc = np.sum(data['vtc'][:n_values,i,:], axis = 1)
  44. vtc = np.array(data['vtc'][:n_values,col-1,row-1])
  45. pred_dist = np.array([fit[f'pred.{k+1}.{col}.{row}'] for k in range(n_values)])
  46. errors = np.quantile(pred_dist, [(1-0.68)/2, 1-(1-0.68)/2], axis = 1)
  47. pred = np.mean(pred_dist, axis = 1)
  48. regr = LinearRegression()
  49. regr.fit(truth.reshape(-1, 1), pred)
  50. # p = np.zeros(n_values)
  51. # for k in range(n_values):
  52. # dy = np.abs(pred[k]-vtc[k])
  53. # more_extreme = pred_dist[k,np.abs(pred_dist[k,:]-pred[k])>dy]
  54. # p[k] = len(more_extreme)/pred_dist.shape[1]
  55. # chi_squared = -2*np.nansum(np.ma.log(p))/n_values
  56. # print(p.shape)
  57. # print(p)
  58. # print(chi_squared)
  59. # log_lik = np.array([fit[f'log_lik.{k+1}.{i+1}.{i+1}'] for k in range(n_values)])
  60. # print(log_lik)
  61. # log_lik = np.mean(log_lik)
  62. # print(log_lik)
  63. # print(np.exp(log_lik))
  64. mask = (truth > 1) & (pred > 1)
  65. ax.set_xlim(1,1000)
  66. ax.set_ylim(1,1000)
  67. ax.set_xscale('log')
  68. ax.set_yscale('log')
  69. slopes_x = np.logspace(0,3,num=3)
  70. ax.plot(slopes_x, regr.coef_[0]*slopes_x, color = 'black', lw = 0.75)
  71. #ax.scatter(truth[mask], pred[mask], s = 1, color = 'black')
  72. #ax.errorbar(truth[mask], pred[mask], [pred[mask]-errors[0,mask],errors[1,mask]-pred[mask]], ls='none', elinewidth = 0.25, color = '#333')
  73. x = truth[mask]
  74. y1 = np.maximum(errors[0,mask],1)
  75. y2 = np.minimum(errors[1, mask], 1000)
  76. srt = np.argsort(x)
  77. ax.fill_between(x[srt], y1[srt], y2[srt], color = '#ccc', alpha = 0.5)
  78. ax.scatter(truth[(vtc > 0) & (truth > 0)], vtc[(vtc > 0) & (truth > 0)], s = 1, color = colors[col-1])
  79. #r2 = np.corrcoef(vtc, pred)[0,1]**2
  80. #baseline = np.corrcoef(vtc, truth)[0,1]**2
  81. #print(speakers[i], r2, baseline)
  82. #ax.text(2, 400, speakers[i], ha = 'left', va = 'center')
  83. ax.set_xticks([])
  84. ax.set_yticks([])
  85. ax.set_xticklabels([])
  86. ax.set_yticklabels([])
  87. if col == 4:
  88. ax.yaxis.tick_right()
  89. if row == 1:
  90. ax.xaxis.tick_top()
  91. ax.set_xticks([10**1.5])
  92. ax.set_xticklabels([speakers[col-1]])
  93. if row == 4:
  94. ax.set_xticks(np.power(10, np.arange(1,4)))
  95. ax.set_xticklabels([f'10$^{i}$' for i in [1,2,3]])
  96. if col == 1:
  97. ax.set_yticks([10**1.5])
  98. ax.set_yticklabels([speakers[row-1]])
  99. if col == 4:
  100. ax.yaxis.tick_right()
  101. ax.set_yticks(np.power(10, np.arange(1,4)))
  102. ax.set_yticklabels([f'10$^{i}$' for i in [1,2,3]])
  103. plt.xlabel('')
  104. #fig.suptitle("$\mu_{eff}$ distribution")
  105. fig.subplots_adjust(wspace = 0, hspace = 0)
  106. plt.savefig(args.output)
  107. plt.show()