import_annotations.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!usr/bin/env python
  2. # -*- coding: utf8 -*-
  3. #
  4. # Author: William N. Havard (base on various files by Lucas Gautheron)
  5. #
  6. import os
  7. import argparse
  8. import logging
  9. import pandas as pd
  10. pd.set_option('mode.chained_assignment',None) # Silences pandas' complaints
  11. from ChildProject.projects import ChildProject
  12. from ChildProject.annotations import AnnotationManager
  13. def _raw_filename(path):
  14. return os.path.splitext(path)[0]
  15. def _filter_missing_annotation_files(input):
  16. annotation_path = lambda row: os.path.join('annotations', row['set'], 'raw', row['raw_filename'])
  17. input['exists'] = input.apply(lambda row: os.path.exists(annotation_path(row)), axis=1)
  18. missing_annotations = input[input['exists'] == False]
  19. input = input[input['exists'] == True]
  20. if len(missing_annotations):
  21. missing_annotations['expected_path'] = missing_annotations.apply(lambda row: annotation_path(row), axis=1)
  22. missing_annotations_path = sorted(missing_annotations['expected_path'].tolist())
  23. missing_message = "[WARNING] Some annotations you expected to have are missing.\n"\
  24. "Check whether these annotations exist and if so, if their expected path "\
  25. "reflect their true path.\n\t - {}".format('\n\t - '.join(missing_annotations_path))
  26. print(missing_message)
  27. return input
  28. def _check_importation(am, annotation_set, expected_ann_number):
  29. annotations = am.annotations[am.annotations['set'] == annotation_set]
  30. if len(annotations) != expected_ann_number:
  31. print('[WARNING] Expected to import {} annotations, only found {}!'.
  32. format(len(annotations), expected_ann_number))
  33. annotations_segments = am.get_segments(annotations)
  34. if len(annotations_segments) == 0:
  35. print('[WARNING] Annotations were imported, but they either contain no segments '
  36. 'or the segments were not imported properly!')
  37. return len(annotations), len(annotations_segments)
  38. def _get_recordings(project, annotation_set, ann_format):
  39. input = project.recordings[['recording_filename', 'duration', 'child_id']]
  40. input.dropna(inplace=True)
  41. input = input[input['recording_filename'] != 'NA']
  42. input['set'] = annotation_set
  43. input['time_seek'] = 0
  44. input['range_onset'] = 0
  45. input['range_offset'] = input['duration']
  46. input['format'] = ann_format
  47. input.drop(['duration'], axis=1, inplace=True)
  48. return input
  49. def main(annotation_set):
  50. annotation_set = annotation_set.lower()
  51. ann_format = '{}_rttm'.format(annotation_set) if annotation_set in ('vcm', 'vtc') else annotation_set
  52. # Load project
  53. project = ChildProject('.')
  54. am = AnnotationManager(project)
  55. # Get recordings and set up df
  56. input = _get_recordings(project, annotation_set, ann_format)
  57. # RAW FILENAME
  58. # /!\ input['raw_filename'] refers to the *subpath and name* of the *annotation file* to be imported
  59. # not the *target* recording file. input['raw_filename'] should hence match the
  60. # name of the annotation files you want to import! The filenames of the annotation files
  61. # usually have the same name or contain the name of the recording. Hence, their name depends
  62. # on the value of input['recording_filename'] (unless in case of bulk importation, see below).
  63. # (input['raw_filename'] = 'SUB/PATH/annotation_filename.txt')
  64. # FILTERING
  65. # /!\ Using input['filter'] to filter the input is not necessary but safer.
  66. # Filtering the *input annotation file* is necessary in case of bulk importation
  67. # where the annotation file contains annotations linked to *several recordings*. In such case,
  68. # it should be specified for each input['recording_filename'] which lines of the annotation file
  69. # should be associated to it.
  70. # Set up 'raw_filename' and 'filter' depending on the annotation set to import
  71. if annotation_set in ('vcm', 'vtc'):
  72. # My VCM annotations end with '.vcm' and not '.rttm'. Update this if it's not the case for you.
  73. rttm_ext = 'rttm' if annotation_set == 'vtc' else 'vcm'
  74. input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), rttm_ext))
  75. # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer)
  76. input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f)))
  77. elif annotation_set == 'alice':
  78. # My ALICE annotations have the same name as the recording filename. Update if it's not the case for you.
  79. input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'txt'))
  80. # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer)
  81. input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f)))
  82. elif annotation_set == 'cha':
  83. input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'cha'))
  84. # CHA files do not need filtering as they only contain annotations for the file they are linked to
  85. else:
  86. raise ValueError('Unknown annotation set `{}`!'.format(annotation_set))
  87. # Filter out rows for which we do not find the matching annotation file
  88. input = _filter_missing_annotation_files(input)
  89. expected_ann_number = len(input)
  90. # Do importation
  91. if len(input) > 0:
  92. am.remove_set(annotation_set)
  93. am.import_annotations(input)
  94. len_ann, len_seg = _check_importation(am, annotation_set, expected_ann_number)
  95. print('Imported {} annotations resulting in {} segments!'.format(len_ann, len_seg))
  96. else:
  97. print('[WARNING] Nothing to import!')
  98. if __name__ == '__main__':
  99. accepted_annotations = ['cha', 'vtc', 'vcm', 'alice']
  100. parser = argparse.ArgumentParser(description='Import {} annotations.'.format('/'.join(accepted_annotations)))
  101. parser.add_argument('annotation_set',
  102. help='Set of annotations to import (amongst {}).'.format('/'.join(accepted_annotations)),
  103. choices=accepted_annotations,
  104. type= lambda s: s.lower())
  105. args = parser.parse_args()
  106. main(args.annotation_set)