123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- #!usr/bin/env python
- # -*- coding: utf8 -*-
- import csv
- # -----------------------------------------------------------------------------
- # File: messages.py (as part of project URUMETRICS)
- # Created: 29/07/2022 15:35
- # Last Modified: 29/07/2022 15:35
- # -----------------------------------------------------------------------------
- # Author: William N. Havard
- # Postdoctoral Researcher
- #
- # Mail : william.havard@ens.fr / william.havard@gmail.com
- #
- # Institution: ENS / Laboratoire de Sciences Cognitives et Psycholinguistique
- #
- # ------------------------------------------------------------------------------
- # Description:
- # •
- # -----------------------------------------------------------------------------
- import logging
- import os
- import sys
- from datetime import datetime
- from pprint import pprint
- import pandas as pd
- import yaml
- from ChildProject.annotations import AnnotationManager
- from ChildProject.projects import ChildProject
- logger = logging.getLogger(__name__)
- def _read_yaml(yaml_path):
- with open(yaml_path, 'r') as in_yaml:
- data = yaml.load(in_yaml, Loader=yaml.FullLoader)
- return data
- def get_metrics(project_path, metrics_file):
- project = ChildProject(project_path)
- am = AnnotationManager(project)
- am.read()
- # Read metrics file and get metrics columns
- metrics = pd.read_csv(metrics_file)
- metrics_columns = list(set(metrics.columns) - set(['recording_filename', 'child_id']))
- # Merge with recordings to get date_iso
- metrics_recordings = pd.merge(metrics, project.recordings, on='recording_filename', suffixes=('', '_drop'))
- metrics_recordings.drop([col for col in metrics_recordings.columns if 'drop' in col], axis=1, inplace=True)
- # Handle file with the same child_id that have the same date -> keep the longest one
- metrics_recordings = (metrics_recordings.groupby(['child_id', 'date_iso'], as_index=False)
- # Keep only the first segment for each candidate speaker
- .apply(lambda rows: (rows.sort_values(by='start_time', ascending=False) # take last instead
- .head(n=1))))
- return metrics_recordings, metrics_columns
- def fill_template(template_key, messages, metrics_evolution):
- template = messages['_{}_{}'.format(*template_key)]
- for positivity_item_index, (positivity_item, _, positivity_direction) in enumerate(metrics_evolution, 1):
- template = template.replace('#{}'.format(positivity_item_index),
- messages[positivity_item][positivity_direction])
- message_variables = [msg for msg in messages if msg.startswith('_')]
- for message_variable in message_variables:
- message_variable = message_variable[1:]
- template = template.replace('#{}'.format(message_variable),
- messages['_{}'.format(message_variable)])
- return template
- def build_messages(metrics_recordings, metrics_columns, message_file_path, date):
- try:
- date = datetime.strptime(date, "%Y%m%d").strftime("%Y-%m-%d")
- except:
- raise ValueError('--date format should be YYYYMMDD without any separators.')
- # Get metrics of interest and messages
- metric_messages = _read_yaml(message_file_path)
- metrics_of_interest = [item for item in list(metric_messages.keys()) if not item.startswith('_')]
- # Keep only rows for which the date is below or equal to the one we want
- metrics_recordings = metrics_recordings[metrics_recordings['date_iso'] <= date]
- # Generate messages
- output_messages = []
- metrics_grouped = metrics_recordings.groupby('child_id', as_index=False)
- for _, metrics_grouped_item in metrics_grouped:
- sorted_metrics_grouped_items = metrics_grouped_item.sort_values(by=['date_iso', 'imported_at'],
- ascending=False)
- # If the first row is not the desired date, skip as no message was/will be generated for this family as
- # this recording is too old
- if sorted_metrics_grouped_items.iloc[0]['date_iso'] != date:
- continue
- # Only one audio (first week), generated default message
- if len(metrics_grouped_item) == 1:
- recording_filename = metrics_grouped_item.iloc[0]['recording_filename']
- message = metric_messages['_default']
- # More than one audio file: generate a message
- else:
- todays_row = sorted_metrics_grouped_items.iloc[0]
- previous_row = sorted_metrics_grouped_items.iloc[1]
- # Compute the difference between the two sets of metrics
- diff_metrics = (todays_row[metrics_columns] - previous_row[metrics_columns])[metrics_of_interest]
- diff_metrics = diff_metrics.to_dict()
- metrics_evolution = [(metric, diff_metrics[metric], diff_metrics[metric] > 0) for metric in metrics_of_interest]
- # Message sorting
- metrics_evolution = sorted(metrics_evolution, key=lambda tup: (abs(tup[1]), tup[2]))
- template_key = list([tpl_key for (_, _, tpl_key) in metrics_evolution])
- recording_filename = metrics_grouped_item.iloc[0]['recording_filename']
- message = fill_template(template_key, metric_messages, metrics_evolution)
- output_messages.append({'recording_filename': recording_filename,
- 'message': message})
- df_out = pd.DataFrame(output_messages)
- return df_out
- def generate_messages(project_path, metrics_file, message_definition, date):
- message_out_path = os.path.join(project_path, 'extra', 'messages', 'generated', 'messages_{}.csv'.format(date))
- message_out_dir = os.path.dirname(message_out_path)
- if not os.path.exists(message_out_dir):
- os.makedirs(message_out_dir)
- # Make sure we have all the files we need
- metrics_recordings, metrics_columns = get_metrics(project_path, metrics_file)
- messages = build_messages(metrics_recordings, metrics_columns, message_definition, date)
- if not os.path.exists(message_out_path):
- if len(messages):
- messages.to_csv(message_out_path, index=False, quoting=csv.QUOTE_NONNUMERIC)
- logger.info('{} messages generated.'.format(len(messages)))
- else:
- logger.warning('No message needs to be generated for date {}.'.format(date))
- else:
- raise IOError('File {} already exists!'.format(message_out_path))
- def main(**kwargs):
- project_path = os.path.abspath('.')
- expected_metrics_file = os.path.join(project_path, 'extra', 'metrics', 'metrics.csv')
- expected_message_definition = os.path.join(project_path, 'extra', 'messages', 'definition', 'metrics_messages.yaml')
- assert os.path.exists(expected_metrics_file) and os.path.exists(expected_message_definition), \
- ValueError('Expected metrics ({}) and/or message definition file ({}) not found. Are you sure to be running this '
- 'command from the root of the data set?'.format(expected_metrics_file, expected_message_definition))
- generate_messages(project_path=project_path, metrics_file=expected_metrics_file,
- message_definition=expected_message_definition, **kwargs)
- def _parse_args(argv):
- import argparse
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--date', type=str, default=datetime.now().strftime("%Y%m%d"),
- help='Date for which to generate messages.')
- args = parser.parse_args(argv)
- return vars(args)
- if __name__ == '__main__':
- import sys
- pgrm_name, argv = sys.argv[0], sys.argv[1:]
- args = _parse_args(argv)
- logging.basicConfig(level=logging.INFO)
- try:
- main(**args)
- sys.exit(0)
- except Exception as e:
- logger.exception(e)
- sys.exit(1)
|