compare.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import seaborn as sns
  4. from ChildProject.projects import ChildProject
  5. from ChildProject.annotations import AnnotationManager
  6. from ChildProject.metrics import segments_to_grid, conf_matrix
  7. categories = ['Adult', 'Youngster', 'Junk']
  8. # load VanDam dataset
  9. project = ChildProject('vandam-data')
  10. am = AnnotationManager(project)
  11. annotations = am.annotations[am.annotations['set'].isin(['vtc', 'zoo'])]
  12. segments = am.get_collapsed_segments(annotations)
  13. # map VTC speakers to a two-class categorization
  14. vtc_segments = segments.loc[segments['set'] == 'vtc']
  15. vtc_segments['speaker_age'] = vtc_segments['speaker_type'].replace({
  16. 'MAL': 'Adult',
  17. 'FEM': 'Adult',
  18. 'CHI': 'Youngster',
  19. 'OCH': 'Youngster'
  20. })
  21. # map Zooniverse-extracted speakers to the same two-class categorization
  22. zoo_segments = segments.loc[segments['set'] == 'zoo']
  23. zoo_segments['speaker_age'] = zoo_segments['speaker_age'].replace({
  24. 'Baby': 'Youngster',
  25. 'Child': 'Youngster',
  26. 'Adolescent': 'Adult'
  27. })
  28. # create matrices indicating if any of the classes is active at every 50 ms time step (unitizing)
  29. vtc = segments_to_grid(
  30. vtc_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
  31. )
  32. zoo = segments_to_grid(
  33. zoo_segments, 0, segments['segment_offset'].max(), 50, 'speaker_age', categories, none = True
  34. )
  35. # keep only the units that have been classified on Zooniverse
  36. vtc = vtc[zoo[:,-1] == 0][:,:-1]
  37. zoo = zoo[zoo[:,-1] == 0][:,:-1]
  38. # compute and show confusion matrices
  39. confusion_counts = conf_matrix(vtc, zoo)
  40. plt.rcParams.update({'font.size': 12})
  41. plt.rc('xtick', labelsize = 10)
  42. plt.rc('ytick', labelsize = 10)
  43. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  44. confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None]
  45. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  46. axes[0].set_xlabel('zoo')
  47. axes[0].set_ylabel('vtc')
  48. axes[0].xaxis.set_ticklabels(categories)
  49. axes[0].yaxis.set_ticklabels(categories)
  50. confusion_counts = np.transpose(confusion_counts)
  51. confusion = confusion_counts/np.sum(zoo, axis = 0)[:,None]
  52. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  53. axes[1].set_xlabel('vtc')
  54. axes[1].set_ylabel('zoo')
  55. axes[1].xaxis.set_ticklabels(categories)
  56. axes[1].yaxis.set_ticklabels(categories)
  57. plt.savefig('annotations/comparison.png', bbox_inches = 'tight')