confusion_matrix.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_grid, conf_matrix
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.metrics import confusion_matrix
  8. from sklearn.preprocessing import normalize
  9. import seaborn as sns
  10. import matplotlib
  11. import matplotlib.pyplot as plt
  12. matplotlib.use("pgf")
  13. matplotlib.rcParams.update({
  14. "pgf.texsystem": "pdflatex",
  15. 'font.family': 'serif',
  16. "font.serif" : "Times New Roman",
  17. 'text.usetex': True,
  18. 'pgf.rcfonts': False,
  19. })
  20. import sys
  21. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  22. path = sys.argv[1]
  23. if __name__ == '__main__':
  24. project = ChildProject(path)
  25. am = AnnotationManager(project)
  26. am.read()
  27. intersection = AnnotationManager.intersection(am.annotations, ['vtc', 'eaf'])
  28. segments = am.get_collapsed_segments(intersection)
  29. segments = segments[segments['speaker_type'].isin(speakers)]
  30. vtc = segments_to_grid(segments[segments['set'] == 'vtc'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  31. eaf = segments_to_grid(segments[segments['set'] == 'eaf'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  32. speakers.extend(['none'])
  33. confusion_counts = conf_matrix(vtc, eaf)
  34. plt.rcParams.update({'font.size': 12})
  35. plt.rc('xtick', labelsize = 10)
  36. plt.rc('ytick', labelsize = 10)
  37. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  38. confusion = confusion_counts/np.sum(vtc, axis = 0)[:,None]
  39. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  40. axes[0].set_xlabel('eaf')
  41. axes[0].set_ylabel('vtc')
  42. axes[0].xaxis.set_ticklabels(speakers)
  43. axes[0].yaxis.set_ticklabels(speakers)
  44. confusion_counts = np.transpose(confusion_counts)
  45. confusion = confusion_counts/np.sum(eaf, axis = 0)[:,None]
  46. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  47. axes[1].set_xlabel('vtc')
  48. axes[1].set_ylabel('eaf')
  49. axes[1].xaxis.set_ticklabels(speakers)
  50. axes[1].yaxis.set_ticklabels(speakers)
  51. plt.savefig('Fig7.pdf', bbox_inches = 'tight')