|
@@ -15,6 +15,7 @@ def order(x,l):
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--key")
|
|
|
+parser.add_argument("--threads-per-chain", type=int, default=4)
|
|
|
parser.add_argument("--normalize", action="store_true", default=False)
|
|
|
args = parser.parse_args()
|
|
|
|
|
@@ -54,11 +55,12 @@ from cmdstanpy import CmdStanModel
|
|
|
|
|
|
model = CmdStanModel(
|
|
|
stan_file=f"votes_distrib.stan",
|
|
|
+ cpp_options={"STAN_THREADS": "TRUE"},
|
|
|
)
|
|
|
fit = model.sample(
|
|
|
data=data,
|
|
|
- chains=4,
|
|
|
- threads_per_chain=1,
|
|
|
+ chains=2,
|
|
|
+ threads_per_chain=args.threads_per_chain,
|
|
|
iter_sampling=1000,
|
|
|
iter_warmup=250,
|
|
|
show_console=True,
|
|
@@ -69,6 +71,8 @@ samples = {}
|
|
|
for (k, v) in vars.items():
|
|
|
samples[k] = v
|
|
|
|
|
|
+np.savez('votes_distrib.npz', **samples)
|
|
|
+
|
|
|
# if args.normalize:
|
|
|
# df["n"] /= df["n_responses"]
|
|
|
|