123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- import pandas as pd
- import numpy as np
- from scipy import signal
- from multitaper import mtspec
- from scipy import interpolate
- import elephant
- import neo.core
- import quantities as pq
- '''
- File: LFP_functions.py
- Author: Pau Boncompte Carré
- Date: 3/02/2024
- This file contains the functions used to align the LFP to the presentation of each
- image in the stimulus table, as well as other functions used in the analysis of the LFP/CSD.
- '''
- ##############################################################################################################
- #Functions used to align the LFP
- ##############################################################################################################
- def align_lfp(lfp, trial_window, alignment_times, trial_ids=None, time_windows=None):
- '''
- Aligns the LFP data array to experiment times of interest
- lfp: data array containing LFP data for one probe insertion
- trial_window: vector specifying the time points to excise around each alignment time
- alignment_times: experiment times around which to excise data
- trial_ids: indices in the session stim table specifying which stimuli to use for alignment.
- None if aligning to non-stimulus times
- time_windows: dictionary where keys are presentation_ids and values are time windows within which to add values in the aligned_lfp
- aligned data array with dimensions channels x trials x time
- '''
- # create a time vector for each trial
- time_selection = np.concatenate([trial_window + t for t in alignment_times])
- if trial_ids is None:
- trial_ids = np.arange(len(alignment_times))
- # create a multi-index for the trials and time points
- inds = pd.MultiIndex.from_product((trial_ids, trial_window),
- names=('presentation_id', 'time_from_presentation_onset'))
- # select the data and stack the time points into a new dimension
- ds = lfp.sel(time = time_selection, method='nearest').to_dataset(name = 'aligned_lfp')
- ds = ds.assign(time=inds).unstack('time')
- aligned_lfp = ds['aligned_lfp']
- #if time_windows is not None, we mask the values outside the time windows
- if time_windows is not None:
- for presentation_id, time_window in time_windows.items():
- mask = (aligned_lfp.time_from_presentation_onset < time_window[0]) | (aligned_lfp.time_from_presentation_onset > time_window[1])
- aligned_lfp.loc[dict(presentation_id=presentation_id)] = aligned_lfp.loc[dict(presentation_id=presentation_id)].where(~mask)
- return aligned_lfp
- def align_image_lfps(stim_active, lfp, is_norm = False):
- '''
- This is the main function that aligns the lfp to the presentation of each image in the stimulus table.
- Given the stimulus table, it finds the presentation times of each image and then calls the align_lfp function.
- stim_active: stimulus table
- lfp: data array containing LFP data for one probe insertion
- aligned_lfps: dictionary with keys being the image names and values being the aligned lfp xarrays
- '''
- #get the presentation times of each image and the image names
- presentation_times = stim_active.start_time.values
- image_names = stim_active['image_name'].unique()
- image_names = image_names[image_names != 'omitted']
- aligned_lfps = {}
- #for every image we obtain an xarray with the aligned lfp.
- # We save all of them in a dictionary (keys are the image names and values are the xarrays)
- for image_name in image_names:
- #we get the info needed for the align_lfp function
- im_ids, change_ids, n_ims = get_stim_indexes(stim_active,image_name)
- change_times = presentation_times[change_ids]
- #find the trial_window (max time window)
- trial_window = np.arange(-0.25, 0.75*max(n_ims)-0.25, 1/500)
- #create a dictionary with keys being the change ids and values being the time window around each presentation
- time_window_dict = {}
- for i, change_id in enumerate(change_ids):
- time_window_dict[change_id] = [-0.25, 0.75*n_ims[i]-0.25]
- #we call the align_lfp function and save the xarray in the dictionary
- aligned_lfp = align_lfp(lfp, trial_window, change_times, change_ids, time_window_dict)
- #if is_norm is True, we perform energy normalization
- if is_norm == True:
- aligned_lfp = energy_normalization(aligned_lfp,is_mean=False)
- aligned_lfps[image_name] = aligned_lfp
- return aligned_lfps
- def get_stim_indexes(stim_active,im_name):
- '''
- This function returns the indexes of the stimulus table where the image appears and the indexes where the image appears after a change.
- stim_active: stimulus table
- im_name: name of the image we want to align to
- im_indexes: list with indexes of the stimulus table where the image appears
- change_indexes: list with indexes of the stimulus table where the image appears after a change
- n_ims: list wth number of times the image appears after a change
- '''
- im_indexes = []
- change_indexes = []
- n_ims = []
- i = 0
- #we go through the stimulus table and find the indexes where the
- #image appears and the indexes where the image appears after a change
- while i < len(stim_active):
- if stim_active.iloc[i]['is_change'] == True and stim_active.iloc[i]['image_name'] == im_name:
- change_indexes.append(i)
- im_indexes.append(i)
- n_im = 1
- i += 1
- #we count how many times the image appears after a change
- while i < len(stim_active) and stim_active.iloc[i]['image_name'] == im_name:
- im_indexes.append(i)
- n_im += 1
- i += 1
- n_ims.append(n_im)
- i+=1
- return im_indexes, change_indexes, n_ims
- ##############################################################################################################
- #Functions for energy normalization
- ##############################################################################################################
- def energy_normalization(signal, is_mean = True):
- '''
- This function normalizes the energy of each channel of a signal to the maximum energy of all channels.
- signal: data array containing the signal to be normalized.
- normalized_signal: data array containing the normalized signal
- '''
- # we check that the data array is 2D
- if is_mean == True and signal.ndim != 2:
- signal = signal.mean(dim='presentation_id',skipna=True)
- normalized_signal = signal.copy()
- energy = np.zeros(len(signal.channel))
- if signal.ndim != 2:
- for j,pres in enumerate(signal.presentation_id):
- for i,channel in enumerate(signal.channel):
- energy[i] = np.sum((signal.sel(channel=channel).sel(presentation_id=pres)-np.mean(signal.sel(channel=channel).sel(presentation_id=pres)))**2)
- for j,pres in enumerate(signal.presentation_id):
- for i,channel in enumerate(signal.channel):
- normalized_signal.loc[dict(channel=channel)].loc[dict(presentation_id=pres)] = (signal.sel(channel=channel).sel(presentation_id=pres)-np.mean(signal.sel(channel=channel).sel(presentation_id=pres))) * np.sqrt(np.max(energy)/energy[i])
- else:
- for i,channel in enumerate(signal.channel):
- energy[i] = np.sum((signal.sel(channel=channel)-np.mean(signal.sel(channel=channel)))**2)
- for i,channel in enumerate(signal.channel):
- normalized_signal.loc[dict(channel=channel)] = (signal.sel(channel=channel)-np.mean(signal.sel(channel=channel))) * np.sqrt(np.max(energy)/energy[i])
- return normalized_signal
- ##############################################################################################################
- #Functions for Visual Area Selection
- ##############################################################################################################
- def select_area(lfps, chans, probe_id, area, is_norm = True, is_mean = False):
- '''
- This function selects the channels in a visual area and returns the aligned LFPs in that area.
- lfps: dictionary with keys being the image names and values being the LFP xarrays
- chans: dataframe with the channels information
- probe_id: probe id
- area: string with the name of the area
- is_norm: boolean indicating whether to perform energy normalization
- aligned_lfps_area: dictionary with keys being the image names and values being the aligned lfp xarrays in the area
- chans_lfp_area: dataframe with the channels information in the area that are used in the LFP
- '''
- #get image names and a sample of the lfp to find the channels of interest
- image_names = list(lfps.keys())
- lfp_sample = lfps[image_names[0]]
- # We observe the channels in the probe that are in the area
- chans_in_area = chans[(chans['probe_id']==probe_id)&(chans['structure_acronym'].str.contains(area))]
- first_channel_id = chans_in_area[chans_in_area['structure_acronym'] == area].index.min()
- last_channel_id = chans_in_area[chans_in_area['structure_acronym'] == area].index.max()
- aligned_lfps_area = {}
- #for every image we obtain an xarray with the aligned lfp.
- # We save all of them in a dictionary (keys are the image names and values are the xarrays)
- for image_name in image_names:
- # Get the LFP data for channels in VISp
- lfp = lfps[image_name]
- aligned_lfp_area = lfp.sel(channel=slice(first_channel_id,last_channel_id))
- #if is_norm is True, we perform energy normalization
- if is_norm == True:
- aligned_lfp_area = energy_normalization(aligned_lfp_area, is_mean = is_mean)
- aligned_lfps_area[image_name] = aligned_lfp_area
- #save the channels in the area of interest that are used in the LFP (a fourth of the total channels in the area!)
- #IMPORTANT: do not confuse chans_lfp_area with chans_in_area, since the later contains the
- #channels in the area whilst the former contains the channels in the area that are used in the LFP.
- chans_lfp_area = chans_in_area[chans_in_area.index.isin(aligned_lfp_area.channel.values)]
- return aligned_lfps_area, chans_lfp_area
- ##############################################################################################################
- #Functions for CSD
- ##############################################################################################################
- def get_csd(lfp, ele_pos, channel_start=None, channel_end=None, fs=500, method_csd = 'KCSD1D'):
- '''
- This function returns the CSD of a signal.
- lfp: data array containing the signal
- ele_pos: array containing the positions of the electrodes
- channel_start: first channel to consider
- channel_end: last channel to consider
- CSD: analog signal containing the CSD
- '''
- if channel_start is None:
- channel_start = lfp.channel[0]
- if channel_end is None:
- channel_end = lfp.channel[-1]
- #average across presentations (if not done yet) and select the channels
- if lfp.ndim != 2:
- lfp = lfp.mean(dim='presentation_id',skipna=True)
- lfp = lfp.sel(channel=slice(channel_start,channel_end))
- #we convert the data array to a neo analog signal
- lfp_neo = neo.core.AnalogSignal(lfp, units='mV', sampling_rate=fs*pq.Hz)
- lfp_neo = lfp_neo.T
- #we get the CSD using the KCSD method
- CSD=elephant.current_source_density.estimate_csd(lfp_neo, coordinates=ele_pos, method=method_csd, process_estimate=True)
- return CSD
- ##############################################################################################################
- #Functions for power spectrum
- ##############################################################################################################
- def butter_highpass(lowcut, fs, order=4):
- '''
- This function returns the coefficients of a highpass butterworth filter.
- lowcut: cutoff frequency
- fs: sampling frequency
- order: order of the filter
- b: numerator coefficients
- a: denominator coefficients
- '''
- nyq = 0.5 * fs #nyquist frequency
- low = lowcut / nyq #cutoff frequency
- b, a = signal.butter(order, low, btype='highpass', analog=False)
- return b, a
- def get_spectrum(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500):
- '''
- This function returns the spectrum of a signal in a time window using the multitaper
- method. It also interpolates the spectrum to have a higher resolution.
- mysignal: data array containing the signal (2D or 3D)
- time_start: start time of the time window
- time_end: end time of the time window
- time_window: time window to use for the spectrum
- is_interpolated: boolean indicating whether to interpolate the spectrum
- fs: sampling frequency
- f: frequency vector
- Quad: spectrum of the signal
- '''
- mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end))
- #average across presentation IDs (if not done yet) and channels
- if mylfp.ndim != 2:
- mylfp = mylfp.mean(dim='presentation_id',skipna=True)
- mylfp = mylfp.mean(dim='channel',skipna=True).values
- #apply highpass filter to remove frequencies below 4 Hz
- b, a = butter_highpass(5, 500)
- lfp = signal.filtfilt(b, a, mylfp)
- lfp = mylfp
- #get the spectrum using multitaper method. We use a time window equal to the time window of the signal,
- #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for
- #the LFP signals we work with.
- if time_window == None:
- time_window = time_end-time_start
- t,f,Quad,MT=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0)
- f = f[:,0]
- #interpolate the spectrum
- if is_interpolated==True:
- f_inter = np.linspace(np.min(f),np.max(f),len(f)*4)
- Quad_inter = np.zeros((np.size(Quad,axis=0)*4,np.size(Quad,axis=1)))
- for i in range(np.size(Quad,1)):
- func1 = interpolate.interp1d(f,Quad[:,i],kind='cubic')
- Quad_inter[:,i] = func1(f_inter)
- return f_inter, Quad_inter
- else:
- return f, Quad
- def get_spectrum2(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500):
- '''
- This function returns the spectrum of a signal in a time window using the multitaper
- method. Moreover, it performs the average after the multitaper method to avoid loss of
- information due to the averaging of the signal.
- mysignal: data array containing the signal (channel x presentation_id x time)
- time_start: start time of the time window
- time_end: end time of the time window
- time_window: time window to use for the spectrum
- is_interpolated: boolean indicating whether to interpolate the spectrum
- fs: sampling frequency
- f: frequency vector
- spectrum: spectrum of the signal
- '''
- #select the signal in the time window
- mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end))
- #get the spectrum using multitaper method. We use a time window equal to the time window of the signal,
- #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for
- #the LFP signals we work with.
- if time_window == None:
- time_window = time_end-time_start
- #We save the spectrum for each presentation in channel_list and then average them in spectrum_list
- #to get the final spectrum of all channels. If there is an error in the spectrum calculation
- #(mainly due to the presence of nan values), we skip the presentation.
- spectrum_list = []
- for i in range(np.size(mylfp.channel)):
- channel_list = []
- for j in range(np.size(mylfp.presentation_id)):
- try:
- lfp = mylfp.isel(channel=i).isel(presentation_id=j).values
- _,f,Quad,_=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0)
- channel_list.append(Quad)
- except:
- continue
- spectrum_list.append(np.mean(channel_list,axis=0))
- spectrum = np.mean(spectrum_list,axis=0)
- f = f[:,0]
- #interpolate the spectrum
- if is_interpolated==True:
- f_inter = np.linspace(np.min(f),np.max(f),len(f)*4)
- spectrum_inter = np.zeros((np.size(spectrum,axis=0)*4,np.size(spectrum,axis=1)))
- for i in range(np.size(spectrum,1)):
- func1 = interpolate.interp1d(f,spectrum[:,i],kind='cubic')
- spectrum_inter[:,i] = func1(f_inter)
- return f_inter, spectrum_inter
- else:
- return f, spectrum
- def get_spectrum_chan(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True, fs=500):
- '''
- This function returns the spectrum of a signal in a time window using the multitaper
- method. It is intended for single channel signals.
- mysignal: data array containing the signal (presentation_id x time)
- time_start: start time of the time window
- time_end: end time of the time window
- time_window: time window to use for the spectrum
- is_interpolated: boolean indicating whether to interpolate the spectrum
- fs: sampling frequency
- f: frequency vector
- spectrum: spectrum of the signal
- '''
- #select the signal in the time window
- mylfp = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end))
- #get the spectrum using multitaper method. We use a time window equal to the time window of the signal,
- #so we only get one spectrum with higher resolution. The parameters have been chosen to be optimal for
- #the LFP signals we work with.
- if time_window == None:
- time_window = time_end-time_start
- #We save the spectrum for each presentation in channel_list and then average them in spectrum_list
- #to get the final spectrum of all channels. If there is an error in the spectrum calculation
- #(mainly due to the presence of nan values), we skip the presentation.
- channel_list = []
- for j in range(np.size(mylfp.presentation_id)):
- try:
- lfp = mylfp.isel(presentation_id=j).values
- _,f,Quad,_=mtspec.spectrogram(lfp, 1/fs, time_window-1/fs, olap=overlap, nw=2, kspec=3, fmin=0, fmax=100, iadapt=0)
- channel_list.append(Quad)
- except:
- continue
- spectrum = np.mean(channel_list,axis=0)
- f = f[:,0]
- #interpolate the spectrum
- if is_interpolated==True:
- f_inter = np.linspace(np.min(f),np.max(f),len(f)*4)
- spectrum_inter = np.zeros((np.size(spectrum,axis=0)*4,np.size(spectrum,axis=1)))
- for i in range(np.size(spectrum,1)):
- func1 = interpolate.interp1d(f,spectrum[:,i],kind='cubic')
- spectrum_inter[:,i] = func1(f_inter)
- return f_inter, spectrum_inter
- else:
- return f, spectrum
- def get_spectrogram(mysignal, time_start, time_end, overlap= 0, time_window=None, is_interpolated=True ,channel_start=None, channel_end=None, fs=500):
- '''
- This function returns the spectrogram of a signal in a time window using the multitaper
- mysignal: 3D data array containing the signal (channel x presentation_id x time)
- time_start: start time of the time window
- time_end: end time of the time window
- time_window: time window to use for the spectrum
- is_interpolated: boolean indicating whether to interpolate the spectrum
- channel_start: first channel to consider
- channel_end: last channel to consider
- fs: sampling frequency
- frespec: frequency vector
- lfpspecs: spectrogram of the signal, a 3D array with dimensions (frequency x time x presentation_id)
- '''
- if time_window == None:
- time_window = time_end-time_start-(1/fs)
- #for the first presentation, we get the spectrogram and save the frequency vector
- for i0 in range(np.size(mysignal.presentation_id)):
- try:
- lfpspec = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)).isel(presentation_id=i0)
- frespec,sigspec = get_spectrum(lfpspec,time_start,time_end,time_window=time_window,overlap=overlap,is_interpolated=False)
- break
- except ValueError:
- continue
- #we create the 3D array with the spectrogram of the signal for every presentation_id
- lfpspecs = np.zeros((np.size(sigspec,0),np.size(sigspec,1),np.size(mysignal.presentation_id)))
- lfpspecs[:,:,0] = sigspec
- for i in range(i0,np.size(mysignal.presentation_id)):
- lfpspec = mysignal.sel(time_from_presentation_onset=slice(time_start,time_end)).isel(presentation_id=i)
- try:
- _,sigspec = get_spectrum(lfpspec,time_start,time_end,time_window=time_window,overlap=overlap,is_interpolated=False)
- lfpspecs[:,:,i] = sigspec
- except ValueError:
- lfpspecs[:,:,i] = np.full((np.size(sigspec,0),np.size(sigspec,1)),np.nan)
- return frespec,lfpspecs
- ##############################################################################################################
- #Miscellaneous functions
- ##############################################################################################################
- def pres_times(n_pres):
- '''
- Given that the first stimulus is from 0 to 0.25s and between stimuli there is a 0.5 interval (therefore
- the second stimulus is 0.75-1s, the thirs is 1.5-1.75s, etc), this function returns the presentation times
- (onset and offest) given the number of the stimulus. If n_pres is 0, return the times between -0.25 and 0s.
- (before the first stimulus).
- n_pres: number of the stimulus
- t_0: start time of the stimulus
- t_f: end time of the stimulus
- '''
- if n_pres == 0:
- t_0 = -0.25
- t_f = 0
- else:
- t_0 = (n_pres-1)*0.75
- t_f = t_0 + 0.25
- return t_0, t_f