spectral.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # -*- coding: utf-8 -*-
  2. """
  3. Identification of spectral properties in analog signals (e.g., the power
  4. spectrum).
  5. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  6. :license: Modified BSD, see LICENSE.txt for details.
  7. """
  8. import warnings
  9. import numpy as np
  10. import scipy.signal
  11. import scipy.fftpack as fftpack
  12. import scipy.signal.signaltools as signaltools
  13. from scipy.signal.windows import get_window
  14. from six import string_types
  15. import quantities as pq
  16. import neo
  17. def _welch(x, y, fs=1.0, window='hanning', nperseg=256, noverlap=None,
  18. nfft=None, detrend='constant', scaling='density', axis=-1):
  19. """
  20. A helper function to estimate cross spectral density using Welch's method.
  21. This function is a slightly modified version of `scipy.signal.welch()` with
  22. modifications based on `matplotlib.mlab._spectral_helper()`.
  23. Welch's method [1]_ computes an estimate of the cross spectral density
  24. by dividing the data into overlapping segments, computing a modified
  25. periodogram for each segment and averaging the cross-periodograms.
  26. Parameters
  27. ----------
  28. x, y : array_like
  29. Time series of measurement values
  30. fs : float, optional
  31. Sampling frequency of the `x` and `y` time series in units of Hz.
  32. Defaults to 1.0.
  33. window : str or tuple or array_like, optional
  34. Desired window to use. See `get_window` for a list of windows and
  35. required parameters. If `window` is array_like it will be used
  36. directly as the window and its length will be used for nperseg.
  37. Defaults to 'hanning'.
  38. nperseg : int, optional
  39. Length of each segment. Defaults to 256.
  40. noverlap: int, optional
  41. Number of points to overlap between segments. If None,
  42. ``noverlap = nperseg / 2``. Defaults to None.
  43. nfft : int, optional
  44. Length of the FFT used, if a zero padded FFT is desired. If None,
  45. the FFT length is `nperseg`. Defaults to None.
  46. detrend : str or function, optional
  47. Specifies how to detrend each segment. If `detrend` is a string,
  48. it is passed as the ``type`` argument to `detrend`. If it is a
  49. function, it takes a segment and returns a detrended segment.
  50. Defaults to 'constant'.
  51. scaling : { 'density', 'spectrum' }, optional
  52. Selects between computing the power spectral density ('density')
  53. where Pxx has units of V**2/Hz if x is measured in V and computing
  54. the power spectrum ('spectrum') where Pxx has units of V**2 if x is
  55. measured in V. Defaults to 'density'.
  56. axis : int, optional
  57. Axis along which the periodogram is computed; the default is over
  58. the last axis (i.e. ``axis=-1``).
  59. Returns
  60. -------
  61. f : ndarray
  62. Array of sample frequencies.
  63. Pxy : ndarray
  64. Cross spectral density or cross spectrum of x and y.
  65. Notes
  66. -----
  67. An appropriate amount of overlap will depend on the choice of window
  68. and on your requirements. For the default 'hanning' window an
  69. overlap of 50% is a reasonable trade off between accurately estimating
  70. the signal power, while not over counting any of the data. Narrower
  71. windows may require a larger overlap.
  72. If `noverlap` is 0, this method is equivalent to Bartlett's method [2]_.
  73. References
  74. ----------
  75. .. [1] P. Welch, "The use of the fast Fourier transform for the
  76. estimation of power spectra: A method based on time averaging
  77. over short, modified periodograms", IEEE Trans. Audio
  78. Electroacoust. vol. 15, pp. 70-73, 1967.
  79. .. [2] M.S. Bartlett, "Periodogram Analysis and Continuous Spectra",
  80. Biometrika, vol. 37, pp. 1-16, 1950.
  81. """
  82. # TODO: This function should be replaced by `scipy.signal.csd()`, which
  83. # will appear in SciPy 0.16.0.
  84. # The checks for if y is x are so that we can use the same function to
  85. # obtain both power spectrum and cross spectrum without doing extra
  86. # calculations.
  87. same_data = y is x
  88. # Make sure we're dealing with a numpy array. If y and x were the same
  89. # object to start with, keep them that way
  90. x = np.asarray(x)
  91. if same_data:
  92. y = x
  93. else:
  94. if x.shape != y.shape:
  95. raise ValueError("x and y must be of the same shape.")
  96. y = np.asarray(y)
  97. if x.size == 0:
  98. return np.empty(x.shape), np.empty(x.shape)
  99. if axis != -1:
  100. x = np.rollaxis(x, axis, len(x.shape))
  101. if not same_data:
  102. y = np.rollaxis(y, axis, len(y.shape))
  103. if x.shape[-1] < nperseg:
  104. warnings.warn('nperseg = %d, is greater than x.shape[%d] = %d, using '
  105. 'nperseg = x.shape[%d]'
  106. % (nperseg, axis, x.shape[axis], axis))
  107. nperseg = x.shape[-1]
  108. if isinstance(window, string_types) or type(window) is tuple:
  109. win = get_window(window, nperseg)
  110. else:
  111. win = np.asarray(window)
  112. if len(win.shape) != 1:
  113. raise ValueError('window must be 1-D')
  114. if win.shape[0] > x.shape[-1]:
  115. raise ValueError('window is longer than x.')
  116. nperseg = win.shape[0]
  117. if scaling == 'density':
  118. scale = 1.0 / (fs * (win * win).sum())
  119. elif scaling == 'spectrum':
  120. scale = 1.0 / win.sum()**2
  121. else:
  122. raise ValueError('Unknown scaling: %r' % scaling)
  123. if noverlap is None:
  124. noverlap = nperseg // 2
  125. elif noverlap >= nperseg:
  126. raise ValueError('noverlap must be less than nperseg.')
  127. if nfft is None:
  128. nfft = nperseg
  129. elif nfft < nperseg:
  130. raise ValueError('nfft must be greater than or equal to nperseg.')
  131. if not hasattr(detrend, '__call__'):
  132. detrend_func = lambda seg: signaltools.detrend(seg, type=detrend)
  133. elif axis != -1:
  134. # Wrap this function so that it receives a shape that it could
  135. # reasonably expect to receive.
  136. def detrend_func(seg):
  137. seg = np.rollaxis(seg, -1, axis)
  138. seg = detrend(seg)
  139. return np.rollaxis(seg, axis, len(seg.shape))
  140. else:
  141. detrend_func = detrend
  142. step = nperseg - noverlap
  143. indices = np.arange(0, x.shape[-1] - nperseg + 1, step)
  144. for k, ind in enumerate(indices):
  145. x_dt = detrend_func(x[..., ind:ind + nperseg])
  146. xft = fftpack.fft(x_dt * win, nfft)
  147. if same_data:
  148. yft = xft
  149. else:
  150. y_dt = detrend_func(y[..., ind:ind + nperseg])
  151. yft = fftpack.fft(y_dt * win, nfft)
  152. if k == 0:
  153. Pxy = (xft * yft.conj())
  154. else:
  155. Pxy *= k / (k + 1.0)
  156. Pxy += (xft * yft.conj()) / (k + 1.0)
  157. Pxy *= scale
  158. f = fftpack.fftfreq(nfft, 1.0 / fs)
  159. if axis != -1:
  160. Pxy = np.rollaxis(Pxy, -1, axis)
  161. return f, Pxy
  162. def welch_psd(signal, num_seg=8, len_seg=None, freq_res=None, overlap=0.5,
  163. fs=1.0, window='hanning', nfft=None, detrend='constant',
  164. return_onesided=True, scaling='density', axis=-1):
  165. """
  166. Estimates power spectrum density (PSD) of a given AnalogSignal using
  167. Welch's method, which works in the following steps:
  168. 1. cut the given data into several overlapping segments. The degree of
  169. overlap can be specified by parameter *overlap* (default is 0.5,
  170. i.e. segments are overlapped by the half of their length).
  171. The number and the length of the segments are determined according
  172. to parameter *num_seg*, *len_seg* or *freq_res*. By default, the
  173. data is cut into 8 segments.
  174. 2. apply a window function to each segment. Hanning window is used by
  175. default. This can be changed by giving a window function or an
  176. array as parameter *window* (for details, see the docstring of
  177. `scipy.signal.welch()`)
  178. 3. compute the periodogram of each segment
  179. 4. average the obtained periodograms to yield PSD estimate
  180. These steps are implemented in `scipy.signal`, and this function is a
  181. wrapper which provides a proper set of parameters to
  182. `scipy.signal.welch()`. Some parameters for scipy.signal.welch(), such as
  183. `nfft`, `detrend`, `window`, `return_onesided` and `scaling`, also works
  184. for this function.
  185. Parameters
  186. ----------
  187. signal: Neo AnalogSignal or Quantity array or Numpy ndarray
  188. Time series data, of which PSD is estimated. When a Quantity array or
  189. Numpy ndarray is given, sampling frequency should be given through the
  190. keyword argument `fs`, otherwise the default value (`fs=1.0`) is used.
  191. num_seg: int, optional
  192. Number of segments. The length of segments is adjusted so that
  193. overlapping segments cover the entire stretch of the given data. This
  194. parameter is ignored if *len_seg* or *freq_res* is given. Default is 8.
  195. len_seg: int, optional
  196. Length of segments. This parameter is ignored if *freq_res* is given.
  197. Default is None (determined from other parameters).
  198. freq_res: Quantity or float, optional
  199. Desired frequency resolution of the obtained PSD estimate in terms of
  200. the interval between adjacent frequency bins. When given as a float, it
  201. is taken as frequency in Hz. Default is None (determined from other
  202. parameters).
  203. overlap: float, optional
  204. Overlap between segments represented as a float number between 0 (no
  205. overlap) and 1 (complete overlap). Default is 0.5 (half-overlapped).
  206. fs: Quantity array or float, optional
  207. Specifies the sampling frequency of the input time series. When the
  208. input is given as an AnalogSignal, the sampling frequency is taken
  209. from its attribute and this parameter is ignored. Default is 1.0.
  210. window, nfft, detrend, return_onesided, scaling, axis: optional
  211. These arguments are directly passed on to scipy.signal.welch(). See the
  212. respective descriptions in the docstring of `scipy.signal.welch()` for
  213. usage.
  214. Returns
  215. -------
  216. freqs: Quantity array or Numpy ndarray
  217. Frequencies associated with the power estimates in `psd`. `freqs` is
  218. always a 1-dimensional array irrespective of the shape of the input
  219. data. Quantity array is returned if `signal` is AnalogSignal or
  220. Quantity array. Otherwise Numpy ndarray containing frequency in Hz is
  221. returned.
  222. psd: Quantity array or Numpy ndarray
  223. PSD estimates of the time series in `signal`. Quantity array is
  224. returned if `data` is AnalogSignal or Quantity array. Otherwise
  225. Numpy ndarray is returned.
  226. """
  227. # initialize a parameter dict (to be given to scipy.signal.welch()) with
  228. # the parameters directly passed on to scipy.signal.welch()
  229. params = {'window': window, 'nfft': nfft,
  230. 'detrend': detrend, 'return_onesided': return_onesided,
  231. 'scaling': scaling, 'axis': axis}
  232. # add the input data to params. When the input is AnalogSignal, the
  233. # data is added after rolling the axis for time index to the last
  234. data = np.asarray(signal)
  235. if isinstance(signal, neo.AnalogSignal):
  236. data = np.rollaxis(data, 0, len(data.shape))
  237. params['x'] = data
  238. # if the data is given as AnalogSignal, use its attribute to specify
  239. # the sampling frequency
  240. if hasattr(signal, 'sampling_rate'):
  241. params['fs'] = signal.sampling_rate.rescale('Hz').magnitude
  242. else:
  243. params['fs'] = fs
  244. if overlap < 0:
  245. raise ValueError("overlap must be greater than or equal to 0")
  246. elif 1 <= overlap:
  247. raise ValueError("overlap must be less then 1")
  248. # determine the length of segments (i.e. *nperseg*) according to given
  249. # parameters
  250. if freq_res is not None:
  251. if freq_res <= 0:
  252. raise ValueError("freq_res must be positive")
  253. dF = freq_res.rescale('Hz').magnitude \
  254. if isinstance(freq_res, pq.quantity.Quantity) else freq_res
  255. nperseg = int(params['fs'] / dF)
  256. if nperseg > data.shape[axis]:
  257. raise ValueError("freq_res is too high for the given data size")
  258. elif len_seg is not None:
  259. if len_seg <= 0:
  260. raise ValueError("len_seg must be a positive number")
  261. elif data.shape[axis] < len_seg:
  262. raise ValueError("len_seg must be shorter than the data length")
  263. nperseg = len_seg
  264. else:
  265. if num_seg <= 0:
  266. raise ValueError("num_seg must be a positive number")
  267. elif data.shape[axis] < num_seg:
  268. raise ValueError("num_seg must be smaller than the data length")
  269. # when only *num_seg* is given, *nperseg* is determined by solving the
  270. # following equation:
  271. # num_seg * nperseg - (num_seg-1) * overlap * nperseg = data.shape[-1]
  272. # ----------------- =============================== ^^^^^^^^^^^
  273. # summed segment lengths total overlap data length
  274. nperseg = int(data.shape[axis] / (num_seg - overlap * (num_seg - 1)))
  275. params['nperseg'] = nperseg
  276. params['noverlap'] = int(nperseg * overlap)
  277. freqs, psd = scipy.signal.welch(**params)
  278. # attach proper units to return values
  279. if isinstance(signal, pq.quantity.Quantity):
  280. if 'scaling' in params and params['scaling'] is 'spectrum':
  281. psd = psd * signal.units * signal.units
  282. else:
  283. psd = psd * signal.units * signal.units / pq.Hz
  284. freqs = freqs * pq.Hz
  285. return freqs, psd
  286. def welch_cohere(x, y, num_seg=8, len_seg=None, freq_res=None, overlap=0.5,
  287. fs=1.0, window='hanning', nfft=None, detrend='constant',
  288. scaling='density', axis=-1):
  289. """
  290. Estimates coherence between a given pair of analog signals. The estimation
  291. is performed with Welch's method: the given pair of data are cut into short
  292. segments, cross-spectra are calculated for each pair of segments, and the
  293. cross-spectra are averaged and normalized by respective auto_spectra. By
  294. default the data are cut into 8 segments with 50% overlap between
  295. neighboring segments. These numbers can be changed through respective
  296. parameters.
  297. Parameters
  298. ----------
  299. x, y: Neo AnalogSignal or Quantity array or Numpy ndarray
  300. A pair of time series data, between which coherence is computed. The
  301. shapes and the sampling frequencies of `x` and `y` must be identical.
  302. When `x` and `y` are not of AnalogSignal, sampling frequency
  303. should be specified through the keyword argument `fs`, otherwise the
  304. default value (`fs=1.0`) is used.
  305. num_seg: int, optional
  306. Number of segments. The length of segments is adjusted so that
  307. overlapping segments cover the entire stretch of the given data. This
  308. parameter is ignored if *len_seg* or *freq_res* is given. Default is 8.
  309. len_seg: int, optional
  310. Length of segments. This parameter is ignored if *freq_res* is given.
  311. Default is None (determined from other parameters).
  312. freq_res: Quantity or float, optional
  313. Desired frequency resolution of the obtained coherence estimate in
  314. terms of the interval between adjacent frequency bins. When given as a
  315. float, it is taken as frequency in Hz. Default is None (determined from
  316. other parameters).
  317. overlap: float, optional
  318. Overlap between segments represented as a float number between 0 (no
  319. overlap) and 1 (complete overlap). Default is 0.5 (half-overlapped).
  320. fs: Quantity array or float, optional
  321. Specifies the sampling frequency of the input time series. When the
  322. input time series are given as AnalogSignal, the sampling
  323. frequency is taken from their attribute and this parameter is ignored.
  324. Default is 1.0.
  325. window, nfft, detrend, scaling, axis: optional
  326. These arguments are directly passed on to a helper function
  327. `elephant.spectral._welch()`. See the respective descriptions in the
  328. docstring of `elephant.spectral._welch()` for usage.
  329. Returns
  330. -------
  331. freqs: Quantity array or Numpy ndarray
  332. Frequencies associated with the estimates of coherency and phase lag.
  333. `freqs` is always a 1-dimensional array irrespective of the shape of
  334. the input data. Quantity array is returned if `x` and `y` are of
  335. AnalogSignal or Quantity array. Otherwise Numpy ndarray containing
  336. frequency in Hz is returned.
  337. coherency: Numpy ndarray
  338. Estimate of coherency between the input time series. For each frequency
  339. coherency takes a value between 0 and 1, with 0 or 1 representing no or
  340. perfect coherence, respectively. When the input arrays `x` and `y` are
  341. multi-dimensional, `coherency` is of the same shape as the inputs and
  342. frequency is indexed along either the first or the last axis depending
  343. on the type of the input: when the input is AnalogSignal, the
  344. first axis indexes frequency, otherwise the last axis does.
  345. phase_lag: Quantity array or Numpy ndarray
  346. Estimate of phase lag in radian between the input time series. For each
  347. frequency phase lag takes a value between -PI and PI, positive values
  348. meaning phase precession of `x` ahead of `y` and vice versa. Quantity
  349. array is returned if `x` and `y` are of AnalogSignal or Quantity
  350. array. Otherwise Numpy ndarray containing phase lag in radian is
  351. returned. The axis for frequency index is determined in the same way as
  352. for `coherency`.
  353. """
  354. # initialize a parameter dict (to be given to _welch()) with
  355. # the parameters directly passed on to _welch()
  356. params = {'window': window, 'nfft': nfft,
  357. 'detrend': detrend, 'scaling': scaling, 'axis': axis}
  358. # When the input is AnalogSignal, the axis for time index is rolled to
  359. # the last
  360. xdata = np.asarray(x)
  361. ydata = np.asarray(y)
  362. if isinstance(x, neo.AnalogSignal):
  363. xdata = np.rollaxis(xdata, 0, len(xdata.shape))
  364. ydata = np.rollaxis(ydata, 0, len(ydata.shape))
  365. # if the data is given as AnalogSignal, use its attribute to specify
  366. # the sampling frequency
  367. if hasattr(x, 'sampling_rate'):
  368. params['fs'] = x.sampling_rate.rescale('Hz').magnitude
  369. else:
  370. params['fs'] = fs
  371. if overlap < 0:
  372. raise ValueError("overlap must be greater than or equal to 0")
  373. elif 1 <= overlap:
  374. raise ValueError("overlap must be less then 1")
  375. # determine the length of segments (i.e. *nperseg*) according to given
  376. # parameters
  377. if freq_res is not None:
  378. if freq_res <= 0:
  379. raise ValueError("freq_res must be positive")
  380. dF = freq_res.rescale('Hz').magnitude \
  381. if isinstance(freq_res, pq.quantity.Quantity) else freq_res
  382. nperseg = int(params['fs'] / dF)
  383. if nperseg > xdata.shape[axis]:
  384. raise ValueError("freq_res is too high for the given data size")
  385. elif len_seg is not None:
  386. if len_seg <= 0:
  387. raise ValueError("len_seg must be a positive number")
  388. elif xdata.shape[axis] < len_seg:
  389. raise ValueError("len_seg must be shorter than the data length")
  390. nperseg = len_seg
  391. else:
  392. if num_seg <= 0:
  393. raise ValueError("num_seg must be a positive number")
  394. elif xdata.shape[axis] < num_seg:
  395. raise ValueError("num_seg must be smaller than the data length")
  396. # when only *num_seg* is given, *nperseg* is determined by solving the
  397. # following equation:
  398. # num_seg * nperseg - (num_seg-1) * overlap * nperseg = data.shape[-1]
  399. # ----------------- =============================== ^^^^^^^^^^^
  400. # summed segment lengths total overlap data length
  401. nperseg = int(xdata.shape[axis] / (num_seg - overlap * (num_seg - 1)))
  402. params['nperseg'] = nperseg
  403. params['noverlap'] = int(nperseg * overlap)
  404. freqs, Pxy = _welch(xdata, ydata, **params)
  405. freqs, Pxx = _welch(xdata, xdata, **params)
  406. freqs, Pyy = _welch(ydata, ydata, **params)
  407. coherency = np.abs(Pxy)**2 / (np.abs(Pxx) * np.abs(Pyy))
  408. phase_lag = np.angle(Pxy)
  409. # attach proper units to return values
  410. if isinstance(x, pq.quantity.Quantity):
  411. freqs = freqs * pq.Hz
  412. phase_lag = phase_lag * pq.rad
  413. # When the input is AnalogSignal, the axis for frequency index is
  414. # rolled to the first to comply with the Neo convention about time axis
  415. if isinstance(x, neo.AnalogSignal):
  416. coherency = np.rollaxis(coherency, -1)
  417. phase_lag = np.rollaxis(phase_lag, -1)
  418. return freqs, coherency, phase_lag