|
@@ -0,0 +1,151 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+Created on Sun Oct 8 15:13:44 2023
|
|
|
+
|
|
|
+@author: chen min,I authorize this code to be publicly available on gin.g-node.org website
|
|
|
+"""
|
|
|
+
|
|
|
+#%% LFP artifacts removal
|
|
|
+
|
|
|
+import neo
|
|
|
+import numpy as np
|
|
|
+import quantities as pq
|
|
|
+from elephant.signal_processing import butter
|
|
|
+from elephant.spectral import welch_psd
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from pathlib import Path
|
|
|
+from scipy import signal
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+config = {"font.family": 'Times New Roman',
|
|
|
+ "mathtext.fontset": 'stix',
|
|
|
+ "savefig.dpi": 300 }
|
|
|
+plt.rcParams.update(config)
|
|
|
+
|
|
|
+
|
|
|
+def load_data(filename):
|
|
|
+ """Load data extraction LFP"""
|
|
|
+ io = neo.io.NeuroExplorerIO(filename)
|
|
|
+ segment = io.read_segment()
|
|
|
+ WB = [segment.analogsignals[i] for i in range(1)]
|
|
|
+
|
|
|
+ return WB
|
|
|
+
|
|
|
+
|
|
|
+def artifacts_removal(WB_one):
|
|
|
+
|
|
|
+ # Find the index of the value that exceeds the threshold 1.0
|
|
|
+ indexes_ = np.where(WB_one>1.5)[0]
|
|
|
+
|
|
|
+ index_diff = np.diff(indexes_) #
|
|
|
+ index1 = np.where(index_diff>60)[0]
|
|
|
+ indexes = indexes_[index1]
|
|
|
+
|
|
|
+ indexes_max_val = []
|
|
|
+ for i in indexes:
|
|
|
+ index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0]
|
|
|
+ index_max_val = i - 5 + index_max_val_
|
|
|
+ indexes_max_val.append(index_max_val)
|
|
|
+ # The index of the maximum value in each artifact
|
|
|
+ indexes_max_val = np.array(indexes_max_val)
|
|
|
+
|
|
|
+ #linear interpolation
|
|
|
+ for j in indexes_max_val:
|
|
|
+ WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55)
|
|
|
+
|
|
|
+ return WB_one
|
|
|
+
|
|
|
+
|
|
|
+def filtering(filenames):
|
|
|
+ """Filtering"""
|
|
|
+ filtered_LFPs = []
|
|
|
+ for filename in filenames:
|
|
|
+ WB = load_data(filename)
|
|
|
+ filtered_LFP = []
|
|
|
+
|
|
|
+ for i in range(len(WB)):
|
|
|
+ WB_one = artifacts_removal(WB[i])
|
|
|
+ # lowpass filtering 300Hz
|
|
|
+ WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz)
|
|
|
+
|
|
|
+ LFP_one = WB_one[5::40] # downsampling
|
|
|
+ LFP_one.sampling_rate = 1000* pq.Hz
|
|
|
+
|
|
|
+ # notching filter
|
|
|
+ for j in range(5):
|
|
|
+ LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j),
|
|
|
+ lowpass_frequency=(49.5+50*j))
|
|
|
+ filtered_LFP.append(LFP_one)
|
|
|
+ filtered_LFPs.append(filtered_LFP)
|
|
|
+
|
|
|
+ return filtered_LFPs
|
|
|
+
|
|
|
+
|
|
|
+def _welch_PSD(filtered_LFPs):
|
|
|
+ """PSD of LFP"""
|
|
|
+ freq = []
|
|
|
+ psd = []
|
|
|
+ for filtered_LFP in filtered_LFPs:
|
|
|
+ freq_ = []
|
|
|
+ psd_ = []
|
|
|
+ for i in range(len(filtered_LFP)):
|
|
|
+ freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000)
|
|
|
+ freq_.append(freq_one)
|
|
|
+ psd_.append(psd_one)
|
|
|
+ freq.append(freq_)
|
|
|
+ psd.append(psd_)
|
|
|
+
|
|
|
+ return freq, psd
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+#%% main
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+
|
|
|
+ filenames = ['D:/Desktop/code/data/test.nex']
|
|
|
+ filtered_LFPs = filtering(filenames)
|
|
|
+ freq, psd = _welch_PSD(filtered_LFPs) # PSD
|
|
|
+
|
|
|
+ fig = plt.figure(figsize=(14, 9.5))
|
|
|
+ # num: file_num i: channel_num mask: range of frequency
|
|
|
+ num = 0
|
|
|
+ i = 0
|
|
|
+ mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100) # 0.5~100Hz
|
|
|
+
|
|
|
+ x = freq[num][i][mask]
|
|
|
+ y = 10*np.log10(psd[num][i][0][mask].magnitude)
|
|
|
+ plt.plot(x, y)
|
|
|
+
|
|
|
+ plt.title(filtered_LFPs[0][i].name, fontsize = 16)
|
|
|
+ plt.xticks(fontsize = 12)
|
|
|
+ plt.yticks(fontsize = 12)
|
|
|
+ plt.xlabel('Frequency(Hz)', fontsize=12)
|
|
|
+ plt.ylabel('PSD(dB/Hz)', fontsize=12)
|
|
|
+
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|