confusion_ratio.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #!/usr/bin/env python3
  2. import pandas as pd
  3. import pickle
  4. import numpy as np
  5. from scipy.special import logit, expit
  6. from scipy.stats import gamma
  7. import argparse
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. matplotlib.use("pgf")
  11. matplotlib.rcParams.update(
  12. {
  13. "pgf.texsystem": "pdflatex",
  14. "font.family": "serif",
  15. "font.serif": "Times New Roman",
  16. "text.usetex": True,
  17. "pgf.rcfonts": False,
  18. }
  19. )
  20. def set_size(width, fraction=1, ratio=None):
  21. fig_width_pt = width * fraction
  22. inches_per_pt = 1 / 72.27
  23. if ratio is None:
  24. ratio = (5**0.5 - 1) / 2
  25. fig_width_in = fig_width_pt * inches_per_pt
  26. fig_height_in = fig_width_in * ratio
  27. return fig_width_in, fig_height_in
  28. # parser = argparse.ArgumentParser(description="plot_pred")
  29. # parser.add_argument("data")
  30. # parser.add_argument("fit")
  31. # parser.add_argument("output")
  32. # args = parser.parse_args()
  33. fits = {
  34. # "vtc": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
  35. "vtc": np.load("output/aggregates_vtc_all_confusion_covariates.npz"),
  36. "lena": np.load("output/aggregates_lena_all_15_confusion_covariates.npz"),
  37. # "lena": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
  38. # "lena": np.load("output/aggregates_lena_all_confusion_covariates.npz")
  39. }
  40. labels = {
  41. "vtc": "VTC",
  42. "lena": "LENA"
  43. }
  44. colors = {
  45. "vtc": "#377eb8",
  46. "lena": "#ff7f00",
  47. }
  48. fig = plt.figure(figsize=set_size(450, 1, 1))
  49. ax = fig.add_subplot(4, 4, 1)
  50. axes = [ax] + [fig.add_subplot(4, 4, i + 1, sharex=ax, sharey=ax) for i in range(1, 4 * 4)]
  51. speakers = ["CHI", "OCH", "FEM", "MAL"]
  52. for i in range(4 * 4):
  53. ax = axes[i]
  54. row = i // 4
  55. col = i % 4
  56. # if args.group is None:
  57. # data = np.hstack([fit[f'alphas.{k}.{label}']/(fit[f'alphas.{k}.{label}']+fit[f'betas.{k}.{label}']).values for k in range(1,n_groups+1)])
  58. # else:
  59. # data = fit[f'alphas.{args.group}.{label}']/(fit[f'alphas.{args.group}.{label}']+fit[f'betas.{args.group}.{label}']).values
  60. # data = np.hstack([(fit[f'group_mus.{k}.{label}']).values for k in range(1,59)])
  61. # data = fit[f'mus.{label}'].values
  62. ax.axhline(y=0, color="black", lw=0.5)
  63. for algo in fits:
  64. fit = fits[algo]
  65. alphas_urban = fit["alphas"][:,0,row,col]
  66. mus_urban = fit["mus"][:,0,row,col]
  67. alphas_rural = fit["alphas"][:,1,row,col]
  68. mus_rural = fit["mus"][:,1,row,col]
  69. scale_urban = mus_urban/alphas_urban
  70. scale_rural = mus_rural/alphas_rural
  71. x = np.linspace(0.01,0.99,200,True)
  72. pdf_urban = np.zeros((len(x), len(alphas_urban)))
  73. pdf_rural = np.zeros((len(x), len(alphas_rural)))
  74. for k in range(len(x)):
  75. # pdf_urban[k,:] = gamma.logpdf(x[k], alphas_urban, np.zeros(len(alphas_urban)), scale_urban)
  76. # pdf_rural[k,:] = gamma.logpdf(x[k], alphas_rural, np.zeros(len(alphas_rural)), scale_rural)
  77. pdf_urban[k,:] = gamma.pdf(x[k], alphas_urban, np.zeros(len(alphas_urban)), scale_urban)
  78. pdf_rural[k,:] = gamma.pdf(x[k], alphas_rural, np.zeros(len(alphas_rural)), scale_rural)
  79. # log_ratio = pdf_rural - pdf_urban
  80. # low = np.quantile(log_ratio, q=0.0275, axis=1)
  81. # high = np.quantile(log_ratio, q=0.975, axis=1)
  82. # mean = np.mean(log_ratio, axis=1)
  83. # ax.plot(x, mean, label=labels[algo] if i==0 else None, color=colors[algo])
  84. # ax.fill_between(x, low, high, alpha=0.2, color=colors[algo])
  85. low_urban = np.quantile(pdf_urban, q=0.05/2, axis=1)
  86. high_urban = np.quantile(pdf_urban, q=1-0.05/2, axis=1)
  87. mean_urban = np.mean(pdf_urban, axis=1)
  88. low_rural = np.quantile(pdf_rural, q=0.05/2, axis=1)
  89. high_rural = np.quantile(pdf_rural, q=1-0.05/2, axis=1)
  90. mean_rural = np.mean(pdf_rural, axis=1)
  91. ax.plot(x, mean_urban, label=labels[algo] if i==0 else None, color=colors[algo], ls="dashed")
  92. ax.fill_between(x, low_urban, high_urban, alpha=0.05, color=colors[algo])
  93. ax.plot(x, mean_rural, color=colors[algo])
  94. ax.fill_between(x, low_rural, high_rural, alpha=0.2, color=colors[algo])
  95. ax.set_xticks([])
  96. ax.set_xticklabels([])
  97. ax.set_yticks([])
  98. ax.set_yticklabels([])
  99. ax.set_xlim(0, 1)
  100. ax.set_ylim(0, 10)
  101. if col == 0:
  102. ax.set_ylabel(speakers[row])
  103. if row == 0:
  104. ax.set_title(speakers[col])
  105. if row == 3:
  106. ax.set_xticks(np.linspace(0.25, 1, 3, endpoint=False))
  107. ax.set_xticklabels(np.linspace(0.25, 1, 3, endpoint=False))
  108. ax.text(
  109. 0.5, 0.9, f"{speakers[row]}$\\to${speakers[col]}",
  110. ha="center", transform = ax.transAxes
  111. )
  112. fig.subplots_adjust(wspace=0, hspace=0)
  113. fig.legend(bbox_to_anchor=(1, 0.1))
  114. plt.savefig("output/confusion_ratio.png", bbox_inches="tight", dpi=720)
  115. plt.savefig("output/confusion_ratio.pdf", bbox_inches="tight")
  116. plt.show()