test_nsdfio.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """
  2. Tests of neo.io.NSDFIO
  3. """
  4. import numpy as np
  5. import quantities as pq
  6. from datetime import datetime
  7. import os
  8. import unittest
  9. from neo.io.nsdfio import HAVE_NSDF, NSDFIO
  10. from neo.test.iotest.common_io_test import BaseTestIO
  11. from neo.core import AnalogSignal, Segment, Block, ChannelIndex
  12. from neo.test.tools import assert_same_attributes, assert_same_annotations, \
  13. assert_neo_object_is_compliant
  14. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  15. class CommonTests(BaseTestIO, unittest.TestCase):
  16. ioclass = NSDFIO
  17. read_and_write_is_bijective = False
  18. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  19. class NSDFIOTest(unittest.TestCase):
  20. """
  21. Base class for all NSDFIO tests.
  22. setUp and tearDown methods are responsible for setting up and cleaning after tests,
  23. respectively
  24. All create_{object} methods create and return an example {object}.
  25. """
  26. def setUp(self):
  27. self.filename = 'nsdfio_testfile.h5'
  28. self.io = NSDFIO(self.filename)
  29. def tearDown(self):
  30. os.remove(self.filename)
  31. def create_list_of_blocks(self):
  32. blocks = []
  33. for i in range(2):
  34. blocks.append(self.create_block(name='Block #{}'.format(i)))
  35. return blocks
  36. def create_block(self, name='Block'):
  37. block = Block()
  38. self._assign_basic_attributes(block, name=name)
  39. self._assign_datetime_attributes(block)
  40. self._assign_index_attribute(block)
  41. self._create_block_children(block)
  42. self._assign_annotations(block)
  43. return block
  44. def _create_block_children(self, block):
  45. for i in range(3):
  46. block.segments.append(self.create_segment(block, name='Segment #{}'.format(i)))
  47. for i in range(3):
  48. block.channel_indexes.append(
  49. self.create_channelindex(block, name='ChannelIndex #{}'.format(i),
  50. analogsignals=[seg.analogsignals[i] for seg in
  51. block.segments]))
  52. def create_segment(self, parent=None, name='Segment'):
  53. segment = Segment()
  54. segment.block = parent
  55. self._assign_basic_attributes(segment, name=name)
  56. self._assign_datetime_attributes(segment)
  57. self._assign_index_attribute(segment)
  58. self._create_segment_children(segment)
  59. self._assign_annotations(segment)
  60. return segment
  61. def _create_segment_children(self, segment):
  62. for i in range(2):
  63. segment.analogsignals.append(self.create_analogsignal(
  64. segment, name='Signal #{}'.format(i * 3)))
  65. segment.analogsignals.append(self.create_analogsignal2(
  66. segment, name='Signal #{}'.format(i * 3 + 1)))
  67. segment.analogsignals.append(self.create_analogsignal3(
  68. segment, name='Signal #{}'.format(i * 3 + 2)))
  69. def create_analogsignal(self, parent=None, name='AnalogSignal1'):
  70. signal = AnalogSignal([[1.0, 2.5], [2.2, 3.1], [3.2, 4.4]], units='mV',
  71. sampling_rate=100 * pq.Hz, t_start=2 * pq.min)
  72. signal.segment = parent
  73. self._assign_basic_attributes(signal, name=name)
  74. self._assign_annotations(signal)
  75. return signal
  76. def create_analogsignal2(self, parent=None, name='AnalogSignal2'):
  77. signal = AnalogSignal([[1], [2], [3], [4], [5]], units='mA',
  78. sampling_period=0.5 * pq.ms)
  79. signal.segment = parent
  80. self._assign_annotations(signal)
  81. return signal
  82. def create_analogsignal3(self, parent=None, name='AnalogSignal3'):
  83. signal = AnalogSignal([[1, 2, 3], [4, 5, 6]], units='mV',
  84. sampling_rate=2 * pq.kHz, t_start=100 * pq.s)
  85. signal.segment = parent
  86. self._assign_basic_attributes(signal, name=name)
  87. return signal
  88. def create_channelindex(self, parent=None, name='ChannelIndex', analogsignals=None):
  89. channels_num = min([signal.shape[1] for signal in analogsignals])
  90. channelindex = ChannelIndex(index=np.arange(channels_num),
  91. channel_names=['Channel{}'.format(
  92. i) for i in range(channels_num)],
  93. channel_ids=np.arange(channels_num),
  94. coordinates=([[1.87, -5.2, 4.0]] * channels_num) * pq.cm)
  95. for signal in analogsignals:
  96. channelindex.analogsignals.append(signal)
  97. self._assign_basic_attributes(channelindex, name)
  98. self._assign_annotations(channelindex)
  99. return channelindex
  100. def _assign_basic_attributes(self, object, name=None):
  101. if name is None:
  102. object.name = 'neo object'
  103. else:
  104. object.name = name
  105. object.description = 'Example of neo object'
  106. object.file_origin = 'datafile.pp'
  107. def _assign_datetime_attributes(self, object):
  108. object.file_datetime = datetime(2017, 6, 11, 14, 53, 23)
  109. object.rec_datetime = datetime(2017, 5, 29, 13, 12, 47)
  110. def _assign_index_attribute(self, object):
  111. object.index = 12
  112. def _assign_annotations(self, object):
  113. object.annotations = {'str': 'value',
  114. 'int': 56,
  115. 'float': 5.234}
  116. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  117. class NSDFIOTestWriteThenRead(NSDFIOTest):
  118. """
  119. Class for testing NSDFIO.
  120. It first creates example neo objects, then writes them to the file,
  121. reads the file and compares the result with the original ones.
  122. all test_{object} methods run "write then read" test for a/an {object}
  123. all compare_{object} methods check if the second {object} is a proper copy
  124. of the first one
  125. """
  126. lazy_modes = [False]
  127. def test_list_of_blocks(self, lazy=False):
  128. blocks = self.create_list_of_blocks()
  129. self.io.write(blocks)
  130. for lazy in self.lazy_modes:
  131. blocks2 = self.io.read(lazy=lazy)
  132. self.compare_list_of_blocks(blocks, blocks2, lazy)
  133. def test_block(self, lazy=False):
  134. block = self.create_block()
  135. self.io.write_block(block)
  136. for lazy in self.lazy_modes:
  137. block2 = self.io.read_block(lazy=lazy)
  138. self.compare_blocks(block, block2, lazy)
  139. def test_segment(self, lazy=False):
  140. segment = self.create_segment()
  141. self.io.write_segment(segment)
  142. for lazy in self.lazy_modes:
  143. segment2 = self.io.read_segment(lazy=lazy)
  144. self.compare_segments(segment, segment2, lazy)
  145. def compare_list_of_blocks(self, blocks1, blocks2, lazy=False):
  146. assert len(blocks1) == len(blocks2)
  147. for block1, block2 in zip(blocks1, blocks2):
  148. self.compare_blocks(block1, block2, lazy)
  149. def compare_blocks(self, block1, block2, lazy=False):
  150. self._compare_objects(block1, block2)
  151. assert block2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  152. assert_neo_object_is_compliant(block2)
  153. self._compare_blocks_children(block1, block2, lazy=lazy)
  154. def _compare_blocks_children(self, block1, block2, lazy):
  155. assert len(block1.segments) == len(block2.segments)
  156. for segment1, segment2 in zip(block1.segments, block2.segments):
  157. self.compare_segments(segment1, segment2, lazy=lazy)
  158. def compare_segments(self, segment1, segment2, lazy=False):
  159. self._compare_objects(segment1, segment2)
  160. assert segment2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  161. self._compare_segments_children(segment1, segment2, lazy=lazy)
  162. def _compare_segments_children(self, segment1, segment2, lazy):
  163. assert len(segment1.analogsignals) == len(segment2.analogsignals)
  164. for signal1, signal2 in zip(segment1.analogsignals, segment2.analogsignals):
  165. self.compare_analogsignals(signal1, signal2, lazy=lazy)
  166. def compare_analogsignals(self, signal1, signal2, lazy=False):
  167. if not lazy:
  168. self._compare_objects(signal1, signal2)
  169. else:
  170. self._compare_objects(signal1, signal2, exclude_attr=['shape', 'signal'])
  171. assert signal2.lazy_shape == signal1.shape
  172. assert signal2.dtype == signal1.dtype
  173. def _compare_objects(self, object1, object2, exclude_attr=[]):
  174. assert object1.__class__.__name__ == object2.__class__.__name__
  175. assert object2.file_origin == self.filename
  176. assert_same_attributes(object1, object2, exclude=[
  177. 'file_origin',
  178. 'file_datetime'] + exclude_attr)
  179. assert_same_annotations(object1, object2)
  180. if __name__ == "__main__":
  181. unittest.main()