test_signal_processing.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the signal_processing module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division, print_function
  8. import unittest
  9. import neo
  10. import numpy as np
  11. import scipy.signal as spsig
  12. import scipy.stats
  13. from numpy.testing.utils import assert_array_almost_equal
  14. import quantities as pq
  15. import elephant.signal_processing
  16. from numpy.ma.testutils import assert_array_equal, assert_allclose
  17. class ZscoreTestCase(unittest.TestCase):
  18. def setUp(self):
  19. self.test_seq1 = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12,
  20. 4, 12, 59, 2, 4, 18, 33, 25, 2, 34,
  21. 4, 1, 1, 14, 8, 1, 10, 1, 8, 20,
  22. 5, 1, 6, 5, 12, 2, 8, 8, 2, 8,
  23. 2, 10, 2, 1, 1, 2, 15, 3, 20, 6,
  24. 11, 6, 18, 2, 5, 17, 4, 3, 13, 6,
  25. 1, 18, 1, 16, 12, 2, 52, 2, 5, 7,
  26. 6, 25, 6, 5, 3, 15, 4, 3, 16, 3,
  27. 6, 5, 24, 21, 3, 3, 4, 8, 4, 11,
  28. 5, 7, 5, 6, 8, 11, 33, 10, 7, 4]
  29. self.test_seq2 = [6, 3, 0, 0, 18, 4, 14, 98, 3, 56,
  30. 7, 4, 6, 9, 11, 16, 13, 3, 2, 15,
  31. 24, 1, 0, 7, 4, 4, 9, 24, 12, 11,
  32. 9, 7, 9, 8, 5, 2, 7, 12, 15, 17,
  33. 3, 7, 2, 1, 0, 17, 2, 6, 3, 32,
  34. 22, 19, 11, 8, 5, 4, 3, 2, 7, 21,
  35. 24, 2, 5, 10, 11, 14, 6, 8, 4, 12,
  36. 6, 5, 2, 22, 25, 19, 16, 22, 13, 2,
  37. 19, 20, 17, 19, 2, 4, 1, 3, 5, 23,
  38. 20, 15, 4, 7, 10, 14, 15, 15, 20, 1]
  39. def test_zscore_single_dup(self):
  40. """
  41. Test z-score on a single AnalogSignal, asking to return a
  42. duplicate.
  43. """
  44. signal = neo.AnalogSignal(
  45. self.test_seq1, units='mV',
  46. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  47. m = np.mean(self.test_seq1)
  48. s = np.std(self.test_seq1)
  49. target = (self.test_seq1 - m) / s
  50. assert_array_equal(target, scipy.stats.zscore(self.test_seq1))
  51. result = elephant.signal_processing.zscore(signal, inplace=False)
  52. assert_array_almost_equal(
  53. result.magnitude, target.reshape(-1, 1), decimal=9)
  54. self.assertEqual(result.units, pq.Quantity(1. * pq.dimensionless))
  55. # Assert original signal is untouched
  56. self.assertEqual(signal[0].magnitude, self.test_seq1[0])
  57. def test_zscore_single_inplace(self):
  58. """
  59. Test z-score on a single AnalogSignal, asking for an inplace
  60. operation.
  61. """
  62. signal = neo.AnalogSignal(
  63. self.test_seq1, units='mV',
  64. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  65. m = np.mean(self.test_seq1)
  66. s = np.std(self.test_seq1)
  67. target = (self.test_seq1 - m) / s
  68. result = elephant.signal_processing.zscore(signal, inplace=True)
  69. assert_array_almost_equal(
  70. result.magnitude, target.reshape(-1, 1), decimal=9)
  71. self.assertEqual(result.units, pq.Quantity(1. * pq.dimensionless))
  72. # Assert original signal is overwritten
  73. self.assertEqual(signal[0].magnitude, target[0])
  74. def test_zscore_single_multidim_dup(self):
  75. """
  76. Test z-score on a single AnalogSignal with multiple dimensions, asking
  77. to return a duplicate.
  78. """
  79. signal = neo.AnalogSignal(
  80. np.transpose(
  81. np.vstack([self.test_seq1, self.test_seq2])), units='mV',
  82. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  83. m = np.mean(signal.magnitude, axis=0, keepdims=True)
  84. s = np.std(signal.magnitude, axis=0, keepdims=True)
  85. target = (signal.magnitude - m) / s
  86. assert_array_almost_equal(
  87. elephant.signal_processing.zscore(
  88. signal, inplace=False).magnitude, target, decimal=9)
  89. # Assert original signal is untouched
  90. self.assertEqual(signal[0, 0].magnitude, self.test_seq1[0])
  91. def test_zscore_single_multidim_inplace(self):
  92. """
  93. Test z-score on a single AnalogSignal with multiple dimensions, asking
  94. for an inplace operation.
  95. """
  96. signal = neo.AnalogSignal(
  97. np.vstack([self.test_seq1, self.test_seq2]), units='mV',
  98. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  99. m = np.mean(signal.magnitude, axis=0, keepdims=True)
  100. s = np.std(signal.magnitude, axis=0, keepdims=True)
  101. target = (signal.magnitude - m) / s
  102. assert_array_almost_equal(
  103. elephant.signal_processing.zscore(
  104. signal, inplace=True).magnitude, target, decimal=9)
  105. # Assert original signal is overwritten
  106. self.assertEqual(signal[0, 0].magnitude, target[0, 0])
  107. def test_zscore_single_dup_int(self):
  108. """
  109. Test if the z-score is correctly calculated even if the input is an
  110. AnalogSignal of type int, asking for a duplicate (duplicate should
  111. be of type float).
  112. """
  113. signal = neo.AnalogSignal(
  114. self.test_seq1, units='mV',
  115. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=int)
  116. m = np.mean(self.test_seq1)
  117. s = np.std(self.test_seq1)
  118. target = (self.test_seq1 - m) / s
  119. assert_array_almost_equal(
  120. elephant.signal_processing.zscore(signal, inplace=False).magnitude,
  121. target.reshape(-1, 1), decimal=9)
  122. # Assert original signal is untouched
  123. self.assertEqual(signal.magnitude[0], self.test_seq1[0])
  124. def test_zscore_single_inplace_int(self):
  125. """
  126. Test if the z-score is correctly calculated even if the input is an
  127. AnalogSignal of type int, asking for an inplace operation.
  128. """
  129. signal = neo.AnalogSignal(
  130. self.test_seq1, units='mV',
  131. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=int)
  132. m = np.mean(self.test_seq1)
  133. s = np.std(self.test_seq1)
  134. target = (self.test_seq1 - m) / s
  135. assert_array_almost_equal(
  136. elephant.signal_processing.zscore(signal, inplace=True).magnitude,
  137. target.reshape(-1, 1).astype(int), decimal=9)
  138. # Assert original signal is overwritten
  139. self.assertEqual(signal[0].magnitude, target.astype(int)[0])
  140. def test_zscore_list_dup(self):
  141. """
  142. Test zscore on a list of AnalogSignal objects, asking to return a
  143. duplicate.
  144. """
  145. signal1 = neo.AnalogSignal(
  146. np.transpose(np.vstack([self.test_seq1, self.test_seq1])),
  147. units='mV',
  148. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  149. signal2 = neo.AnalogSignal(
  150. np.transpose(np.vstack([self.test_seq1, self.test_seq2])),
  151. units='mV',
  152. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  153. signal_list = [signal1, signal2]
  154. m = np.mean(np.hstack([self.test_seq1, self.test_seq1]))
  155. s = np.std(np.hstack([self.test_seq1, self.test_seq1]))
  156. target11 = (self.test_seq1 - m) / s
  157. target21 = (self.test_seq1 - m) / s
  158. m = np.mean(np.hstack([self.test_seq1, self.test_seq2]))
  159. s = np.std(np.hstack([self.test_seq1, self.test_seq2]))
  160. target12 = (self.test_seq1 - m) / s
  161. target22 = (self.test_seq2 - m) / s
  162. # Call elephant function
  163. result = elephant.signal_processing.zscore(signal_list, inplace=False)
  164. assert_array_almost_equal(
  165. result[0].magnitude,
  166. np.transpose(np.vstack([target11, target12])), decimal=9)
  167. assert_array_almost_equal(
  168. result[1].magnitude,
  169. np.transpose(np.vstack([target21, target22])), decimal=9)
  170. # Assert original signal is untouched
  171. self.assertEqual(signal1.magnitude[0, 0], self.test_seq1[0])
  172. self.assertEqual(signal2.magnitude[0, 1], self.test_seq2[0])
  173. def test_zscore_list_inplace(self):
  174. """
  175. Test zscore on a list of AnalogSignal objects, asking for an
  176. inplace operation.
  177. """
  178. signal1 = neo.AnalogSignal(
  179. np.transpose(np.vstack([self.test_seq1, self.test_seq1])),
  180. units='mV',
  181. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  182. signal2 = neo.AnalogSignal(
  183. np.transpose(np.vstack([self.test_seq1, self.test_seq2])),
  184. units='mV',
  185. t_start=0. * pq.ms, sampling_rate=1000. * pq.Hz, dtype=float)
  186. signal_list = [signal1, signal2]
  187. m = np.mean(np.hstack([self.test_seq1, self.test_seq1]))
  188. s = np.std(np.hstack([self.test_seq1, self.test_seq1]))
  189. target11 = (self.test_seq1 - m) / s
  190. target21 = (self.test_seq1 - m) / s
  191. m = np.mean(np.hstack([self.test_seq1, self.test_seq2]))
  192. s = np.std(np.hstack([self.test_seq1, self.test_seq2]))
  193. target12 = (self.test_seq1 - m) / s
  194. target22 = (self.test_seq2 - m) / s
  195. # Call elephant function
  196. result = elephant.signal_processing.zscore(signal_list, inplace=True)
  197. assert_array_almost_equal(
  198. result[0].magnitude,
  199. np.transpose(np.vstack([target11, target12])), decimal=9)
  200. assert_array_almost_equal(
  201. result[1].magnitude,
  202. np.transpose(np.vstack([target21, target22])), decimal=9)
  203. # Assert original signal is overwritten
  204. self.assertEqual(signal1[0, 0].magnitude, target11[0])
  205. self.assertEqual(signal2[0, 0].magnitude, target21[0])
  206. class ButterTestCase(unittest.TestCase):
  207. def test_butter_filter_type(self):
  208. """
  209. Test if correct type of filtering is performed according to how cut-off
  210. frequencies are given
  211. """
  212. # generate white noise AnalogSignal
  213. noise = neo.AnalogSignal(
  214. np.random.normal(size=5000),
  215. sampling_rate=1000 * pq.Hz, units='mV')
  216. # test high-pass filtering: power at the lowest frequency
  217. # should be almost zero
  218. # Note: the default detrend function of scipy.signal.welch() seems to
  219. # cause artificial finite power at the lowest frequencies. Here I avoid
  220. # this by using an identity function for detrending
  221. filtered_noise = elephant.signal_processing.butter(
  222. noise, 250.0 * pq.Hz, None)
  223. _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0,
  224. detrend=lambda x: x)
  225. self.assertAlmostEqual(psd[0, 0], 0)
  226. # test low-pass filtering: power at the highest frequency
  227. # should be almost zero
  228. filtered_noise = elephant.signal_processing.butter(
  229. noise, None, 250.0 * pq.Hz)
  230. _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0)
  231. self.assertAlmostEqual(psd[0, -1], 0)
  232. # test band-pass filtering: power at the lowest and highest frequencies
  233. # should be almost zero
  234. filtered_noise = elephant.signal_processing.butter(
  235. noise, 200.0 * pq.Hz, 300.0 * pq.Hz)
  236. _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0,
  237. detrend=lambda x: x)
  238. self.assertAlmostEqual(psd[0, 0], 0)
  239. self.assertAlmostEqual(psd[0, -1], 0)
  240. # test band-stop filtering: power at the intermediate frequency
  241. # should be almost zero
  242. filtered_noise = elephant.signal_processing.butter(
  243. noise, 400.0 * pq.Hz, 100.0 * pq.Hz)
  244. _, psd = spsig.welch(filtered_noise.T, nperseg=1024, fs=1000.0)
  245. self.assertAlmostEqual(psd[0, 256], 0)
  246. def test_butter_filter_function(self):
  247. # generate white noise AnalogSignal
  248. noise = neo.AnalogSignal(
  249. np.random.normal(size=5000),
  250. sampling_rate=1000 * pq.Hz, units='mV')
  251. # test if the filter performance is as well with filftunc=lfilter as
  252. # with filtfunc=filtfilt (i.e. default option)
  253. kwds = {'signal': noise, 'highpass_freq': 250.0 * pq.Hz,
  254. 'lowpass_freq': None, 'filter_function': 'filtfilt'}
  255. filtered_noise = elephant.signal_processing.butter(**kwds)
  256. _, psd_filtfilt = spsig.welch(
  257. filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)
  258. kwds['filter_function'] = 'lfilter'
  259. filtered_noise = elephant.signal_processing.butter(**kwds)
  260. _, psd_lfilter = spsig.welch(
  261. filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)
  262. self.assertAlmostEqual(psd_filtfilt[0, 0], psd_lfilter[0, 0])
  263. def test_butter_invalid_filter_function(self):
  264. # generate a dummy AnalogSignal
  265. anasig_dummy = neo.AnalogSignal(
  266. np.zeros(5000), sampling_rate=1000 * pq.Hz, units='mV')
  267. # test exception upon invalid filtfunc string
  268. kwds = {'signal': anasig_dummy, 'highpass_freq': 250.0 * pq.Hz,
  269. 'filter_function': 'invalid_filter'}
  270. self.assertRaises(
  271. ValueError, elephant.signal_processing.butter, **kwds)
  272. def test_butter_missing_cutoff_freqs(self):
  273. # generate a dummy AnalogSignal
  274. anasig_dummy = neo.AnalogSignal(
  275. np.zeros(5000), sampling_rate=1000 * pq.Hz, units='mV')
  276. # test a case where no cut-off frequencies are given
  277. kwds = {'signal': anasig_dummy, 'highpass_freq': None,
  278. 'lowpass_freq': None}
  279. self.assertRaises(
  280. ValueError, elephant.signal_processing.butter, **kwds)
  281. def test_butter_input_types(self):
  282. # generate white noise data of different types
  283. noise_np = np.random.normal(size=5000)
  284. noise_pq = noise_np * pq.mV
  285. noise = neo.AnalogSignal(noise_pq, sampling_rate=1000.0 * pq.Hz)
  286. # check input as NumPy ndarray
  287. filtered_noise_np = elephant.signal_processing.butter(
  288. noise_np, 400.0, 100.0, fs=1000.0)
  289. self.assertTrue(isinstance(filtered_noise_np, np.ndarray))
  290. self.assertFalse(isinstance(filtered_noise_np, pq.quantity.Quantity))
  291. self.assertFalse(isinstance(filtered_noise_np, neo.AnalogSignal))
  292. self.assertEqual(filtered_noise_np.shape, noise_np.shape)
  293. # check input as Quantity array
  294. filtered_noise_pq = elephant.signal_processing.butter(
  295. noise_pq, 400.0 * pq.Hz, 100.0 * pq.Hz, fs=1000.0)
  296. self.assertTrue(isinstance(filtered_noise_pq, pq.quantity.Quantity))
  297. self.assertFalse(isinstance(filtered_noise_pq, neo.AnalogSignal))
  298. self.assertEqual(filtered_noise_pq.shape, noise_pq.shape)
  299. # check input as neo AnalogSignal
  300. filtered_noise = elephant.signal_processing.butter(noise,
  301. 400.0 * pq.Hz,
  302. 100.0 * pq.Hz)
  303. self.assertTrue(isinstance(filtered_noise, neo.AnalogSignal))
  304. self.assertEqual(filtered_noise.shape, noise.shape)
  305. # check if the results from different input types are identical
  306. self.assertTrue(np.all(
  307. filtered_noise_pq.magnitude == filtered_noise_np))
  308. self.assertTrue(np.all(
  309. filtered_noise.magnitude[:, 0] == filtered_noise_np))
  310. def test_butter_axis(self):
  311. noise = np.random.normal(size=(4, 5000))
  312. filtered_noise = elephant.signal_processing.butter(
  313. noise, 250.0, fs=1000.0)
  314. filtered_noise_transposed = elephant.signal_processing.butter(
  315. noise.T, 250.0, fs=1000.0, axis=0)
  316. self.assertTrue(np.all(filtered_noise == filtered_noise_transposed.T))
  317. def test_butter_multidim_input(self):
  318. noise_pq = np.random.normal(size=(4, 5000)) * pq.mV
  319. noise_neo = neo.AnalogSignal(
  320. noise_pq.T, sampling_rate=1000.0 * pq.Hz)
  321. noise_neo1d = neo.AnalogSignal(
  322. noise_pq[0], sampling_rate=1000.0 * pq.Hz)
  323. filtered_noise_pq = elephant.signal_processing.butter(
  324. noise_pq, 250.0, fs=1000.0)
  325. filtered_noise_neo = elephant.signal_processing.butter(
  326. noise_neo, 250.0)
  327. filtered_noise_neo1d = elephant.signal_processing.butter(
  328. noise_neo1d, 250.0)
  329. self.assertTrue(np.all(
  330. filtered_noise_pq.magnitude == filtered_noise_neo.T.magnitude))
  331. self.assertTrue(np.all(
  332. filtered_noise_neo1d.magnitude[:, 0] ==
  333. filtered_noise_neo.magnitude[:, 0]))
  334. class HilbertTestCase(unittest.TestCase):
  335. def setUp(self):
  336. # Generate test data of a harmonic function over a long time
  337. time = np.arange(0, 1000, 0.1) * pq.ms
  338. freq = 10 * pq.Hz
  339. self.amplitude = np.array([
  340. np.linspace(1, 10, len(time)),
  341. np.linspace(1, 10, len(time)),
  342. np.ones((len(time))),
  343. np.ones((len(time))) * 10.]).T
  344. self.phase = np.array([
  345. (time * freq).simplified.magnitude * 2. * np.pi,
  346. (time * freq).simplified.magnitude * 2. * np.pi + np.pi / 2,
  347. (time * freq).simplified.magnitude * 2. * np.pi + np.pi,
  348. (time * freq).simplified.magnitude * 2. * 2. * np.pi]).T
  349. self.phase = np.mod(self.phase + np.pi, 2. * np.pi) - np.pi
  350. # rising amplitude cosine, random ampl. sine, flat inverse cosine,
  351. # flat cosine at double frequency
  352. sigs = np.vstack([
  353. self.amplitude[:, 0] * np.cos(self.phase[:, 0]),
  354. self.amplitude[:, 1] * np.cos(self.phase[:, 1]),
  355. self.amplitude[:, 2] * np.cos(self.phase[:, 2]),
  356. self.amplitude[:, 3] * np.cos(self.phase[:, 3])])
  357. self.long_signals = neo.AnalogSignal(
  358. sigs.T, units='mV',
  359. t_start=0. * pq.ms,
  360. sampling_rate=(len(time) / (time[-1] - time[0])).rescale(pq.Hz),
  361. dtype=float)
  362. # Generate test data covering a single oscillation cycle in 1s only
  363. phases = np.arange(0, 2 * np.pi, np.pi / 256)
  364. sigs = np.vstack([
  365. np.sin(phases),
  366. np.cos(phases),
  367. np.sin(2 * phases),
  368. np.cos(2 * phases)])
  369. self.one_period = neo.AnalogSignal(
  370. sigs.T, units=pq.mV,
  371. sampling_rate=len(phases) * pq.Hz)
  372. def test_hilbert_pad_type_error(self):
  373. """
  374. Tests if incorrect pad_type raises ValueError.
  375. """
  376. padding = 'wrong_type'
  377. self.assertRaises(
  378. ValueError, elephant.signal_processing.hilbert,
  379. self.long_signals, N=padding)
  380. def test_hilbert_output_shape(self):
  381. """
  382. Tests if the length of the output is identical to the original signal,
  383. and the dimension is dimensionless.
  384. """
  385. true_shape = np.shape(self.long_signals)
  386. output = elephant.signal_processing.hilbert(
  387. self.long_signals, N='nextpow')
  388. self.assertEquals(np.shape(output), true_shape)
  389. self.assertEqual(output.units, pq.dimensionless)
  390. output = elephant.signal_processing.hilbert(
  391. self.long_signals, N=16384)
  392. self.assertEquals(np.shape(output), true_shape)
  393. self.assertEqual(output.units, pq.dimensionless)
  394. def test_hilbert_theoretical_long_signals(self):
  395. """
  396. Tests the output of the hilbert function with regard to amplitude and
  397. phase of long test signals
  398. """
  399. # Performing test using all pad types
  400. for padding in ['nextpow', 'none', 16384]:
  401. h = elephant.signal_processing.hilbert(
  402. self.long_signals, N=padding)
  403. phase = np.angle(h.magnitude)
  404. amplitude = np.abs(h.magnitude)
  405. real_value = np.real(h.magnitude)
  406. # The real part should be equal to the original long_signals
  407. assert_array_almost_equal(
  408. real_value,
  409. self.long_signals.magnitude,
  410. decimal=14)
  411. # Test only in the middle half of the array (border effects)
  412. ind1 = int(len(h.times) / 4)
  413. ind2 = int(3 * len(h.times) / 4)
  414. # Calculate difference in phase between signal and original phase
  415. # and use smaller of any two phase differences
  416. phasediff = np.abs(phase[ind1:ind2, :] - self.phase[ind1:ind2, :])
  417. phasediff[phasediff >= np.pi] = \
  418. 2 * np.pi - phasediff[phasediff >= np.pi]
  419. # Calculate difference in amplitude between signal and original
  420. # amplitude
  421. amplitudediff = \
  422. amplitude[ind1:ind2, :] - self.amplitude[ind1:ind2, :]
  423. #
  424. assert_allclose(phasediff, 0, atol=0.1)
  425. assert_allclose(amplitudediff, 0, atol=0.5)
  426. def test_hilbert_theoretical_one_period(self):
  427. """
  428. Tests the output of the hilbert function with regard to amplitude and
  429. phase of a short signal covering one cycle (more accurate estimate).
  430. This unit test is adapted from the scipy library of the hilbert()
  431. function.
  432. """
  433. # Precision of testing
  434. decimal = 14
  435. # Performing test using both pad types
  436. for padding in ['nextpow', 'none', 512]:
  437. h = elephant.signal_processing.hilbert(
  438. self.one_period, N=padding)
  439. amplitude = np.abs(h.magnitude)
  440. phase = np.angle(h.magnitude)
  441. real_value = np.real(h.magnitude)
  442. # The real part should be equal to the original long_signals:
  443. assert_array_almost_equal(
  444. real_value,
  445. self.one_period.magnitude,
  446. decimal=decimal)
  447. # The absolute value should be 1 everywhere, for this input:
  448. assert_array_almost_equal(
  449. amplitude,
  450. np.ones(self.one_period.magnitude.shape),
  451. decimal=decimal)
  452. # For the 'slow' sine - the phase should go from -pi/2 to pi/2 in
  453. # the first 256 bins:
  454. assert_array_almost_equal(
  455. phase[:256, 0],
  456. np.arange(-np.pi / 2, np.pi / 2, np.pi / 256),
  457. decimal=decimal)
  458. # For the 'slow' cosine - the phase should go from 0 to pi in the
  459. # same interval:
  460. assert_array_almost_equal(
  461. phase[:256, 1],
  462. np.arange(0, np.pi, np.pi / 256),
  463. decimal=decimal)
  464. # The 'fast' sine should make this phase transition in half the
  465. # time:
  466. assert_array_almost_equal(
  467. phase[:128, 2],
  468. np.arange(-np.pi / 2, np.pi / 2, np.pi / 128),
  469. decimal=decimal)
  470. # The 'fast' cosine should make this phase transition in half the
  471. # time:
  472. assert_array_almost_equal(
  473. phase[:128, 3],
  474. np.arange(0, np.pi, np.pi / 128),
  475. decimal=decimal)
  476. class WaveletTestCase(unittest.TestCase):
  477. def setUp(self):
  478. # generate a 10-sec test data of pure 50 Hz cosine wave
  479. self.fs = 1000.0
  480. self.times = np.arange(0, 10.0, 1/self.fs)
  481. self.test_freq1 = 50.0
  482. self.test_freq2 = 60.0
  483. self.test_data1 = np.cos(2*np.pi*self.test_freq1*self.times)
  484. self.test_data2 = np.sin(2*np.pi*self.test_freq2*self.times)
  485. self.test_data_arr = np.vstack([self.test_data1, self.test_data2])
  486. self.test_data = neo.AnalogSignal(
  487. self.test_data_arr.T*pq.mV, t_start=self.times[0]*pq.s,
  488. t_stop=self.times[-1]*pq.s, sampling_period=(1/self.fs)*pq.s)
  489. self.true_phase1 = np.angle(
  490. self.test_data1 + 1j*np.sin(2*np.pi*self.test_freq1*self.times))
  491. self.true_phase2 = np.angle(
  492. self.test_data2 - 1j*np.cos(2*np.pi*self.test_freq2*self.times))
  493. self.wt_freqs = [10, 20, 30]
  494. def test_wavelet_errors(self):
  495. """
  496. Tests if errors are raised as expected.
  497. """
  498. # too high center frequency
  499. kwds = {'signal': self.test_data, 'freq': self.fs/2}
  500. self.assertRaises(
  501. ValueError, elephant.signal_processing.wavelet_transform, **kwds)
  502. kwds = {'signal': self.test_data_arr, 'freq': self.fs/2, 'fs': self.fs}
  503. self.assertRaises(
  504. ValueError, elephant.signal_processing.wavelet_transform, **kwds)
  505. # too high center frequency in a list
  506. kwds = {'signal': self.test_data, 'freq': [self.fs/10, self.fs/2]}
  507. self.assertRaises(
  508. ValueError, elephant.signal_processing.wavelet_transform, **kwds)
  509. kwds = {'signal': self.test_data_arr,
  510. 'freq': [self.fs/10, self.fs/2], 'fs': self.fs}
  511. self.assertRaises(
  512. ValueError, elephant.signal_processing.wavelet_transform, **kwds)
  513. # nco is not positive
  514. kwds = {'signal': self.test_data, 'freq': self.fs/10, 'nco': 0}
  515. self.assertRaises(
  516. ValueError, elephant.signal_processing.wavelet_transform, **kwds)
  517. def test_wavelet_io(self):
  518. """
  519. Tests the data type and data shape of the output is consistent with
  520. that of the input, and also test the consistency between the outputs
  521. of different types
  522. """
  523. # check the shape of the result array
  524. # --- case of single center frequency
  525. wt = elephant.signal_processing.wavelet_transform(self.test_data,
  526. self.fs/10)
  527. self.assertTrue(wt.ndim == self.test_data.ndim)
  528. self.assertTrue(wt.shape[0] == self.test_data.shape[0]) # time axis
  529. self.assertTrue(wt.shape[1] == self.test_data.shape[1]) # channel axis
  530. wt_arr = elephant.signal_processing.wavelet_transform(
  531. self.test_data_arr, self.fs/10, fs=self.fs)
  532. self.assertTrue(wt_arr.ndim == self.test_data.ndim)
  533. # channel axis
  534. self.assertTrue(wt_arr.shape[0] == self.test_data_arr.shape[0])
  535. # time axis
  536. self.assertTrue(wt_arr.shape[1] == self.test_data_arr.shape[1])
  537. wt_arr1d = elephant.signal_processing.wavelet_transform(
  538. self.test_data1, self.fs/10, fs=self.fs)
  539. self.assertTrue(wt_arr1d.ndim == self.test_data1.ndim)
  540. # time axis
  541. self.assertTrue(wt_arr1d.shape[0] == self.test_data1.shape[0])
  542. # --- case of multiple center frequencies
  543. wt = elephant.signal_processing.wavelet_transform(
  544. self.test_data, self.wt_freqs)
  545. self.assertTrue(wt.ndim == self.test_data.ndim+1)
  546. self.assertTrue(wt.shape[0] == self.test_data.shape[0]) # time axis
  547. self.assertTrue(wt.shape[1] == self.test_data.shape[1]) # channel axis
  548. self.assertTrue(wt.shape[2] == len(self.wt_freqs)) # frequency axis
  549. wt_arr = elephant.signal_processing.wavelet_transform(
  550. self.test_data_arr, self.wt_freqs, fs=self.fs)
  551. self.assertTrue(wt_arr.ndim == self.test_data_arr.ndim+1)
  552. # channel axis
  553. self.assertTrue(wt_arr.shape[0] == self.test_data_arr.shape[0])
  554. # frequency axis
  555. self.assertTrue(wt_arr.shape[1] == len(self.wt_freqs))
  556. # time axis
  557. self.assertTrue(wt_arr.shape[2] == self.test_data_arr.shape[1])
  558. wt_arr1d = elephant.signal_processing.wavelet_transform(
  559. self.test_data1, self.wt_freqs, fs=self.fs)
  560. self.assertTrue(wt_arr1d.ndim == self.test_data1.ndim+1)
  561. # frequency axis
  562. self.assertTrue(wt_arr1d.shape[0] == len(self.wt_freqs))
  563. # time axis
  564. self.assertTrue(wt_arr1d.shape[1] == self.test_data1.shape[0])
  565. # check that the result does not depend on data type
  566. self.assertTrue(np.all(wt[:, 0, :] == wt_arr[0, :, :].T)) # channel 0
  567. self.assertTrue(np.all(wt[:, 1, :] == wt_arr[1, :, :].T)) # channel 1
  568. # check the data contents in the case where freq is given as a list
  569. # Note: there seems to be a bug in np.fft since NumPy 1.14.1, which
  570. # causes that the values of wt_1freq[:, 0] and wt_3freqs[:, 0, 0] are
  571. # not exactly equal, even though they use the same center frequency for
  572. # wavelet transform (in NumPy 1.13.1, they become identical). Here we
  573. # only check that they are almost equal.
  574. wt_1freq = elephant.signal_processing.wavelet_transform(
  575. self.test_data, self.wt_freqs[0])
  576. wt_3freqs = elephant.signal_processing.wavelet_transform(
  577. self.test_data, self.wt_freqs)
  578. assert_array_almost_equal(wt_1freq[:, 0], wt_3freqs[:, 0, 0],
  579. decimal=12)
  580. def test_wavelet_amplitude(self):
  581. """
  582. Tests amplitude properties of the obtained wavelet transform
  583. """
  584. # check that the amplitude of WT of a sinusoid is (almost) constant
  585. wt = elephant.signal_processing.wavelet_transform(self.test_data,
  586. self.test_freq1)
  587. # take a middle segment in order to avoid edge effects
  588. amp = np.abs(wt[int(len(wt)/3):int(len(wt)//3*2), 0])
  589. mean_amp = amp.mean()
  590. assert_array_almost_equal((amp - mean_amp) / mean_amp,
  591. np.zeros_like(amp), decimal=6)
  592. # check that the amplitude of WT is (almost) zero when center frequency
  593. # is considerably different from signal frequency
  594. wt_low = elephant.signal_processing.wavelet_transform(
  595. self.test_data, self.test_freq1/10)
  596. amp_low = np.abs(wt_low[int(len(wt)/3):int(len(wt)//3*2), 0])
  597. assert_array_almost_equal(amp_low, np.zeros_like(amp), decimal=6)
  598. # check that zero padding hardly affect the result
  599. wt_padded = elephant.signal_processing.wavelet_transform(
  600. self.test_data, self.test_freq1, zero_padding=False)
  601. amp_padded = np.abs(wt_padded[int(len(wt)/3):int(len(wt)//3*2), 0])
  602. assert_array_almost_equal(amp_padded, amp, decimal=9)
  603. def test_wavelet_phase(self):
  604. """
  605. Tests phase properties of the obtained wavelet transform
  606. """
  607. # check that the phase of WT is (almost) same as that of the original
  608. # sinusoid
  609. wt = elephant.signal_processing.wavelet_transform(self.test_data,
  610. self.test_freq1)
  611. phase = np.angle(wt[int(len(wt)/3):int(len(wt)//3*2), 0])
  612. true_phase = self.true_phase1[int(len(wt)/3):int(len(wt)//3*2)]
  613. assert_array_almost_equal(np.exp(1j*phase), np.exp(1j*true_phase),
  614. decimal=6)
  615. # check that zero padding hardly affect the result
  616. wt_padded = elephant.signal_processing.wavelet_transform(
  617. self.test_data, self.test_freq1, zero_padding=False)
  618. phase_padded = np.angle(wt_padded[int(len(wt)/3):int(len(wt)//3*2), 0])
  619. assert_array_almost_equal(np.exp(1j*phase_padded), np.exp(1j*phase),
  620. decimal=9)
  621. if __name__ == '__main__':
  622. unittest.main()