effects_comparison.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. from scipy.stats import binom
  12. import pandas as pd
  13. from pyannote.core import Annotation, Segment, Timeline
  14. import matplotlib
  15. matplotlib.use("pgf")
  16. matplotlib.rcParams.update({
  17. "pgf.texsystem": "pdflatex",
  18. 'font.family': 'serif',
  19. 'text.usetex': True,
  20. 'pgf.rcfonts': False,
  21. })
  22. from matplotlib import pyplot as plt
  23. effects = ["beta_sib_och", "beta_sib_adu", "alpha_dev", "beta_dev"]
  24. variables = {
  25. "beta_sib_och": "$\\beta_{\\mathrm{OCH}}^{\\mathrm{sib}}$",
  26. "beta_sib_adu": "$\\beta_{\\mathrm{ADU}}^{\\mathrm{sib}}$",
  27. "alpha_dev": "$\\alpha_{\\mathrm{dev}}$",
  28. "beta_dev": "$\\beta_{\\mathrm{dev}}$",
  29. }
  30. prior_distribution = {
  31. effect: np.random.randn(10000) for effect in effects
  32. }
  33. samples = {
  34. "prior": prior_distribution,
  35. "truth": np.load("output/aggregates_truth_truth_only.npz"),
  36. "lena_raw": np.load("output/aggregates_lena_cougar_sibs_algo_siblings_adu.npz"),
  37. "vtc_raw": np.load("output/aggregates_vtc_cougar_sibs_algo_siblings_adu.npz"),
  38. # "vtc_calibrated": np.load("output/aggregates_vtc_dev_siblings_effect.npz"),
  39. # "lena_calibrated": np.load("output/aggregates_lena_dev_siblings_effect.npz"),
  40. # "vtc_calibrated": np.load("output/aggregates_vtc_fausey_15_dev_siblings.npz"),
  41. # "lena_calibrated": np.load("output/aggregates_lena_fausey_30_dev_siblings.npz")
  42. "vtc_calibrated": np.load("output/aggregates_vtc_sibs_dev_siblings.npz"),
  43. "lena_calibrated": np.load("output/aggregates_lena_sibs_dev_siblings.npz")
  44. }
  45. labels = {
  46. "prior": "Prior",
  47. "truth": "Manual annotations",
  48. "lena_raw": "LENA (uncalibrated)",
  49. "vtc_raw": "VTC (uncalibrated)",
  50. "lena_calibrated": "LENA (calibrated)",
  51. "vtc_calibrated": "VTC (calibrated)"
  52. }
  53. positions = {
  54. "prior": -2,
  55. "truth": 0,
  56. "lena_raw": 3,
  57. "vtc_raw": 2,
  58. "vtc_calibrated": 5,
  59. "lena_calibrated": 6,
  60. }
  61. def plot_effect(effect):
  62. fig, ax = plt.subplots(figsize=[6.4*0.8,3.2*0.8])
  63. for key in samples:
  64. mean = samples[key][effect].mean()
  65. up = np.quantile(samples[key][effect], q=1-0.05/2)
  66. low = np.quantile(samples[key][effect], q=0.05/2)
  67. ax.text(
  68. mean,
  69. positions[key]+0.25,
  70. f"\\footnotesize{{$\\mu={mean:.2f}$, $\\mathrm{{CI}}_{{95\\%}}=\\left[{low:.2f}, {up:.2f}\\right]$}}", ha="center"
  71. )
  72. ax.scatter([mean],[positions[key]], color="black" if key=="prior" else None)
  73. ax.errorbar([mean], [positions[key]], xerr=([mean-low], [up-mean]), ls="none", color="black" if key=="prior" else None)
  74. ax.set_xlabel(variables[effect])
  75. ax.set_yticks(list(positions.values()))
  76. ax.set_yticklabels([labels[key] for key in positions])
  77. ax.set_ylim(np.min(list(positions.values()))-0.25, np.max(list(positions.values()))+1)
  78. ax.axvline(0, 0, 1, color="black", ls="dashed")
  79. fig.savefig(f"output/effect_comparison_{effect}.eps", bbox_inches="tight")
  80. fig.savefig(f"output/effect_comparison_{effect}.png", bbox_inches="tight", dpi=720)
  81. fig.savefig(f"output/effect_comparison_{effect}.pdf", bbox_inches="tight")
  82. for effect in effects:
  83. plot_effect(effect)