threshold.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import pandas as pd
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. import seaborn as sns
  5. df = pd.read_csv("data/maturity2/maturity2.csv")
  6. df = df[df["selection"]=="key_chi"]
  7. df["child_age"] /= 365
  8. df = df[~df["majority_label_age"].isin(["Junk",""])]
  9. df = df.groupby(["child_id", "child_age", "majority_label_age"]).agg(
  10. n=("majority_label_age", "count")
  11. )
  12. df = df.reset_index()
  13. df = df.pivot(columns="majority_label_age", index=["child_id","child_age"], values="n")
  14. df["ratio"] = df["Baby"]/(df["Baby"]+df["Child"])
  15. df = df.reset_index()
  16. df = df[df["ratio"]>=0]
  17. from cmdstanpy import CmdStanModel
  18. data = {
  19. "N": len(df),
  20. "age": df["child_age"].astype(float).values,
  21. "ratio": df["ratio"].astype(float).values
  22. }
  23. model = CmdStanModel(
  24. stan_file=f"model.stan",
  25. )
  26. fit = model.sample(
  27. data=data,
  28. chains=4,
  29. threads_per_chain=1,
  30. iter_sampling=2000,
  31. iter_warmup=500,
  32. show_console=True,
  33. )
  34. vars = fit.stan_variables()
  35. samples = {}
  36. for (k, v) in vars.items():
  37. samples[k] = v
  38. m = samples["lp"].max()
  39. q = np.exp(samples["lp"]-m).mean(axis=0)
  40. p = q/q.sum()
  41. fig, ax1 = plt.subplots()
  42. ax1.scatter(df["child_age"], df["ratio"], label="Data")
  43. ax1.set_xlabel("Key child age (in years)")
  44. ax1.set_ylabel("Baby/(Baby+Child)")
  45. ax2 = ax1.twinx()
  46. ax2.plot(np.linspace(0,5,len(p)),q, color="black", label="Threshold probability density")
  47. fig.legend()
  48. fig.savefig("output/ratio_vs_age.png", bbox_inches="tight")