show_motor_mapping_data.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import os
  4. import aux
  5. from helpers import data_management as dm
  6. import modules.classifier as clf
  7. from analytics import analytics1
  8. import re
  9. import importlib
  10. importlib.reload(analytics1)
  11. # sensory mapping
  12. # states = ['SchliesseHand','BeugeRechtenMittelfinger', 'BeugeRechtenZeigefinger','BeugeRechtenDaumen','OeffneHand',
  13. # 'StreckeRechtenMittelfinger','StreckeRechtenZeigefinger','StreckeRechtenDaumen']
  14. # motor mapping
  15. # states = ['rechte_hand','linke_hand','rechter_daumen','linker_daumen','zunge','fuesse']
  16. # states = ['Zunge', 'Schliesse_Hand', 'RechterDaumen', 'Oeffne_Hand', 'Fuss']
  17. states = ['Zunge', 'Schliesse_Hand', 'Oeffne_Hand', 'Bewege_Augen', 'Bewege_Kopf']
  18. # states = ['ruhe','ja','nein','kopf','fuss']
  19. params = aux.load_config()
  20. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes, exploration=True)
  21. psth_win = [-20,80]
  22. n_trials = 10
  23. restore_baseline = False
  24. psth_tot = np.zeros((len(states),0, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  25. psth = np.zeros((len(states),n_trials, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  26. for fids in range(data_tot.shape[0]):
  27. if restore_baseline:
  28. bl_name = f"{os.path.dirname(file_names[fids])}/bl_{re.search('data_(.+?).bin', file_names[0]).group(1)}.npy"
  29. bl = np.load(bl_name)
  30. data = data_tot[fids,0] * (bl+params.daq.spike_rates.bl_offset) + bl # add baseline if it was removed
  31. print('renormalizing firing rates according to baseline\n')
  32. else:
  33. data = data_tot[fids,0]
  34. # correct baseline
  35. # bl_idx = int(params.recording.timing.t_baseline_1 / params.daq.spike_rates.loop_interval*1000.)
  36. # bl = np.max(data[:bl_idx,:], axis=0) - np.min(data[:bl_idx,:], axis=0)
  37. # data = (data - bl)/(bl+params.daq.spike_rates.bl_offset)
  38. print('No renormalization of firing rates according to baseline\n')
  39. triggers = triggers_tot[fids,0][0]
  40. for state in range(len(triggers)):
  41. if triggers[state].size>0:
  42. for ii in range(triggers[state].size):
  43. jj = triggers[state][0,ii]
  44. if jj+(psth_win[0])>=0 and jj+(psth_win[1])<=data.shape[0]:
  45. psth[state,ii,:,:] = data[jj+psth_win[0]:jj+psth_win[1],:]
  46. # print(state,ii,jj)
  47. print(file_names[fids])
  48. psth_tot = np.concatenate((psth_tot, psth), axis=1)
  49. psth_xx = np.arange(psth_win[0],psth_win[1])
  50. col = ['C0','C1','C2','C3','C4','C5','C6','C7','C8']
  51. plt.figure(1, figsize=[10 , 7])
  52. plt.ioff()
  53. psth_tot[0,7,:,0] = psth_tot[0,7,:,0]*0
  54. ymax = .3
  55. for ch_id in range(0,128):
  56. # for ch_id in range(0, 32):
  57. plt.clf()
  58. print(f'ch_id: {ch_id}')
  59. for ii in range(psth_tot.shape[0]):
  60. # for ii in [1, 2, 3, 5, 6, 7,0,:
  61. ax = plt.subplot(3, 3, ii + 1)
  62. plt.gca().title.set_text(f'{states[ii]}, n={len(psth_tot[ii])}, ch={ch_id}')
  63. # rows = np.argwhere(np.max(psth_tot[0,:,:,0], axis=1) > 25)
  64. # psth1 = np.delete(psth_tot[ii, :, :, ch_id], rows, axis=0)
  65. mu1 = np.mean(psth_tot[ii, 0:, :, ch_id:ch_id + 1], axis=2) # average accross channels
  66. # mu1 = np.mean(psth1, axis=0) # average accross channels
  67. med1 = np.median(mu1, axis=0)
  68. std1 = np.std(mu1, axis=0)
  69. plt.plot(psth_xx, mu1.T, color=col[ii], alpha=0.2)
  70. plt.plot(psth_xx, med1, color='k', lw=2,alpha=0.5)
  71. plt.plot(psth_xx, np.mean(mu1, axis=0), color=col[ii], lw=2)
  72. plt.plot(psth_xx, med1 - std1, 'k--', lw=2, alpha=0.5)
  73. plt.plot(psth_xx, med1 + std1, 'k--', lw=2, alpha=0.5)
  74. ymin = max(med1) - 2 * max(std1)
  75. ymax = max(med1) + 2 * max(std1)
  76. ymin, ymax = plt.ylim(ymin, ymax)
  77. plt.vlines(0, ymin, ymax, alpha=0.5)
  78. # plt.ylim(-ymax,ymax)
  79. if ii in [0, 2, 4]:
  80. plt.ylabel('Sp/sec')
  81. if ii<4:
  82. ax.set_xticks([])
  83. if ii>=4:
  84. plt.xlabel('samples')
  85. dir_name = params.file_handling.results + 'exploration/' + os.path.splitext(file_names[0])[0].split('/')[4]
  86. if not os.path.exists(dir_name):
  87. os.makedirs(dir_name)
  88. plt.savefig(dir_name + f'/ch_id_{ch_id}.png')
  89. plt.show()