confusion.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import pandas as pd
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. import seaborn as sns
  5. import argparse
  6. def order(x,l):
  7. if x == "NO-LABEL":
  8. return 2
  9. if x in ["NA","Junk"]:
  10. return 1
  11. else:
  12. return -sorted(l, reverse=True).index(x)
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--key")
  15. parser.add_argument("--normalize", action="store_true", default=False)
  16. args = parser.parse_args()
  17. majority_col = f"majority_label_{args.key}"
  18. individual_col = f"labels_{args.key}"
  19. df = pd.concat([
  20. pd.read_csv("speech-maturity-dataset/data/babblecor/babblecor.csv"),
  21. pd.read_csv("speech-maturity-dataset/data/maturity1/maturity1.csv"),
  22. pd.read_csv("speech-maturity-dataset/data/maturity2/maturity2.csv"),
  23. ])
  24. df["clip_id"] = df.index.astype(int)+1
  25. df.dropna(axis=0, subset=[majority_col], inplace=True)
  26. df[individual_col] = df[individual_col].str.split(",")
  27. df["n_responses"] = df[individual_col].map(len)
  28. df = df.explode(individual_col)
  29. df = df.groupby(["clip_id",individual_col]).agg(
  30. n=("child_id", "count"),
  31. majority_label=(majority_col, "first"),
  32. n_responses=("n_responses", "first")
  33. ).reset_index()
  34. if args.normalize:
  35. df["n"] /= df["n_responses"]
  36. df = df.pivot(index=["clip_id", "majority_label"], columns=individual_col, values="n").reset_index()
  37. df.fillna(0, inplace=True)
  38. df.drop(columns=["clip_id"], inplace=True)
  39. df = df.groupby("majority_label").mean()
  40. x_labels = set(df.columns.values)
  41. x_order = {x: order(x,x_labels) for x in x_labels}
  42. y_labels = set(df.index.values)
  43. y_order = {y: order(y,y_labels) for y in y_labels}
  44. print(x_order)
  45. print(y_order)
  46. df = df.iloc[df.index.map(y_order).argsort()]
  47. df = df[sorted(df.columns,key=lambda x: x_order[x])]
  48. fig, ax = plt.subplots()
  49. sns.heatmap(df, ax=ax, cmap="Blues", annot=True, fmt=".2f")
  50. ax.set_ylabel("Majority label")
  51. ax.set_xlabel("Individual response frequency")
  52. fig.savefig(f"confusion_{args.key}.png", bbox_inches="tight", dpi=720)