test_statistics.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the statistics module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import neo
  9. import numpy as np
  10. from numpy.testing.utils import assert_array_almost_equal, assert_array_equal
  11. import quantities as pq
  12. import scipy.integrate as spint
  13. import elephant.statistics as es
  14. import elephant.kernels as kernels
  15. import warnings
  16. class isi_TestCase(unittest.TestCase):
  17. def setUp(self):
  18. self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23],
  19. [0.02, 0.71, 1.82, 8.46],
  20. [0.03, 0.14, 0.15, 0.92]])
  21. self.targ_array_2d_0 = np.array([[-0.28, 0.15, 0.95, 7.23],
  22. [0.01, -0.57, -1.67, -7.54]])
  23. self.targ_array_2d_1 = np.array([[0.26, 0.31, 0.36],
  24. [0.69, 1.11, 6.64],
  25. [0.11, 0.01, 0.77]])
  26. self.targ_array_2d_default = self.targ_array_2d_1
  27. self.test_array_1d = self.test_array_2d[0, :]
  28. self.targ_array_1d = self.targ_array_2d_1[0, :]
  29. def test_isi_with_spiketrain(self):
  30. st = neo.SpikeTrain(
  31. self.test_array_1d, units='ms', t_stop=10.0, t_start=0.29)
  32. target = pq.Quantity(self.targ_array_1d, 'ms')
  33. res = es.isi(st)
  34. assert_array_almost_equal(res, target, decimal=9)
  35. def test_isi_with_quantities_1d(self):
  36. st = pq.Quantity(self.test_array_1d, units='ms')
  37. target = pq.Quantity(self.targ_array_1d, 'ms')
  38. res = es.isi(st)
  39. assert_array_almost_equal(res, target, decimal=9)
  40. def test_isi_with_plain_array_1d(self):
  41. st = self.test_array_1d
  42. target = self.targ_array_1d
  43. res = es.isi(st)
  44. assert not isinstance(res, pq.Quantity)
  45. assert_array_almost_equal(res, target, decimal=9)
  46. def test_isi_with_plain_array_2d_default(self):
  47. st = self.test_array_2d
  48. target = self.targ_array_2d_default
  49. res = es.isi(st)
  50. assert not isinstance(res, pq.Quantity)
  51. assert_array_almost_equal(res, target, decimal=9)
  52. def test_isi_with_plain_array_2d_0(self):
  53. st = self.test_array_2d
  54. target = self.targ_array_2d_0
  55. res = es.isi(st, axis=0)
  56. assert not isinstance(res, pq.Quantity)
  57. assert_array_almost_equal(res, target, decimal=9)
  58. def test_isi_with_plain_array_2d_1(self):
  59. st = self.test_array_2d
  60. target = self.targ_array_2d_1
  61. res = es.isi(st, axis=1)
  62. assert not isinstance(res, pq.Quantity)
  63. assert_array_almost_equal(res, target, decimal=9)
  64. class isi_cv_TestCase(unittest.TestCase):
  65. def setUp(self):
  66. self.test_array_regular = np.arange(1, 6)
  67. def test_cv_isi_regular_spiketrain_is_zero(self):
  68. st = neo.SpikeTrain(self.test_array_regular, units='ms', t_stop=10.0)
  69. targ = 0.0
  70. res = es.cv(es.isi(st))
  71. self.assertEqual(res, targ)
  72. def test_cv_isi_regular_array_is_zero(self):
  73. st = self.test_array_regular
  74. targ = 0.0
  75. res = es.cv(es.isi(st))
  76. self.assertEqual(res, targ)
  77. class mean_firing_rate_TestCase(unittest.TestCase):
  78. def setUp(self):
  79. self.test_array_3d = np.ones([5, 7, 13])
  80. self.test_array_2d = np.array([[0.3, 0.56, 0.87, 1.23],
  81. [0.02, 0.71, 1.82, 8.46],
  82. [0.03, 0.14, 0.15, 0.92]])
  83. self.targ_array_2d_0 = np.array([3, 3, 3, 3])
  84. self.targ_array_2d_1 = np.array([4, 4, 4])
  85. self.targ_array_2d_None = 12
  86. self.targ_array_2d_default = self.targ_array_2d_None
  87. self.max_array_2d_0 = np.array([0.3, 0.71, 1.82, 8.46])
  88. self.max_array_2d_1 = np.array([1.23, 8.46, 0.92])
  89. self.max_array_2d_None = 8.46
  90. self.max_array_2d_default = self.max_array_2d_None
  91. self.test_array_1d = self.test_array_2d[0, :]
  92. self.targ_array_1d = self.targ_array_2d_1[0]
  93. self.max_array_1d = self.max_array_2d_1[0]
  94. def test_mean_firing_rate_with_spiketrain(self):
  95. st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0)
  96. target = pq.Quantity(self.targ_array_1d/10., '1/ms')
  97. res = es.mean_firing_rate(st)
  98. assert_array_almost_equal(res, target, decimal=9)
  99. def test_mean_firing_rate_with_spiketrain_set_ends(self):
  100. st = neo.SpikeTrain(self.test_array_1d, units='ms', t_stop=10.0)
  101. target = pq.Quantity(2/0.5, '1/ms')
  102. res = es.mean_firing_rate(st, t_start=0.4, t_stop=0.9)
  103. assert_array_almost_equal(res, target, decimal=9)
  104. def test_mean_firing_rate_with_quantities_1d(self):
  105. st = pq.Quantity(self.test_array_1d, units='ms')
  106. target = pq.Quantity(self.targ_array_1d/self.max_array_1d, '1/ms')
  107. res = es.mean_firing_rate(st)
  108. assert_array_almost_equal(res, target, decimal=9)
  109. def test_mean_firing_rate_with_quantities_1d_set_ends(self):
  110. st = pq.Quantity(self.test_array_1d, units='ms')
  111. target = pq.Quantity(2/0.6, '1/ms')
  112. res = es.mean_firing_rate(st, t_start=400*pq.us, t_stop=1.)
  113. assert_array_almost_equal(res, target, decimal=9)
  114. def test_mean_firing_rate_with_plain_array_1d(self):
  115. st = self.test_array_1d
  116. target = self.targ_array_1d/self.max_array_1d
  117. res = es.mean_firing_rate(st)
  118. assert not isinstance(res, pq.Quantity)
  119. assert_array_almost_equal(res, target, decimal=9)
  120. def test_mean_firing_rate_with_plain_array_1d_set_ends(self):
  121. st = self.test_array_1d
  122. target = self.targ_array_1d/(1.23-0.3)
  123. res = es.mean_firing_rate(st, t_start=0.3, t_stop=1.23)
  124. assert not isinstance(res, pq.Quantity)
  125. assert_array_almost_equal(res, target, decimal=9)
  126. def test_mean_firing_rate_with_plain_array_2d_default(self):
  127. st = self.test_array_2d
  128. target = self.targ_array_2d_default/self.max_array_2d_default
  129. res = es.mean_firing_rate(st)
  130. assert not isinstance(res, pq.Quantity)
  131. assert_array_almost_equal(res, target, decimal=9)
  132. def test_mean_firing_rate_with_plain_array_2d_0(self):
  133. st = self.test_array_2d
  134. target = self.targ_array_2d_0/self.max_array_2d_0
  135. res = es.mean_firing_rate(st, axis=0)
  136. assert not isinstance(res, pq.Quantity)
  137. assert_array_almost_equal(res, target, decimal=9)
  138. def test_mean_firing_rate_with_plain_array_2d_1(self):
  139. st = self.test_array_2d
  140. target = self.targ_array_2d_1/self.max_array_2d_1
  141. res = es.mean_firing_rate(st, axis=1)
  142. assert not isinstance(res, pq.Quantity)
  143. assert_array_almost_equal(res, target, decimal=9)
  144. def test_mean_firing_rate_with_plain_array_3d_None(self):
  145. st = self.test_array_3d
  146. target = np.sum(self.test_array_3d, None)/5.
  147. res = es.mean_firing_rate(st, axis=None, t_stop=5.)
  148. assert not isinstance(res, pq.Quantity)
  149. assert_array_almost_equal(res, target, decimal=9)
  150. def test_mean_firing_rate_with_plain_array_3d_0(self):
  151. st = self.test_array_3d
  152. target = np.sum(self.test_array_3d, 0)/5.
  153. res = es.mean_firing_rate(st, axis=0, t_stop=5.)
  154. assert not isinstance(res, pq.Quantity)
  155. assert_array_almost_equal(res, target, decimal=9)
  156. def test_mean_firing_rate_with_plain_array_3d_1(self):
  157. st = self.test_array_3d
  158. target = np.sum(self.test_array_3d, 1)/5.
  159. res = es.mean_firing_rate(st, axis=1, t_stop=5.)
  160. assert not isinstance(res, pq.Quantity)
  161. assert_array_almost_equal(res, target, decimal=9)
  162. def test_mean_firing_rate_with_plain_array_3d_2(self):
  163. st = self.test_array_3d
  164. target = np.sum(self.test_array_3d, 2)/5.
  165. res = es.mean_firing_rate(st, axis=2, t_stop=5.)
  166. assert not isinstance(res, pq.Quantity)
  167. assert_array_almost_equal(res, target, decimal=9)
  168. def test_mean_firing_rate_with_plain_array_2d_1_set_ends(self):
  169. st = self.test_array_2d
  170. target = np.array([4, 1, 3])/(1.23-0.14)
  171. res = es.mean_firing_rate(st, axis=1, t_start=0.14, t_stop=1.23)
  172. assert not isinstance(res, pq.Quantity)
  173. assert_array_almost_equal(res, target, decimal=9)
  174. def test_mean_firing_rate_with_plain_array_2d_None(self):
  175. st = self.test_array_2d
  176. target = self.targ_array_2d_None/self.max_array_2d_None
  177. res = es.mean_firing_rate(st, axis=None)
  178. assert not isinstance(res, pq.Quantity)
  179. assert_array_almost_equal(res, target, decimal=9)
  180. def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror(self):
  181. st = self.test_array_2d
  182. self.assertRaises(TypeError, es.mean_firing_rate, st,
  183. t_start=pq.Quantity(0, 'ms'))
  184. self.assertRaises(TypeError, es.mean_firing_rate, st,
  185. t_stop=pq.Quantity(10, 'ms'))
  186. self.assertRaises(TypeError, es.mean_firing_rate, st,
  187. t_start=pq.Quantity(0, 'ms'),
  188. t_stop=pq.Quantity(10, 'ms'))
  189. self.assertRaises(TypeError, es.mean_firing_rate, st,
  190. t_start=pq.Quantity(0, 'ms'),
  191. t_stop=10.)
  192. self.assertRaises(TypeError, es.mean_firing_rate, st,
  193. t_start=0.,
  194. t_stop=pq.Quantity(10, 'ms'))
  195. class FanoFactorTestCase(unittest.TestCase):
  196. def setUp(self):
  197. np.random.seed(100)
  198. num_st = 300
  199. self.test_spiketrains = []
  200. self.test_array = []
  201. self.test_quantity = []
  202. self.test_list = []
  203. self.sp_counts = np.zeros(num_st)
  204. for i in range(num_st):
  205. r = np.random.rand(np.random.randint(20) + 1)
  206. st = neo.core.SpikeTrain(r * pq.ms,
  207. t_start=0.0 * pq.ms,
  208. t_stop=20.0 * pq.ms)
  209. self.test_spiketrains.append(st)
  210. self.test_array.append(r)
  211. self.test_quantity.append(r * pq.ms)
  212. self.test_list.append(list(r))
  213. # for cross-validation
  214. self.sp_counts[i] = len(st)
  215. def test_fanofactor_spiketrains(self):
  216. # Test with list of spiketrains
  217. self.assertEqual(
  218. np.var(self.sp_counts) / np.mean(self.sp_counts),
  219. es.fanofactor(self.test_spiketrains))
  220. # One spiketrain in list
  221. st = self.test_spiketrains[0]
  222. self.assertEqual(es.fanofactor([st]), 0.0)
  223. def test_fanofactor_empty(self):
  224. # Test with empty list
  225. self.assertTrue(np.isnan(es.fanofactor([])))
  226. self.assertTrue(np.isnan(es.fanofactor([[]])))
  227. # Test with empty quantity
  228. self.assertTrue(np.isnan(es.fanofactor([] * pq.ms)))
  229. # Empty spiketrain
  230. st = neo.core.SpikeTrain([] * pq.ms, t_start=0 * pq.ms,
  231. t_stop=1.5 * pq.ms)
  232. self.assertTrue(np.isnan(es.fanofactor(st)))
  233. def test_fanofactor_spiketrains_same(self):
  234. # Test with same spiketrains in list
  235. sts = [self.test_spiketrains[0]] * 3
  236. self.assertEqual(es.fanofactor(sts), 0.0)
  237. def test_fanofactor_array(self):
  238. self.assertEqual(es.fanofactor(self.test_array),
  239. np.var(self.sp_counts) / np.mean(self.sp_counts))
  240. def test_fanofactor_array_same(self):
  241. lst = [self.test_array[0]] * 3
  242. self.assertEqual(es.fanofactor(lst), 0.0)
  243. def test_fanofactor_quantity(self):
  244. self.assertEqual(es.fanofactor(self.test_quantity),
  245. np.var(self.sp_counts) / np.mean(self.sp_counts))
  246. def test_fanofactor_quantity_same(self):
  247. lst = [self.test_quantity[0]] * 3
  248. self.assertEqual(es.fanofactor(lst), 0.0)
  249. def test_fanofactor_list(self):
  250. self.assertEqual(es.fanofactor(self.test_list),
  251. np.var(self.sp_counts) / np.mean(self.sp_counts))
  252. def test_fanofactor_list_same(self):
  253. lst = [self.test_list[0]] * 3
  254. self.assertEqual(es.fanofactor(lst), 0.0)
  255. class LVTestCase(unittest.TestCase):
  256. def setUp(self):
  257. self.test_seq = [1, 28, 4, 47, 5, 16, 2, 5, 21, 12,
  258. 4, 12, 59, 2, 4, 18, 33, 25, 2, 34,
  259. 4, 1, 1, 14, 8, 1, 10, 1, 8, 20,
  260. 5, 1, 6, 5, 12, 2, 8, 8, 2, 8,
  261. 2, 10, 2, 1, 1, 2, 15, 3, 20, 6,
  262. 11, 6, 18, 2, 5, 17, 4, 3, 13, 6,
  263. 1, 18, 1, 16, 12, 2, 52, 2, 5, 7,
  264. 6, 25, 6, 5, 3, 15, 4, 3, 16, 3,
  265. 6, 5, 24, 21, 3, 3, 4, 8, 4, 11,
  266. 5, 7, 5, 6, 8, 11, 33, 10, 7, 4]
  267. self.target = 0.971826029994
  268. def test_lv_with_quantities(self):
  269. seq = pq.Quantity(self.test_seq, units='ms')
  270. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  271. def test_lv_with_plain_array(self):
  272. seq = np.array(self.test_seq)
  273. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  274. def test_lv_with_list(self):
  275. seq = self.test_seq
  276. assert_array_almost_equal(es.lv(seq), self.target, decimal=9)
  277. def test_lv_raise_error(self):
  278. seq = self.test_seq
  279. self.assertRaises(AttributeError, es.lv, [])
  280. self.assertRaises(AttributeError, es.lv, 1)
  281. self.assertRaises(ValueError, es.lv, np.array([seq, seq]))
  282. class RateEstimationTestCase(unittest.TestCase):
  283. def setUp(self):
  284. # create a poisson spike train:
  285. self.st_tr = (0, 20.0) # seconds
  286. self.st_dur = self.st_tr[1] - self.st_tr[0] # seconds
  287. self.st_margin = 5.0 # seconds
  288. self.st_rate = 10.0 # Hertz
  289. st_num_spikes = np.random.poisson(self.st_rate*(self.st_dur-2*self.st_margin))
  290. spike_train = np.random.rand(st_num_spikes) * (self.st_dur-2*self.st_margin) + self.st_margin
  291. spike_train.sort()
  292. # convert spike train into neo objects
  293. self.spike_train = neo.SpikeTrain(spike_train*pq.s,
  294. t_start=self.st_tr[0]*pq.s,
  295. t_stop=self.st_tr[1]*pq.s)
  296. # generation of a multiply used specific kernel
  297. self.kernel = kernels.TriangularKernel(sigma = 0.03*pq.s)
  298. def test_instantaneous_rate_and_warnings(self):
  299. st = self.spike_train
  300. sampling_period = 0.01*pq.s
  301. with warnings.catch_warnings(record=True) as w:
  302. inst_rate = es.instantaneous_rate(
  303. st, sampling_period, self.kernel, cutoff=0)
  304. self.assertEqual("The width of the kernel was adjusted to a minimally "
  305. "allowed width.", str(w[-2].message))
  306. self.assertEqual("Instantaneous firing rate approximation contains "
  307. "negative values, possibly caused due to machine "
  308. "precision errors.", str(w[-1].message))
  309. self.assertIsInstance(inst_rate, neo.core.AnalogSignal)
  310. self.assertEquals(
  311. inst_rate.sampling_period.simplified, sampling_period.simplified)
  312. self.assertEquals(inst_rate.simplified.units, pq.Hz)
  313. self.assertEquals(inst_rate.t_stop.simplified, st.t_stop.simplified)
  314. self.assertEquals(inst_rate.t_start.simplified, st.t_start.simplified)
  315. def test_error_instantaneous_rate(self):
  316. self.assertRaises(
  317. TypeError, es.instantaneous_rate, spiketrain=[1,2,3]*pq.s,
  318. sampling_period=0.01*pq.ms, kernel=self.kernel)
  319. self.assertRaises(
  320. TypeError, es.instantaneous_rate, spiketrain=[1,2,3],
  321. sampling_period=0.01*pq.ms, kernel=self.kernel)
  322. st = self.spike_train
  323. self.assertRaises(
  324. TypeError, es.instantaneous_rate, spiketrain=st,
  325. sampling_period=0.01, kernel=self.kernel)
  326. self.assertRaises(
  327. ValueError, es.instantaneous_rate, spiketrain=st,
  328. sampling_period=-0.01*pq.ms, kernel=self.kernel)
  329. self.assertRaises(
  330. TypeError, es.instantaneous_rate, spiketrain=st,
  331. sampling_period=0.01*pq.ms, kernel='NONE')
  332. self.assertRaises(TypeError, es.instantaneous_rate, self.spike_train,
  333. sampling_period=0.01*pq.s, kernel='wrong_string',
  334. t_start=self.st_tr[0]*pq.s, t_stop=self.st_tr[1]*pq.s,
  335. trim=False)
  336. self.assertRaises(
  337. TypeError, es.instantaneous_rate, spiketrain=st,
  338. sampling_period=0.01*pq.ms, kernel=self.kernel, cutoff=20*pq.ms)
  339. self.assertRaises(
  340. TypeError, es.instantaneous_rate, spiketrain=st,
  341. sampling_period=0.01*pq.ms, kernel=self.kernel, t_start=2)
  342. self.assertRaises(
  343. TypeError, es.instantaneous_rate, spiketrain=st,
  344. sampling_period=0.01*pq.ms, kernel=self.kernel, t_stop=20*pq.mV)
  345. self.assertRaises(
  346. TypeError, es.instantaneous_rate, spiketrain=st,
  347. sampling_period=0.01*pq.ms, kernel=self.kernel, trim=1)
  348. def test_rate_estimation_consistency(self):
  349. """
  350. Test, whether the integral of the rate estimation curve is (almost)
  351. equal to the number of spikes of the spike train.
  352. """
  353. kernel_types = [obj for obj in kernels.__dict__.values()
  354. if isinstance(obj, type) and
  355. issubclass(obj, kernels.Kernel) and
  356. hasattr(obj, "_evaluate") and
  357. obj is not kernels.Kernel and
  358. obj is not kernels.SymmetricKernel]
  359. kernel_list = [kernel_type(sigma=0.5*pq.s, invert=False)
  360. for kernel_type in kernel_types]
  361. kernel_resolution = 0.01*pq.s
  362. for kernel in kernel_list:
  363. rate_estimate_a0 = es.instantaneous_rate(self.spike_train,
  364. sampling_period=kernel_resolution,
  365. kernel='auto',
  366. t_start=self.st_tr[0]*pq.s,
  367. t_stop=self.st_tr[1]*pq.s,
  368. trim=False)
  369. rate_estimate0 = es.instantaneous_rate(self.spike_train,
  370. sampling_period=kernel_resolution,
  371. kernel=kernel)
  372. rate_estimate1 = es.instantaneous_rate(self.spike_train,
  373. sampling_period=kernel_resolution,
  374. kernel=kernel,
  375. t_start=self.st_tr[0]*pq.s,
  376. t_stop=self.st_tr[1]*pq.s,
  377. trim=False)
  378. rate_estimate2 = es.instantaneous_rate(self.spike_train,
  379. sampling_period=kernel_resolution,
  380. kernel=kernel,
  381. t_start=self.st_tr[0]*pq.s,
  382. t_stop=self.st_tr[1]*pq.s,
  383. trim=True)
  384. ### test consistency
  385. rate_estimate_list = [rate_estimate0, rate_estimate1,
  386. rate_estimate2, rate_estimate_a0]
  387. for rate_estimate in rate_estimate_list:
  388. num_spikes = len(self.spike_train)
  389. auc = spint.cumtrapz(y=rate_estimate.magnitude[:, 0],
  390. x=rate_estimate.times.rescale('s').magnitude)[-1]
  391. self.assertAlmostEqual(num_spikes, auc, delta=0.05*num_spikes)
  392. class TimeHistogramTestCase(unittest.TestCase):
  393. def setUp(self):
  394. self.spiketrain_a = neo.SpikeTrain(
  395. [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s)
  396. self.spiketrain_b = neo.SpikeTrain(
  397. [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  398. self.spiketrains = [self.spiketrain_a, self.spiketrain_b]
  399. def tearDown(self):
  400. del self.spiketrain_a
  401. self.spiketrain_a = None
  402. del self.spiketrain_b
  403. self.spiketrain_b = None
  404. def test_time_histogram(self):
  405. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  406. histogram = es.time_histogram(self.spiketrains, binsize=pq.s)
  407. assert_array_equal(targ, histogram.magnitude[:, 0])
  408. def test_time_histogram_binary(self):
  409. targ = np.array([2, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  410. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  411. binary=True)
  412. assert_array_equal(targ, histogram.magnitude[:, 0])
  413. def test_time_histogram_tstart_tstop(self):
  414. # Start, stop short range
  415. targ = np.array([2, 1])
  416. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  417. t_start=5 * pq.s, t_stop=7 * pq.s)
  418. assert_array_equal(targ, histogram.magnitude[:, 0])
  419. # Test without t_stop
  420. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0])
  421. histogram = es.time_histogram(self.spiketrains, binsize=1 * pq.s,
  422. t_start=0 * pq.s)
  423. assert_array_equal(targ, histogram.magnitude[:, 0])
  424. # Test without t_start
  425. histogram = es.time_histogram(self.spiketrains, binsize=1 * pq.s,
  426. t_stop=10 * pq.s)
  427. assert_array_equal(targ, histogram.magnitude[:, 0])
  428. def test_time_histogram_output(self):
  429. # Normalization mean
  430. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  431. output='mean')
  432. targ = np.array([4, 2, 1, 1, 2, 2, 1, 0, 1, 0], dtype=float) / 2
  433. assert_array_equal(targ.reshape(targ.size, 1), histogram.magnitude)
  434. # Normalization rate
  435. histogram = es.time_histogram(self.spiketrains, binsize=pq.s,
  436. output='rate')
  437. assert_array_equal(histogram.view(pq.Quantity),
  438. targ.reshape(targ.size, 1) * 1 / pq.s)
  439. # Normalization unspecified, raises error
  440. self.assertRaises(ValueError, es.time_histogram, self.spiketrains,
  441. binsize=pq.s, output=' ')
  442. class ComplexityPdfTestCase(unittest.TestCase):
  443. def setUp(self):
  444. self.spiketrain_a = neo.SpikeTrain(
  445. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s)
  446. self.spiketrain_b = neo.SpikeTrain(
  447. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  448. self.spiketrain_c = neo.SpikeTrain(
  449. [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  450. self.spiketrains = [
  451. self.spiketrain_a, self.spiketrain_b, self.spiketrain_c]
  452. def tearDown(self):
  453. del self.spiketrain_a
  454. self.spiketrain_a = None
  455. del self.spiketrain_b
  456. self.spiketrain_b = None
  457. def test_complexity_pdf(self):
  458. targ = np.array([0.92, 0.01, 0.01, 0.06])
  459. complexity = es.complexity_pdf(self.spiketrains, binsize=0.1*pq.s)
  460. assert_array_equal(targ, complexity.magnitude[:, 0])
  461. self.assertEqual(1, complexity.magnitude[:, 0].sum())
  462. self.assertEqual(len(self.spiketrains)+1, len(complexity))
  463. self.assertIsInstance(complexity, neo.AnalogSignal)
  464. self.assertEqual(complexity.units, 1*pq.dimensionless)
  465. if __name__ == '__main__':
  466. unittest.main()