test_conversion.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the conversion module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import unittest
  8. import neo
  9. import numpy as np
  10. from numpy.testing.utils import assert_array_almost_equal
  11. import quantities as pq
  12. import elephant.conversion as cv
  13. def get_nearest(times, time):
  14. return (np.abs(times-time)).argmin()
  15. class binarize_TestCase(unittest.TestCase):
  16. def setUp(self):
  17. self.test_array_1d = np.array([1.23, 0.3, 0.87, 0.56])
  18. def test_binarize_with_spiketrain_exact(self):
  19. st = neo.SpikeTrain(self.test_array_1d, units='ms',
  20. t_stop=10.0, sampling_rate=100)
  21. times = np.arange(0, 10.+.01, .01)
  22. target = np.zeros_like(times).astype('bool')
  23. for time in self.test_array_1d:
  24. target[get_nearest(times, time)] = True
  25. times = pq.Quantity(times, units='ms')
  26. res, tres = cv.binarize(st, return_times=True)
  27. assert_array_almost_equal(res, target, decimal=9)
  28. assert_array_almost_equal(tres, times, decimal=9)
  29. def test_binarize_with_spiketrain_exact_set_ends(self):
  30. st = neo.SpikeTrain(self.test_array_1d, units='ms',
  31. t_stop=10.0, sampling_rate=100)
  32. times = np.arange(5., 10.+.01, .01)
  33. target = np.zeros_like(times).astype('bool')
  34. times = pq.Quantity(times, units='ms')
  35. res, tres = cv.binarize(st, return_times=True, t_start=5., t_stop=10.)
  36. assert_array_almost_equal(res, target, decimal=9)
  37. assert_array_almost_equal(tres, times, decimal=9)
  38. def test_binarize_with_spiketrain_round(self):
  39. st = neo.SpikeTrain(self.test_array_1d, units='ms',
  40. t_stop=10.0, sampling_rate=10.0)
  41. times = np.arange(0, 10.+.1, .1)
  42. target = np.zeros_like(times).astype('bool')
  43. for time in np.round(self.test_array_1d, 1):
  44. target[get_nearest(times, time)] = True
  45. times = pq.Quantity(times, units='ms')
  46. res, tres = cv.binarize(st, return_times=True)
  47. assert_array_almost_equal(res, target, decimal=9)
  48. assert_array_almost_equal(tres, times, decimal=9)
  49. def test_binarize_with_quantities_exact(self):
  50. st = pq.Quantity(self.test_array_1d, units='ms')
  51. times = np.arange(0, 1.23+.01, .01)
  52. target = np.zeros_like(times).astype('bool')
  53. for time in self.test_array_1d:
  54. target[get_nearest(times, time)] = True
  55. times = pq.Quantity(times, units='ms')
  56. res, tres = cv.binarize(st, return_times=True,
  57. sampling_rate=100.*pq.kHz)
  58. assert_array_almost_equal(res, target, decimal=9)
  59. assert_array_almost_equal(tres, times, decimal=9)
  60. def test_binarize_with_quantities_exact_set_ends(self):
  61. st = pq.Quantity(self.test_array_1d, units='ms')
  62. times = np.arange(0, 10.+.01, .01)
  63. target = np.zeros_like(times).astype('bool')
  64. for time in self.test_array_1d:
  65. target[get_nearest(times, time)] = True
  66. times = pq.Quantity(times, units='ms')
  67. res, tres = cv.binarize(st, return_times=True, t_stop=10.,
  68. sampling_rate=100.*pq.kHz)
  69. assert_array_almost_equal(res, target, decimal=9)
  70. assert_array_almost_equal(tres, times, decimal=9)
  71. def test_binarize_with_quantities_round_set_ends(self):
  72. st = pq.Quantity(self.test_array_1d, units='ms')
  73. times = np.arange(5., 10.+.1, .1)
  74. target = np.zeros_like(times).astype('bool')
  75. times = pq.Quantity(times, units='ms')
  76. res, tres = cv.binarize(st, return_times=True, t_start=5., t_stop=10.,
  77. sampling_rate=10.*pq.kHz)
  78. assert_array_almost_equal(res, target, decimal=9)
  79. assert_array_almost_equal(tres, times, decimal=9)
  80. def test_binarize_with_plain_array_exact(self):
  81. st = self.test_array_1d
  82. times = np.arange(0, 1.23+.01, .01)
  83. target = np.zeros_like(times).astype('bool')
  84. for time in self.test_array_1d:
  85. target[get_nearest(times, time)] = True
  86. res, tres = cv.binarize(st, return_times=True, sampling_rate=100)
  87. assert_array_almost_equal(res, target, decimal=9)
  88. assert_array_almost_equal(tres, times, decimal=9)
  89. def test_binarize_with_plain_array_exact_set_ends(self):
  90. st = self.test_array_1d
  91. times = np.arange(0, 10.+.01, .01)
  92. target = np.zeros_like(times).astype('bool')
  93. for time in self.test_array_1d:
  94. target[get_nearest(times, time)] = True
  95. res, tres = cv.binarize(st, return_times=True, t_stop=10., sampling_rate=100.)
  96. assert_array_almost_equal(res, target, decimal=9)
  97. assert_array_almost_equal(tres, times, decimal=9)
  98. def test_binarize_no_time(self):
  99. st = self.test_array_1d
  100. times = np.arange(0, 1.23+.01, .01)
  101. target = np.zeros_like(times).astype('bool')
  102. for time in self.test_array_1d:
  103. target[get_nearest(times, time)] = True
  104. res0, tres = cv.binarize(st, return_times=True, sampling_rate=100)
  105. res1 = cv.binarize(st, return_times=False, sampling_rate=100)
  106. res2 = cv.binarize(st, sampling_rate=100)
  107. assert_array_almost_equal(res0, res1, decimal=9)
  108. assert_array_almost_equal(res0, res2, decimal=9)
  109. def test_binariz_rate_with_plain_array_and_units_typeerror(self):
  110. st = self.test_array_1d
  111. self.assertRaises(TypeError, cv.binarize, st,
  112. t_start=pq.Quantity(0, 'ms'),
  113. sampling_rate=10.)
  114. self.assertRaises(TypeError, cv.binarize, st,
  115. t_stop=pq.Quantity(10, 'ms'),
  116. sampling_rate=10.)
  117. self.assertRaises(TypeError, cv.binarize, st,
  118. t_start=pq.Quantity(0, 'ms'),
  119. t_stop=pq.Quantity(10, 'ms'),
  120. sampling_rate=10.)
  121. self.assertRaises(TypeError, cv.binarize, st,
  122. t_start=pq.Quantity(0, 'ms'),
  123. t_stop=10.,
  124. sampling_rate=10.)
  125. self.assertRaises(TypeError, cv.binarize, st,
  126. t_start=0.,
  127. t_stop=pq.Quantity(10, 'ms'),
  128. sampling_rate=10.)
  129. self.assertRaises(TypeError, cv.binarize, st,
  130. sampling_rate=10.*pq.Hz)
  131. def test_binariz_without_sampling_rate_valueerror(self):
  132. st0 = self.test_array_1d
  133. st1 = pq.Quantity(st0, 'ms')
  134. self.assertRaises(ValueError, cv.binarize, st0)
  135. self.assertRaises(ValueError, cv.binarize, st0,
  136. t_start=0)
  137. self.assertRaises(ValueError, cv.binarize, st0,
  138. t_stop=10)
  139. self.assertRaises(ValueError, cv.binarize, st0,
  140. t_start=0, t_stop=10)
  141. self.assertRaises(ValueError, cv.binarize, st1,
  142. t_start=pq.Quantity(0, 'ms'), t_stop=10.)
  143. self.assertRaises(ValueError, cv.binarize, st1,
  144. t_start=0., t_stop=pq.Quantity(10, 'ms'))
  145. self.assertRaises(ValueError, cv.binarize, st1)
  146. class TimeHistogramTestCase(unittest.TestCase):
  147. def setUp(self):
  148. self.spiketrain_a = neo.SpikeTrain(
  149. [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s)
  150. self.spiketrain_b = neo.SpikeTrain(
  151. [0.1, 0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s)
  152. self.binsize = 1 * pq.s
  153. def tearDown(self):
  154. self.spiketrain_a = None
  155. del self.spiketrain_a
  156. self.spiketrain_b = None
  157. del self.spiketrain_b
  158. def test_binned_spiketrain_sparse(self):
  159. a = neo.SpikeTrain([1.7, 1.8, 4.3] * pq.s, t_stop=10.0 * pq.s)
  160. b = neo.SpikeTrain([1.7, 1.8, 4.3] * pq.s, t_stop=10.0 * pq.s)
  161. binsize = 1 * pq.s
  162. nbins = 10
  163. x = cv.BinnedSpikeTrain([a, b], num_bins=nbins, binsize=binsize,
  164. t_start=0 * pq.s)
  165. x_sparse = [2, 1, 2, 1]
  166. s = x.to_sparse_array()
  167. self.assertTrue(np.array_equal(s.data, x_sparse))
  168. self.assertTrue(
  169. np.array_equal(x.spike_indices, [[1, 1, 4], [1, 1, 4]]))
  170. def test_binned_spiketrain_shape(self):
  171. a = self.spiketrain_a
  172. x = cv.BinnedSpikeTrain(a, num_bins=10,
  173. binsize=self.binsize,
  174. t_start=0 * pq.s)
  175. x_bool = cv.BinnedSpikeTrain(a, num_bins=10, binsize=self.binsize,
  176. t_start=0 * pq.s)
  177. self.assertTrue(x.to_array().shape == (1, 10))
  178. self.assertTrue(x_bool.to_bool_array().shape == (1, 10))
  179. # shape of the matrix for a list of spike trains
  180. def test_binned_spiketrain_shape_list(self):
  181. a = self.spiketrain_a
  182. b = self.spiketrain_b
  183. c = [a, b]
  184. nbins = 5
  185. x = cv.BinnedSpikeTrain(c, num_bins=nbins, t_start=0 * pq.s,
  186. t_stop=10.0 * pq.s)
  187. x_bool = cv.BinnedSpikeTrain(c, num_bins=nbins, t_start=0 * pq.s,
  188. t_stop=10.0 * pq.s)
  189. self.assertTrue(x.to_array().shape == (2, 5))
  190. self.assertTrue(x_bool.to_bool_array().shape == (2, 5))
  191. def test_binned_spiketrain_neg_times(self):
  192. a = neo.SpikeTrain(
  193. [-6.5, 0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s,
  194. t_start=-6.5 * pq.s, t_stop=10.0 * pq.s)
  195. binsize = self.binsize
  196. nbins = 16
  197. x = cv.BinnedSpikeTrain(a, num_bins=nbins, binsize=binsize,
  198. t_start=-6.5 * pq.s)
  199. y = [
  200. np.array([1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0])]
  201. self.assertTrue(np.array_equal(x.to_bool_array(), y))
  202. def test_binned_spiketrain_neg_times_list(self):
  203. a = neo.SpikeTrain(
  204. [-6.5, 0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s,
  205. t_start=-7 * pq.s, t_stop=7 * pq.s)
  206. b = neo.SpikeTrain(
  207. [-0.1, -0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s,
  208. t_start=-1 * pq.s, t_stop=8 * pq.s)
  209. c = [a, b]
  210. binsize = self.binsize
  211. x_bool = cv.BinnedSpikeTrain(c, binsize=binsize)
  212. y_bool = [[0, 1, 1, 0, 1, 1, 1, 1],
  213. [1, 0, 1, 1, 0, 1, 1, 0]]
  214. self.assertTrue(
  215. np.array_equal(x_bool.to_bool_array(), y_bool))
  216. # checking spike_indices(f) and matrix(m) for 1 spiketrain
  217. def test_binned_spiketrain_indices(self):
  218. a = self.spiketrain_a
  219. binsize = self.binsize
  220. nbins = 10
  221. x = cv.BinnedSpikeTrain(a, num_bins=nbins, binsize=binsize,
  222. t_start=0 * pq.s)
  223. x_bool = cv.BinnedSpikeTrain(a, num_bins=nbins, binsize=binsize,
  224. t_start=0 * pq.s)
  225. y_matrix = [
  226. np.array([2., 1., 0., 1., 1., 1., 1., 0., 0., 0.])]
  227. y_bool_matrix = [
  228. np.array([1., 1., 0., 1., 1., 1., 1., 0., 0., 0.])]
  229. self.assertTrue(
  230. np.array_equal(x.to_array(),
  231. y_matrix))
  232. self.assertTrue(
  233. np.array_equal(x_bool.to_bool_array(), y_bool_matrix))
  234. self.assertTrue(
  235. np.array_equal(x_bool.to_bool_array(), y_bool_matrix))
  236. s = x_bool.to_sparse_bool_array()[
  237. x_bool.to_sparse_bool_array().nonzero()]
  238. self.assertTrue(np.array_equal(s, [[True]*6]))
  239. def test_binned_spiketrain_list(self):
  240. a = self.spiketrain_a
  241. b = self.spiketrain_b
  242. binsize = self.binsize
  243. nbins = 10
  244. c = [a, b]
  245. x = cv.BinnedSpikeTrain(c, num_bins=nbins, binsize=binsize,
  246. t_start=0 * pq.s)
  247. x_bool = cv.BinnedSpikeTrain(c, num_bins=nbins, binsize=binsize,
  248. t_start=0 * pq.s)
  249. y_matrix = np.array(
  250. [[2, 1, 0, 1, 1, 1, 1, 0, 0, 0],
  251. [2, 1, 1, 0, 1, 1, 0, 0, 1, 0]])
  252. y_matrix_bool = np.array(
  253. [[1, 1, 0, 1, 1, 1, 1, 0, 0, 0],
  254. [1, 1, 1, 0, 1, 1, 0, 0, 1, 0]])
  255. self.assertTrue(
  256. np.array_equal(x.to_array(),
  257. y_matrix))
  258. self.assertTrue(
  259. np.array_equal(x_bool.to_bool_array(), y_matrix_bool))
  260. # t_stop is None
  261. def test_binned_spiketrain_list_t_stop(self):
  262. a = self.spiketrain_a
  263. b = self.spiketrain_b
  264. c = [a, b]
  265. binsize = self.binsize
  266. nbins = 10
  267. x = cv.BinnedSpikeTrain(c, num_bins=nbins, binsize=binsize,
  268. t_start=0 * pq.s,
  269. t_stop=None)
  270. x_bool = cv.BinnedSpikeTrain(c, num_bins=nbins, binsize=binsize,
  271. t_start=0 * pq.s)
  272. self.assertTrue(x.t_stop == 10 * pq.s)
  273. self.assertTrue(x_bool.t_stop == 10 * pq.s)
  274. # Test number of bins
  275. def test_binned_spiketrain_list_numbins(self):
  276. a = self.spiketrain_a
  277. b = self.spiketrain_b
  278. c = [a, b]
  279. binsize = 1 * pq.s
  280. x = cv.BinnedSpikeTrain(c, binsize=binsize, t_start=0 * pq.s,
  281. t_stop=10. * pq.s)
  282. x_bool = cv.BinnedSpikeTrain(c, binsize=binsize, t_start=0 * pq.s,
  283. t_stop=10. * pq.s)
  284. self.assertTrue(x.num_bins == 10)
  285. self.assertTrue(x_bool.num_bins == 10)
  286. def test_binned_spiketrain_matrix(self):
  287. # Init
  288. a = self.spiketrain_a
  289. b = self.spiketrain_b
  290. x_bool_a = cv.BinnedSpikeTrain(a, binsize=pq.s, t_start=0 * pq.s,
  291. t_stop=10. * pq.s)
  292. x_bool_b = cv.BinnedSpikeTrain(b, binsize=pq.s, t_start=0 * pq.s,
  293. t_stop=10. * pq.s)
  294. # Assumed results
  295. y_matrix_a = [
  296. np.array([2, 1, 0, 1, 1, 1, 1, 0, 0, 0])]
  297. y_matrix_bool_a = [np.array([1, 1, 0, 1, 1, 1, 1, 0, 0, 0])]
  298. y_matrix_bool_b = [np.array([1, 1, 1, 0, 1, 1, 0, 0, 1, 0])]
  299. # Asserts
  300. self.assertTrue(
  301. np.array_equal(x_bool_a.to_bool_array(), y_matrix_bool_a))
  302. self.assertTrue(np.array_equal(x_bool_b.to_bool_array(),
  303. y_matrix_bool_b))
  304. self.assertTrue(
  305. np.array_equal(x_bool_a.to_array(), y_matrix_a))
  306. def test_binned_spiketrain_matrix_storing(self):
  307. a = self.spiketrain_a
  308. b = self.spiketrain_b
  309. x_bool = cv.BinnedSpikeTrain(a, binsize=pq.s, t_start=0 * pq.s,
  310. t_stop=10. * pq.s)
  311. x = cv.BinnedSpikeTrain(b, binsize=pq.s, t_start=0 * pq.s,
  312. t_stop=10. * pq.s)
  313. # Store Matrix in variable
  314. matrix_bool = x_bool.to_bool_array()
  315. matrix = x.to_array(store_array=True)
  316. # Check if same matrix
  317. self.assertTrue(np.array_equal(x._mat_u,
  318. matrix))
  319. # Get the stored matrix using method
  320. self.assertTrue(
  321. np.array_equal(x_bool.to_bool_array(),
  322. matrix_bool))
  323. self.assertTrue(
  324. np.array_equal(x.to_array(),
  325. matrix))
  326. # Test storing of sparse mat
  327. sparse_bool = x_bool.to_sparse_bool_array()
  328. self.assertTrue(np.array_equal(sparse_bool.toarray(),
  329. x_bool.to_sparse_bool_array().toarray()))
  330. # New class without calculating the matrix
  331. x = cv.BinnedSpikeTrain(b, binsize=pq.s, t_start=0 * pq.s,
  332. t_stop=10. * pq.s)
  333. # No matrix calculated, should be None
  334. self.assertEqual(x._mat_u, None)
  335. # Test with stored matrix
  336. self.assertFalse(np.array_equal(x, matrix))
  337. # Test matrix removing
  338. def test_binned_spiketrain_remove_matrix(self):
  339. a = self.spiketrain_a
  340. x = cv.BinnedSpikeTrain(a, binsize=1 * pq.s, num_bins=10,
  341. t_stop=10. * pq.s)
  342. # Store
  343. x.to_array(store_array=True)
  344. # Remove
  345. x.remove_stored_array()
  346. # Assert matrix is not stored
  347. self.assertIsNone(x._mat_u)
  348. # Test if t_start is calculated correctly
  349. def test_binned_spiketrain_parameter_calc_tstart(self):
  350. a = self.spiketrain_a
  351. x = cv.BinnedSpikeTrain(a, binsize=1 * pq.s, num_bins=10,
  352. t_stop=10. * pq.s)
  353. self.assertEqual(x.t_start, 0. * pq.s)
  354. self.assertEqual(x.t_stop, 10. * pq.s)
  355. self.assertEqual(x.binsize, 1 * pq.s)
  356. self.assertEqual(x.num_bins, 10)
  357. # Test if error raises when type of num_bins is not an integer
  358. def test_binned_spiketrain_numbins_type_error(self):
  359. a = self.spiketrain_a
  360. self.assertRaises(TypeError, cv.BinnedSpikeTrain, a, binsize=pq.s,
  361. num_bins=1.4, t_start=0 * pq.s,
  362. t_stop=10. * pq.s)
  363. # Test if error is raised when providing insufficient number of
  364. # parameters
  365. def test_binned_spiketrain_insufficient_arguments(self):
  366. a = self.spiketrain_a
  367. self.assertRaises(AttributeError, cv.BinnedSpikeTrain, a)
  368. self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, binsize=1 * pq.s,
  369. t_start=0 * pq.s, t_stop=0 * pq.s)
  370. def test_calc_attributes_error(self):
  371. self.assertRaises(ValueError, cv._calc_num_bins, 1, 1 * pq.s, 0 * pq.s)
  372. self.assertRaises(ValueError, cv._calc_binsize, 1, 1 * pq.s, 0 * pq.s)
  373. def test_different_input_types(self):
  374. a = self.spiketrain_a
  375. q = [1, 2, 3] * pq.s
  376. self.assertRaises(TypeError, cv.BinnedSpikeTrain, [a, q], binsize=pq.s)
  377. def test_get_start_stop(self):
  378. a = self.spiketrain_a
  379. b = neo.SpikeTrain(
  380. [-0.1, -0.7, 1.2, 2.2, 4.3, 5.5, 8.0] * pq.s,
  381. t_start=-1 * pq.s, t_stop=8 * pq.s)
  382. start, stop = cv._get_start_stop_from_input(a)
  383. self.assertEqual(start, a.t_start)
  384. self.assertEqual(stop, a.t_stop)
  385. start, stop = cv._get_start_stop_from_input([a, b])
  386. self.assertEqual(start, a.t_start)
  387. self.assertEqual(stop, b.t_stop)
  388. def test_consistency_errors(self):
  389. a = self.spiketrain_a
  390. b = neo.SpikeTrain([-2, -1] * pq.s, t_start=-2 * pq.s,
  391. t_stop=-1 * pq.s)
  392. self.assertRaises(ValueError, cv.BinnedSpikeTrain, [a, b], t_start=5,
  393. t_stop=0, binsize=pq.s, num_bins=10)
  394. b = neo.SpikeTrain([-7, -8, -9] * pq.s, t_start=-9 * pq.s,
  395. t_stop=-7 * pq.s)
  396. self.assertRaises(ValueError, cv.BinnedSpikeTrain, b, t_start=0,
  397. t_stop=10, binsize=pq.s, num_bins=10)
  398. self.assertRaises(ValueError, cv.BinnedSpikeTrain, a, t_start=0 * pq.s,
  399. t_stop=10 * pq.s, binsize=3 * pq.s, num_bins=10)
  400. b = neo.SpikeTrain([-4, -2, 0, 1] * pq.s, t_start=-4 * pq.s,
  401. t_stop=1 * pq.s)
  402. self.assertRaises(TypeError, cv.BinnedSpikeTrain, b, binsize=-2*pq.s,
  403. t_start=-4 * pq.s, t_stop=0 * pq.s)
  404. # Test edges
  405. def test_binned_spiketrain_bin_edges(self):
  406. a = self.spiketrain_a
  407. x = cv.BinnedSpikeTrain(a, binsize=1 * pq.s, num_bins=10,
  408. t_stop=10. * pq.s)
  409. # Test all edges
  410. edges = [float(i) for i in range(11)]
  411. self.assertTrue(np.array_equal(x.bin_edges, edges))
  412. # Test left edges
  413. edges = [float(i) for i in range(10)]
  414. self.assertTrue(np.array_equal(x.bin_edges[:-1], edges))
  415. # Test right edges
  416. edges = [float(i) for i in range(1, 11)]
  417. self.assertTrue(np.array_equal(x.bin_edges[1:], edges))
  418. # Test center edges
  419. edges = np.arange(0, 10) + 0.5
  420. self.assertTrue(np.array_equal(x.bin_centers, edges))
  421. # Test for different units but same times
  422. def test_binned_spiketrain_different_units(self):
  423. a = self.spiketrain_a
  424. b = a.rescale(pq.ms)
  425. binsize = 1 * pq.s
  426. xa = cv.BinnedSpikeTrain(a, binsize=binsize)
  427. xb = cv.BinnedSpikeTrain(b, binsize=binsize.rescale(pq.ms))
  428. self.assertTrue(
  429. np.array_equal(xa.to_bool_array(), xb.to_bool_array()))
  430. self.assertTrue(
  431. np.array_equal(xa.to_sparse_array().data,
  432. xb.to_sparse_array().data))
  433. self.assertTrue(
  434. np.array_equal(xa.bin_edges[:-1],
  435. xb.bin_edges[:-1].rescale(binsize.units)))
  436. def test_binnend_spiketrain_rescaling(self):
  437. train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s,
  438. t_start=1 * pq.s, t_stop=1.01 * pq.s)
  439. bst = cv.BinnedSpikeTrain(train,
  440. t_start=1 * pq.s, t_stop=1.01 * pq.s,
  441. binsize=1 * pq.ms)
  442. target_edges = np.array([1000, 1001, 1002, 1003, 1004, 1005, 1006,
  443. 1007, 1008, 1009, 1010], dtype=np.float)
  444. target_centers = np.array(
  445. [1000.5, 1001.5, 1002.5, 1003.5, 1004.5, 1005.5, 1006.5, 1007.5,
  446. 1008.5, 1009.5], dtype=np.float)
  447. self.assertTrue(np.allclose(bst.bin_edges.magnitude, target_edges))
  448. self.assertTrue(np.allclose(bst.bin_centers.magnitude, target_centers))
  449. self.assertTrue(bst.bin_centers.units == pq.ms)
  450. self.assertTrue(bst.bin_edges.units == pq.ms)
  451. bst = cv.BinnedSpikeTrain(train,
  452. t_start=1 * pq.s, t_stop=1010 * pq.ms,
  453. binsize=1 * pq.ms)
  454. self.assertTrue(np.allclose(bst.bin_edges.magnitude, target_edges))
  455. self.assertTrue(np.allclose(bst.bin_centers.magnitude, target_centers))
  456. self.assertTrue(bst.bin_centers.units == pq.ms)
  457. self.assertTrue(bst.bin_edges.units == pq.ms)
  458. if __name__ == '__main__':
  459. unittest.main()