confusion_matrix.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_grid, conf_matrix
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.metrics import confusion_matrix
  8. from sklearn.preprocessing import normalize
  9. import seaborn as sns
  10. import matplotlib.pyplot as plt
  11. import sys
  12. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  13. path = sys.argv[1]
  14. project = ChildProject(path)
  15. am = AnnotationManager(project)
  16. am.read()
  17. intersection = AnnotationManager.intersection(am.annotations, ['vtc', 'its'])
  18. segments = am.get_collapsed_segments(intersection)
  19. segments = segments[segments['speaker_type'].isin(speakers)]
  20. vtc = segments_to_grid(segments[segments['set'] == 'vtc'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  21. its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  22. speakers.extend(['overlap', 'none'])
  23. confusion_counts = conf_matrix(its, vtc, speakers)
  24. plt.rcParams.update({'font.size': 12})
  25. plt.rc('xtick', labelsize = 10)
  26. plt.rc('ytick', labelsize = 10)
  27. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  28. confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
  29. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  30. axes[0].set_xlabel('its')
  31. axes[0].set_ylabel('vtc')
  32. axes[0].xaxis.set_ticklabels(speakers)
  33. axes[0].yaxis.set_ticklabels(speakers)
  34. confusion_counts = np.transpose(confusion_counts)
  35. confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
  36. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  37. axes[1].set_xlabel('vtc')
  38. axes[1].set_ylabel('its')
  39. axes[1].xaxis.set_ticklabels(speakers)
  40. axes[1].yaxis.set_ticklabels(speakers)
  41. plt.savefig('Fig5.pdf', bbox_inches = 'tight')