messages.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #!usr/bin/env python
  2. # -*- coding: utf8 -*-
  3. import csv
  4. # -----------------------------------------------------------------------------
  5. # File: messages.py (as part of project URUMETRICS)
  6. # Created: 29/07/2022 15:35
  7. # Last Modified: 29/07/2022 15:35
  8. # -----------------------------------------------------------------------------
  9. # Author: William N. Havard
  10. # Postdoctoral Researcher
  11. #
  12. # Mail : william.havard@ens.fr / william.havard@gmail.com
  13. #
  14. # Institution: ENS / Laboratoire de Sciences Cognitives et Psycholinguistique
  15. #
  16. # ------------------------------------------------------------------------------
  17. # Description:
  18. # •
  19. # -----------------------------------------------------------------------------
  20. import logging
  21. import os
  22. import sys
  23. from datetime import datetime
  24. from pprint import pprint
  25. import pandas as pd
  26. import yaml
  27. from ChildProject.annotations import AnnotationManager
  28. from ChildProject.projects import ChildProject
  29. logger = logging.getLogger(__name__)
  30. def _read_yaml(yaml_path):
  31. with open(yaml_path, 'r') as in_yaml:
  32. data = yaml.load(in_yaml, Loader=yaml.FullLoader)
  33. return data
  34. def get_metrics(project_path, metrics_file):
  35. project = ChildProject(project_path)
  36. am = AnnotationManager(project)
  37. am.read()
  38. # Read metrics file and get metrics columns
  39. metrics = pd.read_csv(metrics_file)
  40. metrics_columns = list(set(metrics.columns) - set(['recording_filename', 'child_id']))
  41. # Merge with recordings to get date_iso
  42. metrics_recordings = pd.merge(metrics, project.recordings, on='recording_filename', suffixes=('', '_drop'))
  43. metrics_recordings.drop([col for col in metrics_recordings.columns if 'drop' in col], axis=1, inplace=True)
  44. # Handle file with the same child_id that have the same date -> keep the longest one
  45. metrics_recordings = (metrics_recordings.groupby(['child_id', 'date_iso'], as_index=False)
  46. # Keep only the first segment for each candidate speaker
  47. .apply(lambda rows: (rows.sort_values(by='start_time', ascending=False) # take last instead
  48. .head(n=1))))
  49. return metrics_recordings, metrics_columns
  50. def fill_template(template_key, messages, metrics_evolution):
  51. template = messages['_{}_{}'.format(*template_key)]
  52. for positivity_item_index, (positivity_item, _, positivity_direction) in enumerate(metrics_evolution, 1):
  53. template = template.replace('#{}'.format(positivity_item_index),
  54. messages[positivity_item][positivity_direction])
  55. message_variables = [msg for msg in messages if msg.startswith('_')]
  56. for message_variable in message_variables:
  57. message_variable = message_variable[1:]
  58. template = template.replace('#{}'.format(message_variable),
  59. messages['_{}'.format(message_variable)])
  60. return template
  61. def build_messages(metrics_recordings, metrics_columns, message_file_path, date):
  62. try:
  63. date = datetime.strptime(date, "%Y%m%d").strftime("%Y-%m-%d")
  64. except:
  65. raise ValueError('--date format should be YYYYMMDD without any separators.')
  66. # Get metrics of interest and messages
  67. metric_messages = _read_yaml(message_file_path)
  68. metrics_of_interest = [item for item in list(metric_messages.keys()) if not item.startswith('_')]
  69. # Keep only rows for which the date is below or equal to the one we want
  70. metrics_recordings = metrics_recordings[metrics_recordings['date_iso'] <= date]
  71. # Generate messages
  72. output_messages = []
  73. metrics_grouped = metrics_recordings.groupby('child_id', as_index=False)
  74. for _, metrics_grouped_item in metrics_grouped:
  75. sorted_metrics_grouped_items = metrics_grouped_item.sort_values(by=['date_iso', 'imported_at'],
  76. ascending=False)
  77. # If the first row is not the desired date, skip as no message was/will be generated for this family as
  78. # this recording is too old
  79. if sorted_metrics_grouped_items.iloc[0]['date_iso'] != date:
  80. continue
  81. # Only one audio (first week), generated default message
  82. if len(metrics_grouped_item) == 1:
  83. recording_filename = metrics_grouped_item.iloc[0]['recording_filename']
  84. message = metric_messages['_default']
  85. # More than one audio file: generate a message
  86. else:
  87. todays_row = sorted_metrics_grouped_items.iloc[0]
  88. previous_row = sorted_metrics_grouped_items.iloc[1]
  89. # Compute the difference between the two sets of metrics
  90. diff_metrics = (todays_row[metrics_columns] - previous_row[metrics_columns])[metrics_of_interest]
  91. diff_metrics = diff_metrics.to_dict()
  92. metrics_evolution = [(metric, diff_metrics[metric], diff_metrics[metric] > 0) for metric in metrics_of_interest]
  93. # Message sorting
  94. metrics_evolution = sorted(metrics_evolution, key=lambda tup: (abs(tup[1]), tup[2]))
  95. template_key = list([tpl_key for (_, _, tpl_key) in metrics_evolution])
  96. recording_filename = metrics_grouped_item.iloc[0]['recording_filename']
  97. message = fill_template(template_key, metric_messages, metrics_evolution)
  98. output_messages.append({'recording_filename': recording_filename,
  99. 'message': message})
  100. df_out = pd.DataFrame(output_messages)
  101. return df_out
  102. def generate_messages(project_path, metrics_file, message_definition, date):
  103. message_out_path = os.path.join(project_path, 'extra', 'messages', 'generated', 'messages_{}.csv'.format(date))
  104. message_out_dir = os.path.dirname(message_out_path)
  105. if not os.path.exists(message_out_dir):
  106. os.makedirs(message_out_dir)
  107. # Make sure we have all the files we need
  108. metrics_recordings, metrics_columns = get_metrics(project_path, metrics_file)
  109. messages = build_messages(metrics_recordings, metrics_columns, message_definition, date)
  110. if not os.path.exists(message_out_path):
  111. if len(messages):
  112. messages.to_csv(message_out_path, index=False, quoting=csv.QUOTE_NONNUMERIC)
  113. logger.info('{} messages generated.'.format(len(messages)))
  114. else:
  115. logger.warning('No message needs to be generated for date {}.'.format(date))
  116. else:
  117. raise IOError('File {} already exists!'.format(message_out_path))
  118. def main(**kwargs):
  119. project_path = os.path.abspath('.')
  120. expected_metrics_file = os.path.join(project_path, 'extra', 'metrics', 'metrics.csv')
  121. expected_message_definition = os.path.join(project_path, 'extra', 'messages', 'definition', 'metrics_messages.yaml')
  122. assert os.path.exists(expected_metrics_file) and os.path.exists(expected_message_definition), \
  123. ValueError('Expected metrics ({}) and/or message definition file ({}) not found. Are you sure to be running this '
  124. 'command from the root of the data set?'.format(expected_metrics_file, expected_message_definition))
  125. generate_messages(project_path=project_path, metrics_file=expected_metrics_file,
  126. message_definition=expected_message_definition, **kwargs)
  127. def _parse_args(argv):
  128. import argparse
  129. parser = argparse.ArgumentParser(description='')
  130. parser.add_argument('--date', type=str, default=datetime.now().strftime("%Y%m%d"),
  131. help='Date for which to generate messages.')
  132. args = parser.parse_args(argv)
  133. return vars(args)
  134. if __name__ == '__main__':
  135. import sys
  136. pgrm_name, argv = sys.argv[0], sys.argv[1:]
  137. args = _parse_args(argv)
  138. logging.basicConfig(level=logging.INFO)
  139. try:
  140. main(**args)
  141. sys.exit(0)
  142. except Exception as e:
  143. logger.exception(e)
  144. sys.exit(1)