plot.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. matplotlib.use("pgf")
  6. matplotlib.rcParams.update({
  7. "pgf.texsystem": "pdflatex",
  8. 'font.family': 'serif',
  9. "font.serif" : "Times New Roman",
  10. 'text.usetex': True,
  11. 'pgf.rcfonts': False,
  12. })
  13. def set_size(width, fraction=1, ratio = None):
  14. """ Set aesthetic figure dimensions to avoid scaling in latex.
  15. Parameters
  16. ----------
  17. width: float
  18. Width in pts
  19. fraction: float
  20. Fraction of the width which you wish the figure to occupy
  21. Returns
  22. -------
  23. fig_dim: tuple
  24. Dimensions of figure in inches
  25. """
  26. # Width of figure
  27. fig_width_pt = width * fraction
  28. # Convert from pt to inches
  29. inches_per_pt = 1 / 72.27
  30. # Golden ratio to set aesthetic figure height
  31. if ratio is None:
  32. ratio = (5 ** 0.5 - 1) / 2
  33. # Figure width in inches
  34. fig_width_in = fig_width_pt * inches_per_pt
  35. # Figure height in inches
  36. fig_height_in = fig_width_in * ratio
  37. return fig_width_in, fig_height_in
  38. fit = pd.read_csv('fit.csv')
  39. fig = plt.figure(figsize=set_size(450, 1, 1))
  40. axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
  41. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  42. for i in range(4*4):
  43. ax = axes[i]
  44. row = i//4+1
  45. col = i%4+1
  46. label = f'confusion.{row}.{col}'
  47. ax.set_xticks([])
  48. ax.set_xticklabels([])
  49. ax.set_yticks([])
  50. ax.set_yticklabels([])
  51. ax.set_ylim(0,5)
  52. ax.set_xlim(0,1)
  53. low = fit[label].quantile(0.0275)
  54. high = fit[label].quantile(0.975)
  55. if row == 1:
  56. ax.xaxis.tick_top()
  57. ax.set_xticks([0.5])
  58. ax.set_xticklabels([speakers[col-1]])
  59. if row == 4:
  60. ax.set_xticks(np.linspace(0.25,1,3, endpoint = False))
  61. ax.set_xticklabels(np.linspace(0.25,1,3, endpoint = False))
  62. if col == 1:
  63. ax.set_yticks([2.5])
  64. ax.set_yticklabels([speakers[row-1]])
  65. ax.hist(fit[label], bins = np.linspace(0,1,40), density = True)
  66. ax.axvline(fit[label].mean(), linestyle = '--', linewidth = 0.5, color = '#333', alpha = 1)
  67. ax.text(0.5, 4.5, f'{low:.2f} - {high:.2f}', ha = 'center', va = 'center')
  68. fig.subplots_adjust(wspace = 0, hspace = 0)
  69. plt.savefig('confusion_fit.pdf')
  70. plt.show()