|
@@ -0,0 +1,74 @@
|
|
|
+import argparse
|
|
|
+import numpy as np
|
|
|
+import os
|
|
|
+import pandas as pd
|
|
|
+
|
|
|
+import seaborn as sns
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
+from ChildProject.projects import ChildProject
|
|
|
+from ChildProject.annotations import AnnotationManager
|
|
|
+from ChildProject.metrics import segments_to_grid, conf_matrix
|
|
|
+
|
|
|
+categories = ['Adult', 'Youngster', 'Junk']
|
|
|
+
|
|
|
+# load VanDam dataset
|
|
|
+project = ChildProject('vandam-data')
|
|
|
+am = AnnotationManager(project)
|
|
|
+
|
|
|
+annotations = am.annotations[am.annotations['set'].isin(['vtc', 'zoo'])]
|
|
|
+segments = am.get_collapsed_segments(annotations)
|
|
|
+
|
|
|
+vtc_segments = segments.loc[segments['set'] == 'vtc']
|
|
|
+vtc_segments['speaker_age'] = vtc_segments['speaker_type'].replace({
|
|
|
+ 'MAL': 'Adult',
|
|
|
+ 'FEM': 'Adult',
|
|
|
+ 'CHI': 'Youngster',
|
|
|
+ 'OCH': 'Youngster'
|
|
|
+})
|
|
|
+
|
|
|
+zoo_segments = segments.loc[segments['set'] == 'zoo']
|
|
|
+zoo_segments['speaker_age'] = zoo_segments['speaker_age'].replace({
|
|
|
+ 'Baby': 'Youngster',
|
|
|
+ 'Child': 'Youngster',
|
|
|
+ 'Adolescent': 'Adult'
|
|
|
+})
|
|
|
+
|
|
|
+vtc = segments_to_grid(
|
|
|
+ vtc_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
|
|
|
+)
|
|
|
+zoo = segments_to_grid(
|
|
|
+ zoo_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
|
|
|
+)
|
|
|
+
|
|
|
+vtc = vtc[zoo[:,-1] == 0][:,:-1]
|
|
|
+zoo = zoo[zoo[:,-1] == 0][:,:-1]
|
|
|
+
|
|
|
+confusion_counts = conf_matrix(vtc, zoo)
|
|
|
+
|
|
|
+print(confusion_counts)
|
|
|
+
|
|
|
+plt.rcParams.update({'font.size': 12})
|
|
|
+plt.rc('xtick', labelsize = 10)
|
|
|
+plt.rc('ytick', labelsize = 10)
|
|
|
+
|
|
|
+fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
|
|
|
+
|
|
|
+confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None]
|
|
|
+
|
|
|
+sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
|
|
|
+axes[0].set_xlabel('zoo')
|
|
|
+axes[0].set_ylabel('vtc')
|
|
|
+axes[0].xaxis.set_ticklabels(categories)
|
|
|
+axes[0].yaxis.set_ticklabels(categories)
|
|
|
+
|
|
|
+confusion_counts = np.transpose(confusion_counts)
|
|
|
+confusion = confusion_counts/np.sum(zoo, axis = 0)[:,None]
|
|
|
+
|
|
|
+sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
|
|
|
+axes[1].set_xlabel('vtc')
|
|
|
+axes[1].set_ylabel('zoo')
|
|
|
+axes[1].xaxis.set_ticklabels(categories)
|
|
|
+axes[1].yaxis.set_ticklabels(categories)
|
|
|
+
|
|
|
+plt.savefig('annotations/comparison.png', bbox_inches = 'tight')
|