cross_validation_overdispersion.stan 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int algo,
  6. array[,] int truth,
  7. array[] real clip_duration,
  8. array[] matrix lambda,
  9. matrix omega//,
  10. //array[] vector lambda_fp,
  11. ) {
  12. real ll = 0;
  13. vector [4] bp;
  14. vector[8192] log_contrib_comb;
  15. int n = size(log_contrib_comb);
  16. for (k in start:end) {
  17. for (i in 1:n_classes) {
  18. log_contrib_comb[:n] = rep_vector(0, n);
  19. n = 1;
  20. for (chi in 0:(truth[k,1]>0?max(truth[k,1], algo[k,i]):0)) {
  21. bp[1] = truth[k,1]==0?0:neg_binomial_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]/(omega[1,i]-1), 1/(omega[1,i]-1));
  22. for (och in 0:(truth[k,2]>0?max(truth[k,2], algo[k,i]-chi):0)) {
  23. bp[2] = truth[k,2]==0?0:neg_binomial_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]/(omega[2,i]-1), 1/(omega[2,i]-1));
  24. for (fem in 0:(truth[k,3]>0?max(truth[k,3], algo[k,i]-chi-och):0)) {
  25. bp[3] = truth[k,3]==0?0:neg_binomial_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]/(omega[3,i]-1), 1/(omega[3,i]-1));
  26. for (mal in 0:(truth[k,4]>0?max(truth[k,4], algo[k,i]-chi-och-fem):0)) {
  27. bp[4] = truth[k,4]==0?0:neg_binomial_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]/(omega[4,i]-1), 1/(omega[4,i]-1));
  28. int delta = algo[k,i] - (mal+fem+och+chi);
  29. // if (delta >= 0) {
  30. // log_contrib_comb[n] += sum(bp);
  31. // log_contrib_comb[n] += poisson_lpmf(
  32. // delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  33. // );
  34. // n = n+1;
  35. // }
  36. if (delta==0) {
  37. log_contrib_comb[n] += sum(bp);
  38. n = n+1;
  39. }
  40. }
  41. }
  42. }
  43. }
  44. if (n>1) {
  45. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  46. }
  47. }
  48. }
  49. return ll;
  50. }
  51. }
  52. // TODO
  53. // use speech rates to set priors on truth_vocs
  54. data {
  55. int<lower=1> n_classes; // number of classes
  56. // analysis data block
  57. int<lower=1> n_recs;
  58. int<lower=1> n_children;
  59. array[n_recs] int<lower=1> children;
  60. array[n_recs] real<lower=1> age;
  61. array[n_recs] int<lower=-1> siblings;
  62. array[n_recs, n_classes] int<lower=0> vocs;
  63. array[n_children] int<lower=1> corpus;
  64. real<lower=0> recs_duration;
  65. // speaker confusion data block
  66. int<lower=1> n_clips; // number of clips
  67. int<lower=1> n_groups; // number of groups
  68. int<lower=1> n_corpora;
  69. array [n_clips] int group;
  70. array [n_clips] int conf_corpus;
  71. array [n_clips,n_classes] int<lower=0> algo_total; // algo vocs attributed to specific speakers
  72. array [n_clips,n_classes] int<lower=0> truth_total;
  73. array [n_clips] real<lower=0> clip_duration;
  74. array [n_clips] real<lower=0> clip_age;
  75. int<lower=0> n_validation;
  76. // parallel processing
  77. int<lower=1> threads;
  78. }
  79. transformed data {
  80. vector<lower=0>[n_groups] recording_age;
  81. array[n_children] int<lower=-1> child_siblings;
  82. int no_siblings = 0;
  83. int has_siblings = 0;
  84. array [n_groups,n_classes] int group_truth;
  85. for (c in 1:n_clips) {
  86. recording_age[group[c]] = clip_age[c];
  87. }
  88. for (k in 1:n_recs) {
  89. child_siblings[children[k]] = siblings[k];
  90. }
  91. for (c in 1:n_children) {
  92. if (child_siblings[c] == 0) {
  93. no_siblings += 1;
  94. }
  95. else if (child_siblings[c] > 0) {
  96. has_siblings += 1;
  97. }
  98. }
  99. for (i in 1:n_classes) {
  100. for (g in 1:n_groups) {
  101. group_truth[g,i] = 0;
  102. }
  103. for (c in 1:n_clips) {
  104. group_truth[group[c],i] += truth_total[c,i];
  105. }
  106. }
  107. }
  108. parameters {
  109. // confusion parameters
  110. // confusion matrix
  111. matrix<lower=0>[n_classes,n_classes] alphas;
  112. matrix<lower=0>[n_classes,n_classes] mus;
  113. matrix<lower=1>[n_classes,n_classes] omega;
  114. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  115. // false positives
  116. //vector<lower=0>[n_classes] alphas_fp;
  117. //vector<lower=0>[n_classes] mus_fp;
  118. //array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  119. }
  120. model {
  121. target += reduce_sum(
  122. confusion_model_lpmf, group, n_clips%/%(threads*4),
  123. n_classes,
  124. algo_total, truth_total, clip_duration,
  125. lambda, omega//, lambda_fp
  126. );
  127. //mus_fp ~ exponential(1);
  128. //alphas_fp ~ gamma(2, 1);
  129. for (i in 1:n_classes) {
  130. //lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  131. omega[i,:] ~ pareto(1, 2);
  132. for (j in 1:n_classes) {
  133. mus[i,j] ~ exponential(2);
  134. alphas[i,j] ~ inv_gamma(1, 1);
  135. for (c in 1:n_groups) {
  136. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  137. }
  138. }
  139. }
  140. }
  141. generated quantities {
  142. array [n_groups,n_classes] int sim_vocs_given_lambda;
  143. array [n_groups,n_classes] int sim_vocs;
  144. for (g in 1:n_groups) {
  145. for (i in 1:n_classes) {
  146. vector[n_classes] mu_given_lambda = lambda[g,:,i].*to_vector(group_truth[g,:]);
  147. vector[n_classes] mu = to_vector(gamma_rng(alphas[:,i],alphas[:,i]./mus[:,i])).*to_vector(group_truth[g,:]);
  148. sim_vocs_given_lambda[g,i] = 0;
  149. sim_vocs[g,i] = 0;
  150. for (j in 1:n_classes) {
  151. if (group_truth[g,j] > 0) {
  152. sim_vocs_given_lambda[g,i] += neg_binomial_rng(mu_given_lambda[j]/(omega[j,i]-1), 1/(omega[j,i]-1));
  153. sim_vocs[g,i] += neg_binomial_rng(mu[j]/(omega[j,i]-1), 1/(omega[j,i]-1));
  154. }
  155. }
  156. }
  157. }
  158. }