tdtrawio.py 21 KB


  1. """
  2. Class for reading data from from Tucker Davis TTank format.
  3. Terminology:
  4. TDT hold data with tanks (actually a directory). And tanks hold sub block
  5. (sub directories).
  6. Tanks correspond to neo.Block and tdt block correspond to neo.Segment.
  7. Note the name Block is ambiguous because it does not refer to same thing in TDT
  8. terminology and neo.
  9. In a directory there are several files:
  10. * TSQ timestamp index of data
  11. * TBK some kind of channel info and maybe more
  12. * TEV contains data : spike + event + signal (for old version)
  13. * SEV contains signals (for new version)
  14. * ./sort/ can contain offline spikesorting label for spike
  15. and can be use place of TEV.
  16. Units in this IO are not guaranteed.
  17. Author: Samuel Garcia, SummitKwan, Chadwick Boulay
  18. """
  19. from .baserawio import BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype
  20. import numpy as np
  21. import os
  22. import re
  23. from collections import OrderedDict
  24. class TdtRawIO(BaseRawIO):
  25. rawmode = 'one-dir'
  26. def __init__(self, dirname='', sortname=''):
  27. """
  28. 'sortname' is used to specify the external sortcode generated by offline spike sorting.
  29. if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
  30. which stores the sortcode for every spike; defaults to '',
  31. which uses the original online sort.
  32. """
  33. BaseRawIO.__init__(self)
  34. if dirname.endswith('/'):
  35. dirname = dirname[:-1]
  36. self.dirname = dirname
  37. self.sortname = sortname
  38. def _source_name(self):
  39. return self.dirname
  40. def _parse_header(self):
  41. tankname = os.path.basename(self.dirname)
  42. segment_names = []
  43. for segment_name in os.listdir(self.dirname):
  44. path = os.path.join(self.dirname, segment_name)
  45. if is_tdtblock(path):
  46. segment_names.append(segment_name)
  47. nb_segment = len(segment_names)
  48. # TBK (channel info)
  49. info_channel_groups = None
  50. for seg_index, segment_name in enumerate(segment_names):
  51. path = os.path.join(self.dirname, segment_name)
  52. # TBK contain channels
  53. tbk_filename = os.path.join(path, tankname + '_' + segment_name + '.Tbk')
  54. _info_channel_groups = read_tbk(tbk_filename)
  55. if info_channel_groups is None:
  56. info_channel_groups = _info_channel_groups
  57. else:
  58. assert np.array_equal(info_channel_groups,
  59. _info_channel_groups), 'Channels differ across segments'
  60. # TEV (mixed data)
  61. self._tev_datas = []
  62. for seg_index, segment_name in enumerate(segment_names):
  63. path = os.path.join(self.dirname, segment_name)
  64. tev_filename = os.path.join(path, tankname + '_' + segment_name + '.tev')
  65. if os.path.exists(tev_filename):
  66. tev_data = np.memmap(tev_filename, mode='r', offset=0, dtype='uint8')
  67. else:
  68. tev_data = None
  69. self._tev_datas.append(tev_data)
  70. # TSQ index with timestamp
  71. self._tsq = []
  72. self._seg_t_starts = []
  73. self._seg_t_stops = []
  74. for seg_index, segment_name in enumerate(segment_names):
  75. path = os.path.join(self.dirname, segment_name)
  76. tsq_filename = os.path.join(path, tankname + '_' + segment_name + '.tsq')
  77. tsq = np.fromfile(tsq_filename, dtype=tsq_dtype)
  78. self._tsq.append(tsq)
  79. # Start and stop times are only found in the second and last header row, respectively.
  80. if tsq[1]['evname'] == chr(EVMARK_STARTBLOCK).encode():
  81. self._seg_t_starts.append(tsq[1]['timestamp'])
  82. else:
  83. self._seg_t_starts.append(np.nan)
  84. print('segment start time not found')
  85. if tsq[-1]['evname'] == chr(EVMARK_STOPBLOCK).encode():
  86. self._seg_t_stops.append(tsq[-1]['timestamp'])
  87. else:
  88. self._seg_t_stops.append(np.nan)
  89. print('segment stop time not found')
  90. # If there exists an external sortcode in ./sort/[sortname]/*.SortResult
  91. # (generated after offline sorting)
  92. if self.sortname != '':
  93. try:
  94. for file in os.listdir(os.path.join(path, 'sort', sortname)):
  95. if file.endswith(".SortResult"):
  96. sortresult_filename = os.path.join(path, 'sort', sortname, file)
  97. # get new sortcode
  98. newsortcode = np.fromfile(sortresult_filename, 'int8')[
  99. 1024:] # first 1024 bytes are header
  100. # update the sort code with the info from this file
  101. tsq['sortcode'][1:-1] = newsortcode
  102. # print('sortcode updated')
  103. break
  104. except OSError:
  105. pass
  106. # Re-order segments according to their start times
  107. sort_inds = np.argsort(self._seg_t_starts)
  108. if not np.array_equal(sort_inds, list(range(nb_segment))):
  109. segment_names = [segment_names[x] for x in sort_inds]
  110. self._tev_datas = [self._tev_datas[x] for x in sort_inds]
  111. self._seg_t_starts = [self._seg_t_starts[x] for x in sort_inds]
  112. self._seg_t_stops = [self._seg_t_stops[x] for x in sort_inds]
  113. self._tsq = [self._tsq[x] for x in sort_inds]
  114. self._global_t_start = self._seg_t_starts[0]
  115. # signal channels EVTYPE_STREAM
  116. signal_channels = []
  117. self._sigs_data_buf = {seg_index: {} for seg_index in range(nb_segment)}
  118. self._sigs_index = {seg_index: {} for seg_index in range(nb_segment)}
  119. self._sig_dtype_by_group = {} # key = group_id
  120. self._sig_sample_per_chunk = {} # key = group_id
  121. self._sigs_lengths = {seg_index: {}
  122. for seg_index in range(nb_segment)} # key = seg_index then group_id
  123. self._sigs_t_start = {seg_index: {}
  124. for seg_index in range(nb_segment)} # key = seg_index then group_id
  125. keep = info_channel_groups['TankEvType'] == EVTYPE_STREAM
  126. for group_id, info in enumerate(info_channel_groups[keep]):
  127. self._sig_sample_per_chunk[group_id] = info['NumPoints']
  128. for c in range(info['NumChan']):
  129. chan_index = len(signal_channels)
  130. chan_id = c + 1 # If several StoreName then chan_id is not unique in TDT!!!!!
  131. # loop over segment to get sampling_rate/data_index/data_buffer
  132. sampling_rate = None
  133. dtype = None
  134. for seg_index, segment_name in enumerate(segment_names):
  135. # get data index
  136. tsq = self._tsq[seg_index]
  137. mask = (tsq['evtype'] == EVTYPE_STREAM) & \
  138. (tsq['evname'] == info['StoreName']) & \
  139. (tsq['channel'] == chan_id)
  140. data_index = tsq[mask].copy()
  141. self._sigs_index[seg_index][chan_index] = data_index
  142. size = info['NumPoints'] * data_index.size
  143. if group_id not in self._sigs_lengths[seg_index]:
  144. self._sigs_lengths[seg_index][group_id] = size
  145. else:
  146. assert self._sigs_lengths[seg_index][group_id] == size
  147. # signal start time, relative to start of segment
  148. t_start = data_index['timestamp'][0]
  149. if group_id not in self._sigs_t_start[seg_index]:
  150. self._sigs_t_start[seg_index][group_id] = t_start
  151. else:
  152. assert self._sigs_t_start[seg_index][group_id] == t_start
  153. # sampling_rate and dtype
  154. _sampling_rate = float(data_index['frequency'][0])
  155. _dtype = data_formats[data_index['dataformat'][0]]
  156. if sampling_rate is None:
  157. sampling_rate = _sampling_rate
  158. dtype = _dtype
  159. if group_id not in self._sig_dtype_by_group:
  160. self._sig_dtype_by_group[group_id] = np.dtype(dtype)
  161. else:
  162. assert self._sig_dtype_by_group[group_id] == dtype
  163. else:
  164. assert sampling_rate == _sampling_rate, 'sampling is changing!!!'
  165. assert dtype == _dtype, 'sampling is changing!!!'
  166. # data buffer test if SEV file exists otherwise TEV
  167. path = os.path.join(self.dirname, segment_name)
  168. sev_filename = os.path.join(path, tankname + '_' + segment_name + '_'
  169. + info['StoreName'].decode('ascii')
  170. + '_ch' + str(chan_id) + '.sev')
  171. if os.path.exists(sev_filename):
  172. data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8')
  173. else:
  174. data = self._tev_datas[seg_index]
  175. assert data is not None, 'no TEV nor SEV'
  176. self._sigs_data_buf[seg_index][chan_index] = data
  177. chan_name = '{} {}'.format(info['StoreName'], c + 1)
  178. sampling_rate = sampling_rate
  179. units = 'V' # WARNING this is not sur at all
  180. gain = 1.
  181. offset = 0.
  182. signal_channels.append((chan_name, chan_id, sampling_rate, dtype,
  183. units, gain, offset, group_id))
  184. signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
  185. # unit channels EVTYPE_SNIP
  186. self.internal_unit_ids = {}
  187. self._waveforms_size = []
  188. self._waveforms_dtype = []
  189. unit_channels = []
  190. keep = info_channel_groups['TankEvType'] == EVTYPE_SNIP
  191. tsq = np.hstack(self._tsq)
  192. # If there is no chance the differet TSQ files will have different units,
  193. # then we can do tsq = self._tsq[0]
  194. for info in info_channel_groups[keep]:
  195. for c in range(info['NumChan']):
  196. chan_id = c + 1
  197. mask = (tsq['evtype'] == EVTYPE_SNIP) & \
  198. (tsq['evname'] == info['StoreName']) & \
  199. (tsq['channel'] == chan_id)
  200. unit_ids = np.unique(tsq[mask]['sortcode'])
  201. for unit_id in unit_ids:
  202. unit_index = len(unit_channels)
  203. self.internal_unit_ids[unit_index] = (info['StoreName'], chan_id, unit_id)
  204. unit_name = "ch{}#{}".format(chan_id, unit_id)
  205. wf_units = 'V'
  206. wf_gain = 1.
  207. wf_offset = 0.
  208. wf_left_sweep = info['NumPoints'] // 2
  209. wf_sampling_rate = info['SampleFreq']
  210. unit_channels.append((unit_name, '{}'.format(unit_id),
  211. wf_units, wf_gain, wf_offset,
  212. wf_left_sweep, wf_sampling_rate))
  213. self._waveforms_size.append(info['NumPoints'])
  214. self._waveforms_dtype.append(np.dtype(data_formats[info['DataFormat']]))
  215. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  216. # signal channels EVTYPE_STRON
  217. event_channels = []
  218. keep = info_channel_groups['TankEvType'] == EVTYPE_STRON
  219. for info in info_channel_groups[keep]:
  220. chan_name = info['StoreName']
  221. chan_id = 1
  222. event_channels.append((chan_name, chan_id, 'event'))
  223. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  224. # fill into header dict
  225. self.header = {}
  226. self.header['nb_block'] = 1
  227. self.header['nb_segment'] = [nb_segment]
  228. self.header['signal_channels'] = signal_channels
  229. self.header['unit_channels'] = unit_channels
  230. self.header['event_channels'] = event_channels
  231. # Annotations only standard ones:
  232. self._generate_minimal_annotations()
  233. def _block_count(self):
  234. return 1
  235. def _segment_count(self, block_index):
  236. return self.header['nb_segment'][block_index]
  237. def _segment_t_start(self, block_index, seg_index):
  238. return self._seg_t_starts[seg_index] - self._global_t_start
  239. def _segment_t_stop(self, block_index, seg_index):
  240. return self._seg_t_stops[seg_index] - self._global_t_start
  241. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  242. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  243. size = self._sigs_lengths[seg_index][group_id]
  244. return size
  245. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  246. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  247. return self._sigs_t_start[seg_index][group_id] - self._global_t_start
  248. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  249. # check of channel_indexes is same group_id is done outside (BaseRawIO)
  250. # so first is identique to others
  251. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  252. if i_start is None:
  253. i_start = 0
  254. if i_stop is None:
  255. i_stop = self._sigs_lengths[seg_index][group_id]
  256. dt = self._sig_dtype_by_group[group_id]
  257. raw_signals = np.zeros((i_stop - i_start, len(channel_indexes)), dtype=dt)
  258. sample_per_chunk = self._sig_sample_per_chunk[group_id]
  259. bl0 = i_start // sample_per_chunk
  260. bl1 = int(np.ceil(i_stop / sample_per_chunk))
  261. chunk_nb_bytes = sample_per_chunk * dt.itemsize
  262. for c, channel_index in enumerate(channel_indexes):
  263. data_index = self._sigs_index[seg_index][channel_index]
  264. data_buf = self._sigs_data_buf[seg_index][channel_index]
  265. # loop over data blocks and get chunks
  266. ind = 0
  267. for bl in range(bl0, bl1):
  268. ind0 = data_index[bl]['offset']
  269. ind1 = ind0 + chunk_nb_bytes
  270. data = data_buf[ind0:ind1].view(dt)
  271. if bl == bl1 - 1:
  272. # right border
  273. # be careful that bl could be both bl0 and bl1!!
  274. border = data.size - (i_stop % sample_per_chunk)
  275. data = data[:-border]
  276. if bl == bl0:
  277. # left border
  278. border = i_start % sample_per_chunk
  279. data = data[border:]
  280. raw_signals[ind:data.size + ind, c] = data
  281. ind += data.size
  282. return raw_signals
  283. def _get_mask(self, tsq, seg_index, evtype, evname, chan_id, unit_id, t_start, t_stop):
  284. """Used inside spike and events methods"""
  285. mask = (tsq['evtype'] == evtype) & \
  286. (tsq['evname'] == evname) & \
  287. (tsq['channel'] == chan_id)
  288. if unit_id is not None:
  289. mask &= (tsq['sortcode'] == unit_id)
  290. if t_start is not None:
  291. mask &= tsq['timestamp'] >= (t_start + self._global_t_start)
  292. if t_stop is not None:
  293. mask &= tsq['timestamp'] <= (t_stop + self._global_t_start)
  294. return mask
  295. def _spike_count(self, block_index, seg_index, unit_index):
  296. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  297. tsq = self._tsq[seg_index]
  298. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  299. chan_id, unit_id, None, None)
  300. nb_spike = np.sum(mask)
  301. return nb_spike
  302. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  303. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  304. tsq = self._tsq[seg_index]
  305. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  306. chan_id, unit_id, t_start, t_stop)
  307. timestamps = tsq[mask]['timestamp']
  308. timestamps -= self._global_t_start
  309. return timestamps
  310. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  311. # already in s
  312. spike_times = spike_timestamps.astype(dtype)
  313. return spike_times
  314. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  315. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  316. tsq = self._tsq[seg_index]
  317. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  318. chan_id, unit_id, t_start, t_stop)
  319. nb_spike = np.sum(mask)
  320. data = self._tev_datas[seg_index]
  321. dt = self._waveforms_dtype[unit_index]
  322. nb_sample = self._waveforms_size[unit_index]
  323. waveforms = np.zeros((nb_spike, 1, nb_sample), dtype=dt)
  324. for i, e in enumerate(tsq[mask]):
  325. ind0 = e['offset']
  326. ind1 = ind0 + nb_sample * dt.itemsize
  327. waveforms[i, 0, :] = data[ind0:ind1].view(dt)
  328. return waveforms
  329. def _event_count(self, block_index, seg_index, event_channel_index):
  330. h = self.header['event_channels'][event_channel_index]
  331. store_name = h['name'].encode('ascii')
  332. tsq = self._tsq[seg_index]
  333. chan_id = 0
  334. mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
  335. nb_event = np.sum(mask)
  336. return nb_event
  337. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  338. h = self.header['event_channels'][event_channel_index]
  339. store_name = h['name'].encode('ascii')
  340. tsq = self._tsq[seg_index]
  341. chan_id = 0
  342. mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
  343. timestamps = tsq[mask]['timestamp']
  344. timestamps -= self._global_t_start
  345. labels = tsq[mask]['offset'].astype('U')
  346. durations = None
  347. # TODO if user demand event to epoch
  348. # with EVTYPE_STROFF=258
  349. # and so durations would be not None
  350. # it was not implemented in previous IO.
  351. return timestamps, durations, labels
  352. def _rescale_event_timestamp(self, event_timestamps, dtype):
  353. # already in s
  354. ev_times = event_timestamps.astype(dtype)
  355. return ev_times
  356. tbk_field_types = [
  357. ('StoreName', 'S4'),
  358. ('HeadName', 'S16'),
  359. ('Enabled', 'bool'),
  360. ('CircType', 'int'),
  361. ('NumChan', 'int'),
  362. ('StrobeMode', 'int'),
  363. ('TankEvType', 'int32'),
  364. ('NumPoints', 'int'),
  365. ('DataFormat', 'int'),
  366. ('SampleFreq', 'float64'),
  367. ]
  368. def read_tbk(tbk_filename):
  369. """
  370. Tbk contains some visible header in txt mode to describe
  371. channel group info.
  372. """
  373. with open(tbk_filename, mode='rb') as f:
  374. txt_header = f.read()
  375. infos = []
  376. for chan_grp_header in txt_header.split(b'[STOREHDRITEM]'):
  377. if chan_grp_header.startswith(b'[USERNOTEDELIMITER]'):
  378. break
  379. # parse into a dict
  380. info = OrderedDict()
  381. pattern = br'NAME=(\S+);TYPE=(\S+);VALUE=(\S+);'
  382. r = re.findall(pattern, chan_grp_header)
  383. for name, _type, value in r:
  384. info[name.decode('ascii')] = value
  385. infos.append(info)
  386. # and put into numpy
  387. info_channel_groups = np.zeros(len(infos), dtype=tbk_field_types)
  388. for i, info in enumerate(infos):
  389. for k, dt in tbk_field_types:
  390. v = np.dtype(dt).type(info[k])
  391. info_channel_groups[i][k] = v
  392. return info_channel_groups
  393. tsq_dtype = [
  394. ('size', 'int32'), # bytes 0-4
  395. ('evtype', 'int32'), # bytes 5-8
  396. ('evname', 'S4'), # bytes 9-12
  397. ('channel', 'uint16'), # bytes 13-14
  398. ('sortcode', 'uint16'), # bytes 15-16
  399. ('timestamp', 'float64'), # bytes 17-24
  400. ('offset', 'int64'), # bytes 25-32
  401. ('dataformat', 'int32'), # bytes 33-36
  402. ('frequency', 'float32'), # bytes 37-40
  403. ]
  404. EVTYPE_UNKNOWN = int('00000000', 16) # 0
  405. EVTYPE_STRON = int('00000101', 16) # 257
  406. EVTYPE_STROFF = int('00000102', 16) # 258
  407. EVTYPE_SCALAR = int('00000201', 16) # 513
  408. EVTYPE_STREAM = int('00008101', 16) # 33025
  409. EVTYPE_SNIP = int('00008201', 16) # 33281
  410. EVTYPE_MARK = int('00008801', 16) # 34817
  411. EVTYPE_HASDATA = int('00008000', 16) # 32768
  412. EVTYPE_UCF = int('00000010', 16) # 16
  413. EVTYPE_PHANTOM = int('00000020', 16) # 32
  414. EVTYPE_MASK = int('0000FF0F', 16) # 65295
  415. EVTYPE_INVALID_MASK = int('FFFF0000', 16) # 4294901760
  416. EVMARK_STARTBLOCK = int('0001', 16) # 1
  417. EVMARK_STOPBLOCK = int('0002', 16) # 2
  418. data_formats = {
  419. 0: 'float32',
  420. 1: 'int32',
  421. 2: 'int16',
  422. 3: 'int8',
  423. 4: 'float64',
  424. }
  425. def is_tdtblock(blockpath):
  426. """Is tha path a TDT block (=neo.Segment) ?"""
  427. file_ext = list()
  428. if os.path.isdir(blockpath):
  429. # for every file, get extension, convert to lowercase and append
  430. for file in os.listdir(blockpath):
  431. file_ext.append(os.path.splitext(file)[1].lower())
  432. file_ext = set(file_ext)
  433. tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'}
  434. if file_ext >= tdt_ext: # if containing all the necessary files
  435. return True
  436. else:
  437. return False