123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- #!usr/bin/env python
- # -*- coding: utf8 -*-
- #
- # Author: William N. Havard (base on various files by Lucas Gautheron)
- #
- import os
- import argparse
- import logging
- import pandas as pd
- pd.set_option('mode.chained_assignment',None) # Silences pandas' complaints
- from ChildProject.projects import ChildProject
- from ChildProject.annotations import AnnotationManager
- def _raw_filename(path):
- return os.path.splitext(path)[0]
- def _filter_missing_annotation_files(input):
- annotation_path = lambda row: os.path.join('annotations', row['set'], 'raw', row['raw_filename'])
- input['exists'] = input.apply(lambda row: os.path.exists(annotation_path(row)), axis=1)
- missing_annotations = input[input['exists'] == False]
- input = input[input['exists'] == True]
- if len(missing_annotations):
- missing_annotations['expected_path'] = missing_annotations.apply(lambda row: annotation_path(row), axis=1)
- missing_annotations_path = sorted(missing_annotations['expected_path'].tolist())
- missing_message = "[WARNING] Some annotations you expected to have are missing.\n"\
- "Check whether these annotations exist and if so, if their expected path "\
- "reflect their true path.\n\t - {}".format('\n\t - '.join(missing_annotations_path))
- print(missing_message)
- return input
- def _check_importation(am, annotation_set, expected_ann_number):
- annotations = am.annotations[am.annotations['set'] == annotation_set]
- if len(annotations) != expected_ann_number:
- print('[WARNING] Expected to import {} annotations, only found {}!'.
- format(len(annotations), expected_ann_number))
- annotations_segments = am.get_segments(annotations)
- if len(annotations_segments) == 0:
- print('[WARNING] Annotations were imported, but they either contain no segments '
- 'or the segments were not imported properly!')
- return len(annotations), len(annotations_segments)
- def _get_recordings(project, annotation_set, ann_format):
- input = project.recordings[['recording_filename', 'duration', 'child_id']]
- input.dropna(inplace=True)
- input = input[input['recording_filename'] != 'NA']
- input['set'] = annotation_set
- input['time_seek'] = 0
- input['range_onset'] = 0
- input['range_offset'] = input['duration']
- input['format'] = ann_format
- input.drop(['duration'], axis=1, inplace=True)
- return input
- def main(annotation_set):
- annotation_set = annotation_set.lower()
- ann_format = '{}_rttm'.format(annotation_set) if annotation_set in ('vcm', 'vtc') else annotation_set
- # Load project
- project = ChildProject('.')
- am = AnnotationManager(project)
- # Get recordings and set up df
- input = _get_recordings(project, annotation_set, ann_format)
- # RAW FILENAME
- # /!\ input['raw_filename'] refers to the *subpath and name* of the *annotation file* to be imported
- # not the *target* recording file. input['raw_filename'] should hence match the
- # name of the annotation files you want to import! The filenames of the annotation files
- # usually have the same name or contain the name of the recording. Hence, their name depends
- # on the value of input['recording_filename'] (unless in case of bulk importation, see below).
- # (input['raw_filename'] = 'SUB/PATH/annotation_filename.txt')
- # FILTERING
- # /!\ Using input['filter'] to filter the input is not necessary but safer.
- # Filtering the *input annotation file* is necessary in case of bulk importation
- # where the annotation file contains annotations linked to *several recordings*. In such case,
- # it should be specified for each input['recording_filename'] which lines of the annotation file
- # should be associated to it.
- # Set up 'raw_filename' and 'filter' depending on the annotation set to import
- if annotation_set in ('vcm', 'vtc'):
- # My VCM annotations end with '.vcm' and not '.rttm'. Update this if it's not the case for you.
- rttm_ext = 'rttm' if annotation_set == 'vtc' else 'vcm'
- input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), rttm_ext))
- # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer)
- input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f)))
- elif annotation_set == 'alice':
- # My ALICE annotations have the same name as the recording filename. Update if it's not the case for you.
- input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'txt'))
- # We only keep lines for which the 'file' column is equal to filter (unnecessary here, but safer)
- input['filter'] = input['recording_filename'].apply(lambda f: os.path.basename(_raw_filename(f)))
- elif annotation_set == 'cha':
- input['raw_filename'] = input['recording_filename'].apply(lambda f: '{}.{}'.format(_raw_filename(f), 'cha'))
- # CHA files do not need filtering as they only contain annotations for the file they are linked to
- else:
- raise ValueError('Unknown annotation set `{}`!'.format(annotation_set))
- # Filter out rows for which we do not find the matching annotation file
- input = _filter_missing_annotation_files(input)
- expected_ann_number = len(input)
- # Do importation
- if len(input) > 0:
- am.remove_set(annotation_set)
- am.import_annotations(input)
- len_ann, len_seg = _check_importation(am, annotation_set, expected_ann_number)
- print('Imported {} annotations resulting in {} segments!'.format(len_ann, len_seg))
- else:
- print('[WARNING] Nothing to import!')
- if __name__ == '__main__':
- accepted_annotations = ['cha', 'vtc', 'vcm', 'alice']
- parser = argparse.ArgumentParser(description='Import {} annotations.'.format('/'.join(accepted_annotations)))
- parser.add_argument('annotation_set',
- help='Set of annotations to import (amongst {}).'.format('/'.join(accepted_annotations)),
- choices=accepted_annotations,
- type= lambda s: s.lower())
- args = parser.parse_args()
- main(args.annotation_set)
|