# -*- coding: utf-8 -*- # Copyright (c) 2016, German Neuroinformatics Node (G-Node) # Achilleas Koutsou # # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted under the terms of the BSD License. See # LICENSE file in the root of the Project. """ Tests for NixIO """ import os import shutil from collections import Iterable from datetime import datetime from tempfile import mkdtemp import unittest import string import numpy as np import quantities as pq from neo.core import (Block, Segment, ChannelIndex, AnalogSignal, IrregularlySampledSignal, Unit, SpikeTrain, Event, Epoch) from neo.test.iotest.common_io_test import BaseTestIO from neo.io.nixio import NixIO, create_quantity, units_to_string, neover try: import nixio as nix HAVE_NIX = True except ImportError: HAVE_NIX = False try: from unittest import mock SKIPMOCK = False except ImportError: SKIPMOCK = True @unittest.skipUnless(HAVE_NIX, "Requires NIX") class NixIOTest(unittest.TestCase): io = None tempdir = None filename = None def compare_blocks(self, neoblocks, nixblocks): for neoblock, nixblock in zip(neoblocks, nixblocks): self.compare_attr(neoblock, nixblock) self.assertEqual(len(neoblock.segments), len(nixblock.groups)) for idx, neoseg in enumerate(neoblock.segments): nixgrp = nixblock.groups[neoseg.annotations["nix_name"]] self.compare_segment_group(neoseg, nixgrp) self.assertEqual(len(neoblock.channel_indexes), len(nixblock.sources)) for idx, neochx in enumerate(neoblock.channel_indexes): nixsrc = nixblock.sources[neochx.annotations["nix_name"]] self.compare_chx_source(neochx, nixsrc) self.check_refs(neoblock, nixblock) def compare_chx_source(self, neochx, nixsrc): self.compare_attr(neochx, nixsrc) nix_channels = list(src for src in nixsrc.sources if src.type == "neo.channelindex") self.assertEqual(len(neochx.index), len(nix_channels)) if len(neochx.channel_ids): nix_chanids = list(src.metadata["channel_id"] for src in nixsrc.sources if src.type == "neo.channelindex") self.assertEqual(len(neochx.channel_ids), len(nix_chanids)) for nixchan in nix_channels: nixchanidx = nixchan.metadata["index"] try: neochanpos = list(neochx.index).index(nixchanidx) except ValueError: self.fail("Channel indexes do not match.") if len(neochx.channel_names): neochanname = neochx.channel_names[neochanpos] if ((not isinstance(neochanname, str)) and isinstance(neochanname, bytes)): neochanname = neochanname.decode() nixchanname = nixchan.metadata["neo_name"] self.assertEqual(neochanname, nixchanname) if len(neochx.channel_ids): neochanid = neochx.channel_ids[neochanpos] nixchanid = nixchan.metadata["channel_id"] self.assertEqual(neochanid, nixchanid) elif "channel_id" in nixchan.metadata: self.fail("Channel ID not loaded") nix_units = list(src for src in nixsrc.sources if src.type == "neo.unit") self.assertEqual(len(neochx.units), len(nix_units)) for neounit in neochx.units: nixunit = nixsrc.sources[neounit.annotations["nix_name"]] self.compare_attr(neounit, nixunit) def check_refs(self, neoblock, nixblock): """ Checks whether the references between objects that are not nested are mapped correctly (e.g., SpikeTrains referenced by a Unit). :param neoblock: A Neo block :param nixblock: The corresponding NIX block """ for idx, neochx in enumerate(neoblock.channel_indexes): nixchx = nixblock.sources[neochx.annotations["nix_name"]] # AnalogSignals referencing CHX neoasigs = list(sig.annotations["nix_name"] for sig in neochx.analogsignals) nixasigs = list(set(da.metadata.name for da in nixblock.data_arrays if da.type == "neo.analogsignal" and nixchx in da.sources)) self.assertEqual(len(neoasigs), len(nixasigs)) # IrregularlySampledSignals referencing CHX neoisigs = list(sig.annotations["nix_name"] for sig in neochx.irregularlysampledsignals) nixisigs = list( set(da.metadata.name for da in nixblock.data_arrays if da.type == "neo.irregularlysampledsignal" and nixchx in da.sources) ) self.assertEqual(len(neoisigs), len(nixisigs)) # SpikeTrains referencing CHX and Units for sidx, neounit in enumerate(neochx.units): nixunit = nixchx.sources[neounit.annotations["nix_name"]] neosts = list(st.annotations["nix_name"] for st in neounit.spiketrains) nixsts = list(mt for mt in nixblock.multi_tags if mt.type == "neo.spiketrain" and nixunit.name in mt.sources) # SpikeTrains must also reference CHX for nixst in nixsts: self.assertIn(nixchx.name, nixst.sources) nixsts = list(st.name for st in nixsts) self.assertEqual(len(neosts), len(nixsts)) for neoname in neosts: if neoname: self.assertIn(neoname, nixsts) # Events and Epochs must reference all Signals in the Group (NIX only) for nixgroup in nixblock.groups: nixevep = list(mt for mt in nixgroup.multi_tags if mt.type in ["neo.event", "neo.epoch"]) nixsigs = list(da.name for da in nixgroup.data_arrays if da.type in ["neo.analogsignal", "neo.irregularlysampledsignal"]) for nee in nixevep: for ns in nixsigs: self.assertIn(ns, nee.references) def compare_segment_group(self, neoseg, nixgroup): self.compare_attr(neoseg, nixgroup) neo_signals = neoseg.analogsignals + neoseg.irregularlysampledsignals self.compare_signals_das(neo_signals, nixgroup.data_arrays) neo_eests = neoseg.epochs + neoseg.events + neoseg.spiketrains self.compare_eests_mtags(neo_eests, nixgroup.multi_tags) def compare_signals_das(self, neosignals, data_arrays): totalsignals = 0 for sig in neosignals: dalist = list() nixname = sig.annotations["nix_name"] for da in data_arrays: if da.metadata.name == nixname: dalist.append(da) _, nsig = np.shape(sig) totalsignals += nsig self.assertEqual(nsig, len(dalist)) self.compare_signal_dalist(sig, dalist) self.assertEqual(totalsignals, len(data_arrays)) def compare_signal_dalist(self, neosig, nixdalist): """ Check if a Neo Analog or IrregularlySampledSignal matches a list of NIX DataArrays. :param neosig: Neo Analog or IrregularlySampledSignal :param nixdalist: List of DataArrays """ nixmd = nixdalist[0].metadata self.assertTrue(all(nixmd == da.metadata for da in nixdalist)) neounit = neosig.units for sig, da in zip(np.transpose(neosig), nixdalist): self.compare_attr(neosig, da) daquant = create_quantity(da[:], da.unit) np.testing.assert_almost_equal(sig, daquant) nixunit = create_quantity(1, da.unit) self.assertEqual(neounit, nixunit) timedim = da.dimensions[0] if isinstance(neosig, AnalogSignal): self.assertEqual(timedim.dimension_type, nix.DimensionType.Sample) neosp = neosig.sampling_period nixsp = create_quantity(timedim.sampling_interval, timedim.unit) self.assertEqual(neosp, nixsp) tsunit = timedim.unit if "t_start.units" in da.metadata.props: tsunit = da.metadata["t_start.units"] neots = neosig.t_start nixts = create_quantity(timedim.offset, tsunit) self.assertEqual(neots, nixts) elif isinstance(neosig, IrregularlySampledSignal): self.assertEqual(timedim.dimension_type, nix.DimensionType.Range) np.testing.assert_almost_equal(neosig.times.magnitude, timedim.ticks) self.assertEqual(timedim.unit, units_to_string(neosig.times.units)) def compare_eests_mtags(self, eestlist, mtaglist): self.assertEqual(len(eestlist), len(mtaglist)) for eest in eestlist: mtag = mtaglist[eest.annotations["nix_name"]] if isinstance(eest, Epoch): self.compare_epoch_mtag(eest, mtag) elif isinstance(eest, Event): self.compare_event_mtag(eest, mtag) elif isinstance(eest, SpikeTrain): self.compare_spiketrain_mtag(eest, mtag) def compare_epoch_mtag(self, epoch, mtag): self.assertEqual(mtag.type, "neo.epoch") self.compare_attr(epoch, mtag) pos = mtag.positions posquant = create_quantity(pos[:], pos.unit) ext = mtag.extents extquant = create_quantity(ext[:], ext.unit) np.testing.assert_almost_equal(epoch.as_quantity(), posquant) np.testing.assert_almost_equal(epoch.durations, extquant) for neol, nixl in zip(epoch.labels, mtag.positions.dimensions[0].labels): # Dirty. Should find the root cause instead if isinstance(neol, bytes): neol = neol.decode() if isinstance(nixl, bytes): nixl = nixl.decode() self.assertEqual(neol, nixl) def compare_event_mtag(self, event, mtag): self.assertEqual(mtag.type, "neo.event") self.compare_attr(event, mtag) pos = mtag.positions posquant = create_quantity(pos[:], pos.unit) np.testing.assert_almost_equal(event.as_quantity(), posquant) for neol, nixl in zip(event.labels, mtag.positions.dimensions[0].labels): # Dirty. Should find the root cause instead # Only happens in 3.2 if isinstance(neol, bytes): neol = neol.decode() if isinstance(nixl, bytes): nixl = nixl.decode() self.assertEqual(neol, nixl) def compare_spiketrain_mtag(self, spiketrain, mtag): self.assertEqual(mtag.type, "neo.spiketrain") self.compare_attr(spiketrain, mtag) pos = mtag.positions posquant = create_quantity(pos[:], pos.unit) np.testing.assert_almost_equal(spiketrain.as_quantity(), posquant) if len(mtag.features): neowfs = spiketrain.waveforms nixwfs = mtag.features[0].data self.assertEqual(np.shape(neowfs), np.shape(nixwfs)) for nixwf, neowf in zip(nixwfs, neowfs): for nixrow, neorow in zip(nixwf, neowf): for nixv, neov in zip(nixrow, neorow): self.assertEqual(create_quantity(nixv, nixwfs.unit), neov) self.assertEqual(nixwfs.dimensions[0].dimension_type, nix.DimensionType.Set) self.assertEqual(nixwfs.dimensions[1].dimension_type, nix.DimensionType.Set) self.assertEqual(nixwfs.dimensions[2].dimension_type, nix.DimensionType.Sample) def compare_attr(self, neoobj, nixobj): if isinstance(neoobj, (AnalogSignal, IrregularlySampledSignal)): nix_name = ".".join(nixobj.name.split(".")[:-1]) else: nix_name = nixobj.name self.assertEqual(neoobj.annotations["nix_name"], nix_name) self.assertEqual(neoobj.description, nixobj.definition) if hasattr(neoobj, "rec_datetime") and neoobj.rec_datetime: self.assertEqual(neoobj.rec_datetime, datetime.fromtimestamp(nixobj.created_at)) if hasattr(neoobj, "file_datetime") and neoobj.file_datetime: self.assertEqual(neoobj.file_datetime, datetime.fromtimestamp( nixobj.metadata["file_datetime"])) if neoobj.annotations: nixmd = nixobj.metadata for k, v, in neoobj.annotations.items(): if k == "nix_name": continue if isinstance(v, pq.Quantity): nixunit = nixmd.props[str(k)].unit self.assertEqual(nixunit, units_to_string(v.units)) nixvalue = nixmd[str(k)] if isinstance(nixvalue, Iterable): nixvalue = np.array(nixvalue) np.testing.assert_almost_equal(nixvalue, v.magnitude) else: self.assertEqual(nixmd[str(k)], v, "Property value mismatch: {}".format(k)) @classmethod def create_full_nix_file(cls, filename): nixfile = nix.File.open(filename, nix.FileMode.Overwrite) nix_block_a = nixfile.create_block(cls.rword(10), "neo.block") nix_block_a.definition = cls.rsentence(5, 10) nix_block_b = nixfile.create_block(cls.rword(10), "neo.block") nix_block_b.definition = cls.rsentence(3, 3) nix_block_a.metadata = nixfile.create_section( nix_block_a.name, nix_block_a.name + ".metadata" ) nix_block_a.metadata["neo_name"] = cls.rword(5) nix_block_b.metadata = nixfile.create_section( nix_block_b.name, nix_block_b.name + ".metadata" ) nix_block_b.metadata["neo_name"] = cls.rword(5) nix_blocks = [nix_block_a, nix_block_b] for blk in nix_blocks: for ind in range(3): group = blk.create_group(cls.rword(), "neo.segment") group.definition = cls.rsentence(10, 15) group_md = blk.metadata.create_section( group.name, group.name + ".metadata" ) group.metadata = group_md blk = nix_blocks[0] group = blk.groups[0] allspiketrains = list() allsignalgroups = list() # analogsignals for n in range(5): siggroup = list() asig_name = "{}_asig{}".format(cls.rword(10), n) asig_definition = cls.rsentence(5, 5) asig_md = group.metadata.create_section(asig_name, asig_name + ".metadata") for idx in range(3): da_asig = blk.create_data_array( "{}.{}".format(asig_name, idx), "neo.analogsignal", data=cls.rquant(100, 1) ) da_asig.definition = asig_definition da_asig.unit = "mV" da_asig.metadata = asig_md timedim = da_asig.append_sampled_dimension(0.01) timedim.unit = "ms" timedim.label = "time" timedim.offset = 10 da_asig.append_set_dimension() group.data_arrays.append(da_asig) siggroup.append(da_asig) asig_md["t_start.dim"] = "ms" allsignalgroups.append(siggroup) # irregularlysampledsignals for n in range(2): siggroup = list() isig_name = "{}_isig{}".format(cls.rword(10), n) isig_definition = cls.rsentence(12, 12) isig_md = group.metadata.create_section(isig_name, isig_name + ".metadata") isig_times = cls.rquant(200, 1, True) for idx in range(10): da_isig = blk.create_data_array( "{}.{}".format(isig_name, idx), "neo.irregularlysampledsignal", data=cls.rquant(200, 1) ) da_isig.definition = isig_definition da_isig.unit = "mV" da_isig.metadata = isig_md timedim = da_isig.append_range_dimension(isig_times) timedim.unit = "s" timedim.label = "time" da_isig.append_set_dimension() group.data_arrays.append(da_isig) siggroup.append(da_isig) allsignalgroups.append(siggroup) # SpikeTrains with Waveforms for n in range(4): stname = "{}-st{}".format(cls.rword(20), n) times = cls.rquant(40, 1, True) times_da = blk.create_data_array( "{}.times".format(stname), "neo.spiketrain.times", data=times ) times_da.unit = "ms" mtag_st = blk.create_multi_tag(stname, "neo.spiketrain", times_da) group.multi_tags.append(mtag_st) mtag_st.definition = cls.rsentence(20, 30) mtag_st_md = group.metadata.create_section( mtag_st.name, mtag_st.name + ".metadata" ) mtag_st.metadata = mtag_st_md mtag_st_md.create_property("t_stop", times[-1] + 1.0) waveforms = cls.rquant((10, 8, 5), 1) wfname = "{}.waveforms".format(mtag_st.name) wfda = blk.create_data_array(wfname, "neo.waveforms", data=waveforms) wfda.unit = "mV" mtag_st.create_feature(wfda, nix.LinkType.Indexed) wfda.append_set_dimension() # spike dimension wfda.append_set_dimension() # channel dimension wftimedim = wfda.append_sampled_dimension(0.1) wftimedim.unit = "ms" wftimedim.label = "time" wfda.metadata = mtag_st_md.create_section( wfname, "neo.waveforms.metadata" ) wfda.metadata.create_property("left_sweep", [20] * 5) allspiketrains.append(mtag_st) # Epochs for n in range(3): epname = "{}-ep{}".format(cls.rword(5), n) times = cls.rquant(5, 1, True) times_da = blk.create_data_array( "{}.times".format(epname), "neo.epoch.times", data=times ) times_da.unit = "s" extents = cls.rquant(5, 1) extents_da = blk.create_data_array( "{}.durations".format(epname), "neo.epoch.durations", data=extents ) extents_da.unit = "s" mtag_ep = blk.create_multi_tag( epname, "neo.epoch", times_da ) mtag_ep.metadata = group.metadata.create_section( epname, epname + ".metadata" ) group.multi_tags.append(mtag_ep) mtag_ep.definition = cls.rsentence(2) mtag_ep.extents = extents_da label_dim = mtag_ep.positions.append_set_dimension() label_dim.labels = cls.rsentence(5).split(" ") # reference all signals in the group for siggroup in allsignalgroups: mtag_ep.references.extend(siggroup) # Events for n in range(2): evname = "{}-ev{}".format(cls.rword(5), n) times = cls.rquant(5, 1, True) times_da = blk.create_data_array( "{}.times".format(evname), "neo.event.times", data=times ) times_da.unit = "s" mtag_ev = blk.create_multi_tag( evname, "neo.event", times_da ) mtag_ev.metadata = group.metadata.create_section( evname, evname + ".metadata" ) group.multi_tags.append(mtag_ev) mtag_ev.definition = cls.rsentence(2) label_dim = mtag_ev.positions.append_set_dimension() label_dim.labels = cls.rsentence(5).split(" ") # reference all signals in the group for siggroup in allsignalgroups: mtag_ev.references.extend(siggroup) # CHX nixchx = blk.create_source(cls.rword(10), "neo.channelindex") nixchx.metadata = nix_blocks[0].metadata.create_section( nixchx.name, "neo.channelindex.metadata" ) chantype = "neo.channelindex" # 3 channels for idx, chan in enumerate([2, 5, 9]): channame = "{}.ChannelIndex{}".format(nixchx.name, idx) nixrc = nixchx.create_source(channame, chantype) nixrc.definition = cls.rsentence(13) nixrc.metadata = nixchx.metadata.create_section( nixrc.name, "neo.channelindex.metadata" ) nixrc.metadata.create_property("index", chan) nixrc.metadata.create_property("channel_id", chan + 1) dims = cls.rquant(3, 1) coordprop = nixrc.metadata.create_property("coordinates", dims) coordprop.unit = "pm" nunits = 1 stsperunit = np.array_split(allspiketrains, nunits) for idx in range(nunits): unitname = "{}-unit{}".format(cls.rword(5), idx) nixunit = nixchx.create_source(unitname, "neo.unit") nixunit.metadata = nixchx.metadata.create_section( unitname, unitname + ".metadata" ) nixunit.definition = cls.rsentence(4, 10) for st in stsperunit[idx]: st.sources.append(nixchx) st.sources.append(nixunit) # pick a few signal groups to reference this CHX randsiggroups = np.random.choice(allsignalgroups, 5, False) for siggroup in randsiggroups: for sig in siggroup: sig.sources.append(nixchx) return nixfile @staticmethod def rdate(): return datetime(year=np.random.randint(1980, 2020), month=np.random.randint(1, 13), day=np.random.randint(1, 29)) @classmethod def populate_dates(cls, obj): obj.file_datetime = cls.rdate() obj.rec_datetime = cls.rdate() @staticmethod def rword(n=10): return "".join(np.random.choice(list(string.ascii_letters), n)) @classmethod def rsentence(cls, n=3, maxwl=10): return " ".join(cls.rword(np.random.randint(1, maxwl)) for _ in range(n)) @classmethod def rdict(cls, nitems): rd = dict() for _ in range(nitems): key = cls.rword() value = cls.rword() if np.random.choice((0, 1)) \ else np.random.uniform() rd[key] = value return rd @staticmethod def rquant(shape, unit, incr=False): try: dim = len(shape) except TypeError: dim = 1 if incr and dim > 1: raise TypeError("Shape of quantity array may only be " "one-dimensional when incremental values are " "requested.") arr = np.random.random(shape) if incr: arr = np.array(np.cumsum(arr)) return arr * unit @classmethod def create_all_annotated(cls): times = cls.rquant(1, pq.s) signal = cls.rquant(1, pq.V) blk = Block() blk.annotate(**cls.rdict(3)) cls.populate_dates(blk) seg = Segment() seg.annotate(**cls.rdict(4)) cls.populate_dates(seg) blk.segments.append(seg) asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz) asig.annotate(**cls.rdict(2)) seg.analogsignals.append(asig) isig = IrregularlySampledSignal(times=times, signal=signal, time_units=pq.s) isig.annotate(**cls.rdict(2)) seg.irregularlysampledsignals.append(isig) epoch = Epoch(times=times, durations=times) epoch.annotate(**cls.rdict(4)) seg.epochs.append(epoch) event = Event(times=times) event.annotate(**cls.rdict(4)) seg.events.append(event) spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s) d = cls.rdict(6) d["quantity"] = pq.Quantity(10, "mV") d["qarray"] = pq.Quantity(range(10), "mA") spiketrain.annotate(**d) seg.spiketrains.append(spiketrain) chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10]) chx.annotate(**cls.rdict(5)) blk.channel_indexes.append(chx) unit = Unit() unit.annotate(**cls.rdict(2)) chx.units.append(unit) return blk @unittest.skipUnless(HAVE_NIX, "Requires NIX") class NixIOWriteTest(NixIOTest): def setUp(self): self.tempdir = mkdtemp(prefix="nixiotest") self.filename = os.path.join(self.tempdir, "testnixio.nix") self.writer = NixIO(self.filename, "ow") self.io = self.writer self.reader = nix.File.open(self.filename, nix.FileMode.ReadOnly) def tearDown(self): self.writer.close() self.reader.close() shutil.rmtree(self.tempdir) def write_and_compare(self, blocks, use_obj_names=False): self.writer.write_all_blocks(blocks, use_obj_names) self.compare_blocks(blocks, self.reader.blocks) self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks) self.compare_blocks(blocks, self.reader.blocks) def test_block_write(self): block = Block(name=self.rword(), description=self.rsentence()) self.write_and_compare([block]) block.annotate(**self.rdict(5)) self.write_and_compare([block]) def test_segment_write(self): block = Block(name=self.rword()) segment = Segment(name=self.rword(), description=self.rword()) block.segments.append(segment) self.write_and_compare([block]) segment.annotate(**self.rdict(2)) self.write_and_compare([block]) def test_channel_index_write(self): block = Block(name=self.rword()) chx = ChannelIndex(name=self.rword(), description=self.rsentence(), channel_ids=[10, 20, 30, 50, 80, 130], index=[1, 2, 3, 5, 8, 13]) block.channel_indexes.append(chx) self.write_and_compare([block]) chx.annotate(**self.rdict(3)) self.write_and_compare([block]) chx.channel_names = ["one", "two", "three", "five", "eight", "xiii"] chx.coordinates = self.rquant((6, 3), pq.um) self.write_and_compare([block]) # add an empty channel index and check again newchx = ChannelIndex(np.array([])) block.channel_indexes.append(newchx) self.write_and_compare([block]) def test_signals_write(self): block = Block() seg = Segment() block.segments.append(seg) asig = AnalogSignal(signal=self.rquant((19, 15), pq.mV), sampling_rate=pq.Quantity(10, "Hz")) seg.analogsignals.append(asig) self.write_and_compare([block]) anotherblock = Block("ir signal block") seg = Segment("ir signal seg") anotherblock.segments.append(seg) irsig = IrregularlySampledSignal( signal=np.random.random((20, 30)), times=self.rquant(20, pq.ms, True), units=pq.A ) seg.irregularlysampledsignals.append(irsig) self.write_and_compare([block, anotherblock]) block.segments[0].analogsignals.append( AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S, sampling_period=pq.Quantity(3, "s"), dtype=np.double, name="signal42", description="this is an analogsignal", t_start=45 * pq.ms), ) self.write_and_compare([block, anotherblock]) block.segments[0].irregularlysampledsignals.append( IrregularlySampledSignal(times=np.random.random(10), signal=np.random.random((10, 13)), units="mV", time_units="s", dtype=np.float, name="some sort of signal", description="the signal is described") ) self.write_and_compare([block, anotherblock]) def test_signals_compound_units(self): block = Block() seg = Segment() block.segments.append(seg) units = pq.CompoundUnit("1/30000*V") srate = pq.Quantity(10, pq.CompoundUnit("1.0/10 * Hz")) asig = AnalogSignal(signal=self.rquant((10, 23), units), sampling_rate=srate) seg.analogsignals.append(asig) self.write_and_compare([block]) anotherblock = Block("ir signal block") seg = Segment("ir signal seg") anotherblock.segments.append(seg) irsig = IrregularlySampledSignal( signal=np.random.random((20, 3)), times=self.rquant(20, pq.CompoundUnit("0.1 * ms"), True), units=pq.CompoundUnit("10 * V / s") ) seg.irregularlysampledsignals.append(irsig) self.write_and_compare([block, anotherblock]) block.segments[0].analogsignals.append( AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S, sampling_period=pq.Quantity(3, "s"), dtype=np.double, name="signal42", description="this is an analogsignal", t_start=45 * pq.CompoundUnit("3.14 * s")), ) self.write_and_compare([block, anotherblock]) times = self.rquant(10, pq.CompoundUnit("3 * year"), True) block.segments[0].irregularlysampledsignals.append( IrregularlySampledSignal(times=times, signal=np.random.random((10, 3)), units="mV", dtype=np.float, name="some sort of signal", description="the signal is described") ) self.write_and_compare([block, anotherblock]) def test_epoch_write(self): block = Block() seg = Segment() block.segments.append(seg) epoch = Epoch(times=[1, 1, 10, 3] * pq.ms, durations=[3, 3, 3, 1] * pq.ms, labels=np.array(["one", "two", "three", "four"]), name="test epoch", description="an epoch for testing") seg.epochs.append(epoch) self.write_and_compare([block]) def test_event_write(self): block = Block() seg = Segment() block.segments.append(seg) event = Event(times=np.arange(0, 30, 10) * pq.s, labels=np.array(["0", "1", "2"]), name="event name", description="event description") seg.events.append(event) self.write_and_compare([block]) def test_spiketrain_write(self): block = Block() seg = Segment() block.segments.append(seg) spiketrain = SpikeTrain(times=[3, 4, 5] * pq.s, t_stop=10.0, name="spikes!", description="sssssspikes") seg.spiketrains.append(spiketrain) self.write_and_compare([block]) waveforms = self.rquant((3, 5, 10), pq.mV) spiketrain = SpikeTrain(times=[1, 1.1, 1.2] * pq.ms, t_stop=1.5 * pq.s, name="spikes with wf", description="spikes for waveform test", waveforms=waveforms) seg.spiketrains.append(spiketrain) self.write_and_compare([block]) spiketrain.left_sweep = np.random.random(10) * pq.ms self.write_and_compare([block]) spiketrain.left_sweep = pq.Quantity(-10, "ms") self.write_and_compare([block]) def test_metadata_structure_write(self): neoblk = self.create_all_annotated() self.io.write_block(neoblk) blk = self.io.nix_file.blocks[0] blkmd = blk.metadata self.assertEqual(blk.name, blkmd.name) grp = blk.groups[0] # segment self.assertIn(grp.name, blkmd.sections) grpmd = blkmd.sections[grp.name] for da in grp.data_arrays: # signals name = ".".join(da.name.split(".")[:-1]) self.assertIn(name, grpmd.sections) for mtag in grp.multi_tags: # spiketrains, events, and epochs self.assertIn(mtag.name, grpmd.sections) srcchx = blk.sources[0] # chx self.assertIn(srcchx.name, blkmd.sections) for srcunit in blk.sources: # units self.assertIn(srcunit.name, blkmd.sections) self.write_and_compare([neoblk]) def test_anonymous_objects_write(self): nblocks = 2 nsegs = 2 nanasig = 4 nirrseg = 2 nepochs = 3 nevents = 4 nspiketrains = 3 nchx = 5 nunits = 10 times = self.rquant(1, pq.s) signal = self.rquant(1, pq.V) blocks = [] for blkidx in range(nblocks): blk = Block() blocks.append(blk) for segidx in range(nsegs): seg = Segment() blk.segments.append(seg) for anaidx in range(nanasig): seg.analogsignals.append(AnalogSignal(signal=signal, sampling_rate=pq.Hz)) for irridx in range(nirrseg): seg.irregularlysampledsignals.append( IrregularlySampledSignal(times=times, signal=signal, time_units=pq.s) ) for epidx in range(nepochs): seg.epochs.append(Epoch(times=times, durations=times)) for evidx in range(nevents): seg.events.append(Event(times=times)) for stidx in range(nspiketrains): seg.spiketrains.append(SpikeTrain(times=times, t_stop=times[-1] + pq.s, units=pq.s)) for chidx in range(nchx): chx = ChannelIndex(index=[1, 2], channel_ids=[11, 22]) blk.channel_indexes.append(chx) for unidx in range(nunits): unit = Unit() chx.units.append(unit) self.writer.write_all_blocks(blocks) self.compare_blocks(blocks, self.reader.blocks) with self.assertRaises(ValueError): self.writer.write_all_blocks(blocks, use_obj_names=True) def test_name_objects_write(self): nblocks = 2 nsegs = 2 nanasig = 4 nirrseg = 2 nepochs = 3 nevents = 4 nspiketrains = 3 nchx = 5 nunits = 10 times = self.rquant(1, pq.s) signal = self.rquant(1, pq.V) blocks = [] for blkidx in range(nblocks): blk = Block(name="block{}".format(blkidx)) blocks.append(blk) for segidx in range(nsegs): seg = Segment(name="seg{}".format(segidx)) blk.segments.append(seg) for anaidx in range(nanasig): asig = AnalogSignal( name="{}:as{}".format(seg.name, anaidx), signal=signal, sampling_rate=pq.Hz ) seg.analogsignals.append(asig) for irridx in range(nirrseg): isig = IrregularlySampledSignal( name="{}:is{}".format(seg.name, irridx), times=times, signal=signal, time_units=pq.s ) seg.irregularlysampledsignals.append(isig) for epidx in range(nepochs): seg.epochs.append( Epoch(name="{}:ep{}".format(seg.name, epidx), times=times, durations=times) ) for evidx in range(nevents): seg.events.append( Event(name="{}:ev{}".format(seg.name, evidx), times=times) ) for stidx in range(nspiketrains): seg.spiketrains.append( SpikeTrain(name="{}:st{}".format(seg.name, stidx), times=times, t_stop=times[-1] + pq.s, units=pq.s) ) for chidx in range(nchx): chx = ChannelIndex(name="chx{}".format(chidx), index=[1, 2], channel_ids=[11, 22]) blk.channel_indexes.append(chx) for unidx in range(nunits): unit = Unit(name="unit{}".format(unidx)) chx.units.append(unit) # put guard on _generate_nix_name if not SKIPMOCK: nixgenmock = mock.Mock(name="_generate_nix_name", wraps=self.io._generate_nix_name) self.io._generate_nix_name = nixgenmock self.writer.write_block(blocks[0], use_obj_names=True) self.compare_blocks([blocks[0]], self.reader.blocks) self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks) self.compare_blocks(blocks, self.reader.blocks) if not SKIPMOCK: nixgenmock.assert_not_called() self.write_and_compare(blocks, use_obj_names=True) if not SKIPMOCK: nixgenmock.assert_not_called() self.assertEqual(self.reader.blocks[0].name, "block0") blocks[0].name = blocks[1].name # name conflict with self.assertRaises(ValueError): self.writer.write_all_blocks(blocks, use_obj_names=True) blocks[0].name = "new name" self.assertEqual(blocks[0].segments[1].spiketrains[1].name, "seg1:st1") st0 = blocks[0].segments[0].spiketrains[0].name blocks[0].segments[0].spiketrains[1].name = st0 # name conflict with self.assertRaises(ValueError): self.writer.write_all_blocks(blocks, use_obj_names=True) with self.assertRaises(ValueError): self.writer.write_block(blocks[0], use_obj_names=True) if not SKIPMOCK: nixgenmock.assert_not_called() def test_name_conflicts(self): # anon block blk = Block() with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) # two anon blocks blocks = [Block(), Block()] with self.assertRaises(ValueError): self.io.write_all_blocks(blocks, use_obj_names=True) # same name blocks blocks = [Block(name="one"), Block(name="one")] with self.assertRaises(ValueError): self.io.write_all_blocks(blocks, use_obj_names=True) # one block, two same name segments blk = Block("new") seg = Segment("I am the segment", a="a annoation") blk.segments.append(seg) seg = Segment("I am the segment", a="b annotation") blk.segments.append(seg) with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) times = self.rquant(1, pq.s) signal = self.rquant(1, pq.V) # name conflict: analog + irregular signals seg.analogsignals.append( AnalogSignal(name="signal", signal=signal, sampling_rate=pq.Hz) ) seg.irregularlysampledsignals.append( IrregularlySampledSignal(name="signal", signal=signal, times=times) ) blk = Block(name="Signal conflict Block") blk.segments.append(seg) with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) # name conflict: event + spiketrain blk = Block(name="Event+SpikeTrain conflict Block") seg = Segment(name="Event+SpikeTrain conflict Segment") blk.segments.append(seg) seg.events.append(Event(name="TimeyStuff", times=times)) seg.spiketrains.append(SpikeTrain(name="TimeyStuff", times=times, t_stop=pq.s)) with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) # make spiketrain anon blk.segments[0].spiketrains[0].name = None with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) # name conflict in channel indexes blk = Block(name="ChannelIndex conflict Block") blk.channel_indexes.append(ChannelIndex(name="chax", index=[1])) blk.channel_indexes.append(ChannelIndex(name="chax", index=[2])) with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) # name conflict in units blk = Block(name="unitconf") chx = ChannelIndex(name="ok", index=[100]) blk.channel_indexes.append(chx) chx.units.append(Unit(name="IHAVEATWIN")) chx.units.append(Unit(name="IHAVEATWIN")) with self.assertRaises(ValueError): self.io.write_block(blk, use_obj_names=True) def test_multiref_write(self): blk = Block("blk1") signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV", sampling_period=pq.Quantity(1, "ms")) othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0], units="mV", times=[1, 2, 3], time_units="ms") event = Event(name="Evee", times=[0.3, 0.42], units="year") epoch = Epoch(name="epoche", times=[0.1, 0.2] * pq.min, durations=[0.5, 0.5] * pq.min) st = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3], t_stop=11, units="us") for idx in range(3): segname = "seg" + str(idx) seg = Segment(segname) blk.segments.append(seg) seg.analogsignals.append(signal) seg.irregularlysampledsignals.append(othersignal) seg.events.append(event) seg.epochs.append(epoch) seg.spiketrains.append(st) chidx = ChannelIndex([10, 20, 29]) seg = blk.segments[0] st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000, units="s") seg.spiketrains.append(st) blk.channel_indexes.append(chidx) for idx in range(6): unit = Unit("unit" + str(idx)) chidx.units.append(unit) unit.spiketrains.append(st) self.writer.write_block(blk) self.compare_blocks([blk], self.reader.blocks) def test_no_segment_write(self): # Tests storing AnalogSignal, IrregularlySampledSignal, and SpikeTrain # objects in the secondary (ChannelIndex) substructure without them # being attached to a Segment. blk = Block("segmentless block") signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV", sampling_period=pq.Quantity(1, "ms")) othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0], units="mV", times=[1, 2, 3], time_units="ms") sta = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3], t_stop=11, units="us") stb = SpikeTrain(name="the train of spikes b", times=[1.1, 2.2, 10.1], t_stop=100, units="ms") chidx = ChannelIndex([8, 13, 21]) blk.channel_indexes.append(chidx) chidx.analogsignals.append(signal) chidx.irregularlysampledsignals.append(othersignal) unit = Unit() chidx.units.append(unit) unit.spiketrains.extend([sta, stb]) self.writer.write_block(blk) self.writer.close() self.compare_blocks([blk], self.reader.blocks) reader = NixIO(self.filename, "ro") blk = reader.read_block(neoname="segmentless block") chx = blk.channel_indexes[0] self.assertEqual(len(chx.analogsignals), 1) self.assertEqual(len(chx.irregularlysampledsignals), 1) self.assertEqual(len(chx.units[0].spiketrains), 2) def test_rewrite_refs(self): def checksignalcounts(fname): with NixIO(fname, "ro") as r: blk = r.read_block() chidx = blk.channel_indexes[0] seg = blk.segments[0] self.assertEqual(len(chidx.analogsignals), 2) self.assertEqual(len(chidx.units[0].spiketrains), 3) self.assertEqual(len(seg.analogsignals), 1) self.assertEqual(len(seg.spiketrains), 1) blk = Block() # ChannelIndex chidx = ChannelIndex(index=[1]) blk.channel_indexes.append(chidx) # Two signals on ChannelIndex for idx in range(2): asigchx = AnalogSignal(signal=[idx], units="mV", sampling_rate=pq.Hz) chidx.analogsignals.append(asigchx) # Unit unit = Unit() chidx.units.append(unit) # Three SpikeTrains on Unit for idx in range(3): st = SpikeTrain([idx], units="ms", t_stop=40) unit.spiketrains.append(st) # Segment seg = Segment() blk.segments.append(seg) # One signal on Segment asigseg = AnalogSignal(signal=[2], units="uA", sampling_rate=pq.Hz) seg.analogsignals.append(asigseg) # One spiketrain on Segment stseg = SpikeTrain([10], units="ms", t_stop=40) seg.spiketrains.append(stseg) # Write, compare, and check counts self.writer.write_block(blk) self.compare_blocks([blk], self.reader.blocks) self.assertEqual(len(chidx.analogsignals), 2) self.assertEqual(len(seg.analogsignals), 1) self.assertEqual(len(chidx.analogsignals), 2) self.assertEqual(len(chidx.units[0].spiketrains), 3) self.assertEqual(len(seg.analogsignals), 1) self.assertEqual(len(seg.spiketrains), 1) # Check counts with separate reader checksignalcounts(self.filename) # Write again and check counts secondwrite = os.path.join(self.tempdir, "testnixio-2.nix") with NixIO(secondwrite, "ow") as w: w.write_block(blk) self.compare_blocks([blk], self.reader.blocks) # Read back and check counts scndreader = nix.File.open(secondwrite, mode=nix.FileMode.ReadOnly) self.compare_blocks([blk], scndreader.blocks) checksignalcounts(secondwrite) def test_to_value(self): section = self.io.nix_file.create_section("Metadata value test", "Test") writeprop = self.io._write_property # quantity qvalue = pq.Quantity(10, "mV") writeprop(section, "qvalue", qvalue) self.assertEqual(section["qvalue"], 10) self.assertEqual(section.props["qvalue"].unit, "mV") # datetime dt = self.rdate() writeprop(section, "dt", dt) self.assertEqual(datetime.fromtimestamp(section["dt"]), dt) # string randstr = self.rsentence() writeprop(section, "randstr", randstr) self.assertEqual(section["randstr"], randstr) # bytes bytestring = b"bytestring" writeprop(section, "randbytes", bytestring) self.assertEqual(section["randbytes"], bytestring.decode()) # iterables randlist = np.random.random(10).tolist() writeprop(section, "randlist", randlist) self.assertEqual(randlist, section["randlist"]) randarray = np.random.random(10) writeprop(section, "randarray", randarray) np.testing.assert_almost_equal(randarray, section["randarray"]) # numpy item npval = np.float64(2398) writeprop(section, "npval", npval) self.assertEqual(npval, section["npval"]) # number val = 42 writeprop(section, "val", val) self.assertEqual(val, section["val"]) # empty string (gets stored as empty list) writeprop(section, "emptystring", "") self.assertEqual(list(), section["emptystring"]) def test_annotations_special_cases(self): # Special cases for annotations: empty list, list of strings, # multidimensional lists/arrays # These are handled differently on read, so we test them on a block # instead of just checking the property writer method # empty value # empty list wblock = Block("block with empty list", an_empty_list=list()) self.writer.write_block(wblock) rblock = self.writer.read_block(neoname="block with empty list") self.assertEqual(rblock.annotations["an_empty_list"], list()) # empty tuple (gets read out as list) wblock = Block("block with empty tuple", an_empty_tuple=tuple()) self.writer.write_block(wblock) rblock = self.writer.read_block(neoname="block with empty tuple") self.assertEqual(rblock.annotations["an_empty_tuple"], list()) # list of strings losval = ["one", "two", "one million"] wblock = Block("block with list of strings", los=losval) self.writer.write_block(wblock) rblock = self.writer.read_block(neoname="block with list of strings") self.assertEqual(rblock.annotations["los"], losval) # TODO: multi dimensional value (GH Issue #501) @unittest.skipUnless(HAVE_NIX, "Requires NIX") class NixIOReadTest(NixIOTest): nixfile = None nix_blocks = None @classmethod def setUpClass(cls): cls.tempdir = mkdtemp(prefix="nixiotest") cls.filename = os.path.join(cls.tempdir, "testnixio.nix") if HAVE_NIX: cls.nixfile = cls.create_full_nix_file(cls.filename) def setUp(self): self.io = NixIO(self.filename, "ro") @classmethod def tearDownClass(cls): if HAVE_NIX: cls.nixfile.close() shutil.rmtree(cls.tempdir) def tearDown(self): self.io.close() def test_all_read(self): neo_blocks = self.io.read_all_blocks() nix_blocks = self.io.nix_file.blocks self.compare_blocks(neo_blocks, nix_blocks) def test_iter_read(self): blocknames = [blk.name for blk in self.nixfile.blocks] for blk, nixname in zip(self.io.iter_blocks(), blocknames): self.assertEqual(blk.annotations["nix_name"], nixname) def test_nix_name_read(self): for nixblock in self.nixfile.blocks: nixname = nixblock.name neoblock = self.io.read_block(nixname=nixname) self.assertEqual(neoblock.annotations["nix_name"], nixname) def test_index_read(self): for idx, nixblock in enumerate(self.nixfile.blocks): neoblock = self.io.read_block(index=idx) self.assertEqual(neoblock.annotations["nix_name"], nixblock.name) self.assertEqual(neoblock.annotations["nix_name"], self.nixfile.blocks[idx].name) def test_auto_index_read(self): for nixblock in self.nixfile.blocks: neoblock = self.io.read_block() # don't specify index self.assertEqual(neoblock.annotations["nix_name"], nixblock.name) # No more blocks - should return None self.assertIs(self.io.read_block(), None) self.assertIs(self.io.read_block(), None) self.assertIs(self.io.read_block(), None) with NixIO(self.filename, "ro") as nf: neoblock = nf.read_block(index=1) self.assertEqual(self.nixfile.blocks[1].name, neoblock.annotations["nix_name"]) neoblock = nf.read_block() # should start again from 0 self.assertEqual(self.nixfile.blocks[0].name, neoblock.annotations["nix_name"]) def test_neo_name_read(self): for nixblock in self.nixfile.blocks: neoname = nixblock.metadata["neo_name"] neoblock = self.io.read_block(neoname=neoname) self.assertEqual(neoblock.annotations["nix_name"], nixblock.name) @unittest.skipUnless(HAVE_NIX, "Requires NIX") class NixIOContextTests(NixIOTest): def setUp(self): self.tempdir = mkdtemp(prefix="nixiotest") self.filename = os.path.join(self.tempdir, "testnixio.nix") def tearDown(self): shutil.rmtree(self.tempdir) def test_context_write(self): neoblock = Block(name=self.rword(), description=self.rsentence()) with NixIO(self.filename, "ow") as iofile: iofile.write_block(neoblock) nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly) self.compare_blocks([neoblock], nixfile.blocks) nixfile.close() neoblock.annotate(**self.rdict(5)) with NixIO(self.filename, "rw") as iofile: iofile.write_block(neoblock) nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly) self.compare_blocks([neoblock], nixfile.blocks) nixfile.close() def test_context_read(self): nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite) name_one = self.rword() name_two = self.rword() nixfile.create_block(name_one, "neo.block") nixfile.create_block(name_two, "neo.block") nixfile.close() with NixIO(self.filename, "ro") as iofile: blocks = iofile.read_all_blocks() self.assertEqual(blocks[0].annotations["nix_name"], name_one) self.assertEqual(blocks[1].annotations["nix_name"], name_two) @unittest.skipUnless(HAVE_NIX, "Requires NIX") class NixIOVerTests(NixIOTest): def setUp(self): self.tempdir = mkdtemp(prefix="nixiotest") self.filename = os.path.join(self.tempdir, "testnixio.nix") def tearDown(self): shutil.rmtree(self.tempdir) def test_new_file(self): with NixIO(self.filename, "ow") as iofile: self.assertEqual(iofile._file_version, neover) nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly) filever = nixfile.sections["neo"]["version"] self.assertEqual(filever, neover) nixfile.close() def test_oldfile_nover(self): nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite) nixfile.close() with NixIO(self.filename, "ro") as iofile: self.assertEqual(iofile._file_version, '0.5.2') # compat version nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly) self.assertNotIn("neo", nixfile.sections) nixfile.close() with NixIO(self.filename, "rw") as iofile: self.assertEqual(iofile._file_version, '0.5.2') # compat version # section should have been created now nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly) self.assertIn("neo", nixfile.sections) self.assertEqual(nixfile.sections["neo"]["version"], '0.5.2') nixfile.close() def test_file_with_ver(self): someversion = '0.100.10' nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite) filemd = nixfile.create_section("neo", "neo.metadata") filemd["version"] = someversion nixfile.close() with NixIO(self.filename, "ro") as iofile: self.assertEqual(iofile._file_version, someversion) with NixIO(self.filename, "rw") as iofile: self.assertEqual(iofile._file_version, someversion) with NixIO(self.filename, "ow") as iofile: self.assertEqual(iofile._file_version, neover) @unittest.skipUnless(HAVE_NIX, "Requires NIX") class CommonTests(BaseTestIO, unittest.TestCase): ioclass = NixIO read_and_write_is_bijective = False if __name__ == "__main__": unittest.main()