test_spike_train_generation.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the spike_train_generation module.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division
  8. import unittest
  9. import os
  10. import warnings
  11. import neo
  12. import numpy as np
  13. from numpy.testing.utils import assert_array_almost_equal
  14. from scipy.stats import kstest, expon, poisson
  15. from quantities import V, s, ms, second, Hz, kHz, mV, dimensionless
  16. import elephant.spike_train_generation as stgen
  17. from elephant.statistics import isi
  18. from scipy.stats import expon
  19. def pdiff(a, b):
  20. """Difference between a and b as a fraction of a
  21. i.e. abs((a - b)/a)
  22. """
  23. return abs((a - b)/a)
  24. class AnalogSignalThresholdDetectionTestCase(unittest.TestCase):
  25. def setUp(self):
  26. pass
  27. def test_threshold_detection(self):
  28. # Test whether spikes are extracted at the correct times from
  29. # an analog signal.
  30. # Load membrane potential simulated using Brian2
  31. # according to make_spike_extraction_test_data.py.
  32. curr_dir = os.path.dirname(os.path.realpath(__file__))
  33. raw_data_file_loc = os.path.join(curr_dir,'spike_extraction_test_data.txt')
  34. raw_data = []
  35. with open(raw_data_file_loc, 'r') as f:
  36. for x in (f.readlines()):
  37. raw_data.append(float(x))
  38. vm = neo.AnalogSignal(raw_data, units=V, sampling_period=0.1*ms)
  39. spike_train = stgen.threshold_detection(vm)
  40. try:
  41. len(spike_train)
  42. except TypeError: # Handles an error in Neo related to some zero length
  43. # spike trains being treated as unsized objects.
  44. warnings.warn(("The spike train may be an unsized object. This may be related "
  45. "to an issue in Neo with some zero-length SpikeTrain objects. "
  46. "Bypassing this by creating an empty SpikeTrain object."))
  47. spike_train = neo.core.SpikeTrain([],t_start=spike_train.t_start,
  48. t_stop=spike_train.t_stop,
  49. units=spike_train.units)
  50. # Correct values determined previously.
  51. true_spike_train = [0.0123, 0.0354, 0.0712, 0.1191,
  52. 0.1694, 0.22, 0.2711]
  53. # Does threshold_detection gives the correct number of spikes?
  54. self.assertEqual(len(spike_train),len(true_spike_train))
  55. # Does threshold_detection gives the correct times for the spikes?
  56. try:
  57. assert_array_almost_equal(spike_train,spike_train)
  58. except AttributeError: # If numpy version too old to have allclose
  59. self.assertTrue(np.array_equal(spike_train,spike_train))
  60. class AnalogSignalPeakDetectionTestCase(unittest.TestCase):
  61. def setUp(self):
  62. curr_dir = os.path.dirname(os.path.realpath(__file__))
  63. raw_data_file_loc = os.path.join(curr_dir, 'spike_extraction_test_data.txt')
  64. raw_data = []
  65. with open(raw_data_file_loc, 'r') as f:
  66. for x in (f.readlines()):
  67. raw_data.append(float(x))
  68. self.vm = neo.AnalogSignal(raw_data, units=V, sampling_period=0.1*ms)
  69. self.true_time_stamps = [0.0124, 0.0354, 0.0713, 0.1192, 0.1695,
  70. 0.2201, 0.2711] * second
  71. def test_peak_detection_time_stamps(self):
  72. # Test with default arguments
  73. result = stgen.peak_detection(self.vm)
  74. self.assertEqual(len(self.true_time_stamps), len(result))
  75. self.assertIsInstance(result, neo.core.SpikeTrain)
  76. try:
  77. assert_array_almost_equal(result, self.true_time_stamps)
  78. except AttributeError:
  79. self.assertTrue(np.array_equal(result, self.true_time_stamps))
  80. def test_peak_detection_threshold(self):
  81. # Test for empty SpikeTrain when threshold is too high
  82. result = stgen.peak_detection(self.vm, threshold=30 * mV)
  83. self.assertEqual(len(result), 0)
  84. class AnalogSignalSpikeExtractionTestCase(unittest.TestCase):
  85. def setUp(self):
  86. curr_dir = os.path.dirname(os.path.realpath(__file__))
  87. raw_data_file_loc = os.path.join(curr_dir, 'spike_extraction_test_data.txt')
  88. raw_data = []
  89. with open(raw_data_file_loc, 'r') as f:
  90. for x in (f.readlines()):
  91. raw_data.append(float(x))
  92. self.vm = neo.AnalogSignal(raw_data, units=V, sampling_period=0.1*ms)
  93. self.first_spike = np.array([-0.04084546, -0.03892033, -0.03664779,
  94. -0.03392689, -0.03061474, -0.02650277,
  95. -0.0212756, -0.01443531, -0.00515365,
  96. 0.00803962, 0.02797951, -0.07,
  97. -0.06974495, -0.06950466, -0.06927778,
  98. -0.06906314, -0.06885969, -0.06866651,
  99. -0.06848277, -0.06830773, -0.06814071,
  100. -0.06798113, -0.06782843, -0.06768213,
  101. -0.06754178, -0.06740699, -0.06727737,
  102. -0.06715259, -0.06703235, -0.06691635])
  103. def test_spike_extraction_waveform(self):
  104. spike_train = stgen.spike_extraction(self.vm.reshape(-1),
  105. extr_interval = (-1*ms, 2*ms))
  106. try:
  107. assert_array_almost_equal(spike_train.waveforms[0][0].magnitude.reshape(-1),
  108. self.first_spike)
  109. except AttributeError:
  110. self.assertTrue(
  111. np.array_equal(spike_train.waveforms[0][0].magnitude,
  112. self.first_spike))
  113. class HomogeneousPoissonProcessTestCase(unittest.TestCase):
  114. def setUp(self):
  115. pass
  116. def test_statistics(self):
  117. # This is a statistical test that has a non-zero chance of failure
  118. # during normal operation. Thus, we set the random seed to a value that
  119. # creates a realization passing the test.
  120. np.random.seed(seed=12345)
  121. for rate in [123.0*Hz, 0.123*kHz]:
  122. for t_stop in [2345*ms, 2.345*second]:
  123. spiketrain = stgen.homogeneous_poisson_process(rate, t_stop=t_stop)
  124. intervals = isi(spiketrain)
  125. expected_spike_count = int((rate * t_stop).simplified)
  126. self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.2) # should fail about 1 time in 1000
  127. expected_mean_isi = (1/rate)
  128. self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.2)
  129. expected_first_spike = 0*ms
  130. self.assertLess(spiketrain[0] - expected_first_spike, 7*expected_mean_isi)
  131. expected_last_spike = t_stop
  132. self.assertLess(expected_last_spike - spiketrain[-1], 7*expected_mean_isi)
  133. # Kolmogorov-Smirnov test
  134. D, p = kstest(intervals.rescale(t_stop.units),
  135. "expon",
  136. args=(0, expected_mean_isi.rescale(t_stop.units)), # args are (loc, scale)
  137. alternative='two-sided')
  138. self.assertGreater(p, 0.001)
  139. self.assertLess(D, 0.12)
  140. def test_low_rates(self):
  141. spiketrain = stgen.homogeneous_poisson_process(0*Hz, t_stop=1000*ms)
  142. self.assertEqual(spiketrain.size, 0)
  143. # not really a test, just making sure that all code paths are covered
  144. for i in range(10):
  145. spiketrain = stgen.homogeneous_poisson_process(1*Hz, t_stop=1000*ms)
  146. def test_buffer_overrun(self):
  147. np.random.seed(6085) # this seed should produce a buffer overrun
  148. t_stop=1000*ms
  149. rate = 10*Hz
  150. spiketrain = stgen.homogeneous_poisson_process(rate, t_stop=t_stop)
  151. expected_last_spike = t_stop
  152. expected_mean_isi = (1/rate).rescale(ms)
  153. self.assertLess(expected_last_spike - spiketrain[-1], 4*expected_mean_isi)
  154. class InhomogeneousPoissonProcessTestCase(unittest.TestCase):
  155. def setUp(self):
  156. rate_list = [[20] for i in range(1000)] + [[200] for i in range(1000)]
  157. self.rate_profile = neo.AnalogSignal(
  158. rate_list * Hz, sampling_period=0.001*s)
  159. rate_0 = [[0] for i in range(1000)]
  160. self.rate_profile_0 = neo.AnalogSignal(
  161. rate_0 * Hz, sampling_period=0.001*s)
  162. rate_negative = [[-1] for i in range(1000)]
  163. self.rate_profile_negative = neo.AnalogSignal(
  164. rate_negative * Hz, sampling_period=0.001 * s)
  165. pass
  166. def test_statistics(self):
  167. # This is a statistical test that has a non-zero chance of failure
  168. # during normal operation. Thus, we set the random seed to a value that
  169. # creates a realization passing the test.
  170. np.random.seed(seed=12345)
  171. for rate in [self.rate_profile, self.rate_profile.rescale(kHz)]:
  172. spiketrain = stgen.inhomogeneous_poisson_process(rate)
  173. intervals = isi(spiketrain)
  174. # Computing expected statistics and percentiles
  175. expected_spike_count = (np.sum(
  176. rate) * rate.sampling_period).simplified
  177. percentile_count = poisson.ppf(.999, expected_spike_count)
  178. expected_min_isi = (1 / np.min(rate))
  179. expected_max_isi = (1 / np.max(rate))
  180. percentile_min_isi = expon.ppf(.999, expected_min_isi)
  181. percentile_max_isi = expon.ppf(.999, expected_max_isi)
  182. # Testing (each should fail 1 every 1000 times)
  183. self.assertLess(spiketrain.size, percentile_count)
  184. self.assertLess(np.min(intervals), percentile_min_isi)
  185. self.assertLess(np.max(intervals), percentile_max_isi)
  186. # Testing t_start t_stop
  187. self.assertEqual(rate.t_stop, spiketrain.t_stop)
  188. self.assertEqual(rate.t_start, spiketrain.t_start)
  189. # Testing type
  190. spiketrain_as_array = stgen.inhomogeneous_poisson_process(
  191. rate, as_array=True)
  192. self.assertTrue(isinstance(spiketrain_as_array, np.ndarray))
  193. self.assertTrue(isinstance(spiketrain, neo.SpikeTrain))
  194. def test_low_rates(self):
  195. spiketrain = stgen.inhomogeneous_poisson_process(self.rate_profile_0)
  196. self.assertEqual(spiketrain.size, 0)
  197. def test_negative_rates(self):
  198. self.assertRaises(
  199. ValueError, stgen.inhomogeneous_poisson_process,
  200. self.rate_profile_negative)
  201. class HomogeneousGammaProcessTestCase(unittest.TestCase):
  202. def setUp(self):
  203. pass
  204. def test_statistics(self):
  205. # This is a statistical test that has a non-zero chance of failure
  206. # during normal operation. Thus, we set the random seed to a value that
  207. # creates a realization passing the test.
  208. np.random.seed(seed=12345)
  209. a = 3.0
  210. for b in (67.0*Hz, 0.067*kHz):
  211. for t_stop in (2345*ms, 2.345*second):
  212. spiketrain = stgen.homogeneous_gamma_process(a, b, t_stop=t_stop)
  213. intervals = isi(spiketrain)
  214. expected_spike_count = int((b/a * t_stop).simplified)
  215. self.assertLess(pdiff(expected_spike_count, spiketrain.size), 0.25) # should fail about 1 time in 1000
  216. expected_mean_isi = (a/b).rescale(ms)
  217. self.assertLess(pdiff(expected_mean_isi, intervals.mean()), 0.3)
  218. expected_first_spike = 0*ms
  219. self.assertLess(spiketrain[0] - expected_first_spike, 4*expected_mean_isi)
  220. expected_last_spike = t_stop
  221. self.assertLess(expected_last_spike - spiketrain[-1], 4*expected_mean_isi)
  222. # Kolmogorov-Smirnov test
  223. D, p = kstest(intervals.rescale(t_stop.units),
  224. "gamma",
  225. args=(a, 0, (1/b).rescale(t_stop.units)), # args are (a, loc, scale)
  226. alternative='two-sided')
  227. self.assertGreater(p, 0.001)
  228. self.assertLess(D, 0.25)
  229. class _n_poisson_TestCase(unittest.TestCase):
  230. def setUp(self):
  231. self.n = 4
  232. self.rate = 10*Hz
  233. self.rates = range(1, self.n + 1)*Hz
  234. self.t_stop = 10000*ms
  235. def test_poisson(self):
  236. # Check the output types for input rate + n number of neurons
  237. pp = stgen._n_poisson(rate=self.rate, t_stop=self.t_stop, n=self.n)
  238. self.assertIsInstance(pp, list)
  239. self.assertIsInstance(pp[0], neo.core.spiketrain.SpikeTrain)
  240. self.assertEqual(pp[0].simplified.units, 1000*ms)
  241. self.assertEqual(len(pp), self.n)
  242. # Check the output types for input list of rates
  243. pp = stgen._n_poisson(rate=self.rates, t_stop=self.t_stop)
  244. self.assertIsInstance(pp, list)
  245. self.assertIsInstance(pp[0], neo.core.spiketrain.SpikeTrain)
  246. self.assertEqual(pp[0].simplified.units, 1000*ms)
  247. self.assertEqual(len(pp), self.n)
  248. def test_poisson_error(self):
  249. # Dimensionless rate
  250. self.assertRaises(
  251. ValueError, stgen._n_poisson, rate=5, t_stop=self.t_stop)
  252. # Negative rate
  253. self.assertRaises(
  254. ValueError, stgen._n_poisson, rate=-5*Hz, t_stop=self.t_stop)
  255. # Negative value when rate is a list
  256. self.assertRaises(
  257. ValueError, stgen._n_poisson, rate=[-5, 3]*Hz, t_stop=self.t_stop)
  258. # Negative n
  259. self.assertRaises(
  260. ValueError, stgen._n_poisson, rate=self.rate, t_stop=self.t_stop,
  261. n=-1)
  262. # t_start>t_stop
  263. self.assertRaises(
  264. ValueError, stgen._n_poisson, rate=self.rate, t_start=4*ms,
  265. t_stop=3*ms, n=3)
  266. class singleinteractionprocess_TestCase(unittest.TestCase):
  267. def setUp(self):
  268. self.n = 4
  269. self.rate = 10*Hz
  270. self.rates = range(1, self.n + 1)*Hz
  271. self.t_stop = 10000*ms
  272. self.rate_c = 1*Hz
  273. def test_sip(self):
  274. # Generate an example SIP mode
  275. sip, coinc = stgen.single_interaction_process(
  276. n=self.n, t_stop=self.t_stop, rate=self.rate,
  277. rate_c=self.rate_c, return_coinc=True)
  278. # Check the output types
  279. self.assertEqual(type(sip), list)
  280. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  281. self.assertEqual(type(coinc[0]), neo.core.spiketrain.SpikeTrain)
  282. self.assertEqual(sip[0].simplified.units, 1000*ms)
  283. self.assertEqual(coinc[0].simplified.units, 1000*ms)
  284. # Check the output length
  285. self.assertEqual(len(sip), self.n)
  286. self.assertEqual(
  287. len(coinc[0]), (self.rate_c*self.t_stop).rescale(dimensionless))
  288. # Generate an example SIP mode giving a list of rates as imput
  289. sip, coinc = stgen.single_interaction_process(
  290. t_stop=self.t_stop, rate=self.rates,
  291. rate_c=self.rate_c, return_coinc=True)
  292. # Check the output types
  293. self.assertEqual(type(sip), list)
  294. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  295. self.assertEqual(type(coinc[0]), neo.core.spiketrain.SpikeTrain)
  296. self.assertEqual(sip[0].simplified.units, 1000*ms)
  297. self.assertEqual(coinc[0].simplified.units, 1000*ms)
  298. # Check the output length
  299. self.assertEqual(len(sip), self.n)
  300. self.assertEqual(
  301. len(coinc[0]), (self.rate_c*self.t_stop).rescale(dimensionless))
  302. # Generate an example SIP mode stochastic number of coincidences
  303. sip = stgen.single_interaction_process(
  304. n=self.n, t_stop=self.t_stop, rate=self.rate,
  305. rate_c=self.rate_c, coincidences='stochastic', return_coinc=False)
  306. # Check the output types
  307. self.assertEqual(type(sip), list)
  308. self.assertEqual(type(sip[0]), neo.core.spiketrain.SpikeTrain)
  309. self.assertEqual(sip[0].simplified.units, 1000*ms)
  310. def test_sip_error(self):
  311. # Negative rate
  312. self.assertRaises(
  313. ValueError, stgen.single_interaction_process, n=self.n, rate=-5*Hz,
  314. rate_c=self.rate_c, t_stop=self.t_stop)
  315. # Negative coincidence rate
  316. self.assertRaises(
  317. ValueError, stgen.single_interaction_process, n=self.n,
  318. rate=self.rate, rate_c=-3*Hz, t_stop=self.t_stop)
  319. # Negative value when rate is a list
  320. self.assertRaises(
  321. ValueError, stgen.single_interaction_process, n=self.n,
  322. rate=[-5, 3, 4, 2]*Hz, rate_c=self.rate_c, t_stop=self.t_stop)
  323. # Negative n
  324. self.assertRaises(
  325. ValueError, stgen.single_interaction_process, n=-1,
  326. rate=self.rate, rate_c=self.rate_c, t_stop=self.t_stop)
  327. # Rate_c < rate
  328. self.assertRaises(
  329. ValueError, stgen.single_interaction_process, n=self.n,
  330. rate=self.rate, rate_c=self.rate + 1*Hz, t_stop=self.t_stop)
  331. class cppTestCase(unittest.TestCase):
  332. def test_cpp_hom(self):
  333. # testing output with generic inputs
  334. A = [0, .9, .1]
  335. t_stop = 10 * 1000 * ms
  336. t_start = 5 * 1000 * ms
  337. rate = 3 * Hz
  338. cpp_hom = stgen.cpp(rate, A, t_stop, t_start=t_start)
  339. # testing the ouput formats
  340. self.assertEqual(
  341. [type(train) for train in cpp_hom], [neo.SpikeTrain]*len(cpp_hom))
  342. self.assertEqual(cpp_hom[0].simplified.units, 1000 * ms)
  343. self.assertEqual(type(cpp_hom), list)
  344. # testing quantities format of the output
  345. self.assertEqual(
  346. [train.simplified.units for train in cpp_hom], [1000 * ms]*len(
  347. cpp_hom))
  348. # testing output t_start t_stop
  349. for st in cpp_hom:
  350. self.assertEqual(st.t_stop, t_stop)
  351. self.assertEqual(st.t_start, t_start)
  352. self.assertEqual(len(cpp_hom), len(A) - 1)
  353. # testing the units
  354. A = [0, 0.9, 0.1]
  355. t_stop = 10000*ms
  356. t_start = 5 * 1000 * ms
  357. rate = 3 * Hz
  358. cpp_unit = stgen.cpp(rate, A, t_stop, t_start=t_start)
  359. self.assertEqual(cpp_unit[0].units, t_stop.units)
  360. self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units)
  361. self.assertEqual(cpp_unit[0].t_start.units, t_stop.units)
  362. # testing output without copy of spikes
  363. A = [1]
  364. t_stop = 10 * 1000 * ms
  365. t_start = 5 * 1000 * ms
  366. rate = 3 * Hz
  367. cpp_hom_empty = stgen.cpp(rate, A, t_stop, t_start=t_start)
  368. self.assertEqual(
  369. [len(train) for train in cpp_hom_empty], [0]*len(cpp_hom_empty))
  370. # testing output with rate equal to 0
  371. A = [0, .9, .1]
  372. t_stop = 10 * 1000 * ms
  373. t_start = 5 * 1000 * ms
  374. rate = 0 * Hz
  375. cpp_hom_empty_r = stgen.cpp(rate, A, t_stop, t_start=t_start)
  376. self.assertEqual(
  377. [len(train) for train in cpp_hom_empty_r], [0]*len(
  378. cpp_hom_empty_r))
  379. # testing output with same spike trains in output
  380. A = [0, 0, 1]
  381. t_stop = 10 * 1000 * ms
  382. t_start = 5 * 1000 * ms
  383. rate = 3 * Hz
  384. cpp_hom_eq = stgen.cpp(rate, A, t_stop, t_start=t_start)
  385. self.assertTrue(
  386. np.allclose(cpp_hom_eq[0].magnitude, cpp_hom_eq[1].magnitude))
  387. def test_cpp_hom_errors(self):
  388. # testing raises of ValueError (wrong inputs)
  389. # testing empty amplitude
  390. self.assertRaises(
  391. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=3*Hz)
  392. # testing sum of amplitude>1
  393. self.assertRaises(
  394. ValueError, stgen.cpp, A=[1, 1, 1], t_stop=10*1000 * ms, rate=3*Hz)
  395. # testing negative value in the amplitude
  396. self.assertRaises(
  397. ValueError, stgen.cpp, A=[-1, 1, 1], t_stop=10*1000 * ms,
  398. rate=3*Hz)
  399. # test negative rate
  400. self.assertRaises(
  401. AssertionError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  402. rate=-3*Hz)
  403. # test wrong unit for rate
  404. self.assertRaises(
  405. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  406. rate=3*1000 * ms)
  407. # testing raises of AttributeError (missing input units)
  408. # Testing missing unit to t_stop
  409. self.assertRaises(
  410. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10, rate=3*Hz)
  411. # Testing missing unit to t_start
  412. self.assertRaises(
  413. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms, rate=3*Hz,
  414. t_start=3)
  415. # testing rate missing unit
  416. self.assertRaises(
  417. AttributeError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  418. rate=3)
  419. def test_cpp_het(self):
  420. # testing output with generic inputs
  421. A = [0, .9, .1]
  422. t_stop = 10 * 1000 * ms
  423. t_start = 5 * 1000 * ms
  424. rate = [3, 4] * Hz
  425. cpp_het = stgen.cpp(rate, A, t_stop, t_start=t_start)
  426. # testing the ouput formats
  427. self.assertEqual(
  428. [type(train) for train in cpp_het], [neo.SpikeTrain]*len(cpp_het))
  429. self.assertEqual(cpp_het[0].simplified.units, 1000 * ms)
  430. self.assertEqual(type(cpp_het), list)
  431. # testing units
  432. self.assertEqual(
  433. [train.simplified.units for train in cpp_het], [1000 * ms]*len(
  434. cpp_het))
  435. # testing output t_start and t_stop
  436. for st in cpp_het:
  437. self.assertEqual(st.t_stop, t_stop)
  438. self.assertEqual(st.t_start, t_start)
  439. # testing the number of output spiketrains
  440. self.assertEqual(len(cpp_het), len(A) - 1)
  441. self.assertEqual(len(cpp_het), len(rate))
  442. # testing the units
  443. A = [0, 0.9, 0.1]
  444. t_stop = 10000*ms
  445. t_start = 5 * 1000 * ms
  446. rate = [3, 4] * Hz
  447. cpp_unit = stgen.cpp(rate, A, t_stop, t_start=t_start)
  448. self.assertEqual(cpp_unit[0].units, t_stop.units)
  449. self.assertEqual(cpp_unit[0].t_stop.units, t_stop.units)
  450. self.assertEqual(cpp_unit[0].t_start.units, t_stop.units)
  451. # testing without copying any spikes
  452. A = [1, 0, 0]
  453. t_stop = 10 * 1000 * ms
  454. t_start = 5 * 1000 * ms
  455. rate = [3, 4] * Hz
  456. cpp_het_empty = stgen.cpp(rate, A, t_stop, t_start=t_start)
  457. self.assertEqual(len(cpp_het_empty[0]), 0)
  458. # testing output with rate equal to 0
  459. A = [0, .9, .1]
  460. t_stop = 10 * 1000 * ms
  461. t_start = 5 * 1000 * ms
  462. rate = [0, 0] * Hz
  463. cpp_het_empty_r = stgen.cpp(rate, A, t_stop, t_start=t_start)
  464. self.assertEqual(
  465. [len(train) for train in cpp_het_empty_r], [0]*len(
  466. cpp_het_empty_r))
  467. # testing completely sync spiketrains
  468. A = [0, 0, 1]
  469. t_stop = 10 * 1000 * ms
  470. t_start = 5 * 1000 * ms
  471. rate = [3, 3] * Hz
  472. cpp_het_eq = stgen.cpp(rate, A, t_stop, t_start=t_start)
  473. self.assertTrue(np.allclose(
  474. cpp_het_eq[0].magnitude, cpp_het_eq[1].magnitude))
  475. def test_cpp_het_err(self):
  476. # testing raises of ValueError (wrong inputs)
  477. # testing empty amplitude
  478. self.assertRaises(
  479. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  480. # testing sum amplitude>1
  481. self.assertRaises(
  482. ValueError, stgen.cpp, A=[1, 1, 1], t_stop=10*1000 * ms,
  483. rate=[3, 4]*Hz)
  484. # testing amplitude negative value
  485. self.assertRaises(
  486. ValueError, stgen.cpp, A=[-1, 1, 1], t_stop=10*1000 * ms,
  487. rate=[3, 4]*Hz)
  488. # testing negative rate
  489. self.assertRaises(
  490. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  491. rate=[-3, 4]*Hz)
  492. # testing empty rate
  493. self.assertRaises(
  494. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms, rate=[]*Hz)
  495. # testing empty amplitude
  496. self.assertRaises(
  497. ValueError, stgen.cpp, A=[], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  498. # testing different len(A)-1 and len(rate)
  499. self.assertRaises(
  500. ValueError, stgen.cpp, A=[0, 1], t_stop=10*1000 * ms, rate=[3, 4]*Hz)
  501. # testing rate with different unit from Hz
  502. self.assertRaises(
  503. ValueError, stgen.cpp, A=[0, 1], t_stop=10*1000 * ms,
  504. rate=[3, 4]*1000 * ms)
  505. # Testing analytical constrain between amplitude and rate
  506. self.assertRaises(
  507. ValueError, stgen.cpp, A=[0, 0, 1], t_stop=10*1000 * ms,
  508. rate=[3, 4]*Hz, t_start=3)
  509. # testing raises of AttributeError (missing input units)
  510. # Testing missing unit to t_stop
  511. self.assertRaises(
  512. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10, rate=[3, 4]*Hz)
  513. # Testing missing unit to t_start
  514. self.assertRaises(
  515. ValueError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  516. rate=[3, 4]*Hz, t_start=3)
  517. # Testing missing unit to rate
  518. self.assertRaises(
  519. AttributeError, stgen.cpp, A=[0, 1, 0], t_stop=10*1000 * ms,
  520. rate=[3, 4])
  521. def test_cpp_jttered(self):
  522. # testing output with generic inputs
  523. A = [0, .9, .1]
  524. t_stop = 10 * 1000 * ms
  525. t_start = 5 * 1000 * ms
  526. rate = 3 * Hz
  527. cpp_shift = stgen.cpp(
  528. rate, A, t_stop, t_start=t_start, shift=3*ms)
  529. # testing the ouput formats
  530. self.assertEqual(
  531. [type(train) for train in cpp_shift], [neo.SpikeTrain]*len(
  532. cpp_shift))
  533. self.assertEqual(cpp_shift[0].simplified.units, 1000 * ms)
  534. self.assertEqual(type(cpp_shift), list)
  535. # testing quantities format of the output
  536. self.assertEqual(
  537. [train.simplified.units for train in cpp_shift],
  538. [1000 * ms]*len(cpp_shift))
  539. # testing output t_start t_stop
  540. for st in cpp_shift:
  541. self.assertEqual(st.t_stop, t_stop)
  542. self.assertEqual(st.t_start, t_start)
  543. self.assertEqual(len(cpp_shift), len(A) - 1)
  544. if __name__ == '__main__':
  545. unittest.main()