123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- """
- RawIO Class for NIX files
- The RawIO assumes all segments and all blocks have the same structure.
- It supports all kinds of NEO objects.
- Author: Chek Yin Choi
- """
- from __future__ import print_function, division, absolute_import
- from neo.rawio.baserawio import (BaseRawIO, _signal_channel_dtype,
- _unit_channel_dtype, _event_channel_dtype)
- import numpy as np
- try:
- import nixio as nix
- HAVE_NIX = True
- except ImportError:
- HAVE_NIX = False
- nix = None
- class NIXRawIO(BaseRawIO):
- extensions = ['nix']
- rawmode = 'one-file'
- def __init__(self, filename=''):
- BaseRawIO.__init__(self)
- self.filename = filename
- def _source_name(self):
- return self.filename
- def _parse_header(self):
- self.file = nix.File.open(self.filename, nix.FileMode.ReadOnly)
- sig_channels = []
- size_list = []
- for bl in self.file.blocks:
- for seg in bl.groups:
- for da_idx, da in enumerate(seg.data_arrays):
- if da.type == "neo.analogsignal":
- chan_id = da_idx
- ch_name = da.metadata['neo_name']
- units = str(da.unit)
- dtype = str(da.dtype)
- sr = 1 / da.dimensions[0].sampling_interval
- da_leng = da.size
- if da_leng not in size_list:
- size_list.append(da_leng)
- group_id = 0
- for sid, li_leng in enumerate(size_list):
- if li_leng == da_leng:
- group_id = sid
- # very important! group_id use to store channel groups!!!
- # use only for different signal length
- gain = 1
- offset = 0.
- sig_channels.append((ch_name, chan_id, sr, dtype,
- units, gain, offset, group_id))
- break
- break
- sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
- unit_channels = []
- unit_name = ""
- unit_id = ""
- for bl in self.file.blocks:
- for seg in bl.groups:
- for mt in seg.multi_tags:
- if mt.type == "neo.spiketrain":
- unit_name = mt.metadata['neo_name']
- unit_id = mt.id
- if mt.features:
- wf_units = mt.features[0].data.unit
- wf_sampling_rate = 1 / mt.features[0].data.dimensions[
- 2].sampling_interval
- else:
- wf_units = None
- wf_sampling_rate = 0
- wf_gain = 1
- wf_offset = 0.
- if mt.features and "left_sweep" in mt.features[0].data.metadata:
- wf_left_sweep = mt.features[0].data.metadata["left_sweep"]
- else:
- wf_left_sweep = 0
- unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
- wf_offset, wf_left_sweep, wf_sampling_rate))
- break
- break
- unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
- event_channels = []
- event_count = 0
- epoch_count = 0
- for bl in self.file.blocks:
- for seg in bl.groups:
- for mt in seg.multi_tags:
- if mt.type == "neo.event":
- ev_name = mt.metadata['neo_name']
- ev_id = event_count
- event_count += 1
- ev_type = "event"
- event_channels.append((ev_name, ev_id, ev_type))
- if mt.type == "neo.epoch":
- ep_name = mt.metadata['neo_name']
- ep_id = epoch_count
- epoch_count += 1
- ep_type = "epoch"
- event_channels.append((ep_name, ep_id, ep_type))
- break
- break
- event_channels = np.array(event_channels, dtype=_event_channel_dtype)
- self.da_list = {'blocks': []}
- for block_index, blk in enumerate(self.file.blocks):
- d = {'segments': []}
- self.da_list['blocks'].append(d)
- for seg_index, seg in enumerate(blk.groups):
- d = {'signals': []}
- self.da_list['blocks'][block_index]['segments'].append(d)
- size_list = []
- data_list = []
- da_name_list = []
- for da in seg.data_arrays:
- if da.type == 'neo.analogsignal':
- size_list.append(da.size)
- data_list.append(da)
- da_name_list.append(da.metadata['neo_name'])
- self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'] = size_list
- self.da_list['blocks'][block_index]['segments'][seg_index]['data'] = data_list
- self.da_list['blocks'][block_index]['segments'][seg_index]['ch_name'] = \
- da_name_list
- self.unit_list = {'blocks': []}
- for block_index, blk in enumerate(self.file.blocks):
- d = {'segments': []}
- self.unit_list['blocks'].append(d)
- for seg_index, seg in enumerate(blk.groups):
- d = {'spiketrains': [], 'spiketrains_id': [], 'spiketrains_unit': []}
- self.unit_list['blocks'][block_index]['segments'].append(d)
- st_idx = 0
- for st in seg.multi_tags:
- d = {'waveforms': []}
- self.unit_list[
- 'blocks'][block_index]['segments'][seg_index]['spiketrains_unit'].append(d)
- if st.type == 'neo.spiketrain':
- seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
- seg['spiketrains'].append(st.positions)
- seg['spiketrains_id'].append(st.id)
- if st.features and st.features[0].data.type == "neo.waveforms":
- waveforms = st.features[0].data
- if waveforms:
- seg['spiketrains_unit'][st_idx]['waveforms'] = waveforms
- else:
- seg['spiketrains_unit'][st_idx]['waveforms'] = None
- # assume one spiketrain one waveform
- st_idx += 1
- self.header = {}
- self.header['nb_block'] = len(self.file.blocks)
- self.header['nb_segment'] = [len(bl.groups) for bl in self.file.blocks]
- self.header['signal_channels'] = sig_channels
- self.header['unit_channels'] = unit_channels
- self.header['event_channels'] = event_channels
- self._generate_minimal_annotations()
- def _segment_t_start(self, block_index, seg_index):
- t_start = 0
- for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
- if mt.type == "neo.spiketrain":
- t_start = mt.metadata['t_start']
- return t_start
- def _segment_t_stop(self, block_index, seg_index):
- t_stop = 0
- for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
- if mt.type == "neo.spiketrain":
- t_stop = mt.metadata['t_stop']
- return t_stop
- def _get_signal_size(self, block_index, seg_index, channel_indexes):
- if channel_indexes is None:
- channel_indexes = list(range(self.header['signal_channels'].size))
- ch_idx = channel_indexes[0]
- size = self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'][ch_idx]
- return size # size is per signal, not the sum of all channel_indexes
- def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
- if channel_indexes is None:
- channel_indexes = list(range(self.header['signal_channels'].size))
- ch_idx = channel_indexes[0]
- da = [da for da in self.file.blocks[block_index].groups[seg_index].data_arrays][ch_idx]
- sig_t_start = float(da.metadata['t_start'])
- return sig_t_start # assume same group_id always same t_start
- def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
- if channel_indexes is None:
- channel_indexes = list(range(self.header['signal_channels'].size))
- if i_start is None:
- i_start = 0
- if i_stop is None:
- for c in channel_indexes:
- i_stop = self.da_list['blocks'][block_index]['segments'][seg_index]['data_size'][c]
- break
- raw_signals_list = []
- da_list = self.da_list['blocks'][block_index]['segments'][seg_index]
- for idx in channel_indexes:
- da = da_list['data'][idx]
- raw_signals_list.append(da[i_start:i_stop])
- raw_signals = np.array(raw_signals_list)
- raw_signals = np.transpose(raw_signals)
- return raw_signals
- def _spike_count(self, block_index, seg_index, unit_index):
- count = 0
- head_id = self.header['unit_channels'][unit_index][1]
- for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
- for src in mt.sources:
- if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]:
- if head_id == src.id:
- return len(mt.positions)
- return count
- def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
- spike_dict = self.unit_list['blocks'][block_index]['segments'][seg_index]['spiketrains']
- spike_timestamps = spike_dict[unit_index]
- spike_timestamps = np.transpose(spike_timestamps)
- if t_start is not None or t_stop is not None:
- lim0 = t_start
- lim1 = t_stop
- mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
- spike_timestamps = spike_timestamps[mask]
- return spike_timestamps
- def _rescale_spike_timestamp(self, spike_timestamps, dtype):
- spike_times = spike_timestamps.astype(dtype)
- return spike_times
- def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
- # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
- seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
- waveforms = seg['spiketrains_unit'][unit_index]['waveforms']
- if not waveforms:
- return None
- raw_waveforms = np.array(waveforms)
- if t_start is not None:
- lim0 = t_start
- mask = (raw_waveforms >= lim0)
- raw_waveforms = np.where(mask, raw_waveforms, np.nan) # use nan to keep the shape
- if t_stop is not None:
- lim1 = t_stop
- mask = (raw_waveforms <= lim1)
- raw_waveforms = np.where(mask, raw_waveforms, np.nan)
- return raw_waveforms
- def _event_count(self, block_index, seg_index, event_channel_index):
- event_count = 0
- for event in self.file.blocks[block_index].groups[seg_index].multi_tags:
- if event.type == 'neo.event' or event.type == 'neo.epoch':
- if event_count == event_channel_index:
- return len(event.positions)
- else:
- event_count += 1
- return event_count
- def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
- timestamp = []
- labels = []
- durations = None
- if event_channel_index is None:
- raise IndexError
- for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
- if mt.type == "neo.event" or mt.type == "neo.epoch":
- labels.append(mt.positions.dimensions[0].labels)
- po = mt.positions
- if po.type == "neo.event.times" or po.type == "neo.epoch.times":
- timestamp.append(po)
- if self.header['event_channels'][event_channel_index]['type'] == b'epoch' \
- and mt.extents:
- if mt.extents.type == 'neo.epoch.durations':
- durations = np.array(mt.extents)
- break
- timestamp = timestamp[event_channel_index][:]
- timestamp = np.array(timestamp, dtype="float")
- labels = labels[event_channel_index][:]
- labels = np.array(labels, dtype='U')
- if t_start is not None:
- keep = timestamp >= t_start
- timestamp, labels = timestamp[keep], labels[keep]
- if t_stop is not None:
- keep = timestamp <= t_stop
- timestamp, labels = timestamp[keep], labels[keep]
- return timestamp, durations, labels # only the first fits in rescale
- def _rescale_event_timestamp(self, event_timestamps, dtype='float64'):
- ev_unit = ''
- for mt in self.file.blocks[0].groups[0].multi_tags:
- if mt.type == "neo.event":
- ev_unit = mt.positions.unit
- break
- if ev_unit == 'ms':
- event_timestamps /= 1000
- event_times = event_timestamps.astype(dtype)
- # supposing unit is second, other possibilities maybe mS microS...
- return event_times # return in seconds
- def _rescale_epoch_duration(self, raw_duration, dtype='float64'):
- ep_unit = ''
- for mt in self.file.blocks[0].groups[0].multi_tags:
- if mt.type == "neo.epoch":
- ep_unit = mt.positions.unit
- break
- if ep_unit == 'ms':
- raw_duration /= 1000
- durations = raw_duration.astype(dtype)
- # supposing unit is second, other possibilities maybe mS microS...
- return durations # return in seconds
|