enumeration_poisson_direct_dev.stan 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int vtc,
  6. array[,] int truth,
  7. array[] real clip_duration,
  8. array[] matrix lambda,
  9. array[] vector lambda_fp
  10. ) {
  11. real ll = 0;
  12. vector [4] bp;
  13. vector[16384] log_contrib_comb;
  14. int n = size(log_contrib_comb);
  15. for (k in start:end) {
  16. for (i in 1:n_classes) {
  17. log_contrib_comb[:n] = rep_vector(0, n);
  18. n = 1;
  19. for (chi in 0:(truth[k,1]>0?max(truth[k,1], vtc[k,i]):0)) {
  20. bp[1] = truth[k,1]==0?0:poisson_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]);
  21. for (och in 0:(truth[k,2]>0?max(truth[k,2], vtc[k,i]-chi):0)) {
  22. bp[2] = truth[k,2]==0?0:poisson_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]);
  23. for (fem in 0:(truth[k,3]>0?max(truth[k,3], vtc[k,i]-chi-och):0)) {
  24. bp[3] = truth[k,3]==0?0:poisson_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]);
  25. for (mal in 0:(truth[k,4]>0?max(truth[k,4], vtc[k,i]-chi-och-fem):0)) {
  26. bp[4] = truth[k,4]==0?0:poisson_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]);
  27. int delta = vtc[k,i] - (mal+fem+och+chi);
  28. if (delta >= 0) {
  29. log_contrib_comb[n] += sum(bp);
  30. log_contrib_comb[n] += poisson_lpmf(
  31. delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  32. );
  33. n = n+1;
  34. }
  35. }
  36. }
  37. }
  38. }
  39. if (n>1) {
  40. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  41. }
  42. }
  43. }
  44. return ll;
  45. }
  46. real model_lpmf(array[] int children,
  47. int start, int end,
  48. int n_recs,
  49. int n_classes,
  50. real duration,
  51. array [,] int vocs,
  52. matrix truth_vocs,
  53. array [] matrix actual_confusion,
  54. array [] vector actual_fp_rate,
  55. real mu_c,
  56. real mu_a,
  57. vector tc,
  58. vector ta,
  59. real conv,
  60. vector phi
  61. ) {
  62. real ll = 0;
  63. vector [4] expect;
  64. //vector [4] sd;
  65. for (k in start:end) {
  66. expect = rep_vector(0, 4);
  67. for (i in 1:n_classes) {
  68. expect[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i]);
  69. expect[i] += actual_fp_rate[k,i] * duration;
  70. }
  71. ll += normal_lpdf(vocs[k,:] | expect, sqrt(expect));
  72. ll += normal_lpdf(
  73. truth_vocs[k,1]/1000 | mu_c+tc[k] + conv*(truth_vocs[k,3]+truth_vocs[k,4])/1000, phi[1]
  74. );
  75. ll += normal_lpdf(
  76. (truth_vocs[k,3]+truth_vocs[k,4])/1000 | mu_a+ta[children[k-start+1]], phi[2]
  77. );
  78. }
  79. return ll;
  80. }
  81. }
  82. // TODO
  83. // use speech rates to set priors on truth_vocs
  84. data {
  85. int<lower=1> n_classes; // number of classes
  86. // analysis data block
  87. int<lower=1> n_recs;
  88. int<lower=1> n_children;
  89. array[n_recs] int<lower=1> children;
  90. array[n_recs] real<lower=1> age;
  91. array[n_recs, n_classes] int<lower=0> vocs;
  92. array[n_children] int<lower=1> corpus;
  93. real<lower=0> recs_duration;
  94. // speaker confusion data block
  95. int<lower=1> n_clips; // number of clips
  96. int<lower=1> n_groups; // number of groups
  97. int<lower=1> n_corpora;
  98. array [n_clips] int group;
  99. array [n_clips] int conf_corpus;
  100. array [n_clips,n_classes] int<lower=0> vtc_total; // vtc vocs attributed to specific speakers
  101. array [n_clips,n_classes] int<lower=0> truth_total;
  102. array [n_clips] real<lower=0> clip_duration;
  103. int<lower=0> n_validation;
  104. // actual speech rates
  105. int<lower=1> n_rates;
  106. array [n_rates,n_classes] int<lower=0> speech_rates;
  107. array [n_rates] int group_corpus;
  108. array [n_rates] real<lower=0> durations;
  109. // parallel processing
  110. int<lower=1> threads;
  111. }
  112. parameters {
  113. // analysis
  114. matrix<lower=0> [n_recs, n_classes] truth_vocs;
  115. array [n_recs] matrix<lower=0>[n_classes,n_classes] actual_confusion;
  116. array [n_recs] vector<lower=0>[n_classes] actual_fp_rate;
  117. real mu_a; // average ADU speech log
  118. real<lower=0> sigma_a; // ADU speech log dispersion
  119. vector [n_children] z_ta; // ADU speech z-score
  120. real mu_c; // average CHI speech log
  121. vector<lower=0>[2] phi; // speech log dispersion
  122. real mu_beta_0; // average effect of age
  123. real<lower=0> sigma_beta_0; // effect of age disperson
  124. vector [n_children] z_beta_0; // effect of age z-score
  125. real beta_a; // effect of excess ADU speech
  126. real conv; // conversational/contextual effects
  127. // confusion parameters
  128. matrix<lower=1>[n_classes,n_classes] alphas;
  129. matrix<lower=0>[n_classes,n_classes] mus;
  130. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  131. vector<lower=1>[n_classes] alphas_fp;
  132. vector<lower=0>[n_classes] mus_fp;
  133. array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  134. // speech rates
  135. matrix<lower=1>[n_classes,n_corpora] speech_rate_alpha;
  136. matrix<lower=0>[n_classes,n_corpora] speech_rate_mu;
  137. vector<lower=0>[n_classes] speech_rate_mu_prior;
  138. matrix<lower=0> [n_classes,n_rates] speech_rate;
  139. }
  140. transformed parameters {
  141. vector [n_children] beta_0;
  142. vector [n_recs] tc;
  143. vector [n_children] ta;
  144. beta_0 = mu_beta_0 + z_beta_0*sigma_beta_0;
  145. ta = z_ta*sigma_a;
  146. for (k in 1:n_recs) {
  147. tc[k] = beta_0[children[k]]*age[k]/12 + beta_a*ta[children[k]]*sigma_a*age[k]/12;
  148. }
  149. }
  150. model {
  151. //actual model
  152. target += reduce_sum(
  153. model_lpmf, children, n_recs%/%(threads*2),
  154. n_recs, n_classes, recs_duration,
  155. vocs,
  156. truth_vocs, actual_confusion, actual_fp_rate,
  157. mu_c, mu_a, tc, ta, conv, phi
  158. );
  159. mu_c ~ normal(0, 5);
  160. mu_a ~ normal(0, 5);
  161. sigma_a ~ exponential(1);
  162. z_ta ~ normal(0, 1);
  163. mu_beta_0 ~ normal(0, 1);
  164. sigma_beta_0 ~ exponential(1);
  165. z_beta_0 ~ normal(0, 1);
  166. beta_a ~ normal(0, 1);
  167. conv ~ normal(0, 1);
  168. phi ~ exponential(1);
  169. for (k in 1:n_recs) {
  170. for (i in 1:n_classes) {
  171. actual_confusion[k,i] ~ gamma(alphas[i,:], alphas[i,:]./mus[i,:]);
  172. }
  173. actual_fp_rate[k] ~ gamma(alphas_fp, alphas_fp./mus_fp);
  174. }
  175. for (k in 1:n_recs) {
  176. truth_vocs[k,:] ~ gamma(
  177. speech_rate_alpha[:,corpus[children[k]]],
  178. (speech_rate_alpha[:,corpus[children[k]]]./speech_rate_mu[:,corpus[children[k]]])/1000/recs_duration
  179. );
  180. }
  181. target += reduce_sum(
  182. confusion_model_lpmf, group, n_clips%/%(threads*4),
  183. n_classes,
  184. vtc_total, truth_total, clip_duration,
  185. lambda, lambda_fp
  186. );
  187. mus_fp ~ exponential(1);
  188. alphas_fp ~ normal(1, 1);
  189. for (i in 1:n_classes) {
  190. lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  191. for (j in 1:n_classes) {
  192. mus[i,j] ~ exponential(2);
  193. alphas[i,j] ~ normal(1,1);
  194. for (c in 1:n_groups) {
  195. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  196. }
  197. }
  198. }
  199. // speech rates
  200. speech_rate_mu_prior ~ exponential(2);
  201. for (i in 1:n_classes) {
  202. speech_rate_alpha[i,:] ~ normal(1, 1);
  203. speech_rate_mu[i,:] ~ gamma(4, 4/speech_rate_mu_prior[i]);
  204. }
  205. for (g in 1:n_rates) {
  206. for (i in 1:n_classes) {
  207. speech_rate[i,g] ~ gamma(
  208. speech_rate_alpha[i,group_corpus[g]],
  209. (speech_rate_alpha[i,group_corpus[g]]/speech_rate_mu[i,group_corpus[g]])/1000
  210. );
  211. speech_rates[g,i] ~ poisson(speech_rate[i,g]*durations[g]);
  212. }
  213. }
  214. }