123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529 |
- # -*- coding: utf-8 -*-
- """
- Class for reading data from from Tucker Davis TTank format.
- Terminology:
- TDT hold data with tanks (actually a directory). And tanks hold sub block
- (sub directories).
- Tanks correspond to neo.Block and tdt block correspond to neo.Segment.
- Note the name Block is ambiguous because it does not refer to same thing in TDT
- terminology and neo.
- In a directory there are several files:
- * TSQ timestamp index of data
- * TBK some kind of channel info and maybe more
- * TEV contains data : spike + event + signal (for old version)
- * SEV contains signals (for new version)
- * ./sort/ can contain offline spikesorting label for spike
- and can be use place of TEV.
- Units in this IO are not guaranteed.
- Author: Samuel Garcia, SummitKwan, Chadwick Boulay
- """
- from __future__ import print_function, division, absolute_import
- # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
- from .baserawio import BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype
- import numpy as np
- import os
- import re
- from collections import OrderedDict
- class TdtRawIO(BaseRawIO):
- rawmode = 'one-dir'
- def __init__(self, dirname='', sortname=''):
- """
- 'sortname' is used to specify the external sortcode generated by offline spike sorting.
- if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
- which stores the sortcode for every spike; defaults to '',
- which uses the original online sort.
- """
- BaseRawIO.__init__(self)
- if dirname.endswith('/'):
- dirname = dirname[:-1]
- self.dirname = dirname
- self.sortname = sortname
- def _source_name(self):
- return self.dirname
- def _parse_header(self):
- tankname = os.path.basename(self.dirname)
- segment_names = []
- for segment_name in os.listdir(self.dirname):
- path = os.path.join(self.dirname, segment_name)
- if is_tdtblock(path):
- segment_names.append(segment_name)
- nb_segment = len(segment_names)
- # TBK (channel info)
- info_channel_groups = None
- for seg_index, segment_name in enumerate(segment_names):
- path = os.path.join(self.dirname, segment_name)
- # TBK contain channels
- tbk_filename = os.path.join(path, tankname + '_' + segment_name + '.Tbk')
- _info_channel_groups = read_tbk(tbk_filename)
- if info_channel_groups is None:
- info_channel_groups = _info_channel_groups
- else:
- assert np.array_equal(info_channel_groups,
- _info_channel_groups), 'Channels differ across segments'
- # TEV (mixed data)
- self._tev_datas = []
- for seg_index, segment_name in enumerate(segment_names):
- path = os.path.join(self.dirname, segment_name)
- tev_filename = os.path.join(path, tankname + '_' + segment_name + '.tev')
- if os.path.exists(tev_filename):
- tev_data = np.memmap(tev_filename, mode='r', offset=0, dtype='uint8')
- else:
- tev_data = None
- self._tev_datas.append(tev_data)
- # TSQ index with timestamp
- self._tsq = []
- self._seg_t_starts = []
- self._seg_t_stops = []
- for seg_index, segment_name in enumerate(segment_names):
- path = os.path.join(self.dirname, segment_name)
- tsq_filename = os.path.join(path, tankname + '_' + segment_name + '.tsq')
- tsq = np.fromfile(tsq_filename, dtype=tsq_dtype)
- self._tsq.append(tsq)
- # Start and stop times are only found in the second and last header row, respectively.
- if tsq[1]['evname'] == chr(EVMARK_STARTBLOCK).encode():
- self._seg_t_starts.append(tsq[1]['timestamp'])
- else:
- self._seg_t_starts.append(np.nan)
- print('segment start time not found')
- if tsq[-1]['evname'] == chr(EVMARK_STOPBLOCK).encode():
- self._seg_t_stops.append(tsq[-1]['timestamp'])
- else:
- self._seg_t_stops.append(np.nan)
- print('segment stop time not found')
- # If there exists an external sortcode in ./sort/[sortname]/*.SortResult
- # (generated after offline sorting)
- if self.sortname is not '':
- try:
- for file in os.listdir(os.path.join(path, 'sort', sortname)):
- if file.endswith(".SortResult"):
- sortresult_filename = os.path.join(path, 'sort', sortname, file)
- # get new sortcode
- newsortcode = np.fromfile(sortresult_filename, 'int8')[
- 1024:] # first 1024 bytes are header
- # update the sort code with the info from this file
- tsq['sortcode'][1:-1] = newsortcode
- # print('sortcode updated')
- break
- except OSError:
- pass
- except IOError:
- pass
- # Re-order segments according to their start times
- sort_inds = np.argsort(self._seg_t_starts)
- if not np.array_equal(sort_inds, list(range(nb_segment))):
- segment_names = [segment_names[x] for x in sort_inds]
- self._tev_datas = [self._tev_datas[x] for x in sort_inds]
- self._seg_t_starts = [self._seg_t_starts[x] for x in sort_inds]
- self._seg_t_stops = [self._seg_t_stops[x] for x in sort_inds]
- self._tsq = [self._tsq[x] for x in sort_inds]
- self._global_t_start = self._seg_t_starts[0]
- # signal channels EVTYPE_STREAM
- signal_channels = []
- self._sigs_data_buf = {seg_index: {} for seg_index in range(nb_segment)}
- self._sigs_index = {seg_index: {} for seg_index in range(nb_segment)}
- self._sig_dtype_by_group = {} # key = group_id
- self._sig_sample_per_chunk = {} # key = group_id
- self._sigs_lengths = {seg_index: {}
- for seg_index in range(nb_segment)} # key = seg_index then group_id
- self._sigs_t_start = {seg_index: {}
- for seg_index in range(nb_segment)} # key = seg_index then group_id
- keep = info_channel_groups['TankEvType'] == EVTYPE_STREAM
- for group_id, info in enumerate(info_channel_groups[keep]):
- self._sig_sample_per_chunk[group_id] = info['NumPoints']
- for c in range(info['NumChan']):
- chan_index = len(signal_channels)
- chan_id = c + 1 # If several StoreName then chan_id is not unique in TDT!!!!!
- # loop over segment to get sampling_rate/data_index/data_buffer
- sampling_rate = None
- dtype = None
- for seg_index, segment_name in enumerate(segment_names):
- # get data index
- tsq = self._tsq[seg_index]
- mask = (tsq['evtype'] == EVTYPE_STREAM) & \
- (tsq['evname'] == info['StoreName']) & \
- (tsq['channel'] == chan_id)
- data_index = tsq[mask].copy()
- self._sigs_index[seg_index][chan_index] = data_index
- size = info['NumPoints'] * data_index.size
- if group_id not in self._sigs_lengths[seg_index]:
- self._sigs_lengths[seg_index][group_id] = size
- else:
- assert self._sigs_lengths[seg_index][group_id] == size
- # signal start time, relative to start of segment
- t_start = data_index['timestamp'][0]
- if group_id not in self._sigs_t_start[seg_index]:
- self._sigs_t_start[seg_index][group_id] = t_start
- else:
- assert self._sigs_t_start[seg_index][group_id] == t_start
- # sampling_rate and dtype
- _sampling_rate = float(data_index['frequency'][0])
- _dtype = data_formats[data_index['dataformat'][0]]
- if sampling_rate is None:
- sampling_rate = _sampling_rate
- dtype = _dtype
- if group_id not in self._sig_dtype_by_group:
- self._sig_dtype_by_group[group_id] = np.dtype(dtype)
- else:
- assert self._sig_dtype_by_group[group_id] == dtype
- else:
- assert sampling_rate == _sampling_rate, 'sampling is changing!!!'
- assert dtype == _dtype, 'sampling is changing!!!'
- # data buffer test if SEV file exists otherwise TEV
- path = os.path.join(self.dirname, segment_name)
- sev_filename = os.path.join(path, tankname + '_' + segment_name + '_'
- + info['StoreName'].decode('ascii')
- + '_ch' + str(chan_id) + '.sev')
- if os.path.exists(sev_filename):
- data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8')
- else:
- data = self._tev_datas[seg_index]
- assert data is not None, 'no TEV nor SEV'
- self._sigs_data_buf[seg_index][chan_index] = data
- chan_name = '{} {}'.format(info['StoreName'], c + 1)
- sampling_rate = sampling_rate
- units = 'V' # WARNING this is not sur at all
- gain = 1.
- offset = 0.
- signal_channels.append((chan_name, chan_id, sampling_rate, dtype,
- units, gain, offset, group_id))
- signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
- # unit channels EVTYPE_SNIP
- self.internal_unit_ids = {}
- self._waveforms_size = []
- self._waveforms_dtype = []
- unit_channels = []
- keep = info_channel_groups['TankEvType'] == EVTYPE_SNIP
- tsq = np.hstack(self._tsq)
- # If there is no chance the differet TSQ files will have different units,
- # then we can do tsq = self._tsq[0]
- for info in info_channel_groups[keep]:
- for c in range(info['NumChan']):
- chan_id = c + 1
- mask = (tsq['evtype'] == EVTYPE_SNIP) & \
- (tsq['evname'] == info['StoreName']) & \
- (tsq['channel'] == chan_id)
- unit_ids = np.unique(tsq[mask]['sortcode'])
- for unit_id in unit_ids:
- unit_index = len(unit_channels)
- self.internal_unit_ids[unit_index] = (info['StoreName'], chan_id, unit_id)
- unit_name = "ch{}#{}".format(chan_id, unit_id)
- wf_units = 'V'
- wf_gain = 1.
- wf_offset = 0.
- wf_left_sweep = info['NumPoints'] // 2
- wf_sampling_rate = info['SampleFreq']
- unit_channels.append((unit_name, '{}'.format(unit_id),
- wf_units, wf_gain, wf_offset,
- wf_left_sweep, wf_sampling_rate))
- self._waveforms_size.append(info['NumPoints'])
- self._waveforms_dtype.append(np.dtype(data_formats[info['DataFormat']]))
- unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
- # signal channels EVTYPE_STRON
- event_channels = []
- keep = info_channel_groups['TankEvType'] == EVTYPE_STRON
- for info in info_channel_groups[keep]:
- chan_name = info['StoreName']
- chan_id = 1
- event_channels.append((chan_name, chan_id, 'event'))
- event_channels = np.array(event_channels, dtype=_event_channel_dtype)
- # fill into header dict
- self.header = {}
- self.header['nb_block'] = 1
- self.header['nb_segment'] = [nb_segment]
- self.header['signal_channels'] = signal_channels
- self.header['unit_channels'] = unit_channels
- self.header['event_channels'] = event_channels
- # Annotations only standard ones:
- self._generate_minimal_annotations()
- def _block_count(self):
- return 1
- def _segment_count(self, block_index):
- return self.header['nb_segment'][block_index]
- def _segment_t_start(self, block_index, seg_index):
- return self._seg_t_starts[seg_index] - self._global_t_start
- def _segment_t_stop(self, block_index, seg_index):
- return self._seg_t_stops[seg_index] - self._global_t_start
- def _get_signal_size(self, block_index, seg_index, channel_indexes):
- group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
- size = self._sigs_lengths[seg_index][group_id]
- return size
- def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
- group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
- return self._sigs_t_start[seg_index][group_id] - self._global_t_start
- def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
- # check of channel_indexes is same group_id is done outside (BaseRawIO)
- # so first is identique to others
- group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
- if i_start is None:
- i_start = 0
- if i_stop is None:
- i_stop = self._sigs_lengths[seg_index][group_id]
- dt = self._sig_dtype_by_group[group_id]
- raw_signals = np.zeros((i_stop - i_start, len(channel_indexes)), dtype=dt)
- sample_per_chunk = self._sig_sample_per_chunk[group_id]
- bl0 = i_start // sample_per_chunk
- bl1 = int(np.ceil(i_stop / sample_per_chunk))
- chunk_nb_bytes = sample_per_chunk * dt.itemsize
- for c, channel_index in enumerate(channel_indexes):
- data_index = self._sigs_index[seg_index][channel_index]
- data_buf = self._sigs_data_buf[seg_index][channel_index]
- # loop over data blocks and get chunks
- ind = 0
- for bl in range(bl0, bl1):
- ind0 = data_index[bl]['offset']
- ind1 = ind0 + chunk_nb_bytes
- data = data_buf[ind0:ind1].view(dt)
- if bl == bl1 - 1:
- # right border
- # be careful that bl could be both bl0 and bl1!!
- border = data.size - (i_stop % sample_per_chunk)
- data = data[:-border]
- if bl == bl0:
- # left border
- border = i_start % sample_per_chunk
- data = data[border:]
- raw_signals[ind:data.size + ind, c] = data
- ind += data.size
- return raw_signals
- def _get_mask(self, tsq, seg_index, evtype, evname, chan_id, unit_id, t_start, t_stop):
- """Used inside spike and events methods"""
- mask = (tsq['evtype'] == evtype) & \
- (tsq['evname'] == evname) & \
- (tsq['channel'] == chan_id)
- if unit_id is not None:
- mask &= (tsq['sortcode'] == unit_id)
- if t_start is not None:
- mask &= tsq['timestamp'] >= (t_start + self._global_t_start)
- if t_stop is not None:
- mask &= tsq['timestamp'] <= (t_stop + self._global_t_start)
- return mask
- def _spike_count(self, block_index, seg_index, unit_index):
- store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
- tsq = self._tsq[seg_index]
- mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
- chan_id, unit_id, None, None)
- nb_spike = np.sum(mask)
- return nb_spike
- def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
- store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
- tsq = self._tsq[seg_index]
- mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
- chan_id, unit_id, t_start, t_stop)
- timestamps = tsq[mask]['timestamp']
- timestamps -= self._global_t_start
- return timestamps
- def _rescale_spike_timestamp(self, spike_timestamps, dtype):
- # already in s
- spike_times = spike_timestamps.astype(dtype)
- return spike_times
- def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
- store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
- tsq = self._tsq[seg_index]
- mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
- chan_id, unit_id, t_start, t_stop)
- nb_spike = np.sum(mask)
- data = self._tev_datas[seg_index]
- dt = self._waveforms_dtype[unit_index]
- nb_sample = self._waveforms_size[unit_index]
- waveforms = np.zeros((nb_spike, 1, nb_sample), dtype=dt)
- for i, e in enumerate(tsq[mask]):
- ind0 = e['offset']
- ind1 = ind0 + nb_sample * dt.itemsize
- waveforms[i, 0, :] = data[ind0:ind1].view(dt)
- return waveforms
- def _event_count(self, block_index, seg_index, event_channel_index):
- h = self.header['event_channels'][event_channel_index]
- store_name = h['name'].encode('ascii')
- tsq = self._tsq[seg_index]
- chan_id = 0
- mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
- nb_event = np.sum(mask)
- return nb_event
- def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
- h = self.header['event_channels'][event_channel_index]
- store_name = h['name'].encode('ascii')
- tsq = self._tsq[seg_index]
- chan_id = 0
- mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
- timestamps = tsq[mask]['timestamp']
- timestamps -= self._global_t_start
- labels = tsq[mask]['offset'].astype('U')
- durations = None
- # TODO if user demand event to epoch
- # with EVTYPE_STROFF=258
- # and so durations would be not None
- # it was not implemented in previous IO.
- return timestamps, durations, labels
- def _rescale_event_timestamp(self, event_timestamps, dtype):
- # already in s
- ev_times = event_timestamps.astype(dtype)
- return ev_times
- tbk_field_types = [
- ('StoreName', 'S4'),
- ('HeadName', 'S16'),
- ('Enabled', 'bool'),
- ('CircType', 'int'),
- ('NumChan', 'int'),
- ('StrobeMode', 'int'),
- ('TankEvType', 'int32'),
- ('NumPoints', 'int'),
- ('DataFormat', 'int'),
- ('SampleFreq', 'float64'),
- ]
- def read_tbk(tbk_filename):
- """
- Tbk contains some visible header in txt mode to describe
- channel group info.
- """
- with open(tbk_filename, mode='rb') as f:
- txt_header = f.read()
- infos = []
- for chan_grp_header in txt_header.split(b'[STOREHDRITEM]'):
- if chan_grp_header.startswith(b'[USERNOTEDELIMITER]'):
- break
- # parse into a dict
- info = OrderedDict()
- pattern = br'NAME=(\S+);TYPE=(\S+);VALUE=(\S+);'
- r = re.findall(pattern, chan_grp_header)
- for name, _type, value in r:
- info[name.decode('ascii')] = value
- infos.append(info)
- # and put into numpy
- info_channel_groups = np.zeros(len(infos), dtype=tbk_field_types)
- for i, info in enumerate(infos):
- for k, dt in tbk_field_types:
- v = np.dtype(dt).type(info[k])
- info_channel_groups[i][k] = v
- return info_channel_groups
- tsq_dtype = [
- ('size', 'int32'), # bytes 0-4
- ('evtype', 'int32'), # bytes 5-8
- ('evname', 'S4'), # bytes 9-12
- ('channel', 'uint16'), # bytes 13-14
- ('sortcode', 'uint16'), # bytes 15-16
- ('timestamp', 'float64'), # bytes 17-24
- ('offset', 'int64'), # bytes 25-32
- ('dataformat', 'int32'), # bytes 33-36
- ('frequency', 'float32'), # bytes 37-40
- ]
- EVTYPE_UNKNOWN = int('00000000', 16) # 0
- EVTYPE_STRON = int('00000101', 16) # 257
- EVTYPE_STROFF = int('00000102', 16) # 258
- EVTYPE_SCALAR = int('00000201', 16) # 513
- EVTYPE_STREAM = int('00008101', 16) # 33025
- EVTYPE_SNIP = int('00008201', 16) # 33281
- EVTYPE_MARK = int('00008801', 16) # 34817
- EVTYPE_HASDATA = int('00008000', 16) # 32768
- EVTYPE_UCF = int('00000010', 16) # 16
- EVTYPE_PHANTOM = int('00000020', 16) # 32
- EVTYPE_MASK = int('0000FF0F', 16) # 65295
- EVTYPE_INVALID_MASK = int('FFFF0000', 16) # 4294901760
- EVMARK_STARTBLOCK = int('0001', 16) # 1
- EVMARK_STOPBLOCK = int('0002', 16) # 2
- data_formats = {
- 0: 'float32',
- 1: 'int32',
- 2: 'int16',
- 3: 'int8',
- 4: 'float64',
- }
- def is_tdtblock(blockpath):
- """Is tha path a TDT block (=neo.Segment) ?"""
- file_ext = list()
- if os.path.isdir(blockpath):
- # for every file, get extension, convert to lowercase and append
- for file in os.listdir(blockpath):
- file_ext.append(os.path.splitext(file)[1].lower())
- file_ext = set(file_ext)
- tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'}
- if file_ext >= tdt_ext: # if containing all the necessary files
- return True
- else:
- return False
|