confusion_matrix.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from ChildProject.projects import ChildProject
  2. from ChildProject.annotations import AnnotationManager
  3. from ChildProject.metrics import segments_to_grid, conf_matrix
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  7. project = ChildProject('.')
  8. am = AnnotationManager(project)
  9. am.read()
  10. SET_1 = 'vtc'
  11. SET_2 = 'eaf'
  12. intersection = AnnotationManager.intersection(am.annotations, [SET_1, SET_2])
  13. segments = am.get_collapsed_segments(intersection)
  14. segments = segments[segments['speaker_type'].isin(speakers)]
  15. # Y
  16. #vtc = segments_to_grid(segments[segments['set'] == 'vtc'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers,none = False)
  17. vtc = segments_to_grid(segments[segments['set'] == SET_1], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  18. # X
  19. #its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers,none = False)
  20. its = segments_to_grid(segments[segments['set'] == SET_2], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  21. confusion_counts = conf_matrix(vtc, its)
  22. print(confusion_counts)
  23. all_positive = np.delete(confusion_counts, -1, 0)
  24. all_negative = np.delete(confusion_counts, -1, 1)
  25. precision = np.delete(all_negative, -1, 0).trace() / all_positive.sum()
  26. recall = np.delete(all_negative, -1, 0).trace() / all_negative.sum()
  27. fscore = (2 * precision * recall) / (precision + recall)
  28. scores = {}
  29. i=0
  30. with open('extra/scores.txt','w') as f:
  31. for label in speakers:
  32. rec = confusion_counts[i,i] / confusion_counts[ :,i].sum()
  33. preci = confusion_counts[i,i] / confusion_counts[i,: ].sum()
  34. fsc = (2 * preci * rec) / (preci + rec)
  35. #scores[label] = (preci, rec, fsc)
  36. f.write(f"{label}: precision {preci}; recall {rec}; F-score {fsc}\n")
  37. i+=1
  38. f.write(f"General: precision {precision}; recall {recall}; F-score {fscore}\n")
  39. #print(f"General: precision {precision}; recall {recall}; F-score {fscore}")
  40. print(f"Results written to scores.txt")
  41. normalized = confusion_counts
  42. speakers.append("None")
  43. speakers = [""] + speakers
  44. print(normalized)
  45. print(speakers)
  46. fig, ax = plt.subplots(figsize=(7.5, 7.5))
  47. ax.set_xticklabels(speakers)
  48. ax.set_yticklabels(speakers)
  49. ax.matshow(normalized, cmap=plt.cm.Blues, alpha=0.3)
  50. for i in range(normalized.shape[0]):
  51. for j in range(normalized.shape[1]):
  52. ax.text(x=j, y=i,s=round(normalized[i, j],3), va='center', ha='center', size='xx-large')
  53. ax.xaxis.set_label_position("top")
  54. # set Y and X
  55. plt.ylabel(SET_1, fontsize=18)
  56. plt.xlabel(SET_2, fontsize=18)
  57. plt.title('Confusion Matrix', fontsize=18)
  58. plt.savefig('extra/conf_matrix.png')
  59. normalized = confusion_counts/(np.sum(vtc, axis = 0)[:,None])
  60. fig, ax = plt.subplots(figsize=(7.5, 7.5))
  61. ax.set_xticklabels(speakers)
  62. ax.set_yticklabels(speakers)
  63. ax.matshow(normalized, cmap=plt.cm.Blues, alpha=0.3)
  64. for i in range(normalized.shape[0]):
  65. for j in range(normalized.shape[1]):
  66. ax.text(x=j, y=i,s=round(normalized[i, j],3), va='center', ha='center', size='xx-large')
  67. ax.xaxis.set_label_position("top")
  68. plt.ylabel(SET_1, fontsize=18)
  69. plt.xlabel(SET_2, fontsize=18)
  70. plt.title('Confusion Matrix', fontsize=18)
  71. plt.savefig('extra/conf_matrix_normalized.png')