|
@@ -2,7 +2,7 @@
|
|
|
|
|
|
from ChildProject.projects import ChildProject
|
|
|
from ChildProject.annotations import AnnotationManager
|
|
|
-from ChildProject.metrics import gamma, segments_to_grid
|
|
|
+from ChildProject.metrics import segments_to_grid, conf_matrix
|
|
|
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
@@ -30,18 +30,7 @@ its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_
|
|
|
|
|
|
speakers.extend(['overlap', 'none'])
|
|
|
|
|
|
-def grid_to_vector(grid):
|
|
|
- return np.argmax(grid[:,::-1], axis = 1)
|
|
|
-
|
|
|
-def conf_matrix(horizontal, vertical, categories):
|
|
|
- n = len(categories)-1
|
|
|
- vertical = np.vectorize(lambda x: categories[n-x])(grid_to_vector(vertical))
|
|
|
- horizontal = np.vectorize(lambda x: categories[n-x])(grid_to_vector(horizontal))
|
|
|
-
|
|
|
- confusion = confusion_matrix(vertical, horizontal, labels = categories)
|
|
|
- confusion = normalize(confusion, axis = 1, norm = 'l1')
|
|
|
-
|
|
|
- return confusion
|
|
|
+confusion_counts = conf_matrix(its, vtc, speakers)
|
|
|
|
|
|
plt.rcParams.update({'font.size': 12})
|
|
|
plt.rc('xtick', labelsize = 10)
|
|
@@ -49,14 +38,17 @@ plt.rc('ytick', labelsize = 10)
|
|
|
|
|
|
fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
|
|
|
|
|
|
-confusion = conf_matrix(its, vtc, speakers)
|
|
|
+confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
|
|
|
+
|
|
|
sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
|
|
|
axes[0].set_xlabel('its')
|
|
|
axes[0].set_ylabel('vtc')
|
|
|
axes[0].xaxis.set_ticklabels(speakers)
|
|
|
axes[0].yaxis.set_ticklabels(speakers)
|
|
|
|
|
|
-confusion = conf_matrix(vtc, its, speakers)
|
|
|
+confusion_counts = np.transpose(confusion_counts)
|
|
|
+confusion = normalize(confusion_counts, axis = 1, norm = 'l1')
|
|
|
+
|
|
|
sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
|
|
|
axes[1].set_xlabel('vtc')
|
|
|
axes[1].set_ylabel('its')
|