123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693 |
- """
- baserawio
- ======
- Classes
- -------
- BaseRawIO
- abstract class which should be overridden to write a RawIO.
- RawIO is a new API in neo that is supposed to acces as fast as possible
- raw data. All IO with theses characteristics should/could be rewritten:
- * internally use of memmap (or hdf5)
- * reading header is quite cheap (not read all the file)
- * neo tree object is symetric and logical: same channel/units/event
- along all block and segments.
- So this handles **only** one simplified but very frequent case of dataset:
- * Only one channel set for AnalogSignal (aka ChannelIndex) stable along Segment
- * Only one channel set for SpikeTrain (aka Unit) stable along Segment
- * AnalogSignal have all the same sampling_rate acroos all Segment
- * t_start/t_stop are the same for many object (SpikeTrain, Event) inside a Segment
- * AnalogSignal should all have the same sampling_rate otherwise the won't be read
- a the same time. So signal_group_mode=='split-all' in BaseFromRaw
- A helper class `neo.io.basefromrawio.BaseFromRaw` should transform a RawIO to
- neo legacy IO from free.
- With this API the IO have an attributes `header` with necessary keys.
- See ExampleRawIO as example.
- BaseRawIO implement a possible presistent cache system that can be used
- by some IOs to avoid very long parse_header(). The idea is that some variable
- or vector can be store somewhere (near the fiel, /tmp, any path)
- """
- # from __future__ import unicode_literals, print_function, division, absolute_import
- import logging
- import numpy as np
- import os
- import sys
- from neo import logging_handler
- try:
- import joblib
- HAVE_JOBLIB = True
- except ImportError:
- HAVE_JOBLIB = False
- possible_raw_modes = ['one-file', 'multi-file', 'one-dir', ] # 'multi-dir', 'url', 'other'
- error_header = 'Header is not read yet, do parse_header() first'
- _signal_channel_dtype = [
- ('name', 'U64'),
- ('id', 'int64'),
- ('sampling_rate', 'float64'),
- ('dtype', 'U16'),
- ('units', 'U64'),
- ('gain', 'float64'),
- ('offset', 'float64'),
- ('group_id', 'int64'),
- ]
- _common_sig_characteristics = ['sampling_rate', 'dtype', 'group_id']
- _unit_channel_dtype = [
- ('name', 'U64'),
- ('id', 'U64'),
- # for waveform
- ('wf_units', 'U64'),
- ('wf_gain', 'float64'),
- ('wf_offset', 'float64'),
- ('wf_left_sweep', 'int64'),
- ('wf_sampling_rate', 'float64'),
- ]
- _event_channel_dtype = [
- ('name', 'U64'),
- ('id', 'U64'),
- ('type', 'S5'), # epoch ot event
- ]
- class BaseRawIO:
- """
- Generic class to handle.
- """
- name = 'BaseIO'
- description = ''
- extensions = []
- rawmode = None # one key in possible_raw_modes
- def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs):
- """
- When rawmode=='one-file' kargs MUST contains 'filename' the filename
- When rawmode=='multi-file' kargs MUST contains 'filename' one of the filenames.
- When rawmode=='one-dir' kargs MUST contains 'dirname' the dirname.
- """
- # create a logger for the IO class
- fullname = self.__class__.__module__ + '.' + self.__class__.__name__
- self.logger = logging.getLogger(fullname)
- # create a logger for 'neo' and add a handler to it if it doesn't
- # have one already.
- # (it will also not add one if the root logger has a handler)
- corename = self.__class__.__module__.split('.')[0]
- corelogger = logging.getLogger(corename)
- rootlogger = logging.getLogger()
- if not corelogger.handlers and not rootlogger.handlers:
- corelogger.addHandler(logging_handler)
- self.use_cache = use_cache
- if use_cache:
- assert HAVE_JOBLIB, 'You need to install joblib for cache'
- self.setup_cache(cache_path)
- else:
- self._cache = None
- self.header = None
- def parse_header(self):
- """
- This must parse the file header to get all stuff for fast use later on.
- This must create
- self.header['nb_block']
- self.header['nb_segment']
- self.header['signal_channels']
- self.header['units_channels']
- self.header['event_channels']
- """
- self._parse_header()
- self._group_signal_channel_characteristics()
- def source_name(self):
- """Return fancy name of file source"""
- return self._source_name()
- def __repr__(self):
- txt = '{}: {}\n'.format(self.__class__.__name__, self.source_name())
- if self.header is not None:
- nb_block = self.block_count()
- txt += 'nb_block: {}\n'.format(nb_block)
- nb_seg = [self.segment_count(i) for i in range(nb_block)]
- txt += 'nb_segment: {}\n'.format(nb_seg)
- for k in ('signal_channels', 'unit_channels', 'event_channels'):
- ch = self.header[k]
- if len(ch) > 8:
- chantxt = "[{} ... {}]".format(', '.join(e for e in ch['name'][:4]),
- ' '.join(e for e in ch['name'][-4:]))
- else:
- chantxt = "[{}]".format(', '.join(e for e in ch['name']))
- txt += '{}: {}\n'.format(k, chantxt)
- return txt
- def _generate_minimal_annotations(self):
- """
- Helper function that generate a nested dict
- of all annotations.
- must be called when these are Ok:
- * block_count()
- * segment_count()
- * signal_channels_count()
- * unit_channels_count()
- * event_channels_count()
- Usage:
- raw_annotations['blocks'][block_index] = { 'nickname' : 'super block', 'segments' : ...}
- raw_annotations['blocks'][block_index] = { 'nickname' : 'super block', 'segments' : ...}
- raw_annotations['blocks'][block_index]['segments'][seg_index]['signals'][channel_index] = {'nickname': 'super channel'}
- raw_annotations['blocks'][block_index]['segments'][seg_index]['units'][unit_index] = {'nickname': 'super neuron'}
- raw_annotations['blocks'][block_index]['segments'][seg_index]['events'][ev_chan] = {'nickname': 'super trigger'}
- Theses annotations will be used at the neo.io API directly in objects.
- Standard annotation like name/id/file_origin are already generated here.
- """
- signal_channels = self.header['signal_channels']
- unit_channels = self.header['unit_channels']
- event_channels = self.header['event_channels']
- a = {'blocks': [], 'signal_channels': [], 'unit_channels': [], 'event_channels': []}
- for block_index in range(self.block_count()):
- d = {'segments': []}
- d['file_origin'] = self.source_name()
- a['blocks'].append(d)
- for seg_index in range(self.segment_count(block_index)):
- d = {'signals': [], 'units': [], 'events': []}
- d['file_origin'] = self.source_name()
- a['blocks'][block_index]['segments'].append(d)
- for c in range(signal_channels.size):
- # use for AnalogSignal.annotations
- d = {}
- d['name'] = signal_channels['name'][c]
- d['channel_id'] = signal_channels['id'][c]
- a['blocks'][block_index]['segments'][seg_index]['signals'].append(d)
- for c in range(unit_channels.size):
- # use for SpikeTrain.annotations
- d = {}
- d['name'] = unit_channels['name'][c]
- d['id'] = unit_channels['id'][c]
- a['blocks'][block_index]['segments'][seg_index]['units'].append(d)
- for c in range(event_channels.size):
- # use for Event.annotations
- d = {}
- d['name'] = event_channels['name'][c]
- d['id'] = event_channels['id'][c]
- d['file_origin'] = self._source_name()
- a['blocks'][block_index]['segments'][seg_index]['events'].append(d)
- for c in range(signal_channels.size):
- # use for ChannelIndex.annotations
- d = {}
- d['name'] = signal_channels['name'][c]
- d['channel_id'] = signal_channels['id'][c]
- d['file_origin'] = self._source_name()
- a['signal_channels'].append(d)
- for c in range(unit_channels.size):
- # use for Unit.annotations
- d = {}
- d['name'] = unit_channels['name'][c]
- d['id'] = unit_channels['id'][c]
- d['file_origin'] = self._source_name()
- a['unit_channels'].append(d)
- for c in range(event_channels.size):
- # not used in neo.io at the moment could usefull one day
- d = {}
- d['name'] = event_channels['name'][c]
- d['id'] = event_channels['id'][c]
- d['file_origin'] = self._source_name()
- a['event_channels'].append(d)
- self.raw_annotations = a
- def _raw_annotate(self, obj_name, chan_index=0, block_index=0, seg_index=0, **kargs):
- """
- Annotate an object in the list/dict tree annotations.
- """
- bl_annotations = self.raw_annotations['blocks'][block_index]
- seg_annotations = bl_annotations['segments'][seg_index]
- if obj_name == 'blocks':
- bl_annotations.update(kargs)
- elif obj_name == 'segments':
- seg_annotations.update(kargs)
- elif obj_name in ['signals', 'events', 'units']:
- obj_annotations = seg_annotations[obj_name][chan_index]
- obj_annotations.update(kargs)
- elif obj_name in ['signal_channels', 'unit_channels', 'event_channel']:
- obj_annotations = self.raw_annotations[obj_name][chan_index]
- obj_annotations.update(kargs)
- def _repr_annotations(self):
- txt = 'Raw annotations\n'
- for block_index in range(self.block_count()):
- bl_a = self.raw_annotations['blocks'][block_index]
- txt += '*Block {}\n'.format(block_index)
- for k, v in bl_a.items():
- if k in ('segments',):
- continue
- txt += ' -{}: {}\n'.format(k, v)
- for seg_index in range(self.segment_count(block_index)):
- seg_a = bl_a['segments'][seg_index]
- txt += ' *Segment {}\n'.format(seg_index)
- for k, v in seg_a.items():
- if k in ('signals', 'units', 'events',):
- continue
- txt += ' -{}: {}\n'.format(k, v)
- for child in ('signals', 'units', 'events'):
- n = self.header[child[:-1] + '_channels'].shape[0]
- for c in range(n):
- neo_name = {'signals': 'AnalogSignal',
- 'units': 'SpikeTrain', 'events': 'Event/Epoch'}[child]
- txt += ' *{} {}\n'.format(neo_name, c)
- child_a = seg_a[child][c]
- for k, v in child_a.items():
- txt += ' -{}: {}\n'.format(k, v)
- return txt
- def print_annotations(self):
- """Print formated raw_annotations"""
- print(self._repr_annotations())
- def block_count(self):
- """return number of blocks"""
- return self.header['nb_block']
- def segment_count(self, block_index):
- """return number of segment for a given block"""
- return self.header['nb_segment'][block_index]
- def signal_channels_count(self):
- """Return the number of signal channels.
- Same along all Blocks and Segments.
- """
- return len(self.header['signal_channels'])
- def unit_channels_count(self):
- """Return the number of unit (aka spike) channels.
- Same along all Blocks and Segment.
- """
- return len(self.header['unit_channels'])
- def event_channels_count(self):
- """Return the number of event/epoch channels.
- Same allong all Blocks and Segments.
- """
- return len(self.header['event_channels'])
- def segment_t_start(self, block_index, seg_index):
- """Global t_start of a Segment in s. Shared by all objects except
- for AnalogSignal.
- """
- return self._segment_t_start(block_index, seg_index)
- def segment_t_stop(self, block_index, seg_index):
- """Global t_start of a Segment in s. Shared by all objects except
- for AnalogSignal.
- """
- return self._segment_t_stop(block_index, seg_index)
- ###
- # signal and channel zone
- def _group_signal_channel_characteristics(self):
- """
- Useful for few IOs (TdtrawIO, NeuroExplorerRawIO, ...).
- Group signals channels by same characteristics:
- * sampling_rate (global along block and segment)
- * group_id (explicite channel group)
- If all channels have the same characteristics then
- `get_analogsignal_chunk` can be call wihtout restriction.
- If not, then **channel_indexes** must be specified
- in `get_analogsignal_chunk` and only channels with same
- characteristics can be read at the same time.
- This is useful for some IO than
- have internally several signals channels family.
- For many RawIO all channels have the same
- sampling_rate/size/t_start. In that cases, internal flag
- **self._several_channel_groups will be set to False, so
- `get_analogsignal_chunk(..)` won't suffer in performance.
- Note that at neo.io level this have an impact on
- `signal_group_mode`. 'split-all' will work in any situation
- But grouping channel in the same AnalogSignal
- with 'group-by-XXX' will depend on common characteristics
- of course.
- """
- characteristics = self.header['signal_channels'][_common_sig_characteristics]
- unique_characteristics = np.unique(characteristics)
- if len(unique_characteristics) == 1:
- self._several_channel_groups = False
- else:
- self._several_channel_groups = True
- def _check_common_characteristics(self, channel_indexes):
- """
- Useful for few IOs (TdtrawIO, NeuroExplorerRawIO, ...).
- Check that a set a signal channel_indexes share common
- characteristics (**sampling_rate/t_start/size**).
- Useful only when RawIO propose differents channels groups
- with different sampling_rate for instance.
- """
- # ~ print('_check_common_characteristics', channel_indexes)
- assert channel_indexes is not None, \
- 'You must specify channel_indexes'
- characteristics = self.header['signal_channels'][_common_sig_characteristics]
- # ~ print(characteristics[channel_indexes])
- assert np.unique(characteristics[channel_indexes]).size == 1, \
- 'This channel set has varied characteristics'
- def get_group_signal_channel_indexes(self):
- """
- Useful for few IOs (TdtrawIO, NeuroExplorerRawIO, ...).
- Return a list of channel_indexes than have same characteristics
- """
- if self._several_channel_groups:
- characteristics = self.header['signal_channels'][_common_sig_characteristics]
- unique_characteristics = np.unique(characteristics)
- channel_indexes_list = []
- for e in unique_characteristics:
- channel_indexes, = np.nonzero(characteristics == e)
- channel_indexes_list.append(channel_indexes)
- return channel_indexes_list
- else:
- return [None]
- def channel_name_to_index(self, channel_names):
- """
- Transform channel_names to channel_indexes.
- Based on self.header['signal_channels']
- """
- ch = self.header['signal_channels']
- channel_indexes, = np.nonzero(np.in1d(ch['name'], channel_names))
- assert len(channel_indexes) == len(channel_names), 'not match'
- return channel_indexes
- def channel_id_to_index(self, channel_ids):
- """
- Transform channel_ids to channel_indexes.
- Based on self.header['signal_channels']
- """
- ch = self.header['signal_channels']
- channel_indexes, = np.nonzero(np.in1d(ch['id'], channel_ids))
- assert len(channel_indexes) == len(channel_ids), 'not match'
- return channel_indexes
- def _get_channel_indexes(self, channel_indexes, channel_names, channel_ids):
- """
- Select channel_indexes from channel_indexes/channel_names/channel_ids
- depending which is not None.
- """
- if channel_indexes is None and channel_names is not None:
- channel_indexes = self.channel_name_to_index(channel_names)
- if channel_indexes is None and channel_ids is not None:
- channel_indexes = self.channel_id_to_index(channel_ids)
- return channel_indexes
- def get_signal_size(self, block_index, seg_index, channel_indexes=None):
- if self._several_channel_groups:
- self._check_common_characteristics(channel_indexes)
- return self._get_signal_size(block_index, seg_index, channel_indexes)
- def get_signal_t_start(self, block_index, seg_index, channel_indexes=None):
- if self._several_channel_groups:
- self._check_common_characteristics(channel_indexes)
- return self._get_signal_t_start(block_index, seg_index, channel_indexes)
- def get_signal_sampling_rate(self, channel_indexes=None):
- if self._several_channel_groups:
- self._check_common_characteristics(channel_indexes)
- chan_index0 = channel_indexes[0]
- else:
- chan_index0 = 0
- sr = self.header['signal_channels'][chan_index0]['sampling_rate']
- return float(sr)
- def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_stop=None,
- channel_indexes=None, channel_names=None, channel_ids=None):
- """
- Return a chunk of raw signal.
- """
- channel_indexes = self._get_channel_indexes(channel_indexes, channel_names, channel_ids)
- if self._several_channel_groups:
- self._check_common_characteristics(channel_indexes)
- raw_chunk = self._get_analogsignal_chunk(
- block_index, seg_index, i_start, i_stop, channel_indexes)
- return raw_chunk
- def rescale_signal_raw_to_float(self, raw_signal, dtype='float32',
- channel_indexes=None, channel_names=None, channel_ids=None):
- channel_indexes = self._get_channel_indexes(channel_indexes, channel_names, channel_ids)
- if channel_indexes is None:
- channel_indexes = slice(None)
- channels = self.header['signal_channels'][channel_indexes]
- float_signal = raw_signal.astype(dtype)
- if np.any(channels['gain'] != 1.):
- float_signal *= channels['gain']
- if np.any(channels['offset'] != 0.):
- float_signal += channels['offset']
- return float_signal
- # spiketrain and unit zone
- def spike_count(self, block_index=0, seg_index=0, unit_index=0):
- return self._spike_count(block_index, seg_index, unit_index)
- def get_spike_timestamps(self, block_index=0, seg_index=0, unit_index=0,
- t_start=None, t_stop=None):
- """
- The timestamp is as close to the format itself. Sometimes float/int32/int64.
- Sometimes it is the index on the signal but not always.
- The conversion to second or index_on_signal is done outside here.
- t_start/t_sop are limits in seconds.
- """
- timestamp = self._get_spike_timestamps(block_index, seg_index, unit_index, t_start, t_stop)
- return timestamp
- def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'):
- """
- Rescale spike timestamps to seconds.
- """
- return self._rescale_spike_timestamp(spike_timestamps, dtype)
- # spiketrain waveform zone
- def get_spike_raw_waveforms(self, block_index=0, seg_index=0, unit_index=0,
- t_start=None, t_stop=None):
- wf = self._get_spike_raw_waveforms(block_index, seg_index, unit_index, t_start, t_stop)
- return wf
- def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32', unit_index=0):
- wf_gain = self.header['unit_channels']['wf_gain'][unit_index]
- wf_offset = self.header['unit_channels']['wf_offset'][unit_index]
- float_waveforms = raw_waveforms.astype(dtype)
- if wf_gain != 1.:
- float_waveforms *= wf_gain
- if wf_offset != 0.:
- float_waveforms += wf_offset
- return float_waveforms
- # event and epoch zone
- def event_count(self, block_index=0, seg_index=0, event_channel_index=0):
- return self._event_count(block_index, seg_index, event_channel_index)
- def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0,
- t_start=None, t_stop=None):
- """
- The timestamp is as close to the format itself. Sometimes float/int32/int64.
- Sometimes it is the index on the signal but not always.
- The conversion to second or index_on_signal is done outside here.
- t_start/t_sop are limits in seconds.
- returns
- timestamp
- labels
- durations
- """
- timestamp, durations, labels = self._get_event_timestamps(
- block_index, seg_index, event_channel_index, t_start, t_stop)
- return timestamp, durations, labels
- def rescale_event_timestamp(self, event_timestamps, dtype='float64'):
- """
- Rescale event timestamps to s
- """
- return self._rescale_event_timestamp(event_timestamps, dtype)
- def rescale_epoch_duration(self, raw_duration, dtype='float64'):
- """
- Rescale epoch raw duration to s
- """
- return self._rescale_epoch_duration(raw_duration, dtype)
- def setup_cache(self, cache_path, **init_kargs):
- if self.rawmode in ('one-file', 'multi-file'):
- resource_name = self.filename
- elif self.rawmode == 'one-dir':
- resource_name = self.dirname
- else:
- raise (NotImplementedError)
- if cache_path == 'home':
- if sys.platform.startswith('win'):
- dirname = os.path.join(os.environ['APPDATA'], 'neo_rawio_cache')
- elif sys.platform.startswith('darwin'):
- dirname = '~/Library/Application Support/neo_rawio_cache'
- else:
- dirname = os.path.expanduser('~/.config/neo_rawio_cache')
- dirname = os.path.join(dirname, self.__class__.__name__)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- elif cache_path == 'same_as_resource':
- dirname = os.path.dirname(resource_name)
- else:
- assert os.path.exists(cache_path), \
- 'cache_path do not exists use "home" or "same_as_resource" to make this auto'
- # the hash of the resource (dir of file) is done with filename+datetime
- # TODO make something more sophisticated when rawmode='one-dir' that use all
- # filename and datetime
- d = dict(ressource_name=resource_name, mtime=os.path.getmtime(resource_name))
- hash = joblib.hash(d, hash_name='md5')
- # name is constructed from the real_n,ame and the hash
- name = '{}_{}'.format(os.path.basename(resource_name), hash)
- self.cache_filename = os.path.join(dirname, name)
- if os.path.exists(self.cache_filename):
- self.logger.warning('Use existing cache file {}'.format(self.cache_filename))
- self._cache = joblib.load(self.cache_filename)
- else:
- self.logger.warning('Create cache file {}'.format(self.cache_filename))
- self._cache = {}
- self.dump_cache()
- def add_in_cache(self, **kargs):
- assert self.use_cache
- self._cache.update(kargs)
- self.dump_cache()
- def dump_cache(self):
- assert self.use_cache
- joblib.dump(self._cache, self.cache_filename)
- ##################
- # Functions to be implemented in IO below here
- def _parse_header(self):
- raise (NotImplementedError)
- # must call
- # self._generate_empty_annotations()
- def _source_name(self):
- raise (NotImplementedError)
- def _segment_t_start(self, block_index, seg_index):
- raise (NotImplementedError)
- def _segment_t_stop(self, block_index, seg_index):
- raise (NotImplementedError)
- ###
- # signal and channel zone
- def _get_signal_size(self, block_index, seg_index, channel_indexes):
- raise (NotImplementedError)
- def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
- raise (NotImplementedError)
- def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
- raise (NotImplementedError)
- ###
- # spiketrain and unit zone
- def _spike_count(self, block_index, seg_index, unit_index):
- raise (NotImplementedError)
- def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
- raise (NotImplementedError)
- def _rescale_spike_timestamp(self, spike_timestamps, dtype):
- raise (NotImplementedError)
- ###
- # spike waveforms zone
- def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
- raise (NotImplementedError)
- ###
- # event and epoch zone
- def _event_count(self, block_index, seg_index, event_channel_index):
- raise (NotImplementedError)
- def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
- raise (NotImplementedError)
- def _rescale_event_timestamp(self, event_timestamps, dtype):
- raise (NotImplementedError)
- def _rescale_epoch_duration(self, raw_duration, dtype):
- raise (NotImplementedError)
|