import matplotlib.pyplot as plt import numpy as np import seaborn as sns from ChildProject.projects import ChildProject from ChildProject.annotations import AnnotationManager from ChildProject.metrics import segments_to_grid, conf_matrix categories = ['Adult', 'Youngster', 'Junk'] # load VanDam dataset project = ChildProject('vandam-data') am = AnnotationManager(project) annotations = am.annotations[am.annotations['set'].isin(['vtc', 'zoo'])] segments = am.get_collapsed_segments(annotations) # map VTC speakers to a two-class categorization vtc_segments = segments.loc[segments['set'] == 'vtc'] vtc_segments['speaker_age'] = vtc_segments['speaker_type'].replace({ 'MAL': 'Adult', 'FEM': 'Adult', 'CHI': 'Youngster', 'OCH': 'Youngster' }) # map Zooniverse-extracted speakers to the same two-class categorization zoo_segments = segments.loc[segments['set'] == 'zoo'] zoo_segments['speaker_age'] = zoo_segments['speaker_age'].replace({ 'Baby': 'Youngster', 'Child': 'Youngster', 'Adolescent': 'Adult' }) # create matrices indicating if any of the classes is active at every 50 ms time step (unitizing) vtc = segments_to_grid( vtc_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True ) zoo = segments_to_grid( zoo_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True ) # keep only the units that have been classified on Zooniverse vtc = vtc[zoo[:,-1] == 0][:,:-1] zoo = zoo[zoo[:,-1] == 0][:,:-1] # compute and show confusion matrices confusion_counts = conf_matrix(vtc, zoo) plt.rcParams.update({'font.size': 12}) plt.rc('xtick', labelsize = 10) plt.rc('ytick', labelsize = 10) fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8)) confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None] sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds') axes[0].set_xlabel('zoo') axes[0].set_ylabel('vtc') axes[0].xaxis.set_ticklabels(categories) axes[0].yaxis.set_ticklabels(categories) confusion_counts = np.transpose(confusion_counts) confusion = confusion_counts/np.sum(zoo, axis = 0)[:,None] sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds') axes[1].set_xlabel('vtc') axes[1].set_ylabel('zoo') axes[1].xaxis.set_ticklabels(categories) axes[1].yaxis.set_ticklabels(categories) plt.savefig('annotations/comparison.png', bbox_inches = 'tight')