123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import matplotlib.pyplot as plt
- import numpy as np
- import yaml
- import aux
- from helpers import data_management as dm
- import modules.classifier as clf
- import analytics
- import os
- import itertools
- from collections.abc import Iterable
- params = aux.load_config(force_reload=True)
- if params.speller.type == 'exploration':
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- exploration=True, trigger_pos=params.classifier.trigger_pos)
- elif params.speller.type == 'feedback':
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- feedback=True, trigger_pos=params.classifier.trigger_pos)
- else:
- data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
- trigger_pos=params.classifier.trigger_pos)
- if params.classifier.trigger_pos == 'start':
- psth_win = [-2.0, 3.0] # in seconds!
- else:
- psth_win = [-5.0, 3.0] # in seconds!
- psth_win = np.array(psth_win) * 1000.0 / params.daq.spike_rates.loop_interval
- psth_win = psth_win.astype(int)
- for fids in range(data_tot.shape[0]):
- data = data_tot[fids,0]
- triggers = triggers_tot[fids,0][0]
- psth_down = np.zeros((triggers_tot[0,1].shape[1], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- psth_up = np.zeros((triggers_tot[0,0].shape[1], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
- for ii,tr_id in enumerate(triggers_tot[0,1][0]):
- if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
- psth_down[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- print(ii,tr_id)
- for ii,tr_id in enumerate(triggers_tot[0,0][0]):
- if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
- psth_up[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
- print(ii,tr_id)
- psth_xx = np.arange(psth_win[0],psth_win[1]) / 20.
- col = ['C0', 'C1']
- plt.figure(1)
- plt.clf()
- ### Calculate normalized firing rates for the PSTHs, according to configuration
- norm_rate = {}
- norm_rate['ch_ids'] = np.asarray([ch.id for ch in params.daq.normalization.channels])
- norm_rate['bottoms'] = np.asarray([ch.bottom for ch in params.daq.normalization.channels])
- norm_rate['tops'] = np.asarray([ch.top for ch in params.daq.normalization.channels])
- norm_rate['invs'] = [ch.invert for ch in params.daq.normalization.channels]
- clamped_rates_down = np.maximum(np.minimum(psth_down[:, :, norm_rate['ch_ids']], norm_rate['tops']), norm_rate['bottoms'])
- clamped_rates_up = np.maximum(np.minimum(psth_up[:, :, norm_rate['ch_ids']], norm_rate['tops']), norm_rate['bottoms'])
- norm_rates_down = (clamped_rates_down - norm_rate['bottoms']) / (norm_rate['tops'] - norm_rate['bottoms'])
- norm_rates_up = (clamped_rates_up - norm_rate['bottoms']) / (norm_rate['tops'] - norm_rate['bottoms'])
- norm_rates_up[:,:,norm_rate['invs']] = 1 - norm_rates_up[:,:,norm_rate['invs']]
- norm_rates_down[:,:,norm_rate['invs']] = 1 - norm_rates_down[:,:,norm_rate['invs']]
- ### Calculate firing rate average across channels and then normalize (to simulate 'use_all_channels' option)
- achcfg = params.daq.normalization.all_channels
- all_ch_rates_down = np.maximum(np.minimum(np.squeeze(np.nanmean(psth_down, axis=2)), achcfg.top), achcfg.bottom)
- all_ch_rates_up = np.maximum(np.minimum(np.squeeze(np.nanmean(psth_up, axis=2)), achcfg.top), achcfg.bottom)
- all_ch_rates_down = (all_ch_rates_down - achcfg.bottom) / (achcfg.top - achcfg.bottom)
- all_ch_rates_up = (all_ch_rates_up - achcfg.bottom) / (achcfg.top - achcfg.bottom)
- if achcfg.invert:
- all_ch_rates_down = 1.0 - all_ch_rates_down
- all_ch_rates_up = 1.0 - all_ch_rates_up
- # for multiple channels used for control, use each single one, the whole set, and each subset with one fewer item than the whole set
- n_ch = len(norm_rate['ch_ids'])
- ch_len_list = [i for i in [1, n_ch - 1, n_ch] if i != 0]
- ix_list = list(itertools.chain.from_iterable(itertools.combinations(range(len(norm_rate['ch_ids'])), i) for i in set(ch_len_list) ))
- for ch_id in [*range(128), *ix_list, 'all'] :
- if ch_id == 'all':
- psth_down_agg = all_ch_rates_down
- psth_up_agg = all_ch_rates_up
- filter_min_rate = False
- elif isinstance(ch_id, Iterable):
-
- psth_down_agg = np.squeeze(np.mean(norm_rates_down[:,:,ch_id], axis=2))
- psth_up_agg = np.squeeze(np.mean(norm_rates_up[:,:,ch_id], axis=2))
- filter_min_rate = False
- ch_id = np.sort(norm_rate['ch_ids'][np.array(ch_id)])
- else:
- psth_down_agg = psth_down[:, :, ch_id]
- psth_up_agg = psth_up[:, :, ch_id]
- filter_min_rate = params.plot.filter_min_rate
-
- ymin = 0
- ymax = max(psth_down_agg.max(), psth_up_agg.max())
- mu1 = np.mean(psth_down_agg, axis=0)
- md1 = np.median(psth_down_agg, axis=0)
- mu2 = np.mean(psth_up_agg, axis=0)
- md2 = np.median(psth_up_agg, axis=0)
- # if the mean firing rate is less than 4 Hz in all conditions and at all times, the channel
- # is probably not interesting
- if filter_min_rate and max(mu1.max(),mu2.max()) < filter_min_rate:
- continue
- plt.figure(1)
- plt.clf()
- plt.subplot(221)
- plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='k', alpha=0.8)
- plt.plot(psth_xx, md1.T, '--', color='k', alpha=0.8)
- # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
- # plt.ylim(0,8)
- plt.ylabel('sp/sec')
- plt.ylim(ymin, ymax)
- plt.title(f'down, n={psth_down_agg.shape[0]}')
- plt.subplot(222)
- plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu2.T, color='k', alpha=0.8)
- plt.plot(psth_xx, md2.T, '--', color='k', alpha=0.8)
- plt.title(f'up, n={psth_up_agg.shape[0]}')
- plt.ylim(ymin, ymax)
- plt.subplot(223)
- plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
- plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
- plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
- plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
- plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
- plt.ylim(ymin, ymax)
- plt.ylabel('sp/sec')
- plt.xlabel('sec')
- plt.subplot(224)
- # plt.plot(psth_xx, psth_down[:, :, ch_id].T, 'C0', alpha=0.5)
- # plt.plot(psth_xx, psth_up[:, :, ch_id].T, 'C1', alpha=0.5)
- plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
- plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
- plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
- plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
- plt.ylim(ymin, ymax)
- plt.xlabel('sec')
- fname = os.path.basename(file_names[0]).split('.')[0]
- os.makedirs(f'{params.file_handling.results}/{fname}', exist_ok=True)
- # plt.savefig(f'/media/vlachos/bck_disk1/kiap/recordings/fr/results/neurofeedback/{fname}_nf_{ch_id}.png')
- fname2 = f'{params.file_handling.results}/{fname}/{fname}_nf_{ch_id}.png'
- print(fname2)
- plt.savefig(fname2)
- # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
- # plt.ylim(0,8)
- # input()
- # plt.draw()
- # plt.show()
- plt.show()
|