category_prediction_stability.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import pandas as pd
  2. import numpy as np
  3. from scipy.stats import binom
  4. import matplotlib
  5. from matplotlib import pyplot as plt
  6. matplotlib.use("pgf")
  7. matplotlib.rcParams.update(
  8. {
  9. "pgf.texsystem": "xelatex",
  10. "font.family": "serif",
  11. "font.serif": "Times New Roman",
  12. "text.usetex": True,
  13. "pgf.rcfonts": False,
  14. }
  15. )
  16. cats = ["Experiment", "Phenomenology", "Theory"]
  17. colors = ['#4daf4a', '#ff7f00', '#377eb8']
  18. accuracy = pd.read_csv("output/category_prediction/accuracy_per_period_kfold.csv").sort_values("year_group")
  19. accuracy = accuracy[accuracy["year_group"]<8]
  20. fig, ax = plt.subplots(1,1)
  21. for i in range(3):
  22. ci = ([],[])
  23. for row in accuracy.to_dict(orient="records"):
  24. low,high = binom.ppf([0.025, 0.975], row[f"count_{i}"], row[f"accurate_{i}"], loc=0)
  25. ci[0].append(low/row[f"count_{i}"])
  26. ci[1].append(high/row[f"count_{i}"])
  27. ax.scatter(accuracy["year_group"], accuracy[f"accurate_{i}"], color=colors[i], label=cats[i], s=15)
  28. ax.errorbar(accuracy["year_group"], accuracy[f"accurate_{i}"], yerr=(ci[0]-accuracy[f"accurate_{i}"], accuracy[f"accurate_{i}"]-ci[1]), color=colors[i], ls="none")
  29. ax.plot(accuracy["year_group"], accuracy[f"dummy_accurate_{i}"], color=colors[i], ls="dashed")
  30. ax.set_xticks(accuracy["year_group"].tolist())
  31. ax.set_xticklabels([f"{1980+i*5}-{1980+(i+1)*5-1}" for i in accuracy["year_group"].tolist()], rotation=45, ha='right')
  32. ax.set_ylabel("Accuracy")
  33. ax.legend()
  34. fig.savefig("plots/category_prediction_stability.eps", bbox_inches="tight")
  35. fig.savefig("plots/category_prediction_stability.png", bbox_inches="tight")
  36. plt.show()