model4.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. import pandas as pd
  12. from pyannote.core import Annotation, Segment, Timeline
  13. import stan
  14. parser = argparse.ArgumentParser(description = 'model3')
  15. parser.add_argument('--group', default = 'child', choices = ['corpus', 'child'])
  16. parser.add_argument('--chains', default = 4, type = int)
  17. parser.add_argument('--samples', default = 2000, type = int)
  18. args = parser.parse_args()
  19. def extrude(self, removed, mode: str = 'intersection'):
  20. if isinstance(removed, Segment):
  21. removed = Timeline([removed])
  22. truncating_support = removed.gaps(support=self.extent())
  23. # loose for truncate means strict for crop and vice-versa
  24. if mode == "loose":
  25. mode = "strict"
  26. elif mode == "strict":
  27. mode = "loose"
  28. return self.crop(truncating_support, mode=mode)
  29. def compute_counts(parameters):
  30. corpus = parameters['corpus']
  31. annotator = parameters['annotator']
  32. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  33. project = ChildProject(parameters['path'])
  34. am = AnnotationManager(project)
  35. am.read()
  36. intersection = AnnotationManager.intersection(
  37. am.annotations, ['vtc', annotator]
  38. )
  39. intersection['onset'] = intersection.apply(lambda r: np.arange(r['range_onset'], r['range_offset'], 15000), axis = 1)
  40. intersection = intersection.explode('onset')
  41. intersection['range_onset'] = intersection['onset']
  42. intersection['range_offset'] = (intersection['range_onset']+15000).clip(upper = intersection['range_offset'])
  43. intersection['path'] = intersection.apply(
  44. lambda r: opj(project.path, 'annotations', r['set'], 'converted', r['annotation_filename']),
  45. axis = 1
  46. )
  47. datalad.api.get(list(intersection['path'].unique()))
  48. intersection = intersection.merge(project.recordings[['recording_filename', 'child_id']], how = 'left')
  49. intersection['child'] = corpus + '_' + intersection['child_id'].astype(str)
  50. data = []
  51. for child, ann in intersection.groupby('child'):
  52. print(corpus, child)
  53. segments = am.get_collapsed_segments(ann)
  54. if 'speaker_type' not in segments.columns:
  55. continue
  56. segments = segments[segments['speaker_type'].isin(speakers)]
  57. vtc = {
  58. speaker: segments_to_annotation(segments[(segments['set'] == 'vtc') & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  59. for speaker in speakers
  60. }
  61. truth = {
  62. speaker: segments_to_annotation(segments[(segments['set'] == annotator) & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  63. for speaker in speakers
  64. }
  65. for speaker_A in speakers:
  66. vtc[f'{speaker_A}_vocs_explained'] = vtc[speaker_A].crop(truth[speaker_A], mode = 'loose')
  67. vtc[f'{speaker_A}_vocs_fp'] = extrude(vtc[speaker_A], vtc[f'{speaker_A}_vocs_explained'])
  68. vtc[f'{speaker_A}_vocs_fn'] = extrude(truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode = 'loose'))
  69. for speaker_B in speakers:
  70. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'] = vtc[f'{speaker_A}_vocs_fp'].crop(truth[speaker_B], mode = 'loose')
  71. d = {}
  72. for i, speaker_A in enumerate(speakers):
  73. for j, speaker_B in enumerate(speakers):
  74. if i != j:
  75. z = len(vtc[f'{speaker_A}_vocs_fp_{speaker_B}'])
  76. else:
  77. z = len(truth[speaker_A]) - len(vtc[f'{speaker_A}_vocs_fn'])
  78. d[f'vtc_{i}_{j}'] = z
  79. d[f'truth_{i}'] = len(truth[speaker_A])
  80. d['child'] = child
  81. data.append(d)
  82. return pd.DataFrame(data).assign(
  83. corpus = corpus
  84. )
  85. stan_code = """
  86. data {
  87. int<lower=1> n_clips; // number of clips
  88. int<lower=1> n_groups; // number of groups
  89. int<lower=1> n_classes; // number of classes
  90. int group[n_clips];
  91. int vtc[n_clips,n_classes,n_classes];
  92. int truth[n_clips,n_classes];
  93. }
  94. parameters {
  95. matrix<lower=0,upper=1>[n_classes,n_classes] mus;
  96. matrix<lower=1>[n_classes,n_classes] etas;
  97. matrix<lower=0,upper=1>[n_classes,n_classes] group_mus[n_groups];
  98. matrix<lower=1>[n_classes,n_classes] group_etas[n_groups];
  99. matrix<lower=0,upper=1>[n_classes,n_classes] group_confusion[n_groups];
  100. }
  101. transformed parameters {
  102. matrix<lower=0>[n_classes,n_classes] alphas[n_groups];
  103. matrix<lower=0>[n_classes,n_classes] betas[n_groups];
  104. for (c in 1:n_groups) {
  105. for (i in 1:n_classes) {
  106. for (j in 1:n_classes) {
  107. alphas[c,i,j] = mus[i,j] * etas[i,j] + group_mus[c,i,j] * group_etas[c,i,j];
  108. betas[c,i,j] = (1-mus[i,j]) * etas[i,j] + (1-group_mus[c,i,j]) * group_etas[c,i,j];
  109. }
  110. }
  111. }
  112. }
  113. model {
  114. for (k in 1:n_clips) {
  115. for (i in 1:n_classes) {
  116. for (j in 1:n_classes) {
  117. vtc[k,i,j] ~ binomial(truth[k,j], group_confusion[group[k],j,i]);
  118. }
  119. }
  120. }
  121. for (i in 1:n_classes) {
  122. for (j in 1:n_classes) {
  123. mus[i,j] ~ beta(1,1);
  124. etas[i,j] ~ pareto(1,1.5);
  125. }
  126. }
  127. for (c in 1:n_groups) {
  128. for (i in 1:n_classes) {
  129. for (j in 1:n_classes) {
  130. group_mus[c,i,j] ~ beta(1,1);
  131. group_etas[c,i,j] ~ pareto(1, 1.5);
  132. }
  133. }
  134. }
  135. for (c in 1:n_groups) {
  136. for (i in 1:n_classes) {
  137. for (j in 1:n_classes) {
  138. group_confusion[c,i,j] ~ beta(alphas[c,i,j], betas[c,i,j]);
  139. }
  140. }
  141. }
  142. }
  143. """
  144. if __name__ == "__main__":
  145. annotators = pd.read_csv('input/annotators.csv')
  146. annotators['path'] = annotators['corpus'].apply(lambda c: opj('input', c))
  147. with mp.Pool(processes = 8) as pool:
  148. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient = 'records')))
  149. print(data)
  150. vtc = np.moveaxis([[data[f'vtc_{j}_{i}'].values for i in range(4)] for j in range(4)], -1, 0)
  151. truth = np.transpose([data[f'truth_{i}'].values for i in range(4)])
  152. print(vtc.shape)
  153. data = {
  154. 'n_clips': truth.shape[0],
  155. 'n_classes': truth.shape[1],
  156. 'n_groups': data[args.group].nunique(),
  157. 'group': 1+data[args.group].astype('category').cat.codes.values,
  158. 'truth': truth.astype(int),
  159. 'vtc': vtc.astype(int)
  160. }
  161. print(f"clips: {data['n_clips']}")
  162. print(f"groups: {data['n_groups']}")
  163. print("true vocs: {}".format(np.sum(data['truth'])))
  164. print("vtc vocs: {}".format(np.sum(data['vtc'])))
  165. posterior = stan.build(stan_code, data = data)
  166. fit = posterior.sample(num_chains = args.chains, num_samples = args.samples)
  167. df = fit.to_frame()
  168. df.to_parquet('fit.parquet')