Browse Source

上传文件至 ''

925029326 2 months ago
parent
commit
5da57dd18a
1 changed files with 151 additions and 0 deletions
  1. 151 0
      artifacts_removal.py

+ 151 - 0
artifacts_removal.py

@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sun Oct  8 15:13:44 2023
+
+@author: chen min,I authorize this code to be publicly available on gin.g-node.org website
+"""
+
+#%% LFP artifacts removal
+
+import neo
+import numpy as np
+import quantities as pq
+from elephant.signal_processing import butter
+from elephant.spectral import welch_psd
+import matplotlib.pyplot as plt
+from pathlib import Path
+from scipy import signal
+
+
+
+
+config = {"font.family": 'Times New Roman',
+          "mathtext.fontset": 'stix',
+          "savefig.dpi": 300   }      
+plt.rcParams.update(config)  
+
+
+def load_data(filename):
+    """Load data    extraction LFP"""
+    io = neo.io.NeuroExplorerIO(filename)
+    segment = io.read_segment()
+    WB = [segment.analogsignals[i] for i in range(1)]
+    
+    return WB
+
+
+def artifacts_removal(WB_one):
+    
+    # Find the index of the value that exceeds the threshold 1.0
+    indexes_ = np.where(WB_one>1.5)[0] 
+    
+    index_diff = np.diff(indexes_)  # 
+    index1 = np.where(index_diff>60)[0]
+    indexes = indexes_[index1]
+
+    indexes_max_val = []
+    for i in indexes:
+        index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0]
+        index_max_val = i - 5 + index_max_val_
+        indexes_max_val.append(index_max_val)
+    # The index of the maximum value in each artifact
+    indexes_max_val = np.array(indexes_max_val)
+    
+    #linear interpolation
+    for j in indexes_max_val:
+        WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55)  
+    
+    return WB_one
+
+
+def filtering(filenames):
+    """Filtering"""
+    filtered_LFPs = []
+    for filename in filenames:
+        WB = load_data(filename)
+        filtered_LFP = []
+        
+        for i in range(len(WB)):
+            WB_one = artifacts_removal(WB[i])
+            # lowpass filtering 300Hz
+            WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz)
+            
+            LFP_one = WB_one[5::40]  # downsampling
+            LFP_one.sampling_rate = 1000* pq.Hz
+            
+            # notching filter
+            for j in range(5):  
+                LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j),
+                                  lowpass_frequency=(49.5+50*j))  
+            filtered_LFP.append(LFP_one)
+        filtered_LFPs.append(filtered_LFP)
+        
+    return filtered_LFPs
+
+
+def _welch_PSD(filtered_LFPs): 
+    """PSD of LFP"""
+    freq = []
+    psd = []
+    for filtered_LFP in filtered_LFPs:
+        freq_ = []
+        psd_ = []
+        for i in range(len(filtered_LFP)):
+            freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000)
+            freq_.append(freq_one)
+            psd_.append(psd_one)
+        freq.append(freq_)
+        psd.append(psd_)
+        
+    return freq, psd
+
+
+
+
+
+
+
+#%% main
+
+if __name__ == '__main__':
+
+    filenames = ['D:/Desktop/code/data/test.nex']
+    filtered_LFPs = filtering(filenames)
+    freq, psd = _welch_PSD(filtered_LFPs)  # PSD
+
+    fig = plt.figure(figsize=(14, 9.5))
+    # num: file_num   i: channel_num   mask: range of frequency
+    num = 0
+    i = 0
+    mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100)  # 0.5~100Hz
+    
+    x = freq[num][i][mask]  
+    y = 10*np.log10(psd[num][i][0][mask].magnitude)  
+    plt.plot(x, y)
+
+    plt.title(filtered_LFPs[0][i].name, fontsize = 16)
+    plt.xticks(fontsize = 12)
+    plt.yticks(fontsize = 12)
+    plt.xlabel('Frequency(Hz)', fontsize=12) 
+    plt.ylabel('PSD(dB/Hz)', fontsize=12)  
+
+    plt.show()
+
+
+
+    
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+