contours.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python3
  2. import argparse
  3. import pandas as pd
  4. import pickle
  5. import numpy as np
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. matplotlib.use("pgf")
  9. matplotlib.rcParams.update({
  10. "pgf.texsystem": "pdflatex",
  11. 'font.family': 'serif',
  12. "font.serif" : "Times New Roman",
  13. 'text.usetex': True,
  14. 'pgf.rcfonts': False,
  15. })
  16. import seaborn as sns
  17. def set_size(width, fraction=1, ratio = None):
  18. """ Set aesthetic figure dimensions to avoid scaling in latex.
  19. Parameters
  20. ----------
  21. width: float
  22. Width in pts
  23. fraction: float
  24. Fraction of the width which you wish the figure to occupy
  25. Returns
  26. -------
  27. fig_dim: tuple
  28. Dimensions of figure in inches
  29. """
  30. # Width of figure
  31. fig_width_pt = width * fraction
  32. # Convert from pt to inches
  33. inches_per_pt = 1 / 72.27
  34. # Golden ratio to set aesthetic figure height
  35. if ratio is None:
  36. ratio = (5 ** 0.5 - 1) / 2
  37. # Figure width in inches
  38. fig_width_in = fig_width_pt * inches_per_pt
  39. # Figure height in inches
  40. fig_height_in = fig_width_in * ratio
  41. return fig_width_in, fig_height_in
  42. parser = argparse.ArgumentParser(description = 'plot_pred')
  43. parser.add_argument('data')
  44. parser.add_argument('fit')
  45. parser.add_argument('output')
  46. args = parser.parse_args()
  47. with open(args.data, 'rb') as fp:
  48. data = pickle.load(fp)
  49. fit = pd.read_parquet(args.fit)
  50. fig = plt.figure(figsize=set_size(450, 1, 1))
  51. axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
  52. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  53. n_groups = data['n_groups']
  54. for i in range(4*4):
  55. ax = axes[i]
  56. row = i//4+1
  57. col = i%4+1
  58. label = f'{row}.{col}'
  59. #mus = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  60. #etas = np.hstack([(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  61. #etas = np.log10(etas)
  62. ax.set_xticks([])
  63. ax.set_xticklabels([])
  64. ax.set_yticks([])
  65. ax.set_yticklabels([])
  66. ax.set_ylim(0,3)
  67. ax.set_xlim(0,1)
  68. if row == 1:
  69. ax.xaxis.tick_top()
  70. ax.set_xticks([0.5])
  71. ax.set_xticklabels([speakers[col-1]])
  72. if row == 4:
  73. ax.set_xticks(np.linspace(0.25,1,3, endpoint = False))
  74. ax.set_xticklabels(np.linspace(0.25,1,3, endpoint = False))
  75. if col == 1:
  76. ax.set_yticks([1.5])
  77. ax.set_yticklabels([speakers[row-1]])
  78. if col == 4:
  79. ax.yaxis.tick_right()
  80. ax.set_yticks(np.arange(1,3))
  81. ax.set_yticklabels([f'10$^{i}' for i in np.arange(1,3)])
  82. kplt = sns.kdeplot(fit[f'mus.{label}'], fit[f'etas.{label}'].apply(np.log), shade=True, cmap="viridis", ax = ax)
  83. #kplt = sns.kdeplot(mus, etas, shade=True, cmap="viridis", ax = ax)
  84. kplt.set(xlabel = None, ylabel = None)
  85. ax.axvline(np.mean(fit[f'mus.{label}']), linestyle = '--', linewidth = 0.5, color = '#333', alpha = 1)
  86. fig.subplots_adjust(wspace = 0, hspace = 0)
  87. plt.savefig(args.output)
  88. plt.show()