artifacts_removal.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Oct 8 15:13:44 2023
  4. @author: chen min,I authorize this code to be publicly available on gin.g-node.org website
  5. """
  6. #%% LFP artifacts removal
  7. import neo
  8. import numpy as np
  9. import quantities as pq
  10. from elephant.signal_processing import butter
  11. from elephant.spectral import welch_psd
  12. import matplotlib.pyplot as plt
  13. from pathlib import Path
  14. from scipy import signal
  15. config = {"font.family": 'Times New Roman',
  16. "mathtext.fontset": 'stix',
  17. "savefig.dpi": 300 }
  18. plt.rcParams.update(config)
  19. def load_data(filename):
  20. """Load data extraction LFP"""
  21. io = neo.io.NeuroExplorerIO(filename)
  22. segment = io.read_segment()
  23. WB = [segment.analogsignals[i] for i in range(1)]
  24. return WB
  25. def artifacts_removal(WB_one):
  26. # Find the index of the value that exceeds the threshold 1.0
  27. indexes_ = np.where(WB_one>1.5)[0]
  28. index_diff = np.diff(indexes_) #
  29. index1 = np.where(index_diff>60)[0]
  30. indexes = indexes_[index1]
  31. indexes_max_val = []
  32. for i in indexes:
  33. index_max_val_ = np.where(WB_one[i-5:i+1]==np.max(WB_one[i-5:i+1]))[0][0]
  34. index_max_val = i - 5 + index_max_val_
  35. indexes_max_val.append(index_max_val)
  36. # The index of the maximum value in each artifact
  37. indexes_max_val = np.array(indexes_max_val)
  38. #linear interpolation
  39. for j in indexes_max_val:
  40. WB_one[j-40:j+15] = np.linspace(WB_one[j-40],WB_one[j+14],55)
  41. return WB_one
  42. def filtering(filenames):
  43. """Filtering"""
  44. filtered_LFPs = []
  45. for filename in filenames:
  46. WB = load_data(filename)
  47. filtered_LFP = []
  48. for i in range(len(WB)):
  49. WB_one = artifacts_removal(WB[i])
  50. # lowpass filtering 300Hz
  51. WB_one = butter(WB_one, lowpass_frequency=300.0 * pq.Hz)
  52. LFP_one = WB_one[5::40] # downsampling
  53. LFP_one.sampling_rate = 1000* pq.Hz
  54. # notching filter
  55. for j in range(5):
  56. LFP_one = butter(LFP_one, highpass_frequency=(50.5+50*j),
  57. lowpass_frequency=(49.5+50*j))
  58. filtered_LFP.append(LFP_one)
  59. filtered_LFPs.append(filtered_LFP)
  60. return filtered_LFPs
  61. def _welch_PSD(filtered_LFPs):
  62. """PSD of LFP"""
  63. freq = []
  64. psd = []
  65. for filtered_LFP in filtered_LFPs:
  66. freq_ = []
  67. psd_ = []
  68. for i in range(len(filtered_LFP)):
  69. freq_one, psd_one = welch_psd(filtered_LFP[i], len_segment=5000)
  70. freq_.append(freq_one)
  71. psd_.append(psd_one)
  72. freq.append(freq_)
  73. psd.append(psd_)
  74. return freq, psd
  75. #%% main
  76. if __name__ == '__main__':
  77. filenames = ['D:/Desktop/code/data/test.nex']
  78. filtered_LFPs = filtering(filenames)
  79. freq, psd = _welch_PSD(filtered_LFPs) # PSD
  80. fig = plt.figure(figsize=(14, 9.5))
  81. # num: file_num i: channel_num mask: range of frequency
  82. num = 0
  83. i = 0
  84. mask = (freq[0][0] >= 0.5) & (freq[0][0] <= 100) # 0.5~100Hz
  85. x = freq[num][i][mask]
  86. y = 10*np.log10(psd[num][i][0][mask].magnitude)
  87. plt.plot(x, y)
  88. plt.title(filtered_LFPs[0][i].name, fontsize = 16)
  89. plt.xticks(fontsize = 12)
  90. plt.yticks(fontsize = 12)
  91. plt.xlabel('Frequency(Hz)', fontsize=12)
  92. plt.ylabel('PSD(dB/Hz)', fontsize=12)
  93. plt.show()