test_pynnio.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of the neo.io.pynnio.PyNNNumpyIO and neo.io.pynnio.PyNNTextIO classes
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division
  7. import os
  8. try:
  9. import unittest2 as unittest
  10. except ImportError:
  11. import unittest
  12. import numpy as np
  13. import quantities as pq
  14. from neo.core import Segment, AnalogSignal, SpikeTrain
  15. from neo.io import PyNNNumpyIO, PyNNTextIO
  16. from numpy.testing import assert_array_equal
  17. from neo.test.tools import assert_arrays_equal, assert_file_contents_equal
  18. from neo.test.iotest.common_io_test import BaseTestIO
  19. #class CommonTestPyNNNumpyIO(BaseTestIO, unittest.TestCase):
  20. # ioclass = PyNNNumpyIO
  21. NCELLS = 5
  22. class CommonTestPyNNTextIO(BaseTestIO, unittest.TestCase):
  23. ioclass = PyNNTextIO
  24. read_and_write_is_bijective = False
  25. def read_test_file(filename):
  26. contents = np.load(filename)
  27. data = contents["data"]
  28. metadata = {}
  29. for name, value in contents['metadata']:
  30. try:
  31. metadata[name] = eval(value)
  32. except Exception:
  33. metadata[name] = value
  34. return data, metadata
  35. read_test_file.__test__ = False
  36. class BaseTestPyNNIO(object):
  37. __test__ = False
  38. def tearDown(self):
  39. if os.path.exists(self.test_file):
  40. os.remove(self.test_file)
  41. def test_write_segment(self):
  42. in_ = self.io_cls(self.test_file)
  43. write_test_file = "write_test.%s" % self.file_extension
  44. out = self.io_cls(write_test_file)
  45. out.write_segment(in_.read_segment(lazy=False, cascade=True))
  46. assert_file_contents_equal(self.test_file, write_test_file)
  47. if os.path.exists(write_test_file):
  48. os.remove(write_test_file)
  49. def build_test_data(self, variable='v'):
  50. metadata = {
  51. 'size': NCELLS,
  52. 'first_index': 0,
  53. 'first_id': 0,
  54. 'n': 505,
  55. 'variable': variable,
  56. 'last_id': NCELLS - 1,
  57. 'last_index': NCELLS - 1,
  58. 'dt': 0.1,
  59. 'label': "population0",
  60. }
  61. if variable == 'v':
  62. metadata['units'] = 'mV'
  63. elif variable == 'spikes':
  64. metadata['units'] = 'ms'
  65. data = np.empty((505, 2))
  66. for i in range(NCELLS):
  67. # signal
  68. data[i*101:(i+1)*101, 0] = np.arange(i, i+101, dtype=float)
  69. # index
  70. data[i*101:(i+1)*101, 1] = i*np.ones((101,), dtype=float)
  71. return data, metadata
  72. build_test_data.__test__ = False
  73. class BaseTestPyNNIO_Signals(BaseTestPyNNIO):
  74. def setUp(self):
  75. self.test_file = "test_file_v.%s" % self.file_extension
  76. self.write_test_file("v")
  77. def test_read_segment_containing_analogsignals_using_eager_cascade(self):
  78. # eager == not lazy
  79. io = self.io_cls(self.test_file)
  80. segment = io.read_segment(lazy=False, cascade=True)
  81. self.assertIsInstance(segment, Segment)
  82. self.assertEqual(len(segment.analogsignals), 1)
  83. as0 = segment.analogsignals[0]
  84. self.assertIsInstance(as0, AnalogSignal)
  85. self.assertEqual(as0.shape, (101, NCELLS))
  86. assert_array_equal(as0[:, 0],
  87. AnalogSignal(np.arange(0, 101, dtype=float),
  88. sampling_period=0.1*pq.ms,
  89. t_start=0*pq.s,
  90. units=pq.mV))
  91. as4 = as0[:, 4]
  92. self.assertIsInstance(as4, AnalogSignal)
  93. assert_array_equal(as4,
  94. AnalogSignal(np.arange(4, 105, dtype=float),
  95. sampling_period=0.1*pq.ms,
  96. t_start=0*pq.s,
  97. units=pq.mV))
  98. # test annotations (stuff from file metadata)
  99. def test_read_analogsignal_using_eager(self):
  100. io = self.io_cls(self.test_file)
  101. sig = io.read_analogsignal(lazy=False)
  102. self.assertIsInstance(sig, AnalogSignal)
  103. assert_array_equal(sig[:, 3],
  104. AnalogSignal(np.arange(3, 104, dtype=float),
  105. sampling_period=0.1*pq.ms,
  106. t_start=0*pq.s,
  107. units=pq.mV))
  108. # should test annotations: 'channel_index', etc.
  109. def test_read_spiketrain_should_fail_with_analogsignal_file(self):
  110. io = self.io_cls(self.test_file)
  111. self.assertRaises(TypeError, io.read_spiketrain, channel_index=0)
  112. class BaseTestPyNNIO_Spikes(BaseTestPyNNIO):
  113. def setUp(self):
  114. self.test_file = "test_file_spikes.%s" % self.file_extension
  115. self.write_test_file("spikes")
  116. def test_read_segment_containing_spiketrains_using_eager_cascade(self):
  117. io = self.io_cls(self.test_file)
  118. segment = io.read_segment(lazy=False, cascade=True)
  119. self.assertIsInstance(segment, Segment)
  120. self.assertEqual(len(segment.spiketrains), NCELLS)
  121. st0 = segment.spiketrains[0]
  122. self.assertIsInstance(st0, SpikeTrain)
  123. assert_arrays_equal(st0,
  124. SpikeTrain(np.arange(0, 101, dtype=float),
  125. t_start=0*pq.s,
  126. t_stop=101*pq.ms,
  127. units=pq.ms))
  128. st4 = segment.spiketrains[4]
  129. self.assertIsInstance(st4, SpikeTrain)
  130. assert_arrays_equal(st4,
  131. SpikeTrain(np.arange(4, 105, dtype=float),
  132. t_start=0*pq.s,
  133. t_stop=105*pq.ms,
  134. units=pq.ms))
  135. # test annotations (stuff from file metadata)
  136. def test_read_spiketrain_using_eager(self):
  137. io = self.io_cls(self.test_file)
  138. st3 = io.read_spiketrain(lazy=False, channel_index=3)
  139. self.assertIsInstance(st3, SpikeTrain)
  140. assert_arrays_equal(st3,
  141. SpikeTrain(np.arange(3, 104, dtype=float),
  142. t_start=0*pq.s,
  143. t_stop=104*pq.s,
  144. units=pq.ms))
  145. # should test annotations: 'channel_index', etc.
  146. def test_read_analogsignal_should_fail_with_spiketrain_file(self):
  147. io = self.io_cls(self.test_file)
  148. self.assertRaises(TypeError, io.read_analogsignal, channel_index=2)
  149. class BaseTestPyNNNumpyIO(object):
  150. io_cls = PyNNNumpyIO
  151. file_extension = "npz"
  152. def write_test_file(self, variable='v', check=False):
  153. data, metadata = self.build_test_data(variable)
  154. metadata_array = np.array(sorted(metadata.items()))
  155. np.savez(self.test_file, data=data, metadata=metadata_array)
  156. if check:
  157. data1, metadata1 = read_test_file(self.test_file)
  158. assert metadata == metadata1, "%s != %s" % (metadata, metadata1)
  159. assert data.shape == data1.shape == (505, 2), \
  160. "%s, %s, (505, 2)" % (data.shape, data1.shape)
  161. assert (data == data1).all()
  162. assert metadata["n"] == 505
  163. write_test_file.__test__ = False
  164. class BaseTestPyNNTextIO(object):
  165. io_cls = PyNNTextIO
  166. file_extension = "txt"
  167. def write_test_file(self, variable='v', check=False):
  168. data, metadata = self.build_test_data(variable)
  169. with open(self.test_file, 'wb') as f:
  170. for item in sorted(metadata.items()):
  171. f.write(("# %s = %s\n" % item).encode('utf8'))
  172. np.savetxt(f, data)
  173. if check:
  174. raise NotImplementedError
  175. write_test_file.__test__ = False
  176. class TestPyNNNumpyIO_Signals(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Signals,
  177. unittest.TestCase):
  178. __test__ = True
  179. class TestPyNNNumpyIO_Spikes(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Spikes,
  180. unittest.TestCase):
  181. __test__ = True
  182. class TestPyNNTextIO_Signals(BaseTestPyNNTextIO, BaseTestPyNNIO_Signals,
  183. unittest.TestCase):
  184. __test__ = True
  185. class TestPyNNTextIO_Spikes(BaseTestPyNNTextIO, BaseTestPyNNIO_Spikes,
  186. unittest.TestCase):
  187. __test__ = True
  188. if __name__ == '__main__':
  189. unittest.main()