enumeration_poisson.stan 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int vtc,
  6. array[,] int truth,
  7. array[] real clip_duration,
  8. array[] matrix lambda,
  9. array[] vector lambda_fp
  10. ) {
  11. real ll = 0;
  12. vector [4] bp;
  13. vector[16384] log_contrib_comb;
  14. int n = size(log_contrib_comb);
  15. for (k in start:end) {
  16. for (i in 1:n_classes) {
  17. log_contrib_comb[:n] = rep_vector(0, n);
  18. n = 1;
  19. for (chi in 0:(truth[k,1]>0?max(truth[k,1], vtc[k,i]):0)) {
  20. bp[1] = truth[k,1]==0?0:poisson_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]);
  21. for (och in 0:(truth[k,2]>0?max(truth[k,2], vtc[k,i]-chi):0)) {
  22. bp[2] = truth[k,2]==0?0:poisson_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]);
  23. for (fem in 0:(truth[k,3]>0?max(truth[k,3], vtc[k,i]-chi-och):0)) {
  24. bp[3] = truth[k,3]==0?0:poisson_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]);
  25. for (mal in 0:(truth[k,4]>0?max(truth[k,4], vtc[k,i]-chi-och-fem):0)) {
  26. bp[4] = truth[k,4]==0?0:poisson_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]);
  27. int delta = vtc[k,i] - (mal+fem+och+chi);
  28. if (delta >= 0) {
  29. log_contrib_comb[n] += sum(bp);
  30. log_contrib_comb[n] += poisson_lpmf(
  31. delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  32. );
  33. n = n+1;
  34. }
  35. }
  36. }
  37. }
  38. }
  39. if (n>1) {
  40. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  41. }
  42. }
  43. }
  44. return ll;
  45. }
  46. real model_lpmf(array[] int children,
  47. int start, int end,
  48. int n_recs,
  49. int n_classes,
  50. real duration,
  51. array [,] int vocs,
  52. matrix truth_vocs,
  53. array [] matrix actual_confusion,
  54. array [] vector actual_fp_rate
  55. ) {
  56. real ll = 0;
  57. vector [4] expect;
  58. //vector [4] sd;
  59. for (k in start:end) {
  60. expect = rep_vector(0, 4);
  61. //sd = rep_vector(0, 4);
  62. for (i in 1:n_classes) {
  63. expect[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i]);
  64. expect[i] += actual_fp_rate[k,i] * duration;
  65. }
  66. ll += normal_lpdf(vocs[k,:] | expect, sqrt(expect));
  67. }
  68. return ll;
  69. }
  70. }
  71. // TODO
  72. // use speech rates to set priors on truth_vocs
  73. data {
  74. int<lower=1> n_classes; // number of classes
  75. // analysis data block
  76. int<lower=1> n_recs;
  77. int<lower=1> n_children;
  78. array[n_recs] int<lower=1> children;
  79. array[n_recs] real<lower=1> age;
  80. array[n_recs, n_classes] int<lower=0> vocs;
  81. array[n_children] int<lower=1> corpus;
  82. real<lower=0> recs_duration;
  83. // speaker confusion data block
  84. int<lower=1> n_clips; // number of clips
  85. int<lower=1> n_groups; // number of groups
  86. int<lower=1> n_corpora;
  87. array [n_clips] int group;
  88. array [n_clips] int conf_corpus;
  89. array [n_clips,n_classes] int<lower=0> vtc_total; // vtc vocs attributed to specific speakers
  90. array [n_clips,n_classes] int<lower=0> truth_total;
  91. array [n_clips] real<lower=0> clip_duration;
  92. int<lower=1> n_validation;
  93. // actual speech rates
  94. int<lower=1> n_rates;
  95. array [n_rates,n_classes] int<lower=0> speech_rates;
  96. array [n_rates] int group_corpus;
  97. array [n_rates] real<lower=0> durations;
  98. // parallel processing
  99. int<lower=1> threads;
  100. }
  101. parameters {
  102. matrix<lower=0> [n_recs, n_classes] truth_vocs;
  103. //array [n_children] matrix<lower=0>[n_classes,n_classes] actual_confusion_baseline;
  104. array [n_recs] matrix<lower=0>[n_classes,n_classes] actual_confusion_baseline;
  105. array [n_recs] vector<lower=0>[n_classes] actual_fp_rate;
  106. // confusion parameters
  107. matrix<lower=1>[n_classes,n_classes] alphas;
  108. matrix<lower=0>[n_classes,n_classes] mus;
  109. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  110. vector<lower=1>[n_classes] alphas_fp;
  111. vector<lower=0>[n_classes] mus_fp;
  112. array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  113. //array [n_corpora] matrix[n_classes,n_classes] corpus_bias;
  114. //matrix<lower=0>[n_classes,n_classes] corpus_sigma;
  115. // speech rates
  116. matrix<lower=1>[n_classes,n_corpora] speech_rate_alpha;
  117. matrix<lower=0>[n_classes,n_corpora] speech_rate_mu;
  118. matrix<lower=0> [n_classes,n_rates] speech_rate;
  119. }
  120. transformed parameters {
  121. // array [n_children] matrix<lower=0,upper=1>[n_classes,n_classes] actual_confusion;
  122. // for (c in 1:n_children) {
  123. // actual_confusion[c] = inv_logit(logit(actual_confusion_baseline[c])+corpus_bias[corpus[c]]);
  124. // }
  125. }
  126. model {
  127. //actual model
  128. target += reduce_sum(
  129. model_lpmf, children, 1,
  130. n_recs, n_classes, recs_duration,
  131. vocs,
  132. truth_vocs, actual_confusion_baseline, actual_fp_rate
  133. );
  134. for (k in 1:n_recs) {
  135. for (i in 1:n_classes) {
  136. actual_confusion_baseline[k,i] ~ gamma(alphas[i,:], alphas[i,:]./mus[i,:]);
  137. }
  138. actual_fp_rate[k] ~ gamma(alphas_fp, alphas_fp./mus_fp);
  139. }
  140. for (k in 1:n_recs) {
  141. truth_vocs[k,:] ~ gamma(
  142. speech_rate_alpha[:,corpus[children[k]]],
  143. (speech_rate_alpha[:,corpus[children[k]]]./speech_rate_mu[:,corpus[children[k]]])/1000/recs_duration
  144. );
  145. }
  146. target += reduce_sum(
  147. confusion_model_lpmf, group, n_clips%/%(threads*4),
  148. n_classes,
  149. vtc_total, truth_total, clip_duration,
  150. lambda, lambda_fp
  151. );
  152. mus_fp ~ exponential(1);
  153. alphas_fp ~ normal(1, 1);
  154. for (i in 1:n_classes) {
  155. lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  156. for (j in 1:n_classes) {
  157. mus[i,j] ~ exponential(2);
  158. alphas[i,j] ~ normal(1,1);
  159. for (c in 1:n_groups) {
  160. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  161. }
  162. }
  163. }
  164. // speech rates
  165. for (i in 1:n_classes) {
  166. speech_rate_alpha[i,:] ~ normal(1, 1);
  167. speech_rate_mu[i,:] ~ exponential(2);
  168. }
  169. for (g in 1:n_rates) {
  170. for (i in 1:n_classes) {
  171. speech_rate[i,g] ~ gamma(
  172. speech_rate_alpha[i,group_corpus[g]],
  173. (speech_rate_alpha[i,group_corpus[g]]/speech_rate_mu[i,group_corpus[g]])/1000
  174. );
  175. speech_rates[g,i] ~ poisson(speech_rate[i,g]*durations[g]);
  176. }
  177. }
  178. }