summary.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. from scipy.stats import binom
  12. import pandas as pd
  13. from pyannote.core import Annotation, Segment, Timeline
  14. import matplotlib
  15. matplotlib.use("pgf")
  16. matplotlib.rcParams.update({
  17. "pgf.texsystem": "pdflatex",
  18. 'font.family': 'serif',
  19. 'text.usetex': True,
  20. 'pgf.rcfonts': False,
  21. })
  22. from matplotlib import pyplot as plt
  23. parser = argparse.ArgumentParser(description = 'model3')
  24. parser.add_argument('--group', default = 'child', choices = ['corpus', 'child'])
  25. parser.add_argument('--chains', default = 4, type = int)
  26. parser.add_argument('--samples', default = 2000, type = int)
  27. args = parser.parse_args()
  28. def set_size(width, ratio):
  29. return width/72.27, ratio*width/72.27
  30. def extrude(self, removed, mode: str = 'intersection'):
  31. if isinstance(removed, Segment):
  32. removed = Timeline([removed])
  33. truncating_support = removed.gaps(support=self.extent())
  34. # loose for truncate means strict for crop and vice-versa
  35. if mode == "loose":
  36. mode = "strict"
  37. elif mode == "strict":
  38. mode = "loose"
  39. return self.crop(truncating_support, mode=mode)
  40. def compute_counts(parameters):
  41. corpus = parameters['corpus']
  42. annotator = parameters['annotator']
  43. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  44. project = ChildProject(parameters['path'])
  45. am = AnnotationManager(project)
  46. am.read()
  47. intersection = AnnotationManager.intersection(
  48. am.annotations, ['vtc', annotator]
  49. )
  50. intersection['onset'] = intersection.apply(lambda r: np.arange(r['range_onset'], r['range_offset'], 15000), axis = 1)
  51. intersection = intersection.explode('onset')
  52. intersection['range_onset'] = intersection['onset']
  53. intersection['range_offset'] = (intersection['range_onset']+15000).clip(upper = intersection['range_offset'])
  54. intersection['path'] = intersection.apply(
  55. lambda r: opj(project.path, 'annotations', r['set'], 'converted', r['annotation_filename']),
  56. axis = 1
  57. )
  58. datalad.api.get(list(intersection['path'].unique()))
  59. intersection = intersection.merge(project.recordings[['recording_filename', 'child_id']], how = 'left')
  60. intersection['child'] = corpus + '_' + intersection['child_id'].astype(str)
  61. data = []
  62. for child, ann in intersection.groupby('child'):
  63. print(corpus, child)
  64. segments = am.get_collapsed_segments(ann)
  65. if 'speaker_type' not in segments.columns:
  66. continue
  67. segments = segments[segments['speaker_type'].isin(speakers)]
  68. vtc = {
  69. speaker: segments_to_annotation(segments[(segments['set'] == 'vtc') & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  70. for speaker in speakers
  71. }
  72. truth = {
  73. speaker: segments_to_annotation(segments[(segments['set'] == annotator) & (segments['speaker_type'] == speaker)], 'speaker_type').get_timeline()
  74. for speaker in speakers
  75. }
  76. for i, speaker_A in enumerate(speakers):
  77. vtc[f'{speaker_A}_vocs_explained'] = vtc[speaker_A].crop(truth[speaker_A], mode = 'loose')
  78. vtc[f'{speaker_A}_vocs_fp'] = extrude(vtc[speaker_A], vtc[f'{speaker_A}_vocs_explained'])
  79. vtc[f'{speaker_A}_vocs_fn'] = extrude(truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode = 'loose'))
  80. vtc[f'{speaker_A}_vocs_unexplained'] = extrude(vtc[speaker_A], vtc[f'{speaker_A}_vocs_explained'])
  81. for speaker_B in speakers:
  82. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'] = vtc[f'{speaker_A}_vocs_fp'].crop(truth[speaker_B], mode = 'loose')
  83. vtc[f'{speaker_A}_vocs_unexplained'] = extrude(vtc[f'{speaker_A}_vocs_unexplained'], vtc[f'{speaker_A}_vocs_unexplained'].crop(truth[speaker_B], mode = 'loose'))
  84. for speaker_C in speakers:
  85. if speaker_C != speaker_B and speaker_C != speaker_A:
  86. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'] = extrude(
  87. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'],
  88. vtc[f'{speaker_A}_vocs_fp_{speaker_B}'].crop(truth[speaker_C], mode = 'loose')
  89. )
  90. d = {'child': child}
  91. for i, speaker_A in enumerate(speakers):
  92. for j, speaker_B in enumerate(speakers):
  93. if i != j:
  94. z = len(vtc[f'{speaker_A}_vocs_fp_{speaker_B}'])
  95. else:
  96. z = len(vtc[f'{speaker_A}_vocs_explained'])
  97. d[f'vtc_{speaker_A}_{speaker_B}'] = z
  98. if len(vtc[f'{speaker_A}_vocs_explained']) > len(truth[speaker_A]):
  99. print(speaker_A, child)
  100. d[f'truth_{speaker_A}'] = len(truth[speaker_A])
  101. d[f'unexplained_{speaker_B}'] = len(vtc[f'{speaker_A}_vocs_unexplained'])
  102. data.append(d)
  103. return pd.DataFrame(data).assign(
  104. corpus = corpus
  105. )
  106. if __name__ == "__main__":
  107. annotators = pd.read_csv('input/annotators.csv')
  108. annotators = annotators[~annotators['annotator'].str.startswith('eaf_2021')]
  109. annotators['path'] = annotators['corpus'].apply(lambda c: opj('input', c))
  110. with mp.Pool(processes = 8) as pool:
  111. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient = 'records')))
  112. data.to_csv('output/summary.csv', index = False)
  113. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  114. colors = ['red', 'orange', 'green', 'blue']
  115. fig, axes = plt.subplots(4, 4, figsize = (6,6))
  116. for i, speaker_A in enumerate(speakers):
  117. for j, speaker_B in enumerate(speakers):
  118. ax = axes.flatten()[4*i+j]
  119. x = data[f'truth_{speaker_A}'].values
  120. y = data[f'vtc_{speaker_B}_{speaker_A}'].values
  121. mask = (x > 0) & (y > 0)
  122. x = x[mask]
  123. y = y[mask]
  124. low = binom.ppf((1-0.68)/2, x, y/x)
  125. high = binom.ppf(1-(1-0.68)/2, x, y/x)
  126. mask = (~np.isnan(low)&(~np.isnan(high)))
  127. yerr = np.array([
  128. y[mask]-low[mask], high[mask]-y[mask]
  129. ])
  130. slopes_x = np.logspace(0,3,num=3)
  131. ax.plot(slopes_x, slopes_x, color = '#ddd', lw = 0.5)
  132. ax.plot(slopes_x, 0.1*slopes_x, color = '#ddd', lw = 0.5, linestyle = '--')
  133. ax.plot(slopes_x, 0.01*slopes_x, color = '#ddd', lw = 0.5, linestyle = '-.')
  134. ax.errorbar(
  135. x[mask], y[mask],
  136. yerr = yerr,
  137. color = colors[j],
  138. ls='none',
  139. elinewidth=0.5
  140. )
  141. ax.scatter(
  142. x, y,
  143. s = 0.75,
  144. color = colors[j]
  145. )
  146. ax.set_xscale('log')
  147. ax.set_yscale('log')
  148. ax.set_xlim(1,1000)
  149. ax.set_ylim(1,1000)
  150. ax.set_xticks([])
  151. ax.set_xticklabels([])
  152. ax.set_yticks([])
  153. ax.set_yticklabels([])
  154. if i == 0:
  155. ax.xaxis.tick_top()
  156. ax.set_xticks([10**1.5])
  157. ax.set_xticklabels([speakers[j]])
  158. if i == 3:
  159. ax.set_xticks(np.power(10, np.arange(1,4)))
  160. ax.set_xticklabels([f'10$^{i}$' for i in [1,2,3]])
  161. if j == 0:
  162. ax.set_yticks([10**1.5])
  163. ax.set_yticklabels([speakers[i]])
  164. if j == 3:
  165. ax.yaxis.tick_right()
  166. ax.set_yticks(np.power(10, np.arange(1,4)))
  167. ax.set_yticklabels([f'10$^{i}$' for i in [1,2,3]])
  168. fig.subplots_adjust(wspace = 0, hspace = 0)
  169. fig.set_size_inches(set_size(450, 1))
  170. fig.savefig('output/summary.pdf')
  171. fig.savefig('output/summary.pgf')