confusion_overdispersion.stan 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int algo,
  6. array[,] int truth,
  7. array[] real clip_duration,
  8. array[] matrix lambda,
  9. matrix omega//,
  10. //array[] vector lambda_fp,
  11. ) {
  12. real ll = 0;
  13. vector [4] bp;
  14. vector[8192] log_contrib_comb;
  15. int n = size(log_contrib_comb);
  16. for (k in start:end) {
  17. for (i in 1:n_classes) {
  18. log_contrib_comb[:n] = rep_vector(0, n);
  19. n = 1;
  20. for (chi in 0:(truth[k,1]>0?max(truth[k,1], algo[k,i]):0)) {
  21. bp[1] = truth[k,1]==0?0:neg_binomial_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]/(omega[1,i]-1), 1/(omega[1,i]-1));
  22. for (och in 0:(truth[k,2]>0?max(truth[k,2], algo[k,i]-chi):0)) {
  23. bp[2] = truth[k,2]==0?0:neg_binomial_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]/(omega[2,i]-1), 1/(omega[2,i]-1));
  24. for (fem in 0:(truth[k,3]>0?max(truth[k,3], algo[k,i]-chi-och):0)) {
  25. bp[3] = truth[k,3]==0?0:neg_binomial_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]/(omega[3,i]-1), 1/(omega[3,i]-1));
  26. for (mal in 0:(truth[k,4]>0?max(truth[k,4], algo[k,i]-chi-och-fem):0)) {
  27. bp[4] = truth[k,4]==0?0:neg_binomial_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]/(omega[4,i]-1), 1/(omega[4,i]-1));
  28. int delta = algo[k,i] - (mal+fem+och+chi);
  29. // if (delta >= 0) {
  30. // log_contrib_comb[n] += sum(bp);
  31. // log_contrib_comb[n] += poisson_lpmf(
  32. // delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  33. // );
  34. // n = n+1;
  35. // }
  36. if (delta==0) {
  37. log_contrib_comb[n] += sum(bp);
  38. n = n+1;
  39. }
  40. }
  41. }
  42. }
  43. }
  44. if (n>1) {
  45. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  46. }
  47. }
  48. }
  49. return ll;
  50. }
  51. }
  52. // TODO
  53. // use speech rates to set priors on truth_vocs
  54. data {
  55. int<lower=1> n_classes; // number of classes
  56. // analysis data block
  57. int<lower=1> n_recs;
  58. int<lower=1> n_children;
  59. array[n_recs] int<lower=1> children;
  60. array[n_recs] real<lower=1> age;
  61. array[n_recs] int<lower=-1> siblings;
  62. array[n_recs, n_classes] int<lower=0> vocs;
  63. array[n_children] int<lower=1> corpus;
  64. real<lower=0> recs_duration;
  65. // speaker confusion data block
  66. int<lower=1> n_clips; // number of clips
  67. int<lower=1> n_groups; // number of groups
  68. int<lower=1> n_corpora;
  69. array [n_clips] int group;
  70. array [n_clips] int conf_corpus;
  71. array [n_clips,n_classes] int<lower=0> algo_total; // algo vocs attributed to specific speakers
  72. array [n_clips,n_classes] int<lower=0> truth_total;
  73. array [n_clips] real<lower=0> clip_duration;
  74. array [n_clips] real<lower=0> clip_age;
  75. array [n_clips] int<lower=0> clip_rural;
  76. int<lower=0> n_validation;
  77. // parallel processing
  78. int<lower=1> threads;
  79. }
  80. transformed data {
  81. vector<lower=0>[n_groups] recording_age;
  82. array [n_groups] int<lower=0> recording_rural;
  83. for (c in 1:n_clips) {
  84. recording_age[group[c]] = clip_age[c];
  85. recording_rural[group[c]] = clip_rural[c];
  86. }
  87. }
  88. parameters {
  89. // confusion matrix
  90. array[2] matrix<lower=0>[n_classes,n_classes] alphas;
  91. array[2] matrix<lower=0>[n_classes,n_classes] mus;
  92. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  93. //matrix<lower=0>[n_classes,n_classes] conf_sd;
  94. matrix<lower=1>[n_classes,n_classes] omega;
  95. }
  96. transformed parameters {
  97. //matrix<lower=1>[n_classes,n_classes] omega = exp(conf_sd/10);
  98. }
  99. model {
  100. //actual model
  101. target += reduce_sum(
  102. confusion_model_lpmf, group, n_clips%/%(threads*4),
  103. n_classes,
  104. algo_total, truth_total, clip_duration,
  105. lambda, omega//, lambda_fp
  106. );
  107. //mus_fp ~ exponential(1);
  108. //alphas_fp ~ gamma(2, 1);
  109. for (i in 1:n_classes) {
  110. //conf_sd[i,:] ~ normal(0, 1);
  111. omega[i,:] ~ pareto(1, 2);
  112. //lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  113. for (j in 1:n_classes) {
  114. mus[1,i,j] ~ exponential(1); // urban
  115. mus[2,i,j] ~ exponential(1); // rural
  116. alphas[1,i,j] ~ lognormal(0,1); // urban
  117. alphas[2,i,j] ~ lognormal(0,1); // rural
  118. for (c in 1:n_groups) {
  119. lambda[c,i,j] ~ gamma(alphas[recording_rural[c]>0?2:1,i,j], alphas[recording_rural[c]>0?2:1,i,j]/(mus[recording_rural[c]>0?2:1,i,j]));
  120. }
  121. }
  122. }
  123. }