enumeration_poisson_correlation.stan 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int vtc,
  6. array[,] int truth,
  7. array[] real clip_duration,
  8. array[] matrix lambda,
  9. array[] vector lambda_fp
  10. ) {
  11. real ll = 0;
  12. vector [4] bp;
  13. vector[16384] log_contrib_comb;
  14. int n = size(log_contrib_comb);
  15. for (k in start:end) {
  16. for (i in 1:n_classes) {
  17. log_contrib_comb[:n] = rep_vector(0, n);
  18. n = 1;
  19. for (chi in 0:(truth[k,1]>0?max(truth[k,1], vtc[k,i]):0)) {
  20. bp[1] = truth[k,1]==0?0:poisson_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]);
  21. for (och in 0:(truth[k,2]>0?max(truth[k,2], vtc[k,i]-chi):0)) {
  22. bp[2] = truth[k,2]==0?0:poisson_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]);
  23. for (fem in 0:(truth[k,3]>0?max(truth[k,3], vtc[k,i]-chi-och):0)) {
  24. bp[3] = truth[k,3]==0?0:poisson_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]);
  25. for (mal in 0:(truth[k,4]>0?max(truth[k,4], vtc[k,i]-chi-och-fem):0)) {
  26. bp[4] = truth[k,4]==0?0:poisson_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]);
  27. int delta = vtc[k,i] - (mal+fem+och+chi);
  28. if (delta >= 0) {
  29. log_contrib_comb[n] += sum(bp);
  30. log_contrib_comb[n] += poisson_lpmf(
  31. delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  32. );
  33. n = n+1;
  34. }
  35. }
  36. }
  37. }
  38. }
  39. if (n>1) {
  40. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  41. }
  42. }
  43. }
  44. return ll;
  45. }
  46. real model_lpmf(array[] int children,
  47. int start, int end,
  48. int n_recs,
  49. int n_classes,
  50. real duration,
  51. array [,] int vocs,
  52. matrix truth_vocs,
  53. array [] matrix actual_confusion,
  54. array [] vector actual_fp_rate
  55. ) {
  56. real ll = 0;
  57. vector [4] expect;
  58. //vector [4] sd;
  59. for (k in start:end) {
  60. expect = rep_vector(0, 4);
  61. //sd = rep_vector(0, 4);
  62. for (i in 1:n_classes) {
  63. expect[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i]);
  64. expect[i] += actual_fp_rate[k,i] * duration;
  65. }
  66. ll += normal_lpdf(vocs[k,:] | expect, sqrt(expect));
  67. }
  68. return ll;
  69. }
  70. }
  71. // TODO
  72. // use speech rates to set priors on truth_vocs
  73. data {
  74. int<lower=1> n_classes; // number of classes
  75. // analysis data block
  76. int<lower=1> n_recs;
  77. int<lower=1> n_children;
  78. array[n_recs] int<lower=1> children;
  79. array[n_recs] real<lower=1> age;
  80. array[n_recs, n_classes] int<lower=0> vocs;
  81. array[n_children] int<lower=1> corpus;
  82. real<lower=0> recs_duration;
  83. // speaker confusion data block
  84. int<lower=1> n_clips; // number of clips
  85. int<lower=1> n_groups; // number of groups
  86. int<lower=1> n_corpora;
  87. array [n_clips] int group;
  88. array [n_clips] int conf_corpus;
  89. array [n_clips,n_classes] int<lower=0> vtc_total; // vtc vocs attributed to specific speakers
  90. array [n_clips,n_classes] int<lower=0> truth_total;
  91. array [n_clips] real<lower=0> clip_duration;
  92. int<lower=1> n_validation;
  93. // actual speech rates
  94. int<lower=1> n_rates;
  95. array [n_rates,n_classes] int<lower=0> speech_rates;
  96. array [n_rates] int group_corpus;
  97. array [n_rates] real<lower=0> durations;
  98. // parallel processing
  99. int<lower=1> threads;
  100. }
  101. parameters {
  102. matrix<lower=0> [n_recs, n_classes] truth_vocs;
  103. array [n_recs] vector[n_classes*n_classes] log_actual_confusion;
  104. array [n_recs] vector<lower=0>[n_classes] actual_fp_rate;
  105. // confusion parameters
  106. vector[n_classes*n_classes] mus;
  107. cholesky_factor_corr[n_classes*n_classes] L_Omega;
  108. vector<lower=0>[n_classes*n_classes] L_sigma;
  109. array [n_groups] vector[n_classes*n_classes] log_lambda;
  110. vector<lower=1>[n_classes] alphas_fp;
  111. vector<lower=0>[n_classes] mus_fp;
  112. array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  113. //array [n_corpora] matrix[n_classes,n_classes] corpus_bias;
  114. //matrix<lower=0>[n_classes,n_classes] corpus_sigma;
  115. // speech rates
  116. matrix<lower=1>[n_classes,n_corpora] speech_rate_alpha;
  117. matrix<lower=0>[n_classes,n_corpora] speech_rate_mu;
  118. matrix<lower=0> [n_classes,n_rates] speech_rate;
  119. }
  120. transformed parameters {
  121. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  122. array [n_recs] matrix<lower=0>[n_classes,n_classes] actual_confusion;
  123. for (i in 1:n_classes) {
  124. for (j in 1:n_classes) {
  125. for (c in 1:n_groups) {
  126. lambda[c,i,j] = exp(log_lambda[c,i+n_classes*(j-1)]);
  127. }
  128. for (k in 1:n_recs) {
  129. actual_confusion[k,i,j] = exp(log_actual_confusion[k,i+n_classes*(j-1)]);
  130. }
  131. }
  132. }
  133. }
  134. model {
  135. matrix[n_classes*n_classes, n_classes*n_classes] L_Sigma = diag_pre_multiply(L_sigma, L_Omega);
  136. //actual model
  137. //target += reduce_sum(
  138. // model_lpmf, children, 1,
  139. // n_recs, n_classes, recs_duration,
  140. // vocs,
  141. // truth_vocs, actual_confusion, actual_fp_rate
  142. //);
  143. for (k in 1:n_recs) {
  144. log_actual_confusion[k] ~ multi_normal_cholesky(mus, L_Sigma);
  145. actual_fp_rate[k] ~ gamma(alphas_fp, alphas_fp./mus_fp);
  146. }
  147. for (k in 1:n_recs) {
  148. truth_vocs[k,:] ~ gamma(
  149. speech_rate_alpha[:,corpus[children[k]]],
  150. (speech_rate_alpha[:,corpus[children[k]]]./speech_rate_mu[:,corpus[children[k]]])/1000/recs_duration
  151. );
  152. }
  153. target += reduce_sum(
  154. confusion_model_lpmf, group, n_clips%/%(threads*4),
  155. n_classes,
  156. vtc_total, truth_total, clip_duration,
  157. lambda, lambda_fp
  158. );
  159. mus_fp ~ exponential(1);
  160. alphas_fp ~ normal(1, 1);
  161. for (i in 1:n_classes) {
  162. lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  163. }
  164. exp(mus) ~ exponential(1);
  165. L_Omega ~ lkj_corr_cholesky(2);
  166. L_sigma ~ exponential(5);
  167. for (c in 1:n_groups) {
  168. log_lambda[c] ~ multi_normal_cholesky(mus, L_Sigma);
  169. }
  170. // speech rates
  171. for (i in 1:n_classes) {
  172. speech_rate_alpha[i,:] ~ normal(1, 1);
  173. speech_rate_mu[i,:] ~ exponential(2);
  174. }
  175. for (g in 1:n_rates) {
  176. for (i in 1:n_classes) {
  177. speech_rate[i,g] ~ gamma(
  178. speech_rate_alpha[i,group_corpus[g]],
  179. (speech_rate_alpha[i,group_corpus[g]]/speech_rate_mu[i,group_corpus[g]])/1000
  180. );
  181. speech_rates[g,i] ~ poisson(speech_rate[i,g]*durations[g]);
  182. }
  183. }
  184. }