recall.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. matplotlib.use("pgf")
  8. matplotlib.rcParams.update({
  9. "pgf.texsystem": "pdflatex",
  10. 'font.family': 'serif',
  11. "font.serif" : "Times New Roman",
  12. 'text.usetex': True,
  13. 'pgf.rcfonts': False,
  14. })
  15. import numpy as np
  16. import os
  17. import pandas as pd
  18. import random
  19. import sys
  20. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  21. sets = {
  22. 'vtc': 'VTC',
  23. 'its': 'LENA',
  24. 'cha/aligned': 'chat+mfa'
  25. }
  26. if __name__ == '__main__':
  27. if not os.path.exists('scores.csv'):
  28. from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure
  29. path = sys.argv[1]
  30. project = ChildProject(path)
  31. am = AnnotationManager(project)
  32. am.read()
  33. intersection = AnnotationManager.intersection(am.annotations, ['eaf'] + list(sets.keys()))
  34. segments = am.get_collapsed_segments(intersection)
  35. segments = segments[segments['speaker_type'].isin(speakers)]
  36. metric = DetectionPrecisionRecallFMeasure()
  37. scores = []
  38. for speaker in speakers:
  39. ref = segments_to_annotation(segments[(segments['set'] == 'eaf') & (segments['speaker_type'] == speaker)], 'speaker_type')
  40. for s in sets:
  41. hyp = segments_to_annotation(segments[(segments['set'] == s) & (segments['speaker_type'] == speaker)], 'speaker_type')
  42. detail = metric.compute_components(ref, hyp)
  43. precision, recall, f = metric.compute_metrics(detail)
  44. scores.append({
  45. 'set': s,
  46. 'speaker': speaker,
  47. 'recall': recall,
  48. 'precision': precision,
  49. 'f': f
  50. })
  51. scores = pd.DataFrame(scores)
  52. scores.to_csv('scores.csv', index = False)
  53. scores = pd.read_csv('scores.csv')
  54. plt.rcParams.update({'font.size': 12})
  55. plt.rc('xtick', labelsize = 10)
  56. plt.rc('ytick', labelsize = 10)
  57. print(scores)
  58. styles = {
  59. 'recall': 's',
  60. 'precision': 'D',
  61. 'f': 'o'
  62. }
  63. labels = {
  64. 'recall': 'recall',
  65. 'precision': 'precision',
  66. 'f': 'F-measure'
  67. }
  68. plt.figure(figsize = (6.4*1, 4.8*1+0.25*4.8))
  69. for speaker in speakers:
  70. i = speakers.index(speaker)
  71. ax = plt.subplot(2, 2, i+1)
  72. ax.set_xlim(-0.5,len(sets)-0.5)
  73. ax.set_ylim(0, 1)
  74. if i >= 2:
  75. ax.set_xticks(range(len(sets)))
  76. ax.set_xticklabels(sets.values(), rotation = 45, horizontalalignment = 'right')
  77. else:
  78. ax.set_xticklabels(['' for i in range(len(sets))])
  79. if i%2 == 1:
  80. ax.set_yticklabels(['' for i in range(6)])
  81. ax.set_xlabel(speaker)
  82. _scores = scores[scores['speaker'] == speaker]
  83. for metric in ['recall', 'precision', 'f']:
  84. ax.scatter(
  85. x = _scores['set'].apply(lambda s: list(sets.keys()).index(s)),
  86. y = _scores[metric],
  87. label = labels[metric],
  88. s = 15,
  89. marker = styles[metric]
  90. )
  91. ax = plt.subplot(2, 2, 2)
  92. ax.legend(loc = "upper right", borderaxespad = 0.1, bbox_to_anchor=(1, 1.25), ncol = 3)
  93. plt.subplots_adjust(wspace = 0.15)
  94. plt.savefig('Fig6.pdf', bbox_inches = 'tight')