Browse Source

confusion matrix calculation relying on the package more

Lucas Gautheron 3 years ago
parent
commit
e374b340dd
2 changed files with 7 additions and 15 deletions
  1. BIN
      Fig5.pdf
  2. 7 15
      code/confusion_matrix.py

BIN
Fig5.pdf


+ 7 - 15
code/confusion_matrix.py

@@ -2,7 +2,7 @@
 
 from ChildProject.projects import ChildProject
 from ChildProject.annotations import AnnotationManager
-from ChildProject.metrics import gamma, segments_to_grid
+from ChildProject.metrics import segments_to_grid, conf_matrix
 
 import numpy as np
 import pandas as pd
@@ -30,18 +30,7 @@ its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_
 
 speakers.extend(['overlap', 'none'])
 
-def grid_to_vector(grid):
-    return np.argmax(grid[:,::-1], axis = 1)
-
-def conf_matrix(horizontal, vertical, categories):
-    n = len(categories)-1
-    vertical = np.vectorize(lambda x: categories[n-x])(grid_to_vector(vertical))
-    horizontal = np.vectorize(lambda x: categories[n-x])(grid_to_vector(horizontal))
-
-    confusion = confusion_matrix(vertical, horizontal, labels = categories)
-    confusion = normalize(confusion, axis = 1, norm = 'l1')
-
-    return confusion
+confusion_counts = conf_matrix(its, vtc, speakers)
 
 plt.rcParams.update({'font.size': 12})
 plt.rc('xtick', labelsize = 10)
@@ -49,14 +38,17 @@ plt.rc('ytick', labelsize = 10)
 
 fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
 
-confusion = conf_matrix(its, vtc, speakers)
+confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
+
 sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
 axes[0].set_xlabel('its')
 axes[0].set_ylabel('vtc')
 axes[0].xaxis.set_ticklabels(speakers)
 axes[0].yaxis.set_ticklabels(speakers)
 
-confusion = conf_matrix(vtc, its, speakers)
+confusion_counts = np.transpose(confusion_counts)
+confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
+
 sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
 axes[1].set_xlabel('vtc')
 axes[1].set_ylabel('its')