compare.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_grid, conf_matrix, segments_to_annotation
  5. from pathlib import Path
  6. def compare_vandam(set1: str, set2: str) :
  7. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  8. project = ChildProject('vandam-data')
  9. am = AnnotationManager(project)
  10. am.read()
  11. #get segments that intercept between two annotations
  12. intersection = AnnotationManager.intersection(am.annotations, [set1, set2])
  13. #retrieve contents
  14. segments = am.get_collapsed_segments(intersection)
  15. segments = segments[segments['speaker_type'].isin(speakers)]
  16. set1_segm = segments_to_grid(segments[segments['set'] == set1], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  17. set2_segm = segments_to_grid(segments[segments['set'] == set2], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  18. print(set1_segm.shape)
  19. print(set2_segm)
  20. ref = segments_to_annotation(segments[segments['set'] == set1], 'speaker_type')
  21. hyp = segments_to_annotation(segments[segments['set'] == set2], 'speaker_type')
  22. if __name__ == '__main__':
  23. from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
  24. metric = DetectionPrecisionRecallFMeasure()
  25. detail = metric.compute_components(ref, hyp)
  26. precision, recall, f = metric.compute_metrics(detail)
  27. dirName = "outputs/compare/" + set1.replace("/","") + "-" + set2.replace("/","")
  28. try:
  29. # Create target Directory
  30. Path(dirName).mkdir(parents= True)
  31. print("Directory " , dirName , " Created ")
  32. except FileExistsError:
  33. print("Directory " , dirName , " already exists")
  34. file= open("{0}/{1}-{2}.txt".format(dirName, set1.replace("/",""), set2.replace("/","")),"w+")
  35. # metric_output = str(f'{precision:.2f}/{recall:.2f}/{f:.2f}')
  36. metric_output = "precision: {0}/recall : {1}/ f: {2}".format(precision, recall, f)
  37. file.write(metric_output)
  38. file.close
  39. print("Metrics [precision & recall & f] saved!")
  40. compare_vandam('cha','eaf')
  41. compare_vandam('cha', 'cha/aligned')