dev_siblings.stan 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int vtc,
  6. array[,] int truth,
  7. array[] real age,
  8. array[] real clip_duration,
  9. array[] matrix lambda,
  10. array[] vector lambda_fp
  11. ) {
  12. real ll = 0;
  13. vector [4] bp;
  14. real lambda_chi;
  15. vector[16384] log_contrib_comb;
  16. int n = size(log_contrib_comb);
  17. for (k in start:end) {
  18. for (i in 1:n_classes) {
  19. log_contrib_comb[:n] = rep_vector(0, n);
  20. n = 1;
  21. for (chi in 0:(truth[k,1]>0?max(truth[k,1], vtc[k,i]):0)) {
  22. bp[1] = truth[k,1]==0?0:poisson_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]);
  23. for (och in 0:(truth[k,2]>0?max(truth[k,2], vtc[k,i]-chi):0)) {
  24. bp[2] = truth[k,2]==0?0:poisson_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]);
  25. for (fem in 0:(truth[k,3]>0?max(truth[k,3], vtc[k,i]-chi-och):0)) {
  26. bp[3] = truth[k,3]==0?0:poisson_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]);
  27. for (mal in 0:(truth[k,4]>0?max(truth[k,4], vtc[k,i]-chi-och-fem):0)) {
  28. bp[4] = truth[k,4]==0?0:poisson_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]);
  29. int delta = vtc[k,i] - (mal+fem+och+chi);
  30. if (delta >= 0) {
  31. log_contrib_comb[n] += sum(bp);
  32. log_contrib_comb[n] += poisson_lpmf(
  33. delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  34. );
  35. n = n+1;
  36. }
  37. }
  38. }
  39. }
  40. }
  41. if (n>1) {
  42. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  43. }
  44. }
  45. }
  46. return ll;
  47. }
  48. real model_lpmf(array[] int children,
  49. int start, int end,
  50. int n_recs,
  51. int n_classes,
  52. real duration,
  53. array [,] int vocs,
  54. array [] real age,
  55. matrix truth_vocs,
  56. array [] matrix actual_confusion,
  57. array [] vector actual_fp_rate
  58. ) {
  59. real ll = 0;
  60. vector [4] expect;
  61. //vector [4] sd;
  62. for (k in start:end) {
  63. expect = rep_vector(0, 4);
  64. //sd = rep_vector(0, 4);
  65. for (i in 1:n_classes) {
  66. expect[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i]);
  67. expect[i] += actual_fp_rate[k,i] * duration;
  68. }
  69. ll += normal_lpdf(vocs[k,:] | expect, sqrt(expect));
  70. }
  71. return ll;
  72. }
  73. }
  74. // TODO
  75. // use speech rates to set priors on truth_vocs
  76. data {
  77. int<lower=1> n_classes; // number of classes
  78. // analysis data block
  79. int<lower=1> n_recs;
  80. int<lower=1> n_children;
  81. array[n_recs] int<lower=1> children;
  82. array[n_recs] real<lower=1> age;
  83. array[n_recs] int<lower=-1> siblings;
  84. array[n_recs, n_classes] int<lower=0> vocs;
  85. array[n_children] int<lower=1> corpus;
  86. real<lower=0> recs_duration;
  87. // speaker confusion data block
  88. int<lower=1> n_clips; // number of clips
  89. int<lower=1> n_groups; // number of groups
  90. int<lower=1> n_corpora;
  91. array [n_clips] int group;
  92. array [n_clips] int conf_corpus;
  93. array [n_clips,n_classes] int<lower=0> vtc_total; // vtc vocs attributed to specific speakers
  94. array [n_clips,n_classes] int<lower=0> truth_total;
  95. array [n_clips] real<lower=0> clip_duration;
  96. array [n_clips] real<lower=0> clip_age;
  97. int<lower=0> n_validation;
  98. // actual speech rates
  99. int<lower=1> n_rates;
  100. int<lower=1> n_speech_rate_children;
  101. array [n_rates,n_classes] int<lower=0> speech_rates;
  102. array [n_rates] int group_corpus;
  103. array [n_rates] real<lower=0> durations;
  104. array [n_rates] real<lower=0> speech_rate_age;
  105. array [n_rates] int<lower=-1> speech_rate_siblings;
  106. array [n_rates] int<lower=1,upper=n_speech_rate_children> speech_rate_child;
  107. // parallel processing
  108. int<lower=1> threads;
  109. }
  110. transformed data {
  111. vector<lower=0>[n_groups] recording_age;
  112. array[n_speech_rate_children] int<lower=1> speech_rate_child_corpus;
  113. array[n_children] int<lower=-1> child_siblings;
  114. array[n_speech_rate_children] int<lower=-1> speech_rate_child_siblings;
  115. for (c in 1:n_clips) {
  116. recording_age[group[c]] = clip_age[c];
  117. }
  118. for (k in 1:n_rates) {
  119. speech_rate_child_corpus[speech_rate_child[k]] = group_corpus[k];
  120. }
  121. for (k in 1:n_recs) {
  122. child_siblings[children[k]] = siblings[k];
  123. }
  124. for (k in 1:n_rates) {
  125. speech_rate_child_siblings[speech_rate_child[k]] = speech_rate_siblings[k];
  126. }
  127. }
  128. parameters {
  129. matrix<lower=0>[n_children,n_classes-1] mu_child_level;
  130. vector [n_children] child_dev_age;
  131. matrix<lower=0> [n_recs, n_classes] truth_vocs;
  132. // nuisance parameters
  133. array [n_recs] matrix<lower=0>[n_classes,n_classes] actual_confusion_baseline;
  134. array [n_recs] vector<lower=0>[n_classes] actual_fp_rate;
  135. // confusion parameters
  136. // confusion matrix
  137. matrix<lower=0>[n_classes,n_classes] alphas;
  138. matrix<lower=0>[n_classes,n_classes] mus;
  139. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  140. // false positives
  141. vector<lower=0>[n_classes] alphas_fp;
  142. vector<lower=0>[n_classes] mus_fp;
  143. array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  144. // speech rates
  145. vector<lower=0>[n_classes] alpha_child_level; // variance across recordings for a given child
  146. matrix<lower=0>[n_classes-1,n_corpora] alpha_corpus_level; // variance among children
  147. matrix<lower=0>[n_classes-1,n_corpora] mu_corpus_level; // child-level average
  148. vector<lower=0>[n_classes-1] alpha_pop_level; // variance among corpora
  149. vector<lower=0>[n_classes] mu_pop_level; // population level averages
  150. vector<lower=0>[n_classes-1] alpha_pop;
  151. matrix<lower=0>[n_classes,n_rates] speech_rate; // truth speech rates observed in annotated clips
  152. matrix<lower=0>[n_speech_rate_children,n_classes-1] speech_rate_child_level; // expected speech rate at the child-level
  153. real<lower=0> beta_sib_och; // effect of n of siblings on OCH speech
  154. vector [n_speech_rate_children] child_dev_speech_age;
  155. // average effect of age
  156. real alpha_dev;
  157. real<lower=0> sigma_dev;
  158. // effect of excess ADU input
  159. real beta_dev;
  160. }
  161. model {
  162. //actual model
  163. target += reduce_sum(
  164. model_lpmf, children, 1,
  165. n_recs, n_classes, recs_duration,
  166. vocs, age,
  167. truth_vocs, actual_confusion_baseline, actual_fp_rate
  168. );
  169. for (k in 1:n_recs) {
  170. for (i in 1:n_classes) {
  171. if (i == 1) {
  172. actual_confusion_baseline[k,i] ~ gamma(alphas[i,:], alphas[i,:]./mus[i,:]);
  173. //actual_confusion_baseline[k,i] ~ gamma(alphas[i,:], alphas[i,:]./(mus[i,:].*exp(delta_chi_age'*age[k]/12.0))); //'
  174. }
  175. else {
  176. actual_confusion_baseline[k,i] ~ gamma(alphas[i,:], alphas[i,:]./mus[i,:]);
  177. }
  178. }
  179. actual_fp_rate[k] ~ gamma(alphas_fp, alphas_fp./mus_fp);
  180. }
  181. for (k in 1:n_recs) {
  182. real chi_mu = mu_pop_level[1]*exp(
  183. child_dev_age[children[k]]*age[k]/12.0/10.0+beta_dev*(mu_child_level[children[k],2]+mu_child_level[children[k],3]-mu_pop_level[3]-mu_pop_level[4])*age[k]/12.0/10.0
  184. );
  185. (truth_vocs[k,1]/1000/recs_duration) ~ gamma(
  186. alpha_child_level[1],
  187. alpha_child_level[1]/chi_mu
  188. );
  189. real och_mu = mu_child_level[children[k],1] * exp(
  190. (child_siblings[children[k]]>0?beta_sib_och:0)
  191. );
  192. (truth_vocs[k,2]/1000/recs_duration) ~ gamma(
  193. alpha_child_level[2],
  194. alpha_child_level[2]/och_mu
  195. );
  196. (truth_vocs[k,3:]/1000/recs_duration) ~ gamma(
  197. alpha_child_level[3:], alpha_child_level[3:]./mu_child_level[children[k],2:]' //'
  198. );
  199. }
  200. for (c in 1:n_children) {
  201. mu_child_level[c] ~ gamma(
  202. alpha_corpus_level[:,corpus[c]],
  203. (alpha_corpus_level[:,corpus[c]]./mu_corpus_level[:,corpus[c]])
  204. );
  205. }
  206. alpha_child_level ~ gamma(2,1);
  207. target += reduce_sum(
  208. confusion_model_lpmf, group, n_clips%/%(threads*4),
  209. n_classes,
  210. vtc_total, truth_total, clip_duration, clip_age,
  211. lambda, lambda_fp
  212. );
  213. mus_fp ~ exponential(1);
  214. alphas_fp ~ gamma(2, 1);
  215. for (i in 1:n_classes) {
  216. lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  217. for (j in 1:n_classes) {
  218. mus[i,j] ~ exponential(i==j?2:8);
  219. alphas[i,j] ~ gamma(2,1);
  220. for (c in 1:n_groups) {
  221. if (i==1) {
  222. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  223. //lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/(mus[i,j]*exp(delta_chi_age[j]*recording_age[c]/12.0)));
  224. }
  225. else {
  226. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  227. }
  228. }
  229. }
  230. }
  231. //delta_chi_age ~ normal(0, 0.1);
  232. // speech rates
  233. mu_pop_level ~ exponential(4);
  234. alpha_pop_level ~ gamma(8, 4);
  235. alpha_pop ~ gamma(10, 10);
  236. for (i in 1:n_classes-1) {
  237. alpha_corpus_level[i,:] ~ gamma(4, 4/alpha_pop[i]);
  238. mu_corpus_level[i,:] ~ gamma(alpha_pop_level[i],alpha_pop_level[i]/mu_pop_level[i+1]);
  239. }
  240. for (g in 1:n_rates) {
  241. real chi_mu = mu_pop_level[1]*exp(
  242. child_dev_speech_age[speech_rate_child[g]]*speech_rate_age[g]/12.0/10.0 + beta_dev*(speech_rate_child_level[speech_rate_child[g],2]+speech_rate_child_level[speech_rate_child[g],3]-mu_pop_level[3]-mu_pop_level[4])*speech_rate_age[g]/12.0/10.0
  243. );
  244. speech_rate[1,g] ~ gamma(
  245. alpha_child_level[1],
  246. alpha_child_level[1]/chi_mu
  247. );
  248. real och_mu = speech_rate_child_level[speech_rate_child[g],1]*exp(
  249. (speech_rate_child_siblings[speech_rate_child[g]]>0?beta_sib_och:0)
  250. );
  251. speech_rate[2,g] ~ gamma(
  252. alpha_child_level[2],
  253. alpha_child_level[2]/och_mu
  254. );
  255. speech_rate[3:,g] ~ gamma(
  256. alpha_child_level[3:],
  257. (alpha_child_level[3:]./(speech_rate_child_level[speech_rate_child[g],2:]')) //'
  258. );
  259. speech_rates[g,:] ~ poisson(speech_rate[:,g]*durations[g]*1000);
  260. }
  261. for (c in 1:n_speech_rate_children) {
  262. speech_rate_child_level[c,:] ~ gamma(
  263. alpha_corpus_level[:,speech_rate_child_corpus[c]],
  264. (alpha_corpus_level[:,speech_rate_child_corpus[c]]./(mu_corpus_level[:,speech_rate_child_corpus[c]]))
  265. );
  266. }
  267. child_dev_age ~ normal(alpha_dev, sigma_dev);
  268. child_dev_speech_age ~ normal(alpha_dev, sigma_dev);
  269. beta_sib_och ~ exponential(1);
  270. alpha_dev ~ normal(0, 1);
  271. sigma_dev ~ exponential(1);
  272. beta_dev ~ normal(0, 1);
  273. }