main.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. import pandas as pd
  12. import pickle
  13. from pyannote.core import Annotation, Segment, Timeline
  14. import stan
  15. parser = argparse.ArgumentParser(description = 'main model described throughout the notes.')
  16. parser.add_argument('--group', default = 'child', choices = ['corpus', 'child'])
  17. parser.add_argument('--chains', default = 4, type = int)
  18. parser.add_argument('--samples', default = 2000, type = int)
  19. parser.add_argument('--validation', default = 0, type = float)
  20. parser.add_argument('--output', default = 'model3')
  21. args = parser.parse_args()
  22. def extrude(self, removed, mode: str = 'intersection'):
  23. if isinstance(removed, Segment):
  24. removed = Timeline([removed])
  25. truncating_support = removed.gaps(support=self.extent())
  26. # loose for truncate means strict for crop and vice-versa
  27. if mode == "loose":
  28. mode = "strict"
  29. elif mode == "strict":
  30. mode = "loose"
  31. return self.crop(truncating_support, mode=mode)
  32. def compute_counts(parameters):
  33. corpus = parameters['corpus']
  34. annotator = parameters['annotator']
  35. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  36. project = ChildProject(parameters['path'])
  37. am = AnnotationManager(project)
  38. am.read()
  39. intersection = AnnotationManager.intersection(
  40. am.annotations, ['vtc', annotator]
  41. )
  42. intersection['path'] = intersection.apply(
  43. lambda r: opj(project.path, 'annotations', r['set'], 'converted', r['annotation_filename']),
  44. axis = 1
  45. )
  46. datalad.api.get(list(intersection['path'].unique()))
  47. intersection = intersection.merge(project.recordings[['recording_filename', 'child_id']], how = 'left')
  48. intersection['child'] = corpus + '_' + intersection['child_id'].astype(str)
  49. intersection['duration'] = intersection['range_offset']-intersection['range_onset']
  50. print(corpus, annotator, (intersection['duration']/1000/2).sum()/3600)
  51. data = []
  52. for child, ann in intersection.groupby('child'):
  53. #print(corpus, child)
  54. segments = am.get_collapsed_segments(ann)
  55. if 'speaker_type' not in segments.columns:
  56. continue
  57. segments = segments[segments['speaker_type'].isin(speakers)]
  58. vtc = {
  59. speaker: segments_to_annotation(segments[(segments['set'] == 'vtc') & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  60. for speaker in speakers
  61. }
  62. truth = {
  63. speaker: segments_to_annotation(segments[(segments['set'] == annotator) & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  64. for speaker in speakers
  65. }
  66. for speaker_A in speakers:
  67. vtc[f'{speaker_A}_vocs_explained'] = vtc[speaker_A].crop(truth[speaker_A], mode = 'loose')
  68. vtc[f'{speaker_A}_vocs_fp'] = extrude(vtc[speaker_A], vtc[f'{speaker_A}_vocs_explained'])
  69. vtc[f'{speaker_A}_vocs_fn'] = extrude(truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode = 'loose'))
  70. for speaker_B in speakers:
  71. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'] = vtc[f'{speaker_A}_vocs_fp'].crop(truth[speaker_B], mode = 'loose')
  72. for speaker_C in speakers:
  73. if speaker_C != speaker_B and speaker_C != speaker_A:
  74. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'] = extrude(
  75. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'],
  76. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'].crop(truth[speaker_C], mode = 'loose')
  77. )
  78. d = {}
  79. for i, speaker_A in enumerate(speakers):
  80. for j, speaker_B in enumerate(speakers):
  81. if i != j:
  82. z = len(vtc[f'{speaker_A}_vocs_fp_{speaker_B}'])
  83. else:
  84. z = min(len(vtc[f'{speaker_A}_vocs_explained']), len(truth[speaker_A]))
  85. d[f'vtc_{i}_{j}'] = z
  86. d[f'truth_{i}'] = len(truth[speaker_A])
  87. d['child'] = child
  88. d['duration'] = ann['duration'].sum()/2/1000
  89. data.append(d)
  90. return pd.DataFrame(data).assign(
  91. corpus = corpus,
  92. )
  93. stan_code = """
  94. data {
  95. int<lower=1> n_clips; // number of clips
  96. int<lower=1> n_groups; // number of groups
  97. int<lower=1> n_classes; // number of classes
  98. int group[n_clips];
  99. int vtc[n_clips,n_classes,n_classes];
  100. int truth[n_clips,n_classes];
  101. int<lower=1> n_validation;
  102. int<lower=1> n_sim;
  103. real<lower=0> rates_alphas[n_classes];
  104. real<lower=0> rates_betas[n_classes];
  105. }
  106. parameters {
  107. matrix<lower=0,upper=1>[n_classes,n_classes] mus;
  108. matrix<lower=1>[n_classes,n_classes] etas;
  109. matrix<lower=0,upper=1>[n_classes,n_classes] group_confusion[n_groups];
  110. }
  111. transformed parameters {
  112. matrix<lower=0>[n_classes,n_classes] alphas;
  113. matrix<lower=0>[n_classes,n_classes] betas;
  114. alphas = mus * etas;
  115. betas = (1-mus) * etas;
  116. }
  117. model {
  118. for (k in n_validation:n_clips) {
  119. for (i in 1:n_classes) {
  120. for (j in 1:n_classes) {
  121. vtc[k,i,j] ~ binomial(truth[k,j], group_confusion[group[k],j,i]);
  122. }
  123. }
  124. }
  125. for (i in 1:n_classes) {
  126. for (j in 1:n_classes) {
  127. mus[i,j] ~ beta(1,1);
  128. etas[i,j] ~ pareto(1,1.5);
  129. }
  130. }
  131. for (c in 1:n_groups) {
  132. for (i in 1:n_classes) {
  133. for (j in 1:n_classes) {
  134. group_confusion[c,i,j] ~ beta(alphas[i,j], betas[i,j]);
  135. }
  136. }
  137. }
  138. }
  139. generated quantities {
  140. int pred[n_clips,n_classes,n_classes];
  141. matrix[n_classes,n_classes] probs[n_groups];
  142. matrix[n_classes,n_classes] log_lik[n_clips];
  143. int sim_truth[n_sim,n_classes];
  144. int sim_vtc[n_sim,n_classes];
  145. vector[n_classes] lambdas;
  146. real chi_adu_coef;
  147. if (uniform_rng(0,1) > 0.99) {
  148. chi_adu_coef = uniform_rng(0,1);
  149. }
  150. else {
  151. chi_adu_coef = 0;
  152. }
  153. for (c in 1:n_groups) {
  154. for (i in 1:n_classes) {
  155. for (j in 1:n_classes) {
  156. probs[c,i,j] = beta_rng(alphas[i,j], betas[i,j]);
  157. }
  158. }
  159. }
  160. for (k in 1:n_clips) {
  161. for (i in 1:n_classes) {
  162. for (j in 1:n_classes) {
  163. if (k >= n_validation) {
  164. pred[k,i,j] = binomial_rng(truth[k,j], group_confusion[group[k],i,j]);
  165. log_lik[k,i,j] = binomial_lpmf(vtc[k,i,j] | truth[k,j], group_confusion[group[k],j,i]);
  166. }
  167. else {
  168. pred[k,i,j] = binomial_rng(truth[k,j], probs[group[k],j,i]);
  169. log_lik[k,i,j] = beta_lpdf(probs[group[k],j,i] | alphas[j,i], betas[j,i]);
  170. log_lik[k,i,j] += binomial_lpmf(vtc[k,i,j] | truth[k,j], probs[group[k],j,i]);
  171. }
  172. }
  173. }
  174. }
  175. real lambda;
  176. for (k in 1:n_sim) {
  177. for (i in 2:n_classes) {
  178. lambda = gamma_rng(rates_alphas[i], rates_betas[i]);
  179. sim_truth[k,i] = poisson_rng(lambda);
  180. }
  181. lambda = gamma_rng(rates_alphas[1], rates_betas[1]);
  182. sim_truth[k,1] = poisson_rng(lambda + chi_adu_coef*(sim_truth[k,3]+sim_truth[k,4]));
  183. }
  184. for (k in 1:n_sim) {
  185. for (i in 1:n_classes) {
  186. sim_vtc[k,i] = 0;
  187. for (j in 1:n_classes) {
  188. real p = beta_rng(alphas[j,i], betas[j,i]);
  189. sim_vtc[k,i] += binomial_rng(sim_truth[k,j], p);
  190. }
  191. }
  192. }
  193. }
  194. """
  195. if __name__ == "__main__":
  196. annotators = pd.read_csv('input/annotators.csv')
  197. annotators['path'] = annotators['corpus'].apply(lambda c: opj('input', c))
  198. with mp.Pool(processes = 8) as pool:
  199. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient = 'records')))
  200. data = data.sample(frac = 1)
  201. duration = data['duration'].sum()
  202. vtc = np.moveaxis([[data[f'vtc_{j}_{i}'].values for i in range(4)] for j in range(4)], -1, 0)
  203. truth = np.transpose([data[f'truth_{i}'].values for i in range(4)])
  204. print(vtc.shape)
  205. rates = pd.read_csv('output/speech_dist.csv')
  206. data = {
  207. 'n_clips': truth.shape[0],
  208. 'n_classes': truth.shape[1],
  209. 'n_groups': data[args.group].nunique(),
  210. 'n_validation': max(1, int(truth.shape[0]*args.validation)),
  211. 'n_sim': 40,
  212. 'group': 1+data[args.group].astype('category').cat.codes.values,
  213. 'truth': truth.astype(int),
  214. 'vtc': vtc.astype(int),
  215. 'rates_alphas': rates['alpha'].values,
  216. 'rates_betas': rates['beta'].values
  217. }
  218. print(f"clips: {data['n_clips']}")
  219. print(f"groups: {data['n_groups']}")
  220. print("true vocs: {}".format(np.sum(data['truth'])))
  221. print("vtc vocs: {}".format(np.sum(data['vtc'])))
  222. print("duration: {}".format(duration))
  223. with open(f'output/samples/data_{args.output}.pickle', 'wb') as fp:
  224. pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL)
  225. posterior = stan.build(stan_code, data = data)
  226. fit = posterior.sample(num_chains = args.chains, num_samples = args.samples)
  227. df = fit.to_frame()
  228. df.to_parquet(f'output/samples/fit_{args.output}.parquet')