correlations_child_level.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from email.errors import MalformedHeaderDefect
  2. from logging import logMultiprocessing
  3. import pandas as pd
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. import matplotlib
  7. matplotlib.use("pgf")
  8. matplotlib.rcParams.update(
  9. {
  10. "pgf.texsystem": "pdflatex",
  11. "font.family": "serif",
  12. "font.serif": "Times New Roman",
  13. "text.usetex": True,
  14. "pgf.rcfonts": False,
  15. }
  16. )
  17. import pickle
  18. from os.path import join as opj
  19. import argparse
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--vtc")
  22. parser.add_argument("--lena")
  23. args = parser.parse_args()
  24. speakers = ["CHI", "OCH", "FEM", "MAL"]
  25. cb_colors = [
  26. "#377eb8",
  27. "#ff7f00",
  28. "#f781bf",
  29. "#4daf4a",
  30. "#a65628",
  31. "#984ea3",
  32. "#999999",
  33. "#e41a1c",
  34. "#dede00",
  35. ]
  36. corpora = {
  37. 'bergelson': 0, 'cougar': 1, 'fausey-trio': 2, 'lucid': 3, 'warlaumont': 4, 'winnipeg': 5
  38. }
  39. corpora_names = {
  40. corpora[corpus]: corpus
  41. for corpus in corpora
  42. }
  43. def correlations(data, samples):
  44. mu = {
  45. x: samples[x]["mu_child_level"]
  46. for x in ["vtc_raw", "lena_raw", "vtc_calibrated", "lena_calibrated"]
  47. }
  48. beta_och = {
  49. x: samples[x]["beta_sib_och"]
  50. for x in ["vtc_raw", "lena_raw", "vtc_calibrated", "lena_calibrated"]
  51. }
  52. beta_adu = {
  53. x: samples[x]["beta_sib_adu"]
  54. for x in ["vtc_raw", "lena_raw", "vtc_calibrated", "lena_calibrated"]
  55. }
  56. sibs = np.zeros(data["vtc_raw"]["n_children"])
  57. for k in range(data["vtc_raw"]["n_recs"]):
  58. sibs[data["vtc_raw"]["children"].iloc[k]-1] = data["vtc_raw"]["siblings"].iloc[k]
  59. has_sibs_data = sibs>=0
  60. fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=([6.4, 6.4]))
  61. for i in range(2):
  62. for j in range(2):
  63. if i<1-j:
  64. axes[i, j].axis('off')
  65. continue
  66. a = 1-i+1
  67. b = j+2
  68. mu_r = {}
  69. low = {}
  70. high = {}
  71. r = {}
  72. bins = np.linspace(-1,1,100)
  73. for x in ["vtc_raw", "lena_raw", "vtc_calibrated", "lena_calibrated"]:
  74. beta_a = beta_och[x] if a==1 else beta_adu[x]/10
  75. beta_b = beta_och[x] if b==1 else beta_adu[x]/10
  76. n_samples = mu[x].shape[0]
  77. mu_r[x] = [np.corrcoef(mu[x][k,has_sibs_data,a-1]*np.exp((sibs[has_sibs_data]==0)*beta_a[k]), mu[x][k,has_sibs_data,b-1]*np.exp((sibs[has_sibs_data]==0)*beta_b[k]))[0,1] for k in range(n_samples)]
  78. low[x] = np.quantile(mu_r[x],q=0.05/2)
  79. high[x] = np.quantile(mu_r[x],q=1-0.05/2)
  80. r[x] = np.mean(mu_r[x])
  81. axes[i,j].axvline(r["vtc_raw"], color=cb_colors[0], lw=1, ls="dashed")
  82. axes[i,j].axvspan(low["vtc_raw"], high["vtc_raw"], color=cb_colors[0], alpha=0.2)
  83. axes[i,j].axvline(r["lena_raw"], color=cb_colors[1], lw=1, ls="dashed")
  84. axes[i,j].axvspan(low["lena_raw"], high["lena_raw"], color=cb_colors[1], alpha=0.2)
  85. # axes[i,j].axvline(calibration_r, color="olive", lw=0.5, ls="dashed")
  86. axes[i,j].hist(mu_r["vtc_calibrated"], bins=bins, histtype="step", density=True, color=cb_colors[0])
  87. axes[i,j].text(
  88. 1-0.05, 0.95, "\\scriptsize\\textbf{VTC}:", ha="right", transform=axes[i,j].transAxes, color=cb_colors[0]
  89. )
  90. axes[i,j].text(
  91. 1-0.05, 0.9,
  92. f"\\scriptsize $r(\\mathrm{{{speakers[a]}}},\\mathrm{{{speakers[b]}}})={r['vtc_calibrated']:.2f}$",
  93. ha="right",
  94. transform=axes[i,j].transAxes,
  95. color="black"
  96. )
  97. axes[i,j].text(
  98. 1-0.05, 0.85,
  99. f"\\scriptsize $\\mathrm{{CI}}_{{95\\%}}$[{low['vtc_calibrated']:.2f}, {high['vtc_calibrated']:.2f}]",
  100. ha="right",
  101. transform = axes[i,j].transAxes,
  102. color="black"
  103. )
  104. axes[i,j].hist(mu_r["lena_calibrated"], bins=bins, histtype="step", density=True, color=cb_colors[1])
  105. axes[i,j].text(
  106. 1-0.05, 0.7, "\\scriptsize\\textbf{LENA}:", ha="right", transform=axes[i,j].transAxes, color=cb_colors[1]
  107. )
  108. axes[i,j].text(
  109. 1-0.05, 0.65,
  110. f"\\scriptsize $r(\\mathrm{{{speakers[a]}}},\\mathrm{{{speakers[b]}}})={r['lena_calibrated']:.2f}$",
  111. ha="right",
  112. transform=axes[i,j].transAxes,
  113. color="black"
  114. )
  115. axes[i,j].text(
  116. 1-0.05, 0.6,
  117. f"\\scriptsize $\\mathrm{{CI}}_{{95\\%}}$[{low['lena_calibrated']:.2f}, {high['lena_calibrated']:.2f}]",
  118. ha="right",
  119. transform = axes[i,j].transAxes,
  120. color="black"
  121. )
  122. axes[i,j].text(
  123. 1-0.05, 0.8,
  124. f"\\scriptsize $r(\\mathrm{{uncalib.}})={r['vtc_raw']:.2f}$",
  125. ha="right",
  126. transform = axes[i,j].transAxes,
  127. color="black"
  128. )
  129. axes[i,j].text(
  130. 1-0.05, 0.55,
  131. f"\\scriptsize $r(\\mathrm{{uncalib.}})={r['lena_raw']:.2f}$",
  132. ha="right",
  133. transform = axes[i,j].transAxes,
  134. color="black"
  135. )
  136. # axes[i,j].text(
  137. # 0.05, 0.6,
  138. # f"\\scriptsize $r(\\mathrm{{human}})={calibration_r:.2f}\\ast$",
  139. # ha="left",
  140. # transform = axes[i,j].transAxes,
  141. # color="olive"
  142. # )
  143. axes[i,j].set_yticks([])
  144. axes[i,j].set_yticklabels([])
  145. axes[i,j].set_xlim(-0.5, 0.5)
  146. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  147. fig.savefig("output/correlations_child_level_all.eps", bbox_inches="tight")
  148. fig.savefig("output/correlations_child_level_all.png", bbox_inches="tight", dpi=720)
  149. fig.savefig("output/correlations_child_level_all.pdf", bbox_inches="tight")
  150. data = {
  151. "truth": "output/aggregates_lena_age24_human.pickle",
  152. "lena_raw": "output/aggregates_lena_age24_algo.pickle",
  153. "vtc_raw": "output/aggregates_vtc_age24_algo.pickle",
  154. "lena_calibrated": "output/aggregates_lena_age24_dev_siblings_binomial_hurdle_fast.pickle",
  155. "vtc_calibrated": "output/aggregates_vtc_age24_dev_siblings_binomial_hurdle_fast.pickle",
  156. }
  157. for key in data:
  158. with open(data[key], "rb") as f:
  159. data[key] = pickle.load(f)
  160. samples = {
  161. "truth": np.load("output/aggregates_lena_age24_human.npz"),
  162. "lena_raw": np.load("output/aggregates_lena_age24_algo.npz"),
  163. "vtc_raw": np.load("output/aggregates_vtc_age24_algo.npz"),
  164. # "lena_calibrated": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
  165. # "vtc_calibrated": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
  166. "lena_calibrated": np.load("output/aggregates_lena_age24_dev_siblings_binomial_hurdle_fast.npz"),
  167. "vtc_calibrated": np.load("output/aggregates_vtc_age24_dev_siblings_binomial_hurdle_fast.npz")
  168. }
  169. labels = {
  170. "prior": "Prior",
  171. "truth": "Manual annotations",
  172. "lena_raw": "LENA (uncalibrated)",
  173. "vtc_raw": "VTC (uncalibrated)",
  174. "lena_calibrated": "LENA (calibrated)",
  175. "vtc_calibrated": "VTC (calibrated)"
  176. }
  177. correlations(data, samples)