votes_distrib.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python
  2. import pandas as pd
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. import seaborn as sns
  6. import argparse
  7. def order(x,l):
  8. if x == "NO-LABEL":
  9. return 2
  10. if x in ["NA","Junk"]:
  11. return 1
  12. else:
  13. return -sorted(l, reverse=True).index(x)
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--key")
  16. parser.add_argument("--threads-per-chain", type=int, default=4)
  17. parser.add_argument("--normalize", action="store_true", default=False)
  18. args = parser.parse_args()
  19. majority_col = f"majority_label_{args.key}"
  20. individual_col = f"labels_{args.key}"
  21. df = pd.concat([
  22. pd.read_csv("speech-maturity-dataset/data/babblecor/babblecor.csv"),
  23. pd.read_csv("speech-maturity-dataset/data/maturity1/maturity1.csv"),
  24. pd.read_csv("speech-maturity-dataset/data/maturity2/maturity2.csv"),
  25. ])
  26. df["clip_id"] = df.index.astype(int)+1
  27. df.dropna(axis=0, subset=[majority_col], inplace=True)
  28. df[individual_col] = df[individual_col].str.split(",")
  29. df["n_responses"] = df[individual_col].map(len)
  30. df = df.explode(individual_col)
  31. df = df.groupby(["clip_id",individual_col]).agg(
  32. n=("child_id", "count"),
  33. majority_label=(majority_col, "first"),
  34. n_responses=("n_responses", "first")
  35. ).reset_index()
  36. df = df.pivot(index="clip_id", columns=individual_col,values="n").fillna(0)
  37. print(df)
  38. data = {
  39. "N": len(df),
  40. "C": len(df.columns),
  41. "votes": df.values.astype(int),
  42. }
  43. from cmdstanpy import CmdStanModel
  44. model = CmdStanModel(
  45. stan_file=f"votes_distrib.stan",
  46. cpp_options={"STAN_THREADS": "TRUE"},
  47. )
  48. fit = model.sample(
  49. data=data,
  50. chains=2,
  51. threads_per_chain=args.threads_per_chain,
  52. iter_sampling=1000,
  53. iter_warmup=250,
  54. show_console=True,
  55. )
  56. vars = fit.stan_variables()
  57. samples = {}
  58. for (k, v) in vars.items():
  59. samples[k] = v
  60. np.savez('votes_distrib.npz', **samples)
  61. # if args.normalize:
  62. # df["n"] /= df["n_responses"]
  63. # # df = df.pivot(index=["clip_id"], columns=individual_col, values="n").reset_index()
  64. # df = df.groupby([individual_col, "n"]).agg(
  65. # clip_id = ("n", "count")
  66. # )
  67. # df.to_csv(f"votes_distrib_{args.key}.csv")