123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- # -*- coding: utf-8 -*-
- """
- Tests of neo.io.brainwaresrcio
- """
- # needed for python 3 compatibility
- from __future__ import absolute_import, division, print_function
- import logging
- import os.path
- import sys
- import unittest
- import numpy as np
- import quantities as pq
- from neo.core import (Block, Event,
- ChannelIndex, Segment, SpikeTrain, Unit)
- from neo.io import BrainwareSrcIO, brainwaresrcio
- from neo.test.iotest.common_io_test import BaseTestIO
- from neo.test.tools import (assert_same_sub_schema,
- assert_neo_object_is_compliant)
- from neo.test.iotest.tools import create_generic_reader
- PY_VER = sys.version_info[0]
- FILES_TO_TEST = ['block_300ms_4rep_1clust_part_ch1.src',
- 'block_500ms_5rep_empty_fullclust_ch1.src',
- 'block_500ms_5rep_empty_partclust_ch1.src',
- 'interleaved_500ms_5rep_ch2.src',
- 'interleaved_500ms_5rep_nospikes_ch1.src',
- 'interleaved_500ms_7rep_noclust_ch1.src',
- 'long_170s_1rep_1clust_ch2.src',
- 'multi_500ms_mulitrep_ch1.src',
- 'random_500ms_12rep_noclust_part_ch2.src',
- 'sequence_500ms_5rep_ch2.src']
- FILES_TO_COMPARE = ['block_300ms_4rep_1clust_part_ch1',
- 'block_500ms_5rep_empty_fullclust_ch1',
- 'block_500ms_5rep_empty_partclust_ch1',
- 'interleaved_500ms_5rep_ch2',
- 'interleaved_500ms_5rep_nospikes_ch1',
- 'interleaved_500ms_7rep_noclust_ch1',
- '',
- 'multi_500ms_mulitrep_ch1',
- 'random_500ms_12rep_noclust_part_ch2',
- 'sequence_500ms_5rep_ch2']
- def proc_src(filename):
- '''Load an src file that has already been processed by the official matlab
- file converter. That matlab data is saved to an m-file, which is then
- converted to a numpy '.npz' file. This numpy file is the file actually
- loaded. This function converts it to a neo block and returns the block.
- This block can be compared to the block produced by BrainwareSrcIO to
- make sure BrainwareSrcIO is working properly
- block = proc_src(filename)
- filename: The file name of the numpy file to load. It should end with
- '*_src_py?.npz'. This will be converted to a neo 'file_origin' property
- with the value '*.src', so the filename to compare should fit that pattern.
- 'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
- for the python 3 version of the numpy file.
- example: filename = 'file1_src_py2.npz'
- src file name = 'file1.src'
- '''
- with np.load(filename) as srcobj:
- srcfile = list(srcobj.items())[0][1]
- filename = os.path.basename(filename[:-12] + '.src')
- block = Block(file_origin=filename)
- NChannels = srcfile['NChannels'][0, 0][0, 0]
- side = str(srcfile['side'][0, 0][0])
- ADperiod = srcfile['ADperiod'][0, 0][0, 0]
- comm_seg = proc_src_comments(srcfile, filename)
- block.segments.append(comm_seg)
- chx = proc_src_units(srcfile, filename)
- chan_nums = np.arange(NChannels, dtype='int')
- chan_names = ['Chan{}'.format(i) for i in range(NChannels)]
- chx.index = chan_nums
- chx.channel_names = np.array(chan_names, dtype='string_')
- block.channel_indexes.append(chx)
- for rep in srcfile['sets'][0, 0].flatten():
- proc_src_condition(rep, filename, ADperiod, side, block)
- block.create_many_to_one_relationship()
- return block
- def proc_src_comments(srcfile, filename):
- '''Get the comments in an src file that has been#!N
- processed by the official
- matlab function. See proc_src for details'''
- comm_seg = Segment(name='Comments', file_origin=filename)
- commentarray = srcfile['comments'].flatten()[0]
- senders = [res[0] for res in commentarray['sender'].flatten()]
- texts = [res[0] for res in commentarray['text'].flatten()]
- timeStamps = [res[0, 0] for res in commentarray['timeStamp'].flatten()]
- timeStamps = np.array(timeStamps, dtype=np.float32)
- t_start = timeStamps.min()
- timeStamps = pq.Quantity(timeStamps - t_start, units=pq.d).rescale(pq.s)
- texts = np.array(texts, dtype='S')
- senders = np.array(senders, dtype='S')
- t_start = brainwaresrcio.convert_brainwaresrc_timestamp(t_start.tolist())
- comments = Event(times=timeStamps, labels=texts, senders=senders)
- comm_seg.events = [comments]
- comm_seg.rec_datetime = t_start
- return comm_seg
- def proc_src_units(srcfile, filename):
- '''Get the units in an src file that has been processed by the official
- matlab function. See proc_src for details'''
- chx = ChannelIndex(file_origin=filename,
- index=np.array([], dtype=int))
- un_unit = Unit(name='UnassignedSpikes', file_origin=filename,
- elliptic=[], boundaries=[], timestamp=[], max_valid=[])
- chx.units.append(un_unit)
- sortInfo = srcfile['sortInfo'][0, 0]
- timeslice = sortInfo['timeslice'][0, 0]
- maxValid = timeslice['maxValid'][0, 0]
- cluster = timeslice['cluster'][0, 0]
- if len(cluster):
- maxValid = maxValid[0, 0]
- elliptic = [res.flatten() for res in cluster['elliptic'].flatten()]
- boundaries = [res.flatten() for res in cluster['boundaries'].flatten()]
- fullclust = zip(elliptic, boundaries)
- for ielliptic, iboundaries in fullclust:
- unit = Unit(file_origin=filename,
- boundaries=[iboundaries],
- elliptic=[ielliptic], timeStamp=[],
- max_valid=[maxValid])
- chx.units.append(unit)
- return chx
- def proc_src_condition(rep, filename, ADperiod, side, block):
- '''Get the condition in a src file that has been processed by the official
- matlab function. See proc_src for details'''
- chx = block.channel_indexes[0]
- stim = rep['stim'].flatten()
- params = [str(res[0]) for res in stim['paramName'][0].flatten()]
- values = [res for res in stim['paramVal'][0].flatten()]
- stim = dict(zip(params, values))
- sweepLen = rep['sweepLen'][0, 0]
- if not len(rep):
- return
- unassignedSpikes = rep['unassignedSpikes'].flatten()
- if len(unassignedSpikes):
- damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
- timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
- spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
- respWin = np.array([], dtype=np.int32)
- trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
- respWin, damaIndexes, timeStamps,
- filename)
- chx.units[0].spiketrains.extend(trains)
- atrains = [trains]
- else:
- damaIndexes = []
- timeStamps = []
- atrains = []
- clusters = rep['clusters'].flatten()
- if len(clusters):
- IdStrings = [res[0] for res in clusters['IdString']]
- sweepLens = [res[0, 0] for res in clusters['sweepLen']]
- respWins = [res.flatten() for res in clusters['respWin']]
- spikeunits = []
- for cluster in clusters['sweeps']:
- if len(cluster):
- spikes = [res.flatten() for res in
- cluster['spikes'].flatten()]
- else:
- spikes = []
- spikeunits.append(spikes)
- else:
- IdStrings = []
- sweepLens = []
- respWins = []
- spikeunits = []
- for unit, IdString in zip(chx.units[1:], IdStrings):
- unit.name = str(IdString)
- fullunit = zip(spikeunits, chx.units[1:], sweepLens, respWins)
- for spikeunit, unit, sweepLen, respWin in fullunit:
- trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
- respWin, damaIndexes, timeStamps,
- filename)
- atrains.append(trains)
- unit.spiketrains.extend(trains)
- atrains = zip(*atrains)
- for trains in atrains:
- segment = Segment(file_origin=filename, feature_type=-1,
- go_by_closest_unit_center=False,
- include_unit_bounds=False, **stim)
- block.segments.append(segment)
- segment.spiketrains = trains
- def proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod, respWin,
- damaIndexes, timeStamps, filename):
- '''Get the unit in a condition in a src file that has been processed by
- the official matlab function. See proc_src for details'''
- if not damaIndexes:
- damaIndexes = [0] * len(spikeunit)
- timeStamps = [0] * len(spikeunit)
- trains = []
- for sweep, damaIndex, timeStamp in zip(spikeunit, damaIndexes,
- timeStamps):
- timeStamp = brainwaresrcio.convert_brainwaresrc_timestamp(timeStamp)
- train = proc_src_condition_unit_repetition(sweep, damaIndex,
- timeStamp, sweepLen,
- side, ADperiod, respWin,
- filename)
- trains.append(train)
- return trains
- def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen,
- side, ADperiod, respWin, filename):
- '''Get the repetion for a unit in a condition in a src file that has been
- processed by the official matlab function. See proc_src for details'''
- damaIndex = damaIndex.astype('int32')
- if len(sweep):
- times = np.array([res[0, 0] for res in sweep['time']])
- shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res
- in sweep['shape']], axis=0)
- trig2 = np.array([res[0, 0] for res in sweep['trig2']])
- else:
- times = np.array([])
- shapes = np.array([[[]]])
- trig2 = np.array([])
- times = pq.Quantity(times, units=pq.ms, dtype=np.float32)
- t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32)
- t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32)
- trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8)
- waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV)
- sampling_period = pq.Quantity(ADperiod, units=pq.us)
- train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop,
- trig2=trig2, dtype=np.float32, timestamp=timeStamp,
- dama_index=damaIndex, side=side, copy=True,
- respwin=respWin, waveforms=waveforms,
- file_origin=filename)
- train.annotations['side'] = side
- train.sampling_period = sampling_period
- return train
- class BrainwareSrcIOTestCase(BaseTestIO, unittest.TestCase):
- '''
- Unit test testcase for neo.io.BrainwareSrcIO
- '''
- ioclass = BrainwareSrcIO
- read_and_write_is_bijective = False
- # These are the files it tries to read and test for compliance
- files_to_test = FILES_TO_TEST
- # these are reference files to compare to
- files_to_compare = FILES_TO_COMPARE
- # add the appropriate suffix depending on the python version
- for i, fname in enumerate(files_to_compare):
- if fname:
- files_to_compare[i] += '_src_py%s.npz' % PY_VER
- # Will fetch from g-node if they don't already exist locally
- # How does it know to do this before any of the other tests?
- files_to_download = files_to_test + files_to_compare
- def setUp(self):
- super(BrainwareSrcIOTestCase, self).setUp()
- def test_reading_same(self):
- for ioobj, path in self.iter_io_objects(return_path=True):
- obj_reader_all = create_generic_reader(ioobj, readall=True)
- obj_reader_base = create_generic_reader(ioobj, target=False)
- obj_reader_next = create_generic_reader(ioobj, target='next_block')
- obj_reader_single = create_generic_reader(ioobj)
- obj_all = obj_reader_all()
- obj_base = obj_reader_base()
- obj_single = obj_reader_single()
- obj_next = [obj_reader_next()]
- while ioobj._isopen:
- obj_next.append(obj_reader_next())
- try:
- assert_same_sub_schema(obj_all[0], obj_base)
- assert_same_sub_schema(obj_all[0], obj_single)
- assert_same_sub_schema(obj_all, obj_next)
- except BaseException as exc:
- exc.args += ('from ' + os.path.basename(path),)
- raise
- self.assertEqual(len(obj_all), len(obj_next))
- def test_against_reference(self):
- for filename, refname in zip(self.files_to_test,
- self.files_to_compare):
- if not refname:
- continue
- obj = self.read_file(filename=filename, readall=True)[0]
- refobj = proc_src(self.get_filename_path(refname))
- try:
- assert_neo_object_is_compliant(obj)
- assert_neo_object_is_compliant(refobj)
- assert_same_sub_schema(obj, refobj)
- except BaseException as exc:
- exc.args += ('from ' + filename,)
- raise
- if __name__ == '__main__':
- logger = logging.getLogger(BrainwareSrcIO.__module__ +
- '.' +
- BrainwareSrcIO.__name__)
- logger.setLevel(100)
- unittest.main()
|