cross_validation.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from os.path import join as opj
  2. from re import A
  3. import pandas as pd
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. import matplotlib
  7. matplotlib.use("pgf")
  8. matplotlib.rcParams.update(
  9. {
  10. "pgf.texsystem": "pdflatex",
  11. "font.family": "serif",
  12. "font.serif": "Times New Roman",
  13. "text.usetex": True,
  14. "pgf.rcfonts": False,
  15. }
  16. )
  17. import pickle
  18. import argparse
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument("--vtc")
  21. parser.add_argument("--output")
  22. args = parser.parse_args()
  23. speakers = ["CHI", "OCH", "FEM", "MAL"]
  24. def validation(data, samples, given_lambda=False):
  25. fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
  26. # FEM,MAL
  27. # OCH,FEM OCH,MAL
  28. truth = np.zeros((data["n_groups"], 4))
  29. algo = np.zeros((data["n_groups"], 4))
  30. for c in range(data["n_clips"]):
  31. truth[data["group"][c]-1] += data["truth_total"][c]
  32. algo[data["group"][c]-1] += data["algo_total"][c]
  33. variable = "sim_vocs_given_lambda" if given_lambda else "sim_vocs"
  34. for row in range(2):
  35. for col in range(2):
  36. ax = axes[row, col]
  37. i = row*2+col
  38. ax.scatter(algo[:,i], truth[:,i], color="#f781bf", s=3.5, marker="s", facecolors='none')
  39. mu = np.mean(samples[variable][:,:,i], axis=0)
  40. low = np.quantile(samples[variable][:,:,i], axis=0, q=(1-0.68)/2)
  41. up = np.quantile(samples[variable][:,:,i], axis=0, q=1-(1-0.68)/2)
  42. ax.plot([0, 200], [0,200], color="black", lw=0.5, ls="dashed")
  43. ax.scatter(algo[:,i], mu, s=3, facecolors='none', edgecolors="#377eb8")
  44. ax.errorbar(algo[:,i], (low+up)/2, ((low+up)/2-low, up-(low+up)/2), ls="none", lw=0.5)
  45. if row == 1 and col == 0:
  46. ax.set_xlabel("algo (obs.)")
  47. ax.set_ylabel("algo (pred)")
  48. # ax.set_xscale("log")
  49. # ax.set_yscale("log")
  50. ax.set_ylim(-5,200)
  51. ax.set_xlim(-5,200)
  52. ax.set_title(speakers[i])
  53. x1, x2, y1, y2 = -2.5, 25, -2.5, 25
  54. axins = ax.inset_axes(
  55. [0.57, 0.06, 0.4, 0.4],
  56. xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
  57. axins.set_xlim(x1, x2)
  58. axins.set_ylim(y1, y2)
  59. axins.scatter(algo[:,i], truth[:,i], color="#f781bf", s=3.5, marker="s", facecolors='none')
  60. axins.plot([0, 200], [0,200], color="black", lw=0.5, ls="dashed")
  61. axins.scatter(algo[:,i], mu, s=3, facecolors='none', edgecolors="#377eb8")
  62. axins.errorbar(algo[:,i], (low+up)/2, ((low+up)/2-low, up-(low+up)/2), ls="none", lw=0.5)
  63. ax.indicate_inset_zoom(axins, edgecolor="black")
  64. fig.savefig(f"output/validation_{variable}_{args.output}.eps", bbox_inches="tight")
  65. data = {
  66. "vtc": f"output/{args.vtc}.pickle",
  67. }
  68. for key in data:
  69. with open(data[key], "rb") as f:
  70. data[key] = pickle.load(f)
  71. samples = {
  72. "vtc": np.load(f"output/{args.vtc}.npz"),
  73. }
  74. validation(data["vtc"], samples["vtc"], True)
  75. validation(data["vtc"], samples["vtc"], False)