etm_ei.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import pandas as pd
  2. import numpy as np
  3. from cmdstanpy import CmdStanModel
  4. import argparse
  5. import pickle
  6. from os.path import join as opj
  7. import seaborn as sns
  8. from matplotlib import pyplot as plt
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input", default="output/etm_20_pretrained")
  11. parser.add_argument("--n", type=int, default=200)
  12. parser.add_argument("--min-pop", type=int, default=100)
  13. parser.add_argument("--stack-rows", type=int, default=1)
  14. parser.add_argument("--chains", type=int, default=1)
  15. parser.add_argument("--threads-per-chain", type=int, default=4)
  16. parser.add_argument("--samples", type=int, default=500)
  17. parser.add_argument("--warmup", type=int, default=300)
  18. parser.add_argument("--model", default="control")
  19. args = parser.parse_args()
  20. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  21. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  22. df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= args.min_pop]
  23. if args.n < len(df):
  24. df = df.head(n=args.n)
  25. print(df)
  26. resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
  27. df = df.merge(resources, left_on="bai", right_on="bai")
  28. data = {
  29. "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
  30. "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
  31. "expertise": np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values),
  32. "R": n_topics,
  33. "C": n_topics,
  34. "n_units": len(df),
  35. "threads": args.threads_per_chain
  36. }
  37. data["cov"] = np.stack(df["pooled_resources"])
  38. junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
  39. for col in ["NR", "NC", "cov", "expertise"]:
  40. data[col] = data[col][:, ~junk]
  41. data["R"] -= junk.sum()
  42. data["C"] -= junk.sum()
  43. data["cov"] = data["cov"]# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  44. fig, ax = plt.subplots()
  45. sns.heatmap(
  46. np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu", ax=ax
  47. )
  48. plt.show()
  49. fig, ax = plt.subplots()
  50. for i in range(data["R"]):
  51. ax.scatter(data["cov"][:,i], data["NR"][:,i]/data["NR"].sum(axis=1))
  52. plt.show()
  53. expertise = data["expertise"]
  54. data["nu"] = np.array([
  55. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(data["R"])]
  56. for i in range(data["R"])
  57. ])
  58. print(data)
  59. model = CmdStanModel(
  60. stan_file=f"code/ei_cov_softmax_{args.model}.stan",
  61. cpp_options={"STAN_THREADS": "TRUE"},
  62. compile="force",
  63. )
  64. fit = model.sample(
  65. data=data,
  66. chains=args.chains,
  67. threads_per_chain=args.threads_per_chain,
  68. iter_sampling=args.samples,
  69. iter_warmup=args.warmup,
  70. show_console=True,
  71. max_treedepth=11,
  72. # step_size=0.1
  73. )
  74. vars = fit.stan_variables()
  75. samples = {}
  76. for (k, v) in vars.items():
  77. samples[k] = v
  78. np.savez_compressed(opj(args.input, f"ei_samples_{args.model}.npz"), **samples)