123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- # -*- coding: utf-8 -*-
- """
- Created on Sun Oct 8 15:13:44 2023
- @author: comeo
- """
- #%% LFP artifacts removal
- import neo
- import numpy as np
- import quantities as pq
- from elephant.signal_processing import butter
- from elephant.spectral import welch_psd
- import matplotlib.pyplot as plt
- from pathlib import Path
- from scipy import signal
- config = {"font.family": 'Times New Roman',
- "mathtext.fontset": 'stix',
- "savefig.dpi": 300 }
- plt.rcParams.update(config)
- def load_data(filename):
- """Load data extraction LFP"""
- io = neo.io.NeuroExplorerIO(filename)
- segment = io.read_segment()
- WB = [segment.analogsignals[i] for i in range(1)]
-
- return WB
- def artifacts_removal(WB_one):
-
- # Find the index of the value that exceeds the threshold 1.0
- indexes_ = np.where(WB_one>1.5)[0]
-
- index_diff = np.diff(indexes_) #
- index1 = np.where(index_diff>60)[0]
- indexes = indexes_[index1]
- indexes_max_val = []
- for i in indexes:
- index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0]
- index_max_val = i - 5 + index_max_val_
- indexes_max_val.append(index_max_val)
- # The index of the maximum value in each artifact
- indexes_max_val = np.array(indexes_max_val)
-
- #linear interpolation
- for j in indexes_max_val:
- WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55)
-
- return WB_one
- def filtering(filenames):
- """Filtering"""
- filtered_LFPs = []
- for filename in filenames:
- WB = load_data(filename)
- filtered_LFP = []
-
- for i in range(len(WB)):
- WB_one = artifacts_removal(WB[i])
- # lowpass filtering 300Hz
- WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz)
-
- LFP_one = WB_one[5::40] # downsampling
- LFP_one.sampling_rate = 1000* pq.Hz
-
- # notching filter
- for j in range(5):
- LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j),
- lowpass_frequency=(49.5+50*j))
- filtered_LFP.append(LFP_one)
- filtered_LFPs.append(filtered_LFP)
-
- return filtered_LFPs
- def _welch_PSD(filtered_LFPs):
- """PSD of LFP"""
- freq = []
- psd = []
- for filtered_LFP in filtered_LFPs:
- freq_ = []
- psd_ = []
- for i in range(len(filtered_LFP)):
- freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000)
- freq_.append(freq_one)
- psd_.append(psd_one)
- freq.append(freq_)
- psd.append(psd_)
-
- return freq, psd
- #%% main
- if __name__ == '__main__':
- filenames = ['D:/Desktop/code/data/test.nex']
- filtered_LFPs = filtering(filenames)
- freq, psd = _welch_PSD(filtered_LFPs) # PSD
- fig = plt.figure(figsize=(14, 9.5))
- # num: file_num i: channel_num mask: range of frequency
- num = 0
- i = 0
- mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100) # 0.5~100Hz
-
- x = freq[num][i][mask]
- y = 10*np.log10(psd[num][i][0][mask].magnitude)
- plt.plot(x, y)
- plt.title(filtered_LFPs[0][i].name, fontsize = 16)
- plt.xticks(fontsize = 12)
- plt.yticks(fontsize = 12)
- plt.xlabel('Frequency(Hz)', fontsize=12)
- plt.ylabel('PSD(dB/Hz)', fontsize=12)
- plt.show()
-
|