plots.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from ChildProject.projects import ChildProject
  2. from ChildProject.annotations import AnnotationManager
  3. from ChildProject.metrics import gamma, segments_to_grid
  4. import numpy as np
  5. import pandas as pd
  6. from sklearn.metrics import confusion_matrix
  7. from sklearn.preprocessing import normalize
  8. import seaborn as sns
  9. import matplotlib.pyplot as plt
  10. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  11. project = ChildProject('.')
  12. am = AnnotationManager(project)
  13. am.read()
  14. intersection = AnnotationManager.intersection(am.annotations, ['vtc', 'its'])
  15. segments = am.get_collapsed_segments(intersection)
  16. segments = segments[segments['speaker_type'].isin(speakers)]
  17. segments.sort_values(['segment_onset', 'segment_offset']).to_csv('test.csv', index = False)
  18. #print(gamma(segments, column = 'speaker_type'))
  19. print('creating grids')
  20. vtc = segments_to_grid(segments[segments['set'] == 'vtc'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  21. its = segments_to_grid(segments[segments['set'] == 'its'], 0, segments['segment_offset'].max(), 100, 'speaker_type', speakers)
  22. print('done creating grids')
  23. speakers.extend(['overlap', 'none'])
  24. def get_pick(row):
  25. for cat in reversed(speakers):
  26. if row[cat]:
  27. return cat
  28. def conf_matrix(horizontal, vertical, categories):
  29. vertical = pd.DataFrame(vertical, columns = categories)
  30. vertical['pick'] = vertical.apply(
  31. get_pick,
  32. axis = 1
  33. )
  34. vertical = vertical['pick'].values
  35. horizontal = pd.DataFrame(horizontal, columns = categories)
  36. horizontal['pick'] = horizontal.apply(
  37. get_pick,
  38. axis = 1
  39. )
  40. horizontal = horizontal['pick'].values
  41. confusion = confusion_matrix(vertical, horizontal, labels = categories)
  42. confusion = normalize(confusion, axis = 1, norm = 'l1')
  43. return confusion
  44. plt.rcParams.update({'font.size': 12})
  45. plt.rc('xtick', labelsize = 10)
  46. plt.rc('ytick', labelsize = 10)
  47. fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(6.4*2, 4.8))
  48. confusion = conf_matrix(its, vtc, speakers)
  49. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[0], cmap = 'Reds')
  50. axes[0].set_xlabel('its')
  51. axes[0].set_ylabel('vtc')
  52. axes[0].xaxis.set_ticklabels(speakers)
  53. axes[0].yaxis.set_ticklabels(speakers)
  54. confusion = conf_matrix(vtc, its, speakers)
  55. sns.heatmap(confusion, annot = True, fmt = '.2f', ax = axes[1], cmap = 'Reds')
  56. axes[1].set_xlabel('vtc')
  57. axes[1].set_ylabel('its')
  58. axes[1].xaxis.set_ticklabels(speakers)
  59. axes[1].yaxis.set_ticklabels(speakers)
  60. plt.savefig('Fig5.pdf', bbox_inches = 'tight')