#!usr/bin/env python # -*- coding: utf8 -*- # # Author: William N. Havard (base on various files by Lucas Gautheron) # import os import argparse import logging import pandas as pd pd.set_option('mode.chained_assignment',None) # Silences pandas' complaints from ChildProject.projects import ChildProject from ChildProject.annotations import AnnotationManager def _raw_filename(path): return os.path.splitext(path)[0] def _filter_missing_annotation_files(input): annotation_path = lambda row: os.path.join('annotations', row['set'], 'raw', row['raw_filename']) input['exists'] = input.apply(lambda row: os.path.exists(annotation_path(row)), axis=1) missing_annotations = input[input['exists'] == False] input = input[input['exists'] == True] if len(missing_annotations): missing_annotations['expected_path'] = missing_annotations.apply(lambda row: annotation_path(row), axis=1) missing_annotations_path = sorted(missing_annotations['expected_path'].tolist()) missing_message = "[WARNING] Some annotations you expected to have are missing.\n"\ "Check whether these annotations exist and if so, if their expected path "\ "reflect their true path.\n\t - {}".format('\n\t - '.join(missing_annotations_path)) print(missing_message) return input def _check_importation(am, annotation_set, expected_ann_number): annotations = am.annotations[am.annotations['set'] == annotation_set] if len(annotations) != expected_ann_number: print('[WARNING] Expected to import {} annotations, only found {}!'. format(len(annotations), expected_ann_number)) annotations_segments = am.get_segments(annotations) if len(annotations_segments) == 0: print('[WARNING] Annotations were imported, but they either contain no segments ' 'or the segments were not imported properly!') return len(annotations), len(annotations_segments) def _get_recordings(project, annotation_set, annotation_format): input = project.recordings[['recording_filename', 'duration', 'child_id']] input.dropna(inplace=True) input = input[input['recording_filename'] != 'NA'] input['set'] = annotation_set input['time_seek'] = 0 input['range_onset'] = 0 input['range_offset'] = input['duration'] input['format'] = annotation_format input.drop(['duration'], axis=1, inplace=True) return input def main(annotation_set, annotation_format): annotation_set = annotation_set.lower() annotation_format = '{}_rttm'.format(annotation_format) if annotation_format in ('vcm', 'vtc') else annotation_format # Load project project = ChildProject('.') am = AnnotationManager(project) # Get recordings and set up df input = _get_recordings(project, annotation_set, annotation_format) # RAW FILENAME # /!\ input['raw_filename'] refers to the *subpath and name* of the *annotation file* to be imported # not the *target* recording file. input['raw_filename'] should hence match the # name of the annotation files you want to import! The filenames of the annotation files # usually have the same name or contain the name of the recording. Hence, their name depends # on the value of input['recording_filename'] (unless in case of bulk importation, see below). # (input['raw_filename'] = 'SUB/PATH/annotation_filename.txt') # FILTERING # /!\ Using input['filter'] to filter the input is not necessary but safer. # Filtering the *input annotation file* is necessary in case of bulk importation # where the annotation file contains annotations linked to *several recordings*. In such case, # it should be specified for each input['recording_filename'] which lines of the annotation file # should be associated to it. # Set up 'raw_filename' and 'filter' depending on the annotation set to import if annotation_format in ('vcm', 'vtc'): # My VCM annotations end with '.vcm' and not '.rttm'. Update this if it's not the case for you. rttm_ext = 'rttm' if annotation_format == 'vtc' else 'vcm' input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), rttm_ext)) # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer) input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f))) elif annotation_format == 'alice': # My ALICE annotations have the same name as the recording filename. Update if it's not the case for you. input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'txt')) # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer) input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f))) elif annotation_format == 'cha': input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'cha')) # CHA files do not need filtering as they only contain annotations for the file they are linked to elif annotation_format == 'its': input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'its')) # CHA files do not need filtering as they only contain annotations for the file they are linked to else: raise ValueError('Unknown annotation format `{}`!'.format(annotation_format)) # Filter out rows for which we do not find the matching annotation file input = _filter_missing_annotation_files(input) expected_ann_number = len(input) # Do importation if len(input) > 0: am.remove_set(annotation_set) am.import_annotations(input) len_ann, len_seg = _check_importation(am, annotation_set, expected_ann_number) print('Imported {} annotations resulting in {} segments!'.format(len_ann, len_seg)) else: print('[WARNING] Nothing to import!') if __name__ == '__main__': accepted_formats = ['cha', 'vtc', 'vcm', 'alice', 'its'] parser = argparse.ArgumentParser(description='Import {} annotations.'.format('/'.join(accepted_formats))) parser.add_argument('--annotation-set', required=True, help='Annotations to import (amongst {})'.format('/'.join(os.listdir('annotations')))) parser.add_argument('--annotation-format', required=True, help='Format of the annotations to be imported (amongst {}).'.format('/'.join(accepted_formats)), choices=accepted_formats, type= lambda s: s.lower()) args = parser.parse_args() main(**vars(args))