compare.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import argparse
  2. import numpy as np
  3. import os
  4. import pandas as pd
  5. import seaborn as sns
  6. import matplotlib.pyplot as plt
  7. from ChildProject.projects import ChildProject
  8. from ChildProject.annotations import AnnotationManager
  9. from ChildProject.metrics import segments_to_grid, conf_matrix
  10. categories = ['Adult', 'Youngster', 'Junk']
  11. # load VanDam dataset
  12. project = ChildProject('vandam-data')
  13. am = AnnotationManager(project)
  14. annotations = am.annotations[am.annotations['set'].isin(['vtc', 'zoo'])]
  15. segments = am.get_collapsed_segments(annotations)
  16. # map VTC speakers to a two-class categorization
  17. vtc_segments = segments.loc[segments['set'] == 'vtc']
  18. vtc_segments['speaker_age'] = vtc_segments['speaker_type'].replace({
  19. 'MAL': 'Adult',
  20. 'FEM': 'Adult',
  21. 'CHI': 'Youngster',
  22. 'OCH': 'Youngster'
  23. })
  24. # map Zooniverse-extracted speakers to the same two-class categorization
  25. zoo_segments = segments.loc[segments['set'] == 'zoo']
  26. zoo_segments['speaker_age'] = zoo_segments['speaker_age'].replace({
  27. 'Baby': 'Youngster',
  28. 'Child': 'Youngster',
  29. 'Adolescent': 'Adult'
  30. })
  31. # create matrices indicating if any of the classes is active at every 50 ms time step (unitizing)
  32. vtc = segments_to_grid(
  33. vtc_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
  34. )
  35. zoo = segments_to_grid(
  36. zoo_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
  37. )
  38. # keep only the units that have been classified on Zooniverse
  39. vtc = vtc[zoo[:,-1] == 0][:,:-1]
  40. zoo = zoo[zoo[:,-1] == 0][:,:-1]
  41. # compute and show confusion matrices
  42. confusion_counts = conf_matrix(vtc, zoo)
  43. plt.rcParams.update({'font.size': 12})
  44. plt.rc('xtick', labelsize = 10)
  45. plt.rc('ytick', labelsize = 10)
  46. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  47. confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None]
  48. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  49. axes[0].set_xlabel('zoo')
  50. axes[0].set_ylabel('vtc')
  51. axes[0].xaxis.set_ticklabels(categories)
  52. axes[0].yaxis.set_ticklabels(categories)
  53. confusion_counts = np.transpose(confusion_counts)
  54. confusion = confusion_counts/np.sum(zoo, axis = 0)[:,None]
  55. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  56. axes[1].set_xlabel('vtc')
  57. axes[1].set_ylabel('zoo')
  58. axes[1].xaxis.set_ticklabels(categories)
  59. axes[1].yaxis.set_ticklabels(categories)
  60. plt.savefig('annotations/comparison.png', bbox_inches = 'tight')