test_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. """
  2. Tests of the neo.utils module
  3. """
  4. import unittest
  5. import numpy as np
  6. import quantities as pq
  7. from neo.rawio.examplerawio import ExampleRawIO
  8. from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy,
  9. EventProxy, EpochProxy)
  10. from neo.core.dataobject import ArrayDict
  11. from neo.core import (Block, Segment, AnalogSignal, IrregularlySampledSignal,
  12. Epoch, Event, SpikeTrain)
  13. from neo.test.tools import (assert_arrays_almost_equal,
  14. assert_arrays_equal,
  15. assert_neo_object_is_compliant,
  16. assert_same_attributes,
  17. assert_same_annotations)
  18. from neo.utils import (get_events, get_epochs, add_epoch, match_events, cut_block_by_epochs)
  19. class BaseProxyTest(unittest.TestCase):
  20. def setUp(self):
  21. self.reader = ExampleRawIO(filename='my_filename.fake')
  22. self.reader.parse_header()
  23. class TestUtilsWithoutProxyObjects(unittest.TestCase):
  24. def test__get_events(self):
  25. starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s,
  26. labels=['label1', 'label2', 'label3'],
  27. name='pick_me')
  28. starts_1.annotate(event_type='trial start')
  29. starts_1.array_annotate(trial_id=[1, 2, 3])
  30. stops_1 = Event(times=[5.5, 14.9, 30.1] * pq.s)
  31. stops_1.annotate(event_type='trial stop')
  32. stops_1.array_annotate(trial_id=[1, 2, 3])
  33. starts_2 = Event(times=[33.2, 41.7, 52.4] * pq.s)
  34. starts_2.annotate(event_type='trial start')
  35. starts_2.array_annotate(trial_id=[4, 5, 6])
  36. stops_2 = Event(times=[37.6, 46.1, 57.0] * pq.s)
  37. stops_2.annotate(event_type='trial stop')
  38. stops_2.array_annotate(trial_id=[4, 5, 6])
  39. seg = Segment()
  40. seg2 = Segment()
  41. seg.events = [starts_1, stops_1]
  42. seg2.events = [starts_2, stops_2]
  43. block = Block()
  44. block.segments = [seg, seg2]
  45. # test getting one whole event via annotation or attribute
  46. extracted_starts1 = get_events(seg, event_type='trial start')
  47. extracted_starts1b = get_events(block, name='pick_me')
  48. self.assertEqual(len(extracted_starts1), 1)
  49. self.assertEqual(len(extracted_starts1b), 1)
  50. extracted_starts1 = extracted_starts1[0]
  51. extracted_starts1b = extracted_starts1b[0]
  52. assert_same_attributes(extracted_starts1, starts_1)
  53. assert_same_attributes(extracted_starts1b, starts_1)
  54. # test getting an empty list by searching for a non-existent property
  55. empty1 = get_events(seg, foo='bar')
  56. self.assertEqual(len(empty1), 0)
  57. # test getting an empty list by searching for a non-existent property value
  58. empty2 = get_events(seg, event_type='undefined')
  59. self.assertEqual(len(empty2), 0)
  60. # test getting only one event time of one event
  61. trial_2 = get_events(block, trial_id=2, event_type='trial start')
  62. self.assertEqual(len(trial_2), 1)
  63. trial_2 = trial_2[0]
  64. self.assertEqual(starts_1.name, trial_2.name)
  65. self.assertEqual(starts_1.description, trial_2.description)
  66. self.assertEqual(starts_1.file_origin, trial_2.file_origin)
  67. self.assertEqual(starts_1.annotations['event_type'], trial_2.annotations['event_type'])
  68. assert_arrays_equal(trial_2.array_annotations['trial_id'], np.array([2]))
  69. self.assertIsInstance(trial_2.array_annotations, ArrayDict)
  70. # test getting only one event time of more than one event
  71. trial_2b = get_events(block, trial_id=2)
  72. self.assertEqual(len(trial_2b), 2)
  73. start_idx = np.where(np.array([ev.annotations['event_type']
  74. for ev in trial_2b]) == 'trial start')[0][0]
  75. trial_2b_start = trial_2b[start_idx]
  76. trial_2b_stop = trial_2b[start_idx - 1]
  77. assert_same_attributes(trial_2b_start, trial_2)
  78. self.assertEqual(stops_1.name, trial_2b_stop.name)
  79. self.assertEqual(stops_1.description, trial_2b_stop.description)
  80. self.assertEqual(stops_1.file_origin, trial_2b_stop.file_origin)
  81. self.assertEqual(stops_1.annotations['event_type'],
  82. trial_2b_stop.annotations['event_type'])
  83. assert_arrays_equal(trial_2b_stop.array_annotations['trial_id'], np.array([2]))
  84. self.assertIsInstance(trial_2b_stop.array_annotations, ArrayDict)
  85. # test getting more than one event time of one event
  86. trials_1_2 = get_events(block, trial_id=[1, 2], event_type='trial start')
  87. self.assertEqual(len(trials_1_2), 1)
  88. trials_1_2 = trials_1_2[0]
  89. self.assertEqual(starts_1.name, trials_1_2.name)
  90. self.assertEqual(starts_1.description, trials_1_2.description)
  91. self.assertEqual(starts_1.file_origin, trials_1_2.file_origin)
  92. self.assertEqual(starts_1.annotations['event_type'], trials_1_2.annotations['event_type'])
  93. assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2]))
  94. self.assertIsInstance(trials_1_2.array_annotations, ArrayDict)
  95. # test selecting event times by label
  96. trials_1_2 = get_events(block, labels=['label1', 'label2'])
  97. self.assertEqual(len(trials_1_2), 1)
  98. trials_1_2 = trials_1_2[0]
  99. self.assertEqual(starts_1.name, trials_1_2.name)
  100. self.assertEqual(starts_1.description, trials_1_2.description)
  101. self.assertEqual(starts_1.file_origin, trials_1_2.file_origin)
  102. self.assertEqual(starts_1.annotations['event_type'], trials_1_2.annotations['event_type'])
  103. assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2]))
  104. self.assertIsInstance(trials_1_2.array_annotations, ArrayDict)
  105. # test getting more than one event time of more than one event
  106. trials_1_2b = get_events(block, trial_id=[1, 2])
  107. self.assertEqual(len(trials_1_2b), 2)
  108. start_idx = np.where(np.array([ev.annotations['event_type']
  109. for ev in trials_1_2b]) == 'trial start')[0][0]
  110. trials_1_2b_start = trials_1_2b[start_idx]
  111. trials_1_2b_stop = trials_1_2b[start_idx - 1]
  112. assert_same_attributes(trials_1_2b_start, trials_1_2)
  113. self.assertEqual(stops_1.name, trials_1_2b_stop.name)
  114. self.assertEqual(stops_1.description, trials_1_2b_stop.description)
  115. self.assertEqual(stops_1.file_origin, trials_1_2b_stop.file_origin)
  116. self.assertEqual(stops_1.annotations['event_type'],
  117. trials_1_2b_stop.annotations['event_type'])
  118. assert_arrays_equal(trials_1_2b_stop.array_annotations['trial_id'], np.array([1, 2]))
  119. self.assertIsInstance(trials_1_2b_stop.array_annotations, ArrayDict)
  120. def test__get_epochs(self):
  121. a_1 = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s)
  122. a_1.annotate(epoch_type='a', pick='me')
  123. a_1.array_annotate(trial_id=[1, 2, 3])
  124. b_1 = Epoch([5.5, 14.9, 30.1] * pq.s, durations=[4.7, 4.9, 5.2] * pq.s)
  125. b_1.annotate(epoch_type='b')
  126. b_1.array_annotate(trial_id=[1, 2, 3])
  127. a_2 = Epoch([33.2, 41.7, 52.4] * pq.s, durations=[5.3, 5.0, 5.1] * pq.s)
  128. a_2.annotate(epoch_type='a')
  129. a_2.array_annotate(trial_id=[4, 5, 6])
  130. b_2 = Epoch([37.6, 46.1, 57.0] * pq.s, durations=[4.9, 5.2, 5.1] * pq.s)
  131. b_2.annotate(epoch_type='b')
  132. b_2.array_annotate(trial_id=[4, 5, 6])
  133. seg = Segment()
  134. seg2 = Segment()
  135. seg.epochs = [a_1, b_1]
  136. seg2.epochs = [a_2, b_2]
  137. block = Block()
  138. block.segments = [seg, seg2]
  139. # test getting one whole event via annotation
  140. extracted_a_1 = get_epochs(seg, epoch_type='a')
  141. extracted_a_1b = get_epochs(block, pick='me')
  142. self.assertEqual(len(extracted_a_1), 1)
  143. self.assertEqual(len(extracted_a_1b), 1)
  144. extracted_a_1 = extracted_a_1[0]
  145. extracted_a_1b = extracted_a_1b[0]
  146. assert_same_attributes(extracted_a_1, a_1)
  147. assert_same_attributes(extracted_a_1b, a_1)
  148. # test getting an empty list by searching for a non-existent property
  149. empty1 = get_epochs(seg, foo='bar')
  150. self.assertEqual(len(empty1), 0)
  151. # test getting an empty list by searching for a non-existent property value
  152. empty2 = get_epochs(seg, epoch_type='undefined')
  153. self.assertEqual(len(empty2), 0)
  154. # test getting only one event time of one event
  155. trial_2 = get_epochs(block, trial_id=2, epoch_type='a')
  156. self.assertEqual(len(trial_2), 1)
  157. trial_2 = trial_2[0]
  158. self.assertEqual(a_1.name, trial_2.name)
  159. self.assertEqual(a_1.description, trial_2.description)
  160. self.assertEqual(a_1.file_origin, trial_2.file_origin)
  161. self.assertEqual(a_1.annotations['epoch_type'], trial_2.annotations['epoch_type'])
  162. assert_arrays_equal(trial_2.array_annotations['trial_id'], np.array([2]))
  163. self.assertIsInstance(trial_2.array_annotations, ArrayDict)
  164. # test getting only one event time of more than one event
  165. trial_2b = get_epochs(block, trial_id=2)
  166. self.assertEqual(len(trial_2b), 2)
  167. a_idx = np.where(np.array([ev.annotations['epoch_type'] for ev in trial_2b]) == 'a')[0][0]
  168. trial_2b_a = trial_2b[a_idx]
  169. trial_2b_b = trial_2b[a_idx - 1]
  170. assert_same_attributes(trial_2b_a, trial_2)
  171. self.assertEqual(b_1.name, trial_2b_b.name)
  172. self.assertEqual(b_1.description, trial_2b_b.description)
  173. self.assertEqual(b_1.file_origin, trial_2b_b.file_origin)
  174. self.assertEqual(b_1.annotations['epoch_type'], trial_2b_b.annotations['epoch_type'])
  175. assert_arrays_equal(trial_2b_b.array_annotations['trial_id'], np.array([2]))
  176. self.assertIsInstance(trial_2b_b.array_annotations, ArrayDict)
  177. # test getting more than one event time of one event
  178. trials_1_2 = get_epochs(block, trial_id=[1, 2], epoch_type='a')
  179. self.assertEqual(len(trials_1_2), 1)
  180. trials_1_2 = trials_1_2[0]
  181. self.assertEqual(a_1.name, trials_1_2.name)
  182. self.assertEqual(a_1.description, trials_1_2.description)
  183. self.assertEqual(a_1.file_origin, trials_1_2.file_origin)
  184. self.assertEqual(a_1.annotations['epoch_type'], trials_1_2.annotations['epoch_type'])
  185. assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2]))
  186. self.assertIsInstance(trials_1_2.array_annotations, ArrayDict)
  187. # test getting more than one event time of more than one event
  188. trials_1_2b = get_epochs(block, trial_id=[1, 2])
  189. self.assertEqual(len(trials_1_2b), 2)
  190. a_idx = np.where(np.array([ev.annotations['epoch_type']
  191. for ev in trials_1_2b]) == 'a')[0][0]
  192. trials_1_2b_a = trials_1_2b[a_idx]
  193. trials_1_2b_b = trials_1_2b[a_idx - 1]
  194. assert_same_attributes(trials_1_2b_a, trials_1_2)
  195. self.assertEqual(b_1.name, trials_1_2b_b.name)
  196. self.assertEqual(b_1.description, trials_1_2b_b.description)
  197. self.assertEqual(b_1.file_origin, trials_1_2b_b.file_origin)
  198. self.assertEqual(b_1.annotations['epoch_type'], trials_1_2b_b.annotations['epoch_type'])
  199. assert_arrays_equal(trials_1_2b_b.array_annotations['trial_id'], np.array([1, 2]))
  200. self.assertIsInstance(trials_1_2b_b.array_annotations, ArrayDict)
  201. def test__add_epoch(self):
  202. starts = Event(times=[0.5, 10.0, 25.2] * pq.s)
  203. starts.annotate(event_type='trial start')
  204. starts.array_annotate(trial_id=[1, 2, 3])
  205. stops = Event(times=[5.5, 14.9, 30.1] * pq.s)
  206. stops.annotate(event_type='trial stop')
  207. stops.array_annotate(trial_id=[1, 2, 3])
  208. seg = Segment()
  209. seg.events = [starts, stops]
  210. # test cutting with one event only
  211. ep_starts = add_epoch(seg, starts, pre=-300 * pq.ms, post=250 * pq.ms)
  212. assert_neo_object_is_compliant(ep_starts)
  213. assert_same_annotations(ep_starts, starts)
  214. assert_arrays_almost_equal(ep_starts.times, starts.times - 300 * pq.ms, 1e-12)
  215. assert_arrays_almost_equal(ep_starts.durations,
  216. (550 * pq.ms).rescale(ep_starts.durations.units)
  217. * np.ones(len(starts)), 1e-12)
  218. # test cutting with two events
  219. ep_trials = add_epoch(seg, starts, stops)
  220. assert_neo_object_is_compliant(ep_trials)
  221. assert_same_annotations(ep_trials, starts)
  222. assert_arrays_almost_equal(ep_trials.times, starts.times, 1e-12)
  223. assert_arrays_almost_equal(ep_trials.durations, stops - starts, 1e-12)
  224. def test__match_events(self):
  225. starts = Event(times=[0.5, 10.0, 25.2] * pq.s)
  226. starts.annotate(event_type='trial start')
  227. starts.array_annotate(trial_id=[1, 2, 3])
  228. stops = Event(times=[5.5, 14.9, 30.1] * pq.s)
  229. stops.annotate(event_type='trial stop')
  230. stops.array_annotate(trial_id=[1, 2, 3])
  231. stops2 = Event(times=[0.1, 5.5, 5.6, 14.9, 25.2, 30.1] * pq.s)
  232. stops2.annotate(event_type='trial stop')
  233. stops2.array_annotate(trial_id=[1, 1, 2, 2, 3, 3])
  234. # test for matching input events, should just return identical copies
  235. matched_starts, matched_stops = match_events(starts, stops)
  236. assert_same_attributes(matched_starts, starts)
  237. assert_same_attributes(matched_stops, stops)
  238. # test for non-matching input events, should find shortest positive non-zero durations
  239. matched_starts2, matched_stops2 = match_events(starts, stops2)
  240. assert_same_attributes(matched_starts2, starts)
  241. assert_same_attributes(matched_stops2, stops)
  242. def test__cut_block_by_epochs(self):
  243. epoch = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s,
  244. t_start=.1 * pq.s)
  245. epoch.annotate(epoch_type='a', pick='me')
  246. epoch.array_annotate(trial_id=[1, 2, 3])
  247. epoch2 = Epoch([0.6, 9.5, 16.8, 34.1] * pq.s, durations=[4.5, 4.8, 5.0, 5.0] * pq.s,
  248. t_start=.1 * pq.s)
  249. epoch2.annotate(epoch_type='b')
  250. epoch2.array_annotate(trial_id=[1, 2, 3, 4])
  251. event = Event(times=[0.5, 10.0, 25.2] * pq.s, t_start=.1 * pq.s)
  252. event.annotate(event_type='trial start')
  253. event.array_annotate(trial_id=[1, 2, 3])
  254. anasig = AnalogSignal(np.arange(50.0) * pq.mV, t_start=.1 * pq.s,
  255. sampling_rate=1.0 * pq.Hz)
  256. irrsig = IrregularlySampledSignal(signal=np.arange(50.0) * pq.mV,
  257. times=anasig.times, t_start=.1 * pq.s)
  258. st = SpikeTrain(np.arange(0.5, 50, 7) * pq.s, t_start=.1 * pq.s, t_stop=50.0 * pq.s,
  259. waveforms=np.array([[[0., 1.], [0.1, 1.1]], [[2., 3.], [2.1, 3.1]],
  260. [[4., 5.], [4.1, 5.1]], [[6., 7.], [6.1, 7.1]],
  261. [[8., 9.], [8.1, 9.1]], [[12., 13.], [12.1, 13.1]],
  262. [[14., 15.], [14.1, 15.1]],
  263. [[16., 17.], [16.1, 17.1]]]) * pq.mV,
  264. array_annotations={'spikenum': np.arange(1, 9)})
  265. # test without resetting the time
  266. seg = Segment()
  267. seg2 = Segment(name='NoCut')
  268. seg.epochs = [epoch, epoch2]
  269. seg.events = [event]
  270. seg.analogsignals = [anasig]
  271. seg.irregularlysampledsignals = [irrsig]
  272. seg.spiketrains = [st]
  273. original_block = Block()
  274. original_block.segments = [seg, seg2]
  275. original_block.create_many_to_one_relationship()
  276. block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
  277. assert_neo_object_is_compliant(block)
  278. self.assertEqual(len(block.segments), 3)
  279. for epoch_idx in range(len(epoch)):
  280. self.assertEqual(len(block.segments[epoch_idx].events), 1)
  281. self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
  282. self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
  283. self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)
  284. if epoch_idx != 0:
  285. self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
  286. else:
  287. self.assertEqual(len(block.segments[epoch_idx].epochs), 2)
  288. assert_same_attributes(block.segments[epoch_idx].spiketrains[0],
  289. st.time_slice(t_start=epoch.times[epoch_idx],
  290. t_stop=epoch.times[epoch_idx]
  291. + epoch.durations[epoch_idx]))
  292. assert_same_attributes(block.segments[epoch_idx].analogsignals[0],
  293. anasig.time_slice(t_start=epoch.times[epoch_idx],
  294. t_stop=epoch.times[epoch_idx]
  295. + epoch.durations[epoch_idx]))
  296. assert_same_attributes(block.segments[epoch_idx].irregularlysampledsignals[0],
  297. irrsig.time_slice(t_start=epoch.times[epoch_idx],
  298. t_stop=epoch.times[epoch_idx]
  299. + epoch.durations[epoch_idx]))
  300. assert_same_attributes(block.segments[epoch_idx].events[0],
  301. event.time_slice(t_start=epoch.times[epoch_idx],
  302. t_stop=epoch.times[epoch_idx]
  303. + epoch.durations[epoch_idx]))
  304. assert_same_attributes(block.segments[0].epochs[0],
  305. epoch.time_slice(t_start=epoch.times[0],
  306. t_stop=epoch.times[0] + epoch.durations[0]))
  307. assert_same_attributes(block.segments[0].epochs[1],
  308. epoch2.time_slice(t_start=epoch.times[0],
  309. t_stop=epoch.times[0] + epoch.durations[0]))
  310. # test with resetting the time
  311. seg = Segment()
  312. seg2 = Segment(name='NoCut')
  313. seg.epochs = [epoch, epoch2]
  314. seg.events = [event]
  315. seg.analogsignals = [anasig]
  316. seg.irregularlysampledsignals = [irrsig]
  317. seg.spiketrains = [st]
  318. original_block = Block()
  319. original_block.segments = [seg, seg2]
  320. original_block.create_many_to_one_relationship()
  321. block = cut_block_by_epochs(original_block, properties={'pick': 'me'}, reset_time=True)
  322. assert_neo_object_is_compliant(block)
  323. self.assertEqual(len(block.segments), 3)
  324. for epoch_idx in range(len(epoch)):
  325. self.assertEqual(len(block.segments[epoch_idx].events), 1)
  326. self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
  327. self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
  328. self.assertEqual(len(block.segments[epoch_idx].irregularlysampledsignals), 1)
  329. if epoch_idx != 0:
  330. self.assertEqual(len(block.segments[epoch_idx].epochs), 1)
  331. else:
  332. self.assertEqual(len(block.segments[epoch_idx].epochs), 2)
  333. assert_same_attributes(block.segments[epoch_idx].spiketrains[0],
  334. st.time_shift(- epoch.times[epoch_idx]).time_slice(
  335. t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx]))
  336. anasig_target = anasig.time_shift(- epoch.times[epoch_idx])
  337. anasig_target = anasig_target.time_slice(t_start=0 * pq.s,
  338. t_stop=epoch.durations[epoch_idx])
  339. assert_same_attributes(block.segments[epoch_idx].analogsignals[0], anasig_target)
  340. irrsig_target = irrsig.time_shift(- epoch.times[epoch_idx])
  341. irrsig_target = irrsig_target.time_slice(t_start=0 * pq.s,
  342. t_stop=epoch.durations[epoch_idx])
  343. assert_same_attributes(block.segments[epoch_idx].irregularlysampledsignals[0],
  344. irrsig_target)
  345. assert_same_attributes(block.segments[epoch_idx].events[0],
  346. event.time_shift(- epoch.times[epoch_idx]).time_slice(
  347. t_start=0 * pq.s, t_stop=epoch.durations[epoch_idx]))
  348. assert_same_attributes(block.segments[0].epochs[0],
  349. epoch.time_shift(- epoch.times[0]).time_slice(t_start=0 * pq.s,
  350. t_stop=epoch.durations[0]))
  351. assert_same_attributes(block.segments[0].epochs[1],
  352. epoch2.time_shift(- epoch.times[0]).time_slice(t_start=0 * pq.s,
  353. t_stop=epoch.durations[0]))
  354. class TestUtilsWithProxyObjects(BaseProxyTest):
  355. def test__get_events(self):
  356. starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s)
  357. starts_1.annotate(event_type='trial start', pick='me')
  358. starts_1.array_annotate(trial_id=[1, 2, 3])
  359. stops_1 = Event(times=[5.5, 14.9, 30.1] * pq.s)
  360. stops_1.annotate(event_type='trial stop')
  361. stops_1.array_annotate(trial_id=[1, 2, 3])
  362. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  363. block_index=0, seg_index=0)
  364. proxy_event.annotate(event_type='trial start')
  365. seg = Segment()
  366. seg.events = [starts_1, stops_1, proxy_event]
  367. # test getting multiple events including a proxy
  368. extracted_starts = get_events(seg, event_type='trial start')
  369. self.assertEqual(len(extracted_starts), 2)
  370. # make sure the event is loaded and a neo.Event object is returned
  371. self.assertTrue(isinstance(extracted_starts[0], Event))
  372. self.assertTrue(isinstance(extracted_starts[1], Event))
  373. def test__get_epochs(self):
  374. a = Epoch([0.5, 10.0, 25.2] * pq.s, durations=[5.1, 4.8, 5.0] * pq.s)
  375. a.annotate(epoch_type='a', pick='me')
  376. a.array_annotate(trial_id=[1, 2, 3])
  377. b = Epoch([5.5, 14.9, 30.1] * pq.s, durations=[4.7, 4.9, 5.2] * pq.s)
  378. b.annotate(epoch_type='b')
  379. b.array_annotate(trial_id=[1, 2, 3])
  380. proxy_epoch = EpochProxy(rawio=self.reader, event_channel_index=1,
  381. block_index=0, seg_index=0)
  382. proxy_epoch.annotate(epoch_type='a')
  383. seg = Segment()
  384. seg.epochs = [a, b, proxy_epoch]
  385. # test getting multiple epochs including a proxy
  386. extracted_epochs = get_epochs(seg, epoch_type='a')
  387. self.assertEqual(len(extracted_epochs), 2)
  388. # make sure the epoch is loaded and a neo.Epoch object is returned
  389. self.assertTrue(isinstance(extracted_epochs[0], Epoch))
  390. self.assertTrue(isinstance(extracted_epochs[1], Epoch))
  391. def test__add_epoch(self):
  392. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  393. block_index=0, seg_index=0)
  394. loaded_event = proxy_event.load()
  395. regular_event = Event(times=loaded_event.times - 1 * loaded_event.units)
  396. seg = Segment()
  397. seg.events = [regular_event, proxy_event]
  398. # test cutting with two events one of which is a proxy
  399. epoch = add_epoch(seg, regular_event, proxy_event)
  400. assert_neo_object_is_compliant(epoch)
  401. assert_same_annotations(epoch, regular_event)
  402. assert_arrays_almost_equal(epoch.times, regular_event.times, 1e-12)
  403. assert_arrays_almost_equal(epoch.durations,
  404. np.ones(regular_event.shape) * loaded_event.units, 1e-12)
  405. def test__match_events(self):
  406. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  407. block_index=0, seg_index=0)
  408. loaded_event = proxy_event.load()
  409. regular_event = Event(times=loaded_event.times - 1 * loaded_event.units,
  410. labels=np.array(['trigger_a', 'trigger_b'] * 3, dtype='U12'))
  411. seg = Segment()
  412. seg.events = [regular_event, proxy_event]
  413. # test matching two events one of which is a proxy
  414. matched_regular, matched_proxy = match_events(regular_event, proxy_event)
  415. assert_same_attributes(matched_regular, regular_event)
  416. assert_same_attributes(matched_proxy, loaded_event)
  417. def test__cut_block_by_epochs(self):
  418. seg = Segment()
  419. proxy_anasig = AnalogSignalProxy(rawio=self.reader,
  420. global_channel_indexes=None,
  421. block_index=0, seg_index=0)
  422. seg.analogsignals.append(proxy_anasig)
  423. proxy_st = SpikeTrainProxy(rawio=self.reader, unit_index=0,
  424. block_index=0, seg_index=0)
  425. seg.spiketrains.append(proxy_st)
  426. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  427. block_index=0, seg_index=0)
  428. seg.events.append(proxy_event)
  429. proxy_epoch = EpochProxy(rawio=self.reader, event_channel_index=1,
  430. block_index=0, seg_index=0)
  431. proxy_epoch.annotate(pick='me')
  432. seg.epochs.append(proxy_epoch)
  433. loaded_epoch = proxy_epoch.load()
  434. loaded_event = proxy_event.load()
  435. loaded_st = proxy_st.load()
  436. loaded_anasig = proxy_anasig.load()
  437. original_block = Block()
  438. original_block.segments = [seg]
  439. original_block.create_many_to_one_relationship()
  440. block = cut_block_by_epochs(original_block, properties={'pick': 'me'})
  441. assert_neo_object_is_compliant(block)
  442. self.assertEqual(len(block.segments), proxy_epoch.shape[0])
  443. for epoch_idx in range(len(loaded_epoch)):
  444. sliced_event = loaded_event.time_slice(t_start=loaded_epoch.times[epoch_idx],
  445. t_stop=loaded_epoch.times[epoch_idx]
  446. + loaded_epoch.durations[epoch_idx])
  447. has_event = len(sliced_event) > 0
  448. sliced_anasig = loaded_anasig.time_slice(t_start=loaded_epoch.times[epoch_idx],
  449. t_stop=loaded_epoch.times[epoch_idx]
  450. + loaded_epoch.durations[epoch_idx])
  451. sliced_st = loaded_st.time_slice(t_start=loaded_epoch.times[epoch_idx],
  452. t_stop=loaded_epoch.times[epoch_idx]
  453. + loaded_epoch.durations[epoch_idx])
  454. self.assertEqual(len(block.segments[epoch_idx].events), int(has_event))
  455. self.assertEqual(len(block.segments[epoch_idx].spiketrains), 1)
  456. self.assertEqual(len(block.segments[epoch_idx].analogsignals), 1)
  457. self.assertTrue(isinstance(block.segments[epoch_idx].spiketrains[0],
  458. SpikeTrain))
  459. assert_same_attributes(block.segments[epoch_idx].spiketrains[0],
  460. sliced_st)
  461. self.assertTrue(isinstance(block.segments[epoch_idx].analogsignals[0],
  462. AnalogSignal))
  463. assert_same_attributes(block.segments[epoch_idx].analogsignals[0],
  464. sliced_anasig)
  465. if has_event:
  466. self.assertTrue(isinstance(block.segments[epoch_idx].events[0],
  467. Event))
  468. assert_same_attributes(block.segments[epoch_idx].events[0],
  469. sliced_event)
  470. block2 = Block()
  471. seg2 = Segment()
  472. epoch = Epoch(np.arange(10) * pq.s, durations=np.ones(10) * pq.s)
  473. epoch.annotate(pick='me instead')
  474. seg2.epochs = [proxy_epoch, epoch]
  475. block2.segments = [seg2]
  476. block2.create_many_to_one_relationship()
  477. # test correct loading and slicing of EpochProxy objects
  478. # (not tested above since we used the EpochProxy to cut the block)
  479. block3 = cut_block_by_epochs(block2, properties={'pick': 'me instead'})
  480. for epoch_idx in range(len(epoch)):
  481. sliced_epoch = loaded_epoch.time_slice(t_start=epoch.times[epoch_idx],
  482. t_stop=epoch.times[epoch_idx]
  483. + epoch.durations[epoch_idx])
  484. has_epoch = len(sliced_epoch) > 0
  485. if has_epoch:
  486. self.assertTrue(isinstance(block3.segments[epoch_idx].epochs[0],
  487. Epoch))
  488. assert_same_attributes(block3.segments[epoch_idx].epochs[0],
  489. sliced_epoch)
  490. if __name__ == "__main__":
  491. unittest.main()