# -*- coding: utf-8 -*- """ Created on Sun Oct 8 15:13:44 2023 @author: chen min,I authorize this code to be publicly available on gin.g-node.org website """ #%% LFP artifacts removal import neo import numpy as np import quantities as pq from elephant.signal_processing import butter from elephant.spectral import welch_psd import matplotlib.pyplot as plt from pathlib import Path from scipy import signal config = {"font.family": 'Times New Roman', "mathtext.fontset": 'stix', "savefig.dpi": 300 } plt.rcParams.update(config) def load_data(filename): """Load data extraction LFP""" io = neo.io.NeuroExplorerIO(filename) segment = io.read_segment() WB = [segment.analogsignals[i] for i in range(1)] return WB def artifacts_removal(WB_one): # Find the index of the value that exceeds the threshold 1.0 indexes_ = np.where(WB_one>1.5)[0] index_diff = np.diff(indexes_) # index1 = np.where(index_diff>60)[0] indexes = indexes_[index1] indexes_max_val = [] for i in indexes: index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0] index_max_val = i - 5 + index_max_val_ indexes_max_val.append(index_max_val) # The index of the maximum value in each artifact indexes_max_val = np.array(indexes_max_val) #linear interpolation for j in indexes_max_val: WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55) return WB_one def filtering(filenames): """Filtering""" filtered_LFPs = [] for filename in filenames: WB = load_data(filename) filtered_LFP = [] for i in range(len(WB)): WB_one = artifacts_removal(WB[i]) # lowpass filtering 300Hz WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz) LFP_one = WB_one[5::40] # downsampling LFP_one.sampling_rate = 1000* pq.Hz # notching filter for j in range(5): LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j), lowpass_frequency=(49.5+50*j)) filtered_LFP.append(LFP_one) filtered_LFPs.append(filtered_LFP) return filtered_LFPs def _welch_PSD(filtered_LFPs): """PSD of LFP""" freq = [] psd = [] for filtered_LFP in filtered_LFPs: freq_ = [] psd_ = [] for i in range(len(filtered_LFP)): freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000) freq_.append(freq_one) psd_.append(psd_one) freq.append(freq_) psd.append(psd_) return freq, psd #%% main if __name__ == '__main__': filenames = ['D:/Desktop/code/data/test.nex'] filtered_LFPs = filtering(filenames) freq, psd = _welch_PSD(filtered_LFPs) # PSD fig = plt.figure(figsize=(14, 9.5)) # num: file_num i: channel_num mask: range of frequency num = 0 i = 0 mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100) # 0.5~100Hz x = freq[num][i][mask] y = 10*np.log10(psd[num][i][0][mask].magnitude) plt.plot(x, y) plt.title(filtered_LFPs[0][i].name, fontsize = 16) plt.xticks(fontsize = 12) plt.yticks(fontsize = 12) plt.xlabel('Frequency(Hz)', fontsize=12) plt.ylabel('PSD(dB/Hz)', fontsize=12) plt.show()