word2vec_validation.py 1009 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. import numpy as np
  2. from matplotlib import pyplot as plt
  3. import matplotlib
  4. matplotlib.use("pgf")
  5. matplotlib.rcParams.update(
  6. {
  7. "pgf.texsystem": "xelatex",
  8. "font.family": "serif",
  9. "font.serif": "Times New Roman",
  10. "text.usetex": True,
  11. "pgf.rcfonts": False,
  12. }
  13. )
  14. plt.rcParams["text.latex.preamble"].join([
  15. r"\usepackage{amsmath}",
  16. r"\setmainfont{amssymb}",
  17. ])
  18. from matplotlib.gridspec import GridSpec
  19. import pandas as pd
  20. import argparse
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument("--input")
  23. args = parser.parse_args()
  24. df = pd.read_csv(args.input)
  25. fig, ax = plt.subplots(figsize=[6.4,3.2])
  26. for d, v in df.groupby("dim"):
  27. ax.scatter([d]*len(v), v["loss"], alpha=0.5, color="gray", facecolors="none")
  28. avg = df.groupby("dim")["loss"].mean()
  29. ax.plot(avg.index, avg)
  30. ax.set_xlabel("Embeddings dimension ($L$)")
  31. ax.set_ylabel("Skip-gram Word2Vec Loss")
  32. fig.savefig("output/word2vec_validation.pdf", bbox_inches="tight")