test_nixio_fr.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """
  2. Tests of neo.io.nixio_fr
  3. """
  4. import numpy as np
  5. import unittest
  6. from quantities import s
  7. from neo.io.nixio_fr import NixIO as NixIOfr
  8. import quantities as pq
  9. from neo.io.nixio import NixIO
  10. from neo.test.iotest.common_io_test import BaseTestIO
  11. from neo.core import Block, Segment, AnalogSignal, SpikeTrain, Event
  12. from neo.test.iotest.tools import get_test_file_full_path
  13. try:
  14. import nixio as nix
  15. HAVE_NIX = True
  16. except ImportError:
  17. HAVE_NIX = False
  18. import os
  19. @unittest.skipUnless(HAVE_NIX, "Requires NIX")
  20. class TestNixfr(BaseTestIO, unittest.TestCase, ):
  21. ioclass = NixIOfr
  22. files_to_test = ['nixio_fr.nix']
  23. files_to_download = ['nixio_fr.nix']
  24. def setUp(self):
  25. super().setUp()
  26. self.testfilename = self.get_filename_path('nixio_fr.nix')
  27. self.reader_fr = NixIOfr(filename=self.testfilename)
  28. self.reader_norm = NixIO(filename=self.testfilename, mode='ro')
  29. self.blk = self.reader_fr.read_block(block_index=1, load_waveforms=True)
  30. # read block with NixIOfr
  31. self.blk1 = self.reader_norm.read_block(index=1) # read same block with NixIO
  32. def tearDown(self):
  33. self.reader_fr.file.close()
  34. self.reader_norm.close()
  35. def test_check_same_neo_structure(self):
  36. self.assertEqual(len(self.blk.segments), len(self.blk1.segments))
  37. for seg1, seg2 in zip(self.blk.segments, self.blk1.segments):
  38. self.assertEqual(len(seg1.analogsignals), len(seg2.analogsignals))
  39. self.assertEqual(len(seg1.spiketrains), len(seg2.spiketrains))
  40. self.assertEqual(len(seg1.events), len(seg2.events))
  41. self.assertEqual(len(seg1.epochs), len(seg2.epochs))
  42. def test_check_same_data_content(self):
  43. for seg1, seg2 in zip(self.blk.segments, self.blk1.segments):
  44. for asig1, asig2 in zip(seg1.analogsignals, seg2.analogsignals):
  45. np.testing.assert_almost_equal(asig1.magnitude, asig2.magnitude)
  46. # not completely equal
  47. for st1, st2 in zip(seg1.spiketrains, seg2.spiketrains):
  48. np.testing.assert_array_equal(st1.magnitude, st2.times)
  49. for wf1, wf2 in zip(st1.waveforms, st2.waveforms):
  50. np.testing.assert_array_equal(wf1.shape, wf2.shape)
  51. np.testing.assert_almost_equal(wf1.magnitude, wf2.magnitude)
  52. for ev1, ev2 in zip(seg1.events, seg2.events):
  53. np.testing.assert_almost_equal(ev1.times, ev2.times)
  54. assert np.all(ev1.labels == ev2.labels)
  55. for ep1, ep2 in zip(seg1.epochs, seg2.epochs):
  56. assert len(ep1.durations) == len(ep2.times)
  57. np.testing.assert_almost_equal(ep1.times, ep2.times)
  58. np.testing.assert_array_equal(ep1.durations, ep2.durations)
  59. np.testing.assert_array_equal(ep1.labels, ep2.labels)
  60. # Not testing for channel_index as rawio always read from seg
  61. for chid1, chid2 in zip(self.blk.channel_indexes, self.blk1.channel_indexes):
  62. for asig1, asig2 in zip(chid1.analogsignals, chid2.analogsignals):
  63. np.testing.assert_almost_equal(asig1.magnitude, asig2.magnitude)
  64. def test_analog_signal(self):
  65. seg1 = self.blk.segments[0]
  66. an_sig1 = seg1.analogsignals[0]
  67. assert len(an_sig1) == 30
  68. an_sig2 = seg1.analogsignals[1]
  69. assert an_sig2.shape == (50, 3)
  70. def test_spike_train(self):
  71. st1 = self.blk.segments[0].spiketrains[0]
  72. assert np.all(st1.times == np.cumsum(np.arange(0, 1, 0.1)).tolist() * pq.s + 10 * pq.s)
  73. def test_event(self):
  74. seg1 = self.blk.segments[0]
  75. event1 = seg1.events[0]
  76. raw_time = 10 + np.cumsum(np.array([0, 1, 2, 3, 4]))
  77. assert np.all(event1.times == np.array(raw_time * pq.s / 1000))
  78. assert np.all(event1.labels == np.array(['A', 'B', 'C', 'D', 'E'], dtype='U'))
  79. assert len(seg1.events) == 1
  80. def test_epoch(self):
  81. seg1 = self.blk.segments[1]
  82. seg2 = self.blk1.segments[1]
  83. epoch1 = seg1.epochs[0]
  84. epoch2 = seg2.epochs[0]
  85. assert len(epoch1.durations) == len(epoch1.times)
  86. assert np.all(epoch1.durations == epoch2.durations)
  87. assert np.all(epoch1.labels == epoch2.labels)
  88. def test_annotations(self):
  89. self.testfilename = self.get_filename_path('nixio_fr_ann.nix')
  90. with NixIO(filename=self.testfilename, mode='ow') as io:
  91. annotations = {'my_custom_annotation': 'hello block'}
  92. bl = Block(**annotations)
  93. annotations = {'something': 'hello hello000'}
  94. seg = Segment(**annotations)
  95. an =AnalogSignal([[1, 2, 3], [4, 5, 6]], units='V',
  96. sampling_rate=1*pq.Hz)
  97. an.annotations['ansigrandom'] = 'hello chars'
  98. sp = SpikeTrain([3, 4, 5]* s, t_stop=10.0)
  99. sp.annotations['railway'] = 'hello train'
  100. ev = Event(np.arange(0, 30, 10)*pq.Hz,
  101. labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S'))
  102. ev.annotations['venue'] = 'hello event'
  103. ev2 = Event(np.arange(0, 30, 10) * pq.Hz,
  104. labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S'))
  105. ev2.annotations['evven'] = 'hello ev'
  106. seg.spiketrains.append(sp)
  107. seg.events.append(ev)
  108. seg.events.append(ev2)
  109. seg.analogsignals.append(an)
  110. bl.segments.append(seg)
  111. io.write_block(bl)
  112. io.close()
  113. with NixIOfr(filename=self.testfilename) as frio:
  114. frbl = frio.read_block()
  115. assert 'my_custom_annotation' in frbl.annotations
  116. assert 'something' in frbl.segments[0].annotations
  117. # assert 'ansigrandom' in frbl.segments[0].analogsignals[0].annotations
  118. assert 'railway' in frbl.segments[0].spiketrains[0].annotations
  119. assert 'venue' in frbl.segments[0].events[0].annotations
  120. assert 'evven' in frbl.segments[0].events[1].annotations
  121. os.remove(self.testfilename)
  122. if __name__ == '__main__':
  123. unittest.main()