Browse Source

[DATALAD] Recorded changes

Lucas Gautheron 3 months ago
parent
commit
a924323421
2 changed files with 35 additions and 7 deletions
  1. 6 2
      votes_distrib.py
  2. 29 5
      votes_distrib.stan

+ 6 - 2
votes_distrib.py

@@ -15,6 +15,7 @@ def order(x,l):
 
 parser = argparse.ArgumentParser()
 parser.add_argument("--key")
+parser.add_argument("--threads-per-chain", type=int, default=4)
 parser.add_argument("--normalize", action="store_true", default=False)
 args = parser.parse_args()
 
@@ -54,11 +55,12 @@ from cmdstanpy import CmdStanModel
 
 model = CmdStanModel(
     stan_file=f"votes_distrib.stan",
+    cpp_options={"STAN_THREADS": "TRUE"},
 )
 fit = model.sample(
     data=data,
-    chains=4,
-    threads_per_chain=1,
+    chains=2,
+    threads_per_chain=args.threads_per_chain,
     iter_sampling=1000,
     iter_warmup=250,
     show_console=True,
@@ -69,6 +71,8 @@ samples = {}
 for (k, v) in vars.items():
     samples[k] = v
 
+np.savez('votes_distrib.npz', **samples)
+
 # if args.normalize:
 #     df["n"] /= df["n_responses"]
 

+ 29 - 5
votes_distrib.stan

@@ -1,20 +1,44 @@
+functions {
+    real model_lpmf(array[] int counts,
+        int start, int end,
+        array[,] int votes,
+        vector alphas
+        ) {
+            real ll = 0;
+            for (k in start:end) {
+                ll += dirichlet_multinomial_lpmf(votes[k] | alphas);
+            }
+            return ll;
+        }
+}
+
 data {
     int<lower=1> N;
     int<lower=1> C;
     array[N,C] int<lower=0> votes;
 }
 
+transformed data {
+    array[N] int counts;
+    for (i in 1:N) {
+        counts[i] = sum(votes[i]);
+    }
+}
+
 parameters {
     vector<lower=0>[C] alphas;
     //array[N] simplex[C] p;
 }
 
 model {
-    for (i in 1:N) {
-        votes[i] ~ multinomial_dirichlet(alphas);
-        //votes[i] ~ multinomial(p[i]);
-        //p[i] ~ dirichlet(alphas);
-    }
+    target += reduce_sum(
+      model_lpmf, counts, 1,
+      votes, alphas
+    );
 
     alphas ~ exponential(1);
 }
+
+generated quantities {
+    simplex[C] p = dirichlet_rng(alphas);
+}