123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import matplotlib.pyplot as plt
- import numpy as np
- import yaml
- import aux
- from helpers import data_management as dm
- import modules.classifier as clf
- import analytics
- import os
- import itertools
- from collections.abc import Iterable
- from sklearn.decomposition.pca import PCA
- params = aux.load_config()
- if params.speller.type == 'exploration':
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- exploration=True, trigger_pos=params.classifier.trigger_pos)
- elif params.speller.type == 'feedback':
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- feedback=True, trigger_pos=params.classifier.trigger_pos)
- else:
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- trigger_pos=params.classifier.trigger_pos)
- if params.classifier.trigger_pos == 'start':
- psth_win = [-3.0, 5.0] # in seconds!
- else:
- psth_win = [-5.0, 3.0] # in seconds!
- psth_win = np.array(psth_win) * 1000.0 / params.daq.spike_rates.loop_interval
- psth_win = psth_win.astype(int)
- for fids in range(data_tot.shape[0]):
- data = data_tot[fids,0]
- triggers_down = triggers_tot[fids,0][0]
- triggers_up = triggers_tot[fids,1][0]
- pca1 = PCA(n_components=30)
- pca1.fit(data)
- data_pca = pca1.transform(data)
- psth_down = np.zeros((triggers_down.shape[0], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- psth_up = np.zeros((triggers_up.shape[0], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- psth_down_pca = np.zeros((triggers_down.shape[0], np.diff(psth_win)[0], pca1.n_components))
- psth_up_pca = np.zeros((triggers_up.shape[0], np.diff(psth_win)[0], pca1.n_components))
- for ii,tr_id in enumerate(triggers_down):
- if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
- psth_down[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- psth_down_pca[ii,:,:] = data_pca[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- print(ii,tr_id)
- for ii,tr_id in enumerate(triggers_up):
- if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
- psth_up[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- psth_up_pca[ii,:,:] = data_pca[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- print(ii,tr_id)
- psth_xx = np.arange(psth_win[0],psth_win[1]) / 20.
- col = ['C0','C1']
- plt.figure(1)
- plt.clf()
- for ch_id in range(pca1.n_components):
- psth_down_agg = psth_down_pca[:, :, ch_id]
- psth_up_agg = psth_up_pca[:, :, ch_id]
-
-
- ymin = min(psth_down_agg.min(), psth_up_agg.min())
- ymax = max(psth_down_agg.max(), psth_up_agg.max())
- mu1 = np.mean(psth_down_agg, axis=0)
- md1 = np.median(psth_down_agg, axis=0)
- mu2 = np.mean(psth_up_agg, axis=0)
- md2 = np.median(psth_up_agg, axis=0)
- plt.figure(1)
- plt.clf()
- plt.subplot(221)
- plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='k', alpha=0.8)
- plt.plot(psth_xx, md1.T, '--', color='k', alpha=0.8)
- # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
- # plt.ylim(0,8)
- plt.ylabel('sp/sec')
- plt.ylim(ymin, ymax)
- plt.title(f'down, n={psth_down_agg.shape[0]}')
- plt.subplot(222)
- plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu2.T, color='k', alpha=0.8)
- plt.plot(psth_xx, md2.T, '--', color='k', alpha=0.8)
- plt.title(f'up, n={psth_up_agg.shape[0]}')
- plt.ylim(ymin, ymax)
- plt.subplot(223)
- plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
- plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
- plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
- plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
- plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
- plt.ylim(ymin, ymax)
- plt.ylabel('sp/sec')
- plt.xlabel('sec')
- plt.subplot(224)
- # plt.plot(psth_xx, psth_down[:, :, ch_id].T, 'C0', alpha=0.5)
- # plt.plot(psth_xx, psth_up[:, :, ch_id].T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
- plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
- plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
- plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
- plt.ylim(ymin, ymax)
- plt.xlabel('sec')
- fname = os.path.basename(file_names[0]).split('.')[0]
- os.makedirs(f'{params.file_handling.results}/{fname}', exist_ok=True)
- # plt.savefig(f'/media/vlachos/bck_disk1/kiap/recordings/fr/results/neurofeedback/{fname}_nf_{ch_id}.png')
- fname2 = f'{params.file_handling.results}/{fname}/{fname}_nf_pc{ch_id}.png'
- print(fname2)
- plt.savefig(fname2)
- # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
- # plt.ylim(0,8)
- # input()
- # plt.draw()
- # plt.show()
- plt.show()
|