cross_validation_multinomial.stan 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. functions {
  2. #include "blocks/confusion_model_multinomial.stan"
  3. }
  4. // TODO
  5. // use speech rates to set priors on truth_vocs
  6. data {
  7. int<lower=1> n_classes; // number of classes
  8. // speaker confusion data block
  9. int<lower=1> n_clips; // number of clips
  10. int<lower=1> n_groups; // number of groups
  11. int<lower=1> n_corpora;
  12. array [n_clips] int group;
  13. array [n_clips] int conf_corpus;
  14. array [n_clips,n_classes] int<lower=0> algo_total; // algo vocs attributed to specific speakers
  15. array [n_clips,n_classes] int<lower=0> truth_total;
  16. array [n_clips] real<lower=0> clip_duration;
  17. array [n_clips] real<lower=0> clip_age;
  18. int<lower=0> n_validation;
  19. // parallel processing
  20. int<lower=1> threads;
  21. }
  22. transformed data {
  23. array [n_groups,n_classes] int group_truth;
  24. for (i in 1:n_classes) {
  25. for (g in 1:n_groups) {
  26. group_truth[g,i] = 0;
  27. }
  28. for (c in 1:n_clips) {
  29. group_truth[group[c],i] += truth_total[c,i];
  30. }
  31. }
  32. }
  33. parameters {
  34. #include "blocks/confusion_model_parameters_multinomial.stan"
  35. }
  36. model {
  37. target += reduce_sum(
  38. confusion_model_lpdf, lambda, 1,
  39. n_classes, n_clips,
  40. algo_total, truth_total, group, clip_duration, clip_age
  41. );
  42. #include "blocks/confusion_model_priors_multinomial.stan"
  43. }
  44. generated quantities {
  45. array [n_groups,n_classes] int sim_vocs_given_lambda;
  46. array [n_groups,n_classes] int sim_vocs;
  47. for (g in 1:n_groups) {
  48. array [n_classes,n_classes+1] int n_given_lambda;
  49. array [n_classes,n_classes+1] int n;
  50. for (i in 1:n_classes) {
  51. if (group_truth[g,i]>0) {
  52. n_given_lambda[i] = multinomial_rng(lambda[g,i,:], group_truth[g,i]);
  53. n[i] = multinomial_rng(dirichlet_rng(mus[i]*etas[i]), group_truth[g,i]);
  54. }
  55. else {
  56. for (j in 1:n_classes+1) {
  57. n_given_lambda[i,j] = 0;
  58. n[i,j] = 0;
  59. }
  60. }
  61. }
  62. for (i in 1:n_classes) {
  63. sim_vocs_given_lambda[g,i] = sum(n_given_lambda[:,i]);
  64. sim_vocs[g,i] = sum(n[:,i]);
  65. }
  66. }
  67. }