dev_siblings_overdispersion.stan 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. functions {
  2. real confusion_model_lpmf(array[] int group,
  3. int start, int end,
  4. int n_classes,
  5. array[,] int algo,
  6. array[,] int truth,
  7. array[] real age,
  8. array[] real clip_duration,
  9. array[] matrix lambda,
  10. matrix omega//,
  11. //array[] vector lambda_fp,
  12. ) {
  13. real ll = 0;
  14. vector [4] bp;
  15. vector[8192] log_contrib_comb;
  16. int n = size(log_contrib_comb);
  17. for (k in start:end) {
  18. for (i in 1:n_classes) {
  19. log_contrib_comb[:n] = rep_vector(0, n);
  20. n = 1;
  21. for (chi in 0:(truth[k,1]>0?max(truth[k,1], algo[k,i]):0)) {
  22. bp[1] = truth[k,1]==0?0:neg_binomial_lpmf(chi | truth[k,1]*lambda[group[k-start+1],1,i]/(omega[1,i]-1), 1/(omega[1,i]-1));
  23. for (och in 0:(truth[k,2]>0?max(truth[k,2], algo[k,i]-chi):0)) {
  24. bp[2] = truth[k,2]==0?0:neg_binomial_lpmf(och | truth[k,2]*lambda[group[k-start+1],2,i]/(omega[2,i]-1), 1/(omega[2,i]-1));
  25. for (fem in 0:(truth[k,3]>0?max(truth[k,3], algo[k,i]-chi-och):0)) {
  26. bp[3] = truth[k,3]==0?0:neg_binomial_lpmf(fem | truth[k,3]*lambda[group[k-start+1],3,i]/(omega[3,i]-1), 1/(omega[3,i]-1));
  27. for (mal in 0:(truth[k,4]>0?max(truth[k,4], algo[k,i]-chi-och-fem):0)) {
  28. bp[4] = truth[k,4]==0?0:neg_binomial_lpmf(mal | truth[k,4]*lambda[group[k-start+1],4,i]/(omega[4,i]-1), 1/(omega[4,i]-1));
  29. int delta = algo[k,i] - (mal+fem+och+chi);
  30. // if (delta >= 0) {
  31. // log_contrib_comb[n] += sum(bp);
  32. // log_contrib_comb[n] += poisson_lpmf(
  33. // delta | lambda_fp[group[k-start+1],i]*clip_duration[k]
  34. // );
  35. // n = n+1;
  36. // }
  37. if (delta==0) {
  38. log_contrib_comb[n] += sum(bp);
  39. n = n+1;
  40. }
  41. }
  42. }
  43. }
  44. }
  45. if (n>1) {
  46. ll += log_sum_exp(log_contrib_comb[1:n-1]);
  47. }
  48. }
  49. }
  50. return ll;
  51. }
  52. real inverse_model_lpmf(array[] int children,
  53. int start, int end,
  54. int n_recs,
  55. int n_classes,
  56. real duration,
  57. array [,] int vocs,
  58. array [] real age,
  59. matrix truth_vocs,
  60. array [] matrix actual_confusion,
  61. //array [] vector actual_fp_rate,
  62. matrix mus,
  63. matrix alphas,
  64. matrix omega//,
  65. //vector mus_fp,
  66. //vector alphas_fp
  67. ) {
  68. real ll = 0;
  69. vector [4] expect;
  70. vector [4] sd;
  71. for (k in start:end) {
  72. expect = rep_vector(0, 4);
  73. for (i in 1:n_classes) {
  74. ll += gamma_lpdf(actual_confusion[k,i] | alphas[i,:], alphas[i,:]./mus[i,:]);
  75. //ll += gamma_lpdf(actual_fp_rate[k] | alphas_fp, alphas_fp./mus_fp);
  76. expect[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i]);
  77. sd[i] = dot_product(truth_vocs[k,:], actual_confusion[k,:,i].*omega[:,i]);
  78. //expect[i] += actual_fp_rate[k,i] * duration;
  79. }
  80. ll += normal_lpdf(vocs[k,:] | expect, sqrt(sd));
  81. }
  82. return ll;
  83. }
  84. real recs_priors_lpmf(array[] int children,
  85. int start, int end,
  86. int n_recs,
  87. int n_classes,
  88. real recs_duration,
  89. array [] real age,
  90. matrix truth_vocs,
  91. vector mu_pop_level,
  92. matrix mu_child_level,
  93. vector alpha_child_level,
  94. vector child_dev_age,
  95. real beta_dev
  96. ) {
  97. real ll = 0;
  98. for (k in start:end) {
  99. real chi_mu = mu_pop_level[1]*exp(
  100. child_dev_age[children[k-start+1]]*age[k]/12.0/10.0+beta_dev*(mu_child_level[children[k-start+1],2]+mu_child_level[children[k-start+1],3]-mu_pop_level[3]-mu_pop_level[4])*age[k]/12.0/10.0
  101. );
  102. ll += gamma_lpdf(
  103. truth_vocs[k,1]/1000/recs_duration | alpha_child_level[1], alpha_child_level[1]/chi_mu
  104. );
  105. ll += gamma_lpdf(
  106. truth_vocs[k,2:]/1000/recs_duration | alpha_child_level[2:], alpha_child_level[2:]./mu_child_level[children[k-start+1],:]' //'
  107. );
  108. }
  109. return ll;
  110. }
  111. }
  112. // TODO
  113. // use speech rates to set priors on truth_vocs
  114. data {
  115. int<lower=1> n_classes; // number of classes
  116. // analysis data block
  117. int<lower=1> n_recs;
  118. int<lower=1> n_children;
  119. array[n_recs] int<lower=1> children;
  120. array[n_recs] real<lower=1> age;
  121. array[n_recs] int<lower=-1> siblings;
  122. array[n_recs, n_classes] int<lower=0> vocs;
  123. array[n_children] int<lower=1> corpus;
  124. real<lower=0> recs_duration;
  125. // speaker confusion data block
  126. int<lower=1> n_clips; // number of clips
  127. int<lower=1> n_groups; // number of groups
  128. int<lower=1> n_corpora;
  129. array [n_clips] int group;
  130. array [n_clips] int conf_corpus;
  131. array [n_clips,n_classes] int<lower=0> algo_total; // algo vocs attributed to specific speakers
  132. array [n_clips,n_classes] int<lower=0> truth_total;
  133. array [n_clips] real<lower=0> clip_duration;
  134. array [n_clips] real<lower=0> clip_age;
  135. int<lower=0> n_validation;
  136. // actual speech rates
  137. int<lower=1> n_rates;
  138. int<lower=1> n_speech_rate_children;
  139. array [n_rates,n_classes] int<lower=0> speech_rates;
  140. array [n_rates] int group_corpus;
  141. array [n_rates] real<lower=0> durations;
  142. array [n_rates] real<lower=0> speech_rate_age;
  143. array [n_rates] int<lower=-1> speech_rate_siblings;
  144. array [n_rates] int<lower=1,upper=n_speech_rate_children> speech_rate_child;
  145. // parallel processing
  146. int<lower=1> threads;
  147. }
  148. transformed data {
  149. vector<lower=0>[n_groups] recording_age;
  150. array[n_speech_rate_children] int<lower=1> speech_rate_child_corpus;
  151. array[n_children] int<lower=-1> child_siblings;
  152. array[n_speech_rate_children] int<lower=-1> speech_rate_child_siblings;
  153. int no_siblings = 0;
  154. int has_siblings = 0;
  155. for (c in 1:n_clips) {
  156. recording_age[group[c]] = clip_age[c];
  157. }
  158. for (k in 1:n_rates) {
  159. speech_rate_child_corpus[speech_rate_child[k]] = group_corpus[k];
  160. }
  161. for (k in 1:n_recs) {
  162. child_siblings[children[k]] = siblings[k];
  163. }
  164. for (c in 1:n_children) {
  165. if (child_siblings[c] == 0) {
  166. no_siblings += 1;
  167. }
  168. else if (child_siblings[c] > 0) {
  169. has_siblings += 1;
  170. }
  171. }
  172. for (k in 1:n_rates) {
  173. speech_rate_child_siblings[speech_rate_child[k]] = speech_rate_siblings[k];
  174. }
  175. }
  176. parameters {
  177. matrix<lower=0>[n_children,n_classes-1] mu_child_level;
  178. vector [n_children] child_dev_age;
  179. matrix<lower=0> [n_recs, n_classes] truth_vocs;
  180. // nuisance parameters
  181. array [n_recs] matrix<lower=0>[n_classes,n_classes] actual_confusion_baseline;
  182. //array [n_recs] vector<lower=0>[n_classes] actual_fp_rate;
  183. // confusion parameters
  184. // confusion matrix
  185. matrix<lower=0>[n_classes,n_classes] alphas;
  186. matrix<lower=0>[n_classes,n_classes] mus;
  187. matrix<lower=0>[n_classes,n_classes] conf_sd;
  188. array [n_groups] matrix<lower=0>[n_classes,n_classes] lambda;
  189. // false positives
  190. //vector<lower=0>[n_classes] alphas_fp;
  191. //vector<lower=0>[n_classes] mus_fp;
  192. //array [n_groups] vector<lower=0>[n_classes] lambda_fp;
  193. // speech rates
  194. vector<lower=0>[n_classes] alpha_child_level; // variance across recordings for a given child
  195. array[2] matrix<lower=0>[n_classes-1,n_corpora] alpha_corpus_level; // variance among children
  196. matrix<lower=0>[n_classes-1,n_corpora] mu_corpus_level; // child-level average
  197. vector<lower=0>[n_classes-1] alpha_pop_level; // variance among corpora
  198. vector<lower=0>[n_classes] mu_pop_level; // population level averages
  199. vector<lower=0>[n_classes-1] alpha_pop;
  200. matrix<lower=0>[n_classes,n_rates] speech_rate; // truth speech rates observed in annotated clips
  201. matrix<lower=0>[n_speech_rate_children,n_classes-1] speech_rate_child_level; // expected speech rate at the child-level
  202. // siblings
  203. real beta_sib_och; // effect of having siblings on OCH speech
  204. real beta_sib_adu; // effect of having siblings on ADU speech
  205. real<lower=0,upper=1> p_sib; // prob of having siblings
  206. vector [n_speech_rate_children] child_dev_speech_age;
  207. // average effect of age
  208. real alpha_dev;
  209. real<lower=0> sigma_dev;
  210. // effect of excess ADU input
  211. real beta_dev;
  212. }
  213. model {
  214. //actual model
  215. matrix[n_classes,n_classes] omega = exp(conf_sd/10);
  216. // inverse confusion model
  217. target += reduce_sum(
  218. inverse_model_lpmf, children, 1,
  219. n_recs, n_classes, recs_duration,
  220. vocs, age,
  221. truth_vocs, actual_confusion_baseline, mus, alphas, omega//, mus_fp, alphas_fp
  222. );
  223. // priors on actual speech
  224. target += reduce_sum(
  225. recs_priors_lpmf, children, 1,
  226. n_recs, n_classes, recs_duration, age,
  227. truth_vocs,
  228. mu_pop_level, mu_child_level, alpha_child_level,
  229. child_dev_age, beta_dev
  230. );
  231. vector [2] ll;
  232. int distrib;
  233. for (c in 1:n_children) {
  234. // if there is sibling data
  235. if (child_siblings[c]>=0) {
  236. distrib = child_siblings[c]>0?2:1;
  237. mu_child_level[c,1] ~ gamma(
  238. alpha_corpus_level[distrib,1,corpus[c]],
  239. (alpha_corpus_level[distrib,1,corpus[c]]/(mu_corpus_level[1,corpus[c]]*exp(
  240. child_siblings[c]>0?beta_sib_och:0
  241. )))
  242. );
  243. mu_child_level[c,2:] ~ gamma(
  244. alpha_corpus_level[distrib,2:,corpus[c]],
  245. (alpha_corpus_level[distrib,2:,corpus[c]]./mu_corpus_level[2:,corpus[c]]*exp(
  246. child_siblings[c]>0?beta_sib_adu:0
  247. ))
  248. );
  249. }
  250. // otherwise
  251. else {
  252. // assuming no sibling
  253. ll[1] = log(p_sib)+gamma_lpdf(
  254. mu_child_level[c,1] | alpha_corpus_level[2,1,corpus[c]], alpha_corpus_level[2,1,corpus[c]]/(mu_corpus_level[1,corpus[c]]*exp(beta_sib_och))
  255. );
  256. ll[1] += gamma_lpdf(
  257. mu_child_level[c,2] | alpha_corpus_level[2,2,corpus[c]], alpha_corpus_level[2,2,corpus[c]]/(mu_corpus_level[2,corpus[c]]*exp(beta_sib_adu))
  258. );
  259. ll[1] += gamma_lpdf(
  260. mu_child_level[c,3] | alpha_corpus_level[2,3,corpus[c]], alpha_corpus_level[2,3,corpus[c]]/(mu_corpus_level[3,corpus[c]]*exp(beta_sib_adu))
  261. );
  262. // assuming sibling
  263. ll[2] = log(1-p_sib)+gamma_lpdf(
  264. mu_child_level[c,1] | alpha_corpus_level[1,1,corpus[c]], alpha_corpus_level[1,1,corpus[c]]/(mu_corpus_level[1,corpus[c]])
  265. );
  266. ll[2] += gamma_lpdf(
  267. mu_child_level[c,2] | alpha_corpus_level[1,2,corpus[c]], alpha_corpus_level[1,2,corpus[c]]/(mu_corpus_level[2,corpus[c]])
  268. );
  269. ll[2] += gamma_lpdf(
  270. mu_child_level[c,3] | alpha_corpus_level[1,3,corpus[c]], alpha_corpus_level[1,3,corpus[c]]/(mu_corpus_level[3,corpus[c]])
  271. );
  272. target += log_sum_exp(ll);
  273. }
  274. }
  275. alpha_child_level ~ gamma(2,1);
  276. target += reduce_sum(
  277. confusion_model_lpmf, group, n_clips%/%(threads*4),
  278. n_classes,
  279. algo_total, truth_total, clip_duration, clip_age,
  280. lambda, omega//, lambda_fp
  281. );
  282. //mus_fp ~ exponential(1);
  283. //alphas_fp ~ gamma(2, 1);
  284. for (i in 1:n_classes) {
  285. //lambda_fp[:,i] ~ gamma(alphas_fp[i], alphas_fp[i]/mus_fp[i]);
  286. conf_sd[i,:] ~ normal(0, 1);
  287. for (j in 1:n_classes) {
  288. mus[i,j] ~ exponential(i==j?2:8);
  289. alphas[i,j] ~ gamma(2,1);
  290. // mus[i,j] ~ exponential(1);
  291. // alphas[i,j] ~ exponential(1);
  292. for (c in 1:n_groups) {
  293. lambda[c,i,j] ~ gamma(alphas[i,j], alphas[i,j]/mus[i,j]);
  294. }
  295. }
  296. }
  297. // speech rates
  298. mu_pop_level ~ exponential(4); // 250 vocs/hour
  299. alpha_pop_level ~ gamma(8, 4); // sd = 0.35 x \mu
  300. alpha_pop ~ gamma(10, 10);
  301. for (i in 1:n_classes-1) {
  302. alpha_corpus_level[1,i,:] ~ gamma(4, 4/alpha_pop[i]);
  303. alpha_corpus_level[2,i,:] ~ gamma(4, 4/alpha_pop[i]);
  304. mu_corpus_level[i,:] ~ gamma(alpha_pop_level[i],alpha_pop_level[i]/mu_pop_level[i+1]);
  305. }
  306. for (g in 1:n_rates) {
  307. real chi_mu = mu_pop_level[1]*exp(
  308. child_dev_speech_age[speech_rate_child[g]]*speech_rate_age[g]/12.0/10.0 + beta_dev*(speech_rate_child_level[speech_rate_child[g],2]+speech_rate_child_level[speech_rate_child[g],3]-mu_pop_level[3]-mu_pop_level[4])*speech_rate_age[g]/12.0/10.0
  309. );
  310. speech_rate[1,g] ~ gamma(
  311. alpha_child_level[1],
  312. alpha_child_level[1]/chi_mu
  313. );
  314. speech_rate[2:,g] ~ gamma(
  315. alpha_child_level[2:],
  316. (alpha_child_level[2:]./(speech_rate_child_level[speech_rate_child[g],:]')) //'
  317. );
  318. speech_rates[g,:] ~ poisson(speech_rate[:,g]*durations[g]*1000);
  319. }
  320. for (c in 1:n_speech_rate_children) {
  321. distrib = child_siblings[c]>0?2:1;
  322. speech_rate_child_level[c,1] ~ gamma(
  323. alpha_corpus_level[distrib,1,speech_rate_child_corpus[c]],
  324. (alpha_corpus_level[distrib,1,speech_rate_child_corpus[c]]/(mu_corpus_level[1,speech_rate_child_corpus[c]]*exp(
  325. speech_rate_child_siblings[c]>0?beta_sib_och:0
  326. )))
  327. );
  328. speech_rate_child_level[c,2:] ~ gamma(
  329. alpha_corpus_level[distrib,2:,speech_rate_child_corpus[c]],
  330. (alpha_corpus_level[distrib,2:,speech_rate_child_corpus[c]]./(mu_corpus_level[2:,speech_rate_child_corpus[c]]*exp(
  331. speech_rate_child_siblings[c]>0?beta_sib_adu:0
  332. )))
  333. );
  334. }
  335. child_dev_age ~ normal(alpha_dev, sigma_dev);
  336. child_dev_speech_age ~ normal(alpha_dev, sigma_dev);
  337. has_siblings ~ binomial(has_siblings+no_siblings, p_sib);
  338. p_sib ~ uniform(0, 1);
  339. beta_sib_och ~ normal(0, 1);
  340. beta_sib_adu ~ normal(0, 1);
  341. alpha_dev ~ normal(0, 1);
  342. sigma_dev ~ exponential(1);
  343. beta_dev ~ normal(0, 1);
  344. }