recall.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import os
  9. import pandas as pd
  10. import random
  11. import sys
  12. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  13. sets = ['its', 'vtc (conf 50%)', 'vtc (drop 50%)', 'vtc (conf 75%)', 'vtc (drop 75%)']
  14. def confusion(segments, prob):
  15. segments['speaker_type'] = segments['speaker_type'].apply(
  16. lambda s: random.choice(speakers) if random.random() < prob else s
  17. )
  18. return segments
  19. def drop(segments, prob):
  20. return segments.sample(frac = 1-prob)
  21. if not os.path.exists('scores.csv'):
  22. path = sys.argv[1]
  23. project = ChildProject(path)
  24. am = AnnotationManager(project)
  25. am.read()
  26. intersection = AnnotationManager.intersection(am.annotations, ['vtc', 'its'])
  27. segments = am.get_collapsed_segments(intersection)
  28. segments = segments[segments['speaker_type'].isin(speakers)]
  29. segments.sort_values(['segment_onset', 'segment_offset']).to_csv('test.csv', index = False)
  30. conf50 = segments[segments['set'] == 'vtc'].copy()
  31. conf50 = confusion(conf50, 0.5)
  32. conf50['set'] = 'vtc (conf 50%)'
  33. conf75 = segments[segments['set'] == 'vtc'].copy()
  34. conf75 = confusion(conf75, 0.75)
  35. conf75['set'] = 'vtc (conf 75%)'
  36. drop50 = segments[segments['set'] == 'vtc'].copy()
  37. drop50 = drop(drop50, 0.5)
  38. drop50['set'] = 'vtc (drop 50%)'
  39. drop75 = segments[segments['set'] == 'vtc'].copy()
  40. drop75 = drop(drop75, 0.75)
  41. drop75['set'] = 'vtc (drop 75%)'
  42. segments = pd.concat([segments, conf50, conf75, drop50, drop75])
  43. metric = DetectionPrecisionRecallFMeasure()
  44. scores = []
  45. for speaker in speakers:
  46. ref = segments_to_annotation(segments[(segments['set'] == 'vtc') & (segments['speaker_type'] == speaker)], 'speaker_type')
  47. for s in sets:
  48. hyp = segments_to_annotation(segments[(segments['set'] == s) & (segments['speaker_type'] == speaker)], 'speaker_type')
  49. detail = metric.compute_components(ref, hyp)
  50. precision, recall, f = metric.compute_metrics(detail)
  51. scores.append({
  52. 'set': s,
  53. 'speaker': speaker,
  54. 'recall': recall,
  55. 'precision': precision,
  56. 'f': f
  57. })
  58. scores = pd.DataFrame(scores)
  59. scores.to_csv('scores.csv', index = False)
  60. scores = pd.read_csv('scores.csv')
  61. plt.rcParams.update({'font.size': 12})
  62. plt.rc('xtick', labelsize = 10)
  63. plt.rc('ytick', labelsize = 10)
  64. print(scores)
  65. styles = {
  66. 'recall': 's',
  67. 'precision': 'D',
  68. 'f': 'o'
  69. }
  70. labels = {
  71. 'recall': 'recall',
  72. 'precision': 'precision',
  73. 'f': 'F-measure'
  74. }
  75. plt.figure(figsize = (6.4*1, 4.8*1+0.25*4.8))
  76. for speaker in speakers:
  77. i = speakers.index(speaker)
  78. ax = plt.subplot(2, 2, i+1)
  79. ax.set_xlim(-0.5,len(sets)-0.5)
  80. ax.set_ylim(0, 1)
  81. if i >= 2:
  82. ax.set_xticks(range(len(sets)))
  83. ax.set_xticklabels(sets, rotation = 45, horizontalalignment = 'right')
  84. else:
  85. ax.set_xticklabels(['' for i in range(len(sets))])
  86. if i%2 == 1:
  87. ax.set_yticklabels(['' for i in range(6)])
  88. ax.set_xlabel(speaker)
  89. _scores = scores[scores['speaker'] == speaker]
  90. for metric in ['recall', 'precision', 'f']:
  91. ax.scatter(
  92. x = _scores['set'].apply(lambda s: sets.index(s)),
  93. y = _scores[metric],
  94. label = labels[metric],
  95. s = 15,
  96. marker = styles[metric]
  97. )
  98. ax = plt.subplot(2, 2, 2)
  99. ax.legend(loc = "upper right", borderaxespad = 0.1, bbox_to_anchor=(1, 1.25), ncol = 3)
  100. plt.subplots_adjust(wspace = 0.15)
  101. plt.savefig('Fig4.pdf', bbox_inches = 'tight')