test_nsdfio.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.NSDFIO
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import numpy as np
  8. import quantities as pq
  9. from datetime import datetime
  10. import os
  11. import unittest
  12. from neo.io.nsdfio import HAVE_NSDF, NSDFIO
  13. from neo.test.iotest.common_io_test import BaseTestIO
  14. from neo.core import AnalogSignal, Segment, Block, ChannelIndex
  15. from neo.test.tools import assert_same_attributes, assert_same_annotations, \
  16. assert_neo_object_is_compliant
  17. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  18. class CommonTests(BaseTestIO, unittest.TestCase):
  19. ioclass = NSDFIO
  20. read_and_write_is_bijective = False
  21. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  22. class NSDFIOTest(unittest.TestCase):
  23. """
  24. Base class for all NSDFIO tests.
  25. setUp and tearDown methods are responsible for setting up and cleaning after tests,
  26. respectively
  27. All create_{object} methods create and return an example {object}.
  28. """
  29. def setUp(self):
  30. self.filename = 'nsdfio_testfile.h5'
  31. self.io = NSDFIO(self.filename)
  32. def tearDown(self):
  33. os.remove(self.filename)
  34. def create_list_of_blocks(self):
  35. blocks = []
  36. for i in range(2):
  37. blocks.append(self.create_block(name='Block #{}'.format(i)))
  38. return blocks
  39. def create_block(self, name='Block'):
  40. block = Block()
  41. self._assign_basic_attributes(block, name=name)
  42. self._assign_datetime_attributes(block)
  43. self._assign_index_attribute(block)
  44. self._create_block_children(block)
  45. self._assign_annotations(block)
  46. return block
  47. def _create_block_children(self, block):
  48. for i in range(3):
  49. block.segments.append(self.create_segment(block, name='Segment #{}'.format(i)))
  50. for i in range(3):
  51. block.channel_indexes.append(
  52. self.create_channelindex(block, name='ChannelIndex #{}'.format(i),
  53. analogsignals=[seg.analogsignals[i] for seg in
  54. block.segments]))
  55. def create_segment(self, parent=None, name='Segment'):
  56. segment = Segment()
  57. segment.block = parent
  58. self._assign_basic_attributes(segment, name=name)
  59. self._assign_datetime_attributes(segment)
  60. self._assign_index_attribute(segment)
  61. self._create_segment_children(segment)
  62. self._assign_annotations(segment)
  63. return segment
  64. def _create_segment_children(self, segment):
  65. for i in range(2):
  66. segment.analogsignals.append(self.create_analogsignal(
  67. segment, name='Signal #{}'.format(i * 3)))
  68. segment.analogsignals.append(self.create_analogsignal2(
  69. segment, name='Signal #{}'.format(i * 3 + 1)))
  70. segment.analogsignals.append(self.create_analogsignal3(
  71. segment, name='Signal #{}'.format(i * 3 + 2)))
  72. def create_analogsignal(self, parent=None, name='AnalogSignal1'):
  73. signal = AnalogSignal([[1.0, 2.5], [2.2, 3.1], [3.2, 4.4]], units='mV',
  74. sampling_rate=100 * pq.Hz, t_start=2 * pq.min)
  75. signal.segment = parent
  76. self._assign_basic_attributes(signal, name=name)
  77. self._assign_annotations(signal)
  78. return signal
  79. def create_analogsignal2(self, parent=None, name='AnalogSignal2'):
  80. signal = AnalogSignal([[1], [2], [3], [4], [5]], units='mA',
  81. sampling_period=0.5 * pq.ms)
  82. signal.segment = parent
  83. self._assign_annotations(signal)
  84. return signal
  85. def create_analogsignal3(self, parent=None, name='AnalogSignal3'):
  86. signal = AnalogSignal([[1, 2, 3], [4, 5, 6]], units='mV',
  87. sampling_rate=2 * pq.kHz, t_start=100 * pq.s)
  88. signal.segment = parent
  89. self._assign_basic_attributes(signal, name=name)
  90. return signal
  91. def create_channelindex(self, parent=None, name='ChannelIndex', analogsignals=None):
  92. channels_num = min([signal.shape[1] for signal in analogsignals])
  93. channelindex = ChannelIndex(index=np.arange(channels_num),
  94. channel_names=['Channel{}'.format(
  95. i) for i in range(channels_num)],
  96. channel_ids=np.arange(channels_num),
  97. coordinates=([[1.87, -5.2, 4.0]] * channels_num) * pq.cm)
  98. for signal in analogsignals:
  99. channelindex.analogsignals.append(signal)
  100. self._assign_basic_attributes(channelindex, name)
  101. self._assign_annotations(channelindex)
  102. return channelindex
  103. def _assign_basic_attributes(self, object, name=None):
  104. if name is None:
  105. object.name = 'neo object'
  106. else:
  107. object.name = name
  108. object.description = 'Example of neo object'
  109. object.file_origin = 'datafile.pp'
  110. def _assign_datetime_attributes(self, object):
  111. object.file_datetime = datetime(2017, 6, 11, 14, 53, 23)
  112. object.rec_datetime = datetime(2017, 5, 29, 13, 12, 47)
  113. def _assign_index_attribute(self, object):
  114. object.index = 12
  115. def _assign_annotations(self, object):
  116. object.annotations = {'str': 'value',
  117. 'int': 56,
  118. 'float': 5.234}
  119. @unittest.skipUnless(HAVE_NSDF, "Requires NSDF")
  120. class NSDFIOTestWriteThenRead(NSDFIOTest):
  121. """
  122. Class for testing NSDFIO.
  123. It first creates example neo objects, then writes them to the file,
  124. reads the file and compares the result with the original ones.
  125. all test_{object} methods run "write then read" test for a/an {object}
  126. all compare_{object} methods check if the second {object} is a proper copy
  127. of the first one
  128. """
  129. lazy_modes = [False]
  130. def test_list_of_blocks(self, lazy=False):
  131. blocks = self.create_list_of_blocks()
  132. self.io.write(blocks)
  133. for lazy in self.lazy_modes:
  134. blocks2 = self.io.read(lazy=lazy)
  135. self.compare_list_of_blocks(blocks, blocks2, lazy)
  136. def test_block(self, lazy=False):
  137. block = self.create_block()
  138. self.io.write_block(block)
  139. for lazy in self.lazy_modes:
  140. block2 = self.io.read_block(lazy=lazy)
  141. self.compare_blocks(block, block2, lazy)
  142. def test_segment(self, lazy=False):
  143. segment = self.create_segment()
  144. self.io.write_segment(segment)
  145. for lazy in self.lazy_modes:
  146. segment2 = self.io.read_segment(lazy=lazy)
  147. self.compare_segments(segment, segment2, lazy)
  148. def compare_list_of_blocks(self, blocks1, blocks2, lazy=False):
  149. assert len(blocks1) == len(blocks2)
  150. for block1, block2 in zip(blocks1, blocks2):
  151. self.compare_blocks(block1, block2, lazy)
  152. def compare_blocks(self, block1, block2, lazy=False):
  153. self._compare_objects(block1, block2)
  154. assert block2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  155. assert_neo_object_is_compliant(block2)
  156. self._compare_blocks_children(block1, block2, lazy=lazy)
  157. def _compare_blocks_children(self, block1, block2, lazy):
  158. assert len(block1.segments) == len(block2.segments)
  159. for segment1, segment2 in zip(block1.segments, block2.segments):
  160. self.compare_segments(segment1, segment2, lazy=lazy)
  161. def compare_segments(self, segment1, segment2, lazy=False):
  162. self._compare_objects(segment1, segment2)
  163. assert segment2.file_datetime == datetime.fromtimestamp(os.stat(self.filename).st_mtime)
  164. self._compare_segments_children(segment1, segment2, lazy=lazy)
  165. def _compare_segments_children(self, segment1, segment2, lazy):
  166. assert len(segment1.analogsignals) == len(segment2.analogsignals)
  167. for signal1, signal2 in zip(segment1.analogsignals, segment2.analogsignals):
  168. self.compare_analogsignals(signal1, signal2, lazy=lazy)
  169. def compare_analogsignals(self, signal1, signal2, lazy=False):
  170. if not lazy:
  171. self._compare_objects(signal1, signal2)
  172. else:
  173. self._compare_objects(signal1, signal2, exclude_attr=['shape', 'signal'])
  174. assert signal2.lazy_shape == signal1.shape
  175. assert signal2.dtype == signal1.dtype
  176. def _compare_objects(self, object1, object2, exclude_attr=[]):
  177. assert object1.__class__.__name__ == object2.__class__.__name__
  178. assert object2.file_origin == self.filename
  179. assert_same_attributes(object1, object2, exclude=[
  180. 'file_origin',
  181. 'file_datetime'] + exclude_attr)
  182. assert_same_annotations(object1, object2)
  183. if __name__ == "__main__":
  184. unittest.main()