confusion_matrix2.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from ChildProject.projects import ChildProject
  2. from ChildProject.annotations import AnnotationManager
  3. from ChildProject.metrics import segments_to_grid, conf_matrix
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. speakers = ['CHI', 'FEM', 'OCH'] #PUT HERE THE LABELS YOU WANT TO INCLUDE
  7. project = ChildProject('.')
  8. am = AnnotationManager(project)
  9. am.read()
  10. SET_1 = 'eaf_2023/ak' #CHANGE THE FOLDER TO WHERE THE MANUAL ANNOTATIONS ARE
  11. SET_2 = 'vtc' #CHANGE THE FOLDER TO WHERE VTC GENERATED ANNOTATIONS ARE
  12. recording = ['77021_5/V20230127-070014.WAV']
  13. intersection = AnnotationManager.intersection(am.annotations, [SET_1, SET_2])
  14. intersection = intersection[intersection['recording_filename'].isin(recording)]
  15. print(intersection)
  16. segments = am.get_collapsed_segments(intersection)
  17. segments = segments[segments['speaker_type'].isin(speakers)]
  18. # Y
  19. #vtc = segments_to_grid(segments[segments['set'] == 'vtc'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers,none = False)
  20. vtc = segments_to_grid(segments[segments['set'] == SET_1], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  21. # X
  22. #its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers,none = False)
  23. its = segments_to_grid(segments[segments['set'] == SET_2], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  24. confusion_counts = conf_matrix(vtc, its)
  25. all_positive = np.delete(confusion_counts, -1, 0)
  26. all_negative = np.delete(confusion_counts, -1, 1)
  27. precision = np.delete(all_negative, -1, 0).trace() / all_positive.sum()
  28. recall = np.delete(all_negative, -1, 0).trace() / all_negative.sum()
  29. fscore = (2 * precision * recall) / (precision + recall)
  30. scores = {}
  31. i=0
  32. with open('scores.txt','w') as f:
  33. for label in speakers:
  34. rec = confusion_counts[i,i] / confusion_counts[ :,i].sum()
  35. preci = confusion_counts[i,i] / confusion_counts[i,: ].sum()
  36. fsc = (2 * preci * rec) / (preci + rec)
  37. #scores[label] = (preci, rec, fsc)
  38. f.write(f"{label}: precision {preci}; recall {rec}; F-score {fsc}\n")
  39. i+=1
  40. f.write(f"General: precision {precision}; recall {recall}; F-score {fscore}\n")
  41. #print(f"General: precision {precision}; recall {recall}; F-score {fscore}")
  42. print(f"Results written to scores.txt")
  43. normalized = confusion_counts
  44. speakers.append("None")
  45. speakers = [""] + speakers
  46. fig, ax = plt.subplots(figsize=(7.5, 7.5))
  47. ax.set_xticklabels(speakers)
  48. ax.set_yticklabels(speakers)
  49. ax.matshow(normalized, cmap=plt.cm.Blues, alpha=0.3)
  50. for i in range(normalized.shape[0]):
  51. for j in range(normalized.shape[1]):
  52. ax.text(x=j, y=i,s=round(normalized[i, j],3), va='center', ha='center', size='xx-large')
  53. ax.xaxis.set_label_position("top")
  54. # set Y and X
  55. plt.ylabel(SET_1, fontsize=18)
  56. plt.xlabel(SET_2, fontsize=18)
  57. plt.title('Confusion Matrix', fontsize=18)
  58. plt.savefig('conf_matrix.png')
  59. normalized = confusion_counts/(np.sum(vtc, axis = 0)[:,None])
  60. fig, ax = plt.subplots(figsize=(7.5, 7.5))
  61. ax.set_xticklabels(speakers)
  62. ax.set_yticklabels(speakers)
  63. ax.matshow(normalized, cmap=plt.cm.Blues, alpha=0.3)
  64. for i in range(normalized.shape[0]):
  65. for j in range(normalized.shape[1]):
  66. ax.text(x=j, y=i,s=round(normalized[i, j],3), va='center', ha='center', size='xx-large')
  67. ax.xaxis.set_label_position("top")
  68. plt.ylabel(SET_1, fontsize=18)
  69. plt.xlabel(SET_2, fontsize=18)
  70. plt.title('Confusion Matrix', fontsize=18)
  71. plt.savefig('conf_matrix_normalized.png')