test_statistics.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the statistics module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division
  8. import unittest
  9. import neo
  10. import numpy as np
  11. from numpy.testing.utils import assert_array_almost_equal, assert_array_equal
  12. import quantities as pq
  13. import scipy.integrate as spint
  14. import elephant.statistics as es
  15. import elephant.kernels as kernels
  16. import warnings
  17. class isi_TestCase(unittest.TestCase):
  18. def setUp(self):
  19. self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23],
  20. [0.02, 0.71, 1.82, 8.46],
  21. [0.03, 0.14, 0.15, 0.92]])
  22. self.targ_array_2d_0 = np.array([[-0.28, 0.15, 0.95, 7.23],
  23. [0.01, -0.57, -1.67, -7.54]])
  24. self.targ_array_2d_1 = np.array([[0.26, 0.31, 0.36],
  25. [0.69, 1.11, 6.64],
  26. [0.11, 0.01, 0.77]])
  27. self.targ_array_2d_default = self.targ_array_2d_1
  28. self.test_array_1d = self.test_array_2d[0, :]
  29. self.targ_array_1d = self.targ_array_2d_1[0, :]
  30. def test_isi_with_spiketrain(self):
  31. st = neo.SpikeTrain(
  32. self.test_array_1d, units='ms', t_stop=10.0, t_start=0.29)
  33. target = pq.Quantity(self.targ_array_1d, 'ms')
  34. res = es.isi(st)
  35. assert_array_almost_equal(res, target, decimal=9)
  36. def test_isi_with_quantities_1d(self):
  37. st = pq.Quantity(self.test_array_1d, units='ms')
  38. target = pq.Quantity(self.targ_array_1d, 'ms')
  39. res = es.isi(st)
  40. assert_array_almost_equal(res, target, decimal=9)
  41. def test_isi_with_plain_array_1d(self):
  42. st = self.test_array_1d
  43. target = self.targ_array_1d
  44. res = es.isi(st)
  45. assert not isinstance(res, pq.Quantity)
  46. assert_array_almost_equal(res, target, decimal=9)
  47. def test_isi_with_plain_array_2d_default(self):
  48. st = self.test_array_2d
  49. target = self.targ_array_2d_default
  50. res = es.isi(st)
  51. assert not isinstance(res, pq.Quantity)
  52. assert_array_almost_equal(res, target, decimal=9)
  53. def test_isi_with_plain_array_2d_0(self):
  54. st = self.test_array_2d
  55. target = self.targ_array_2d_0
  56. res = es.isi(st, axis=0)
  57. assert not isinstance(res, pq.Quantity)
  58. assert_array_almost_equal(res, target, decimal=9)
  59. def test_isi_with_plain_array_2d_1(self):
  60. st = self.test_array_2d
  61. target = self.targ_array_2d_1
  62. res = es.isi(st, axis=1)
  63. assert not isinstance(res, pq.Quantity)
  64. assert_array_almost_equal(res, target, decimal=9)
  65. class isi_cv_TestCase(unittest.TestCase):
  66. def setUp(self):
  67. self.test_array_regular = np.arange(1, 6)
  68. def test_cv_isi_regular_spiketrain_is_zero(self):
  69. st = neo.SpikeTrain(self.test_array_regular, units='ms', t_stop=10.0)
  70. targ = 0.0
  71. res = es.cv(es.isi(st))
  72. self.assertEqual(res, targ)
  73. def test_cv_isi_regular_array_is_zero(self):
  74. st = self.test_array_regular
  75. targ = 0.0
  76. res = es.cv(es.isi(st))
  77. self.assertEqual(res, targ)
  78. class mean_firing_rate_TestCase(unittest.TestCase):
  79. def setUp(self):
  80. self.test_array_3d = np.ones([5, 7, 13])
  81. self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23],
  82. [0.02, 0.71, 1.82, 8.46],
  83. [0.03, 0.14, 0.15, 0.92]])
  84. self.targ_array_2d_0 = np.array([3, 3, 3, 3])
  85. self.targ_array_2d_1 = np.array([4, 4, 4])
  86. self.targ_array_2d_None = 12
  87. self.targ_array_2d_default = self.targ_array_2d_None
  88. self.max_array_2d_0 = np.array([0.3, 0.71, 1.82, 8.46])
  89. self.max_array_2d_1 = np.array([1.23, 8.46, 0.92])
  90. self.max_array_2d_None = 8.46
  91. self.max_array_2d_default = self.max_array_2d_None
  92. self.test_array_1d = self.test_array_2d[0, :]
  93. self.targ_array_1d = self.targ_array_2d_1[0]
  94. self.max_array_1d = self.max_array_2d_1[0]
  95. def test_mean_firing_rate_with_spiketrain(self):
  96. st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0)
  97. target = pq.Quantity(self.targ_array_1d/10., '1/ms')
  98. res = es.mean_firing_rate(st)
  99. assert_array_almost_equal(res, target, decimal=9)
  100. def test_mean_firing_rate_with_spiketrain_set_ends(self):
  101. st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0)
  102. target = pq.Quantity(2/0.5, '1/ms')
  103. res = es.mean_firing_rate(st, t_start=0.4, t_stop=0.9)
  104. assert_array_almost_equal(res, target, decimal=9)
  105. def test_mean_firing_rate_with_quantities_1d(self):
  106. st = pq.Quantity(self.test_array_1d, units='ms')
  107. target = pq.Quantity(self.targ_array_1d/self.max_array_1d, '1/ms')
  108. res = es.mean_firing_rate(st)
  109. assert_array_almost_equal(res, target, decimal=9)
  110. def test_mean_firing_rate_with_quantities_1d_set_ends(self):
  111. st = pq.Quantity(self.test_array_1d, units='ms')
  112. target = pq.Quantity(2/0.6, '1/ms')
  113. res = es.mean_firing_rate(st, t_start=400*pq.us, t_stop=1.)
  114. assert_array_almost_equal(res, target, decimal=9)
  115. def test_mean_firing_rate_with_plain_array_1d(self):
  116. st = self.test_array_1d
  117. target = self.targ_array_1d/self.max_array_1d
  118. res = es.mean_firing_rate(st)
  119. assert not isinstance(res, pq.Quantity)
  120. assert_array_almost_equal(res, target, decimal=9)
  121. def test_mean_firing_rate_with_plain_array_1d_set_ends(self):
  122. st = self.test_array_1d
  123. target = self.targ_array_1d/(1.23-0.3)
  124. res = es.mean_firing_rate(st, t_start=0.3, t_stop=1.23)
  125. assert not isinstance(res, pq.Quantity)
  126. assert_array_almost_equal(res, target, decimal=9)
  127. def test_mean_firing_rate_with_plain_array_2d_default(self):
  128. st = self.test_array_2d
  129. target = self.targ_array_2d_default/self.max_array_2d_default
  130. res = es.mean_firing_rate(st)
  131. assert not isinstance(res, pq.Quantity)
  132. assert_array_almost_equal(res, target, decimal=9)
  133. def test_mean_firing_rate_with_plain_array_2d_0(self):
  134. st = self.test_array_2d
  135. target = self.targ_array_2d_0/self.max_array_2d_0
  136. res = es.mean_firing_rate(st, axis=0)
  137. assert not isinstance(res, pq.Quantity)
  138. assert_array_almost_equal(res, target, decimal=9)
  139. def test_mean_firing_rate_with_plain_array_2d_1(self):
  140. st = self.test_array_2d
  141. target = self.targ_array_2d_1/self.max_array_2d_1
  142. res = es.mean_firing_rate(st, axis=1)
  143. assert not isinstance(res, pq.Quantity)
  144. assert_array_almost_equal(res, target, decimal=9)
  145. def test_mean_firing_rate_with_plain_array_3d_None(self):
  146. st = self.test_array_3d
  147. target = np.sum(self.test_array_3d, None)/5.
  148. res = es.mean_firing_rate(st, axis=None, t_stop=5.)
  149. assert not isinstance(res, pq.Quantity)
  150. assert_array_almost_equal(res, target, decimal=9)
  151. def test_mean_firing_rate_with_plain_array_3d_0(self):
  152. st = self.test_array_3d
  153. target = np.sum(self.test_array_3d, 0)/5.
  154. res = es.mean_firing_rate(st, axis=0, t_stop=5.)
  155. assert not isinstance(res, pq.Quantity)
  156. assert_array_almost_equal(res, target, decimal=9)
  157. def test_mean_firing_rate_with_plain_array_3d_1(self):
  158. st = self.test_array_3d
  159. target = np.sum(self.test_array_3d, 1)/5.
  160. res = es.mean_firing_rate(st, axis=1, t_stop=5.)
  161. assert not isinstance(res, pq.Quantity)
  162. assert_array_almost_equal(res, target, decimal=9)
  163. def test_mean_firing_rate_with_plain_array_3d_2(self):
  164. st = self.test_array_3d
  165. target = np.sum(self.test_array_3d, 2)/5.
  166. res = es.mean_firing_rate(st, axis=2, t_stop=5.)
  167. assert not isinstance(res, pq.Quantity)
  168. assert_array_almost_equal(res, target, decimal=9)
  169. def test_mean_firing_rate_with_plain_array_2d_1_set_ends(self):
  170. st = self.test_array_2d
  171. target = np.array([4, 1, 3])/(1.23-0.14)
  172. res = es.mean_firing_rate(st, axis=1, t_start=0.14, t_stop=1.23)
  173. assert not isinstance(res, pq.Quantity)
  174. assert_array_almost_equal(res, target, decimal=9)
  175. def test_mean_firing_rate_with_plain_array_2d_None(self):
  176. st = self.test_array_2d
  177. target = self.targ_array_2d_None/self.max_array_2d_None
  178. res = es.mean_firing_rate(st, axis=None)
  179. assert not isinstance(res, pq.Quantity)
  180. assert_array_almost_equal(res, target, decimal=9)
  181. def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror(self):
  182. st = self.test_array_2d
  183. self.assertRaises(TypeError, es.mean_firing_rate, st,
  184. t_start=pq.Quantity(0, 'ms'))
  185. self.assertRaises(TypeError, es.mean_firing_rate, st,
  186. t_stop=pq.Quantity(10, 'ms'))
  187. self.assertRaises(TypeError, es.mean_firing_rate, st,
  188. t_start=pq.Quantity(0, 'ms'),
  189. t_stop=pq.Quantity(10, 'ms'))
  190. self.assertRaises(TypeError, es.mean_firing_rate, st,
  191. t_start=pq.Quantity(0, 'ms'),
  192. t_stop=10.)
  193. self.assertRaises(TypeError, es.mean_firing_rate, st,
  194. t_start=0.,
  195. t_stop=pq.Quantity(10, 'ms'))
  196. class FanoFactorTestCase(unittest.TestCase):
  197. def setUp(self):
  198. np.random.seed(100)
  199. num_st = 300
  200. self.test_spiketrains = []
  201. self.test_array = []
  202. self.test_quantity = []
  203. self.test_list = []
  204. self.sp_counts = np.zeros(num_st)
  205. for i in range(num_st):
  206. r = np.random.rand(np.random.randint(20) + 1)
  207. st = neo.core.SpikeTrain(r * pq.ms,
  208. t_start=0.0 * pq.ms,
  209. t_stop=20.0 * pq.ms)
  210. self.test_spiketrains.append(st)
  211. self.test_array.append(r)
  212. self.test_quantity.append(r * pq.ms)
  213. self.test_list.append(list(r))
  214. # for cross-validation
  215. self.sp_counts[i] = len(st)
  216. def test_fanofactor_spiketrains(self):
  217. # Test with list of spiketrains
  218. self.assertEqual(
  219. np.var(self.sp_counts) / np.mean(self.sp_counts),
  220. es.fanofactor(self.test_spiketrains))
  221. # One spiketrain in list
  222. st = self.test_spiketrains[0]
  223. self.assertEqual(es.fanofactor([st]), 0.0)
  224. def test_fanofactor_empty(self):
  225. # Test with empty list
  226. self.assertTrue(np.isnan(es.fanofactor([])))
  227. self.assertTrue(np.isnan(es.fanofactor([[]])))
  228. # Test with empty quantity
  229. self.assertTrue(np.isnan(es.fanofactor([] * pq.ms)))
  230. # Empty spiketrain
  231. st = neo.core.SpikeTrain([] * pq.ms, t_start=0 * pq.ms,
  232. t_stop=1.5 * pq.ms)
  233. self.assertTrue(np.isnan(es.fanofactor(st)))
  234. def test_fanofactor_spiketrains_same(self):
  235. # Test with same spiketrains in list
  236. sts = [self.test_spiketrains[0]] * 3
  237. self.assertEqual(es.fanofactor(sts), 0.0)
  238. def test_fanofactor_array(self):
  239. self.assertEqual(es.fanofactor(self.test_array),
  240. np.var(self.sp_counts) / np.mean(self.sp_counts))
  241. def test_fanofactor_array_same(self):
  242. lst = [self.test_array[0]] * 3
  243. self.assertEqual(es.fanofactor(lst), 0.0)
  244. def test_fanofactor_quantity(self):
  245. self.assertEqual(es.fanofactor(self.test_quantity),
  246. np.var(self.sp_counts) / np.mean(self.sp_counts))
  247. def test_fanofactor_quantity_same(self):
  248. lst = [self.test_quantity[0]] * 3
  249. self.assertEqual(es.fanofactor(lst), 0.0)
  250. def test_fanofactor_list(self):
  251. self.assertEqual(es.fanofactor(self.test_list),
  252. np.var(self.sp_counts) / np.mean(self.sp_counts))
  253. def test_fanofactor_list_same(self):
  254. lst = [self.test_list[0]] * 3
  255. self.assertEqual(es.fanofactor(lst), 0.0)
  256. class LVTestCase(unittest.TestCase):
  257. def setUp(self):
  258. self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12,
  259. 4, 12, 59, 2, 4, 18, 33, 25, 2, 34,
  260. 4, 1, 1, 14, 8, 1, 10, 1, 8, 20,
  261. 5, 1, 6, 5, 12, 2, 8, 8, 2, 8,
  262. 2, 10, 2, 1, 1, 2, 15, 3, 20, 6,
  263. 11, 6, 18, 2, 5, 17, 4, 3, 13, 6,
  264. 1, 18, 1, 16, 12, 2, 52, 2, 5, 7,
  265. 6, 25, 6, 5, 3, 15, 4, 3, 16, 3,
  266. 6, 5, 24, 21, 3, 3, 4, 8, 4, 11,
  267. 5, 7, 5, 6, 8, 11, 33, 10, 7, 4]
  268. self.target = 0.971826029994
  269. def test_lv_with_quantities(self):
  270. seq = pq.Quantity(self.test_seq, units='ms')
  271. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  272. def test_lv_with_plain_array(self):
  273. seq = np.array(self.test_seq)
  274. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  275. def test_lv_with_list(self):
  276. seq = self.test_seq
  277. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  278. def test_lv_raise_error(self):
  279. seq = self.test_seq
  280. self.assertRaises(AttributeError, es.lv, [])
  281. self.assertRaises(AttributeError, es.lv, 1)
  282. self.assertRaises(ValueError, es.lv, np.array([seq, seq]))
  283. class CV2TestCase(unittest.TestCase):
  284. def setUp(self):
  285. self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12,
  286. 4, 12, 59, 2, 4, 18, 33, 25, 2, 34,
  287. 4, 1, 1, 14, 8, 1, 10, 1, 8, 20,
  288. 5, 1, 6, 5, 12, 2, 8, 8, 2, 8,
  289. 2, 10, 2, 1, 1, 2, 15, 3, 20, 6,
  290. 11, 6, 18, 2, 5, 17, 4, 3, 13, 6,
  291. 1, 18, 1, 16, 12, 2, 52, 2, 5, 7,
  292. 6, 25, 6, 5, 3, 15, 4, 3, 16, 3,
  293. 6, 5, 24, 21, 3, 3, 4, 8, 4, 11,
  294. 5, 7, 5, 6, 8, 11, 33, 10, 7, 4]
  295. self.target = 1.0022235296529176
  296. def test_cv2_with_quantities(self):
  297. seq = pq.Quantity(self.test_seq, units='ms')
  298. assert_array_almost_equal(es.cv2(seq), self.target, decimal=9)
  299. def test_cv2_with_plain_array(self):
  300. seq = np.array(self.test_seq)
  301. assert_array_almost_equal(es.cv2(seq), self.target, decimal=9)
  302. def test_cv2_with_list(self):
  303. seq = self.test_seq
  304. assert_array_almost_equal(es.cv2(seq), self.target, decimal=9)
  305. def test_cv2_raise_error(self):
  306. seq = self.test_seq
  307. self.assertRaises(AttributeError, es.cv2, [])
  308. self.assertRaises(AttributeError, es.cv2, 1)
  309. self.assertRaises(AttributeError, es.cv2, np.array([seq, seq]))
  310. class RateEstimationTestCase(unittest.TestCase):
  311. def setUp(self):
  312. # create a poisson spike train:
  313. self.st_tr = (0, 20.0) # seconds
  314. self.st_dur = self.st_tr[1] - self.st_tr[0] # seconds
  315. self.st_margin = 5.0 # seconds
  316. self.st_rate = 10.0 # Hertz
  317. st_num_spikes = np.random.poisson(
  318. self.st_rate*(self.st_dur-2*self.st_margin))
  319. spike_train = np.random.rand(
  320. st_num_spikes) * (self.st_dur-2*self.st_margin) + self.st_margin
  321. spike_train.sort()
  322. # convert spike train into neo objects
  323. self.spike_train = neo.SpikeTrain(spike_train*pq.s,
  324. t_start=self.st_tr[0]*pq.s,
  325. t_stop=self.st_tr[1]*pq.s)
  326. # generation of a multiply used specific kernel
  327. self.kernel = kernels.TriangularKernel(sigma=0.03*pq.s)
  328. def test_instantaneous_rate_and_warnings(self):
  329. st = self.spike_train
  330. sampling_period = 0.01*pq.s
  331. with warnings.catch_warnings(record=True) as w:
  332. inst_rate = es.instantaneous_rate(
  333. st, sampling_period, self.kernel, cutoff=0)
  334. message1 = "The width of the kernel was adjusted to a minimally " \
  335. "allowed width."
  336. message2 = "Instantaneous firing rate approximation contains " \
  337. "negative values, possibly caused due to machine " \
  338. "precision errors."
  339. warning_message = [str(m.message) for m in w]
  340. self.assertTrue(message1 in warning_message)
  341. self.assertTrue(message2 in warning_message)
  342. self.assertIsInstance(inst_rate, neo.core.AnalogSignal)
  343. self.assertEqual(
  344. inst_rate.sampling_period.simplified, sampling_period.simplified)
  345. self.assertEqual(inst_rate.simplified.units, pq.Hz)
  346. self.assertEqual(inst_rate.t_stop.simplified, st.t_stop.simplified)
  347. self.assertEqual(inst_rate.t_start.simplified, st.t_start.simplified)
  348. def test_error_instantaneous_rate(self):
  349. self.assertRaises(
  350. TypeError, es.instantaneous_rate, spiketrain=[1, 2, 3]*pq.s,
  351. sampling_period=0.01*pq.ms, kernel=self.kernel)
  352. self.assertRaises(
  353. TypeError, es.instantaneous_rate, spiketrain=[1, 2, 3],
  354. sampling_period=0.01*pq.ms, kernel=self.kernel)
  355. st = self.spike_train
  356. self.assertRaises(
  357. TypeError, es.instantaneous_rate, spiketrain=st,
  358. sampling_period=0.01, kernel=self.kernel)
  359. self.assertRaises(
  360. ValueError, es.instantaneous_rate, spiketrain=st,
  361. sampling_period=-0.01*pq.ms, kernel=self.kernel)
  362. self.assertRaises(
  363. TypeError, es.instantaneous_rate, spiketrain=st,
  364. sampling_period=0.01*pq.ms, kernel='NONE')
  365. self.assertRaises(TypeError, es.instantaneous_rate, self.spike_train,
  366. sampling_period=0.01*pq.s, kernel='wrong_string',
  367. t_start=self.st_tr[0]*pq.s, t_stop=self.st_tr[1]*pq.s,
  368. trim=False)
  369. self.assertRaises(
  370. TypeError, es.instantaneous_rate, spiketrain=st,
  371. sampling_period=0.01*pq.ms, kernel=self.kernel, cutoff=20*pq.ms)
  372. self.assertRaises(
  373. TypeError, es.instantaneous_rate, spiketrain=st,
  374. sampling_period=0.01*pq.ms, kernel=self.kernel, t_start=2)
  375. self.assertRaises(
  376. TypeError, es.instantaneous_rate, spiketrain=st,
  377. sampling_period=0.01*pq.ms, kernel=self.kernel, t_stop=20*pq.mV)
  378. self.assertRaises(
  379. TypeError, es.instantaneous_rate, spiketrain=st,
  380. sampling_period=0.01*pq.ms, kernel=self.kernel, trim=1)
  381. def test_rate_estimation_consistency(self):
  382. """
  383. Test, whether the integral of the rate estimation curve is (almost)
  384. equal to the number of spikes of the spike train.
  385. """
  386. kernel_types = [obj for obj in kernels.__dict__.values()
  387. if isinstance(obj, type) and
  388. issubclass(obj, kernels.Kernel) and
  389. hasattr(obj, "_evaluate") and
  390. obj is not kernels.Kernel and
  391. obj is not kernels.SymmetricKernel]
  392. kernel_list = [kernel_type(sigma=0.5*pq.s, invert=False)
  393. for kernel_type in kernel_types]
  394. kernel_resolution = 0.01*pq.s
  395. for kernel in kernel_list:
  396. rate_estimate_a0 = es.instantaneous_rate(self.spike_train,
  397. sampling_period=kernel_resolution,
  398. kernel='auto',
  399. t_start=self.st_tr[0]*pq.s,
  400. t_stop=self.st_tr[1]*pq.s,
  401. trim=False)
  402. rate_estimate0 = es.instantaneous_rate(self.spike_train,
  403. sampling_period=kernel_resolution,
  404. kernel=kernel)
  405. rate_estimate1 = es.instantaneous_rate(self.spike_train,
  406. sampling_period=kernel_resolution,
  407. kernel=kernel,
  408. t_start=self.st_tr[0]*pq.s,
  409. t_stop=self.st_tr[1]*pq.s,
  410. trim=False)
  411. rate_estimate2 = es.instantaneous_rate(self.spike_train,
  412. sampling_period=kernel_resolution,
  413. kernel=kernel,
  414. t_start=self.st_tr[0]*pq.s,
  415. t_stop=self.st_tr[1]*pq.s,
  416. trim=True)
  417. # test consistency
  418. rate_estimate_list = [rate_estimate0, rate_estimate1,
  419. rate_estimate2, rate_estimate_a0]
  420. for rate_estimate in rate_estimate_list:
  421. num_spikes = len(self.spike_train)
  422. auc = spint.cumtrapz(y=rate_estimate.magnitude[:, 0],
  423. x=rate_estimate.times.rescale('s').magnitude)[-1]
  424. self.assertAlmostEqual(num_spikes, auc, delta=0.05*num_spikes)
  425. def test_instantaneous_rate_spiketrainlist(self):
  426. st_num_spikes = np.random.poisson(
  427. self.st_rate*(self.st_dur-2*self.st_margin))
  428. spike_train2 = np.random.rand(
  429. st_num_spikes) * (self.st_dur - 2 * self.st_margin) + self.st_margin
  430. spike_train2.sort()
  431. spike_train2 = neo.SpikeTrain(spike_train2 * pq.s,
  432. t_start=self.st_tr[0] * pq.s,
  433. t_stop=self.st_tr[1] * pq.s)
  434. st_rate_1 = es.instantaneous_rate(self.spike_train,
  435. sampling_period=0.01*pq.s,
  436. kernel=self.kernel)
  437. st_rate_2 = es.instantaneous_rate(spike_train2,
  438. sampling_period=0.01*pq.s,
  439. kernel=self.kernel)
  440. combined_rate = es.instantaneous_rate([self.spike_train, spike_train2],
  441. sampling_period=0.01*pq.s,
  442. kernel=self.kernel)
  443. summed_rate = st_rate_1 + st_rate_2 # equivalent for identical kernels
  444. for a, b in zip(combined_rate.magnitude, summed_rate.magnitude):
  445. self.assertAlmostEqual(a, b, delta=0.0001)
  446. # Regression test for #144
  447. def test_instantaneous_rate_regression_144(self):
  448. # The following spike train contains spikes that are so close to each
  449. # other, that the optimal kernel cannot be detected. Therefore, the
  450. # function should react with a ValueError.
  451. st = neo.SpikeTrain([2.12, 2.13, 2.15] * pq.s, t_stop=10 * pq.s)
  452. self.assertRaises(ValueError, es.instantaneous_rate, st, 1 * pq.ms)
  453. class TimeHistogramTestCase(unittest.TestCase):
  454. def setUp(self):
  455. self.spiketrain_a = neo.SpikeTrain(
  456. [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s)
  457. self.spiketrain_b = neo.SpikeTrain(
  458. [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  459. self.spiketrains = [self.spiketrain_a, self.spiketrain_b]
  460. def tearDown(self):
  461. del self.spiketrain_a
  462. self.spiketrain_a = None
  463. del self.spiketrain_b
  464. self.spiketrain_b = None
  465. def test_time_histogram(self):
  466. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  467. histogram = es.time_histogram(self.spiketrains, binsize=pq.s)
  468. assert_array_equal(targ, histogram.magnitude[:, 0])
  469. def test_time_histogram_binary(self):
  470. targ = np.array([2, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  471. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  472. binary=True)
  473. assert_array_equal(targ, histogram.magnitude[:, 0])
  474. def test_time_histogram_tstart_tstop(self):
  475. # Start, stop short range
  476. targ = np.array([2, 1])
  477. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  478. t_start=5 * pq.s, t_stop=7 * pq.s)
  479. assert_array_equal(targ, histogram.magnitude[:, 0])
  480. # Test without t_stop
  481. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  482. histogram = es.time_histogram(self.spiketrains, binsize=1 * pq.s,
  483. t_start=0 * pq.s)
  484. assert_array_equal(targ, histogram.magnitude[:, 0])
  485. # Test without t_start
  486. histogram = es.time_histogram(self.spiketrains, binsize=1 * pq.s,
  487. t_stop=10 * pq.s)
  488. assert_array_equal(targ, histogram.magnitude[:, 0])
  489. def test_time_histogram_output(self):
  490. # Normalization mean
  491. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  492. output='mean')
  493. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0], dtype=float) / 2
  494. assert_array_equal(targ.reshape(targ.size, 1), histogram.magnitude)
  495. # Normalization rate
  496. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  497. output='rate')
  498. assert_array_equal(histogram.view(pq.Quantity),
  499. targ.reshape(targ.size, 1) * 1 / pq.s)
  500. # Normalization unspecified, raises error
  501. self.assertRaises(ValueError, es.time_histogram, self.spiketrains,
  502. binsize=pq.s, output=' ')
  503. class ComplexityPdfTestCase(unittest.TestCase):
  504. def setUp(self):
  505. self.spiketrain_a = neo.SpikeTrain(
  506. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s)
  507. self.spiketrain_b = neo.SpikeTrain(
  508. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  509. self.spiketrain_c = neo.SpikeTrain(
  510. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  511. self.spiketrains = [
  512. self.spiketrain_a, self.spiketrain_b, self.spiketrain_c]
  513. def tearDown(self):
  514. del self.spiketrain_a
  515. self.spiketrain_a = None
  516. del self.spiketrain_b
  517. self.spiketrain_b = None
  518. def test_complexity_pdf(self):
  519. targ = np.array([0.92, 0.01, 0.01, 0.06])
  520. complexity = es.complexity_pdf(self.spiketrains, binsize=0.1*pq.s)
  521. assert_array_equal(targ, complexity.magnitude[:, 0])
  522. self.assertEqual(1, complexity.magnitude[:, 0].sum())
  523. self.assertEqual(len(self.spiketrains)+1, len(complexity))
  524. self.assertIsInstance(complexity, neo.AnalogSignal)
  525. self.assertEqual(complexity.units, 1*pq.dimensionless)
  526. if __name__ == '__main__':
  527. unittest.main()