plot4.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import numpy as np
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update({
  8. "pgf.texsystem": "pdflatex",
  9. 'font.family': 'serif',
  10. "font.serif" : "Times New Roman",
  11. 'text.usetex': True,
  12. 'pgf.rcfonts': False,
  13. })
  14. import seaborn as sns
  15. def set_size(width, fraction=1, ratio = None):
  16. """ Set aesthetic figure dimensions to avoid scaling in latex.
  17. Parameters
  18. ----------
  19. width: float
  20. Width in pts
  21. fraction: float
  22. Fraction of the width which you wish the figure to occupy
  23. Returns
  24. -------
  25. fig_dim: tuple
  26. Dimensions of figure in inches
  27. """
  28. # Width of figure
  29. fig_width_pt = width * fraction
  30. # Convert from pt to inches
  31. inches_per_pt = 1 / 72.27
  32. # Golden ratio to set aesthetic figure height
  33. if ratio is None:
  34. ratio = (5 ** 0.5 - 1) / 2
  35. # Figure width in inches
  36. fig_width_in = fig_width_pt * inches_per_pt
  37. # Figure height in inches
  38. fig_height_in = fig_width_in * ratio
  39. return fig_width_in, fig_height_in
  40. fit = pd.read_parquet('fit.parquet')
  41. fig = plt.figure(figsize=set_size(450, 1, 1))
  42. axes = [fig.add_subplot(4,4,i+1) for i in range(4*4)]
  43. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  44. for i in range(4*4):
  45. ax = axes[i]
  46. row = i//4+1
  47. col = i%4+1
  48. label = f'{row}.{col}'
  49. mus = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,59)])
  50. etas = np.hstack([(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,59)])
  51. etas = np.log10(etas)
  52. ax.set_xticks([])
  53. ax.set_xticklabels([])
  54. ax.set_yticks([])
  55. ax.set_yticklabels([])
  56. ax.set_ylim(0,3)
  57. ax.set_xlim(0,1)
  58. if row == 1:
  59. ax.xaxis.tick_top()
  60. ax.set_xticks([0.5])
  61. ax.set_xticklabels([speakers[col-1]])
  62. if row == 4:
  63. ax.set_xticks(np.linspace(0.25,1,3, endpoint = False))
  64. ax.set_xticklabels(np.linspace(0.25,1,3, endpoint = False))
  65. if col == 1:
  66. ax.set_yticks([1.5])
  67. ax.set_yticklabels([speakers[row-1]])
  68. if col == 4:
  69. ax.yaxis.tick_right()
  70. ax.set_yticks(np.arange(1,3))
  71. ax.set_yticklabels([f'10$^{i}' for i in np.arange(1,3)])
  72. #sns.kdeplot(fit[f'mus.{label}'], fit[f'etas.{label}'].apply(np.log), shade=True, cmap="viridis", ax = ax)
  73. kplt = sns.kdeplot(mus, etas, shade=True, cmap="viridis", ax = ax)
  74. kplt.set(xlabel = None, ylabel = None)
  75. fig.subplots_adjust(wspace = 0, hspace = 0)
  76. plt.savefig('density.pdf')
  77. plt.show()