votes_distrib.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import pandas as pd
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. import seaborn as sns
  5. import argparse
  6. def order(x,l):
  7. if x == "NO-LABEL":
  8. return 2
  9. if x in ["NA","Junk"]:
  10. return 1
  11. else:
  12. return -sorted(l, reverse=True).index(x)
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--key")
  15. parser.add_argument("--threads-per-chain", type=int, default=4)
  16. parser.add_argument("--normalize", action="store_true", default=False)
  17. args = parser.parse_args()
  18. majority_col = f"majority_label_{args.key}"
  19. individual_col = f"labels_{args.key}"
  20. df = pd.concat([
  21. pd.read_csv("speech-maturity-dataset/data/babblecor/babblecor.csv"),
  22. pd.read_csv("speech-maturity-dataset/data/maturity1/maturity1.csv"),
  23. pd.read_csv("speech-maturity-dataset/data/maturity2/maturity2.csv"),
  24. ])
  25. df["clip_id"] = df.index.astype(int)+1
  26. df.dropna(axis=0, subset=[majority_col], inplace=True)
  27. df[individual_col] = df[individual_col].str.split(",")
  28. df["n_responses"] = df[individual_col].map(len)
  29. df = df.explode(individual_col)
  30. df = df.groupby(["clip_id",individual_col]).agg(
  31. n=("child_id", "count"),
  32. majority_label=(majority_col, "first"),
  33. n_responses=("n_responses", "first")
  34. ).reset_index()
  35. df = df.pivot(index="clip_id", columns=individual_col,values="n").fillna(0)
  36. print(df)
  37. data = {
  38. "N": len(df),
  39. "C": len(df.columns),
  40. "votes": df.values.astype(int),
  41. }
  42. from cmdstanpy import CmdStanModel
  43. model = CmdStanModel(
  44. stan_file=f"votes_distrib.stan",
  45. cpp_options={"STAN_THREADS": "TRUE"},
  46. )
  47. fit = model.sample(
  48. data=data,
  49. chains=2,
  50. threads_per_chain=args.threads_per_chain,
  51. iter_sampling=1000,
  52. iter_warmup=250,
  53. show_console=True,
  54. )
  55. vars = fit.stan_variables()
  56. samples = {}
  57. for (k, v) in vars.items():
  58. samples[k] = v
  59. np.savez('votes_distrib.npz', **samples)
  60. # if args.normalize:
  61. # df["n"] /= df["n_responses"]
  62. # # df = df.pivot(index=["clip_id"], columns=individual_col, values="n").reset_index()
  63. # df = df.groupby([individual_col, "n"]).agg(
  64. # clip_id = ("n", "count")
  65. # )
  66. # df.to_csv(f"votes_distrib_{args.key}.csv")