123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import matplotlib.pyplot as plt
- import numpy as np
- import os
- import aux
- from helpers import data_management as dm
- import modules.classifier as clf
- from analytics import analytics1
- import re
- import importlib
- importlib.reload(analytics1)
- # sensory mapping
- # states = ['SchliesseHand','BeugeRechtenMittelfinger', 'BeugeRechtenZeigefinger','BeugeRechtenDaumen','OeffneHand',
- # 'StreckeRechtenMittelfinger','StreckeRechtenZeigefinger','StreckeRechtenDaumen']
- # motor mapping
- # states = ['rechte_hand','linke_hand','rechter_daumen','linker_daumen','zunge','fuesse']
- # states = ['Zunge', 'Schliesse_Hand', 'RechterDaumen', 'Oeffne_Hand', 'Fuss']
- states = ['Zunge', 'Schliesse_Hand', 'Oeffne_Hand', 'Bewege_Augen', 'Bewege_Kopf']
- # states = ['ruhe','ja','nein','kopf','fuss']
- params = aux.load_config()
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes, exploration=True)
- psth_win = [-20,80]
- n_trials = 10
- restore_baseline = False
- psth_tot = np.zeros((len(states),0, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- psth = np.zeros((len(states),n_trials, np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- for fids in range(data_tot.shape[0]):
- if restore_baseline:
- bl_name = f"{os.path.dirname(file_names[fids])}/bl_{re.search('data_(.+?).bin', file_names[0]).group(1)}.npy"
- bl = np.load(bl_name)
- data = data_tot[fids,0] * (bl+params.daq.spike_rates.bl_offset) + bl # add baseline if it was removed
- print('renormalizing firing rates according to baseline\n')
- else:
- data = data_tot[fids,0]
- # correct baseline
- # bl_idx = int(params.recording.timing.t_baseline_1 / params.daq.spike_rates.loop_interval*1000.)
- # bl = np.max(data[:bl_idx,:], axis=0) - np.min(data[:bl_idx,:], axis=0)
- # data = (data - bl)/(bl+params.daq.spike_rates.bl_offset)
- print('No renormalization of firing rates according to baseline\n')
- triggers = triggers_tot[fids,0][0]
- for state in range(len(triggers)):
- if triggers[state].size>0:
- for ii in range(triggers[state].size):
- jj = triggers[state][0,ii]
- if jj+(psth_win[0])>=0 and jj+(psth_win[1])<=data.shape[0]:
- psth[state,ii,:,:] = data[jj+psth_win[0]:jj+psth_win[1],:]
- # print(state,ii,jj)
- print(file_names[fids])
- psth_tot = np.concatenate((psth_tot, psth), axis=1)
- psth_xx = np.arange(psth_win[0],psth_win[1])
- col = ['C0','C1','C2','C3','C4','C5','C6','C7','C8']
- plt.figure(1, figsize=[10 , 7])
- plt.ioff()
- psth_tot[0,7,:,0] = psth_tot[0,7,:,0]*0
- ymax = .3
- for ch_id in range(0,128):
- # for ch_id in range(0, 32):
- plt.clf()
- print(f'ch_id: {ch_id}')
- for ii in range(psth_tot.shape[0]):
- # for ii in [1, 2, 3, 5, 6, 7,0,:
- ax = plt.subplot(3, 3, ii + 1)
- plt.gca().title.set_text(f'{states[ii]}, n={len(psth_tot[ii])}, ch={ch_id}')
- # rows = np.argwhere(np.max(psth_tot[0,:,:,0], axis=1) > 25)
- # psth1 = np.delete(psth_tot[ii, :, :, ch_id], rows, axis=0)
- mu1 = np.mean(psth_tot[ii, 0:, :, ch_id:ch_id + 1], axis=2) # average accross channels
- # mu1 = np.mean(psth1, axis=0) # average accross channels
- med1 = np.median(mu1, axis=0)
- std1 = np.std(mu1, axis=0)
- plt.plot(psth_xx, mu1.T, color=col[ii], alpha=0.2)
- plt.plot(psth_xx, med1, color='k', lw=2,alpha=0.5)
- plt.plot(psth_xx, np.mean(mu1, axis=0), color=col[ii], lw=2)
- plt.plot(psth_xx, med1 - std1, 'k--', lw=2, alpha=0.5)
- plt.plot(psth_xx, med1 + std1, 'k--', lw=2, alpha=0.5)
- ymin = max(med1) - 2 * max(std1)
- ymax = max(med1) + 2 * max(std1)
- ymin, ymax = plt.ylim(ymin, ymax)
- plt.vlines(0, ymin, ymax, alpha=0.5)
- # plt.ylim(-ymax,ymax)
- if ii in [0, 2, 4]:
- plt.ylabel('Sp/sec')
- if ii<4:
- ax.set_xticks([])
- if ii>=4:
- plt.xlabel('samples')
-
- dir_name = params.file_handling.results + 'exploration/' + os.path.splitext(file_names[0])[0].split('/')[4]
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
- plt.savefig(dir_name + f'/ch_id_{ch_id}.png')
-
- plt.show()
|