Selaa lähdekoodia

删除 'artifacts_removal.py'

925029326 3 kuukautta sitten
vanhempi
commit
1cc20bac6a
1 muutettua tiedostoa jossa 0 lisäystä ja 151 poistoa
  1. 0 151
      artifacts_removal.py

+ 0 - 151
artifacts_removal.py

@@ -1,151 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Sun Oct  8 15:13:44 2023
-
-@author: comeo
-"""
-
-#%% LFP artifacts removal
-
-import neo
-import numpy as np
-import quantities as pq
-from elephant.signal_processing import butter
-from elephant.spectral import welch_psd
-import matplotlib.pyplot as plt
-from pathlib import Path
-from scipy import signal
-
-
-
-
-config = {"font.family": 'Times New Roman',
-          "mathtext.fontset": 'stix',
-          "savefig.dpi": 300   }      
-plt.rcParams.update(config)  
-
-
-def load_data(filename):
-    """Load data    extraction LFP"""
-    io = neo.io.NeuroExplorerIO(filename)
-    segment = io.read_segment()
-    WB = [segment.analogsignals[i] for i in range(1)]
-    
-    return WB
-
-
-def artifacts_removal(WB_one):
-    
-    # Find the index of the value that exceeds the threshold 1.0
-    indexes_ = np.where(WB_one>1.5)[0] 
-    
-    index_diff = np.diff(indexes_)  # 
-    index1 = np.where(index_diff>60)[0]
-    indexes = indexes_[index1]
-
-    indexes_max_val = []
-    for i in indexes:
-        index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0]
-        index_max_val = i - 5 + index_max_val_
-        indexes_max_val.append(index_max_val)
-    # The index of the maximum value in each artifact
-    indexes_max_val = np.array(indexes_max_val)
-    
-    #linear interpolation
-    for j in indexes_max_val:
-        WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55)  
-    
-    return WB_one
-
-
-def filtering(filenames):
-    """Filtering"""
-    filtered_LFPs = []
-    for filename in filenames:
-        WB = load_data(filename)
-        filtered_LFP = []
-        
-        for i in range(len(WB)):
-            WB_one = artifacts_removal(WB[i])
-            # lowpass filtering 300Hz
-            WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz)
-            
-            LFP_one = WB_one[5::40]  # downsampling
-            LFP_one.sampling_rate = 1000* pq.Hz
-            
-            # notching filter
-            for j in range(5):  
-                LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j),
-                                  lowpass_frequency=(49.5+50*j))  
-            filtered_LFP.append(LFP_one)
-        filtered_LFPs.append(filtered_LFP)
-        
-    return filtered_LFPs
-
-
-def _welch_PSD(filtered_LFPs): 
-    """PSD of LFP"""
-    freq = []
-    psd = []
-    for filtered_LFP in filtered_LFPs:
-        freq_ = []
-        psd_ = []
-        for i in range(len(filtered_LFP)):
-            freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000)
-            freq_.append(freq_one)
-            psd_.append(psd_one)
-        freq.append(freq_)
-        psd.append(psd_)
-        
-    return freq, psd
-
-
-
-
-
-
-
-#%% main
-
-if __name__ == '__main__':
-
-    filenames = ['D:/Desktop/code/data/test.nex']
-    filtered_LFPs = filtering(filenames)
-    freq, psd = _welch_PSD(filtered_LFPs)  # PSD
-
-    fig = plt.figure(figsize=(14, 9.5))
-    # num: file_num   i: channel_num   mask: range of frequency
-    num = 0
-    i = 0
-    mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100)  # 0.5~100Hz
-    
-    x = freq[num][i][mask]  
-    y = 10*np.log10(psd[num][i][0][mask].magnitude)  
-    plt.plot(x, y)
-
-    plt.title(filtered_LFPs[0][i].name, fontsize = 16)
-    plt.xticks(fontsize = 12)
-    plt.yticks(fontsize = 12)
-    plt.xlabel('Frequency(Hz)', fontsize=12) 
-    plt.ylabel('PSD(dB/Hz)', fontsize=12)  
-
-    plt.show()
-
-
-
-    
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-