import pandas as pd import numpy as np from scipy import signal from multitaper import mtspec from scipy import interpolate import elephant import neo.core import quantities as pq ''' File: LFP_functions.py Author: Pau Boncompte Carré Date: 3/02/2024 This file contains the functions used to align the LFP to the presentation of each image in the stimulus table, as well as other functions used in the analysis of the LFP/CSD. ''' ############################################################################################################## #Functions used to align the LFP ############################################################################################################## def align_lfp(lfp, trial_window, alignment_times, trial_ids=None, time_windows=None): ''' Aligns the LFP data array to experiment times of interest INPUTS: lfp: data array containing LFP data for one probe insertion trial_window: vector specifying the time points to excise around each alignment time alignment_times: experiment times around which to excise data trial_ids: indices in the session stim table specifying which stimuli to use for alignment. None if aligning to non-stimulus times time_windows: dictionary where keys are presentation_ids and values are time windows within which to add values in the aligned_lfp OUTPUT: aligned data array with dimensions channels x trials x time ''' # create a time vector for each trial time_selection = np.concatenate([trial_window + t for t in alignment_times]) if trial_ids is None: trial_ids = np.arange(len(alignment_times)) # create a multi-index for the trials and time points inds = pd.MultiIndex.from_product((trial_ids, trial_window), names=('presentation_id', 'time_from_presentation_onset')) # select the data and stack the time points into a new dimension ds = lfp.sel(time = time_selection, method='nearest').to_dataset(name = 'aligned_lfp') ds = ds.assign(time=inds).unstack('time') aligned_lfp = ds['aligned_lfp'] #if time_windows is not None, we mask the values outside the time windows if time_windows is not None: for presentation_id, time_window in time_windows.items(): mask = (aligned_lfp.time_from_presentation_onset < time_window[0]) | (aligned_lfp.time_from_presentation_onset > time_window[1]) aligned_lfp.loc[dict(presentation_id=presentation_id)] = aligned_lfp.loc[dict(presentation_id=presentation_id)].where(~mask) return aligned_lfp def align_image_lfps(stim_active, lfp, is_norm = False): ''' This is the main function that aligns the lfp to the presentation of each image in the stimulus table. Given the stimulus table, it finds the presentation times of each image and then calls the align_lfp function. INPUTS: stim_active: stimulus table lfp: data array containing LFP data for one probe insertion OUTPUT: aligned_lfps: dictionary with keys being the image names and values being the aligned lfp xarrays ''' #get the presentation times of each image and the image names presentation_times = stim_active.start_time.values image_names = stim_active['image_name'].unique() image_names = image_names[image_names != 'omitted'] aligned_lfps = {} #for every image we obtain an xarray with the aligned lfp. # We save all of them in a dictionary (keys are the image names and values are the xarrays) for image_name in image_names: #we get the info needed for the align_lfp function im_ids, change_ids, n_ims = get_stim_indexes(stim_active,image_name) change_times = presentation_times[change_ids] #find the trial_window (max time window) trial_window = np.arange(-0.25, 0.75*max(n_ims)-0.25, 1/500) #create a dictionary with keys being the change ids and values being the time window around each presentation time_window_dict = {} for i, change_id in enumerate(change_ids): time_window_dict[change_id] = [-0.25, 0.75*n_ims[i]-0.25] #we call the align_lfp function and save the xarray in the dictionary aligned_lfp = align_lfp(lfp, trial_window, change_times, change_ids, time_window_dict) #if is_norm is True, we perform energy normalization if is_norm == True: aligned_lfp = energy_normalization(aligned_lfp,is_mean=False) aligned_lfps[image_name] = aligned_lfp return aligned_lfps def get_stim_indexes(stim_active,im_name): ''' This function returns the indexes of the stimulus table where the image appears and the indexes where the image appears after a change. INPUTS: stim_active: stimulus table im_name: name of the image we want to align to OUTPUT: im_indexes: list with indexes of the stimulus table where the image appears change_indexes: list with indexes of the stimulus table where the image appears after a change n_ims: list wth number of times the image appears after a change ''' im_indexes = [] change_indexes = [] n_ims = [] i = 0 #we go through the stimulus table and find the indexes where the #image appears and the indexes where the image appears after a change while i < len(stim_active): if stim_active.iloc[i]['is_change'] == True and stim_active.iloc[i]['image_name'] == im_name: change_indexes.append(i) im_indexes.append(i) n_im = 1 i += 1 #we count how many times the image appears after a change while i < len(stim_active) and stim_active.iloc[i]['image_name'] == im_name: im_indexes.append(i) n_im += 1 i += 1 n_ims.append(n_im) i+=1 return im_indexes, change_indexes, n_ims ############################################################################################################## #Functions for energy normalization ############################################################################################################## def energy_normalization(signal, is_mean = True): ''' This function normalizes the energy of each channel of a signal to the maximum energy of all channels. INPUTS: signal: data array containing the signal to be normalized. OUTPUT: normalized_signal: data array containing the normalized signal ''' # we check that the data array is 2D if is_mean == True and signal.ndim != 2: signal = signal.mean(dim='presentation_id',skipna=True) normalized_signal = signal.copy() energy = np.zeros(len(signal.channel)) if signal.ndim != 2: for j,pres in enumerate(signal.presentation_id): for i,channel in enumerate(signal.channel): energy[i] = np.sum((signal.sel(channel=channel).sel(presentation_id=pres)-np.mean(signal.sel(channel=channel).sel(presentation_id=pres)))**2) for j,pres in enumerate(signal.presentation_id): for i,channel in enumerate(signal.channel): normalized_signal.loc[dict(channel=channel)].loc[dict(presentation_id=pres)] = (signal.sel(channel=channel).sel(presentation_id=pres)-np.mean(signal.sel(channel=channel).sel(presentation_id=pres))) * np.sqrt(np.max(energy)/energy[i]) else: for i,channel in enumerate(signal.channel): energy[i] = np.sum((signal.sel(channel=channel)-np.mean(signal.sel(channel=channel)))**2) for i,channel in enumerate(signal.channel): normalized_signal.loc[dict(channel=channel)] = (signal.sel(channel=channel)-np.mean(signal.sel(channel=channel))) * np.sqrt(np.max(energy)/energy[i]) return normalized_signal ############################################################################################################## #Functions for Visual Area Selection ############################################################################################################## def select_area(lfps, chans, probe_id, area, is_norm = True, is_mean = False): ''' This function selects the channels in a visual area and returns the aligned LFPs in that area. INPUTS: lfps: dictionary with keys being the image names and values being the LFP xarrays chans: dataframe with the channels information probe_id: probe id area: string with the name of the area is_norm: boolean indicating whether to perform energy normalization OUTPUT: aligned_lfps_area: dictionary with keys being the image names and values being the aligned lfp xarrays in the area chans_lfp_area: dataframe with the channels information in the area that are used in the LFP ''' #get image names and a sample of the lfp to find the channels of interest image_names = list(lfps.keys()) lfp_sample = lfps[image_names[0]] # We observe the channels in the probe that are in the area chans_in_area = chans[(chans['probe_id']==probe_id)&(chans['structure_acronym'].str.contains(area))] first_channel_id = chans_in_area[chans_in_area['structure_acronym'] == area].index.min() last_channel_id = chans_in_area[chans_in_area['structure_acronym'] == area].index.max() aligned_lfps_area = {} #for every image we obtain an xarray with the aligned lfp. # We save all of them in a dictionary (keys are the image names and values are the xarrays) for image_name in image_names: # Get the LFP data for channels in VISp lfp = lfps[image_name] aligned_lfp_area = lfp.sel(channel=slice(first_channel_id,last_channel_id)) #if is_norm is True, we perform energy normalization if is_norm == True: aligned_lfp_area = energy_normalization(aligned_lfp_area, is_mean = is_mean) aligned_lfps_area[image_name] = aligned_lfp_area #save the channels in the area of interest that are used in the LFP (a fourth of the total channels in the area!) #IMPORTANT: do not confuse chans_lfp_area with chans_in_area, since the later contains the #channels in the area whilst the former contains the channels in the area that are used in the LFP. chans_lfp_area = chans_in_area[chans_in_area.index.isin(aligned_lfp_area.channel.values)] return aligned_lfps_area, chans_lfp_area ############################################################################################################## #Functions for CSD ############################################################################################################## def get_csd(lfp, ele_pos, channel_start=None, channel_end=None, fs=500, method_csd = 'KCSD1D'): ''' This function returns the CSD of a signal. INPUTS: lfp: data array containing the signal ele_pos: array containing the positions of the electrodes channel_start: first channel to consider channel_end: last channel to consider OUTPUT: CSD: analog signal containing the CSD ''' if channel_start is None: channel_start = lfp.channel[0] if channel_end is None: channel_end = lfp.channel[-1] #average across presentations (if not done yet) and select the channels if lfp.ndim != 2: lfp = lfp.mean(dim='presentation_id',skipna=True) lfp = lfp.sel(channel=slice(channel_start,channel_end)) #we convert the data array to a neo analog signal lfp_neo = neo.core.AnalogSignal(lfp, units='mV', sampling_rate=fs*pq.Hz) lfp_neo = lfp_neo.T #we get the CSD using the KCSD method CSD=elephant.current_source_density.estimate_csd(lfp_neo, coordinates=ele_pos, method=method_csd, process_estimate=True) return CSD ############################################################################################################## #Functions for power spectrum ############################################################################################################## def butter_highpass(lowcut, fs, order=4): ''' This function returns the coefficients of a highpass butterworth filter. INPUTS: lowcut: cutoff frequency fs: sampling frequency order: order of the filter OUTPUT: b: numerator coefficients a: denominator coefficients ''' nyq = 0.5 * fs #nyquist frequency low = lowcut / nyq #cutoff frequency b, a = signal.butter(order, low, btype='highpass', analog=False) return b, a def get_spectrum(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500): ''' This function returns the spectrum of a signal in a time window using the multitaper method. It also interpolates the spectrum to have a higher resolution. INPUTS: mysignal: data array containing the signal (2D or 3D) time_start: start time of the time window time_end: end time of the time window time_window: time window to use for the spectrum is_interpolated: boolean indicating whether to interpolate the spectrum fs: sampling frequency OUTPUT: f: frequency vector Quad: spectrum of the signal ''' mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)) #average across presentation IDs (if not done yet) and channels if mylfp.ndim != 2: mylfp = mylfp.mean(dim='presentation_id',skipna=True) mylfp = mylfp.mean(dim='channel',skipna=True).values #apply highpass filter to remove frequencies below 4 Hz b, a = butter_highpass(5, 500) lfp = signal.filtfilt(b, a, mylfp) lfp = mylfp #get the spectrum using multitaper method. We use a time window equal to the time window of the signal, #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for #the LFP signals we work with. if time_window == None: time_window = time_end-time_start t,f,Quad,MT=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0) f = f[:,0] #interpolate the spectrum if is_interpolated==True: f_inter = np.linspace(np.min(f),np.max(f),len(f)*4) Quad_inter = np.zeros((np.size(Quad,axis=0)*4,np.size(Quad,axis=1))) for i in range(np.size(Quad,1)): func1 = interpolate.interp1d(f,Quad[:,i],kind='cubic') Quad_inter[:,i] = func1(f_inter) return f_inter, Quad_inter else: return f, Quad def get_spectrum2(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500): ''' This function returns the spectrum of a signal in a time window using the multitaper method. Moreover, it performs the average after the multitaper method to avoid loss of information due to the averaging of the signal. INPUTS: mysignal: data array containing the signal (channel x presentation_id x time) time_start: start time of the time window time_end: end time of the time window time_window: time window to use for the spectrum is_interpolated: boolean indicating whether to interpolate the spectrum fs: sampling frequency OUTPUT: f: frequency vector spectrum: spectrum of the signal ''' #select the signal in the time window mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)) #get the spectrum using multitaper method. We use a time window equal to the time window of the signal, #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for #the LFP signals we work with. if time_window == None: time_window = time_end-time_start #We save the spectrum for each presentation in channel_list and then average them in spectrum_list #to get the final spectrum of all channels. If there is an error in the spectrum calculation #(mainly due to the presence of nan values), we skip the presentation. spectrum_list = [] for i in range(np.size(mylfp.channel)): channel_list = [] for j in range(np.size(mylfp.presentation_id)): try: lfp = mylfp.isel(channel=i).isel(presentation_id=j).values _,f,Quad,_=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0) channel_list.append(Quad) except: continue spectrum_list.append(np.mean(channel_list,axis=0)) spectrum = np.mean(spectrum_list,axis=0) f = f[:,0] #interpolate the spectrum if is_interpolated==True: f_inter = np.linspace(np.min(f),np.max(f),len(f)*4) spectrum_inter = np.zeros((np.size(spectrum,axis=0)*4,np.size(spectrum,axis=1))) for i in range(np.size(spectrum,1)): func1 = interpolate.interp1d(f,spectrum[:,i],kind='cubic') spectrum_inter[:,i] = func1(f_inter) return f_inter, spectrum_inter else: return f, spectrum def get_spectrum_chan(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500): ''' This function returns the spectrum of a signal in a time window using the multitaper method. It is intended for single channel signals. INPUTS: mysignal: data array containing the signal (presentation_id x time) time_start: start time of the time window time_end: end time of the time window time_window: time window to use for the spectrum is_interpolated: boolean indicating whether to interpolate the spectrum fs: sampling frequency OUTPUT: f: frequency vector spectrum: spectrum of the signal ''' #select the signal in the time window mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)) #get the spectrum using multitaper method. We use a time window equal to the time window of the signal, #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for #the LFP signals we work with. if time_window == None: time_window = time_end-time_start #We save the spectrum for each presentation in channel_list and then average them in spectrum_list #to get the final spectrum of all channels. If there is an error in the spectrum calculation #(mainly due to the presence of nan values), we skip the presentation. channel_list = [] for j in range(np.size(mylfp.presentation_id)): try: lfp = mylfp.isel(presentation_id=j).values _,f,Quad,_=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0) channel_list.append(Quad) except: continue spectrum = np.mean(channel_list,axis=0) f = f[:,0] #interpolate the spectrum if is_interpolated==True: f_inter = np.linspace(np.min(f),np.max(f),len(f)*4) spectrum_inter = np.zeros((np.size(spectrum,axis=0)*4,np.size(spectrum,axis=1))) for i in range(np.size(spectrum,1)): func1 = interpolate.interp1d(f,spectrum[:,i],kind='cubic') spectrum_inter[:,i] = func1(f_inter) return f_inter, spectrum_inter else: return f, spectrum def get_spectrogram(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True ,channel_start=None, channel_end=None, fs=500): ''' This function returns the spectrogram of a signal in a time window using the multitaper INPUTS: mysignal: 3D data array containing the signal (channel x presentation_id x time) time_start: start time of the time window time_end: end time of the time window time_window: time window to use for the spectrum is_interpolated: boolean indicating whether to interpolate the spectrum channel_start: first channel to consider channel_end: last channel to consider fs: sampling frequency OUTPUT: frespec: frequency vector lfpspecs: spectrogram of the signal, a 3D array with dimensions (frequency x time x presentation_id) ''' if time_window == None: time_window = time_end-time_start-(1/fs) #for the first presentation, we get the spectrogram and save the frequency vector for i0 in range(np.size(mysignal.presentation_id)): try: lfpspec = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)).isel(presentation_id=i0) frespec,sigspec = get_spectrum(lfpspec,time_start,time_end,time_window=time_window,overlap=overlap,is_interpolated=False) break except ValueError: continue #we create the 3D array with the spectrogram of the signal for every presentation_id lfpspecs = np.zeros((np.size(sigspec,0),np.size(sigspec,1),np.size(mysignal.presentation_id))) lfpspecs[:,:,0] = sigspec for i in range(i0,np.size(mysignal.presentation_id)): lfpspec = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)).isel(presentation_id=i) try: _,sigspec = get_spectrum(lfpspec,time_start,time_end,time_window=time_window,overlap=overlap,is_interpolated=False) lfpspecs[:,:,i] = sigspec except ValueError: lfpspecs[:,:,i] = np.full((np.size(sigspec,0),np.size(sigspec,1)),np.nan) return frespec,lfpspecs ############################################################################################################## #Miscellaneous functions ############################################################################################################## def pres_times(n_pres): ''' Given that the first stimulus is from 0 to 0.25s and between stimuli there is a 0.5 interval (therefore the second stimulus is 0.75-1s, the thirs is 1.5-1.75s, etc), this function returns the presentation times (onset and offest) given the number of the stimulus. If n_pres is 0, return the times between -0.25 and 0s. (before the first stimulus). INPUTS: n_pres: number of the stimulus OUTPUT: t_0: start time of the stimulus t_f: end time of the stimulus ''' if n_pres == 0: t_0 = -0.25 t_f = 0 else: t_0 = (n_pres-1)*0.75 t_f = t_0 + 0.25 return t_0, t_f