algo_comparison.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from email.errors import MalformedHeaderDefect
  2. from logging import logMultiprocessing
  3. import pandas as pd
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. import matplotlib
  7. matplotlib.use("pgf")
  8. matplotlib.rcParams.update(
  9. {
  10. "pgf.texsystem": "pdflatex",
  11. "font.family": "serif",
  12. "font.serif": "Times New Roman",
  13. "text.usetex": True,
  14. "pgf.rcfonts": False,
  15. }
  16. )
  17. import pickle
  18. from os.path import join as opj
  19. import argparse
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--vtc")
  22. parser.add_argument("--lena")
  23. args = parser.parse_args()
  24. data = {
  25. "lena": opj("output", f"{args.lena}.pickle"),
  26. "vtc": opj("output", f"{args.vtc}.pickle")
  27. }
  28. for key in data:
  29. with open(data[key], "rb") as f:
  30. data[key] = pickle.load(f)
  31. samples = {
  32. "vtc": np.load(opj("output", f"{args.vtc}.npz")),
  33. "lena": np.load(opj("output", f"{args.lena}.npz"))
  34. }
  35. labels = {
  36. "lena": "LENA",
  37. "vtc": "VTC"
  38. }
  39. speakers = ["CHI", "OCH", "FEM", "MAL"]
  40. cb_colors = [
  41. "#377eb8",
  42. "#ff7f00",
  43. "#f781bf",
  44. "#4daf4a",
  45. "#a65628",
  46. "#984ea3",
  47. "#999999",
  48. "#e41a1c",
  49. "#dede00",
  50. ]
  51. corpora = {
  52. 'bergelson': 0, 'cougar': 1, 'fausey-trio': 2, 'lucid': 3, 'warlaumont': 4, 'winnipeg': 5
  53. }
  54. corpora_names = {
  55. corpora[corpus]: corpus
  56. for corpus in corpora
  57. }
  58. def algo_comparison(algo1, algo2, algo1_name, algo2_name):
  59. fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
  60. for row in range(2):
  61. for col in range(2):
  62. i = row + 2 * col
  63. for corpus in set(algo1["corpus"]):
  64. mask = np.array([algo1["corpus"][k - 1] == corpus for k in algo1["children"]])
  65. vocs_algo1 = algo1["vocs"][mask,i]
  66. vocs_algo2 = algo2["vocs"][mask,i]
  67. axes[row, col].scatter(
  68. vocs_algo1,
  69. vocs_algo2,
  70. s=0.5,
  71. color=cb_colors[corpus - 1],
  72. )
  73. if row == 1 and col == 0:
  74. axes[row, col].set_xlabel(algo1_name)
  75. axes[row, col].set_ylabel(algo2_name)
  76. axes[row,col].set_xscale("log")
  77. axes[row,col].set_yscale("log")
  78. axes[row,col].set_xlim(20,5000)
  79. axes[row,col].set_ylim(20,5000)
  80. axes[row, col].axline((100, 100), (1000, 1000), color="black")
  81. axes[row, col].set_title(speakers[i], y=1, pad=-14)
  82. plt.subplots_adjust(wspace=0, hspace=0)
  83. fig.savefig(f"output/algo_comparison_{algo1_name}_{algo2_name}.eps", bbox_inches="tight")
  84. algo_comparison(data["lena"], data["vtc"], "LENA", "VTC")