test_nsdfio.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.NSDFIO
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import numpy as np
  8. import quantities as pq
  9. from datetime import datetime
  10. import os
  11. try:
  12. import unittest2 as unittest
  13. except ImportError:
  14. import unittest
  15. from neo.io.nsdfio import HAVE_NSDF, NSDFIO
  16. from neo.test.iotest.common_io_test import BaseTestIO
  17. from neo.core import AnalogSignal, Segment, Block
  18. from neo.test.tools import assert_same_attributes, assert_same_annotations, assert_neo_object_is_compliant
  19. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  20. class CommonTests(BaseTestIO, unittest.TestCase):
  21. ioclass = NSDFIO
  22. read_and_write_is_bijective = False
  23. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  24. class NSDFIOTest(unittest.TestCase):
  25. """
  26. Base class for all NSDFIO tests.
  27. setUp and tearDown methods are responsible for respectively: setting up and cleaning after tests
  28. All create_{object} methods create and return an example {object}.
  29. """
  30. def setUp(self):
  31. self.filename = 'nsdfio_testfile.h5'
  32. self.io = NSDFIO(self.filename)
  33. def tearDown(self):
  34. os.remove(self.filename)
  35. def create_list_of_blocks(self):
  36. blocks = []
  37. for i in range(2):
  38. blocks.append(self.create_block(name='Block #{}'.format(i)))
  39. return blocks
  40. def create_block(self, name='Block'):
  41. block = Block()
  42. self._assign_basic_attributes(block, name=name)
  43. self._assign_datetime_attributes(block)
  44. self._assign_index_attribute(block)
  45. self._create_block_children(block)
  46. self._assign_annotations(block)
  47. return block
  48. def _create_block_children(self, block):
  49. for i in range(3):
  50. block.segments.append(self.create_segment(block, name='Segment #{}'.format(i)))
  51. def create_segment(self, parent=None, name='Segment'):
  52. segment = Segment()
  53. segment.block = parent
  54. self._assign_basic_attributes(segment, name=name)
  55. self._assign_datetime_attributes(segment)
  56. self._assign_index_attribute(segment)
  57. self._create_segment_children(segment)
  58. self._assign_annotations(segment)
  59. return segment
  60. def _create_segment_children(self, segment):
  61. for i in range(2):
  62. segment.analogsignals.append(self.create_analogsignal(segment, name='Signal #{}'.format(i * 3)))
  63. segment.analogsignals.append(self.create_analogsignal2(segment, name='Signal #{}'.format(i * 3 + 1)))
  64. segment.analogsignals.append(self.create_analogsignal3(segment, name='Signal #{}'.format(i * 3 + 2)))
  65. def create_analogsignal(self, parent=None, name='AnalogSignal1'):
  66. signal = AnalogSignal([[1.0, 2.5], [2.2, 3.1], [3.2, 4.4]], units='mV',
  67. sampling_rate=100 * pq.Hz, t_start=2 * pq.min)
  68. signal.segment = parent
  69. self._assign_basic_attributes(signal, name=name)
  70. self._assign_annotations(signal)
  71. return signal
  72. def create_analogsignal2(self, parent=None, name='AnalogSignal2'):
  73. signal = AnalogSignal([[1], [2], [3], [4], [5]], units='mA',
  74. sampling_period=0.5 * pq.ms)
  75. signal.segment = parent
  76. self._assign_annotations(signal)
  77. return signal
  78. def create_analogsignal3(self, parent=None, name='AnalogSignal3'):
  79. signal = AnalogSignal([[1, 2, 3], [4, 5, 6]], units='mV',
  80. sampling_rate=2 * pq.kHz, t_start=100 * pq.s)
  81. signal.segment = parent
  82. self._assign_basic_attributes(signal, name=name)
  83. return signal
  84. def _assign_basic_attributes(self, object, name=None):
  85. if name is None:
  86. object.name = 'neo object'
  87. else:
  88. object.name = name
  89. object.description = 'Example of neo object'
  90. object.file_origin = 'datafile.pp'
  91. def _assign_datetime_attributes(self, object):
  92. object.file_datetime = datetime(2017, 6, 11, 14, 53, 23)
  93. object.rec_datetime = datetime(2017, 5, 29, 13, 12, 47)
  94. def _assign_index_attribute(self, object):
  95. object.index = 12
  96. def _assign_annotations(self, object):
  97. object.annotations = {'str': 'value',
  98. 'int': 56,
  99. 'float': 5.234}
  100. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  101. class NSDFIOTestWriteThenRead(NSDFIOTest):
  102. """
  103. Class for testing NSDFIO.
  104. It first creates example neo objects, then writes them to the file,
  105. reads the file and compares the result with the original ones.
  106. all test_{object} methods run "write then read" test for a/an {object}
  107. all compare_{object} methods check if the second {object} is a proper copy
  108. of the first one, read in suitable lazy and cascade mode
  109. """
  110. lazy_modes = [False, True]
  111. cascade_modes = [False, True]
  112. def test_list_of_blocks(self, lazy=False, cascade=True):
  113. blocks = self.create_list_of_blocks()
  114. self.io.write(blocks)
  115. for lazy in self.lazy_modes:
  116. for cascade in self.cascade_modes:
  117. blocks2 = self.io.read(lazy=lazy, cascade=cascade)
  118. self.compare_list_of_blocks(blocks, blocks2, lazy, cascade)
  119. def test_block(self, lazy=False, cascade=True):
  120. block = self.create_block()
  121. self.io.write_block(block)
  122. for lazy in self.lazy_modes:
  123. for cascade in self.cascade_modes:
  124. block2 = self.io.read_block(lazy=lazy, cascade=cascade)
  125. self.compare_blocks(block, block2, lazy, cascade)
  126. def test_segment(self, lazy=False, cascade=True):
  127. segment = self.create_segment()
  128. self.io.write_segment(segment)
  129. for lazy in self.lazy_modes:
  130. for cascade in self.cascade_modes:
  131. segment2 = self.io.read_segment(lazy=lazy, cascade=cascade)
  132. self.compare_segments(segment, segment2, lazy, cascade)
  133. def compare_list_of_blocks(self, blocks1, blocks2, lazy=False, cascade=True):
  134. assert len(blocks1) == len(blocks2)
  135. for block1, block2 in zip(blocks1, blocks2):
  136. self.compare_blocks(block1, block2, lazy, cascade)
  137. def compare_blocks(self, block1, block2, lazy=False, cascade=True):
  138. self._compare_objects(block1, block2)
  139. assert block2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  140. assert_neo_object_is_compliant(block2)
  141. if cascade:
  142. self._compare_blocks_children(block1, block2, lazy=lazy)
  143. else:
  144. assert len(block2.segments) == 0
  145. def _compare_blocks_children(self, block1, block2, lazy):
  146. assert len(block1.segments) == len(block2.segments)
  147. for segment1, segment2 in zip(block1.segments, block2.segments):
  148. self.compare_segments(segment1, segment2, lazy=lazy)
  149. def compare_segments(self, segment1, segment2, lazy=False, cascade=True):
  150. self._compare_objects(segment1, segment2)
  151. assert segment2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  152. if cascade:
  153. self._compare_segments_children(segment1, segment2, lazy=lazy)
  154. else:
  155. assert len(segment2.analogsignals) == 0
  156. def _compare_segments_children(self, segment1, segment2, lazy):
  157. assert len(segment1.analogsignals) == len(segment2.analogsignals)
  158. for signal1, signal2 in zip(segment1.analogsignals, segment2.analogsignals):
  159. self.compare_analogsignals(signal1, signal2, lazy=lazy)
  160. def compare_analogsignals(self, signal1, signal2, lazy=False, cascade=True):
  161. if not lazy:
  162. self._compare_objects(signal1, signal2)
  163. else:
  164. self._compare_objects(signal1, signal2, exclude_attr=['shape', 'signal'])
  165. assert signal2.lazy_shape == signal1.shape
  166. assert signal2.dtype == signal1.dtype
  167. def _compare_objects(self, object1, object2, exclude_attr=[]):
  168. assert object1.__class__.__name__ == object2.__class__.__name__
  169. assert object2.file_origin == self.filename
  170. assert_same_attributes(object1, object2, exclude=['file_origin', 'file_datetime'] + exclude_attr)
  171. assert_same_annotations(object1, object2)
  172. if __name__ == "__main__":
  173. unittest.main()