123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691 |
- # -*- coding: utf-8 -*-
- '''
- Test to make sure generated datasets are sane.
- '''
- # needed for python 3 compatibility
- from __future__ import absolute_import, division
- import unittest
- from datetime import datetime
- import numpy as np
- import quantities as pq
- from neo.core import (class_by_name, Block, Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch)
- from neo.test.generate_datasets import (generate_one_simple_block,
- generate_one_simple_segment,
- generate_from_supported_objects,
- get_fake_value, get_fake_values,
- fake_neo, TEST_ANNOTATIONS)
- from neo.test.tools import assert_arrays_equal, assert_neo_object_is_compliant
- class Test__generate_one_simple_segment(unittest.TestCase):
- def test_defaults(self):
- res = generate_one_simple_segment()
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 0)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 0)
- self.assertEqual(len(res.events), 0)
- self.assertEqual(len(res.epochs), 0)
- def test_all_supported(self):
- objects = [Block, Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_one_simple_segment(supported_objects=objects)
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 4)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 6)
- self.assertEqual(len(res.events), 3)
- self.assertEqual(len(res.epochs), 2)
- def test_half_supported(self):
- objects = [Segment,
- IrregularlySampledSignal, SpikeTrain,
- Epoch]
- res = generate_one_simple_segment(supported_objects=objects)
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 0)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 6)
- self.assertEqual(len(res.events), 0)
- self.assertEqual(len(res.epochs), 2)
- def test_all_without_block(self):
- objects = [Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_one_simple_segment(supported_objects=objects)
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 4)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 6)
- self.assertEqual(len(res.events), 3)
- self.assertEqual(len(res.epochs), 2)
- def test_all_without_segment_valueerror(self):
- objects = [Block,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- self.assertRaises(ValueError, generate_one_simple_segment,
- supported_objects=objects)
- class Test__generate_one_simple_block(unittest.TestCase):
- def test_defaults(self):
- res = generate_one_simple_block()
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 0)
- def test_all_supported(self):
- objects = [Block, Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_one_simple_block(supported_objects=objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 3)
- seg1, seg2, seg3 = res.segments
- self.assertEqual(len(seg1.analogsignals), 4)
- self.assertEqual(len(seg1.irregularlysampledsignals), 0)
- self.assertEqual(len(seg1.spiketrains), 6)
- self.assertEqual(len(seg1.events), 3)
- self.assertEqual(len(seg1.epochs), 2)
- self.assertEqual(len(seg2.analogsignals), 4)
- self.assertEqual(len(seg2.irregularlysampledsignals), 0)
- self.assertEqual(len(seg2.spiketrains), 6)
- self.assertEqual(len(seg2.events), 3)
- self.assertEqual(len(seg2.epochs), 2)
- self.assertEqual(len(seg3.analogsignals), 4)
- self.assertEqual(len(seg3.irregularlysampledsignals), 0)
- self.assertEqual(len(seg3.spiketrains), 6)
- self.assertEqual(len(seg3.events), 3)
- self.assertEqual(len(seg3.epochs), 2)
- def test_half_supported(self):
- objects = [Block, Segment,
- IrregularlySampledSignal, SpikeTrain,
- Epoch]
- res = generate_one_simple_block(supported_objects=objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 3)
- seg1, seg2, seg3 = res.segments
- self.assertEqual(len(seg1.analogsignals), 0)
- self.assertEqual(len(seg1.irregularlysampledsignals), 0)
- self.assertEqual(len(seg1.spiketrains), 6)
- self.assertEqual(len(seg1.events), 0)
- self.assertEqual(len(seg1.epochs), 2)
- self.assertEqual(len(seg2.analogsignals), 0)
- self.assertEqual(len(seg2.irregularlysampledsignals), 0)
- self.assertEqual(len(seg2.spiketrains), 6)
- self.assertEqual(len(seg2.events), 0)
- self.assertEqual(len(seg2.epochs), 2)
- self.assertEqual(len(seg3.analogsignals), 0)
- self.assertEqual(len(seg3.irregularlysampledsignals), 0)
- self.assertEqual(len(seg3.spiketrains), 6)
- self.assertEqual(len(seg3.events), 0)
- self.assertEqual(len(seg3.epochs), 2)
- def test_all_without_segment(self):
- objects = [Block,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_one_simple_block(supported_objects=objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 0)
- def test_all_without_block_valueerror(self):
- objects = [Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- self.assertRaises(ValueError, generate_one_simple_block,
- supported_objects=objects)
- class Test__generate_from_supported_objects(unittest.TestCase):
- def test_no_object_valueerror(self):
- objects = []
- self.assertRaises(ValueError, generate_from_supported_objects, objects)
- def test_all(self):
- objects = [Block, Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_from_supported_objects(objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 3)
- seg1, seg2, seg3 = res.segments
- self.assertEqual(len(seg1.analogsignals), 4)
- self.assertEqual(len(seg1.irregularlysampledsignals), 0)
- self.assertEqual(len(seg1.spiketrains), 6)
- self.assertEqual(len(seg1.events), 3)
- self.assertEqual(len(seg1.epochs), 2)
- self.assertEqual(len(seg2.analogsignals), 4)
- self.assertEqual(len(seg2.irregularlysampledsignals), 0)
- self.assertEqual(len(seg2.spiketrains), 6)
- self.assertEqual(len(seg2.events), 3)
- self.assertEqual(len(seg2.epochs), 2)
- self.assertEqual(len(seg3.analogsignals), 4)
- self.assertEqual(len(seg3.irregularlysampledsignals), 0)
- self.assertEqual(len(seg3.spiketrains), 6)
- self.assertEqual(len(seg3.events), 3)
- self.assertEqual(len(seg3.epochs), 2)
- def test_block(self):
- objects = [Block]
- res = generate_from_supported_objects(objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 0)
- def test_block_segment(self):
- objects = [Segment, Block]
- res = generate_from_supported_objects(objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 3)
- seg1, seg2, seg3 = res.segments
- self.assertEqual(len(seg1.analogsignals), 0)
- self.assertEqual(len(seg1.irregularlysampledsignals), 0)
- self.assertEqual(len(seg1.spiketrains), 0)
- self.assertEqual(len(seg1.events), 0)
- self.assertEqual(len(seg1.epochs), 0)
- self.assertEqual(len(seg2.analogsignals), 0)
- self.assertEqual(len(seg2.irregularlysampledsignals), 0)
- self.assertEqual(len(seg2.spiketrains), 0)
- self.assertEqual(len(seg2.events), 0)
- self.assertEqual(len(seg2.epochs), 0)
- self.assertEqual(len(seg3.analogsignals), 0)
- self.assertEqual(len(seg3.irregularlysampledsignals), 0)
- self.assertEqual(len(seg3.spiketrains), 0)
- self.assertEqual(len(seg3.events), 0)
- self.assertEqual(len(seg3.epochs), 0)
- def test_segment(self):
- objects = [Segment]
- res = generate_from_supported_objects(objects)
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 0)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 0)
- self.assertEqual(len(res.events), 0)
- self.assertEqual(len(res.epochs), 0)
- def test_all_without_block(self):
- objects = [Segment,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_from_supported_objects(objects)
- self.assertTrue(isinstance(res, Segment))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.analogsignals), 4)
- self.assertEqual(len(res.irregularlysampledsignals), 0)
- self.assertEqual(len(res.spiketrains), 6)
- self.assertEqual(len(res.events), 3)
- self.assertEqual(len(res.epochs), 2)
- def test_all_without_segment(self):
- objects = [Block,
- ChannelIndex, Unit,
- AnalogSignal,
- IrregularlySampledSignal, SpikeTrain,
- Event, Epoch]
- res = generate_from_supported_objects(supported_objects=objects)
- self.assertTrue(isinstance(res, Block))
- assert_neo_object_is_compliant(res)
- self.assertEqual(len(res.segments), 0)
- class Test__get_fake_value(unittest.TestCase):
- def setUp(self):
- np.random.seed(0)
- def test__t_start(self):
- name = 't_start'
- datatype = pq.Quantity
- targ = 0.0 * pq.millisecond
- res = get_fake_value(name, datatype)
- self.assertTrue(isinstance(res, pq.Quantity))
- self.assertEqual(res.units, pq.millisecond)
- assert_arrays_equal(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__t_stop(self):
- name = 't_stop'
- datatype = pq.Quantity
- targ = 1.0 * pq.millisecond
- res = get_fake_value(name, datatype)
- self.assertTrue(isinstance(res, pq.Quantity))
- self.assertEqual(res.units, pq.millisecond)
- assert_arrays_equal(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__sampling_rate(self):
- name = 'sampling_rate'
- datatype = pq.Quantity
- targ = 10000.0 * pq.Hz
- res = get_fake_value(name, datatype)
- self.assertTrue(isinstance(res, pq.Quantity))
- self.assertEqual(res.units, pq.Hz)
- assert_arrays_equal(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__str(self):
- name = 'test__str'
- datatype = str
- targ = str(np.random.randint(100000))
- res = get_fake_value(name, datatype, seed=0)
- self.assertTrue(isinstance(res, str))
- self.assertEqual(targ, res)
- def test__name(self):
- name = 'name'
- datatype = str
- obj = 'Block'
- targ = 'Block'+str(np.random.randint(100000))
- res = get_fake_value(name, datatype, seed=0, obj=obj)
- self.assertTrue(isinstance(res, str))
- self.assertEqual(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__description(self):
- name = 'description'
- datatype = str
- obj = Segment
- targ = 'test Segment '+str(np.random.randint(100000))
- res = get_fake_value(name, datatype, seed=0, obj=obj)
- self.assertTrue(isinstance(res, str))
- self.assertEqual(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__file_origin(self):
- name = 'file_origin'
- datatype = str
- targ = 'test_file.txt'
- res = get_fake_value(name, datatype, seed=0)
- self.assertTrue(isinstance(res, str))
- self.assertEqual(targ, res)
- self.assertRaises(ValueError, get_fake_value, name, datatype, dim=1)
- self.assertRaises(ValueError, get_fake_value, name, np.ndarray)
- def test__int(self):
- name = 'test__int'
- datatype = int
- targ = np.random.randint(100)
- res = get_fake_value(name, datatype, seed=0)
- self.assertTrue(isinstance(res, int))
- self.assertEqual(targ, res)
- def test__float(self):
- name = 'test__float'
- datatype = float
- targ = 1000. * np.random.random()
- res = get_fake_value(name, datatype, seed=0)
- self.assertTrue(isinstance(res, float))
- self.assertEqual(targ, res)
- def test__datetime(self):
- name = 'test__datetime'
- datatype = datetime
- targ = datetime.fromtimestamp(1000000000*np.random.random())
- res = get_fake_value(name, datatype, seed=0)
- self.assertTrue(isinstance(res, datetime))
- self.assertEqual(res, targ)
- def test__quantity(self):
- name = 'test__quantity'
- datatype = pq.Quantity
- dim = 2
- size = []
- units = np.random.choice(['nA', 'mA', 'A', 'mV', 'V'])
- for i in range(int(dim)):
- size.append(np.random.randint(5) + 1)
- targ = pq.Quantity(np.random.random(size)*1000, units=units)
- res = get_fake_value(name, datatype, dim=dim, seed=0)
- self.assertTrue(isinstance(res, pq.Quantity))
- self.assertEqual(res.units, getattr(pq, units))
- assert_arrays_equal(targ, res)
- def test__quantity_force_units(self):
- name = 'test__quantity'
- datatype = np.ndarray
- dim = 2
- units = pq.ohm
- size = []
- for i in range(int(dim)):
- size.append(np.random.randint(5) + 1)
- targ = pq.Quantity(np.random.random(size)*1000, units=units)
- res = get_fake_value(name, datatype, dim=dim, seed=0, units=units)
- self.assertTrue(isinstance(res, np.ndarray))
- assert_arrays_equal(targ, res)
- def test__ndarray(self):
- name = 'test__ndarray'
- datatype = np.ndarray
- dim = 2
- size = []
- for i in range(int(dim)):
- size.append(np.random.randint(5) + 1)
- targ = np.random.random(size)*1000
- res = get_fake_value(name, datatype, dim=dim, seed=0)
- self.assertTrue(isinstance(res, np.ndarray))
- assert_arrays_equal(targ, res)
- def test__list(self):
- name = 'test__list'
- datatype = list
- dim = 2
- size = []
- for i in range(int(dim)):
- size.append(np.random.randint(5) + 1)
- targ = (np.random.random(size)*1000).tolist()
- res = get_fake_value(name, datatype, dim=dim, seed=0)
- self.assertTrue(isinstance(res, list))
- self.assertEqual(targ, res)
- def test__other_valueerror(self):
- name = 'test__other_fail'
- datatype = set([1, 2, 3, 4])
- self.assertRaises(ValueError, get_fake_value, name, datatype)
- class Test__get_fake_values(unittest.TestCase):
- def setUp(self):
- np.random.seed(0)
- self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
- range(len(TEST_ANNOTATIONS))])
- self.annotations['seed'] = 0
- def subcheck__get_fake_values(self, cls):
- res1 = get_fake_values(cls, annotate=False, seed=0)
- res2 = get_fake_values(cls, annotate=True, seed=0)
- if hasattr(cls, 'lower'):
- cls = class_by_name[cls]
- attrs = cls._necessary_attrs + cls._recommended_attrs
- attrnames = [attr[0] for attr in attrs]
- attrtypes = [attr[1] for attr in attrs]
- attritems = zip(attrnames, attrtypes)
- attrannnames = attrnames + list(self.annotations.keys())
- self.assertEqual(sorted(attrnames), sorted(res1.keys()))
- self.assertEqual(sorted(attrannnames), sorted(res2.keys()))
- items11 = [(name, type(value)) for name, value in res1.items()]
- self.assertEqual(sorted(attritems), sorted(items11))
- for name, value in res1.items():
- try:
- self.assertEqual(res2[name], value)
- except ValueError:
- assert_arrays_equal(res2[name], value)
- for name, value in self.annotations.items():
- self.assertFalse(name in res1)
- self.assertEqual(res2[name], value)
- for attr in attrs:
- name = attr[0]
- if len(attr) < 3:
- continue
- dim = attr[2]
- self.assertEqual(dim, res1[name].ndim)
- self.assertEqual(dim, res2[name].ndim)
- if len(attr) < 4:
- continue
- dtype = attr[3]
- self.assertEqual(dtype.kind, res1[name].dtype.kind)
- self.assertEqual(dtype.kind, res2[name].dtype.kind)
- def check__get_fake_values(self, cls):
- self.subcheck__get_fake_values(cls)
- self.subcheck__get_fake_values(cls.__name__)
- def test__analogsignalarray(self):
- self.check__get_fake_values(AnalogSignal)
- def test__block(self):
- self.check__get_fake_values(Block)
- def test__epoch(self):
- self.check__get_fake_values(Epoch)
- def test__event(self):
- self.check__get_fake_values(Event)
- def test__irregularlysampledsignal(self):
- self.check__get_fake_values(IrregularlySampledSignal)
- def test__channelindex(self):
- self.check__get_fake_values(ChannelIndex)
- def test__segment(self):
- self.check__get_fake_values(Segment)
- def test__spiketrain(self):
- self.check__get_fake_values(SpikeTrain)
- def test__unit(self):
- self.check__get_fake_values(Unit)
- class Test__generate_datasets(unittest.TestCase):
- def setUp(self):
- self.annotations = dict([(str(x), TEST_ANNOTATIONS[x]) for x in
- range(len(TEST_ANNOTATIONS))])
- def check__generate_datasets(self, cls):
- clsname = cls.__name__
- self.subcheck__generate_datasets(cls, cascade=True)
- self.subcheck__generate_datasets(cls, cascade=True, seed=0)
- self.subcheck__generate_datasets(cls, cascade=False)
- self.subcheck__generate_datasets(cls, cascade=False, seed=0)
- self.subcheck__generate_datasets(clsname, cascade=True)
- self.subcheck__generate_datasets(clsname, cascade=True, seed=0)
- self.subcheck__generate_datasets(clsname, cascade=False)
- self.subcheck__generate_datasets(clsname, cascade=False, seed=0)
- def subcheck__generate_datasets(self, cls, cascade, seed=None):
- self.annotations['seed'] = seed
- if seed is None:
- res = fake_neo(obj_type=cls, cascade=cascade)
- else:
- res = fake_neo(obj_type=cls, cascade=cascade, seed=seed)
- if not hasattr(cls, 'lower'):
- self.assertTrue(isinstance(res, cls))
- else:
- self.assertEqual(res.__class__.__name__, cls)
- assert_neo_object_is_compliant(res)
- self.assertEqual(res.annotations, self.annotations)
- resattr = get_fake_values(cls, annotate=False, seed=0)
- if seed is not None:
- for name, value in resattr.items():
- if name in ['channel_names',
- 'channel_indexes',
- 'channel_index',
- 'coordinates']:
- continue
- try:
- try:
- resvalue = getattr(res, name)
- except AttributeError:
- if name == 'signal':
- continue
- raise
- try:
- self.assertEqual(resvalue, value)
- except ValueError:
- assert_arrays_equal(resvalue, value)
- except BaseException as exc:
- exc.args += ('from %s' % name,)
- raise
- if not getattr(res, '_child_objects', ()):
- pass
- elif not cascade:
- self.assertEqual(res.children, ())
- else:
- self.assertNotEqual(res.children, ())
- if cls in ['ChannelIndex', ChannelIndex]:
- for i, unit in enumerate(res.units):
- for sigarr in res.analogsignals:
- self.assertEqual(unit.get_channel_indexes()[0],
- sigarr.get_channel_index()[i])
- def test__analogsignalarray(self):
- self.check__generate_datasets(AnalogSignal)
- def test__block(self):
- self.check__generate_datasets(AnalogSignal)
- def test__epoch(self):
- self.check__generate_datasets(Epoch)
- def test__event(self):
- self.check__generate_datasets(Event)
- def test__irregularlysampledsignal(self):
- self.check__generate_datasets(IrregularlySampledSignal)
- def test__channelindex(self):
- self.check__generate_datasets(ChannelIndex)
- def test__segment(self):
- self.check__generate_datasets(Segment)
- def test__spiketrain(self):
- self.check__generate_datasets(SpikeTrain)
- def test__unit(self):
- self.check__generate_datasets(Unit)
|