Browse Source

compute precisions

Loann Peurey 1 year ago
parent
commit
6aeaafa0d0
2 changed files with 12 additions and 28 deletions
  1. 1 0
      extra/reliability/precisions.csv
  2. 11 28
      scripts/confusion_matrix.py

+ 1 - 0
extra/reliability/precisions.csv

@@ -0,0 +1 @@
+../../.git/annex/objects/zK/88/MD5E-s462--e662c1bd60a293a206e2d5579f2db67e.csv/MD5E-s462--e662c1bd60a293a206e2d5579f2db67e.csv

+ 11 - 28
scripts/confusion_matrix.py

@@ -12,7 +12,7 @@ We allow annotations to be compared between different recordings only because al
 """
 speakers = ['CHI', 'OCH', 'FEM', 'MAL']
 speakers_plot = speakers
-speakers_plot.append("None")
+#speakers_plot.append("None")
 speakers_plot = [""] + speakers_plot
 
 project = ChildProject('.')
@@ -24,26 +24,7 @@ anns = pd.DataFrame.copy(am.annotations)
 human_ann = anns[anns['set'] == 'eaf_cha']
 vtc_ann = anns[anns['set'] == 'vtc']
 
-def initiate_plt():
-    fig, ax = plt.subplots(figsize=(7.5, 7.5))
-    ax.set_xticklabels(speakers)  
-    ax.set_yticklabels(speakers)
-    ax.xaxis.set_label_position("top")
-    plt.xlabel('HUMAN', fontsize=18)
-    plt.ylabel('VTC', fontsize=18)
-    plt.title('Confusion Matrix', fontsize=18)
-    return ax
-    
-#ax = initiate_plt()
-
-def build_matrix(array, output):
-    fig = plt.figure(figsize=(12, 6.75))
-    
-    ax.matshow(array, cmap=plt.cm.Blues, alpha=0.3)
-    for i in range(array.shape[0]):
-        for j in range(array.shape[1]):
-            ax.text(x=j, y=i,s=round(array[i, j],3), va='center', ha='center', size='xx-large')
-    plt.savefig(output)
+precisions = pd.DataFrame()
 
 for i, row in vtc_ann.iterrows():
     audio_rec = row['recording_filename']
@@ -62,12 +43,11 @@ for i, row in vtc_ann.iterrows():
     eaf_cha = segments_to_grid(segments[segments['set'] == 'eaf_cha'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
 
     confusion_counts = conf_matrix(vtc, eaf_cha)
-
-    #build_matrix(confusion_counts, 'extra/reliability/{}.png'.format(audio_rec),ax)
-
-    #normalized = confusion_counts/(np.sum(vtc, axis = 0)[:,None])
-
-    #build_matrix(normalized,'extra/reliability/{}_normalized.png'.format(audio_rec),ax)
+    
+    total = np.sum(confusion_counts)
+    precision = np.trace(confusion_counts) / total if total != 0 else 0
+    
+    precisions = pd.concat([precisions, pd.DataFrame(data = {'recording_filename': [audio_rec],'precision': [precision]})])
     
     normalized = confusion_counts
     
@@ -101,4 +81,7 @@ for i, row in vtc_ann.iterrows():
     plt.title('Confusion Matrix', fontsize=18)
     plt.savefig('extra/reliability/{}_normalized.png'.format(audio_rec))
     
-    plt.close()
+    plt.close()
+
+print(precisions)
+precisions.to_csv('extra/reliability/precisions.csv',index=False)