Quellcode durchsuchen

updated comparison.py

Martin Frébourg vor 2 Jahren
Ursprung
Commit
3816ea2f3d

+ 38 - 3
code/compare.py

@@ -4,6 +4,9 @@ from ChildProject.projects import ChildProject
 from ChildProject.annotations import AnnotationManager
 from ChildProject.metrics import segments_to_grid, conf_matrix, segments_to_annotation
 from pathlib import Path
+import seaborn as sns
+import matplotlib.pyplot as plt
+import numpy as np
 
 def compare_vandam(set1: str, set2: str) :
 
@@ -54,9 +57,41 @@ def compare_vandam(set1: str, set2: str) :
     #generates segments
     set1_segm = segments_to_grid(segments[segments['set'] == set1], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
     set2_segm = segments_to_grid(segments[segments['set'] == set2], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
-    matrix_df = pd.DataFrame(conf_matrix(set1_segm, set2_segm))
-    matrix_df.to_csv("{0}/{1}-{2}-confusion-matrix.csv".format(dirName, set1.replace("/",""), set2.replace("/","")), mode = "w", index=False)
+    
+    speakers.extend(['none'])
+
+    confusion_counts = conf_matrix(set1_segm, set2_segm)
+
+    plt.rcParams.update({'font.size': 12})
+    plt.rc('xtick', labelsize = 10)
+    plt.rc('ytick', labelsize = 10)
+
+    fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
+
+    confusion = confusion_counts/np.sum(set1_segm, axis = 0)[:,None]
+
+    sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
+    axes[0].set_xlabel(set2)
+    axes[0].set_ylabel(set1)
+    axes[0].xaxis.set_ticklabels(speakers)
+    axes[0].yaxis.set_ticklabels(speakers)
+
+    confusion_counts = np.transpose(confusion_counts)
+    confusion = confusion_counts/np.sum(set2_segm, axis = 0)[:,None]
+
+    sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
+    axes[1].set_xlabel(set1)
+    axes[1].set_ylabel(set2)
+    axes[1].xaxis.set_ticklabels(speakers)
+    axes[1].yaxis.set_ticklabels(speakers)
+
+    plt.savefig("{0}/{1}-{2}-confusion-matrix.jpg".format(dirName, set1.replace("/",""), set2.replace("/",""), bbox_inches = 'tight'))
+    
+    
+    #matrix_df = pd.DataFrame(conf_matrix(set1_segm, set2_segm))
+   # matrix_df.to_csv("{0}/{1}-{2}-confusion-matrix.csv".format(dirName, set1.replace("/",""), set2.replace("/","")), mode = "w", index=False)
     print("Confusion matrix saved for {0} and {1}!".format(set1, set2))
 
 compare_vandam('eaf', 'cha')
-compare_vandam('eaf', 'cha/aligned')
+compare_vandam('eaf', 'cha/aligned')
+compare_vandam('cha', 'cha/aligned')

+ 1 - 0
outputs/compare/cha-chaaligned/cha-chaaligned-confusion-matrix.jpg

@@ -0,0 +1 @@
+../../../.git/annex/objects/q2/9W/MD5E-s61222--47d0b98048bd0222793bb5e7da2f91de.jpg/MD5E-s61222--47d0b98048bd0222793bb5e7da2f91de.jpg

+ 1 - 0
outputs/compare/cha-chaaligned/cha-chaaligned.txt

@@ -0,0 +1 @@
+../../../.git/annex/objects/x0/k2/MD5E-s301--b279b27726c2a333fa467568cc6b1a7c.txt/MD5E-s301--b279b27726c2a333fa467568cc6b1a7c.txt

+ 0 - 1
outputs/compare/eaf-cha/eaf-cha-confusion-matrix.csv

@@ -1 +0,0 @@
-../../../.git/annex/objects/02/qX/MD5E-s74--22ae56390035e28008fa9cdc4ef4f910.csv/MD5E-s74--22ae56390035e28008fa9cdc4ef4f910.csv

+ 1 - 0
outputs/compare/eaf-cha/eaf-cha-confusion-matrix.jpg

@@ -0,0 +1 @@
+../../../.git/annex/objects/GG/Kq/MD5E-s60314--703409554750fac610274aec7308c082.jpg/MD5E-s60314--703409554750fac610274aec7308c082.jpg

+ 0 - 1
outputs/compare/eaf-chaaligned/eaf-chaaligned-confusion-matrix.csv

@@ -1 +0,0 @@
-../../../.git/annex/objects/pX/51/MD5E-s76--45d91b7c921302f3916c697f755a8c9c.csv/MD5E-s76--45d91b7c921302f3916c697f755a8c9c.csv

+ 1 - 0
outputs/compare/eaf-chaaligned/eaf-chaaligned-confusion-matrix.jpg

@@ -0,0 +1 @@
+../../../.git/annex/objects/KZ/7q/MD5E-s62300--77324ff50d887a02827122abd6d82c6c.jpg/MD5E-s62300--77324ff50d887a02827122abd6d82c6c.jpg