recall.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from ChildProject.projects import ChildProject
  2. from ChildProject.annotations import AnnotationManager
  3. from ChildProject.metrics import segments_to_annotation
  4. from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.metrics import confusion_matrix
  8. from sklearn.preprocessing import normalize
  9. import random
  10. import seaborn as sns
  11. import matplotlib.pyplot as plt
  12. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  13. sets = ['its', 'vtc (conf 50%)', 'vtc (drop 50%)', 'vtc (conf 75%)', 'vtc (drop 75%)']
  14. project = ChildProject('.')
  15. am = AnnotationManager(project)
  16. am.read()
  17. def confusion(segments, prob):
  18. segments['speaker_type'] = segments['speaker_type'].apply(
  19. lambda s: random.choice(speakers) if random.random() < prob else s
  20. )
  21. return segments
  22. def drop(segments, prob):
  23. return segments.sample(frac = 1-prob)
  24. intersection = AnnotationManager.intersection(am.annotations, ['vtc', 'its'])
  25. segments = am.get_collapsed_segments(intersection)
  26. segments = segments[segments['speaker_type'].isin(speakers)]
  27. segments.sort_values(['segment_onset', 'segment_offset']).to_csv('test.csv', index = False)
  28. conf50 = segments[segments['set'] == 'vtc'].copy()
  29. conf50 = confusion(conf50, 0.5)
  30. conf50['set'] = 'vtc (conf 50%)'
  31. conf75 = segments[segments['set'] == 'vtc'].copy()
  32. conf75 = confusion(conf75, 0.75)
  33. conf75['set'] = 'vtc (conf 75%)'
  34. drop50 = segments[segments['set'] == 'vtc'].copy()
  35. drop50 = drop(drop50, 0.5)
  36. drop50['set'] = 'vtc (drop 50%)'
  37. drop75 = segments[segments['set'] == 'vtc'].copy()
  38. drop75 = drop(drop75, 0.75)
  39. drop75['set'] = 'vtc (drop 75%)'
  40. segments = pd.concat([segments, conf50, conf75, drop50, drop75])
  41. metric = DetectionPrecisionRecallFMeasure()
  42. scores = []
  43. for speaker in speakers:
  44. ref = segments_to_annotation(segments[(segments['set'] == 'vtc') & (segments['speaker_type'] == speaker)], 'speaker_type')
  45. for s in sets:
  46. hyp = segments_to_annotation(segments[(segments['set'] == s) & (segments['speaker_type'] == speaker)], 'speaker_type')
  47. detail = metric.compute_components(ref, hyp)
  48. precision, recall, f = metric.compute_metrics(detail)
  49. scores.append({
  50. 'set': s,
  51. 'speaker': speaker,
  52. 'recall': recall,
  53. 'precision': precision,
  54. 'f': f
  55. })
  56. scores = pd.DataFrame(scores)
  57. scores.to_csv('scores.csv', index = False)
  58. plt.rcParams.update({'font.size': 12})
  59. plt.rc('xtick', labelsize = 10)
  60. plt.rc('ytick', labelsize = 10)
  61. fig, axes = plt.subplots(nrows = 2, ncols = 2, figsize=(6.4*2, 4.8*2))
  62. plt.savefig('Fig4.pdf', bbox_inches = 'tight')