123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from email.errors import MalformedHeaderDefect
- from logging import logMultiprocessing
- import pandas as pd
- import numpy as np
- from matplotlib import pyplot as plt
- import matplotlib
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "pdflatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- import pickle
- from os.path import join as opj
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--vtc")
- parser.add_argument("--lena")
- args = parser.parse_args()
- data = {
- "lena": opj("output", f"{args.lena}.pickle"),
- "vtc": opj("output", f"{args.vtc}.pickle")
- }
- for key in data:
- with open(data[key], "rb") as f:
- data[key] = pickle.load(f)
- samples = {
- "vtc": np.load(opj("output", f"{args.vtc}.npz")),
- "lena": np.load(opj("output", f"{args.lena}.npz"))
- }
- labels = {
- "lena": "LENA",
- "vtc": "VTC"
- }
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- cb_colors = [
- "#377eb8",
- "#ff7f00",
- "#f781bf",
- "#4daf4a",
- "#a65628",
- "#984ea3",
- "#999999",
- "#e41a1c",
- "#dede00",
- ]
- corpora = {
- 'bergelson': 0, 'cougar': 1, 'fausey-trio': 2, 'lucid': 3, 'warlaumont': 4, 'winnipeg': 5
- }
- corpora_names = {
- corpora[corpus]: corpus
- for corpus in corpora
- }
- def algo_comparison(algo1, algo2, algo1_name, algo2_name):
- fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
- for row in range(2):
- for col in range(2):
- i = row + 2 * col
- for corpus in set(algo1["corpus"]):
- mask = np.array([algo1["corpus"][k - 1] == corpus for k in algo1["children"]])
- vocs_algo1 = algo1["vocs"][mask,i]
- vocs_algo2 = algo2["vocs"][mask,i]
- axes[row, col].scatter(
- vocs_algo1,
- vocs_algo2,
- s=0.5,
- color=cb_colors[corpus - 1],
- )
- if row == 1 and col == 0:
- axes[row, col].set_xlabel(algo1_name)
- axes[row, col].set_ylabel(algo2_name)
- axes[row,col].set_xscale("log")
- axes[row,col].set_yscale("log")
- axes[row,col].set_xlim(20,5000)
- axes[row,col].set_ylim(20,5000)
- axes[row, col].axline((100, 100), (1000, 1000), color="black")
- axes[row, col].set_title(speakers[i], y=1, pad=-14)
- plt.subplots_adjust(wspace=0, hspace=0)
- fig.savefig(f"output/algo_comparison_{algo1_name}_{algo2_name}.eps", bbox_inches="tight")
- algo_comparison(data["lena"], data["vtc"], "LENA", "VTC")
|