Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

test_nixio.py 57 KB


  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2016, German Neuroinformatics Node (G-Node)
  3. # Achilleas Koutsou <achilleas.k@gmail.com>
  4. #
  5. # All rights reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted under the terms of the BSD License. See
  9. # LICENSE file in the root of the Project.
  10. """
  11. Tests for NixIO
  12. """
  13. import os
  14. import shutil
  15. from collections import Iterable
  16. from datetime import datetime
  17. from tempfile import mkdtemp
  18. import unittest
  19. import string
  20. import numpy as np
  21. import quantities as pq
  22. from neo.core import (Block, Segment, ChannelIndex, AnalogSignal,
  23. IrregularlySampledSignal, Unit, SpikeTrain, Event, Epoch)
  24. from neo.test.iotest.common_io_test import BaseTestIO
  25. from neo.io.nixio import NixIO, create_quantity, units_to_string, neover
  26. try:
  27. import nixio as nix
  28. HAVE_NIX = True
  29. except ImportError:
  30. HAVE_NIX = False
  31. try:
  32. from unittest import mock
  33. SKIPMOCK = False
  34. except ImportError:
  35. SKIPMOCK = True
  36. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  37. class NixIOTest(unittest.TestCase):
  38. io = None
  39. tempdir = None
  40. filename = None
  41. def compare_blocks(self, neoblocks, nixblocks):
  42. for neoblock, nixblock in zip(neoblocks, nixblocks):
  43. self.compare_attr(neoblock, nixblock)
  44. self.assertEqual(len(neoblock.segments), len(nixblock.groups))
  45. for idx, neoseg in enumerate(neoblock.segments):
  46. nixgrp = nixblock.groups[neoseg.annotations["nix_name"]]
  47. self.compare_segment_group(neoseg, nixgrp)
  48. self.assertEqual(len(neoblock.channel_indexes),
  49. len(nixblock.sources))
  50. for idx, neochx in enumerate(neoblock.channel_indexes):
  51. nixsrc = nixblock.sources[neochx.annotations["nix_name"]]
  52. self.compare_chx_source(neochx, nixsrc)
  53. self.check_refs(neoblock, nixblock)
  54. def compare_chx_source(self, neochx, nixsrc):
  55. self.compare_attr(neochx, nixsrc)
  56. nix_channels = list(src for src in nixsrc.sources
  57. if src.type == "neo.channelindex")
  58. self.assertEqual(len(neochx.index), len(nix_channels))
  59. if len(neochx.channel_ids):
  60. nix_chanids = list(src.metadata["channel_id"] for src
  61. in nixsrc.sources
  62. if src.type == "neo.channelindex")
  63. self.assertEqual(len(neochx.channel_ids), len(nix_chanids))
  64. for nixchan in nix_channels:
  65. nixchanidx = nixchan.metadata["index"]
  66. try:
  67. neochanpos = list(neochx.index).index(nixchanidx)
  68. except ValueError:
  69. self.fail("Channel indexes do not match.")
  70. if len(neochx.channel_names):
  71. neochanname = neochx.channel_names[neochanpos]
  72. if ((not isinstance(neochanname, str)) and
  73. isinstance(neochanname, bytes)):
  74. neochanname = neochanname.decode()
  75. nixchanname = nixchan.metadata["neo_name"]
  76. self.assertEqual(neochanname, nixchanname)
  77. if len(neochx.channel_ids):
  78. neochanid = neochx.channel_ids[neochanpos]
  79. nixchanid = nixchan.metadata["channel_id"]
  80. self.assertEqual(neochanid, nixchanid)
  81. elif "channel_id" in nixchan.metadata:
  82. self.fail("Channel ID not loaded")
  83. nix_units = list(src for src in nixsrc.sources
  84. if src.type == "neo.unit")
  85. self.assertEqual(len(neochx.units), len(nix_units))
  86. for neounit in neochx.units:
  87. nixunit = nixsrc.sources[neounit.annotations["nix_name"]]
  88. self.compare_attr(neounit, nixunit)
  89. def check_refs(self, neoblock, nixblock):
  90. """
  91. Checks whether the references between objects that are not nested are
  92. mapped correctly (e.g., SpikeTrains referenced by a Unit).
  93. :param neoblock: A Neo block
  94. :param nixblock: The corresponding NIX block
  95. """
  96. for idx, neochx in enumerate(neoblock.channel_indexes):
  97. nixchx = nixblock.sources[neochx.annotations["nix_name"]]
  98. # AnalogSignals referencing CHX
  99. neoasigs = list(sig.annotations["nix_name"]
  100. for sig in neochx.analogsignals)
  101. nixasigs = list(set(da.metadata.name for da in nixblock.data_arrays
  102. if da.type == "neo.analogsignal" and
  103. nixchx in da.sources))
  104. self.assertEqual(len(neoasigs), len(nixasigs))
  105. # IrregularlySampledSignals referencing CHX
  106. neoisigs = list(sig.annotations["nix_name"] for sig in
  107. neochx.irregularlysampledsignals)
  108. nixisigs = list(
  109. set(da.metadata.name for da in nixblock.data_arrays
  110. if da.type == "neo.irregularlysampledsignal" and
  111. nixchx in da.sources)
  112. )
  113. self.assertEqual(len(neoisigs), len(nixisigs))
  114. # SpikeTrains referencing CHX and Units
  115. for sidx, neounit in enumerate(neochx.units):
  116. nixunit = nixchx.sources[neounit.annotations["nix_name"]]
  117. neosts = list(st.annotations["nix_name"]
  118. for st in neounit.spiketrains)
  119. nixsts = list(mt for mt in nixblock.multi_tags
  120. if mt.type == "neo.spiketrain" and
  121. nixunit.name in mt.sources)
  122. # SpikeTrains must also reference CHX
  123. for nixst in nixsts:
  124. self.assertIn(nixchx.name, nixst.sources)
  125. nixsts = list(st.name for st in nixsts)
  126. self.assertEqual(len(neosts), len(nixsts))
  127. for neoname in neosts:
  128. if neoname:
  129. self.assertIn(neoname, nixsts)
  130. # Events and Epochs must reference all Signals in the Group (NIX only)
  131. for nixgroup in nixblock.groups:
  132. nixevep = list(mt for mt in nixgroup.multi_tags
  133. if mt.type in ["neo.event", "neo.epoch"])
  134. nixsigs = list(da.name for da in nixgroup.data_arrays
  135. if da.type in ["neo.analogsignal",
  136. "neo.irregularlysampledsignal"])
  137. for nee in nixevep:
  138. for ns in nixsigs:
  139. self.assertIn(ns, nee.references)
  140. def compare_segment_group(self, neoseg, nixgroup):
  141. self.compare_attr(neoseg, nixgroup)
  142. neo_signals = neoseg.analogsignals + neoseg.irregularlysampledsignals
  143. self.compare_signals_das(neo_signals, nixgroup.data_arrays)
  144. neo_eests = neoseg.epochs + neoseg.events + neoseg.spiketrains
  145. self.compare_eests_mtags(neo_eests, nixgroup.multi_tags)
  146. def compare_signals_das(self, neosignals, data_arrays):
  147. totalsignals = 0
  148. for sig in neosignals:
  149. dalist = list()
  150. nixname = sig.annotations["nix_name"]
  151. for da in data_arrays:
  152. if da.metadata.name == nixname:
  153. dalist.append(da)
  154. _, nsig = np.shape(sig)
  155. totalsignals += nsig
  156. self.assertEqual(nsig, len(dalist))
  157. self.compare_signal_dalist(sig, dalist)
  158. self.assertEqual(totalsignals, len(data_arrays))
  159. def compare_signal_dalist(self, neosig, nixdalist):
  160. """
  161. Check if a Neo Analog or IrregularlySampledSignal matches a list of
  162. NIX DataArrays.
  163. :param neosig: Neo Analog or IrregularlySampledSignal
  164. :param nixdalist: List of DataArrays
  165. """
  166. nixmd = nixdalist[0].metadata
  167. self.assertTrue(all(nixmd == da.metadata for da in nixdalist))
  168. neounit = neosig.units
  169. for sig, da in zip(np.transpose(neosig), nixdalist):
  170. self.compare_attr(neosig, da)
  171. daquant = create_quantity(da[:], da.unit)
  172. np.testing.assert_almost_equal(sig, daquant)
  173. nixunit = create_quantity(1, da.unit)
  174. self.assertEqual(neounit, nixunit)
  175. timedim = da.dimensions[0]
  176. if isinstance(neosig, AnalogSignal):
  177. self.assertEqual(timedim.dimension_type,
  178. nix.DimensionType.Sample)
  179. neosp = neosig.sampling_period
  180. nixsp = create_quantity(timedim.sampling_interval,
  181. timedim.unit)
  182. self.assertEqual(neosp, nixsp)
  183. tsunit = timedim.unit
  184. if "t_start.units" in da.metadata.props:
  185. tsunit = da.metadata["t_start.units"]
  186. neots = neosig.t_start
  187. nixts = create_quantity(timedim.offset, tsunit)
  188. self.assertEqual(neots, nixts)
  189. elif isinstance(neosig, IrregularlySampledSignal):
  190. self.assertEqual(timedim.dimension_type,
  191. nix.DimensionType.Range)
  192. np.testing.assert_almost_equal(neosig.times.magnitude,
  193. timedim.ticks)
  194. self.assertEqual(timedim.unit,
  195. units_to_string(neosig.times.units))
  196. def compare_eests_mtags(self, eestlist, mtaglist):
  197. self.assertEqual(len(eestlist), len(mtaglist))
  198. for eest in eestlist:
  199. mtag = mtaglist[eest.annotations["nix_name"]]
  200. if isinstance(eest, Epoch):
  201. self.compare_epoch_mtag(eest, mtag)
  202. elif isinstance(eest, Event):
  203. self.compare_event_mtag(eest, mtag)
  204. elif isinstance(eest, SpikeTrain):
  205. self.compare_spiketrain_mtag(eest, mtag)
  206. def compare_epoch_mtag(self, epoch, mtag):
  207. self.assertEqual(mtag.type, "neo.epoch")
  208. self.compare_attr(epoch, mtag)
  209. pos = mtag.positions
  210. posquant = create_quantity(pos[:], pos.unit)
  211. ext = mtag.extents
  212. extquant = create_quantity(ext[:], ext.unit)
  213. np.testing.assert_almost_equal(epoch.as_quantity(), posquant)
  214. np.testing.assert_almost_equal(epoch.durations, extquant)
  215. for neol, nixl in zip(epoch.labels,
  216. mtag.positions.dimensions[0].labels):
  217. # Dirty. Should find the root cause instead
  218. if isinstance(neol, bytes):
  219. neol = neol.decode()
  220. if isinstance(nixl, bytes):
  221. nixl = nixl.decode()
  222. self.assertEqual(neol, nixl)
  223. def compare_event_mtag(self, event, mtag):
  224. self.assertEqual(mtag.type, "neo.event")
  225. self.compare_attr(event, mtag)
  226. pos = mtag.positions
  227. posquant = create_quantity(pos[:], pos.unit)
  228. np.testing.assert_almost_equal(event.as_quantity(), posquant)
  229. for neol, nixl in zip(event.labels,
  230. mtag.positions.dimensions[0].labels):
  231. # Dirty. Should find the root cause instead
  232. # Only happens in 3.2
  233. if isinstance(neol, bytes):
  234. neol = neol.decode()
  235. if isinstance(nixl, bytes):
  236. nixl = nixl.decode()
  237. self.assertEqual(neol, nixl)
  238. def compare_spiketrain_mtag(self, spiketrain, mtag):
  239. self.assertEqual(mtag.type, "neo.spiketrain")
  240. self.compare_attr(spiketrain, mtag)
  241. pos = mtag.positions
  242. posquant = create_quantity(pos[:], pos.unit)
  243. np.testing.assert_almost_equal(spiketrain.as_quantity(), posquant)
  244. if len(mtag.features):
  245. neowfs = spiketrain.waveforms
  246. nixwfs = mtag.features[0].data
  247. self.assertEqual(np.shape(neowfs), np.shape(nixwfs))
  248. for nixwf, neowf in zip(nixwfs, neowfs):
  249. for nixrow, neorow in zip(nixwf, neowf):
  250. for nixv, neov in zip(nixrow, neorow):
  251. self.assertEqual(create_quantity(nixv, nixwfs.unit),
  252. neov)
  253. self.assertEqual(nixwfs.dimensions[0].dimension_type,
  254. nix.DimensionType.Set)
  255. self.assertEqual(nixwfs.dimensions[1].dimension_type,
  256. nix.DimensionType.Set)
  257. self.assertEqual(nixwfs.dimensions[2].dimension_type,
  258. nix.DimensionType.Sample)
  259. def compare_attr(self, neoobj, nixobj):
  260. if isinstance(neoobj, (AnalogSignal, IrregularlySampledSignal)):
  261. nix_name = ".".join(nixobj.name.split(".")[:-1])
  262. else:
  263. nix_name = nixobj.name
  264. self.assertEqual(neoobj.annotations["nix_name"], nix_name)
  265. self.assertEqual(neoobj.description, nixobj.definition)
  266. if hasattr(neoobj, "rec_datetime") and neoobj.rec_datetime:
  267. self.assertEqual(neoobj.rec_datetime,
  268. datetime.fromtimestamp(nixobj.created_at))
  269. if hasattr(neoobj, "file_datetime") and neoobj.file_datetime:
  270. self.assertEqual(neoobj.file_datetime,
  271. datetime.fromtimestamp(
  272. nixobj.metadata["file_datetime"]))
  273. if neoobj.annotations:
  274. nixmd = nixobj.metadata
  275. for k, v, in neoobj.annotations.items():
  276. if k == "nix_name":
  277. continue
  278. if isinstance(v, pq.Quantity):
  279. nixunit = nixmd.props[str(k)].unit
  280. self.assertEqual(nixunit, units_to_string(v.units))
  281. nixvalue = nixmd[str(k)]
  282. if isinstance(nixvalue, Iterable):
  283. nixvalue = np.array(nixvalue)
  284. np.testing.assert_almost_equal(nixvalue, v.magnitude)
  285. else:
  286. self.assertEqual(nixmd[str(k)], v,
  287. "Property value mismatch: {}".format(k))
  288. @classmethod
  289. def create_full_nix_file(cls, filename):
  290. nixfile = nix.File.open(filename, nix.FileMode.Overwrite)
  291. nix_block_a = nixfile.create_block(cls.rword(10), "neo.block")
  292. nix_block_a.definition = cls.rsentence(5, 10)
  293. nix_block_b = nixfile.create_block(cls.rword(10), "neo.block")
  294. nix_block_b.definition = cls.rsentence(3, 3)
  295. nix_block_a.metadata = nixfile.create_section(
  296. nix_block_a.name, nix_block_a.name + ".metadata"
  297. )
  298. nix_block_a.metadata["neo_name"] = cls.rword(5)
  299. nix_block_b.metadata = nixfile.create_section(
  300. nix_block_b.name, nix_block_b.name + ".metadata"
  301. )
  302. nix_block_b.metadata["neo_name"] = cls.rword(5)
  303. nix_blocks = [nix_block_a, nix_block_b]
  304. for blk in nix_blocks:
  305. for ind in range(3):
  306. group = blk.create_group(cls.rword(), "neo.segment")
  307. group.definition = cls.rsentence(10, 15)
  308. group_md = blk.metadata.create_section(
  309. group.name, group.name + ".metadata"
  310. )
  311. group.metadata = group_md
  312. blk = nix_blocks[0]
  313. group = blk.groups[0]
  314. allspiketrains = list()
  315. allsignalgroups = list()
  316. # analogsignals
  317. for n in range(5):
  318. siggroup = list()
  319. asig_name = "{}_asig{}".format(cls.rword(10), n)
  320. asig_definition = cls.rsentence(5, 5)
  321. asig_md = group.metadata.create_section(asig_name,
  322. asig_name + ".metadata")
  323. for idx in range(3):
  324. da_asig = blk.create_data_array(
  325. "{}.{}".format(asig_name, idx),
  326. "neo.analogsignal",
  327. data=cls.rquant(100, 1)
  328. )
  329. da_asig.definition = asig_definition
  330. da_asig.unit = "mV"
  331. da_asig.metadata = asig_md
  332. timedim = da_asig.append_sampled_dimension(0.01)
  333. timedim.unit = "ms"
  334. timedim.label = "time"
  335. timedim.offset = 10
  336. da_asig.append_set_dimension()
  337. group.data_arrays.append(da_asig)
  338. siggroup.append(da_asig)
  339. asig_md["t_start.dim"] = "ms"
  340. allsignalgroups.append(siggroup)
  341. # irregularlysampledsignals
  342. for n in range(2):
  343. siggroup = list()
  344. isig_name = "{}_isig{}".format(cls.rword(10), n)
  345. isig_definition = cls.rsentence(12, 12)
  346. isig_md = group.metadata.create_section(isig_name,
  347. isig_name + ".metadata")
  348. isig_times = cls.rquant(200, 1, True)
  349. for idx in range(10):
  350. da_isig = blk.create_data_array(
  351. "{}.{}".format(isig_name, idx),
  352. "neo.irregularlysampledsignal",
  353. data=cls.rquant(200, 1)
  354. )
  355. da_isig.definition = isig_definition
  356. da_isig.unit = "mV"
  357. da_isig.metadata = isig_md
  358. timedim = da_isig.append_range_dimension(isig_times)
  359. timedim.unit = "s"
  360. timedim.label = "time"
  361. da_isig.append_set_dimension()
  362. group.data_arrays.append(da_isig)
  363. siggroup.append(da_isig)
  364. allsignalgroups.append(siggroup)
  365. # SpikeTrains with Waveforms
  366. for n in range(4):
  367. stname = "{}-st{}".format(cls.rword(20), n)
  368. times = cls.rquant(40, 1, True)
  369. times_da = blk.create_data_array(
  370. "{}.times".format(stname),
  371. "neo.spiketrain.times",
  372. data=times
  373. )
  374. times_da.unit = "ms"
  375. mtag_st = blk.create_multi_tag(stname, "neo.spiketrain", times_da)
  376. group.multi_tags.append(mtag_st)
  377. mtag_st.definition = cls.rsentence(20, 30)
  378. mtag_st_md = group.metadata.create_section(
  379. mtag_st.name, mtag_st.name + ".metadata"
  380. )
  381. mtag_st.metadata = mtag_st_md
  382. mtag_st_md.create_property("t_stop", times[-1] + 1.0)
  383. waveforms = cls.rquant((10, 8, 5), 1)
  384. wfname = "{}.waveforms".format(mtag_st.name)
  385. wfda = blk.create_data_array(wfname, "neo.waveforms",
  386. data=waveforms)
  387. wfda.unit = "mV"
  388. mtag_st.create_feature(wfda, nix.LinkType.Indexed)
  389. wfda.append_set_dimension() # spike dimension
  390. wfda.append_set_dimension() # channel dimension
  391. wftimedim = wfda.append_sampled_dimension(0.1)
  392. wftimedim.unit = "ms"
  393. wftimedim.label = "time"
  394. wfda.metadata = mtag_st_md.create_section(
  395. wfname, "neo.waveforms.metadata"
  396. )
  397. wfda.metadata.create_property("left_sweep",
  398. [20] * 5)
  399. allspiketrains.append(mtag_st)
  400. # Epochs
  401. for n in range(3):
  402. epname = "{}-ep{}".format(cls.rword(5), n)
  403. times = cls.rquant(5, 1, True)
  404. times_da = blk.create_data_array(
  405. "{}.times".format(epname),
  406. "neo.epoch.times",
  407. data=times
  408. )
  409. times_da.unit = "s"
  410. extents = cls.rquant(5, 1)
  411. extents_da = blk.create_data_array(
  412. "{}.durations".format(epname),
  413. "neo.epoch.durations",
  414. data=extents
  415. )
  416. extents_da.unit = "s"
  417. mtag_ep = blk.create_multi_tag(
  418. epname, "neo.epoch", times_da
  419. )
  420. mtag_ep.metadata = group.metadata.create_section(
  421. epname, epname + ".metadata"
  422. )
  423. group.multi_tags.append(mtag_ep)
  424. mtag_ep.definition = cls.rsentence(2)
  425. mtag_ep.extents = extents_da
  426. label_dim = mtag_ep.positions.append_set_dimension()
  427. label_dim.labels = cls.rsentence(5).split(" ")
  428. # reference all signals in the group
  429. for siggroup in allsignalgroups:
  430. mtag_ep.references.extend(siggroup)
  431. # Events
  432. for n in range(2):
  433. evname = "{}-ev{}".format(cls.rword(5), n)
  434. times = cls.rquant(5, 1, True)
  435. times_da = blk.create_data_array(
  436. "{}.times".format(evname),
  437. "neo.event.times",
  438. data=times
  439. )
  440. times_da.unit = "s"
  441. mtag_ev = blk.create_multi_tag(
  442. evname, "neo.event", times_da
  443. )
  444. mtag_ev.metadata = group.metadata.create_section(
  445. evname, evname + ".metadata"
  446. )
  447. group.multi_tags.append(mtag_ev)
  448. mtag_ev.definition = cls.rsentence(2)
  449. label_dim = mtag_ev.positions.append_set_dimension()
  450. label_dim.labels = cls.rsentence(5).split(" ")
  451. # reference all signals in the group
  452. for siggroup in allsignalgroups:
  453. mtag_ev.references.extend(siggroup)
  454. # CHX
  455. nixchx = blk.create_source(cls.rword(10),
  456. "neo.channelindex")
  457. nixchx.metadata = nix_blocks[0].metadata.create_section(
  458. nixchx.name, "neo.channelindex.metadata"
  459. )
  460. chantype = "neo.channelindex"
  461. # 3 channels
  462. for idx, chan in enumerate([2, 5, 9]):
  463. channame = "{}.ChannelIndex{}".format(nixchx.name, idx)
  464. nixrc = nixchx.create_source(channame, chantype)
  465. nixrc.definition = cls.rsentence(13)
  466. nixrc.metadata = nixchx.metadata.create_section(
  467. nixrc.name, "neo.channelindex.metadata"
  468. )
  469. nixrc.metadata.create_property("index", chan)
  470. nixrc.metadata.create_property("channel_id", chan + 1)
  471. dims = cls.rquant(3, 1)
  472. coordprop = nixrc.metadata.create_property("coordinates", dims)
  473. coordprop.unit = "pm"
  474. nunits = 1
  475. stsperunit = np.array_split(allspiketrains, nunits)
  476. for idx in range(nunits):
  477. unitname = "{}-unit{}".format(cls.rword(5), idx)
  478. nixunit = nixchx.create_source(unitname, "neo.unit")
  479. nixunit.metadata = nixchx.metadata.create_section(
  480. unitname, unitname + ".metadata"
  481. )
  482. nixunit.definition = cls.rsentence(4, 10)
  483. for st in stsperunit[idx]:
  484. st.sources.append(nixchx)
  485. st.sources.append(nixunit)
  486. # pick a few signal groups to reference this CHX
  487. randsiggroups = np.random.choice(allsignalgroups, 5, False)
  488. for siggroup in randsiggroups:
  489. for sig in siggroup:
  490. sig.sources.append(nixchx)
  491. return nixfile
  492. @staticmethod
  493. def rdate():
  494. return datetime(year=np.random.randint(1980, 2020),
  495. month=np.random.randint(1, 13),
  496. day=np.random.randint(1, 29))
  497. @classmethod
  498. def populate_dates(cls, obj):
  499. obj.file_datetime = cls.rdate()
  500. obj.rec_datetime = cls.rdate()
  501. @staticmethod
  502. def rword(n=10):
  503. return "".join(np.random.choice(list(string.ascii_letters), n))
  504. @classmethod
  505. def rsentence(cls, n=3, maxwl=10):
  506. return " ".join(cls.rword(np.random.randint(1, maxwl))
  507. for _ in range(n))
  508. @classmethod
  509. def rdict(cls, nitems):
  510. rd = dict()
  511. for _ in range(nitems):
  512. key = cls.rword()
  513. value = cls.rword() if np.random.choice((0, 1)) \
  514. else np.random.uniform()
  515. rd[key] = value
  516. return rd
  517. @staticmethod
  518. def rquant(shape, unit, incr=False):
  519. try:
  520. dim = len(shape)
  521. except TypeError:
  522. dim = 1
  523. if incr and dim > 1:
  524. raise TypeError("Shape of quantity array may only be "
  525. "one-dimensional when incremental values are "
  526. "requested.")
  527. arr = np.random.random(shape)
  528. if incr:
  529. arr = np.array(np.cumsum(arr))
  530. return arr * unit
  531. @classmethod
  532. def create_all_annotated(cls):
  533. times = cls.rquant(1, pq.s)
  534. signal = cls.rquant(1, pq.V)
  535. blk = Block()
  536. blk.annotate(**cls.rdict(3))
  537. cls.populate_dates(blk)
  538. seg = Segment()
  539. seg.annotate(**cls.rdict(4))
  540. cls.populate_dates(seg)
  541. blk.segments.append(seg)
  542. asig = AnalogSignal(signal=signal, sampling_rate=pq.Hz)
  543. asig.annotate(**cls.rdict(2))
  544. seg.analogsignals.append(asig)
  545. isig = IrregularlySampledSignal(times=times, signal=signal,
  546. time_units=pq.s)
  547. isig.annotate(**cls.rdict(2))
  548. seg.irregularlysampledsignals.append(isig)
  549. epoch = Epoch(times=times, durations=times)
  550. epoch.annotate(**cls.rdict(4))
  551. seg.epochs.append(epoch)
  552. event = Event(times=times)
  553. event.annotate(**cls.rdict(4))
  554. seg.events.append(event)
  555. spiketrain = SpikeTrain(times=times, t_stop=pq.s, units=pq.s)
  556. d = cls.rdict(6)
  557. d["quantity"] = pq.Quantity(10, "mV")
  558. d["qarray"] = pq.Quantity(range(10), "mA")
  559. spiketrain.annotate(**d)
  560. seg.spiketrains.append(spiketrain)
  561. chx = ChannelIndex(name="achx", index=[1, 2], channel_ids=[0, 10])
  562. chx.annotate(**cls.rdict(5))
  563. blk.channel_indexes.append(chx)
  564. unit = Unit()
  565. unit.annotate(**cls.rdict(2))
  566. chx.units.append(unit)
  567. return blk
  568. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  569. class NixIOWriteTest(NixIOTest):
  570. def setUp(self):
  571. self.tempdir = mkdtemp(prefix="nixiotest")
  572. self.filename = os.path.join(self.tempdir, "testnixio.nix")
  573. self.writer = NixIO(self.filename, "ow")
  574. self.io = self.writer
  575. self.reader = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  576. def tearDown(self):
  577. self.writer.close()
  578. self.reader.close()
  579. shutil.rmtree(self.tempdir)
  580. def write_and_compare(self, blocks, use_obj_names=False):
  581. self.writer.write_all_blocks(blocks, use_obj_names)
  582. self.compare_blocks(blocks, self.reader.blocks)
  583. self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks)
  584. self.compare_blocks(blocks, self.reader.blocks)
  585. def test_block_write(self):
  586. block = Block(name=self.rword(),
  587. description=self.rsentence())
  588. self.write_and_compare([block])
  589. block.annotate(**self.rdict(5))
  590. self.write_and_compare([block])
  591. def test_segment_write(self):
  592. block = Block(name=self.rword())
  593. segment = Segment(name=self.rword(), description=self.rword())
  594. block.segments.append(segment)
  595. self.write_and_compare([block])
  596. segment.annotate(**self.rdict(2))
  597. self.write_and_compare([block])
  598. def test_channel_index_write(self):
  599. block = Block(name=self.rword())
  600. chx = ChannelIndex(name=self.rword(),
  601. description=self.rsentence(),
  602. channel_ids=[10, 20, 30, 50, 80, 130],
  603. index=[1, 2, 3, 5, 8, 13])
  604. block.channel_indexes.append(chx)
  605. self.write_and_compare([block])
  606. chx.annotate(**self.rdict(3))
  607. self.write_and_compare([block])
  608. chx.channel_names = ["one", "two", "three", "five",
  609. "eight", "xiii"]
  610. chx.coordinates = self.rquant((6, 3), pq.um)
  611. self.write_and_compare([block])
  612. # add an empty channel index and check again
  613. newchx = ChannelIndex(np.array([]))
  614. block.channel_indexes.append(newchx)
  615. self.write_and_compare([block])
  616. def test_signals_write(self):
  617. block = Block()
  618. seg = Segment()
  619. block.segments.append(seg)
  620. asig = AnalogSignal(signal=self.rquant((19, 15), pq.mV),
  621. sampling_rate=pq.Quantity(10, "Hz"))
  622. seg.analogsignals.append(asig)
  623. self.write_and_compare([block])
  624. anotherblock = Block("ir signal block")
  625. seg = Segment("ir signal seg")
  626. anotherblock.segments.append(seg)
  627. irsig = IrregularlySampledSignal(
  628. signal=np.random.random((20, 30)),
  629. times=self.rquant(20, pq.ms, True),
  630. units=pq.A
  631. )
  632. seg.irregularlysampledsignals.append(irsig)
  633. self.write_and_compare([block, anotherblock])
  634. block.segments[0].analogsignals.append(
  635. AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S,
  636. sampling_period=pq.Quantity(3, "s"),
  637. dtype=np.double, name="signal42",
  638. description="this is an analogsignal",
  639. t_start=45 * pq.ms),
  640. )
  641. self.write_and_compare([block, anotherblock])
  642. block.segments[0].irregularlysampledsignals.append(
  643. IrregularlySampledSignal(times=np.random.random(10),
  644. signal=np.random.random((10, 13)),
  645. units="mV", time_units="s",
  646. dtype=np.float,
  647. name="some sort of signal",
  648. description="the signal is described")
  649. )
  650. self.write_and_compare([block, anotherblock])
  651. def test_signals_compound_units(self):
  652. block = Block()
  653. seg = Segment()
  654. block.segments.append(seg)
  655. units = pq.CompoundUnit("1/30000*V")
  656. srate = pq.Quantity(10, pq.CompoundUnit("1.0/10 * Hz"))
  657. asig = AnalogSignal(signal=self.rquant((10, 23), units),
  658. sampling_rate=srate)
  659. seg.analogsignals.append(asig)
  660. self.write_and_compare([block])
  661. anotherblock = Block("ir signal block")
  662. seg = Segment("ir signal seg")
  663. anotherblock.segments.append(seg)
  664. irsig = IrregularlySampledSignal(
  665. signal=np.random.random((20, 3)),
  666. times=self.rquant(20, pq.CompoundUnit("0.1 * ms"), True),
  667. units=pq.CompoundUnit("10 * V / s")
  668. )
  669. seg.irregularlysampledsignals.append(irsig)
  670. self.write_and_compare([block, anotherblock])
  671. block.segments[0].analogsignals.append(
  672. AnalogSignal(signal=[10.0, 1.0, 3.0], units=pq.S,
  673. sampling_period=pq.Quantity(3, "s"),
  674. dtype=np.double, name="signal42",
  675. description="this is an analogsignal",
  676. t_start=45 * pq.CompoundUnit("3.14 * s")),
  677. )
  678. self.write_and_compare([block, anotherblock])
  679. times = self.rquant(10, pq.CompoundUnit("3 * year"), True)
  680. block.segments[0].irregularlysampledsignals.append(
  681. IrregularlySampledSignal(times=times,
  682. signal=np.random.random((10, 3)),
  683. units="mV", dtype=np.float,
  684. name="some sort of signal",
  685. description="the signal is described")
  686. )
  687. self.write_and_compare([block, anotherblock])
  688. def test_epoch_write(self):
  689. block = Block()
  690. seg = Segment()
  691. block.segments.append(seg)
  692. epoch = Epoch(times=[1, 1, 10, 3] * pq.ms,
  693. durations=[3, 3, 3, 1] * pq.ms,
  694. labels=np.array(["one", "two", "three", "four"]),
  695. name="test epoch", description="an epoch for testing")
  696. seg.epochs.append(epoch)
  697. self.write_and_compare([block])
  698. def test_event_write(self):
  699. block = Block()
  700. seg = Segment()
  701. block.segments.append(seg)
  702. event = Event(times=np.arange(0, 30, 10) * pq.s,
  703. labels=np.array(["0", "1", "2"]),
  704. name="event name",
  705. description="event description")
  706. seg.events.append(event)
  707. self.write_and_compare([block])
  708. def test_spiketrain_write(self):
  709. block = Block()
  710. seg = Segment()
  711. block.segments.append(seg)
  712. spiketrain = SpikeTrain(times=[3, 4, 5] * pq.s, t_stop=10.0,
  713. name="spikes!", description="sssssspikes")
  714. seg.spiketrains.append(spiketrain)
  715. self.write_and_compare([block])
  716. waveforms = self.rquant((3, 5, 10), pq.mV)
  717. spiketrain = SpikeTrain(times=[1, 1.1, 1.2] * pq.ms, t_stop=1.5 * pq.s,
  718. name="spikes with wf",
  719. description="spikes for waveform test",
  720. waveforms=waveforms)
  721. seg.spiketrains.append(spiketrain)
  722. self.write_and_compare([block])
  723. spiketrain.left_sweep = np.random.random(10) * pq.ms
  724. self.write_and_compare([block])
  725. spiketrain.left_sweep = pq.Quantity(-10, "ms")
  726. self.write_and_compare([block])
  727. def test_metadata_structure_write(self):
  728. neoblk = self.create_all_annotated()
  729. self.io.write_block(neoblk)
  730. blk = self.io.nix_file.blocks[0]
  731. blkmd = blk.metadata
  732. self.assertEqual(blk.name, blkmd.name)
  733. grp = blk.groups[0] # segment
  734. self.assertIn(grp.name, blkmd.sections)
  735. grpmd = blkmd.sections[grp.name]
  736. for da in grp.data_arrays: # signals
  737. name = ".".join(da.name.split(".")[:-1])
  738. self.assertIn(name, grpmd.sections)
  739. for mtag in grp.multi_tags: # spiketrains, events, and epochs
  740. self.assertIn(mtag.name, grpmd.sections)
  741. srcchx = blk.sources[0] # chx
  742. self.assertIn(srcchx.name, blkmd.sections)
  743. for srcunit in blk.sources: # units
  744. self.assertIn(srcunit.name, blkmd.sections)
  745. self.write_and_compare([neoblk])
  746. def test_anonymous_objects_write(self):
  747. nblocks = 2
  748. nsegs = 2
  749. nanasig = 4
  750. nirrseg = 2
  751. nepochs = 3
  752. nevents = 4
  753. nspiketrains = 3
  754. nchx = 5
  755. nunits = 10
  756. times = self.rquant(1, pq.s)
  757. signal = self.rquant(1, pq.V)
  758. blocks = []
  759. for blkidx in range(nblocks):
  760. blk = Block()
  761. blocks.append(blk)
  762. for segidx in range(nsegs):
  763. seg = Segment()
  764. blk.segments.append(seg)
  765. for anaidx in range(nanasig):
  766. seg.analogsignals.append(AnalogSignal(signal=signal,
  767. sampling_rate=pq.Hz))
  768. for irridx in range(nirrseg):
  769. seg.irregularlysampledsignals.append(
  770. IrregularlySampledSignal(times=times,
  771. signal=signal,
  772. time_units=pq.s)
  773. )
  774. for epidx in range(nepochs):
  775. seg.epochs.append(Epoch(times=times, durations=times))
  776. for evidx in range(nevents):
  777. seg.events.append(Event(times=times))
  778. for stidx in range(nspiketrains):
  779. seg.spiketrains.append(SpikeTrain(times=times,
  780. t_stop=times[-1] + pq.s,
  781. units=pq.s))
  782. for chidx in range(nchx):
  783. chx = ChannelIndex(index=[1, 2],
  784. channel_ids=[11, 22])
  785. blk.channel_indexes.append(chx)
  786. for unidx in range(nunits):
  787. unit = Unit()
  788. chx.units.append(unit)
  789. self.writer.write_all_blocks(blocks)
  790. self.compare_blocks(blocks, self.reader.blocks)
  791. with self.assertRaises(ValueError):
  792. self.writer.write_all_blocks(blocks, use_obj_names=True)
  793. def test_name_objects_write(self):
  794. nblocks = 2
  795. nsegs = 2
  796. nanasig = 4
  797. nirrseg = 2
  798. nepochs = 3
  799. nevents = 4
  800. nspiketrains = 3
  801. nchx = 5
  802. nunits = 10
  803. times = self.rquant(1, pq.s)
  804. signal = self.rquant(1, pq.V)
  805. blocks = []
  806. for blkidx in range(nblocks):
  807. blk = Block(name="block{}".format(blkidx))
  808. blocks.append(blk)
  809. for segidx in range(nsegs):
  810. seg = Segment(name="seg{}".format(segidx))
  811. blk.segments.append(seg)
  812. for anaidx in range(nanasig):
  813. asig = AnalogSignal(
  814. name="{}:as{}".format(seg.name, anaidx),
  815. signal=signal, sampling_rate=pq.Hz
  816. )
  817. seg.analogsignals.append(asig)
  818. for irridx in range(nirrseg):
  819. isig = IrregularlySampledSignal(
  820. name="{}:is{}".format(seg.name, irridx),
  821. times=times,
  822. signal=signal,
  823. time_units=pq.s
  824. )
  825. seg.irregularlysampledsignals.append(isig)
  826. for epidx in range(nepochs):
  827. seg.epochs.append(
  828. Epoch(name="{}:ep{}".format(seg.name, epidx),
  829. times=times, durations=times)
  830. )
  831. for evidx in range(nevents):
  832. seg.events.append(
  833. Event(name="{}:ev{}".format(seg.name, evidx),
  834. times=times)
  835. )
  836. for stidx in range(nspiketrains):
  837. seg.spiketrains.append(
  838. SpikeTrain(name="{}:st{}".format(seg.name, stidx),
  839. times=times,
  840. t_stop=times[-1] + pq.s,
  841. units=pq.s)
  842. )
  843. for chidx in range(nchx):
  844. chx = ChannelIndex(name="chx{}".format(chidx),
  845. index=[1, 2],
  846. channel_ids=[11, 22])
  847. blk.channel_indexes.append(chx)
  848. for unidx in range(nunits):
  849. unit = Unit(name="unit{}".format(unidx))
  850. chx.units.append(unit)
  851. # put guard on _generate_nix_name
  852. if not SKIPMOCK:
  853. nixgenmock = mock.Mock(name="_generate_nix_name",
  854. wraps=self.io._generate_nix_name)
  855. self.io._generate_nix_name = nixgenmock
  856. self.writer.write_block(blocks[0], use_obj_names=True)
  857. self.compare_blocks([blocks[0]], self.reader.blocks)
  858. self.compare_blocks(self.writer.read_all_blocks(), self.reader.blocks)
  859. self.compare_blocks(blocks, self.reader.blocks)
  860. if not SKIPMOCK:
  861. nixgenmock.assert_not_called()
  862. self.write_and_compare(blocks, use_obj_names=True)
  863. if not SKIPMOCK:
  864. nixgenmock.assert_not_called()
  865. self.assertEqual(self.reader.blocks[0].name, "block0")
  866. blocks[0].name = blocks[1].name # name conflict
  867. with self.assertRaises(ValueError):
  868. self.writer.write_all_blocks(blocks, use_obj_names=True)
  869. blocks[0].name = "new name"
  870. self.assertEqual(blocks[0].segments[1].spiketrains[1].name, "seg1:st1")
  871. st0 = blocks[0].segments[0].spiketrains[0].name
  872. blocks[0].segments[0].spiketrains[1].name = st0 # name conflict
  873. with self.assertRaises(ValueError):
  874. self.writer.write_all_blocks(blocks, use_obj_names=True)
  875. with self.assertRaises(ValueError):
  876. self.writer.write_block(blocks[0], use_obj_names=True)
  877. if not SKIPMOCK:
  878. nixgenmock.assert_not_called()
  879. def test_name_conflicts(self):
  880. # anon block
  881. blk = Block()
  882. with self.assertRaises(ValueError):
  883. self.io.write_block(blk, use_obj_names=True)
  884. # two anon blocks
  885. blocks = [Block(), Block()]
  886. with self.assertRaises(ValueError):
  887. self.io.write_all_blocks(blocks, use_obj_names=True)
  888. # same name blocks
  889. blocks = [Block(name="one"), Block(name="one")]
  890. with self.assertRaises(ValueError):
  891. self.io.write_all_blocks(blocks, use_obj_names=True)
  892. # one block, two same name segments
  893. blk = Block("new")
  894. seg = Segment("I am the segment", a="a annoation")
  895. blk.segments.append(seg)
  896. seg = Segment("I am the segment", a="b annotation")
  897. blk.segments.append(seg)
  898. with self.assertRaises(ValueError):
  899. self.io.write_block(blk, use_obj_names=True)
  900. times = self.rquant(1, pq.s)
  901. signal = self.rquant(1, pq.V)
  902. # name conflict: analog + irregular signals
  903. seg.analogsignals.append(
  904. AnalogSignal(name="signal", signal=signal, sampling_rate=pq.Hz)
  905. )
  906. seg.irregularlysampledsignals.append(
  907. IrregularlySampledSignal(name="signal", signal=signal, times=times)
  908. )
  909. blk = Block(name="Signal conflict Block")
  910. blk.segments.append(seg)
  911. with self.assertRaises(ValueError):
  912. self.io.write_block(blk, use_obj_names=True)
  913. # name conflict: event + spiketrain
  914. blk = Block(name="Event+SpikeTrain conflict Block")
  915. seg = Segment(name="Event+SpikeTrain conflict Segment")
  916. blk.segments.append(seg)
  917. seg.events.append(Event(name="TimeyStuff", times=times))
  918. seg.spiketrains.append(SpikeTrain(name="TimeyStuff", times=times,
  919. t_stop=pq.s))
  920. with self.assertRaises(ValueError):
  921. self.io.write_block(blk, use_obj_names=True)
  922. # make spiketrain anon
  923. blk.segments[0].spiketrains[0].name = None
  924. with self.assertRaises(ValueError):
  925. self.io.write_block(blk, use_obj_names=True)
  926. # name conflict in channel indexes
  927. blk = Block(name="ChannelIndex conflict Block")
  928. blk.channel_indexes.append(ChannelIndex(name="chax", index=[1]))
  929. blk.channel_indexes.append(ChannelIndex(name="chax", index=[2]))
  930. with self.assertRaises(ValueError):
  931. self.io.write_block(blk, use_obj_names=True)
  932. # name conflict in units
  933. blk = Block(name="unitconf")
  934. chx = ChannelIndex(name="ok", index=[100])
  935. blk.channel_indexes.append(chx)
  936. chx.units.append(Unit(name="IHAVEATWIN"))
  937. chx.units.append(Unit(name="IHAVEATWIN"))
  938. with self.assertRaises(ValueError):
  939. self.io.write_block(blk, use_obj_names=True)
  940. def test_multiref_write(self):
  941. blk = Block("blk1")
  942. signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV",
  943. sampling_period=pq.Quantity(1, "ms"))
  944. othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0],
  945. units="mV", times=[1, 2, 3],
  946. time_units="ms")
  947. event = Event(name="Evee", times=[0.3, 0.42], units="year")
  948. epoch = Epoch(name="epoche", times=[0.1, 0.2] * pq.min,
  949. durations=[0.5, 0.5] * pq.min)
  950. st = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3],
  951. t_stop=11, units="us")
  952. for idx in range(3):
  953. segname = "seg" + str(idx)
  954. seg = Segment(segname)
  955. blk.segments.append(seg)
  956. seg.analogsignals.append(signal)
  957. seg.irregularlysampledsignals.append(othersignal)
  958. seg.events.append(event)
  959. seg.epochs.append(epoch)
  960. seg.spiketrains.append(st)
  961. chidx = ChannelIndex([10, 20, 29])
  962. seg = blk.segments[0]
  963. st = SpikeTrain(name="choochoo", times=[10, 11, 80], t_stop=1000,
  964. units="s")
  965. seg.spiketrains.append(st)
  966. blk.channel_indexes.append(chidx)
  967. for idx in range(6):
  968. unit = Unit("unit" + str(idx))
  969. chidx.units.append(unit)
  970. unit.spiketrains.append(st)
  971. self.writer.write_block(blk)
  972. self.compare_blocks([blk], self.reader.blocks)
  973. def test_no_segment_write(self):
  974. # Tests storing AnalogSignal, IrregularlySampledSignal, and SpikeTrain
  975. # objects in the secondary (ChannelIndex) substructure without them
  976. # being attached to a Segment.
  977. blk = Block("segmentless block")
  978. signal = AnalogSignal(name="sig1", signal=[0, 1, 2], units="mV",
  979. sampling_period=pq.Quantity(1, "ms"))
  980. othersignal = IrregularlySampledSignal(name="i1", signal=[0, 0, 0],
  981. units="mV", times=[1, 2, 3],
  982. time_units="ms")
  983. sta = SpikeTrain(name="the train of spikes", times=[0.1, 0.2, 10.3],
  984. t_stop=11, units="us")
  985. stb = SpikeTrain(name="the train of spikes b", times=[1.1, 2.2, 10.1],
  986. t_stop=100, units="ms")
  987. chidx = ChannelIndex([8, 13, 21])
  988. blk.channel_indexes.append(chidx)
  989. chidx.analogsignals.append(signal)
  990. chidx.irregularlysampledsignals.append(othersignal)
  991. unit = Unit()
  992. chidx.units.append(unit)
  993. unit.spiketrains.extend([sta, stb])
  994. self.writer.write_block(blk)
  995. self.writer.close()
  996. self.compare_blocks([blk], self.reader.blocks)
  997. reader = NixIO(self.filename, "ro")
  998. blk = reader.read_block(neoname="segmentless block")
  999. chx = blk.channel_indexes[0]
  1000. self.assertEqual(len(chx.analogsignals), 1)
  1001. self.assertEqual(len(chx.irregularlysampledsignals), 1)
  1002. self.assertEqual(len(chx.units[0].spiketrains), 2)
  1003. def test_rewrite_refs(self):
  1004. def checksignalcounts(fname):
  1005. with NixIO(fname, "ro") as r:
  1006. blk = r.read_block()
  1007. chidx = blk.channel_indexes[0]
  1008. seg = blk.segments[0]
  1009. self.assertEqual(len(chidx.analogsignals), 2)
  1010. self.assertEqual(len(chidx.units[0].spiketrains), 3)
  1011. self.assertEqual(len(seg.analogsignals), 1)
  1012. self.assertEqual(len(seg.spiketrains), 1)
  1013. blk = Block()
  1014. # ChannelIndex
  1015. chidx = ChannelIndex(index=[1])
  1016. blk.channel_indexes.append(chidx)
  1017. # Two signals on ChannelIndex
  1018. for idx in range(2):
  1019. asigchx = AnalogSignal(signal=[idx], units="mV",
  1020. sampling_rate=pq.Hz)
  1021. chidx.analogsignals.append(asigchx)
  1022. # Unit
  1023. unit = Unit()
  1024. chidx.units.append(unit)
  1025. # Three SpikeTrains on Unit
  1026. for idx in range(3):
  1027. st = SpikeTrain([idx], units="ms", t_stop=40)
  1028. unit.spiketrains.append(st)
  1029. # Segment
  1030. seg = Segment()
  1031. blk.segments.append(seg)
  1032. # One signal on Segment
  1033. asigseg = AnalogSignal(signal=[2], units="uA",
  1034. sampling_rate=pq.Hz)
  1035. seg.analogsignals.append(asigseg)
  1036. # One spiketrain on Segment
  1037. stseg = SpikeTrain([10], units="ms", t_stop=40)
  1038. seg.spiketrains.append(stseg)
  1039. # Write, compare, and check counts
  1040. self.writer.write_block(blk)
  1041. self.compare_blocks([blk], self.reader.blocks)
  1042. self.assertEqual(len(chidx.analogsignals), 2)
  1043. self.assertEqual(len(seg.analogsignals), 1)
  1044. self.assertEqual(len(chidx.analogsignals), 2)
  1045. self.assertEqual(len(chidx.units[0].spiketrains), 3)
  1046. self.assertEqual(len(seg.analogsignals), 1)
  1047. self.assertEqual(len(seg.spiketrains), 1)
  1048. # Check counts with separate reader
  1049. checksignalcounts(self.filename)
  1050. # Write again and check counts
  1051. secondwrite = os.path.join(self.tempdir, "testnixio-2.nix")
  1052. with NixIO(secondwrite, "ow") as w:
  1053. w.write_block(blk)
  1054. self.compare_blocks([blk], self.reader.blocks)
  1055. # Read back and check counts
  1056. scndreader = nix.File.open(secondwrite, mode=nix.FileMode.ReadOnly)
  1057. self.compare_blocks([blk], scndreader.blocks)
  1058. checksignalcounts(secondwrite)
  1059. def test_to_value(self):
  1060. section = self.io.nix_file.create_section("Metadata value test",
  1061. "Test")
  1062. writeprop = self.io._write_property
  1063. # quantity
  1064. qvalue = pq.Quantity(10, "mV")
  1065. writeprop(section, "qvalue", qvalue)
  1066. self.assertEqual(section["qvalue"], 10)
  1067. self.assertEqual(section.props["qvalue"].unit, "mV")
  1068. # datetime
  1069. dt = self.rdate()
  1070. writeprop(section, "dt", dt)
  1071. self.assertEqual(datetime.fromtimestamp(section["dt"]), dt)
  1072. # string
  1073. randstr = self.rsentence()
  1074. writeprop(section, "randstr", randstr)
  1075. self.assertEqual(section["randstr"], randstr)
  1076. # bytes
  1077. bytestring = b"bytestring"
  1078. writeprop(section, "randbytes", bytestring)
  1079. self.assertEqual(section["randbytes"], bytestring.decode())
  1080. # iterables
  1081. randlist = np.random.random(10).tolist()
  1082. writeprop(section, "randlist", randlist)
  1083. self.assertEqual(randlist, section["randlist"])
  1084. randarray = np.random.random(10)
  1085. writeprop(section, "randarray", randarray)
  1086. np.testing.assert_almost_equal(randarray, section["randarray"])
  1087. # numpy item
  1088. npval = np.float64(2398)
  1089. writeprop(section, "npval", npval)
  1090. self.assertEqual(npval, section["npval"])
  1091. # number
  1092. val = 42
  1093. writeprop(section, "val", val)
  1094. self.assertEqual(val, section["val"])
  1095. # empty string (gets stored as empty list)
  1096. writeprop(section, "emptystring", "")
  1097. self.assertEqual(list(), section["emptystring"])
  1098. def test_annotations_special_cases(self):
  1099. # Special cases for annotations: empty list, list of strings,
  1100. # multidimensional lists/arrays
  1101. # These are handled differently on read, so we test them on a block
  1102. # instead of just checking the property writer method
  1103. # empty value
  1104. # empty list
  1105. wblock = Block("block with empty list", an_empty_list=list())
  1106. self.writer.write_block(wblock)
  1107. rblock = self.writer.read_block(neoname="block with empty list")
  1108. self.assertEqual(rblock.annotations["an_empty_list"], list())
  1109. # empty tuple (gets read out as list)
  1110. wblock = Block("block with empty tuple", an_empty_tuple=tuple())
  1111. self.writer.write_block(wblock)
  1112. rblock = self.writer.read_block(neoname="block with empty tuple")
  1113. self.assertEqual(rblock.annotations["an_empty_tuple"], list())
  1114. # list of strings
  1115. losval = ["one", "two", "one million"]
  1116. wblock = Block("block with list of strings",
  1117. los=losval)
  1118. self.writer.write_block(wblock)
  1119. rblock = self.writer.read_block(neoname="block with list of strings")
  1120. self.assertEqual(rblock.annotations["los"], losval)
  1121. # TODO: multi dimensional value (GH Issue #501)
  1122. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1123. class NixIOReadTest(NixIOTest):
  1124. nixfile = None
  1125. nix_blocks = None
  1126. @classmethod
  1127. def setUpClass(cls):
  1128. cls.tempdir = mkdtemp(prefix="nixiotest")
  1129. cls.filename = os.path.join(cls.tempdir, "testnixio.nix")
  1130. if HAVE_NIX:
  1131. cls.nixfile = cls.create_full_nix_file(cls.filename)
  1132. def setUp(self):
  1133. self.io = NixIO(self.filename, "ro")
  1134. @classmethod
  1135. def tearDownClass(cls):
  1136. if HAVE_NIX:
  1137. cls.nixfile.close()
  1138. shutil.rmtree(cls.tempdir)
  1139. def tearDown(self):
  1140. self.io.close()
  1141. def test_all_read(self):
  1142. neo_blocks = self.io.read_all_blocks()
  1143. nix_blocks = self.io.nix_file.blocks
  1144. self.compare_blocks(neo_blocks, nix_blocks)
  1145. def test_iter_read(self):
  1146. blocknames = [blk.name for blk in self.nixfile.blocks]
  1147. for blk, nixname in zip(self.io.iter_blocks(), blocknames):
  1148. self.assertEqual(blk.annotations["nix_name"], nixname)
  1149. def test_nix_name_read(self):
  1150. for nixblock in self.nixfile.blocks:
  1151. nixname = nixblock.name
  1152. neoblock = self.io.read_block(nixname=nixname)
  1153. self.assertEqual(neoblock.annotations["nix_name"], nixname)
  1154. def test_index_read(self):
  1155. for idx, nixblock in enumerate(self.nixfile.blocks):
  1156. neoblock = self.io.read_block(index=idx)
  1157. self.assertEqual(neoblock.annotations["nix_name"], nixblock.name)
  1158. self.assertEqual(neoblock.annotations["nix_name"],
  1159. self.nixfile.blocks[idx].name)
  1160. def test_auto_index_read(self):
  1161. for nixblock in self.nixfile.blocks:
  1162. neoblock = self.io.read_block() # don't specify index
  1163. self.assertEqual(neoblock.annotations["nix_name"], nixblock.name)
  1164. # No more blocks - should return None
  1165. self.assertIs(self.io.read_block(), None)
  1166. self.assertIs(self.io.read_block(), None)
  1167. self.assertIs(self.io.read_block(), None)
  1168. with NixIO(self.filename, "ro") as nf:
  1169. neoblock = nf.read_block(index=1)
  1170. self.assertEqual(self.nixfile.blocks[1].name,
  1171. neoblock.annotations["nix_name"])
  1172. neoblock = nf.read_block() # should start again from 0
  1173. self.assertEqual(self.nixfile.blocks[0].name,
  1174. neoblock.annotations["nix_name"])
  1175. def test_neo_name_read(self):
  1176. for nixblock in self.nixfile.blocks:
  1177. neoname = nixblock.metadata["neo_name"]
  1178. neoblock = self.io.read_block(neoname=neoname)
  1179. self.assertEqual(neoblock.annotations["nix_name"], nixblock.name)
  1180. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1181. class NixIOContextTests(NixIOTest):
  1182. def setUp(self):
  1183. self.tempdir = mkdtemp(prefix="nixiotest")
  1184. self.filename = os.path.join(self.tempdir, "testnixio.nix")
  1185. def tearDown(self):
  1186. shutil.rmtree(self.tempdir)
  1187. def test_context_write(self):
  1188. neoblock = Block(name=self.rword(), description=self.rsentence())
  1189. with NixIO(self.filename, "ow") as iofile:
  1190. iofile.write_block(neoblock)
  1191. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  1192. self.compare_blocks([neoblock], nixfile.blocks)
  1193. nixfile.close()
  1194. neoblock.annotate(**self.rdict(5))
  1195. with NixIO(self.filename, "rw") as iofile:
  1196. iofile.write_block(neoblock)
  1197. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  1198. self.compare_blocks([neoblock], nixfile.blocks)
  1199. nixfile.close()
  1200. def test_context_read(self):
  1201. nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite)
  1202. name_one = self.rword()
  1203. name_two = self.rword()
  1204. nixfile.create_block(name_one, "neo.block")
  1205. nixfile.create_block(name_two, "neo.block")
  1206. nixfile.close()
  1207. with NixIO(self.filename, "ro") as iofile:
  1208. blocks = iofile.read_all_blocks()
  1209. self.assertEqual(blocks[0].annotations["nix_name"], name_one)
  1210. self.assertEqual(blocks[1].annotations["nix_name"], name_two)
  1211. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1212. class NixIOVerTests(NixIOTest):
  1213. def setUp(self):
  1214. self.tempdir = mkdtemp(prefix="nixiotest")
  1215. self.filename = os.path.join(self.tempdir, "testnixio.nix")
  1216. def tearDown(self):
  1217. shutil.rmtree(self.tempdir)
  1218. def test_new_file(self):
  1219. with NixIO(self.filename, "ow") as iofile:
  1220. self.assertEqual(iofile._file_version, neover)
  1221. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  1222. filever = nixfile.sections["neo"]["version"]
  1223. self.assertEqual(filever, neover)
  1224. nixfile.close()
  1225. def test_oldfile_nover(self):
  1226. nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite)
  1227. nixfile.close()
  1228. with NixIO(self.filename, "ro") as iofile:
  1229. self.assertEqual(iofile._file_version, '0.5.2') # compat version
  1230. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  1231. self.assertNotIn("neo", nixfile.sections)
  1232. nixfile.close()
  1233. with NixIO(self.filename, "rw") as iofile:
  1234. self.assertEqual(iofile._file_version, '0.5.2') # compat version
  1235. # section should have been created now
  1236. nixfile = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  1237. self.assertIn("neo", nixfile.sections)
  1238. self.assertEqual(nixfile.sections["neo"]["version"], '0.5.2')
  1239. nixfile.close()
  1240. def test_file_with_ver(self):
  1241. someversion = '0.100.10'
  1242. nixfile = nix.File.open(self.filename, nix.FileMode.Overwrite)
  1243. filemd = nixfile.create_section("neo", "neo.metadata")
  1244. filemd["version"] = someversion
  1245. nixfile.close()
  1246. with NixIO(self.filename, "ro") as iofile:
  1247. self.assertEqual(iofile._file_version, someversion)
  1248. with NixIO(self.filename, "rw") as iofile:
  1249. self.assertEqual(iofile._file_version, someversion)
  1250. with NixIO(self.filename, "ow") as iofile:
  1251. self.assertEqual(iofile._file_version, neover)
  1252. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  1253. class CommonTests(BaseTestIO, unittest.TestCase):
  1254. ioclass = NixIO
  1255. read_and_write_is_bijective = False
  1256. if __name__ == "__main__":
  1257. unittest.main()