12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import matplotlib.pyplot as plt
- import numpy as np
- import seaborn as sns
- from ChildProject.projects import ChildProject
- from ChildProject.annotations import AnnotationManager
- from ChildProject.metrics import segments_to_grid, conf_matrix
- categories = ['Adult', 'Youngster', 'Junk']
- # load VanDam dataset
- project = ChildProject('vandam-data')
- am = AnnotationManager(project)
- annotations = am.annotations[am.annotations['set'].isin(['vtc', 'zoo'])]
- segments = am.get_collapsed_segments(annotations)
- # map VTC speakers to a two-class categorization
- vtc_segments = segments.loc[segments['set'] == 'vtc']
- vtc_segments['speaker_age'] = vtc_segments['speaker_type'].replace({
- 'MAL': 'Adult',
- 'FEM': 'Adult',
- 'CHI': 'Youngster',
- 'OCH': 'Youngster'
- })
- # map Zooniverse-extracted speakers to the same two-class categorization
- zoo_segments = segments.loc[segments['set'] == 'zoo']
- zoo_segments['speaker_age'] = zoo_segments['speaker_age'].replace({
- 'Baby': 'Youngster',
- 'Child': 'Youngster',
- 'Adolescent': 'Adult'
- })
- # create matrices indicating if any of the classes is active at every 50 ms time step (unitizing)
- vtc = segments_to_grid(
- vtc_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
- )
- zoo = segments_to_grid(
- zoo_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
- )
- # keep only the units that have been classified on Zooniverse
- vtc = vtc[zoo[:,-1] == 0][:,:-1]
- zoo = zoo[zoo[:,-1] == 0][:,:-1]
- # compute and show confusion matrices
- confusion_counts = conf_matrix(vtc, zoo)
- plt.rcParams.update({'font.size': 12})
- plt.rc('xtick', labelsize = 10)
- plt.rc('ytick', labelsize = 10)
- fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
- confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None]
- sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
- axes[0].set_xlabel('zoo')
- axes[0].set_ylabel('vtc')
- axes[0].xaxis.set_ticklabels(categories)
- axes[0].yaxis.set_ticklabels(categories)
- confusion_counts = np.transpose(confusion_counts)
- confusion = confusion_counts/np.sum(zoo, axis = 0)[:,None]
- sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
- axes[1].set_xlabel('vtc')
- axes[1].set_ylabel('zoo')
- axes[1].xaxis.set_ticklabels(categories)
- axes[1].yaxis.set_ticklabels(categories)
- plt.savefig('annotations/comparison.png', bbox_inches = 'tight')
|