compute_derived_annotations.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #!usr/bin/env python
  2. # -*- coding: utf8 -*-
  3. # -----------------------------------------------------------------------------
  4. # File: compute_derived_annotations.py (as part of project URUMETRICS-CODE)
  5. # Created: 03/08/2022 17:32
  6. # Last Modified: 03/08/2022 17:32
  7. # -----------------------------------------------------------------------------
  8. # Author: William N. Havard
  9. # Postdoctoral Researcher
  10. #
  11. # Mail : william.havard@ens.fr / william.havard@gmail.com
  12. #
  13. # Institution: ENS / Laboratoire de Sciences Cognitives et Psycholinguistique
  14. #
  15. # ------------------------------------------------------------------------------
  16. # Description:
  17. # •
  18. # -----------------------------------------------------------------------------
  19. import logging
  20. import os
  21. from functools import partial
  22. import csv
  23. import pandas as pd
  24. from ChildProject.annotations import AnnotationManager
  25. from ChildProject.projects import ChildProject
  26. import annotations_functions
  27. logger = logging.getLogger(__name__)
  28. def _annotation_function_wrapper(func, parser_args, **kwargs):
  29. return partial(func,parser_args=parser_args, **kwargs)
  30. def get_available_segments(project_path, set_name, base_sets, raw_recording_available=False):
  31. """
  32. Get the annotation segments that will be used to construct new annotations for a set.
  33. This returns the segments in the sets used (base_sets) to compute the new annotations (for set_name).
  34. This will exclude segments for which the annotations already exist in the target set.
  35. :param project_path: path to the dataset
  36. :type project_path: str
  37. :param set_name: set to which we will add new annotation files. segments of base sets are excluded if they already have annotation files in the target set
  38. :type set_name: str
  39. :param base_sets: sets from which to get the segments
  40. :type base_sets: list[str]
  41. :param raw_recording_available: if True, exclude annotations for which the actual recording is not present (for when the process requires the audio)
  42. :type raw_recording_available: bool
  43. """
  44. project = ChildProject(project_path)
  45. am = AnnotationManager(project)
  46. am.read()
  47. for base_set in base_sets:
  48. assert os.path.exists(os.path.join(project.path, 'annotations', base_set)), \
  49. ValueError('BASE_SET `{}` not found!'.format(base_set))
  50. # Get available VTC annotations (the files are readable)
  51. base_sets_df = am.annotations[am.annotations['set'].isin(base_sets)]
  52. available_base_sets_anns = base_sets_df[base_sets_df.apply(
  53. lambda r: os.path.exists(os.path.join(project.path, 'annotations', r['set'], 'converted', r['annotation_filename'])),
  54. axis=1)]
  55. # Get already existing conversation annotations and only compute annotations for the files that do not already
  56. # have conversational annotations
  57. set_name = set_name.lower()
  58. if set_name in set(am.annotations['set']):
  59. target_set_anns = am.annotations[am.annotations['set'] == set_name]
  60. available_base_sets_anns = available_base_sets_anns[
  61. ~available_base_sets_anns['recording_filename'].isin(target_set_anns['recording_filename'])]
  62. # We check that the recording is available if the user wants
  63. if raw_recording_available:
  64. available_base_sets_anns = available_base_sets_anns[available_base_sets_anns['recording_filename'].apply(
  65. lambda fn: os.path.exists(os.path.join(project.path, 'recordings', 'raw', fn)))]
  66. # Get the segments that are left
  67. data = am.get_segments(available_base_sets_anns)
  68. return data
  69. def _compute_annotations(project_path, annotation_type, annotation_function, base_sets, raw_recording_available):
  70. """
  71. Computes annotations for the ChildProject in directory project_path, of a specific set, from a list of sets
  72. :param project_path: path to ChildProject dataset
  73. :type project_path: str
  74. :param annotation_type: name of the set to compute for
  75. :type annotation_type: str
  76. :param annotation_function: callable that creates the annotations (stored in annotations_functions)
  77. :type annotation_function: callable
  78. :base_sets: sets that are required to compute the new annotations
  79. :type base_sets: list[str]
  80. :param raw_recording_available: is the actual recording file needed
  81. :type raw_recording_available: bool
  82. :return: annotations
  83. :rtype: pd.DataFrame
  84. """
  85. data = get_available_segments(project_path,
  86. set_name=annotation_type,
  87. base_sets=base_sets,
  88. raw_recording_available=raw_recording_available)
  89. if not len(data):
  90. return pd.DataFrame()
  91. data = data[~data['speaker_type'].isnull()]
  92. annotations = []
  93. data_grouped = data.groupby('recording_filename')
  94. for data_grouped_name, data_grouped_line in data_grouped:
  95. df_annotations = annotation_function(recording_filename=data_grouped_name, segments=data_grouped_line,
  96. project_path = project_path)
  97. annotations.append(df_annotations)
  98. output = pd.concat(annotations, axis=0)
  99. return output
  100. def save_annotations(save_path, annotations, annotation_type):
  101. """
  102. Save the computed annotations
  103. :param save_path: path where to save the annotations (use annotation raw folder)
  104. :type save_path: str
  105. :param annotations: annotations to be saved
  106. :type annotations: pd.DataFrame
  107. :param annotation_type: annotation type, only used to name the raw file
  108. :type annotation_type: str
  109. :return: None
  110. :rtype: None
  111. """
  112. annotations_grouped = annotations.groupby('raw_filename')
  113. for annotation_group_name, annotation_group_data in annotations_grouped:
  114. output_filename = '{}_{}'.format(annotation_type.upper(),annotation_group_name.replace('.rttm', '.csv'))
  115. full_save_path = os.path.join(save_path, output_filename)
  116. if os.path.exists(full_save_path):
  117. logger.warning('File {} already exists! If you want to recompute annotations for this file, '
  118. 'please delete it first!'.format(full_save_path))
  119. continue
  120. annotation_group_data = annotation_group_data.drop(columns=
  121. ['raw_filename',
  122. 'set',
  123. 'time_seek',
  124. 'range_onset',
  125. 'range_offset',
  126. 'format',
  127. 'filter',
  128. 'annotation_filename',
  129. 'imported_at',
  130. 'package_version',
  131. 'error',
  132. 'merged_from',
  133. ])
  134. annotation_group_data.to_csv(full_save_path, index=False, quoting=csv.QUOTE_NONNUMERIC)
  135. logger.info('Saved to {}.'.format(full_save_path))
  136. def main(project_path, annotation_type, save_path, unknown_args):
  137. # Check if running the script from the root of the data set
  138. expected_annotation_path = os.path.join(project_path, 'annotations')
  139. expected_recordings_path = os.path.join(project_path, 'recordings')
  140. assert os.path.exists(expected_annotation_path) and os.path.exists(expected_recordings_path), \
  141. ValueError('Expected annotation ({}) or recording path ({}) not found. Are you sure to be running this '
  142. 'command from the root of the data set?'.format(expected_annotation_path, expected_recordings_path))
  143. assert os.path.exists(os.path.abspath(save_path)), IOError('Path {} does not exist!'.format(save_path))
  144. assert hasattr(annotations_functions, '{}_annotations'.format(annotation_type.lower())), \
  145. ValueError('Annotation function {}_annotations not found.'.format(annotation_type.lower()))
  146. annotation_function = getattr(annotations_functions, '{}_annotations'.format(annotation_type.lower()))
  147. annotation_function_base_sets = getattr(annotation_function, 'BASE_SETS')
  148. raw_recording_available = getattr(annotation_function, 'RAW_RECORDING_AVAILABLE', False)
  149. annotations = _compute_annotations(project_path=project_path,
  150. annotation_type = annotation_type,
  151. annotation_function=_annotation_function_wrapper(
  152. func=annotation_function,
  153. parser_args=unknown_args),
  154. base_sets=annotation_function_base_sets,
  155. raw_recording_available=raw_recording_available)
  156. if not len(annotations):
  157. logger.warning('Apparently nothing needs to be computed!')
  158. return
  159. save_annotations(save_path, annotations, annotation_type)
  160. def _parse_args(argv):
  161. import argparse
  162. parser = argparse.ArgumentParser(description='Compute acoustic annotations.')
  163. parser.add_argument('--project-path', required=False, type=str, default='',
  164. help="Path to a ChildProject/datalad project (useful for debugging purposes).")
  165. parser.add_argument('--annotation-type', required=True,
  166. help='Which type of annotations should be computed.')
  167. parser.add_argument('--save-path', required=True,
  168. help='Path were the annotations should be saved.')
  169. args, unknown_args = parser.parse_known_args(argv)
  170. return vars(args), unknown_args
  171. if __name__ == '__main__':
  172. import sys
  173. pgrm_name, argv = sys.argv[0], sys.argv[1:]
  174. args, unknown_args = _parse_args(argv)
  175. logging.basicConfig(level=logging.INFO)
  176. try:
  177. main(unknown_args=unknown_args, **args)
  178. sys.exit(0)
  179. except Exception as e:
  180. logger.exception(e)
  181. sys.exit(1)