12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010 |
- '''
- description: offline analysis of kiap data
- author: Ioannis Vlachos
- date: 02.04.2019
- Copyright (c) 2019 Ioannis Vlachos.
- All rights reserved.'''
- import re
- import os
- import glob
- import importlib
- import urllib
- import time
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from scipy import signal
- from scipy.fftpack import fftshift
- from scipy import stats
- from cachetools import cached, LRUCache, TTLCache
- from neo.io.blackrockio import BlackrockIO
- import aux
- from aux import log
- from analytics import analytics1
- importlib.reload(analytics1)
- class lfp_analytics:
- def __init__(self, params):
- self.params =params
- return None
- # @cached(cache={})
- def read_data(self, file_name):
- '''reads first array1 and then array2 channels'''
- if hasattr(self, 'data0'):
- return None
- reader = BlackrockIO(filename=file_name)
- data0 = reader.get_analogsignal_chunk(i_start=eval(str(self.params.lfp.i_start)), i_stop=eval(str(self.params.lfp.i_stop)),
- channel_indexes=list(set(self.params.lfp.array1) | set(self.params.lfp.array2)))
- params = self.params
- if self.params.lfp.zscore: # z-score
- self.data0 = stats.zscore(self.data0)
- array1 = params.lfp.array1
- array2 = params.lfp.array2
- data01 = data0[:, :len(array1)] # array-1 electrodes
- data02 = data0[:, len(array1):len(array1) + len(array2)] # array-2 electrodes
- self.data0 = data0
- self.data01 = data01
- self.data02 = data02
- self.reader = reader
- self.file_name = file_name
- self.base_name = os.path.basename(file_name).split('.')[0]
- self.results_path = params.file_handling.results +'lfp/'#+ self.base_name.split('-')[0]
- os.makedirs(self.results_path, exist_ok=True)
- self.ch_nr1 = len(self.params.lfp.array1)
- self.ch_nr2 = len(self.params.lfp.array1)
- return None
- # @cached(cache={})
- def preprocess(self):
- if hasattr(self, 'data1'):
- print('preprocessed data is already in memory')
- return None
- data01 = self.data01
- data02 = self.data02
- params = self.params
- if params.lfp.exclude: # EXCLUDE CHANNELS IF VARIANCE IS OUTSIDE PERCENTILES
- self.remove_channels()
- self.ch_nr1 = len(params.lfp.array1) - len(self.array1_msk)
- self.ch_nr2 = len(params.lfp.array2) - len(self.array2_msk)
- else:
- self.ch_nr1 = len(params.lfp.array1)
- self.ch_nr2 = len(params.lfp.array2)
- self.array1_msk = []
- self.array2_msk = []
-
- if params.lfp.car: # PERFORM COMMON AVERAGE REFERENCING
- data01 = data01 - np.repeat(np.mean(data01, axis=1)[:, np.newaxis], data01.shape[1], 1)
- data02 = data02 - np.repeat(np.mean(data02, axis=1)[:, np.newaxis], data02.shape[1], 1)
- # DEVIDE DATA INTO 3 SUB-BANDS
- self.data1 = np.zeros((data01.shape[0], data01.shape[1], 3)) # (samples, channels, 3)
- self.data2 = np.zeros((data02.shape[0], data02.shape[1], 3)) # (samples, channels, 3)
- # FILTER IN THE THREE DIFFERENT BANDS
- self.data1[:, :, 0], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_lb, params.lfp.fs, params.lfp.filter_order_lb, plot=params.lfp.plot.filters)
- self.data1[:, :, 1], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_mb, params.lfp.fs, params.lfp.filter_order_mb, plot=params.lfp.plot.filters)
- self.data1[:, :, 2], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_hb, params.lfp.fs, params.lfp.filter_order_hb, plot=params.lfp.plot.filters)
- self.data2[:, :, 0], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_lb, params.lfp.fs, params.lfp.filter_order_lb, plot=params.lfp.plot.filters)
- self.data2[:, :, 1], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_mb, params.lfp.fs, params.lfp.filter_order_mb, plot=params.lfp.plot.filters)
- self.data2[:, :, 2], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_hb, params.lfp.fs, params.lfp.filter_order_hb, plot=params.lfp.plot.filters)
- # PERFORM STFT USING FFT
- # ----------------------------------
- # Y01, ff01 = analytics1.calc_fft(data01, params.lfp.fs, i_stop=1e5, axis=0)
- # Y02, ff02 = analytics1.calc_fft(data02, params.lfp.fs, i_stop=1e5, axis=0)
- # Y1, ff1 = analytics1.calc_fft(self.data1, params.lfp.fs, i_stop=1e5, axis=0)
- # Y2, ff2 = analytics1.calc_fft(self.data2, params.lfp.fs, i_stop=1e5, axis=0)
- self.data01 = data01
- self.data02 = data02
- # self.Y01 = Y01
- # self.Y02 = Y02
- # self.Y1 = Y1
- # self.Y2 = Y2
- # self.ff01 = ff01
- # self.ff02 = ff02
- # self.ff1 = ff1
- # self.ff2 = ff2
- return None
- def get_triggers(self, n_triggers=10, verbose=True, speller='question'):
- '''get triggers from corresponding nev files'''
- if speller == 'question':
- tname1, tname2 = 'yes', 'no'
- elif speller == 'feedback':
- tname1, tname2 = 'up', 'down'
- try:
- triggers_ts = [tt[0] for tt in reader.nev_data['Comments'][0]]
- triggers_txt = [tt[5].decode('utf-8') for tt in reader.nev_data['Comments'][0]]
- except:
- log.error('Triggers could not be loaded. Does Nev file exist?')
- return None, None
- if verbose:
- print(reader.nev_data['Comments'])
- triggers_01 = []
- triggers_02 = []
- for ii in range(len(triggers_txt)):
- if f'{tname1}, response, start' in triggers_txt[ii]:
- triggers_01.append(triggers_ts[ii])
- elif f'{tname2}, response, start' in triggers_txt[ii]:
- triggers_02.append(triggers_ts[ii])
- triggers_01 = np.array(triggers_01)
- triggers_02 = np.array(triggers_02)
- # get timestamps for ns2
- triggers1 = triggers_yes // params.lfp.sampling_ratio # map nev triggers to ns2
- triggers2 = triggers_02 // params.lfp.sampling_ratio
- self.triggers1 = triggers1[:n_triggers]
- self.triggers2 = triggers2[:n_triggers]
- return None
- def get_triggers_mapping(self, verbose=True):
- '''get triggers from corresponding nev files'''
- states = self.params.lfp.motor_mapping
- try:
- triggers_ts = [tt[0] for tt in self.reader.nev_data['Comments'][0]]
- triggers_txt = [tt[5].decode('utf-8') for tt in self.reader.nev_data['Comments'][0]]
- except:
- log.error('Triggers could not be loaded. Does Nev file exist?')
- return None, None
- if verbose:
- print(self.reader.nev_data['Comments'])
- triggers = [[] for x in range(len(states))]
- for ii in range(len(triggers_txt)):
- if triggers_ts[ii] // self.params.lfp.sampling_ratio <= len(self.data02):
- if 'Zunge, response, start' in triggers_txt[ii]:
- triggers[0].append(triggers_ts[ii])
- print(ii, triggers_ts[ii])
- elif 'Schliesse_Hand, response, start' in triggers_txt[ii]:
- triggers[1].append(triggers_ts[ii])
- elif 'Oeffne_Hand, response, start' in triggers_txt[ii]:
- triggers[2].append(triggers_ts[ii])
- elif 'Bewege_Augen, response, start' in triggers_txt[ii]:
- triggers[3].append(triggers_ts[ii])
- elif 'Bewege_Kopf, response, start' in triggers_txt[ii]:
- triggers[4].append(triggers_ts[ii])
- for ii in range(5):
- triggers[ii] = np.array(triggers[ii]) // self.params.lfp.sampling_ratio # map nev triggers to ns2
- self.triggers_mapping = triggers
- return None
- def get_valid_trials(triggers_no_ns2, triggers_yes_ns2):
- trials_msk = np.ones((10, 2), dtype=bool)
- for ii, tt in enumerate(triggers_no_ns2):
- tmp = data1[tt + psth_win[0]:tt + psth_win[1], :]
- if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
- trials_msk[ii, 0] = False
- for ii, tt in enumerate(triggers_yes_ns2):
- tmp = data1[tt + psth_win[0]:tt + psth_win[1], :]
- if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
- trials_msk[ii, 1] = False
- print(ii)
- print('\nTrials excluded', np.where(trials_msk[:, 0] == False)[0], np.where(trials_msk[:, 1] == False)[0])
- return trials_msk
- def calc_cc_time(self, save_fig=False):
- '''calculate time resolved correlation coefficients betweech channels of each array'''
-
- cc_step = 1000
- rng = range(0, self.data0.shape[0], cc_step)
- cc1 = np.zeros((len(rng), 3))
- cc2 = np.zeros((len(rng), 3))
- # hist1 = np.zeros((len(rng), 100))
- # hist2 = np.zeros((len(rng), 100))
-
- cc01 = np.tril(np.corrcoef(self.data01.T), k=-1)
- cc02 = np.tril(np.corrcoef(self.data02.T), k=-1)
- for ii, idx in enumerate(rng):
- # cc1[ii, 0] = np.min(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
- # cc1[ii, 1] = np.max(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
- # cc1[ii, 2] = np.mean(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
- cc1[ii, 0] = np.min(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1))
- cc1[ii, 1] = np.max(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1))
- cc1[ii, 2] = np.mean(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1))
- # hist1[ii, :] = np.histogram(np.tril(np.corrcoef(data0[idx:idx + cc_step, array1].T)).flatten(), 100)[1][:-1]
- # cc2[ii, 0] = np.min(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1)))
- # cc2[ii, 1] = np.max(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1)))
- # cc2[ii, 2] = np.mean(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1)))
- cc2[ii, 0] = np.min(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1))
- cc2[ii, 1] = np.max(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1))
- cc2[ii, 2] = np.mean(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1))
- # hist2[ii, :] = np.histogram(np.tril(np.corrcoef(data0[idx:idx + cc_step, array2].T)), 100)[1][:-1]
- # print(ii)
- fh = plt.figure(figsize=(7, 10))
- plt.clf()
- plt.subplot(3, 2, 1)
- plt.pcolormesh(cc01)
- plt.clim(-1, 1)
- plt.title('cc, arr 1')
- plt.subplot(3, 2, 2)
- plt.pcolormesh(cc02)
- plt.clim(-1, 1)
- plt.title('cc, arr 2')
- plt.subplot(3, 2, 3)
- plt.hist(cc01.flatten(), color='C0', label=f'{cc01.flatten().mean():.2f}')
- plt.legend(loc=1)
- plt.subplot(3, 2, 4)
- plt.hist(cc02.flatten(), color='C1', label=f'{cc02.flatten().mean():.2f}')
- plt.legend(loc=1)
- ax = plt.subplot(6, 1, 5)
- plt.plot(cc1, 'C0')
- plt.ylim(-1, 1)
- plt.xticks([])
- plt.subplot(6, 1, 6)
- plt.plot(cc2, 'C1')
- plt.ylim(-1, 1)
- # plt.show()
- if save_fig:
- fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_cc.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- return None
- def compute_psth(arr_nr=1, sub_band=-1):
- '''compute the psth for no and yes questions'''
- triggers1 = self.triggers1
- triggers2 = self.triggers2
-
- psth_len = np.diff(psth_win)[0]
-
- if arr_nr == 1:
- ch_nr = ch_nr1
- if sub_band == -1:
- data_tmp = data01 # raw data
- else:
- data_tmp = data1[:, :, sub_band] # sub_band
- elif arr_nr == 2:
- ch_nr = ch_nr2
- if sub_band == -1:
- data_tmp = data02 # raw data
- else:
- data_tmp = data2[:, :, sub_band] # sub_band
- psth_no = np.zeros((len(triggers2), psth_len, ch_nr))
- psth_yes = np.zeros((len(triggers1), psth_len, ch_nr))
- for ii in range(len(triggers2)):
- if (triggers2[ii] + psth_win[0] > 0) & (triggers2[ii] + psth_win[1] < len(data2)):
- tmp = data_tmp[(triggers2[ii] + psth_win[0]):(triggers2[ii] + psth_win[1]), :]
- # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
- psth_no[ii, :, :] = tmp
- # else:
- # print(f'Excluded no-trial: {ii}')
- for ii in range(len(triggers1)):
- if (triggers1[ii] + psth_win[0] > 0) & (triggers1[ii] + psth_win[1] < len(data2)):
- tmp = data_tmp[(triggers1[ii] + psth_win[0]):(triggers1[ii] + psth_win[1]), :]
- # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
- psth_yes[ii, :, :] = tmp
- # else:
- # print(f'Excluded yes-trial: {ii}')
- psth_yes, psth_no = remove_trials(psth_yes, psth_no)
- return psth_no, psth_yes
- def compute_psth_mapping(self, arr_nr=1, ch_nr=np.arange(32), sub_band=-1):
- '''compute the psth for no and yes questions and up and down feedback'''
- triggers = self.triggers_mapping
- states_nr = len(triggers)
- psth_win = self.params.lfp.psth_win
- psth_len = np.diff(psth_win)[0]
- if arr_nr == 1:
- ch_nr = self.ch_nr1
- if sub_band == -1:
- data_tmp = self.data01 # raw data
- else:
- data_tmp = self.data1[:, :, sub_band] # sub_band
- elif arr_nr == 2:
- ch_nr = self.ch_nr2
- if sub_band == -1:
- data_tmp = self.data02 # raw data
- else:
- data_tmp = self.data2[:, :, sub_band] # sub_band
- psth = []
- for jj in range(states_nr):
- psth.append(np.zeros((len(triggers[jj]), psth_len, ch_nr)))
- for jj in range(states_nr):
- for ii in range(len(triggers[jj])):
- if (triggers[jj][ii] + psth_win[0] > 0) & (triggers[jj][ii] + psth_win[1] < len(self.data02)):
- tmp = data_tmp[(triggers[jj][ii] + psth_win[0]):(triggers[jj][ii] + psth_win[1]), :]
- # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
- psth[jj][ii, :, :] = tmp
- # else:
- # print(f'Excluded no-trial: {ii}')
- # psth_yes, psth_no = remove_trials(psth_yes, psth_no)
- if arr_nr == 1:
- self.psth_mapping1 = psth
- elif arr_nr == 2:
- self.psth_mapping2 = psth
- return None
- def remove_channels(self):
- '''remove channels for which variance lies outside 25-75 percentiles'''
- var1 = np.var(self.data01, axis=0)
- var2 = np.var(self.data02, axis=0)
- L1 = np.percentile(var1, 25, axis=0, keepdims=True)
- L2 = np.percentile(var2, 25, axis=0, keepdims=True)
- U1 = np.percentile(var1, 75, axis=0, keepdims=True)
- U2 = np.percentile(var2, 75, axis=0, keepdims=True)
- w = 3 # manual parameter that affects the valid range
- valid1 = [L1 - w * (U1 - L1), U1 + w * (U1 - L1)]
- valid2 = [L2 - w * (U2 - L2), U2 + w * (U2 - L2)]
- # plt.plot(var1, 'C0')
- # plt.plot(var2, 'C1')
- # array1_msk = (np.where(var1 < valid1[0]) and np.where(var1 > valid1[1]))[0]
- # array1_msk = (np.where(var1 < valid1[0]) and np.where(var1 > valid1[1]))[0]
- array1_msk = np.unique(((var1 < valid1[0]) | (var1 > valid1[1])).nonzero()[0])
- array2_msk = np.unique(((var2 < valid2[0]) | (var2 > valid2[1])).nonzero()[0])
- array1_msk = np.hstack((array1_msk, self.params.lfp.array1_exclude)).astype(int)
- array2_msk = np.hstack((array2_msk, self.params.lfp.array2_exclude)).astype(int)
- array1_msk.sort()
- array2_msk.sort()
- print(f'excluded channels, array1: {array1_msk}')
- print(f'excluded channels, array2: {array2_msk}')
- # plot_channels(self.data01, [], arr_id=1, fs=self.params.lfp.fs, array_msk=array1_msk, i_stop=1000, step=1, color='C0')
- # plot_channels(self.data02, [], arr_id=2, fs=self.params.lfp.fs, array_msk=array2_msk, i_stop=1000, step=1, color='C1')
- self.data01 = np.delete(self.data01, array1_msk, axis=1)
- self.data02 = np.delete(self.data02, array2_msk, axis=1)
- self.array1_msk = array1_msk
- self.array2_msk = array2_msk
- self.ch_nr1 = len(self.params.lfp.array1) - len(array1_msk)
- self.ch_nr2 = len(self.params.lfp.array2) - len(array2_msk)
- return None
- def remove_trials(psth_yes, psth_no):
- '''remove trials for which variance lies outside 25-75 percentiles'''
- var1 = np.var(psth_yes, axis=1)
- var2 = np.var(psth_no, axis=1)
- L1 = np.percentile(var1, 25, axis=0, keepdims=True)
- L2 = np.percentile(var2, 25, axis=0, keepdims=True)
- U1 = np.percentile(var1, 75, axis=0, keepdims=True)
- U2 = np.percentile(var2, 75, axis=0, keepdims=True)
- w = 10
- valid1 = [L1 - w * (U1 - L1), U1 + w * (U1 - L1)]
- valid2 = [L2 - w * (U2 - L2), U2 + w * (U2 - L2)]
- # array1_msk = np.unique((np.where(var1 < valid1[0])[0] or np.where(var1 > valid1[1])[0]))
- # array2_msk = np.unique((np.where(var2 < valid2[0])[0] or np.where(var2 > valid2[1])[0]))
- array1_msk = ((var1 < valid1[0]) | (var1 > valid1[1])).nonzero()[0]
- array2_msk = ((var2 < valid2[0]) | (var2 > valid2[1])).nonzero()[0]
- print(f'excluded yes-trials: {array1_msk}')
- print(f'excluded no-trials: {array2_msk}')
- psth_yes = np.delete(psth_yes, array1_msk, axis=0)
- psth_no = np.delete(psth_no, array2_msk, axis=0)
- return psth_yes, psth_no
- def plot_channels(self, ch_ids, arr_id=1, i_start=0, i_stop=10000, step=5, color='C0', save_fig=False):
- '''plot all data from array1 or array2 in the time domain'''
- i_start = int(i_start)
- i_stop = int(i_stop)
- fs = self.params.lfp.fs
- if arr_id == 1:
- data = self.data01
- array_msk = self.array1_msk
- elif arr_id == 2:
- data = self.data02
- array_msk = self.array2_msk
- if not self.params.lfp.zscore: # z-score only for visualization
- data = stats.zscore(data)
-
- if ch_ids == []:
- ch_ids = range(data.shape[1])
- if i_stop == -1:
- i_stop = data.shape[0]
- offset = np.cumsum(4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis]).T
- fh = plt.figure(figsize=(19, 10))
- plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1, alpha=0.8)
- if array_msk != []: # highlight excluded channels
- plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, array_msk] + offset[array_msk], color='C3', lw=1)
- plt.xlabel('Time (sec)')
- plt.yticks(offset[::step], range(0, len(ch_ids), step))
- plt.ylim(0, offset[-1] + 4)
- plt.title(f'raw data from array {arr_id}, n_ch={data.shape[1]}')
- plt.tight_layout()
- if save_fig:
- fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_channels_time.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- return None
- def plot_spectra(ch_ids=[0]):
- '''plot spectra of raw data and sub_bands, single channels from ch_ids and averages'''
- mpl.rcParams['axes.formatter.limits'] = (-2, 2)
- yy = data01.std() * 2
- xx = np.arange(data1.shape[0]) / params.lfp.fs # plot raw and filtered traces for visual inspection
- xmax = 150
- plt.figure(1, figsize=(10, 10))
- plt.clf()
- # column 1: first array
- plt.subplot(521) # plot raw
- plt.plot(ff1, Y01[:, ch_ids[0]], 'C0', label=f'channel: {array1[ch_ids[0]]}')
- plt.plot(ff1, Y01.mean(1), 'C3', alpha=0.2, label='average')
- plt.xlim(xmin=0, xmax=xmax)
- plt.title('array1 raw')
- # plt.legend()
- plt.gca().set_xticklabels([])
- plt.legend()
-
- plt.subplot(523)
- plt.plot(ff1, Y1[:, ch_ids[0], 0], 'C0', label='array1')
- plt.plot(ff1, Y1[:, :, 0].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_lb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.gca().set_xticklabels([])
- plt.title('sub_band 0')
-
- plt.subplot(525)
- plt.plot(ff1, Y1[:, ch_ids[0], 1], 'C0', label='array1')
- plt.plot(ff1, Y1[:, :, 1].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_mb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.gca().set_xticklabels([])
- plt.title('sub_band 1')
- plt.subplot(527)
- plt.plot(ff1, Y1[:, ch_ids[0], 2], 'C0', label='array1')
- plt.plot(ff1, Y1[:, :, 2].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_hb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.title('sub_band 2')
- plt.subplot(529)
- plt.loglog(ff1, Y01[:, ch_ids[0]], 'C0')
- plt.loglog(ff1, Y01.mean(1), 'C3', alpha=0.2)
- plt.title('array1 raw loglog')
- plt.xlabel('Frequency (Hz)')
- # plt.legend()
- # column 2: second array
- plt.subplot(522)
- plt.plot(ff02, Y02[:, ch_ids[0]], 'C0', label=f'channel: {array2[ch_ids[0]]}')
- plt.plot(ff02, Y02.mean(1), 'C3', alpha=0.2, label='average')
- plt.xlim(xmin=0, xmax=xmax)
- # plt.gca().set_yticklabels([])
- plt.title('array2 raw')
- # xxx
- # plt.legend()
- plt.gca().set_xticklabels([])
- plt.legend()
-
- plt.subplot(524)
- plt.plot(ff2, Y2[:, ch_ids[0], 0], 'C0', label='array1')
- plt.plot(ff2, Y2[:, :, 0].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_lb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.gca().set_xticklabels([])
- plt.title('sub_band 0')
-
- plt.subplot(526)
- plt.plot(ff2, Y2[:, ch_ids[0], 1], 'C0', label='array1')
- plt.plot(ff2, Y2[:, :, 1].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_mb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.gca().set_xticklabels([])
- # plt.gca().set_yticklabels([])
- plt.title('sub_band 1')
- plt.subplot(528)
- plt.plot(ff2, Y2[:, ch_ids[0], 2], 'C0', label='array1')
- plt.plot(ff2, Y2[:, :, 2].mean(1), 'C3', alpha=0.2)
- # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_hb[1] * 2)
- plt.xlim(xmin=0, xmax=xmax)
- plt.title('sub_band 2')
- # plt.gca().set_yticklabels([])
- plt.subplot(5, 2, 10)
- plt.loglog(ff01, Y02[:, ch_ids[0]], 'C0')
- plt.loglog(ff01, Y02.mean(1), 'C3', alpha=0.2)
- # plt.gca().set_yticklabels([])
- plt.title('raw loglog')
- plt.xlabel('Frequency (Hz)')
- plt.tight_layout()
- plt.draw()
- plt.show()
- return None
- def calc_spectra(self, arr_nr=1, axis=0, nperseg=10000, force=False, detrend='constant', onesided=True):
- if arr_nr == 1 and (not hasattr(self, 'Y1') or force):
- ff01, Y01 = signal.welch(self.data01, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum')
- Y01 = fftshift(Y01, axes=axis)
- ff01 = fftshift(ff01)
- ff, Y = signal.welch(self.data1, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum')
- Y = fftshift(Y, axes=axis)
- ff = fftshift(ff)
- self.Y01 = Y01
- self.ff01 = ff01
- self.Y1 = Y
- self.ff1 = ff
- elif arr_nr == 2 and (not hasattr(self, 'Y2') or force):
- ff02, Y02 = signal.welch(self.data02, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum')
- Y02 = fftshift(Y02, axes=axis)
- ff02 = fftshift(ff02)
- ff, Y = signal.welch(self.data2, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum')
- Y = fftshift(Y, axes=axis)
- ff = fftshift(ff)
- self.Y02 = Y02
- self.ff02 = ff02
- self.Y2 = Y
- self.ff2 = ff
- return None
- def plot_spectra_original(self, arr_nr=1, ch_ids=[]):
- '''plot spectra of raw data for all channels of arr_nr, compare with spectrum of white noise'''
- mpl.rcParams['axes.formatter.limits'] = (-2, 2)
- if arr_nr == 1:
- Y = self.Y01
- ff = self.ff01
- d0, d1 = self.data01.shape
- elif arr_nr == 2:
- Y = self.Y02
- ff = self.ff02
- d0, d1 = self.data02.shape
- # Y_GWN, _ = analytics1.calc_fft(np.random.randn(d0, d1), params.lfp.fs, i_stop=1e5, axis=0)
- plt.figure(figsize=(10, 10))
- plt.clf()
-
- if ch_ids == []:
- ch_ids = np.arange(Y.shape[1])
- p1 = int(np.round(np.sqrt(len(ch_ids))))
- for ii in range(Y.shape[1]):
- ax = plt.subplot(p1, p2, ii + 1)
- # plt.semilogy(ff, Y[:, ii], 'C0')
- plt.semilogy(ff, Y[:, ii], 'C0')
- # plt.semilogy(ff1, Y_GWN[:, ii], 'C1', alpha=0.5)
- # plt.ylim(10e-2, 10e4)
- # plt.xlim(0, 30)
- # if ii<56:
- if (ii // 8) < (d1 // 8):
- ax.set_xticks([])
- if np.mod(ii, 8) != 0:
- ax.set_yticks([])
- plt.draw()
- return None
- def plot_spectra_bands(self, arr_nr=1, band_id=0, ch_ids=[], log=False, erase=True, xmin=0, xmax=30, ymin=0, ymax=10, save_fig=False):
-
- '''plot spectra of raw data for all channels of arr_nr, compare with spectrum of white noise'''
- mpl.rcParams['axes.formatter.limits'] = (-2, 2)
- if arr_nr == 1:
- Y = self.Y1
- ff = self.ff1
- d0, d1 = self.data01.shape
- elif arr_nr == 2:
- Y = self.Y2
- ff = self.ff2
- d0, d1 = self.data02.shape
- # Y_GWN, _ = analytics1.calc_fft(np.random.randn(d0, d1), params.lfp.fs, i_stop=1e5, axis=0)
- if erase:
- fh = plt.figure(figsize=(19, 10))
- plt.clf()
- if ch_ids == []:
- ch_ids = np.arange(Y.shape[1])
- p1 = int(np.round(np.sqrt(len(ch_ids))))
- fidx = (ff > xmin) & (ff < xmax)
- for ii, ch_id in enumerate(ch_ids):
- ax = plt.subplot(p1, p1 + 1, ii + 1)
- # plt.semilogy(ff, Y[:, ch_id], 'C0')
- if log:
- plt.semilogy(ff[fidx], Y[fidx, ch_id, band_id], lw=1, label=f'{ch_id}')
- if ii == 0:
- plt.semilogy(ff[fidx], Y[fidx, :, band_id].mean(axis=1), 'k', alpha=0.7, lw=1, label=f'{ch_id}')
- else:
- plt.plot(ff[fidx], Y[fidx, ch_id, band_id], lw=1, label=f'{ch_id}')
- if ii == 0:
- plt.plot(ff[fidx], Y[fidx, :, band_id].mean(axis=1), 'k', alpha=0.7, lw=1, label=f'{ch_id}')
- # plt.semilogy(ff1, Y_GWN[:, ch_id], 'C1', alpha=0.5)
- plt.xlim(xmin, xmax)
- plt.ylim(ymin, ymax)
- # if ch_id<56:
- # if (ch_id // p1) < (d1 // p1):
- if ch_id < Y.shape[1] - 1:
- ax.set_xticks([])
- else:
- plt.xlabel('Hz')
- # if np.mod(ch_id, p1) != 0:
- if ch_id > 0:
- ax.set_yticks([])
- plt.legend(loc=1, prop={'size': 6})
- plt.tight_layout()
- if save_fig:
- fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_spectra.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- plt.draw()
- return None
- def calc_stft(self, arr_nr=1, nperseg=10000):
- if arr_nr == 1:
- data = self.data01
- elif arr_nr == 2:
- data = self.data02
- ff, tt, S = signal.spectrogram(data[:, :], self.params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
- self.ff_stft = ff
- self.tt_stft = tt
- if arr_nr == 1:
- self.S1 = S
- elif arr_nr == 2:
- self.S2 = S
- return None
- def calc_stft_psth(self, arr_nr=1, nperseg=1000):
- if arr_nr == 1:
- psth = self.psth_mapping1
- _,_, Smean = signal.spectrogram(self.data01[70000:250000:, ], self.params.lfp.fs,
- axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
- elif arr_nr == 2:
- psth = self.psth_mapping2
- _,_, Smean = signal.spectrogram(self.data02[70000:250000:, ], self.params.lfp.fs,
- axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
- Smean = Smean.mean(axis=2) # mean stft for normalization
- S_tot = []
- for cond in range(len(psth)):
- ff, tt, S = signal.spectrogram(psth[cond][:, :], self.params.lfp.fs, axis=1, nperseg=nperseg, noverlap=int(nperseg * 0.9))
- self.ff_stft_mapping = ff
- self.tt_stft_mapping = tt
-
- S_tot.append(S)
- Smean = np.repeat(Smean[:, :, np.newaxis], tt.size, axis=2) # adjust dims to psth_mapping array
- if arr_nr == 1:
- self.psth_S1_mapping = S_tot
- self.psth_S1_mean = Smean
- elif arr_nr == 2:
- self.psth_S2_mapping = S_tot
- self.psth_S2_mean = Smean
- return None
- def plot_stft(self, arr_nr=1, fmin=0, fmax=200, save_fig=True):
- if arr_nr == 1:
- S = self.S1
- elif arr_nr == 2:
- S = self.S2
-
- fh = plt.figure()
- fidx = (self.ff_stft > fmin) & (self.ff_stft < fmax)
- plt.pcolormesh(self.tt_stft, self.ff_stft[fidx], np.log10(S[fidx, :, :].mean(axis=1)))
- plt.xlabel('Time (sec)')
- plt.ylabel('Hz')
- plt.title(f'stft, arr: {arr_nr}, average {S.shape[1]} channels')
- if save_fig:
- fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_arr_{arr_nr}_stft.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- return None
- def plot_psth(ff=[], sub_band=0, ch_id=0, trial_id=0, step=10):
- ''' plot psth for no and yes responses separately'''
- yy = max(data01.std(), data1[:, ch_id, sub_band].std()) * 2
- xx1 = np.arange(data1.shape[0]) / params.lfp.fs
- xx2 = np.arange(psth_win[0], psth_win[1]) / params.lfp.fs
- f_idx = ff < 150
- plt.figure(figsize=(10, 10))
- plt.clf()
- log.info(f'Selected sub-band: {sub_band}')
- plt.subplot(411)
- plt.plot(xx1[::step], data01[::step, ch_id], label='raw ns2') # plot raw
- plt.plot(xx1[::step], data1[::step, ch_id, sub_band], 'C1', alpha=0.5, label=f'sub-band: {sub_band}') # plot sub_band (low, middle, high)
- plt.stem(triggers_no_ns2 / params.lfp.fs, triggers_no_ns2 * 0 + yy * 0.8, markerfmt='C3o', label='no') # red
- plt.stem(triggers_yes_ns2 / params.lfp.fs, triggers_yes_ns2 * 0 + yy * 0.8, markerfmt='C2o', label='yes') # green
- plt.ylim(-yy, yy)
- plt.xlabel('sec')
- plt.legend(loc=1)
- plt.subplot(423) # plot psth no-questions
- plt.plot(xx2, psth_no.mean(0)[:, ch_id], label='trial av. no') # average accross trials, plot 1 channel
- plt.plot(xx2, psth_no.mean(0).mean(1), 'C3', alpha=0.4, label='trial-ch av. no') # average accross trials and channels
- plt.xlabel('sec')
- plt.legend()
- plt.subplot(424) # plot psth yes-questions
- plt.plot(xx2, psth_yes.mean(0)[:, ch_id], label='trial av. yes') # average accross trials, plot 1 channel
- plt.plot(xx2, psth_yes.mean(0).mean(1), 'C3', alpha=0.4, label='trial-ch av. yes') # average accross trials and channels
- plt.xlabel('sec')
- plt.legend()
- plt.subplot(425) # plot spectrogram no-question, single channel and trial
- # plt.pcolormesh(tt, ff[f_idx], Sxx1[f_idx, ch_id, :])
- if params.lfp.normalize:
- plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :] / Sxx1_norm[f_idx, :, ch_id])
- else:
- plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :])
- vmax = max(Sxx1[trial_id, f_idx, ch_id, :].max(), Sxx2[trial_id, f_idx, ch_id, :].max())
- plt.clim(vmax=vmax)
- plt.xlabel('sec')
- plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
- plt.subplot(426) # plot spectrogram yes-question, single channel and trial
- if params.lfp.normalize:
- plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :] / Sxx2_norm[f_idx, :, ch_id])
- else:
- plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :])
- plt.clim(vmax=vmax)
- plt.xlabel('sec')
- plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
- plt.subplot(427) # plot spectrogram no-question, averaged
- # plt.pcolormesh(tt, ff[f_idx], Sxx1[f_idx, ch_id, :])
- if params.lfp.normalize:
- plt.pcolormesh(tt, ff[f_idx], Sxx1[:, f_idx, ch_id, :].mean(0) / Sxx1_norm[f_idx, :, ch_id])
- else:
- plt.pcolormesh(tt, ff[f_idx], Sxx1[:, f_idx, ch_id, :].mean(0))
- vmax = max(Sxx1[trial_id, f_idx, ch_id, :].max(), Sxx2[trial_id, f_idx, ch_id, :].max())
- plt.clim(vmax=vmax)
- plt.xlabel('sec')
- plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
- plt.subplot(428) # plot spectrogram yes-question, averaged
- if params.lfp.normalize:
- plt.pcolormesh(tt, ff[f_idx], Sxx2[:, f_idx, ch_id, :].mean(0) / Sxx2_norm[f_idx, :, ch_id])
- else:
- plt.pcolormesh(tt, ff[f_idx], Sxx2[:, f_idx, ch_id, :].mean(0))
- plt.clim(vmax=vmax)
- plt.xlabel('sec')
- plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
- plt.tight_layout()
- plt.show()
- return None
- def plot_psth_time(self, ch_ids=[], save_fig=False):
- states = self.params.lfp.motor_mapping
- for ch_id in ch_ids:
- fh = plt.figure(1)
- plt.clf()
- for ii in range(5):
- plt.subplot(2, 3, ii + 1)
- plt.plot(np.mean(self.psth_mapping[ii][:, :, ch_id], axis=0), 'C0', alpha=0.5)
- plt.plot(np.median(self.psth_mapping[ii][:, :, ch_id], axis=0), 'k', alpha=0.5)
- plt.ylim(-200, 200)
- plt.vlines(1000, -500, 500, alpha=0.7)
- plt.title(f'{states[ii]}, n={self.psth_mapping[ii].shape[0]}')
- if ii < 3:
- plt.xticks([])
- plt.tight_layout()
- if save_fig:
- fig_name = f'{self.results_path}psth/{self.base_name.split("-")[1]}_psth_time_ch_{ch_id}_.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- def plot_psth_stft(self, arr_nr=[], ch_ids=[], fmin=0, fmax=200, save_fig=False):
- states = self.params.lfp.motor_mapping
- if arr_nr == 1:
- psth = self.psth_S1_mapping
- S_mean = self.psth_S1_mean
- elif arr_nr == 2:
- psth = self.psth_S2_mapping
- S_mean = self.psth_S2_mean
- for ch_id in ch_ids:
- fh = plt.figure(1)
- plt.clf()
- ff = self.ff_stft_mapping
- tt = self.tt_stft_mapping
- fidx = (ff > fmin) & (ff < fmax)
- for cond in range(5):
- plt.subplot(2, 3, cond + 1)
- plt.pcolormesh(tt, ff[fidx], np.median(psth[cond][:, fidx, ch_id, :], axis=0) / S_mean[fidx, ch_id, :])
- # plt.plot(np.mean(self.psth_mapping[cond][:, fidx, ch_id, :], axis=0), 'C0', alpha=0.5)
- plt.title(f'{states[cond]}, n={psth[cond].shape[0]}')
- if cond < 3:
- plt.xticks([])
- plt.tight_layout()
- if save_fig:
- fig_name = f'{self.results_path}psth/stft/{self.base_name.split("-")[1]}_psth_stft_arr_{arr_nr}_ch_{ch_id}_.png'
- plt.savefig(fig_name)
- print(f'saving figure {fh.number} in {fig_name}\n')
- def plot_pdf(self, f_id=0, ch_ids=range(32), xmax=50, ymax=1000):
- data = self.data[f_id, 0]
- plt.figure()
- plt.clf()
- for ch_id in range(len(ch_ids)):
- ax = plt.subplot(8, 8, ch_id + 1)
- hist, bin_edges = np.histogram(data[:, ch_id], range(-1, 50))
- plt.bar(bin_edges[:-1], hist, width=1, label=f'{ch_ids[ch_id]}')
- plt.xlim(0, xmax)
- plt.ylim(0, ymax)
- plt.legend(loc=1, prop={'size': 6})
- if ch_id < 56:
- ax.set_xticks([])
- else:
- ax.set_xlabel('sp/sec')
- if np.mod(ch_id, 8) > 0:
- ax.set_yticks([])
- return None
- def plot_spectra(self, f_id=0, recompute=True, fs=20, v_thr=5, ymin=0, ymax=7, arr_id=0):
- data = self.data[f_id, 0]
- if arr_id == 0:
- ch_ids = range(32, 96)
- else:
- ch_ids = list(range(32)) + list(range(96, 128))
- if not hasattr(self, 'Y') or recompute:
- # Y1, ff1 = analytics1.calc_fft(data, fs=fs, i_stop=-1, axis=0)
- ff, Y = signal.welch(data, fs=fs, axis=0, nperseg=1000)
- # fidx1 = (ff1 > 0.1) & (ff1 < 5)
- self.Y = Y
- self.ff = ff
- fidx = (self.ff > 0.1) & (self.ff < 5)
- v_mean = data.mean(axis=0)
- v_ids = np.argwhere(v_mean > v_thr) # get ids of channels with mean firing rate higher than v_thr
-
- plt.figure()
- plt.clf()
- for ii, ch_id in enumerate(ch_ids):
- ax = plt.subplot(8, 8, ii + 1)
- # plt.plot(ff1[fidx1], Y1[fidx1, ch_id], label=f'{ch_id}')
- if ch_id in v_ids:
- plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'C0', lw=2, label=f'{ch_id}')
- else:
- plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'k', lw=1, alpha=0.5, label=f'{ch_id}')
- plt.legend(loc=1, prop={'size': 6})
- plt.ylim(ymin, ymax)
- if ch_id < 56:
- ax.set_xticks([])
- else:
- ax.set_xlabel('Hz')
- if np.mod(ch_id, 8) > 0:
- ax.set_yticks([])
- return None
|