123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- # -*- coding: utf-8 -*-
- """
- Tests of the neo.io.pynnio.PyNNNumpyIO and neo.io.pynnio.PyNNTextIO classes
- """
- # needed for python 3 compatibility
- from __future__ import absolute_import, division
- import os
- import unittest
- import numpy as np
- import quantities as pq
- from neo.core import Segment, AnalogSignal, SpikeTrain
- from neo.io import PyNNNumpyIO, PyNNTextIO
- from numpy.testing import assert_array_equal
- from neo.test.tools import assert_arrays_equal, assert_file_contents_equal
- from neo.test.iotest.common_io_test import BaseTestIO
- #class CommonTestPyNNNumpyIO(BaseTestIO, unittest.TestCase):
- # ioclass = PyNNNumpyIO
- NCELLS = 5
- class CommonTestPyNNTextIO(BaseTestIO, unittest.TestCase):
- ioclass = PyNNTextIO
- read_and_write_is_bijective = False
- def read_test_file(filename):
- contents = np.load(filename)
- data = contents["data"]
- metadata = {}
- for name, value in contents['metadata']:
- try:
- metadata[name] = eval(value)
- except Exception:
- metadata[name] = value
- return data, metadata
- read_test_file.__test__ = False
- class BaseTestPyNNIO(object):
- __test__ = False
- def tearDown(self):
- if os.path.exists(self.test_file):
- os.remove(self.test_file)
- def test_write_segment(self):
- in_ = self.io_cls(self.test_file)
- write_test_file = "write_test.%s" % self.file_extension
- out = self.io_cls(write_test_file)
- out.write_segment(in_.read_segment(lazy=False, cascade=True))
- assert_file_contents_equal(self.test_file, write_test_file)
- if os.path.exists(write_test_file):
- os.remove(write_test_file)
- def build_test_data(self, variable='v'):
- metadata = {
- 'size': NCELLS,
- 'first_index': 0,
- 'first_id': 0,
- 'n': 505,
- 'variable': variable,
- 'last_id': NCELLS - 1,
- 'last_index': NCELLS - 1,
- 'dt': 0.1,
- 'label': "population0",
- }
- if variable == 'v':
- metadata['units'] = 'mV'
- elif variable == 'spikes':
- metadata['units'] = 'ms'
- data = np.empty((505, 2))
- for i in range(NCELLS):
- # signal
- data[i*101:(i+1)*101, 0] = np.arange(i, i+101, dtype=float)
- # index
- data[i*101:(i+1)*101, 1] = i*np.ones((101,), dtype=float)
- return data, metadata
- build_test_data.__test__ = False
- class BaseTestPyNNIO_Signals(BaseTestPyNNIO):
- def setUp(self):
- self.test_file = "test_file_v.%s" % self.file_extension
- self.write_test_file("v")
- def test_read_segment_containing_analogsignals_using_eager_cascade(self):
- # eager == not lazy
- io = self.io_cls(self.test_file)
- segment = io.read_segment(lazy=False, cascade=True)
- self.assertIsInstance(segment, Segment)
- self.assertEqual(len(segment.analogsignals), 1)
- as0 = segment.analogsignals[0]
- self.assertIsInstance(as0, AnalogSignal)
- self.assertEqual(as0.shape, (101, NCELLS))
- assert_array_equal(as0[:, 0],
- AnalogSignal(np.arange(0, 101, dtype=float),
- sampling_period=0.1*pq.ms,
- t_start=0*pq.s,
- units=pq.mV))
- as4 = as0[:, 4]
- self.assertIsInstance(as4, AnalogSignal)
- assert_array_equal(as4,
- AnalogSignal(np.arange(4, 105, dtype=float),
- sampling_period=0.1*pq.ms,
- t_start=0*pq.s,
- units=pq.mV))
- # test annotations (stuff from file metadata)
- def test_read_analogsignal_using_eager(self):
- io = self.io_cls(self.test_file)
- sig = io.read_analogsignal(lazy=False)
- self.assertIsInstance(sig, AnalogSignal)
- assert_array_equal(sig[:, 3],
- AnalogSignal(np.arange(3, 104, dtype=float),
- sampling_period=0.1*pq.ms,
- t_start=0*pq.s,
- units=pq.mV))
- # should test annotations: 'channel_index', etc.
- def test_read_spiketrain_should_fail_with_analogsignal_file(self):
- io = self.io_cls(self.test_file)
- self.assertRaises(TypeError, io.read_spiketrain, channel_index=0)
- class BaseTestPyNNIO_Spikes(BaseTestPyNNIO):
- def setUp(self):
- self.test_file = "test_file_spikes.%s" % self.file_extension
- self.write_test_file("spikes")
- def test_read_segment_containing_spiketrains_using_eager_cascade(self):
- io = self.io_cls(self.test_file)
- segment = io.read_segment(lazy=False, cascade=True)
- self.assertIsInstance(segment, Segment)
- self.assertEqual(len(segment.spiketrains), NCELLS)
- st0 = segment.spiketrains[0]
- self.assertIsInstance(st0, SpikeTrain)
- assert_arrays_equal(st0,
- SpikeTrain(np.arange(0, 101, dtype=float),
- t_start=0*pq.s,
- t_stop=101*pq.ms,
- units=pq.ms))
- st4 = segment.spiketrains[4]
- self.assertIsInstance(st4, SpikeTrain)
- assert_arrays_equal(st4,
- SpikeTrain(np.arange(4, 105, dtype=float),
- t_start=0*pq.s,
- t_stop=105*pq.ms,
- units=pq.ms))
- # test annotations (stuff from file metadata)
- def test_read_spiketrain_using_eager(self):
- io = self.io_cls(self.test_file)
- st3 = io.read_spiketrain(lazy=False, channel_index=3)
- self.assertIsInstance(st3, SpikeTrain)
- assert_arrays_equal(st3,
- SpikeTrain(np.arange(3, 104, dtype=float),
- t_start=0*pq.s,
- t_stop=104*pq.s,
- units=pq.ms))
- # should test annotations: 'channel_index', etc.
- def test_read_analogsignal_should_fail_with_spiketrain_file(self):
- io = self.io_cls(self.test_file)
- self.assertRaises(TypeError, io.read_analogsignal, channel_index=2)
- class BaseTestPyNNNumpyIO(object):
- io_cls = PyNNNumpyIO
- file_extension = "npz"
- def write_test_file(self, variable='v', check=False):
- data, metadata = self.build_test_data(variable)
- metadata_array = np.array(sorted(metadata.items()))
- np.savez(self.test_file, data=data, metadata=metadata_array)
- if check:
- data1, metadata1 = read_test_file(self.test_file)
- assert metadata == metadata1, "%s != %s" % (metadata, metadata1)
- assert data.shape == data1.shape == (505, 2), \
- "%s, %s, (505, 2)" % (data.shape, data1.shape)
- assert (data == data1).all()
- assert metadata["n"] == 505
- write_test_file.__test__ = False
- class BaseTestPyNNTextIO(object):
- io_cls = PyNNTextIO
- file_extension = "txt"
- def write_test_file(self, variable='v', check=False):
- data, metadata = self.build_test_data(variable)
- with open(self.test_file, 'wb') as f:
- for item in sorted(metadata.items()):
- f.write(("# %s = %s\n" % item).encode('utf8'))
- np.savetxt(f, data)
- if check:
- raise NotImplementedError
- write_test_file.__test__ = False
- class TestPyNNNumpyIO_Signals(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Signals,
- unittest.TestCase):
- __test__ = True
- class TestPyNNNumpyIO_Spikes(BaseTestPyNNNumpyIO, BaseTestPyNNIO_Spikes,
- unittest.TestCase):
- __test__ = True
- class TestPyNNTextIO_Signals(BaseTestPyNNTextIO, BaseTestPyNNIO_Signals,
- unittest.TestCase):
- __test__ = True
- class TestPyNNTextIO_Spikes(BaseTestPyNNTextIO, BaseTestPyNNIO_Spikes,
- unittest.TestCase):
- __test__ = True
- if __name__ == '__main__':
- unittest.main()
|