test_proxyobjects.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """
  2. Tests proxyobject mechanisms with ExampleRawIO
  3. """
  4. import unittest
  5. import numpy as np
  6. import quantities as pq
  7. from neo.rawio.examplerawio import ExampleRawIO
  8. from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy,
  9. EventProxy, EpochProxy)
  10. from neo.core import (Segment, AnalogSignal,
  11. Epoch, Event, SpikeTrain)
  12. from neo.test.tools import (assert_arrays_almost_equal,
  13. assert_neo_object_is_compliant,
  14. assert_same_attributes)
  15. class BaseProxyTest(unittest.TestCase):
  16. def setUp(self):
  17. self.reader = ExampleRawIO(filename='my_filename.fake')
  18. self.reader.parse_header()
  19. class TestAnalogSignalProxy(BaseProxyTest):
  20. def test_AnalogSignalProxy(self):
  21. proxy_anasig = AnalogSignalProxy(rawio=self.reader, global_channel_indexes=None,
  22. block_index=0, seg_index=0,)
  23. assert proxy_anasig.sampling_rate == 10 * pq.kHz
  24. assert proxy_anasig.t_start == 0 * pq.s
  25. assert proxy_anasig.t_stop == 10 * pq.s
  26. assert proxy_anasig.duration == 10 * pq.s
  27. assert proxy_anasig.file_origin == 'my_filename.fake'
  28. # full load
  29. full_anasig = proxy_anasig.load(time_slice=None)
  30. assert isinstance(full_anasig, AnalogSignal)
  31. assert_same_attributes(proxy_anasig, full_anasig)
  32. # slice time
  33. anasig = proxy_anasig.load(time_slice=(2. * pq.s, 5 * pq.s))
  34. assert anasig.t_start == 2. * pq.s
  35. assert anasig.duration == 3. * pq.s
  36. assert anasig.shape == (30000, 16)
  37. assert_same_attributes(proxy_anasig.time_slice(2. * pq.s, 5 * pq.s), anasig)
  38. # ceil next sample when slicing
  39. anasig = proxy_anasig.load(time_slice=(1.99999 * pq.s, 5.000001 * pq.s))
  40. assert anasig.t_start == 2. * pq.s
  41. assert anasig.duration == 3. * pq.s
  42. assert anasig.shape == (30000, 16)
  43. # buggy time slice
  44. with self.assertRaises(AssertionError):
  45. anasig = proxy_anasig.load(time_slice=(2. * pq.s, 15 * pq.s))
  46. anasig = proxy_anasig.load(time_slice=(2. * pq.s, 15 * pq.s), strict_slicing=False)
  47. assert proxy_anasig.t_stop == 10 * pq.s
  48. # select channels
  49. anasig = proxy_anasig.load(channel_indexes=[3, 4, 9])
  50. assert anasig.shape[1] == 3
  51. # select channels and slice times
  52. anasig = proxy_anasig.load(time_slice=(2. * pq.s, 5 * pq.s), channel_indexes=[3, 4, 9])
  53. assert anasig.shape == (30000, 3)
  54. # magnitude mode rescaled
  55. anasig_float = proxy_anasig.load(magnitude_mode='rescaled')
  56. assert anasig_float.dtype == 'float32'
  57. assert anasig_float.units == pq.uV
  58. assert anasig_float.units == proxy_anasig.units
  59. # magnitude mode raw
  60. anasig_int = proxy_anasig.load(magnitude_mode='raw')
  61. assert anasig_int.dtype == 'int16'
  62. assert anasig_int.units == pq.CompoundUnit('0.0152587890625*uV')
  63. assert_arrays_almost_equal(anasig_float, anasig_int.rescale('uV'), 1e-9)
  64. # test array_annotations
  65. assert 'info' in proxy_anasig.array_annotations
  66. assert proxy_anasig.array_annotations['info'].size == 16
  67. assert 'info' in anasig_float.array_annotations
  68. assert anasig_float.array_annotations['info'].size == 16
  69. def test_global_local_channel_indexes(self):
  70. proxy_anasig = AnalogSignalProxy(rawio=self.reader,
  71. global_channel_indexes=slice(0, 10, 2), block_index=0, seg_index=0)
  72. assert proxy_anasig.shape == (100000, 5)
  73. assert '(ch0,ch2,ch4,ch6,ch8)' in proxy_anasig.name
  74. # should be channel ch0 and ch6
  75. anasig = proxy_anasig.load(channel_indexes=[0, 3])
  76. assert anasig.shape == (100000, 2)
  77. assert '(ch0,ch6)' in anasig.name
  78. class TestSpikeTrainProxy(BaseProxyTest):
  79. def test_SpikeTrainProxy(self):
  80. proxy_sptr = SpikeTrainProxy(rawio=self.reader, unit_index=0,
  81. block_index=0, seg_index=0)
  82. assert proxy_sptr.name == 'unit0'
  83. assert proxy_sptr.t_start == 0 * pq.s
  84. assert proxy_sptr.t_stop == 10 * pq.s
  85. assert proxy_sptr.shape == (20,)
  86. assert proxy_sptr.left_sweep == 0.002 * pq.s
  87. assert proxy_sptr.sampling_rate == 10 * pq.kHz
  88. # full load
  89. full_sptr = proxy_sptr.load(time_slice=None)
  90. assert isinstance(full_sptr, SpikeTrain)
  91. assert_same_attributes(proxy_sptr, full_sptr)
  92. assert full_sptr.shape == proxy_sptr.shape
  93. # slice time
  94. sptr = proxy_sptr.load(time_slice=(250 * pq.ms, 500 * pq.ms))
  95. assert sptr.t_start == .25 * pq.s
  96. assert sptr.t_stop == .5 * pq.s
  97. assert sptr.shape == (6,)
  98. assert_same_attributes(proxy_sptr.time_slice(250 * pq.ms, 500 * pq.ms), sptr)
  99. # buggy time slice
  100. with self.assertRaises(AssertionError):
  101. sptr = proxy_sptr.load(time_slice=(2. * pq.s, 15 * pq.s))
  102. sptr = proxy_sptr.load(time_slice=(2. * pq.s, 15 * pq.s), strict_slicing=False)
  103. assert sptr.t_stop == 10 * pq.s
  104. # magnitude mode rescaled
  105. sptr_float = proxy_sptr.load(magnitude_mode='rescaled')
  106. assert sptr_float.dtype == 'float64'
  107. assert sptr_float.units == pq.s
  108. # magnitude mode raw
  109. # TODO when raw mode implemented
  110. # sptr_int = proxy_sptr.load(magnitude_mode='raw')
  111. # assert sptr_int.dtype=='int64'
  112. # assert sptr_int.units==pq.CompoundUnit('1/10000*s')
  113. # assert_arrays_almost_equal(sptr_float, sptr_int.rescale('s'), 1e-9)
  114. # Without waveforms
  115. sptr = proxy_sptr.load(load_waveforms=False)
  116. assert sptr.waveforms is None
  117. # With waveforms
  118. sptr = proxy_sptr.load(load_waveforms=True, magnitude_mode='rescaled')
  119. assert sptr.waveforms is not None
  120. assert sptr.waveforms.shape == (20, 1, 50)
  121. assert sptr.waveforms.units == 1 * pq.uV
  122. # slice waveforms
  123. sptr = proxy_sptr.load(load_waveforms=True, time_slice=(250 * pq.ms, 500 * pq.ms))
  124. assert sptr.waveforms.shape == (6, 1, 50)
  125. class TestEventProxy(BaseProxyTest):
  126. def test_EventProxy(self):
  127. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  128. block_index=0, seg_index=0)
  129. assert proxy_event.name == 'Some events'
  130. assert proxy_event.shape == (6,)
  131. # full load
  132. full_event = proxy_event.load(time_slice=None)
  133. assert isinstance(full_event, Event)
  134. assert_same_attributes(proxy_event, full_event, exclude=('times', 'labels'))
  135. assert full_event.shape == proxy_event.shape
  136. # slice time
  137. event = proxy_event.load(time_slice=(1 * pq.s, 2 * pq.s))
  138. assert event.shape == (2,)
  139. assert event.labels.shape == (2,)
  140. assert_same_attributes(proxy_event.time_slice(1 * pq.s, 2 * pq.s), event)
  141. # buggy time slice
  142. with self.assertRaises(AssertionError):
  143. event = proxy_event.load(time_slice=(2 * pq.s, 15 * pq.s))
  144. event = proxy_event.load(time_slice=(2 * pq.s, 15 * pq.s), strict_slicing=False)
  145. class TestEpochProxy(BaseProxyTest):
  146. def test_EpochProxy(self):
  147. proxy_epoch = EpochProxy(rawio=self.reader, event_channel_index=1,
  148. block_index=0, seg_index=0)
  149. assert proxy_epoch.name == 'Some epochs'
  150. assert proxy_epoch.shape == (10,)
  151. # full load
  152. full_epoch = proxy_epoch.load(time_slice=None)
  153. assert isinstance(full_epoch, Epoch)
  154. assert_same_attributes(proxy_epoch, full_epoch, exclude=('times', 'labels', 'durations'))
  155. assert full_epoch.shape == proxy_epoch.shape
  156. # slice time
  157. epoch = proxy_epoch.load(time_slice=(1 * pq.s, 4 * pq.s))
  158. assert epoch.shape == (3,)
  159. assert epoch.labels.shape == (3,)
  160. assert epoch.durations.shape == (3,)
  161. assert_same_attributes(proxy_epoch.time_slice(1 * pq.s, 4 * pq.s), epoch)
  162. # buggy time slice
  163. with self.assertRaises(AssertionError):
  164. epoch = proxy_epoch.load(time_slice=(2 * pq.s, 15 * pq.s))
  165. epoch = proxy_epoch.load(time_slice=(2 * pq.s, 15 * pq.s), strict_slicing=False)
  166. class TestSegmentWithProxy(BaseProxyTest):
  167. def test_segment_with_proxy(self):
  168. seg = Segment()
  169. proxy_anasig = AnalogSignalProxy(rawio=self.reader,
  170. global_channel_indexes=None,
  171. block_index=0, seg_index=0,)
  172. seg.analogsignals.append(proxy_anasig)
  173. proxy_sptr = SpikeTrainProxy(rawio=self.reader, unit_index=0,
  174. block_index=0, seg_index=0)
  175. seg.spiketrains.append(proxy_sptr)
  176. proxy_event = EventProxy(rawio=self.reader, event_channel_index=0,
  177. block_index=0, seg_index=0)
  178. seg.events.append(proxy_event)
  179. proxy_epoch = EpochProxy(rawio=self.reader, event_channel_index=1,
  180. block_index=0, seg_index=0)
  181. seg.epochs.append(proxy_epoch)
  182. if __name__ == "__main__":
  183. unittest.main()