etm_ei.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from socketserver import ThreadingUnixStreamServer
  2. import pandas as pd
  3. import numpy as np
  4. from cmdstanpy import CmdStanModel
  5. import argparse
  6. import pickle
  7. from os.path import join as opj
  8. import seaborn as sns
  9. from matplotlib import pyplot as plt
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--input")
  12. parser.add_argument("--n", type=int, default=200)
  13. parser.add_argument("--min-pop", type=int, default=100)
  14. parser.add_argument("--stack-rows", type=int, default=1)
  15. parser.add_argument("--chains", type=int, default=1)
  16. parser.add_argument("--threads-per-chain", type=int, default=4)
  17. parser.add_argument("--samples", type=int, default=500)
  18. parser.add_argument("--warmup", type=int, default=1000)
  19. args = parser.parse_args()
  20. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  21. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  22. df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= args.min_pop]
  23. df = df.sample(n=args.n)
  24. resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
  25. df = df.merge(resources, left_on="bai", right_on="bai")
  26. data = {
  27. "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
  28. "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
  29. "R": n_topics,
  30. "C": n_topics,
  31. "n_units": len(df),
  32. "threads": args.threads_per_chain
  33. }
  34. data["cov"] = np.stack(df["pooled_resources"])
  35. junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
  36. for col in ["NR", "NC", "cov"]:
  37. data[col] = data[col][:, ~junk]
  38. data["R"] -= junk.sum()
  39. data["C"] -= junk.sum()
  40. data["cov"] = data["cov"] / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  41. sns.heatmap(
  42. np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu"
  43. )
  44. plt.show()
  45. print(data["cov"].shape)
  46. model = CmdStanModel(
  47. stan_file=f"code/ei_cov_softmax_control.stan",
  48. cpp_options={"STAN_THREADS": "TRUE"},
  49. compile="force",
  50. )
  51. fit = model.sample(
  52. data=data,
  53. chains=args.chains,
  54. threads_per_chain=args.threads_per_chain,
  55. iter_sampling=args.samples,
  56. iter_warmup=args.warmup,
  57. )
  58. vars = fit.stan_variables()
  59. samples = {}
  60. for (k, v) in vars.items():
  61. samples[k] = v
  62. np.savez_compressed(opj(args.input, "ei_samples.npz"), **samples)