show_feedback_data_pca.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import yaml
  4. import aux
  5. from helpers import data_management as dm
  6. import modules.classifier as clf
  7. import analytics
  8. import os
  9. import itertools
  10. from collections.abc import Iterable
  11. from sklearn.decomposition.pca import PCA
  12. params = aux.load_config()
  13. if params.speller.type == 'exploration':
  14. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  15. exploration=True, trigger_pos=params.classifier.trigger_pos)
  16. elif params.speller.type == 'feedback':
  17. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  18. feedback=True, trigger_pos=params.classifier.trigger_pos)
  19. else:
  20. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  21. trigger_pos=params.classifier.trigger_pos)
  22. if params.classifier.trigger_pos == 'start':
  23. psth_win = [-3.0, 5.0] # in seconds!
  24. else:
  25. psth_win = [-5.0, 3.0] # in seconds!
  26. psth_win = np.array(psth_win) * 1000.0 / params.daq.spike_rates.loop_interval
  27. psth_win = psth_win.astype(int)
  28. for fids in range(data_tot.shape[0]):
  29. data = data_tot[fids,0]
  30. triggers_down = triggers_tot[fids,0][0]
  31. triggers_up = triggers_tot[fids,1][0]
  32. pca1 = PCA(n_components=30)
  33. pca1.fit(data)
  34. data_pca = pca1.transform(data)
  35. psth_down = np.zeros((triggers_down.shape[0], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  36. psth_up = np.zeros((triggers_up.shape[0], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  37. psth_down_pca = np.zeros((triggers_down.shape[0], np.diff(psth_win)[0], pca1.n_components))
  38. psth_up_pca = np.zeros((triggers_up.shape[0], np.diff(psth_win)[0], pca1.n_components))
  39. for ii,tr_id in enumerate(triggers_down):
  40. if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
  41. psth_down[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  42. psth_down_pca[ii,:,:] = data_pca[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  43. print(ii,tr_id)
  44. for ii,tr_id in enumerate(triggers_up):
  45. if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
  46. psth_up[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  47. psth_up_pca[ii,:,:] = data_pca[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  48. print(ii,tr_id)
  49. psth_xx = np.arange(psth_win[0],psth_win[1]) / 20.
  50. col = ['C0','C1']
  51. plt.figure(1)
  52. plt.clf()
  53. for ch_id in range(pca1.n_components):
  54. psth_down_agg = psth_down_pca[:, :, ch_id]
  55. psth_up_agg = psth_up_pca[:, :, ch_id]
  56. ymin = min(psth_down_agg.min(), psth_up_agg.min())
  57. ymax = max(psth_down_agg.max(), psth_up_agg.max())
  58. mu1 = np.mean(psth_down_agg, axis=0)
  59. md1 = np.median(psth_down_agg, axis=0)
  60. mu2 = np.mean(psth_up_agg, axis=0)
  61. md2 = np.median(psth_up_agg, axis=0)
  62. plt.figure(1)
  63. plt.clf()
  64. plt.subplot(221)
  65. plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
  66. plt.plot(psth_xx, mu1.T, color='k', alpha=0.8)
  67. plt.plot(psth_xx, md1.T, '--', color='k', alpha=0.8)
  68. # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
  69. # plt.ylim(0,8)
  70. plt.ylabel('sp/sec')
  71. plt.ylim(ymin, ymax)
  72. plt.title(f'down, n={psth_down_agg.shape[0]}')
  73. plt.subplot(222)
  74. plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
  75. plt.plot(psth_xx, mu2.T, color='k', alpha=0.8)
  76. plt.plot(psth_xx, md2.T, '--', color='k', alpha=0.8)
  77. plt.title(f'up, n={psth_up_agg.shape[0]}')
  78. plt.ylim(ymin, ymax)
  79. plt.subplot(223)
  80. plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
  81. plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
  82. plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
  83. plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
  84. plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
  85. plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
  86. plt.ylim(ymin, ymax)
  87. plt.ylabel('sp/sec')
  88. plt.xlabel('sec')
  89. plt.subplot(224)
  90. # plt.plot(psth_xx, psth_down[:, :, ch_id].T, 'C0', alpha=0.5)
  91. # plt.plot(psth_xx, psth_up[:, :, ch_id].T, 'C1', alpha=0.5)
  92. plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
  93. plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
  94. plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
  95. plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
  96. plt.ylim(ymin, ymax)
  97. plt.xlabel('sec')
  98. fname = os.path.basename(file_names[0]).split('.')[0]
  99. os.makedirs(f'{params.file_handling.results}/{fname}', exist_ok=True)
  100. # plt.savefig(f'/media/vlachos/bck_disk1/kiap/recordings/fr/results/neurofeedback/{fname}_nf_{ch_id}.png')
  101. fname2 = f'{params.file_handling.results}/{fname}/{fname}_nf_pc{ch_id}.png'
  102. print(fname2)
  103. plt.savefig(fname2)
  104. # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
  105. # plt.ylim(0,8)
  106. # input()
  107. # plt.draw()
  108. # plt.show()
  109. plt.show()