123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import pandas as pd
- import numpy as np
- from matplotlib import pyplot as plt
- import seaborn as sns
- import argparse
- def order(x,l):
- if x == "NO-LABEL":
- return 2
- if x in ["NA","Junk"]:
- return 1
- else:
- return -sorted(l, reverse=True).index(x)
- parser = argparse.ArgumentParser()
- parser.add_argument("--key")
- parser.add_argument("--normalize", action="store_true", default=False)
- args = parser.parse_args()
- majority_col = f"majority_label_{args.key}"
- individual_col = f"labels_{args.key}"
- df = pd.concat([
- pd.read_csv("speech-maturity-dataset/data/babblecor/babblecor.csv"),
- pd.read_csv("speech-maturity-dataset/data/maturity1/maturity1.csv"),
- pd.read_csv("speech-maturity-dataset/data/maturity2/maturity2.csv"),
- ])
- df["clip_id"] = df.index.astype(int)+1
- df.dropna(axis=0, subset=[majority_col], inplace=True)
- df[individual_col] = df[individual_col].str.split(",")
- df["n_responses"] = df[individual_col].map(len)
- df = df.explode(individual_col)
- df = df.groupby(["clip_id",individual_col]).agg(
- n=("child_id", "count"),
- majority_label=(majority_col, "first"),
- n_responses=("n_responses", "first")
- ).reset_index()
- if args.normalize:
- df["n"] /= df["n_responses"]
- df = df.pivot(index=["clip_id", "majority_label"], columns=individual_col, values="n").reset_index()
- df.fillna(0, inplace=True)
- df.drop(columns=["clip_id"], inplace=True)
- df = df.groupby("majority_label").mean()
- x_labels = set(df.columns.values)
- x_order = {x: order(x,x_labels) for x in x_labels}
- y_labels = set(df.index.values)
- y_order = {y: order(y,y_labels) for y in y_labels}
- print(x_order)
- print(y_order)
- df = df.iloc[df.index.map(y_order).argsort()]
- df = df[sorted(df.columns,key=lambda x: x_order[x])]
- fig, ax = plt.subplots()
- sns.heatmap(df, ax=ax, cmap="Blues", annot=True, fmt=".2f")
- ax.set_ylabel("Majority label")
- ax.set_xlabel("Individual response frequency")
- fig.savefig(f"confusion_{args.key}.png", bbox_inches="tight", dpi=720)
|