show_sensory_mapping_data.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import aux
  4. from helpers import data_management as dm
  5. import modules.classifier as clf
  6. import analytics
  7. # sensory mapping
  8. states = ['SchliesseHand','BeugeRechtenMittelfinger', 'BeugeRechtenZeigefinger','BeugeRechtenDaumen','OeffneHand',
  9. 'StreckeRechtenMittelfinger','StreckeRechtenZeigefinger','StreckeRechtenDaumen']
  10. # motor mapping
  11. # states = ['rechte_hand', 'rechter_daumen', 'linker_daumen', 'zunge', 'fuesse']
  12. params = aux.load_config()
  13. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes, exploration=True)
  14. psth_win = [-20,80]
  15. n_trials = 10
  16. psth_tot = np.zeros((len(states),0, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  17. psth = np.zeros((len(states),n_trials, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  18. for fids in range(data_tot.shape[0]):
  19. data = data_tot[fids,0]
  20. triggers = triggers_tot[fids,0][0]
  21. for state in range(len(triggers)):
  22. if triggers[state].size>0:
  23. for ii in range(triggers[state].size):
  24. jj = triggers[state][0,ii]
  25. if jj+(psth_win[0])>=0 and jj+(psth_win[1])<=data.shape[0]:
  26. psth[state,ii,:,:] = data[jj+psth_win[0]:jj+psth_win[1],:]
  27. print(state,ii,jj)
  28. print(file_names[fids])
  29. psth_tot = np.concatenate((psth_tot, psth), axis=1)
  30. psth_xx = np.arange(psth_win[0],psth_win[1])
  31. col = ['C0','C1','C2','C3','C4','C5','C6','C7','C8']
  32. plt.figure(1)
  33. # plt.subplot(321)
  34. # for ii in range(psth.shape[0]):
  35. # plt.plot(np.mean(psth[ii,0:,:,32], axis=0).T,col[ii])
  36. # plt.subplot(322)
  37. # for ii in range(psth.shape[0]):
  38. # plt.plot(np.mean(psth[ii,0:,:,33], axis=0).T,col[ii])
  39. plt.figure(1, figsize=[19.2 , 9.55])
  40. # plt.ioff()
  41. # for ch_id in range(0,128):
  42. for ch_id in range(87,88):
  43. plt.clf()
  44. print(f'ch_id: {ch_id}')
  45. for ii in range(psth_tot.shape[0]):
  46. # for ii in [1, 2, 3, 5, 6, 7,0,:
  47. ax = plt.subplot(3,4,ii+1)
  48. plt.gca().title.set_text(f'{states[ii]} n={psth_tot.shape[1]}')
  49. mu1 = np.mean(psth_tot[ii,0:,:,ch_id:ch_id+1], axis=2) # average accross channels
  50. plt.plot(psth_xx, mu1.T, color=col[ii], alpha=0.2)
  51. plt.plot(psth_xx, np.median(mu1, axis=0), color='k',lw=2,alpha=0.5)
  52. plt.plot(psth_xx, np.mean(mu1, axis=0), color=col[ii],lw=2)
  53. plt.ylim(0,20)
  54. ax.set_xticks([])
  55. for ii in range(4):
  56. ax=plt.subplot(3,4,8+ii+1)
  57. mu1 = np.mean(psth_tot[ii,0:,:,ch_id:ch_id+1], axis=2)
  58. mu2 = np.mean(psth_tot[ii+4,0:,:,ch_id:ch_id+1], axis=2)
  59. plt.plot(psth_xx, np.mean(mu1, axis=0), color=col[ii], lw=3)
  60. plt.plot(psth_xx, np.mean(mu2, axis=0), color=col[ii+4], lw=3)
  61. plt.plot(psth_xx, np.median(mu1, axis=0), 'k--', lw=2, alpha=0.6)
  62. plt.plot(psth_xx, np.median(mu2, axis=0), 'k--', lw=2, alpha=0.6)
  63. # plt.ylim(0,20)
  64. plt.savefig('/kiap/src/data/results/sensor_mapping/'+f'ch_id_{ch_id}.png')
  65. # plt.savefig('/data/clinical/neural_new/2019-03-23/results/'+f'ch_id_{ch_id}.png')
  66. # plt.savefig('/data/clinical/neural_new/2019-03-26/results/'+f'ch_id_{ch_id}.png')
  67. # plt.subplot(324)
  68. # for ii in range(psth.shape[0]):
  69. # plt.plot(np.mean(psth[ii,0:,:,0], axis=0).T,col[ii])
  70. # plt.subplot(325)
  71. # for ii in range(psth.shape[0]):
  72. # plt.plot(np.mean(psth[ii,1:,:,0:32], axis=2).T,col[ii])
  73. # plt.subplot(326)
  74. # for ii in range(psth.shape[0]):
  75. # plt.plot(np.mean(psth[ii,1:,:,32:96], axis=2).T,col[ii])
  76. # print('\nsubplot1: channel 0, averaged across stimuli')
  77. # print('subplot2: all stimuli, averaged across channels 0-64')
  78. # print('subplot3: all stimuli, averaged across channels 64-128')
  79. # plt.legend()
  80. plt.show()