test_sta.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests for the function sta module
  4. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import math
  9. import numpy as np
  10. import scipy
  11. from numpy.testing import assert_array_equal
  12. from numpy.testing.utils import assert_array_almost_equal
  13. import neo
  14. from neo import AnalogSignal, SpikeTrain
  15. from elephant.conversion import BinnedSpikeTrain
  16. import quantities as pq
  17. from quantities import ms, mV, Hz
  18. import elephant.sta as sta
  19. import warnings
  20. class sta_TestCase(unittest.TestCase):
  21. def setUp(self):
  22. self.asiga0 = AnalogSignal(np.array([
  23. np.sin(np.arange(0, 20 * math.pi, 0.1))]).T,
  24. units='mV', sampling_rate=10 / ms)
  25. self.asiga1 = AnalogSignal(np.array([
  26. np.sin(np.arange(0, 20 * math.pi, 0.1)),
  27. np.cos(np.arange(0, 20 * math.pi, 0.1))]).T,
  28. units='mV', sampling_rate=10 / ms)
  29. self.asiga2 = AnalogSignal(np.array([
  30. np.sin(np.arange(0, 20 * math.pi, 0.1)),
  31. np.cos(np.arange(0, 20 * math.pi, 0.1)),
  32. np.tan(np.arange(0, 20 * math.pi, 0.1))]).T,
  33. units='mV', sampling_rate=10 / ms)
  34. self.st0 = SpikeTrain(
  35. [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
  36. units='ms', t_stop=self.asiga0.t_stop)
  37. self.lst = [SpikeTrain(
  38. [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
  39. units='ms', t_stop=self.asiga1.t_stop),
  40. SpikeTrain([30, 35, 40], units='ms', t_stop=self.asiga1.t_stop)]
  41. #***********************************************************************
  42. #************************ Test for typical values **********************
  43. def test_spike_triggered_average_with_n_spikes_on_constant_function(self):
  44. '''Signal should average to the input'''
  45. const = 13.8
  46. x = const * np.ones(201)
  47. asiga = AnalogSignal(
  48. np.array([x]).T, units='mV', sampling_rate=10 / ms)
  49. st = SpikeTrain([3, 5.6, 7, 7.1, 16, 16.3], units='ms', t_stop=20)
  50. window_starttime = -2 * ms
  51. window_endtime = 2 * ms
  52. STA = sta.spike_triggered_average(
  53. asiga, st, (window_starttime, window_endtime))
  54. a = int(((window_endtime - window_starttime) *
  55. asiga.sampling_rate).simplified)
  56. cutout = asiga[0: a]
  57. cutout.t_start = window_starttime
  58. assert_array_almost_equal(STA, cutout, 12)
  59. def test_spike_triggered_average_with_shifted_sin_wave(self):
  60. '''Signal should average to zero'''
  61. STA = sta.spike_triggered_average(
  62. self.asiga0, self.st0, (-4 * ms, 4 * ms))
  63. target = 5e-2 * mV
  64. self.assertEqual(np.abs(STA).max().dimensionality.simplified,
  65. pq.Quantity(1, "V").dimensionality.simplified)
  66. self.assertLess(np.abs(STA).max(), target)
  67. def test_only_one_spike(self):
  68. '''The output should be the same as the input'''
  69. x = np.arange(0, 20, 0.1)
  70. y = x**2
  71. sr = 10 / ms
  72. z = AnalogSignal(np.array([y]).T, units='mV', sampling_rate=sr)
  73. spiketime = 8 * ms
  74. spiketime_in_ms = int((spiketime / ms).simplified)
  75. st = SpikeTrain([spiketime_in_ms], units='ms', t_stop=20)
  76. window_starttime = -3 * ms
  77. window_endtime = 5 * ms
  78. STA = sta.spike_triggered_average(
  79. z, st, (window_starttime, window_endtime))
  80. cutout = z[int(((spiketime + window_starttime) * sr).simplified):
  81. int(((spiketime + window_endtime) * sr).simplified)]
  82. cutout.t_start = window_starttime
  83. assert_array_equal(STA, cutout)
  84. def test_usage_of_spikes(self):
  85. st = SpikeTrain([16.5 * math.pi, 17.5 * math.pi,
  86. 18.5 * math.pi, 19.5 * math.pi], units='ms', t_stop=20 * math.pi)
  87. STA = sta.spike_triggered_average(
  88. self.asiga0, st, (-math.pi * ms, math.pi * ms))
  89. self.assertEqual(STA.annotations['used_spikes'], 3)
  90. self.assertEqual(STA.annotations['unused_spikes'], 1)
  91. #***********************************************************************
  92. #**** Test for an invalid value, to check that the function raises *****
  93. #********* an exception or returns an error code ***********************
  94. def test_analog_signal_of_wrong_type(self):
  95. '''Analog signal given as list, but must be AnalogSignal'''
  96. asiga = [0, 1, 2, 3, 4]
  97. self.assertRaises(TypeError, sta.spike_triggered_average,
  98. asiga, self.st0, (-2 * ms, 2 * ms))
  99. def test_spiketrain_of_list_type_in_wrong_sense(self):
  100. st = [10, 11, 12]
  101. self.assertRaises(TypeError, sta.spike_triggered_average,
  102. self.asiga0, st, (1 * ms, 2 * ms))
  103. def test_spiketrain_of_nonlist_and_nonspiketrain_type(self):
  104. st = (10, 11, 12)
  105. self.assertRaises(TypeError, sta.spike_triggered_average,
  106. self.asiga0, st, (1 * ms, 2 * ms))
  107. def test_forgotten_AnalogSignal_argument(self):
  108. self.assertRaises(TypeError, sta.spike_triggered_average,
  109. self.st0, (-2 * ms, 2 * ms))
  110. def test_one_smaller_nrspiketrains_smaller_nranalogsignals(self):
  111. '''Number of spiketrains between 1 and number of analogsignals'''
  112. self.assertRaises(ValueError, sta.spike_triggered_average,
  113. self.asiga2, self.lst, (-2 * ms, 2 * ms))
  114. def test_more_spiketrains_than_analogsignals_forbidden(self):
  115. self.assertRaises(ValueError, sta.spike_triggered_average,
  116. self.asiga0, self.lst, (-2 * ms, 2 * ms))
  117. def test_spike_earlier_than_analogsignal(self):
  118. st = SpikeTrain([-1 * math.pi, 2 * math.pi],
  119. units='ms', t_start=-2 * math.pi, t_stop=20 * math.pi)
  120. self.assertRaises(ValueError, sta.spike_triggered_average,
  121. self.asiga0, st, (-2 * ms, 2 * ms))
  122. def test_spike_later_than_analogsignal(self):
  123. st = SpikeTrain(
  124. [math.pi, 21 * math.pi], units='ms', t_stop=25 * math.pi)
  125. self.assertRaises(ValueError, sta.spike_triggered_average,
  126. self.asiga0, st, (-2 * ms, 2 * ms))
  127. def test_impossible_window(self):
  128. self.assertRaises(ValueError, sta.spike_triggered_average,
  129. self.asiga0, self.st0, (-2 * ms, -5 * ms))
  130. def test_window_larger_than_signal(self):
  131. self.assertRaises(ValueError, sta.spike_triggered_average,
  132. self.asiga0, self.st0, (-15 * math.pi * ms, 15 * math.pi * ms))
  133. def test_wrong_window_starttime_unit(self):
  134. self.assertRaises(TypeError, sta.spike_triggered_average,
  135. self.asiga0, self.st0, (-2 * mV, 2 * ms))
  136. def test_wrong_window_endtime_unit(self):
  137. self.assertRaises(TypeError, sta.spike_triggered_average,
  138. self.asiga0, self.st0, (-2 * ms, 2 * Hz))
  139. def test_window_borders_as_complex_numbers(self):
  140. self.assertRaises(TypeError, sta.spike_triggered_average, self.asiga0,
  141. self.st0, ((-2 * math.pi + 3j) * ms, (2 * math.pi + 3j) * ms))
  142. #***********************************************************************
  143. #**** Test for an empty value (where the argument is a list, array, ****
  144. #********* vector or other container datatype). ************************
  145. def test_empty_analogsignal(self):
  146. asiga = AnalogSignal([], units='mV', sampling_rate=10 / ms)
  147. st = SpikeTrain([5], units='ms', t_stop=10)
  148. self.assertRaises(ValueError, sta.spike_triggered_average,
  149. asiga, st, (-1 * ms, 1 * ms))
  150. def test_one_spiketrain_empty(self):
  151. '''Test for one empty SpikeTrain, but existing spikes in other'''
  152. st = [SpikeTrain(
  153. [9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
  154. units='ms', t_stop=self.asiga1.t_stop),
  155. SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop)]
  156. STA = sta.spike_triggered_average(self.asiga1, st, (-1 * ms, 1 * ms))
  157. cmp_array = AnalogSignal(np.array([np.zeros(20, dtype=float)]).T,
  158. units='mV', sampling_rate=10 / ms)
  159. cmp_array = cmp_array / 0.
  160. cmp_array.t_start = -1 * ms
  161. assert_array_equal(STA[:, 1], cmp_array[:, 0])
  162. def test_all_spiketrains_empty(self):
  163. st = SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop)
  164. with warnings.catch_warnings(record=True) as w:
  165. # Cause all warnings to always be triggered.
  166. warnings.simplefilter("always")
  167. # Trigger warnings.
  168. STA = sta.spike_triggered_average(
  169. self.asiga1, st, (-1 * ms, 1 * ms))
  170. self.assertEqual("No spike at all was either found or used "
  171. "for averaging", str(w[-1].message))
  172. nan_array = np.empty(20)
  173. nan_array.fill(np.nan)
  174. cmp_array = AnalogSignal(np.array([nan_array, nan_array]).T,
  175. units='mV', sampling_rate=10 / ms)
  176. assert_array_equal(STA, cmp_array)
  177. # =========================================================================
  178. # Tests for new scipy verison (with scipy.signal.coherence)
  179. # =========================================================================
  180. @unittest.skipIf(not hasattr(scipy.signal, 'coherence'), "Please update scipy "
  181. "to a version >= 0.16")
  182. class sfc_TestCase_new_scipy(unittest.TestCase):
  183. def setUp(self):
  184. # standard testsignals
  185. tlen0 = 100 * pq.s
  186. f0 = 20. * pq.Hz
  187. fs0 = 1 * pq.ms
  188. t0 = np.arange(
  189. 0, tlen0.rescale(pq.s).magnitude,
  190. fs0.rescale(pq.s).magnitude) * pq.s
  191. self.anasig0 = AnalogSignal(
  192. np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
  193. units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
  194. self.st0 = SpikeTrain(
  195. np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms,
  196. t_start=0 * pq.ms, t_stop=tlen0)
  197. self.bst0 = BinnedSpikeTrain(self.st0, binsize=fs0)
  198. # shortened analogsignals
  199. self.anasig1 = self.anasig0.time_slice(1 * pq.s, None)
  200. self.anasig2 = self.anasig0.time_slice(None, 99 * pq.s)
  201. # increased sampling frequency
  202. fs1 = 0.1 * pq.ms
  203. self.anasig3 = AnalogSignal(
  204. np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
  205. units=pq.mV, t_start=0 * pq.ms, sampling_period=fs1)
  206. self.bst1 = BinnedSpikeTrain(
  207. self.st0.time_slice(self.anasig3.t_start, self.anasig3.t_stop),
  208. binsize=fs1)
  209. # analogsignal containing multiple traces
  210. self.anasig4 = AnalogSignal(
  211. np.array([
  212. np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
  213. np.sin(4 * np.pi * (f0 * t0).simplified.magnitude)]).
  214. transpose(),
  215. units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
  216. # shortened spike train
  217. self.st3 = SpikeTrain(
  218. np.arange(
  219. (tlen0.rescale(pq.ms).magnitude * .25),
  220. (tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms,
  221. t_start=0 * pq.ms, t_stop=tlen0)
  222. self.bst3 = BinnedSpikeTrain(self.st3, binsize=fs0)
  223. self.st4 = SpikeTrain(np.arange(
  224. (tlen0.rescale(pq.ms).magnitude * .25),
  225. (tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms,
  226. t_start=5 * fs0, t_stop=tlen0 - 5 * fs0)
  227. self.bst4 = BinnedSpikeTrain(self.st4, binsize=fs0)
  228. # spike train with incompatible binsize
  229. self.bst5 = BinnedSpikeTrain(self.st3, binsize=fs0 * 2.)
  230. # spike train with same binsize as the analog signal, but with
  231. # bin edges not aligned to the time axis of the analog signal
  232. self.bst6 = BinnedSpikeTrain(
  233. self.st3, binsize=fs0, t_start=4.5 * fs0, t_stop=tlen0 - 4.5 * fs0)
  234. # =========================================================================
  235. # Tests for correct input handling
  236. # =========================================================================
  237. def test_wrong_input_type(self):
  238. self.assertRaises(TypeError,
  239. sta.spike_field_coherence,
  240. np.array([1, 2, 3]), self.bst0)
  241. self.assertRaises(TypeError,
  242. sta.spike_field_coherence,
  243. self.anasig0, [1, 2, 3])
  244. self.assertRaises(ValueError,
  245. sta.spike_field_coherence,
  246. self.anasig0.duplicate_with_new_array([]), self.bst0)
  247. def test_start_stop_times_out_of_range(self):
  248. self.assertRaises(ValueError,
  249. sta.spike_field_coherence,
  250. self.anasig1, self.bst0)
  251. self.assertRaises(ValueError,
  252. sta.spike_field_coherence,
  253. self.anasig2, self.bst0)
  254. def test_non_matching_input_binning(self):
  255. self.assertRaises(ValueError,
  256. sta.spike_field_coherence,
  257. self.anasig0, self.bst1)
  258. def test_incompatible_spiketrain_analogsignal(self):
  259. # These spike trains have incompatible binning (binsize or alignment to
  260. # time axis of analog signal)
  261. self.assertRaises(ValueError,
  262. sta.spike_field_coherence,
  263. self.anasig0, self.bst5)
  264. self.assertRaises(ValueError,
  265. sta.spike_field_coherence,
  266. self.anasig0, self.bst6)
  267. def test_signal_dimensions(self):
  268. # single analogsignal trace and single spike train
  269. s_single, f_single = sta.spike_field_coherence(self.anasig0, self.bst0)
  270. self.assertEqual(len(f_single.shape), 1)
  271. self.assertEqual(len(s_single.shape), 2)
  272. # multiple analogsignal traces and single spike train
  273. s_multi, f_multi = sta.spike_field_coherence(self.anasig4, self.bst0)
  274. self.assertEqual(len(f_multi.shape), 1)
  275. self.assertEqual(len(s_multi.shape), 2)
  276. # frequencies are identical since same sampling frequency was used
  277. # in both cases and data length is the same
  278. assert_array_equal(f_single, f_multi)
  279. # coherences of s_single and first signal in s_multi are identical,
  280. # since first analogsignal trace in anasig4 is same as in anasig0
  281. assert_array_equal(s_single[:, 0], s_multi[:, 0])
  282. def test_non_binned_spiketrain_input(self):
  283. s, f = sta.spike_field_coherence(self.anasig0, self.st0)
  284. f_ind = np.where(f >= 19.)[0][0]
  285. max_ind = np.argmax(s[1:]) + 1
  286. self.assertEqual(f_ind, max_ind)
  287. self.assertAlmostEqual(s[f_ind], 1., delta=0.01)
  288. # =========================================================================
  289. # Tests for correct return values
  290. # =========================================================================
  291. def test_spike_field_coherence_perfect_coherence(self):
  292. # check for detection of 20Hz peak in anasig0/bst0
  293. s, f = sta.spike_field_coherence(
  294. self.anasig0, self.bst0, window='boxcar')
  295. f_ind = np.where(f >= 19.)[0][0]
  296. max_ind = np.argmax(s[1:]) + 1
  297. self.assertEqual(f_ind, max_ind)
  298. self.assertAlmostEqual(s[f_ind], 1., delta=0.01)
  299. def test_output_frequencies(self):
  300. nfft = 256
  301. _, f = sta.spike_field_coherence(self.anasig3, self.bst1, nfft=nfft)
  302. # check number of frequency samples
  303. self.assertEqual(len(f), nfft / 2 + 1)
  304. # check values of frequency samples
  305. assert_array_almost_equal(
  306. f, np.linspace(
  307. 0, self.anasig3.sampling_rate.rescale('Hz').magnitude / 2,
  308. nfft / 2 + 1) * pq.Hz)
  309. def test_short_spiketrain(self):
  310. # this spike train has the same length as anasig0
  311. s1, f1 = sta.spike_field_coherence(
  312. self.anasig0, self.bst3, window='boxcar')
  313. # this spike train has the same spikes as above, but is shorter than
  314. # anasig0
  315. s2, f2 = sta.spike_field_coherence(
  316. self.anasig0, self.bst4, window='boxcar')
  317. # the results above should be the same, nevertheless
  318. assert_array_equal(s1.magnitude, s2.magnitude)
  319. assert_array_equal(f1.magnitude, f2.magnitude)
  320. # =========================================================================
  321. # Tests for old scipy verison (without scipy.signal.coherence)
  322. # =========================================================================
  323. @unittest.skipIf(hasattr(scipy.signal, 'coherence'), 'Applies only for old '
  324. 'scipy versions (<0.16)')
  325. class sfc_TestCase_old_scipy(unittest.TestCase):
  326. def setUp(self):
  327. # standard testsignals
  328. tlen0 = 100 * pq.s
  329. f0 = 20. * pq.Hz
  330. fs0 = 1 * pq.ms
  331. t0 = np.arange(
  332. 0, tlen0.rescale(pq.s).magnitude,
  333. fs0.rescale(pq.s).magnitude) * pq.s
  334. self.anasig0 = AnalogSignal(
  335. np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
  336. units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
  337. self.st0 = SpikeTrain(
  338. np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms,
  339. t_start=0 * pq.ms, t_stop=tlen0)
  340. self.bst0 = BinnedSpikeTrain(self.st0, binsize=fs0)
  341. def test_old_scipy_version(self):
  342. self.assertRaises(AttributeError, sta.spike_field_coherence,
  343. self.anasig0, self.bst0)
  344. if __name__ == '__main__':
  345. unittest.main()