''' description: offline analysis of kiap data author: Ioannis Vlachos date: 02.04.2019 Copyright (c) 2019 Ioannis Vlachos. All rights reserved.''' import re import os import glob import importlib import urllib import time import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl from scipy import signal from scipy.fftpack import fftshift from scipy import stats from cachetools import cached, LRUCache, TTLCache from neo.io.blackrockio import BlackrockIO import aux from aux import log from analytics import analytics1 importlib.reload(analytics1) class lfp_analytics: def __init__(self, params): self.params =params return None # @cached(cache={}) def read_data(self, file_name): '''reads first array1 and then array2 channels''' if hasattr(self, 'data0'): return None reader = BlackrockIO(filename=file_name) data0 = reader.get_analogsignal_chunk(i_start=eval(str(self.params.lfp.i_start)), i_stop=eval(str(self.params.lfp.i_stop)), channel_indexes=list(set(self.params.lfp.array1) | set(self.params.lfp.array2))) params = self.params if self.params.lfp.zscore: # z-score self.data0 = stats.zscore(self.data0) array1 = params.lfp.array1 array2 = params.lfp.array2 data01 = data0[:, :len(array1)] # array-1 electrodes data02 = data0[:, len(array1):len(array1) + len(array2)] # array-2 electrodes self.data0 = data0 self.data01 = data01 self.data02 = data02 self.reader = reader self.file_name = file_name self.base_name = os.path.basename(file_name).split('.')[0] self.results_path = params.file_handling.results +'lfp/'#+ self.base_name.split('-')[0] os.makedirs(self.results_path, exist_ok=True) self.ch_nr1 = len(self.params.lfp.array1) self.ch_nr2 = len(self.params.lfp.array1) return None # @cached(cache={}) def preprocess(self): if hasattr(self, 'data1'): print('preprocessed data is already in memory') return None data01 = self.data01 data02 = self.data02 params = self.params if params.lfp.exclude: # EXCLUDE CHANNELS IF VARIANCE IS OUTSIDE PERCENTILES self.remove_channels() self.ch_nr1 = len(params.lfp.array1) - len(self.array1_msk) self.ch_nr2 = len(params.lfp.array2) - len(self.array2_msk) else: self.ch_nr1 = len(params.lfp.array1) self.ch_nr2 = len(params.lfp.array2) self.array1_msk = [] self.array2_msk = [] if params.lfp.car: # PERFORM COMMON AVERAGE REFERENCING data01 = data01 - np.repeat(np.mean(data01, axis=1)[:, np.newaxis], data01.shape[1], 1) data02 = data02 - np.repeat(np.mean(data02, axis=1)[:, np.newaxis], data02.shape[1], 1) # DEVIDE DATA INTO 3 SUB-BANDS self.data1 = np.zeros((data01.shape[0], data01.shape[1], 3)) # (samples, channels, 3) self.data2 = np.zeros((data02.shape[0], data02.shape[1], 3)) # (samples, channels, 3) # FILTER IN THE THREE DIFFERENT BANDS self.data1[:, :, 0], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_lb, params.lfp.fs, params.lfp.filter_order_lb, plot=params.lfp.plot.filters) self.data1[:, :, 1], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_mb, params.lfp.fs, params.lfp.filter_order_mb, plot=params.lfp.plot.filters) self.data1[:, :, 2], _, _ = analytics1.bw_filter(data01.T, params.lfp.filter_fc_hb, params.lfp.fs, params.lfp.filter_order_hb, plot=params.lfp.plot.filters) self.data2[:, :, 0], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_lb, params.lfp.fs, params.lfp.filter_order_lb, plot=params.lfp.plot.filters) self.data2[:, :, 1], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_mb, params.lfp.fs, params.lfp.filter_order_mb, plot=params.lfp.plot.filters) self.data2[:, :, 2], _, _ = analytics1.bw_filter(data02.T, params.lfp.filter_fc_hb, params.lfp.fs, params.lfp.filter_order_hb, plot=params.lfp.plot.filters) # PERFORM STFT USING FFT # ---------------------------------- # Y01, ff01 = analytics1.calc_fft(data01, params.lfp.fs, i_stop=1e5, axis=0) # Y02, ff02 = analytics1.calc_fft(data02, params.lfp.fs, i_stop=1e5, axis=0) # Y1, ff1 = analytics1.calc_fft(self.data1, params.lfp.fs, i_stop=1e5, axis=0) # Y2, ff2 = analytics1.calc_fft(self.data2, params.lfp.fs, i_stop=1e5, axis=0) self.data01 = data01 self.data02 = data02 # self.Y01 = Y01 # self.Y02 = Y02 # self.Y1 = Y1 # self.Y2 = Y2 # self.ff01 = ff01 # self.ff02 = ff02 # self.ff1 = ff1 # self.ff2 = ff2 return None def get_triggers(self, n_triggers=10, verbose=True, speller='question'): '''get triggers from corresponding nev files''' if speller == 'question': tname1, tname2 = 'yes', 'no' elif speller == 'feedback': tname1, tname2 = 'up', 'down' try: triggers_ts = [tt[0] for tt in reader.nev_data['Comments'][0]] triggers_txt = [tt[5].decode('utf-8') for tt in reader.nev_data['Comments'][0]] except: log.error('Triggers could not be loaded. Does Nev file exist?') return None, None if verbose: print(reader.nev_data['Comments']) triggers_01 = [] triggers_02 = [] for ii in range(len(triggers_txt)): if f'{tname1}, response, start' in triggers_txt[ii]: triggers_01.append(triggers_ts[ii]) elif f'{tname2}, response, start' in triggers_txt[ii]: triggers_02.append(triggers_ts[ii]) triggers_01 = np.array(triggers_01) triggers_02 = np.array(triggers_02) # get timestamps for ns2 triggers1 = triggers_yes // params.lfp.sampling_ratio # map nev triggers to ns2 triggers2 = triggers_02 // params.lfp.sampling_ratio self.triggers1 = triggers1[:n_triggers] self.triggers2 = triggers2[:n_triggers] return None def get_triggers_mapping(self, verbose=True): '''get triggers from corresponding nev files''' states = self.params.lfp.motor_mapping try: triggers_ts = [tt[0] for tt in self.reader.nev_data['Comments'][0]] triggers_txt = [tt[5].decode('utf-8') for tt in self.reader.nev_data['Comments'][0]] except: log.error('Triggers could not be loaded. Does Nev file exist?') return None, None if verbose: print(self.reader.nev_data['Comments']) triggers = [[] for x in range(len(states))] for ii in range(len(triggers_txt)): if triggers_ts[ii] // self.params.lfp.sampling_ratio <= len(self.data02): if 'Zunge, response, start' in triggers_txt[ii]: triggers[0].append(triggers_ts[ii]) print(ii, triggers_ts[ii]) elif 'Schliesse_Hand, response, start' in triggers_txt[ii]: triggers[1].append(triggers_ts[ii]) elif 'Oeffne_Hand, response, start' in triggers_txt[ii]: triggers[2].append(triggers_ts[ii]) elif 'Bewege_Augen, response, start' in triggers_txt[ii]: triggers[3].append(triggers_ts[ii]) elif 'Bewege_Kopf, response, start' in triggers_txt[ii]: triggers[4].append(triggers_ts[ii]) for ii in range(5): triggers[ii] = np.array(triggers[ii]) // self.params.lfp.sampling_ratio # map nev triggers to ns2 self.triggers_mapping = triggers return None def get_valid_trials(triggers_no_ns2, triggers_yes_ns2): trials_msk = np.ones((10, 2), dtype=bool) for ii, tt in enumerate(triggers_no_ns2): tmp = data1[tt + psth_win[0]:tt + psth_win[1], :] if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)): trials_msk[ii, 0] = False for ii, tt in enumerate(triggers_yes_ns2): tmp = data1[tt + psth_win[0]:tt + psth_win[1], :] if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)): trials_msk[ii, 1] = False print(ii) print('\nTrials excluded', np.where(trials_msk[:, 0] == False)[0], np.where(trials_msk[:, 1] == False)[0]) return trials_msk def calc_cc_time(self, save_fig=False): '''calculate time resolved correlation coefficients betweech channels of each array''' cc_step = 1000 rng = range(0, self.data0.shape[0], cc_step) cc1 = np.zeros((len(rng), 3)) cc2 = np.zeros((len(rng), 3)) # hist1 = np.zeros((len(rng), 100)) # hist2 = np.zeros((len(rng), 100)) cc01 = np.tril(np.corrcoef(self.data01.T), k=-1) cc02 = np.tril(np.corrcoef(self.data02.T), k=-1) for ii, idx in enumerate(rng): # cc1[ii, 0] = np.min(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))) # cc1[ii, 1] = np.max(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))) # cc1[ii, 2] = np.mean(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))) cc1[ii, 0] = np.min(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1)) cc1[ii, 1] = np.max(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1)) cc1[ii, 2] = np.mean(np.tril(np.corrcoef(self.data01[idx:idx + cc_step, :].T), k=-1)) # hist1[ii, :] = np.histogram(np.tril(np.corrcoef(data0[idx:idx + cc_step, array1].T)).flatten(), 100)[1][:-1] # cc2[ii, 0] = np.min(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1))) # cc2[ii, 1] = np.max(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1))) # cc2[ii, 2] = np.mean(np.abs(np.tril(np.corrcoef(data02[idx:idx + cc_step, :].T), k=-1))) cc2[ii, 0] = np.min(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1)) cc2[ii, 1] = np.max(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1)) cc2[ii, 2] = np.mean(np.tril(np.corrcoef(self.data02[idx:idx + cc_step, :].T), k=-1)) # hist2[ii, :] = np.histogram(np.tril(np.corrcoef(data0[idx:idx + cc_step, array2].T)), 100)[1][:-1] # print(ii) fh = plt.figure(figsize=(7, 10)) plt.clf() plt.subplot(3, 2, 1) plt.pcolormesh(cc01) plt.clim(-1, 1) plt.title('cc, arr 1') plt.subplot(3, 2, 2) plt.pcolormesh(cc02) plt.clim(-1, 1) plt.title('cc, arr 2') plt.subplot(3, 2, 3) plt.hist(cc01.flatten(), color='C0', label=f'{cc01.flatten().mean():.2f}') plt.legend(loc=1) plt.subplot(3, 2, 4) plt.hist(cc02.flatten(), color='C1', label=f'{cc02.flatten().mean():.2f}') plt.legend(loc=1) ax = plt.subplot(6, 1, 5) plt.plot(cc1, 'C0') plt.ylim(-1, 1) plt.xticks([]) plt.subplot(6, 1, 6) plt.plot(cc2, 'C1') plt.ylim(-1, 1) # plt.show() if save_fig: fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_cc.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') return None def compute_psth(arr_nr=1, sub_band=-1): '''compute the psth for no and yes questions''' triggers1 = self.triggers1 triggers2 = self.triggers2 psth_len = np.diff(psth_win)[0] if arr_nr == 1: ch_nr = ch_nr1 if sub_band == -1: data_tmp = data01 # raw data else: data_tmp = data1[:, :, sub_band] # sub_band elif arr_nr == 2: ch_nr = ch_nr2 if sub_band == -1: data_tmp = data02 # raw data else: data_tmp = data2[:, :, sub_band] # sub_band psth_no = np.zeros((len(triggers2), psth_len, ch_nr)) psth_yes = np.zeros((len(triggers1), psth_len, ch_nr)) for ii in range(len(triggers2)): if (triggers2[ii] + psth_win[0] > 0) & (triggers2[ii] + psth_win[1] < len(data2)): tmp = data_tmp[(triggers2[ii] + psth_win[0]):(triggers2[ii] + psth_win[1]), :] # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)): psth_no[ii, :, :] = tmp # else: # print(f'Excluded no-trial: {ii}') for ii in range(len(triggers1)): if (triggers1[ii] + psth_win[0] > 0) & (triggers1[ii] + psth_win[1] < len(data2)): tmp = data_tmp[(triggers1[ii] + psth_win[0]):(triggers1[ii] + psth_win[1]), :] # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)): psth_yes[ii, :, :] = tmp # else: # print(f'Excluded yes-trial: {ii}') psth_yes, psth_no = remove_trials(psth_yes, psth_no) return psth_no, psth_yes def compute_psth_mapping(self, arr_nr=1, ch_nr=np.arange(32), sub_band=-1): '''compute the psth for no and yes questions and up and down feedback''' triggers = self.triggers_mapping states_nr = len(triggers) psth_win = self.params.lfp.psth_win psth_len = np.diff(psth_win)[0] if arr_nr == 1: ch_nr = self.ch_nr1 if sub_band == -1: data_tmp = self.data01 # raw data else: data_tmp = self.data1[:, :, sub_band] # sub_band elif arr_nr == 2: ch_nr = self.ch_nr2 if sub_band == -1: data_tmp = self.data02 # raw data else: data_tmp = self.data2[:, :, sub_band] # sub_band psth = [] for jj in range(states_nr): psth.append(np.zeros((len(triggers[jj]), psth_len, ch_nr))) for jj in range(states_nr): for ii in range(len(triggers[jj])): if (triggers[jj][ii] + psth_win[0] > 0) & (triggers[jj][ii] + psth_win[1] < len(self.data02)): tmp = data_tmp[(triggers[jj][ii] + psth_win[0]):(triggers[jj][ii] + psth_win[1]), :] # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)): psth[jj][ii, :, :] = tmp # else: # print(f'Excluded no-trial: {ii}') # psth_yes, psth_no = remove_trials(psth_yes, psth_no) if arr_nr == 1: self.psth_mapping1 = psth elif arr_nr == 2: self.psth_mapping2 = psth return None def remove_channels(self): '''remove channels for which variance lies outside 25-75 percentiles''' var1 = np.var(self.data01, axis=0) var2 = np.var(self.data02, axis=0) L1 = np.percentile(var1, 25, axis=0, keepdims=True) L2 = np.percentile(var2, 25, axis=0, keepdims=True) U1 = np.percentile(var1, 75, axis=0, keepdims=True) U2 = np.percentile(var2, 75, axis=0, keepdims=True) w = 3 # manual parameter that affects the valid range valid1 = [L1 - w * (U1 - L1), U1 + w * (U1 - L1)] valid2 = [L2 - w * (U2 - L2), U2 + w * (U2 - L2)] # plt.plot(var1, 'C0') # plt.plot(var2, 'C1') # array1_msk = (np.where(var1 < valid1[0]) and np.where(var1 > valid1[1]))[0] # array1_msk = (np.where(var1 < valid1[0]) and np.where(var1 > valid1[1]))[0] array1_msk = np.unique(((var1 < valid1[0]) | (var1 > valid1[1])).nonzero()[0]) array2_msk = np.unique(((var2 < valid2[0]) | (var2 > valid2[1])).nonzero()[0]) array1_msk = np.hstack((array1_msk, self.params.lfp.array1_exclude)).astype(int) array2_msk = np.hstack((array2_msk, self.params.lfp.array2_exclude)).astype(int) array1_msk.sort() array2_msk.sort() print(f'excluded channels, array1: {array1_msk}') print(f'excluded channels, array2: {array2_msk}') # plot_channels(self.data01, [], arr_id=1, fs=self.params.lfp.fs, array_msk=array1_msk, i_stop=1000, step=1, color='C0') # plot_channels(self.data02, [], arr_id=2, fs=self.params.lfp.fs, array_msk=array2_msk, i_stop=1000, step=1, color='C1') self.data01 = np.delete(self.data01, array1_msk, axis=1) self.data02 = np.delete(self.data02, array2_msk, axis=1) self.array1_msk = array1_msk self.array2_msk = array2_msk self.ch_nr1 = len(self.params.lfp.array1) - len(array1_msk) self.ch_nr2 = len(self.params.lfp.array2) - len(array2_msk) return None def remove_trials(psth_yes, psth_no): '''remove trials for which variance lies outside 25-75 percentiles''' var1 = np.var(psth_yes, axis=1) var2 = np.var(psth_no, axis=1) L1 = np.percentile(var1, 25, axis=0, keepdims=True) L2 = np.percentile(var2, 25, axis=0, keepdims=True) U1 = np.percentile(var1, 75, axis=0, keepdims=True) U2 = np.percentile(var2, 75, axis=0, keepdims=True) w = 10 valid1 = [L1 - w * (U1 - L1), U1 + w * (U1 - L1)] valid2 = [L2 - w * (U2 - L2), U2 + w * (U2 - L2)] # array1_msk = np.unique((np.where(var1 < valid1[0])[0] or np.where(var1 > valid1[1])[0])) # array2_msk = np.unique((np.where(var2 < valid2[0])[0] or np.where(var2 > valid2[1])[0])) array1_msk = ((var1 < valid1[0]) | (var1 > valid1[1])).nonzero()[0] array2_msk = ((var2 < valid2[0]) | (var2 > valid2[1])).nonzero()[0] print(f'excluded yes-trials: {array1_msk}') print(f'excluded no-trials: {array2_msk}') psth_yes = np.delete(psth_yes, array1_msk, axis=0) psth_no = np.delete(psth_no, array2_msk, axis=0) return psth_yes, psth_no def plot_channels(self, ch_ids, arr_id=1, i_start=0, i_stop=10000, step=5, color='C0', save_fig=False): '''plot all data from array1 or array2 in the time domain''' i_start = int(i_start) i_stop = int(i_stop) fs = self.params.lfp.fs if arr_id == 1: data = self.data01 array_msk = self.array1_msk elif arr_id == 2: data = self.data02 array_msk = self.array2_msk if not self.params.lfp.zscore: # z-score only for visualization data = stats.zscore(data) if ch_ids == []: ch_ids = range(data.shape[1]) if i_stop == -1: i_stop = data.shape[0] offset = np.cumsum(4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis]).T fh = plt.figure(figsize=(19, 10)) plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1, alpha=0.8) if array_msk != []: # highlight excluded channels plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, array_msk] + offset[array_msk], color='C3', lw=1) plt.xlabel('Time (sec)') plt.yticks(offset[::step], range(0, len(ch_ids), step)) plt.ylim(0, offset[-1] + 4) plt.title(f'raw data from array {arr_id}, n_ch={data.shape[1]}') plt.tight_layout() if save_fig: fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_channels_time.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') return None def plot_spectra(ch_ids=[0]): '''plot spectra of raw data and sub_bands, single channels from ch_ids and averages''' mpl.rcParams['axes.formatter.limits'] = (-2, 2) yy = data01.std() * 2 xx = np.arange(data1.shape[0]) / params.lfp.fs # plot raw and filtered traces for visual inspection xmax = 150 plt.figure(1, figsize=(10, 10)) plt.clf() # column 1: first array plt.subplot(521) # plot raw plt.plot(ff1, Y01[:, ch_ids[0]], 'C0', label=f'channel: {array1[ch_ids[0]]}') plt.plot(ff1, Y01.mean(1), 'C3', alpha=0.2, label='average') plt.xlim(xmin=0, xmax=xmax) plt.title('array1 raw') # plt.legend() plt.gca().set_xticklabels([]) plt.legend() plt.subplot(523) plt.plot(ff1, Y1[:, ch_ids[0], 0], 'C0', label='array1') plt.plot(ff1, Y1[:, :, 0].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_lb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.gca().set_xticklabels([]) plt.title('sub_band 0') plt.subplot(525) plt.plot(ff1, Y1[:, ch_ids[0], 1], 'C0', label='array1') plt.plot(ff1, Y1[:, :, 1].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_mb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.gca().set_xticklabels([]) plt.title('sub_band 1') plt.subplot(527) plt.plot(ff1, Y1[:, ch_ids[0], 2], 'C0', label='array1') plt.plot(ff1, Y1[:, :, 2].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_hb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.title('sub_band 2') plt.subplot(529) plt.loglog(ff1, Y01[:, ch_ids[0]], 'C0') plt.loglog(ff1, Y01.mean(1), 'C3', alpha=0.2) plt.title('array1 raw loglog') plt.xlabel('Frequency (Hz)') # plt.legend() # column 2: second array plt.subplot(522) plt.plot(ff02, Y02[:, ch_ids[0]], 'C0', label=f'channel: {array2[ch_ids[0]]}') plt.plot(ff02, Y02.mean(1), 'C3', alpha=0.2, label='average') plt.xlim(xmin=0, xmax=xmax) # plt.gca().set_yticklabels([]) plt.title('array2 raw') # xxx # plt.legend() plt.gca().set_xticklabels([]) plt.legend() plt.subplot(524) plt.plot(ff2, Y2[:, ch_ids[0], 0], 'C0', label='array1') plt.plot(ff2, Y2[:, :, 0].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_lb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.gca().set_xticklabels([]) plt.title('sub_band 0') plt.subplot(526) plt.plot(ff2, Y2[:, ch_ids[0], 1], 'C0', label='array1') plt.plot(ff2, Y2[:, :, 1].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_mb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.gca().set_xticklabels([]) # plt.gca().set_yticklabels([]) plt.title('sub_band 1') plt.subplot(528) plt.plot(ff2, Y2[:, ch_ids[0], 2], 'C0', label='array1') plt.plot(ff2, Y2[:, :, 2].mean(1), 'C3', alpha=0.2) # plt.xlim(xmin=0, xmax=params.lfp.filter_fc_hb[1] * 2) plt.xlim(xmin=0, xmax=xmax) plt.title('sub_band 2') # plt.gca().set_yticklabels([]) plt.subplot(5, 2, 10) plt.loglog(ff01, Y02[:, ch_ids[0]], 'C0') plt.loglog(ff01, Y02.mean(1), 'C3', alpha=0.2) # plt.gca().set_yticklabels([]) plt.title('raw loglog') plt.xlabel('Frequency (Hz)') plt.tight_layout() plt.draw() plt.show() return None def calc_spectra(self, arr_nr=1, axis=0, nperseg=10000, force=False, detrend='constant', onesided=True): if arr_nr == 1 and (not hasattr(self, 'Y1') or force): ff01, Y01 = signal.welch(self.data01, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum') Y01 = fftshift(Y01, axes=axis) ff01 = fftshift(ff01) ff, Y = signal.welch(self.data1, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum') Y = fftshift(Y, axes=axis) ff = fftshift(ff) self.Y01 = Y01 self.ff01 = ff01 self.Y1 = Y self.ff1 = ff elif arr_nr == 2 and (not hasattr(self, 'Y2') or force): ff02, Y02 = signal.welch(self.data02, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum') Y02 = fftshift(Y02, axes=axis) ff02 = fftshift(ff02) ff, Y = signal.welch(self.data2, fs=self.params.lfp.fs, axis=0, nperseg=nperseg, detrend=detrend, return_onesided=onesided, scaling='spectrum') Y = fftshift(Y, axes=axis) ff = fftshift(ff) self.Y02 = Y02 self.ff02 = ff02 self.Y2 = Y self.ff2 = ff return None def plot_spectra_original(self, arr_nr=1, ch_ids=[]): '''plot spectra of raw data for all channels of arr_nr, compare with spectrum of white noise''' mpl.rcParams['axes.formatter.limits'] = (-2, 2) if arr_nr == 1: Y = self.Y01 ff = self.ff01 d0, d1 = self.data01.shape elif arr_nr == 2: Y = self.Y02 ff = self.ff02 d0, d1 = self.data02.shape # Y_GWN, _ = analytics1.calc_fft(np.random.randn(d0, d1), params.lfp.fs, i_stop=1e5, axis=0) plt.figure(figsize=(10, 10)) plt.clf() if ch_ids == []: ch_ids = np.arange(Y.shape[1]) p1 = int(np.round(np.sqrt(len(ch_ids)))) for ii in range(Y.shape[1]): ax = plt.subplot(p1, p2, ii + 1) # plt.semilogy(ff, Y[:, ii], 'C0') plt.semilogy(ff, Y[:, ii], 'C0') # plt.semilogy(ff1, Y_GWN[:, ii], 'C1', alpha=0.5) # plt.ylim(10e-2, 10e4) # plt.xlim(0, 30) # if ii<56: if (ii // 8) < (d1 // 8): ax.set_xticks([]) if np.mod(ii, 8) != 0: ax.set_yticks([]) plt.draw() return None def plot_spectra_bands(self, arr_nr=1, band_id=0, ch_ids=[], log=False, erase=True, xmin=0, xmax=30, ymin=0, ymax=10, save_fig=False): '''plot spectra of raw data for all channels of arr_nr, compare with spectrum of white noise''' mpl.rcParams['axes.formatter.limits'] = (-2, 2) if arr_nr == 1: Y = self.Y1 ff = self.ff1 d0, d1 = self.data01.shape elif arr_nr == 2: Y = self.Y2 ff = self.ff2 d0, d1 = self.data02.shape # Y_GWN, _ = analytics1.calc_fft(np.random.randn(d0, d1), params.lfp.fs, i_stop=1e5, axis=0) if erase: fh = plt.figure(figsize=(19, 10)) plt.clf() if ch_ids == []: ch_ids = np.arange(Y.shape[1]) p1 = int(np.round(np.sqrt(len(ch_ids)))) fidx = (ff > xmin) & (ff < xmax) for ii, ch_id in enumerate(ch_ids): ax = plt.subplot(p1, p1 + 1, ii + 1) # plt.semilogy(ff, Y[:, ch_id], 'C0') if log: plt.semilogy(ff[fidx], Y[fidx, ch_id, band_id], lw=1, label=f'{ch_id}') if ii == 0: plt.semilogy(ff[fidx], Y[fidx, :, band_id].mean(axis=1), 'k', alpha=0.7, lw=1, label=f'{ch_id}') else: plt.plot(ff[fidx], Y[fidx, ch_id, band_id], lw=1, label=f'{ch_id}') if ii == 0: plt.plot(ff[fidx], Y[fidx, :, band_id].mean(axis=1), 'k', alpha=0.7, lw=1, label=f'{ch_id}') # plt.semilogy(ff1, Y_GWN[:, ch_id], 'C1', alpha=0.5) plt.xlim(xmin, xmax) plt.ylim(ymin, ymax) # if ch_id<56: # if (ch_id // p1) < (d1 // p1): if ch_id < Y.shape[1] - 1: ax.set_xticks([]) else: plt.xlabel('Hz') # if np.mod(ch_id, p1) != 0: if ch_id > 0: ax.set_yticks([]) plt.legend(loc=1, prop={'size': 6}) plt.tight_layout() if save_fig: fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_spectra.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') plt.draw() return None def calc_stft(self, arr_nr=1, nperseg=10000): if arr_nr == 1: data = self.data01 elif arr_nr == 2: data = self.data02 ff, tt, S = signal.spectrogram(data[:, :], self.params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9)) self.ff_stft = ff self.tt_stft = tt if arr_nr == 1: self.S1 = S elif arr_nr == 2: self.S2 = S return None def calc_stft_psth(self, arr_nr=1, nperseg=1000): if arr_nr == 1: psth = self.psth_mapping1 _,_, Smean = signal.spectrogram(self.data01[70000:250000:, ], self.params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9)) elif arr_nr == 2: psth = self.psth_mapping2 _,_, Smean = signal.spectrogram(self.data02[70000:250000:, ], self.params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9)) Smean = Smean.mean(axis=2) # mean stft for normalization S_tot = [] for cond in range(len(psth)): ff, tt, S = signal.spectrogram(psth[cond][:, :], self.params.lfp.fs, axis=1, nperseg=nperseg, noverlap=int(nperseg * 0.9)) self.ff_stft_mapping = ff self.tt_stft_mapping = tt S_tot.append(S) Smean = np.repeat(Smean[:, :, np.newaxis], tt.size, axis=2) # adjust dims to psth_mapping array if arr_nr == 1: self.psth_S1_mapping = S_tot self.psth_S1_mean = Smean elif arr_nr == 2: self.psth_S2_mapping = S_tot self.psth_S2_mean = Smean return None def plot_stft(self, arr_nr=1, fmin=0, fmax=200, save_fig=True): if arr_nr == 1: S = self.S1 elif arr_nr == 2: S = self.S2 fh = plt.figure() fidx = (self.ff_stft > fmin) & (self.ff_stft < fmax) plt.pcolormesh(self.tt_stft, self.ff_stft[fidx], np.log10(S[fidx, :, :].mean(axis=1))) plt.xlabel('Time (sec)') plt.ylabel('Hz') plt.title(f'stft, arr: {arr_nr}, average {S.shape[1]} channels') if save_fig: fig_name = f'{self.results_path}{self.base_name.split("-")[1]}_arr_{arr_nr}_stft.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') return None def plot_psth(ff=[], sub_band=0, ch_id=0, trial_id=0, step=10): ''' plot psth for no and yes responses separately''' yy = max(data01.std(), data1[:, ch_id, sub_band].std()) * 2 xx1 = np.arange(data1.shape[0]) / params.lfp.fs xx2 = np.arange(psth_win[0], psth_win[1]) / params.lfp.fs f_idx = ff < 150 plt.figure(figsize=(10, 10)) plt.clf() log.info(f'Selected sub-band: {sub_band}') plt.subplot(411) plt.plot(xx1[::step], data01[::step, ch_id], label='raw ns2') # plot raw plt.plot(xx1[::step], data1[::step, ch_id, sub_band], 'C1', alpha=0.5, label=f'sub-band: {sub_band}') # plot sub_band (low, middle, high) plt.stem(triggers_no_ns2 / params.lfp.fs, triggers_no_ns2 * 0 + yy * 0.8, markerfmt='C3o', label='no') # red plt.stem(triggers_yes_ns2 / params.lfp.fs, triggers_yes_ns2 * 0 + yy * 0.8, markerfmt='C2o', label='yes') # green plt.ylim(-yy, yy) plt.xlabel('sec') plt.legend(loc=1) plt.subplot(423) # plot psth no-questions plt.plot(xx2, psth_no.mean(0)[:, ch_id], label='trial av. no') # average accross trials, plot 1 channel plt.plot(xx2, psth_no.mean(0).mean(1), 'C3', alpha=0.4, label='trial-ch av. no') # average accross trials and channels plt.xlabel('sec') plt.legend() plt.subplot(424) # plot psth yes-questions plt.plot(xx2, psth_yes.mean(0)[:, ch_id], label='trial av. yes') # average accross trials, plot 1 channel plt.plot(xx2, psth_yes.mean(0).mean(1), 'C3', alpha=0.4, label='trial-ch av. yes') # average accross trials and channels plt.xlabel('sec') plt.legend() plt.subplot(425) # plot spectrogram no-question, single channel and trial # plt.pcolormesh(tt, ff[f_idx], Sxx1[f_idx, ch_id, :]) if params.lfp.normalize: plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :] / Sxx1_norm[f_idx, :, ch_id]) else: plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :]) vmax = max(Sxx1[trial_id, f_idx, ch_id, :].max(), Sxx2[trial_id, f_idx, ch_id, :].max()) plt.clim(vmax=vmax) plt.xlabel('sec') plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50) plt.subplot(426) # plot spectrogram yes-question, single channel and trial if params.lfp.normalize: plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :] / Sxx2_norm[f_idx, :, ch_id]) else: plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :]) plt.clim(vmax=vmax) plt.xlabel('sec') plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50) plt.subplot(427) # plot spectrogram no-question, averaged # plt.pcolormesh(tt, ff[f_idx], Sxx1[f_idx, ch_id, :]) if params.lfp.normalize: plt.pcolormesh(tt, ff[f_idx], Sxx1[:, f_idx, ch_id, :].mean(0) / Sxx1_norm[f_idx, :, ch_id]) else: plt.pcolormesh(tt, ff[f_idx], Sxx1[:, f_idx, ch_id, :].mean(0)) vmax = max(Sxx1[trial_id, f_idx, ch_id, :].max(), Sxx2[trial_id, f_idx, ch_id, :].max()) plt.clim(vmax=vmax) plt.xlabel('sec') plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50) plt.subplot(428) # plot spectrogram yes-question, averaged if params.lfp.normalize: plt.pcolormesh(tt, ff[f_idx], Sxx2[:, f_idx, ch_id, :].mean(0) / Sxx2_norm[f_idx, :, ch_id]) else: plt.pcolormesh(tt, ff[f_idx], Sxx2[:, f_idx, ch_id, :].mean(0)) plt.clim(vmax=vmax) plt.xlabel('sec') plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50) plt.tight_layout() plt.show() return None def plot_psth_time(self, ch_ids=[], save_fig=False): states = self.params.lfp.motor_mapping for ch_id in ch_ids: fh = plt.figure(1) plt.clf() for ii in range(5): plt.subplot(2, 3, ii + 1) plt.plot(np.mean(self.psth_mapping[ii][:, :, ch_id], axis=0), 'C0', alpha=0.5) plt.plot(np.median(self.psth_mapping[ii][:, :, ch_id], axis=0), 'k', alpha=0.5) plt.ylim(-200, 200) plt.vlines(1000, -500, 500, alpha=0.7) plt.title(f'{states[ii]}, n={self.psth_mapping[ii].shape[0]}') if ii < 3: plt.xticks([]) plt.tight_layout() if save_fig: fig_name = f'{self.results_path}psth/{self.base_name.split("-")[1]}_psth_time_ch_{ch_id}_.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') def plot_psth_stft(self, arr_nr=[], ch_ids=[], fmin=0, fmax=200, save_fig=False): states = self.params.lfp.motor_mapping if arr_nr == 1: psth = self.psth_S1_mapping S_mean = self.psth_S1_mean elif arr_nr == 2: psth = self.psth_S2_mapping S_mean = self.psth_S2_mean for ch_id in ch_ids: fh = plt.figure(1) plt.clf() ff = self.ff_stft_mapping tt = self.tt_stft_mapping fidx = (ff > fmin) & (ff < fmax) for cond in range(5): plt.subplot(2, 3, cond + 1) plt.pcolormesh(tt, ff[fidx], np.median(psth[cond][:, fidx, ch_id, :], axis=0) / S_mean[fidx, ch_id, :]) # plt.plot(np.mean(self.psth_mapping[cond][:, fidx, ch_id, :], axis=0), 'C0', alpha=0.5) plt.title(f'{states[cond]}, n={psth[cond].shape[0]}') if cond < 3: plt.xticks([]) plt.tight_layout() if save_fig: fig_name = f'{self.results_path}psth/stft/{self.base_name.split("-")[1]}_psth_stft_arr_{arr_nr}_ch_{ch_id}_.png' plt.savefig(fig_name) print(f'saving figure {fh.number} in {fig_name}\n') def plot_pdf(self, f_id=0, ch_ids=range(32), xmax=50, ymax=1000): data = self.data[f_id, 0] plt.figure() plt.clf() for ch_id in range(len(ch_ids)): ax = plt.subplot(8, 8, ch_id + 1) hist, bin_edges = np.histogram(data[:, ch_id], range(-1, 50)) plt.bar(bin_edges[:-1], hist, width=1, label=f'{ch_ids[ch_id]}') plt.xlim(0, xmax) plt.ylim(0, ymax) plt.legend(loc=1, prop={'size': 6}) if ch_id < 56: ax.set_xticks([]) else: ax.set_xlabel('sp/sec') if np.mod(ch_id, 8) > 0: ax.set_yticks([]) return None def plot_spectra(self, f_id=0, recompute=True, fs=20, v_thr=5, ymin=0, ymax=7, arr_id=0): data = self.data[f_id, 0] if arr_id == 0: ch_ids = range(32, 96) else: ch_ids = list(range(32)) + list(range(96, 128)) if not hasattr(self, 'Y') or recompute: # Y1, ff1 = analytics1.calc_fft(data, fs=fs, i_stop=-1, axis=0) ff, Y = signal.welch(data, fs=fs, axis=0, nperseg=1000) # fidx1 = (ff1 > 0.1) & (ff1 < 5) self.Y = Y self.ff = ff fidx = (self.ff > 0.1) & (self.ff < 5) v_mean = data.mean(axis=0) v_ids = np.argwhere(v_mean > v_thr) # get ids of channels with mean firing rate higher than v_thr plt.figure() plt.clf() for ii, ch_id in enumerate(ch_ids): ax = plt.subplot(8, 8, ii + 1) # plt.plot(ff1[fidx1], Y1[fidx1, ch_id], label=f'{ch_id}') if ch_id in v_ids: plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'C0', lw=2, label=f'{ch_id}') else: plt.plot(self.ff[fidx], self.Y[fidx, ch_id], 'k', lw=1, alpha=0.5, label=f'{ch_id}') plt.legend(loc=1, prop={'size': 6}) plt.ylim(ymin, ymax) if ch_id < 56: ax.set_xticks([]) else: ax.set_xlabel('Hz') if np.mod(ch_id, 8) > 0: ax.set_yticks([]) return None