etm_ei.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import pandas as pd
  2. import numpy as np
  3. from cmdstanpy import CmdStanModel
  4. import argparse
  5. import pickle
  6. from os.path import join as opj, basename
  7. import seaborn as sns
  8. from matplotlib import pyplot as plt
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input", default="output/etm_20_pretrained")
  11. parser.add_argument("--portfolios", default=None)
  12. parser.add_argument("--resources", default=None)
  13. parser.add_argument("--n", type=int, default=4000)
  14. parser.add_argument("--min-pop", type=int, default=100)
  15. parser.add_argument("--stack-rows", type=int, default=1)
  16. parser.add_argument("--chains", type=int, default=1)
  17. parser.add_argument("--threads-per-chain", type=int, default=4)
  18. parser.add_argument("--samples", type=int, default=500)
  19. parser.add_argument("--warmup", type=int, default=300)
  20. parser.add_argument("--model", default="control_nu")
  21. parser.add_argument("--mle", action="store_true", default=False)
  22. args = parser.parse_args()
  23. portfolios = opj(args.input, "aggregate.csv") if args.portfolios is None else args.portfolios
  24. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  25. df = pd.read_csv(portfolios)
  26. df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= args.min_pop]
  27. if args.n < len(df):
  28. df = df.head(n=args.n)
  29. n_authors = len(df)
  30. print(opj(args.input, f"ei_samples_{args.model}_{basename(portfolios)}.npz"))
  31. resources = pd.read_parquet(
  32. opj(args.input, "pooled_resources.parquet") if args.resources is None else args.resources
  33. )
  34. print(len(resources))
  35. df = df.merge(resources, left_on="bai", right_on="bai")
  36. assert len(df)==n_authors
  37. print(n_authors)
  38. data = {
  39. "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
  40. "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
  41. "expertise": np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values),
  42. "R": n_topics,
  43. "C": n_topics,
  44. "n_units": len(df),
  45. "threads": args.threads_per_chain
  46. }
  47. data["cov"] = np.stack(df["pooled_resources"])
  48. junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
  49. for col in ["NR", "NC", "cov", "expertise"]:
  50. data[col] = data[col][:, ~junk]
  51. data["R"] -= junk.sum()
  52. data["C"] -= junk.sum()
  53. data["cov"] = data["cov"]# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  54. fig, ax = plt.subplots()
  55. sns.heatmap(
  56. np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu", ax=ax
  57. )
  58. plt.show()
  59. fig, ax = plt.subplots()
  60. for i in range(data["R"]):
  61. ax.scatter(data["cov"][:,i], data["NR"][:,i]/data["NR"].sum(axis=1))
  62. plt.show()
  63. expertise = data["expertise"]
  64. data["nu"] = np.array([
  65. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(data["R"])]
  66. for i in range(data["R"])
  67. ])
  68. model = CmdStanModel(
  69. stan_file=f"code/ei_cov_softmax_{args.model}.stan",
  70. cpp_options={"STAN_THREADS": "TRUE"},
  71. )
  72. if args.mle:
  73. fit = model.optimize(
  74. data=data,
  75. show_console=True,
  76. iter=5000,
  77. )
  78. else:
  79. fit = model.sample(
  80. data=data,
  81. chains=args.chains,
  82. threads_per_chain=args.threads_per_chain,
  83. iter_sampling=args.samples,
  84. iter_warmup=args.warmup,
  85. show_console=True,
  86. max_treedepth=11,
  87. # step_size=0.1
  88. )
  89. vars = fit.stan_variables()
  90. samples = {}
  91. for (k, v) in vars.items():
  92. samples[k] = v
  93. format = "mle" if args.mle else "samples"
  94. if args.portfolios is None:
  95. np.savez_compressed(opj(args.input, f"ei_{format}_{args.model}.npz"), **samples)
  96. else:
  97. np.savez_compressed(opj(args.input, f"ei_{format}_{args.model}_{basename(portfolios)}.npz"), **samples)