compare.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from numpy import mod
  2. import pandas as pd
  3. from ChildProject.projects import ChildProject
  4. from ChildProject.annotations import AnnotationManager
  5. from ChildProject.metrics import segments_to_grid, conf_matrix, segments_to_annotation
  6. from pathlib import Path
  7. import seaborn as sns
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. def compare_vandam(set1: str, set2: str) :
  11. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  12. project = ChildProject('inputs/vandam-data')
  13. am = AnnotationManager(project)
  14. am.read()
  15. #get segments that intercept between two annotations
  16. intersection = AnnotationManager.intersection(am.annotations, [set1, set2])
  17. #output directory
  18. dirName = "outputs/compare/" + set1.replace("/","") + "-" + set2.replace("/","")
  19. try:
  20. # Create target Directory
  21. Path(dirName).mkdir(parents= True)
  22. print("Directory " , dirName , " Created ")
  23. except FileExistsError:
  24. print("Directory " , dirName , " already exists")
  25. #opens output file
  26. file= open("{0}/{1}-{2}.txt".format(dirName, set1.replace("/",""), set2.replace("/","")),"a")
  27. for speaker in speakers:
  28. #retrieve contents
  29. segments = am.get_collapsed_segments(intersection)
  30. segments = segments[segments['speaker_type'].isin(pd.Series(speaker))]
  31. ref = segments_to_annotation(segments[segments['set'] == set1], 'speaker_type')
  32. hyp = segments_to_annotation(segments[segments['set'] == set2], 'speaker_type')
  33. if __name__ == '__main__':
  34. #compute metrics
  35. from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
  36. metric = DetectionPrecisionRecallFMeasure()
  37. detail = metric.compute_components(ref, hyp)
  38. precision, recall, f = metric.compute_metrics(detail)
  39. #saves metrics to output file
  40. metric_output = "precision: {0} / recall : {1} / f: {2}\n".format(precision, recall, f)
  41. file.write(speaker + ": " + metric_output)
  42. print("Metrics [precision & recall & f] saved! for {0}".format(speaker))
  43. file.close
  44. segments = am.get_collapsed_segments(intersection)
  45. #generates segments
  46. set1_segm = segments_to_grid(segments[segments['set'] == set1], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  47. set2_segm = segments_to_grid(segments[segments['set'] == set2], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  48. speakers.extend(['none'])
  49. confusion_counts = conf_matrix(set1_segm, set2_segm)
  50. plt.rcParams.update({'font.size': 12})
  51. plt.rc('xtick', labelsize = 10)
  52. plt.rc('ytick', labelsize = 10)
  53. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  54. confusion = confusion_counts/np.sum(set1_segm, axis = 0)[:,None]
  55. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  56. axes[0].set_xlabel(set2)
  57. axes[0].set_ylabel(set1)
  58. axes[0].xaxis.set_ticklabels(speakers)
  59. axes[0].yaxis.set_ticklabels(speakers)
  60. confusion_counts = np.transpose(confusion_counts)
  61. confusion = confusion_counts/np.sum(set2_segm, axis = 0)[:,None]
  62. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  63. axes[1].set_xlabel(set1)
  64. axes[1].set_ylabel(set2)
  65. axes[1].xaxis.set_ticklabels(speakers)
  66. axes[1].yaxis.set_ticklabels(speakers)
  67. plt.savefig("{0}/{1}-{2}-confusion-matrix.jpg".format(dirName, set1.replace("/",""), set2.replace("/",""), bbox_inches = 'tight'))
  68. #matrix_df = pd.DataFrame(conf_matrix(set1_segm, set2_segm))
  69. # matrix_df.to_csv("{0}/{1}-{2}-confusion-matrix.csv".format(dirName, set1.replace("/",""), set2.replace("/","")), mode = "w", index=False)
  70. print("Confusion matrix saved for {0} and {1}!".format(set1, set2))
  71. compare_vandam('eaf', 'cha')
  72. compare_vandam('eaf', 'cha/aligned')
  73. compare_vandam('cha', 'cha/aligned')