signal_processing.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Basic processing procedures for analog signals (e.g., performing a z-score of a
  4. signal, or filtering a signal).
  5. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  6. :license: Modified BSD, see LICENSE.txt for details.
  7. '''
  8. from __future__ import division, print_function
  9. import numpy as np
  10. import scipy.signal
  11. import quantities as pq
  12. import neo
  13. def zscore(signal, inplace=True):
  14. '''
  15. Apply a z-score operation to one or several AnalogSignal objects.
  16. The z-score operation subtracts the mean :math:`\\mu` of the signal, and
  17. divides by its standard deviation :math:`\\sigma`:
  18. .. math::
  19. Z(x(t))= \\frac{x(t)-\\mu}{\\sigma}
  20. If an AnalogSignal containing multiple signals is provided, the
  21. z-transform is always calculated for each signal individually.
  22. If a list of AnalogSignal objects is supplied, the mean and standard
  23. deviation are calculated across all objects of the list. Thus, all list
  24. elements are z-transformed by the same values of :math:`\\mu` and
  25. :math:`\\sigma`. For AnalogSignals, each signal of the array is
  26. treated separately across list elements. Therefore, the number of signals
  27. must be identical for each AnalogSignal of the list.
  28. Parameters
  29. ----------
  30. signal : neo.AnalogSignal or list of neo.AnalogSignal
  31. Signals for which to calculate the z-score.
  32. inplace : bool
  33. If True, the contents of the input signal(s) is replaced by the
  34. z-transformed signal. Otherwise, a copy of the original
  35. AnalogSignal(s) is returned. Default: True
  36. Returns
  37. -------
  38. neo.AnalogSignal or list of neo.AnalogSignal
  39. The output format matches the input format: for each supplied
  40. AnalogSignal object a corresponding object is returned containing
  41. the z-transformed signal with the unit dimensionless.
  42. Use Case
  43. --------
  44. You may supply a list of AnalogSignal objects, where each object in
  45. the list contains the data of one trial of the experiment, and each signal
  46. of the AnalogSignal corresponds to the recordings from one specific
  47. electrode in a particular trial. In this scenario, you will z-transform the
  48. signal of each electrode separately, but transform all trials of a given
  49. electrode in the same way.
  50. Examples
  51. --------
  52. >>> a = neo.AnalogSignal(
  53. ... np.array([1, 2, 3, 4, 5, 6]).reshape(-1,1)*mV,
  54. ... t_start=0*s, sampling_rate=1000*Hz)
  55. >>> b = neo.AnalogSignal(
  56. ... np.transpose([[1, 2, 3, 4, 5, 6], [11, 12, 13, 14, 15, 16]])*mV,
  57. ... t_start=0*s, sampling_rate=1000*Hz)
  58. >>> c = neo.AnalogSignal(
  59. ... np.transpose([[21, 22, 23, 24, 25, 26], [31, 32, 33, 34, 35, 36]])*mV,
  60. ... t_start=0*s, sampling_rate=1000*Hz)
  61. >>> print zscore(a)
  62. [[-1.46385011]
  63. [-0.87831007]
  64. [-0.29277002]
  65. [ 0.29277002]
  66. [ 0.87831007]
  67. [ 1.46385011]] dimensionless
  68. >>> print zscore(b)
  69. [[-1.46385011 -1.46385011]
  70. [-0.87831007 -0.87831007]
  71. [-0.29277002 -0.29277002]
  72. [ 0.29277002 0.29277002]
  73. [ 0.87831007 0.87831007]
  74. [ 1.46385011 1.46385011]] dimensionless
  75. >>> print zscore([b,c])
  76. [<AnalogSignal(array([[-1.11669108, -1.08361877],
  77. [-1.0672076 , -1.04878252],
  78. [-1.01772411, -1.01394628],
  79. [-0.96824063, -0.97911003],
  80. [-0.91875714, -0.94427378],
  81. [-0.86927366, -0.90943753]]) * dimensionless, [0.0 s, 0.006 s],
  82. sampling rate: 1000.0 Hz)>,
  83. <AnalogSignal(array([[ 0.78170952, 0.84779261],
  84. [ 0.86621866, 0.90728682],
  85. [ 0.9507278 , 0.96678104],
  86. [ 1.03523694, 1.02627526],
  87. [ 1.11974608, 1.08576948],
  88. [ 1.20425521, 1.1452637 ]]) * dimensionless, [0.0 s, 0.006 s],
  89. sampling rate: 1000.0 Hz)>]
  90. '''
  91. # Transform input to a list
  92. if type(signal) is not list:
  93. signal = [signal]
  94. # Calculate mean and standard deviation
  95. m = np.mean(np.concatenate(signal), axis=0)
  96. s = np.std(np.concatenate(signal), axis=0)
  97. if not inplace:
  98. # Create new signal instance
  99. result = []
  100. for sig in signal:
  101. sig_dimless = sig.duplicate_with_new_array(
  102. (sig.magnitude - m.magnitude) / s.magnitude) / sig.units
  103. result.append(sig_dimless)
  104. else:
  105. result = []
  106. # Overwrite signal
  107. for sig in signal:
  108. sig[:] = pq.Quantity(
  109. (sig.magnitude - m.magnitude) / s.magnitude,
  110. units=sig.units)
  111. sig_dimless = sig / sig.units
  112. result.append(sig_dimless)
  113. # Return single object, or list of objects
  114. if len(result) == 1:
  115. return result[0]
  116. else:
  117. return result
  118. def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
  119. filter_function='filtfilt', fs=1.0, axis=-1):
  120. """
  121. Butterworth filtering function for neo.AnalogSignal. Filter type is
  122. determined according to how values of `highpass_freq` and `lowpass_freq`
  123. are given (see Parameters section for details).
  124. Parameters
  125. ----------
  126. signal : AnalogSignal or Quantity array or NumPy ndarray
  127. Time series data to be filtered. When given as Quantity array or NumPy
  128. ndarray, the sampling frequency should be given through the keyword
  129. argument `fs`.
  130. highpass_freq, lowpass_freq : Quantity or float
  131. High-pass and low-pass cut-off frequencies, respectively. When given as
  132. float, the given value is taken as frequency in Hz.
  133. Filter type is determined depending on values of these arguments:
  134. * highpass_freq only (lowpass_freq = None): highpass filter
  135. * lowpass_freq only (highpass_freq = None): lowpass filter
  136. * highpass_freq < lowpass_freq: bandpass filter
  137. * highpass_freq > lowpass_freq: bandstop filter
  138. order : int
  139. Order of Butterworth filter. Default is 4.
  140. filter_function : string
  141. Filtering function to be used. Either 'filtfilt'
  142. (`scipy.signal.filtfilt()`) or 'lfilter' (`scipy.signal.lfilter()`). In
  143. most applications 'filtfilt' should be used, because it doesn't bring
  144. about phase shift due to filtering. Default is 'filtfilt'.
  145. fs : Quantity or float
  146. The sampling frequency of the input time series. When given as float,
  147. its value is taken as frequency in Hz. When the input is given as neo
  148. AnalogSignal, its attribute is used to specify the sampling
  149. frequency and this parameter is ignored. Default is 1.0.
  150. axis : int
  151. Axis along which filter is applied. Default is -1.
  152. Returns
  153. -------
  154. filtered_signal : AnalogSignal or Quantity array or NumPy ndarray
  155. Filtered input data. The shape and type is identical to those of the
  156. input.
  157. """
  158. def _design_butterworth_filter(Fs, hpfreq=None, lpfreq=None, order=4):
  159. # set parameters for filter design
  160. Fn = Fs / 2.
  161. # - filter type is determined according to the values of cut-off
  162. # frequencies
  163. if lpfreq and hpfreq:
  164. if hpfreq < lpfreq:
  165. Wn = (hpfreq / Fn, lpfreq / Fn)
  166. btype = 'bandpass'
  167. else:
  168. Wn = (lpfreq / Fn, hpfreq / Fn)
  169. btype = 'bandstop'
  170. elif lpfreq:
  171. Wn = lpfreq / Fn
  172. btype = 'lowpass'
  173. elif hpfreq:
  174. Wn = hpfreq / Fn
  175. btype = 'highpass'
  176. else:
  177. raise ValueError(
  178. "Either highpass_freq or lowpass_freq must be given"
  179. )
  180. # return filter coefficients
  181. return scipy.signal.butter(order, Wn, btype=btype)
  182. # design filter
  183. Fs = signal.sampling_rate.rescale(pq.Hz).magnitude \
  184. if hasattr(signal, 'sampling_rate') else fs
  185. Fh = highpass_freq.rescale(pq.Hz).magnitude \
  186. if isinstance(highpass_freq, pq.quantity.Quantity) else highpass_freq
  187. Fl = lowpass_freq.rescale(pq.Hz).magnitude \
  188. if isinstance(lowpass_freq, pq.quantity.Quantity) else lowpass_freq
  189. b, a = _design_butterworth_filter(Fs, Fh, Fl, order)
  190. # When the input is AnalogSignal, the axis for time index (i.e. the
  191. # first axis) needs to be rolled to the last
  192. data = np.asarray(signal)
  193. if isinstance(signal, neo.AnalogSignal):
  194. data = np.rollaxis(data, 0, len(data.shape))
  195. # apply filter
  196. if filter_function is 'lfilter':
  197. filtered_data = scipy.signal.lfilter(b, a, data, axis=axis)
  198. elif filter_function is 'filtfilt':
  199. filtered_data = scipy.signal.filtfilt(b, a, data, axis=axis)
  200. else:
  201. raise ValueError(
  202. "filter_func must to be either 'filtfilt' or 'lfilter'"
  203. )
  204. if isinstance(signal, neo.AnalogSignal):
  205. return signal.duplicate_with_new_array(np.rollaxis(filtered_data, -1, 0))
  206. elif isinstance(signal, pq.quantity.Quantity):
  207. return filtered_data * signal.units
  208. else:
  209. return filtered_data
  210. def wavelet_transform(signal, freq, nco=6.0, fs=1.0, zero_padding=True):
  211. """
  212. Compute the wavelet transform of a given signal with Morlet mother wavelet.
  213. The parametrization of the wavelet is based on [1].
  214. Parameters
  215. ----------
  216. signal : neo.AnalogSignal or array_like
  217. Time series data to be wavelet-transformed. When multi-dimensional
  218. array_like is given, the time axis must be the last dimension of
  219. the array_like.
  220. freq : float or list of floats
  221. Center frequency of the Morlet wavelet in Hz. Multiple center
  222. frequencies can be given as a list, in which case the function
  223. computes the wavelet transforms for all the given frequencies at once.
  224. nco : float (optional)
  225. Size of the mother wavelet (approximate number of oscillation cycles
  226. within a wavelet; related to the wavelet number w as w ~ 2 pi nco / 6),
  227. as defined in [1]. A larger nco value leads to a higher frequency
  228. resolution and a lower temporal resolution, and vice versa. Typically
  229. used values are in a range of 3 - 8, but one should be cautious when
  230. using a value smaller than ~ 6, in which case the admissibility of the
  231. wavelet is not ensured (cf. [2]). Default value is 6.0.
  232. fs : float (optional)
  233. Sampling rate of the input data in Hz. When `signal` is given as an
  234. AnalogSignal, the sampling frequency is taken from its attribute and
  235. this parameter is ignored. Default value is 1.0.
  236. zero_padding : bool (optional)
  237. Specifies whether the data length is extended to the least power of
  238. 2 greater than the original length, by padding zeros to the tail, for
  239. speeding up the computation. In the case of True, the extended part is
  240. cut out from the final result before returned, so that the output
  241. has the same length as the input. Default is True.
  242. Returns
  243. -------
  244. signal_wt: complex array
  245. Wavelet transform of the input data. When `freq` was given as a list,
  246. the way how the wavelet transforms for different frequencies are
  247. returned depends on the input type. When the input was an AnalogSignal
  248. of shape (Nt, Nch), where Nt and Nch are the numbers of time points and
  249. channels, respectively, the returned array has a shape (Nt, Nch, Nf),
  250. where Nf = `len(freq)`, such that the last dimension indexes the
  251. frequencies. When the input was an array_like of shape
  252. (a, b, ..., c, Nt), the returned array has a shape
  253. (a, b, ..., c, Nf, Nt), such that the second last dimension indexes the
  254. frequencies.
  255. To summarize, `signal_wt.ndim` = `signal.ndim` + 1, with the additional
  256. dimension in the last axis (for AnalogSignal input) or the second last
  257. axis (for array_like input) indexing the frequencies.
  258. Raises
  259. ------
  260. ValueError
  261. If `freq` (or one of the values in `freq` when it is a list) is greater
  262. than the half of `fs`, or `nco` is not positive.
  263. References
  264. ----------
  265. 1. Le van Quyen et al. J Neurosci Meth 111:83-98 (2001)
  266. 2. Farge, Annu Rev Fluid Mech 24:395-458 (1992)
  267. """
  268. def _morlet_wavelet_ft(freq, nco, fs, n):
  269. # Generate the Fourier transform of Morlet wavelet as defined
  270. # in Le van Quyen et al. J Neurosci Meth 111:83-98 (2001).
  271. sigma = nco / (6. * freq)
  272. freqs = np.fft.fftfreq(n, 1.0 / fs)
  273. heaviside = np.array(freqs > 0., dtype=np.float)
  274. ft_real = np.sqrt(2 * np.pi * freq) * sigma * np.exp(
  275. -2 * (np.pi * sigma * (freqs - freq)) ** 2) * heaviside * fs
  276. ft_imag = np.zeros_like(ft_real)
  277. return ft_real + 1.0j * ft_imag
  278. data = np.asarray(signal)
  279. # When the input is AnalogSignal, the axis for time index (i.e. the
  280. # first axis) needs to be rolled to the last
  281. if isinstance(signal, neo.AnalogSignal):
  282. data = np.rollaxis(data, 0, data.ndim)
  283. # When the input is AnalogSignal, use its attribute to specify the
  284. # sampling frequency
  285. if hasattr(signal, 'sampling_rate'):
  286. fs = signal.sampling_rate
  287. if isinstance(fs, pq.quantity.Quantity):
  288. fs = fs.rescale('Hz').magnitude
  289. if isinstance(freq, (list, tuple, np.ndarray)):
  290. freqs = np.asarray(freq)
  291. else:
  292. freqs = np.array([freq,])
  293. if isinstance(freqs[0], pq.quantity.Quantity):
  294. freqs = [f.rescale('Hz').magnitude for f in freqs]
  295. # check whether the given central frequencies are less than the
  296. # Nyquist frequency of the signal
  297. if np.any(freqs >= fs / 2):
  298. raise ValueError("`freq` must be less than the half of " +
  299. "the sampling rate `fs` = {} Hz".format(fs))
  300. # check if nco is positive
  301. if nco <= 0:
  302. raise ValueError("`nco` must be positive")
  303. n_orig = data.shape[-1]
  304. if zero_padding:
  305. n = 2 ** (int(np.log2(n_orig)) + 1)
  306. else:
  307. n = n_orig
  308. # generate Morlet wavelets (in the frequency domain)
  309. wavelet_fts = np.empty([len(freqs), n], dtype=np.complex)
  310. for i, f in enumerate(freqs):
  311. wavelet_fts[i] = _morlet_wavelet_ft(f, nco, fs, n)
  312. # perform wavelet transform by convoluting the signal with the wavelets
  313. if data.ndim == 1:
  314. data = np.expand_dims(data, 0)
  315. data = np.expand_dims(data, data.ndim-1)
  316. data = np.fft.ifft(np.fft.fft(data, n) * wavelet_fts)
  317. signal_wt = data[..., 0:n_orig]
  318. # reshape the result array according to the input
  319. if isinstance(signal, neo.AnalogSignal):
  320. signal_wt = np.rollaxis(signal_wt, -1)
  321. if not isinstance(freq, (list, tuple, np.ndarray)):
  322. signal_wt = signal_wt[..., 0]
  323. else:
  324. if signal.ndim == 1:
  325. signal_wt = signal_wt[0]
  326. if not isinstance(freq, (list, tuple, np.ndarray)):
  327. signal_wt = signal_wt[..., 0, :]
  328. return signal_wt
  329. def hilbert(signal, N='nextpow'):
  330. '''
  331. Apply a Hilbert transform to an AnalogSignal object in order to obtain its
  332. (complex) analytic signal.
  333. The time series of the instantaneous angle and amplitude can be obtained as
  334. the angle (np.angle) and absolute value (np.abs) of the complex analytic
  335. signal, respectively.
  336. By default, the function will zero-pad the signal to a length corresponding
  337. to the next higher power of 2. This will provide higher computational
  338. efficiency at the expense of memory. In addition, this circumvents a
  339. situation where for some specific choices of the length of the input,
  340. scipy.signal.hilbert() will not terminate.
  341. Parameters
  342. -----------
  343. signal : neo.AnalogSignal
  344. Signal(s) to transform
  345. N : string or int
  346. Defines whether the signal is zero-padded.
  347. 'none': no padding
  348. 'nextpow': zero-pad to the next length that is a power of 2
  349. int: directly specify the length to zero-pad to (indicates the
  350. number of Fourier components, see parameter N of
  351. scipy.signal.hilbert()).
  352. Default: 'nextpow'.
  353. Returns
  354. -------
  355. neo.AnalogSignal
  356. Contains the complex analytic signal(s) corresponding to the input
  357. signals. The unit of the analytic signal is dimensionless.
  358. Example
  359. -------
  360. Create a sine signal at 5 Hz with increasing amplitude and calculate the
  361. instantaneous phases
  362. >>> t = np.arange(0, 5000) * ms
  363. >>> f = 5. * Hz
  364. >>> a = neo.AnalogSignal(
  365. ... np.array(
  366. ... (1 + t.magnitude / t[-1].magnitude) * np.sin(
  367. ... 2. * np.pi * f * t.rescale(s))).reshape((-1,1))*mV,
  368. ... t_start=0*s, sampling_rate=1000*Hz)
  369. >>> analytic_signal = hilbert(a, N='nextpow')
  370. >>> angles = np.angle(analytic_signal)
  371. >>> amplitudes = np.abs(analytic_signal)
  372. >>> print angles
  373. [[-1.57079633]
  374. [-1.51334228]
  375. [-1.46047675]
  376. ...,
  377. [-1.73112977]
  378. [-1.68211683]
  379. [-1.62879501]]
  380. >>> plt.plot(t,angles)
  381. '''
  382. # Length of input signals
  383. n_org = signal.shape[0]
  384. # Right-pad signal to desired length using the signal itself
  385. if type(N) == int:
  386. # User defined padding
  387. n = N
  388. elif N == 'nextpow':
  389. # To speed up calculation of the Hilbert transform, make sure we change
  390. # the signal to be of a length that is a power of two. Failure to do so
  391. # results in computations of certain signal lengths to not finish (or
  392. # finish in absurd time). This might be a bug in scipy (0.16), e.g.,
  393. # the following code will not terminate for this value of k:
  394. #
  395. # import numpy
  396. # import scipy.signal
  397. # k=679346
  398. # t = np.arange(0, k) / 1000.
  399. # a = (1 + t / t[-1]) * np.sin(2 * np.pi * 5 * t)
  400. # analytic_signal = scipy.signal.hilbert(a)
  401. #
  402. # For this reason, nextpow is the default setting for now.
  403. n = 2 ** (int(np.log2(n_org - 1)) + 1)
  404. elif N == 'none':
  405. # No padding
  406. n = n_org
  407. else:
  408. raise ValueError("'{}' is an unknown N.".format(N))
  409. output = signal.duplicate_with_new_array(
  410. scipy.signal.hilbert(signal.magnitude, N=n, axis=0)[:n_org])
  411. return output / output.units