12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256 |
- # -*- coding: utf-8 -*-
- # Copyright (c) 2016, German Neuroinformatics Node (G-Node)
- # Achilleas Koutsou <achilleas.k@gmail.com>
- #
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted under the terms of the BSD License. See
- # LICENSE file in the root of the Project.
- """
- Tests for neo.io.nixio
- """
- import os
- from datetime import datetime
- import unittest
- try:
- from unittest import mock
- except ImportError:
- import mock
- import string
- import numpy as np
- import quantities as pq
- from neo.core import (Block, Segment, ChannelIndex, AnalogSignal,
- IrregularlySampledSignal, Unit, SpikeTrain, Event, Epoch)
- from neo.test.iotest.common_io_test import BaseTestIO
- try:
- import nixio as nix
- HAVE_NIX = True
- except ImportError:
- HAVE_NIX = False
- from neo.io.nixio import NixIO
- from neo.io.nixio import string_types
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOTest(unittest.TestCase):
- filename = None
- io = None
- def compare_blocks(self, neoblocks, nixblocks):
- for neoblock, nixblock in zip(neoblocks, nixblocks):
- self.compare_attr(neoblock, nixblock)
- self.assertEqual(len(neoblock.segments), len(nixblock.groups))
- for idx, neoseg in enumerate(neoblock.segments):
- nixgrp = nixblock.groups[neoseg.annotations["nix_name"]]
- self.compare_segment_group(neoseg, nixgrp)
- for idx, neochx in enumerate(neoblock.channel_indexes):
- nixsrc = nixblock.sources[neochx.annotations["nix_name"]]
- self.compare_chx_source(neochx, nixsrc)
- self.check_refs(neoblock, nixblock)
- def compare_chx_source(self, neochx, nixsrc):
- self.compare_attr(neochx, nixsrc)
- nix_channels = list(src for src in nixsrc.sources
- if src.type == "neo.channelindex")
- self.assertEqual(len(neochx.index), len(nix_channels))
- if len(neochx.channel_ids):
- nix_chanids = list(src.metadata["channel_id"] for src
- in nixsrc.sources
- if src.type == "neo.channelindex")
- self.assertEqual(len(neochx.channel_ids), len(nix_chanids))
- for nixchan in nix_channels:
- nixchanidx = nixchan.metadata["index"]
- try:
- neochanpos = list(neochx.index).index(nixchanidx)
- except ValueError:
- self.fail("Channel indexes do not match.")
- if len(neochx.channel_names):
- neochanname = neochx.channel_names[neochanpos]
- if ((not isinstance(neochanname, str)) and
- isinstance(neochanname, bytes)):
- neochanname = neochanname.decode()
- nixchanname = nixchan.metadata["neo_name"]
- self.assertEqual(neochanname, nixchanname)
- if len(neochx.channel_ids):
- neochanid = neochx.channel_ids[neochanpos]
- nixchanid = nixchan.metadata["channel_id"]
- self.assertEqual(neochanid, nixchanid)
- elif "channel_id" in nixchan.metadata:
- self.fail("Channel ID not loaded")
- nix_units = list(src for src in nixsrc.sources
- if src.type == "neo.unit")
- self.assertEqual(len(neochx.units), len(nix_units))
- for neounit in neochx.units:
- nixunit = nixsrc.sources[neounit.annotations["nix_name"]]
- self.compare_attr(neounit, nixunit)
- def check_refs(self, neoblock, nixblock):
- """
- Checks whether the references between objects that are not nested are
- mapped correctly (e.g., SpikeTrains referenced by a Unit).
- :param neoblock: A Neo block
- :param nixblock: The corresponding NIX block
- """
- for idx, neochx in enumerate(neoblock.channel_indexes):
- nixchx = nixblock.sources[neochx.annotations["nix_name"]]
- # AnalogSignals referencing CHX
- neoasigs = list(sig.annotations["nix_name"]
- for sig in neochx.analogsignals)
- nixasigs = list(set(da.metadata.name for da in nixblock.data_arrays
- if da.type == "neo.analogsignal" and
- nixchx in da.sources))
- self.assertEqual(len(neoasigs), len(nixasigs),
- neochx.analogsignals)
- # IrregularlySampledSignals referencing CHX
- neoisigs = list(sig.annotations["nix_name"] for sig in
- neochx.irregularlysampledsignals)
- nixisigs = list(
- set(da.metadata.name for da in nixblock.data_arrays
- if da.type == "neo.irregularlysampledsignal" and
- nixchx in da.sources)
- )
- self.assertEqual(len(neoisigs), len(nixisigs))
- # SpikeTrains referencing CHX and Units
- for sidx, neounit in enumerate(neochx.units):
- nixunit = nixchx.sources[neounit.annotations["nix_name"]]
- neosts = list(st.annotations["nix_name"]
- for st in neounit.spiketrains)
- nixsts = list(mt for mt in nixblock.multi_tags
- if mt.type == "neo.spiketrain" and
- nixunit.name in mt.sources)
- # SpikeTrains must also reference CHX
- for nixst in nixsts:
- self.assertIn(nixchx.name, nixst.sources)
- nixsts = list(st.name for st in nixsts)
- self.assertEqual(len(neosts), len(nixsts))
- for neoname in neosts:
- if neoname:
- self.assertIn(neoname, nixsts)
- # Events and Epochs must reference all Signals in the Group (NIX only)
- for nixgroup in nixblock.groups:
- nixevep = list(mt for mt in nixgroup.multi_tags
- if mt.type in ["neo.event", "neo.epoch"])
- nixsigs = list(da.name for da in nixgroup.data_arrays
- if da.type in ["neo.analogsignal",
- "neo.irregularlysampledsignal"])
- for nee in nixevep:
- for ns in nixsigs:
- self.assertIn(ns, nee.references)
- def compare_segment_group(self, neoseg, nixgroup):
- self.compare_attr(neoseg, nixgroup)
- neo_signals = neoseg.analogsignals + neoseg.irregularlysampledsignals
- self.compare_signals_das(neo_signals, nixgroup.data_arrays)
- neo_eests = neoseg.epochs + neoseg.events + neoseg.spiketrains
- self.compare_eests_mtags(neo_eests, nixgroup.multi_tags)
- def compare_signals_das(self, neosignals, data_arrays):
- for sig in neosignals:
- if self.io._find_lazy_loaded(sig) is not None:
- sig = self.io.load_lazy_object(sig)
- dalist = list()
- nixname = sig.annotations["nix_name"]
- for da in data_arrays:
- if da.metadata.name == nixname:
- dalist.append(da)
- _, nsig = np.shape(sig)
- self.assertEqual(nsig, len(dalist))
- self.compare_signal_dalist(sig, dalist)
- def compare_signal_dalist(self, neosig, nixdalist):
- """
- Check if a Neo Analog or IrregularlySampledSignal matches a list of
- NIX DataArrays.
- :param neosig: Neo Analog or IrregularlySampledSignal
- :param nixdalist: List of DataArrays
- """
- nixmd = nixdalist[0].metadata
- self.assertTrue(all(nixmd == da.metadata for da in nixdalist))
- neounit = str(neosig.dimensionality)
- for sig, da in zip(np.transpose(neosig),
- sorted(nixdalist, key=lambda d: d.name)):
- self.compare_attr(neosig, da)
- np.testing.assert_almost_equal(sig.magnitude, da)
- self.assertEqual(neounit, da.unit)
- timedim = da.dimensions[0]
- if isinstance(neosig, AnalogSignal):
- self.assertEqual(timedim.dimension_type,
- nix.DimensionType.Sample)
- self.assertEqual(
- pq.Quantity(timedim.sampling_interval, timedim.unit),
- neosig.sampling_period
- )
- self.assertEqual(timedim.offset, neosig.t_start.magnitude)
- if "t_start.units" in da.metadata.props:
- self.assertEqual(da.metadata["t_start.units"],
- str(neosig.t_start.dimensionality))
- elif isinstance(neosig, IrregularlySampledSignal):
- self.assertEqual(timedim.dimension_type,
- nix.DimensionType.Range)
- np.testing.assert_almost_equal(neosig.times.magnitude,
- timedim.ticks)
- self.assertEqual(timedim.unit,
- str(neosig.times.dimensionality))
- def compare_eests_mtags(self, eestlist, mtaglist):
- self.assertEqual(len(eestlist), len(mtaglist))
- for eest in eestlist:
- if self.io._find_lazy_loaded(eest) is not None:
- eest = self.io.load_lazy_object(eest)
- mtag = mtaglist[eest.annotations["nix_name"]]
- if isinstance(eest, Epoch):
- self.compare_epoch_mtag(eest, mtag)
- elif isinstance(eest, Event):
- self.compare_event_mtag(eest, mtag)
- elif isinstance(eest, SpikeTrain):
- self.compare_spiketrain_mtag(eest, mtag)
- def compare_epoch_mtag(self, epoch, mtag):
- self.assertEqual(mtag.type, "neo.epoch")
- self.compare_attr(epoch, mtag)
- np.testing.assert_almost_equal(epoch.times.magnitude, mtag.positions)
- np.testing.assert_almost_equal(epoch.durations.magnitude, mtag.extents)
- self.assertEqual(mtag.positions.unit,
- str(epoch.times.units.dimensionality))
- self.assertEqual(mtag.extents.unit,
- str(epoch.durations.units.dimensionality))
- for neol, nixl in zip(epoch.labels,
- mtag.positions.dimensions[0].labels):
- # Dirty. Should find the root cause instead
- if isinstance(neol, bytes):
- neol = neol.decode()
- if isinstance(nixl, bytes):
- nixl = nixl.decode()
- self.assertEqual(neol, nixl)
- def compare_event_mtag(self, event, mtag):
- self.assertEqual(mtag.type, "neo.event")
- self.compare_attr(event, mtag)
- np.testing.assert_almost_equal(event.times.magnitude, mtag.positions)
- self.assertEqual(mtag.positions.unit, str(event.units.dimensionality))
- for neol, nixl in zip(event.labels,
- mtag.positions.dimensions[0].labels):
- # Dirty. Should find the root cause instead
- # Only happens in 3.2
- if isinstance(neol, bytes):
- neol = neol.decode()
- if isinstance(nixl, bytes):
- nixl = nixl.decode()
- self.assertEqual(neol, nixl)
- def compare_spiketrain_mtag(self, spiketrain, mtag):
- self.assertEqual(mtag.type, "neo.spiketrain")
- self.compare_attr(spiketrain, mtag)
- np.testing.assert_almost_equal(spiketrain.times.magnitude,
- mtag.positions)
- if len(mtag.features):
- neowf = spiketrain.waveforms
- nixwf = mtag.features[0].data
- self.assertEqual(np.shape(neowf), np.shape(nixwf))
- self.assertEqual(nixwf.unit, str(neowf.units.dimensionality))
- np.testing.assert_almost_equal(neowf.magnitude, nixwf)
- self.assertEqual(nixwf.dimensions[0].dimension_type,
- nix.DimensionType.Set)
- self.assertEqual(nixwf.dimensions[1].dimension_type,
- nix.DimensionType.Set)
- self.assertEqual(nixwf.dimensions[2].dimension_type,
- nix.DimensionType.Sample)
- def compare_attr(self, neoobj, nixobj):
- if isinstance(neoobj, (AnalogSignal, IrregularlySampledSignal)):
- nix_name = ".".join(nixobj.name.split(".")[:-1])
- else:
- nix_name = nixobj.name
- self.assertEqual(neoobj.annotations["nix_name"], nix_name)
- self.assertEqual(neoobj.description, nixobj.definition)
- if hasattr(neoobj, "rec_datetime") and neoobj.rec_datetime:
- self.assertEqual(neoobj.rec_datetime,
- datetime.fromtimestamp(nixobj.created_at))
- if hasattr(neoobj, "file_datetime") and neoobj.file_datetime:
- self.assertEqual(neoobj.file_datetime,
- datetime.fromtimestamp(
- nixobj.metadata["file_datetime"]))
- if neoobj.annotations:
- nixmd = nixobj.metadata
- for k, v, in neoobj.annotations.items():
- if k == "nix_name":
- continue
- if isinstance(v, pq.Quantity):
- self.assertEqual(nixmd.props[str(k)].unit,
- str(v.dimensionality))
- np.testing.assert_almost_equal(nixmd[str(k)],
- v.magnitude)
- else:
- self.assertEqual(nixmd[str(k)], v)
- @classmethod
- def create_full_nix_file(cls, filename):
- nixfile = nix.File.open(filename, nix.FileMode.Overwrite,
- backend="h5py")
- nix_block_a = nixfile.create_block(cls.rword(10), "neo.block")
- nix_block_a.definition = cls.rsentence(5, 10)
- nix_block_b = nixfile.create_block(cls.rword(10), "neo.block")
- nix_block_b.definition = cls.rsentence(3, 3)
- nix_block_a.metadata = nixfile.create_section(
- nix_block_a.name, nix_block_a.name+".metadata"
- )
- nix_block_b.metadata = nixfile.create_section(
- nix_block_b.name, nix_block_b.name+".metadata"
- )
- nix_blocks = [nix_block_a, nix_block_b]
- for blk in nix_blocks:
- for ind in range(3):
- group = blk.create_group(cls.rword(), "neo.segment")
- group.definition = cls.rsentence(10, 15)
- group_md = blk.metadata.create_section(group.name,
- group.name+".metadata")
- group.metadata = group_md
- blk = nix_blocks[0]
- group = blk.groups[0]
- allspiketrains = list()
- allsignalgroups = list()
- # analogsignals
- for n in range(3):
- siggroup = list()
- asig_name = "{}_asig{}".format(cls.rword(10), n)
- asig_definition = cls.rsentence(5, 5)
- asig_md = group.metadata.create_section(asig_name,
- asig_name+".metadata")
- for idx in range(3):
- da_asig = blk.create_data_array(
- "{}.{}".format(asig_name, idx),
- "neo.analogsignal",
- data=cls.rquant(100, 1)
- )
- da_asig.definition = asig_definition
- da_asig.unit = "mV"
- da_asig.metadata = asig_md
- timedim = da_asig.append_sampled_dimension(0.01)
- timedim.unit = "ms"
- timedim.label = "time"
- timedim.offset = 10
- da_asig.append_set_dimension()
- group.data_arrays.append(da_asig)
- siggroup.append(da_asig)
- allsignalgroups.append(siggroup)
- # irregularlysampledsignals
- for n in range(2):
- siggroup = list()
- isig_name = "{}_isig{}".format(cls.rword(10), n)
- isig_definition = cls.rsentence(12, 12)
- isig_md = group.metadata.create_section(isig_name,
- isig_name+".metadata")
- isig_times = cls.rquant(200, 1, True)
- for idx in range(10):
- da_isig = blk.create_data_array(
- "{}.{}".format(isig_name, idx),
- "neo.irregularlysampledsignal",
- data=cls.rquant(200, 1)
- )
- da_isig.definition = isig_definition
- da_isig.unit = "mV"
- da_isig.metadata = isig_md
- timedim = da_isig.append_range_dimension(isig_times)
- timedim.unit = "s"
- timedim.label = "time"
- da_isig.append_set_dimension()
- group.data_arrays.append(da_isig)
- siggroup.append(da_isig)
- allsignalgroups.append(siggroup)
- # SpikeTrains with Waveforms
- for n in range(4):
- stname = "{}-st{}".format(cls.rword(20), n)
- times = cls.rquant(400, 1, True)
- times_da = blk.create_data_array(
- "{}.times".format(stname),
- "neo.spiketrain.times",
- data=times
- )
- times_da.unit = "ms"
- mtag_st = blk.create_multi_tag(stname,
- "neo.spiketrain",
- times_da)
- group.multi_tags.append(mtag_st)
- mtag_st.definition = cls.rsentence(20, 30)
- mtag_st_md = group.metadata.create_section(
- mtag_st.name, mtag_st.name+".metadata"
- )
- mtag_st.metadata = mtag_st_md
- mtag_st_md.create_property(
- "t_stop", nix.Value(times[-1]+1.0)
- )
- waveforms = cls.rquant((10, 8, 5), 1)
- wfname = "{}.waveforms".format(mtag_st.name)
- wfda = blk.create_data_array(wfname, "neo.waveforms",
- data=waveforms)
- wfda.unit = "mV"
- mtag_st.create_feature(wfda, nix.LinkType.Indexed)
- wfda.append_set_dimension() # spike dimension
- wfda.append_set_dimension() # channel dimension
- wftimedim = wfda.append_sampled_dimension(0.1)
- wftimedim.unit = "ms"
- wftimedim.label = "time"
- wfda.metadata = mtag_st_md.create_section(
- wfname, "neo.waveforms.metadata"
- )
- wfda.metadata.create_property("left_sweep",
- [nix.Value(20)]*5)
- allspiketrains.append(mtag_st)
- # Epochs
- for n in range(3):
- epname = "{}-ep{}".format(cls.rword(5), n)
- times = cls.rquant(5, 1, True)
- times_da = blk.create_data_array(
- "{}.times".format(epname),
- "neo.epoch.times",
- data=times
- )
- times_da.unit = "s"
- extents = cls.rquant(5, 1)
- extents_da = blk.create_data_array(
- "{}.durations".format(epname),
- "neo.epoch.durations",
- data=extents
- )
- extents_da.unit = "s"
- mtag_ep = blk.create_multi_tag(
- epname, "neo.epoch", times_da
- )
- mtag_ep.metadata = group.metadata.create_section(
- epname, epname+".metadata"
- )
- group.multi_tags.append(mtag_ep)
- mtag_ep.definition = cls.rsentence(2)
- mtag_ep.extents = extents_da
- label_dim = mtag_ep.positions.append_set_dimension()
- label_dim.labels = cls.rsentence(5).split(" ")
- # reference all signals in the group
- for siggroup in allsignalgroups:
- mtag_ep.references.extend(siggroup)
- # Events
- for n in range(2):
- evname = "{}-ev{}".format(cls.rword(5), n)
- times = cls.rquant(5, 1, True)
- times_da = blk.create_data_array(
- "{}.times".format(evname),
- "neo.event.times",
- data=times
- )
- times_da.unit = "s"
- mtag_ev = blk.create_multi_tag(
- evname, "neo.event", times_da
- )
- mtag_ev.metadata = group.metadata.create_section(
- evname, evname+".metadata"
- )
- group.multi_tags.append(mtag_ev)
- mtag_ev.definition = cls.rsentence(2)
- label_dim = mtag_ev.positions.append_set_dimension()
- label_dim.labels = cls.rsentence(5).split(" ")
- # reference all signals in the group
- for siggroup in allsignalgroups:
- mtag_ev.references.extend(siggroup)
- # CHX
- nixchx = blk.create_source(cls.rword(10),
- "neo.channelindex")
- nixchx.metadata = nix_blocks[0].metadata.create_section(
- nixchx.name, "neo.channelindex.metadata"
- )
- chantype = "neo.channelindex"
- # 3 channels
- for idx, chan in enumerate([2, 5, 9]):
- channame = "{}.ChannelIndex{}".format(nixchx.name, idx)
- nixrc = nixchx.create_source(channame, chantype)
- nixrc.definition = cls.rsentence(13)
- nixrc.metadata = nixchx.metadata.create_section(
- nixrc.name, "neo.channelindex.metadata"
- )
- nixrc.metadata.create_property("index", nix.Value(chan))
- nixrc.metadata.create_property("channel_id", nix.Value(chan+1))
- dims = tuple(map(nix.Value, cls.rquant(3, 1)))
- nixrc.metadata.create_property("coordinates", dims)
- nixrc.metadata.create_property("coordinates.units",
- nix.Value("um"))
- nunits = 1
- stsperunit = np.array_split(allspiketrains, nunits)
- for idx in range(nunits):
- unitname = "{}-unit{}".format(cls.rword(5), idx)
- nixunit = nixchx.create_source(unitname, "neo.unit")
- nixunit.metadata = nixchx.metadata.create_section(
- unitname, unitname+".metadata"
- )
- nixunit.definition = cls.rsentence(4, 10)
- for st in stsperunit[idx]:
- st.sources.append(nixchx)
- st.sources.append(nixunit)
- # pick a few signal groups to reference this CHX
- randsiggroups = np.random.choice(allsignalgroups, 5, False)
- for siggroup in randsiggroups:
- for sig in siggroup:
- sig.sources.append(nixchx)
- return nixfile
- @staticmethod
- def rdate():
- return datetime(year=np.random.randint(1980, 2020),
- month=np.random.randint(1, 13),
- day=np.random.randint(1, 29))
- @classmethod
- def populate_dates(cls, obj):
- obj.file_datetime = cls.rdate()
- obj.rec_datetime = cls.rdate()
- @staticmethod
- def rword(n=10):
- return "".join(np.random.choice(list(string.ascii_letters), n))
- @classmethod
- def rsentence(cls, n=3, maxwl=10):
- return " ".join(cls.rword(np.random.randint(1, maxwl))
- for _ in range(n))
- @classmethod
- def rdict(cls, nitems):
- rd = dict()
- for _ in range(nitems):
- key = cls.rword()
- value = cls.rword() if np.random.choice((0, 1)) \
- else np.random.uniform()
- rd[key] = value
- return rd
- @staticmethod
- def rquant(shape, unit, incr=False):
- try:
- dim = len(shape)
- except TypeError:
- dim = 1
- if incr and dim > 1:
- raise TypeError("Shape of quantity array may only be "
- "one-dimensional when incremental values are "
- "requested.")
- arr = np.random.random(shape)
- if incr:
- arr = np.array(np.cumsum(arr))
- return arr*unit
- @classmethod
- def create_all_annotated(cls):
- times = cls.rquant(1, pq.s)
- signal = cls.rquant(1, pq.V)
- blk = Block()
- blk.annotate(**cls.rdict(3))
- seg = Segment()
- seg.annotate(**cls.rdict(4))
- blk.segments.append(seg)
- asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz)
- asig.annotate(**cls.rdict(2))
- seg.analogsignals.append(asig)
- isig = IrregularlySampledSignal(times=times, signal=signal,
- time_units=pq.s)
- isig.annotate(**cls.rdict(2))
- seg.irregularlysampledsignals.append(isig)
- epoch = Epoch(times=times, durations=times)
- epoch.annotate(**cls.rdict(4))
- seg.epochs.append(epoch)
- event = Event(times=times)
- event.annotate(**cls.rdict(4))
- seg.events.append(event)
- spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s)
- d = cls.rdict(6)
- d["quantity"] = pq.Quantity(10, "mV")
- d["qarray"] = pq.Quantity(range(10), "mA")
- spiketrain.annotate(**d)
- seg.spiketrains.append(spiketrain)
- chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10])
- chx.annotate(**cls.rdict(5))
- blk.channel_indexes.append(chx)
- unit = Unit()
- unit.annotate(**cls.rdict(2))
- chx.units.append(unit)
- return blk
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOWriteTest(NixIOTest):
- def setUp(self):
- self.filename = "nixio_testfile_write.h5"
- self.writer = NixIO(self.filename, "ow")
- self.io = self.writer
- self.reader = nix.File.open(self.filename,
- nix.FileMode.ReadOnly,
- backend="h5py")
- def tearDown(self):
- self.writer.close()
- self.reader.close()
- os.remove(self.filename)
- def write_and_compare(self, blocks):
- self.writer.write_all_blocks(blocks)
- self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks)
- self.compare_blocks(blocks, self.reader.blocks)
- def test_block_write(self):
- block = Block(name=self.rword(),
- description=self.rsentence())
- self.write_and_compare([block])
- block.annotate(**self.rdict(5))
- self.write_and_compare([block])
- def test_segment_write(self):
- block = Block(name=self.rword())
- segment = Segment(name=self.rword(), description=self.rword())
- block.segments.append(segment)
- self.write_and_compare([block])
- segment.annotate(**self.rdict(2))
- self.write_and_compare([block])
- def test_channel_index_write(self):
- block = Block(name=self.rword())
- chx = ChannelIndex(name=self.rword(),
- description=self.rsentence(),
- channel_ids=[10, 20, 30, 50, 80, 130],
- index=[1, 2, 3, 5, 8, 13])
- block.channel_indexes.append(chx)
- self.write_and_compare([block])
- chx.annotate(**self.rdict(3))
- self.write_and_compare([block])
- chx.channel_names = ["one", "two", "three", "five",
- "eight", "xiii"]
- self.write_and_compare([block])
- def test_signals_write(self):
- block = Block()
- seg = Segment()
- block.segments.append(seg)
- asig = AnalogSignal(signal=self.rquant((10, 3), pq.mV),
- sampling_rate=pq.Quantity(10, "Hz"))
- seg.analogsignals.append(asig)
- self.write_and_compare([block])
- anotherblock = Block("ir signal block")
- seg = Segment("ir signal seg")
- anotherblock.segments.append(seg)
- irsig = IrregularlySampledSignal(
- signal=np.random.random((20, 3)),
- times=self.rquant(20, pq.ms, True),
- units=pq.A
- )
- seg.irregularlysampledsignals.append(irsig)
- self.write_and_compare([block, anotherblock])
- block.segments[0].analogsignals.append(
- AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S,
- sampling_period=pq.Quantity(3, "s"),
- dtype=np.double, name="signal42",
- description="this is an analogsignal",
- t_start=45 * pq.ms),
- )
- self.write_and_compare([block, anotherblock])
- block.segments[0].irregularlysampledsignals.append(
- IrregularlySampledSignal(times=np.random.random(10),
- signal=np.random.random((10, 3)),
- units="mV", time_units="s",
- dtype=np.float,
- name="some sort of signal",
- description="the signal is described")
- )
- self.write_and_compare([block, anotherblock])
- def test_epoch_write(self):
- block = Block()
- seg = Segment()
- block.segments.append(seg)
- epoch = Epoch(times=[1, 1, 10, 3]*pq.ms, durations=[3, 3, 3, 1]*pq.ms,
- labels=np.array(["one", "two", "three", "four"]),
- name="test epoch", description="an epoch for testing")
- seg.epochs.append(epoch)
- self.write_and_compare([block])
- def test_event_write(self):
- block = Block()
- seg = Segment()
- block.segments.append(seg)
- event = Event(times=np.arange(0, 30, 10)*pq.s,
- labels=np.array(["0", "1", "2"]),
- name="event name",
- description="event description")
- seg.events.append(event)
- self.write_and_compare([block])
- def test_spiketrain_write(self):
- block = Block()
- seg = Segment()
- block.segments.append(seg)
- spiketrain = SpikeTrain(times=[3, 4, 5]*pq.s, t_stop=10.0,
- name="spikes!", description="sssssspikes")
- seg.spiketrains.append(spiketrain)
- self.write_and_compare([block])
- waveforms = self.rquant((3, 5, 10), pq.mV)
- spiketrain = SpikeTrain(times=[1, 1.1, 1.2]*pq.ms, t_stop=1.5*pq.s,
- name="spikes with wf",
- description="spikes for waveform test",
- waveforms=waveforms)
- seg.spiketrains.append(spiketrain)
- self.write_and_compare([block])
- spiketrain.left_sweep = np.random.random(10)*pq.ms
- self.write_and_compare([block])
- def test_metadata_structure_write(self):
- neoblk = self.create_all_annotated()
- self.io.write_block(neoblk)
- blk = self.io.nix_file.blocks[0]
- blkmd = blk.metadata
- self.assertEqual(blk.name, blkmd.name)
- grp = blk.groups[0] # segment
- self.assertIn(grp.name, blkmd.sections)
- grpmd = blkmd.sections[grp.name]
- for da in grp.data_arrays: # signals
- name = ".".join(da.name.split(".")[:-1])
- self.assertIn(name, grpmd.sections)
- for mtag in grp.multi_tags: # spiketrains, events, and epochs
- self.assertIn(mtag.name, grpmd.sections)
- srcchx = blk.sources[0] # chx
- self.assertIn(srcchx.name, blkmd.sections)
- for srcunit in blk.sources: # units
- self.assertIn(srcunit.name, blkmd.sections)
- self.write_and_compare([neoblk])
- def test_anonymous_objects_write(self):
- nblocks = 2
- nsegs = 2
- nanasig = 4
- nirrseg = 2
- nepochs = 3
- nevents = 4
- nspiketrains = 3
- nchx = 5
- nunits = 10
- times = self.rquant(1, pq.s)
- signal = self.rquant(1, pq.V)
- blocks = []
- for blkidx in range(nblocks):
- blk = Block()
- blocks.append(blk)
- for segidx in range(nsegs):
- seg = Segment()
- blk.segments.append(seg)
- for anaidx in range(nanasig):
- seg.analogsignals.append(AnalogSignal(signal=signal,
- sampling_rate=pq.Hz))
- for irridx in range(nirrseg):
- seg.irregularlysampledsignals.append(
- IrregularlySampledSignal(times=times,
- signal=signal,
- time_units=pq.s)
- )
- for epidx in range(nepochs):
- seg.epochs.append(Epoch(times=times, durations=times))
- for evidx in range(nevents):
- seg.events.append(Event(times=times))
- for stidx in range(nspiketrains):
- seg.spiketrains.append(SpikeTrain(times=times,
- t_stop=times[-1]+pq.s,
- units=pq.s))
- for chidx in range(nchx):
- chx = ChannelIndex(name="chx{}".format(chidx),
- index=[1, 2],
- channel_ids=[11, 22])
- blk.channel_indexes.append(chx)
- for unidx in range(nunits):
- unit = Unit()
- chx.units.append(unit)
- self.writer.write_all_blocks(blocks)
- self.compare_blocks(blocks, self.reader.blocks)
- def test_multiref_write(self):
- blk = Block("blk1")
- signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV",
- sampling_period=pq.Quantity(1, "ms"))
- for idx in range(3):
- segname = "seg" + str(idx)
- seg = Segment(segname)
- blk.segments.append(seg)
- seg.analogsignals.append(signal)
- chidx = ChannelIndex([10, 20, 29])
- seg = blk.segments[0]
- st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000,
- units="s")
- seg.spiketrains.append(st)
- blk.channel_indexes.append(chidx)
- for idx in range(6):
- unit = Unit("unit" + str(idx))
- chidx.units.append(unit)
- unit.spiketrains.append(st)
- self.writer.write_block(blk)
- self.compare_blocks([blk], self.reader.blocks)
- def test_to_value(self):
- section = self.io.nix_file.create_section("Metadata value test",
- "Test")
- writeprop = self.io._write_property
- # quantity
- qvalue = pq.Quantity(10, "mV")
- writeprop(section, "qvalue", qvalue)
- self.assertEqual(section["qvalue"], 10)
- self.assertEqual(section.props["qvalue"].unit, "mV")
- # datetime
- dt = self.rdate()
- writeprop(section, "dt", dt)
- self.assertEqual(datetime.fromtimestamp(section["dt"]), dt)
- # string
- randstr = self.rsentence()
- writeprop(section, "randstr", randstr)
- self.assertEqual(section["randstr"], randstr)
- # bytes
- bytestring = b"bytestring"
- writeprop(section, "randbytes", bytestring)
- self.assertEqual(section["randbytes"], bytestring.decode())
- # iterables
- randlist = np.random.random(10).tolist()
- writeprop(section, "randlist", randlist)
- self.assertEqual(randlist, section["randlist"])
- randarray = np.random.random(10)
- writeprop(section, "randarray", randarray)
- np.testing.assert_almost_equal(randarray, section["randarray"])
- # numpy item
- npval = np.float64(2398)
- writeprop(section, "npval", npval)
- self.assertEqual(npval, section["npval"])
- # number
- val = 42
- writeprop(section, "val", val)
- self.assertEqual(val, section["val"])
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOReadTest(NixIOTest):
- filename = "testfile_readtest.h5"
- nixfile = None
- nix_blocks = None
- original_methods = dict()
- @classmethod
- def setUpClass(cls):
- if HAVE_NIX:
- cls.nixfile = cls.create_full_nix_file(cls.filename)
- def setUp(self):
- self.io = NixIO(self.filename, "ro")
- self.original_methods["_read_cascade"] = self.io._read_cascade
- self.original_methods["_update_maps"] = self.io._update_maps
- @classmethod
- def tearDownClass(cls):
- if HAVE_NIX:
- cls.nixfile.close()
- os.remove(cls.filename)
- def tearDown(self):
- self.io.close()
- def test_all_read(self):
- neo_blocks = self.io.read_all_blocks(cascade=True, lazy=False)
- nix_blocks = self.io.nix_file.blocks
- self.compare_blocks(neo_blocks, nix_blocks)
- def test_lazyload_fullcascade_read(self):
- neo_blocks = self.io.read_all_blocks(cascade=True, lazy=True)
- nix_blocks = self.io.nix_file.blocks
- # data objects should be empty
- for block in neo_blocks:
- for seg in block.segments:
- for asig in seg.analogsignals:
- self.assertEqual(len(asig), 0)
- for isig in seg.irregularlysampledsignals:
- self.assertEqual(len(isig), 0)
- for epoch in seg.epochs:
- self.assertEqual(len(epoch), 0)
- for event in seg.events:
- self.assertEqual(len(event), 0)
- for st in seg.spiketrains:
- self.assertEqual(len(st), 0)
- self.compare_blocks(neo_blocks, nix_blocks)
- def test_lazyload_lazycascade_read(self):
- neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=True)
- nix_blocks = self.io.nix_file.blocks
- self.compare_blocks(neo_blocks, nix_blocks)
- def test_lazycascade_read(self):
- def getitem(self, index):
- return self._data.__getitem__(index)
- from neo.io.nixio import LazyList
- getitem_original = LazyList.__getitem__
- LazyList.__getitem__ = getitem
- neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=False)
- for block in neo_blocks:
- self.assertIsInstance(block.segments, LazyList)
- self.assertIsInstance(block.channel_indexes, LazyList)
- for seg in block.segments:
- self.assertIsInstance(seg, string_types)
- for chx in block.channel_indexes:
- self.assertIsInstance(chx, string_types)
- LazyList.__getitem__ = getitem_original
- def test_load_lazy_cascade(self):
- from neo.io.nixio import LazyList
- neo_blocks = self.io.read_all_blocks(cascade="lazy", lazy=False)
- for block in neo_blocks:
- self.assertIsInstance(block.segments, LazyList)
- self.assertIsInstance(block.channel_indexes, LazyList)
- name = block.annotations["nix_name"]
- block = self.io.load_lazy_cascade("/" + name, lazy=False)
- self.assertIsInstance(block.segments, list)
- self.assertIsInstance(block.channel_indexes, list)
- for seg in block.segments:
- self.assertIsInstance(seg.analogsignals, list)
- self.assertIsInstance(seg.irregularlysampledsignals, list)
- self.assertIsInstance(seg.epochs, list)
- self.assertIsInstance(seg.events, list)
- self.assertIsInstance(seg.spiketrains, list)
- def test_nocascade_read(self):
- self.io._read_cascade = mock.Mock()
- neo_blocks = self.io.read_all_blocks(cascade=False)
- self.io._read_cascade.assert_not_called()
- for block in neo_blocks:
- self.assertEqual(len(block.segments), 0)
- nix_block = self.io.nix_file.blocks[block.annotations["nix_name"]]
- self.compare_attr(block, nix_block)
- def test_lazy_load_subschema(self):
- blk = self.io.nix_file.blocks[0]
- segpath = "/" + blk.name + "/segments/" + blk.groups[0].name
- segment = self.io.load_lazy_cascade(segpath, lazy=True)
- self.assertIsInstance(segment, Segment)
- self.assertEqual(segment.annotations["nix_name"], blk.groups[0].name)
- self.assertIs(segment.block, None)
- self.assertEqual(len(segment.analogsignals[0]), 0)
- segment = self.io.load_lazy_cascade(segpath, lazy=False)
- self.assertEqual(np.shape(segment.analogsignals[0]), (100, 3))
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOHashTest(NixIOTest):
- def setUp(self):
- self.hash = NixIO._hash_object
- def _hash_test(self, objtype, argfuncs):
- attr = {}
- for arg, func in argfuncs.items():
- attr[arg] = func()
- obj_one = objtype(**attr)
- obj_two = objtype(**attr)
- hash_one = self.hash(obj_one)
- hash_two = self.hash(obj_two)
- self.assertEqual(hash_one, hash_two)
- for arg, func in argfuncs.items():
- chattr = attr.copy()
- chattr[arg] = func()
- obj_two = objtype(**chattr)
- hash_two = self.hash(obj_two)
- self.assertNotEqual(
- hash_one, hash_two,
- "Hash test failed with different '{}'".format(arg)
- )
- def test_block_seg_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "rec_datetime": self.rdate,
- "file_datetime": self.rdate,
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(Block, argfuncs)
- self._hash_test(Segment, argfuncs)
- self._hash_test(Unit, argfuncs)
- def test_chx_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "index": lambda: np.random.random(10).tolist(),
- "channel_names": lambda: self.rsentence(10).split(" "),
- "coordinates": lambda: [(np.random.random() * pq.cm,
- np.random.random() * pq.cm,
- np.random.random() * pq.cm)]*10,
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(ChannelIndex, argfuncs)
- def test_analogsignal_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "signal": lambda: self.rquant((10, 10), pq.mV),
- "sampling_rate": lambda: np.random.random() * pq.Hz,
- "t_start": lambda: np.random.random() * pq.sec,
- "t_stop": lambda: np.random.random() * pq.sec,
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(AnalogSignal, argfuncs)
- def test_irregularsignal_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "signal": lambda: self.rquant((10, 10), pq.mV),
- "times": lambda: self.rquant(10, pq.ms, True),
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(IrregularlySampledSignal, argfuncs)
- def test_event_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "times": lambda: self.rquant(10, pq.ms),
- "durations": lambda: self.rquant(10, pq.ms),
- "labels": lambda: self.rsentence(10).split(" "),
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(Event, argfuncs)
- self._hash_test(Epoch, argfuncs)
- def test_spiketrain_hash(self):
- argfuncs = {"name": self.rword,
- "description": self.rsentence,
- "times": lambda: self.rquant(10, pq.ms, True),
- "t_start": lambda: -np.random.random() * pq.sec,
- "t_stop": lambda: np.random.random() * 100 * pq.sec,
- "waveforms": lambda: self.rquant((10, 10, 20), pq.mV),
- # annotations
- self.rword(): self.rword,
- self.rword(): lambda: self.rquant((10, 10), pq.mV)}
- self._hash_test(SpikeTrain, argfuncs)
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOPartialWriteTest(NixIOTest):
- filename = "testfile_partialwrite.h5"
- nixfile = None
- neo_blocks = None
- original_methods = dict()
- @classmethod
- def setUpClass(cls):
- if HAVE_NIX:
- cls.nixfile = cls.create_full_nix_file(cls.filename)
- def setUp(self):
- self.io = NixIO(self.filename, "rw")
- self.neo_blocks = self.io.read_all_blocks()
- self.original_methods["_write_attr_annotations"] =\
- self.io._write_attr_annotations
- @classmethod
- def tearDownClass(cls):
- if HAVE_NIX:
- cls.nixfile.close()
- os.remove(cls.filename)
- def tearDown(self):
- self.restore_methods()
- self.io.close()
- def restore_methods(self):
- for name, method in self.original_methods.items():
- setattr(self.io, name, self.original_methods[name])
- def _mock_write_attr(self, objclass):
- typestr = str(objclass.__name__).lower()
- self.io._write_attr_annotations = mock.Mock(
- wraps=self.io._write_attr_annotations,
- side_effect=self.check_obj_type("neo.{}".format(typestr))
- )
- neo_blocks = self.neo_blocks
- self.modify_objects(neo_blocks, excludes=[objclass])
- self.io.write_all_blocks(neo_blocks)
- self.restore_methods()
- def check_obj_type(self, typestring):
- neq = self.assertNotEqual
- def side_effect_func(*args, **kwargs):
- obj = kwargs.get("nixobj", args[0])
- if isinstance(obj, list):
- for sig in obj:
- neq(sig.type, typestring)
- else:
- neq(obj.type, typestring)
- return side_effect_func
- @classmethod
- def modify_objects(cls, objs, excludes=()):
- excludes = tuple(excludes)
- for obj in objs:
- if not (excludes and isinstance(obj, excludes)):
- obj.description = cls.rsentence()
- for container in getattr(obj, "_child_containers", []):
- children = getattr(obj, container)
- cls.modify_objects(children, excludes)
- def test_partial(self):
- for objclass in NixIO.supported_objects:
- self._mock_write_attr(objclass)
- self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
- def test_no_modifications(self):
- self.io._write_attr_annotations = mock.Mock()
- self.io.write_all_blocks(self.neo_blocks)
- self.io._write_attr_annotations.assert_not_called()
- self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
- # clearing hashes and checking again
- for k in self.io._object_hashes.keys():
- self.io._object_hashes[k] = None
- self.io.write_all_blocks(self.neo_blocks)
- self.io._write_attr_annotations.assert_not_called()
- # changing hashes to force rewrite
- for k in self.io._object_hashes.keys():
- self.io._object_hashes[k] = "_"
- self.io.write_all_blocks(self.neo_blocks)
- callcount = self.io._write_attr_annotations.call_count
- self.assertEqual(callcount, len(self.io._object_hashes))
- self.compare_blocks(self.neo_blocks, self.io.nix_file.blocks)
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class NixIOContextTests(NixIOTest):
- filename = "context_test.h5"
- def test_context_write(self):
- neoblock = Block(name=self.rword(), description=self.rsentence())
- with NixIO(self.filename, "ow") as iofile:
- iofile.write_block(neoblock)
- nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly,
- backend="h5py")
- self.compare_blocks([neoblock], nixfile.blocks)
- nixfile.close()
- neoblock.annotate(**self.rdict(5))
- with NixIO(self.filename, "rw") as iofile:
- iofile.write_block(neoblock)
- nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly,
- backend="h5py")
- self.compare_blocks([neoblock], nixfile.blocks)
- nixfile.close()
- os.remove(self.filename)
- def test_context_read(self):
- nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite,
- backend="h5py")
- name_one = self.rword()
- name_two = self.rword()
- nixfile.create_block(name_one, "neo.block")
- nixfile.create_block(name_two, "neo.block")
- nixfile.close()
- with NixIO(self.filename, "ro") as iofile:
- blocks = iofile.read_all_blocks()
- self.assertEqual(blocks[0].annotations["nix_name"], name_one)
- self.assertEqual(blocks[1].annotations["nix_name"], name_two)
- os.remove(self.filename)
- @unittest.skipUnless(HAVE_NIX, "Requires NIX")
- class CommonTests(BaseTestIO, unittest.TestCase):
- ioclass = NixIO
- read_and_write_is_bijective = False
|