tdtrawio.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading data from from Tucker Davis TTank format.
  4. Terminology:
  5. TDT hold data with tanks (actually a directory). And tanks hold sub block
  6. (sub directories).
  7. Tanks correspond to neo.Block and tdt block correspond to neo.Segment.
  8. Note the name Block is ambiguous because it does not refer to same thing in TDT
  9. terminology and neo.
  10. In a directory there are several files:
  11. * TSQ timestamp index of data
  12. * TBK some kind of channel info and maybe more
  13. * TEV contains data : spike + event + signal (for old version)
  14. * SEV contains signals (for new version)
  15. * ./sort/ can contain offline spikesorting label for spike
  16. and can be use place of TEV.
  17. Units in this IO are not guaranteed.
  18. Author: Samuel Garcia, SummitKwan, Chadwick Boulay
  19. """
  20. from __future__ import print_function, division, absolute_import
  21. # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
  22. from .baserawio import BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype
  23. import numpy as np
  24. import os
  25. import re
  26. from collections import OrderedDict
  27. class TdtRawIO(BaseRawIO):
  28. rawmode = 'one-dir'
  29. def __init__(self, dirname='', sortname=''):
  30. """
  31. 'sortname' is used to specify the external sortcode generated by offline spike sorting.
  32. if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
  33. which stores the sortcode for every spike; defaults to '',
  34. which uses the original online sort.
  35. """
  36. BaseRawIO.__init__(self)
  37. if dirname.endswith('/'):
  38. dirname = dirname[:-1]
  39. self.dirname = dirname
  40. self.sortname = sortname
  41. def _source_name(self):
  42. return self.dirname
  43. def _parse_header(self):
  44. tankname = os.path.basename(self.dirname)
  45. segment_names = []
  46. for segment_name in os.listdir(self.dirname):
  47. path = os.path.join(self.dirname, segment_name)
  48. if is_tdtblock(path):
  49. segment_names.append(segment_name)
  50. nb_segment = len(segment_names)
  51. # TBK (channel info)
  52. info_channel_groups = None
  53. for seg_index, segment_name in enumerate(segment_names):
  54. path = os.path.join(self.dirname, segment_name)
  55. # TBK contain channels
  56. tbk_filename = os.path.join(path, tankname + '_' + segment_name + '.Tbk')
  57. _info_channel_groups = read_tbk(tbk_filename)
  58. if info_channel_groups is None:
  59. info_channel_groups = _info_channel_groups
  60. else:
  61. assert np.array_equal(info_channel_groups,
  62. _info_channel_groups), 'Channels differ across segments'
  63. # TEV (mixed data)
  64. self._tev_datas = []
  65. for seg_index, segment_name in enumerate(segment_names):
  66. path = os.path.join(self.dirname, segment_name)
  67. tev_filename = os.path.join(path, tankname + '_' + segment_name + '.tev')
  68. if os.path.exists(tev_filename):
  69. tev_data = np.memmap(tev_filename, mode='r', offset=0, dtype='uint8')
  70. else:
  71. tev_data = None
  72. self._tev_datas.append(tev_data)
  73. # TSQ index with timestamp
  74. self._tsq = []
  75. self._seg_t_starts = []
  76. self._seg_t_stops = []
  77. for seg_index, segment_name in enumerate(segment_names):
  78. path = os.path.join(self.dirname, segment_name)
  79. tsq_filename = os.path.join(path, tankname + '_' + segment_name + '.tsq')
  80. tsq = np.fromfile(tsq_filename, dtype=tsq_dtype)
  81. self._tsq.append(tsq)
  82. # Start and stop times are only found in the second and last header row, respectively.
  83. if tsq[1]['evname'] == chr(EVMARK_STARTBLOCK).encode():
  84. self._seg_t_starts.append(tsq[1]['timestamp'])
  85. else:
  86. self._seg_t_starts.append(np.nan)
  87. print('segment start time not found')
  88. if tsq[-1]['evname'] == chr(EVMARK_STOPBLOCK).encode():
  89. self._seg_t_stops.append(tsq[-1]['timestamp'])
  90. else:
  91. self._seg_t_stops.append(np.nan)
  92. print('segment stop time not found')
  93. # If there exists an external sortcode in ./sort/[sortname]/*.SortResult
  94. # (generated after offline sorting)
  95. if self.sortname is not '':
  96. try:
  97. for file in os.listdir(os.path.join(path, 'sort', sortname)):
  98. if file.endswith(".SortResult"):
  99. sortresult_filename = os.path.join(path, 'sort', sortname, file)
  100. # get new sortcode
  101. newsortcode = np.fromfile(sortresult_filename, 'int8')[
  102. 1024:] # first 1024 bytes are header
  103. # update the sort code with the info from this file
  104. tsq['sortcode'][1:-1] = newsortcode
  105. # print('sortcode updated')
  106. break
  107. except OSError:
  108. pass
  109. except IOError:
  110. pass
  111. # Re-order segments according to their start times
  112. sort_inds = np.argsort(self._seg_t_starts)
  113. if not np.array_equal(sort_inds, list(range(nb_segment))):
  114. segment_names = [segment_names[x] for x in sort_inds]
  115. self._tev_datas = [self._tev_datas[x] for x in sort_inds]
  116. self._seg_t_starts = [self._seg_t_starts[x] for x in sort_inds]
  117. self._seg_t_stops = [self._seg_t_stops[x] for x in sort_inds]
  118. self._tsq = [self._tsq[x] for x in sort_inds]
  119. self._global_t_start = self._seg_t_starts[0]
  120. # signal channels EVTYPE_STREAM
  121. signal_channels = []
  122. self._sigs_data_buf = {seg_index: {} for seg_index in range(nb_segment)}
  123. self._sigs_index = {seg_index: {} for seg_index in range(nb_segment)}
  124. self._sig_dtype_by_group = {} # key = group_id
  125. self._sig_sample_per_chunk = {} # key = group_id
  126. self._sigs_lengths = {seg_index: {}
  127. for seg_index in range(nb_segment)} # key = seg_index then group_id
  128. self._sigs_t_start = {seg_index: {}
  129. for seg_index in range(nb_segment)} # key = seg_index then group_id
  130. keep = info_channel_groups['TankEvType'] == EVTYPE_STREAM
  131. for group_id, info in enumerate(info_channel_groups[keep]):
  132. self._sig_sample_per_chunk[group_id] = info['NumPoints']
  133. for c in range(info['NumChan']):
  134. chan_index = len(signal_channels)
  135. chan_id = c + 1 # If several StoreName then chan_id is not unique in TDT!!!!!
  136. # loop over segment to get sampling_rate/data_index/data_buffer
  137. sampling_rate = None
  138. dtype = None
  139. for seg_index, segment_name in enumerate(segment_names):
  140. # get data index
  141. tsq = self._tsq[seg_index]
  142. mask = (tsq['evtype'] == EVTYPE_STREAM) & \
  143. (tsq['evname'] == info['StoreName']) & \
  144. (tsq['channel'] == chan_id)
  145. data_index = tsq[mask].copy()
  146. self._sigs_index[seg_index][chan_index] = data_index
  147. size = info['NumPoints'] * data_index.size
  148. if group_id not in self._sigs_lengths[seg_index]:
  149. self._sigs_lengths[seg_index][group_id] = size
  150. else:
  151. assert self._sigs_lengths[seg_index][group_id] == size
  152. # signal start time, relative to start of segment
  153. t_start = data_index['timestamp'][0]
  154. if group_id not in self._sigs_t_start[seg_index]:
  155. self._sigs_t_start[seg_index][group_id] = t_start
  156. else:
  157. assert self._sigs_t_start[seg_index][group_id] == t_start
  158. # sampling_rate and dtype
  159. _sampling_rate = float(data_index['frequency'][0])
  160. _dtype = data_formats[data_index['dataformat'][0]]
  161. if sampling_rate is None:
  162. sampling_rate = _sampling_rate
  163. dtype = _dtype
  164. if group_id not in self._sig_dtype_by_group:
  165. self._sig_dtype_by_group[group_id] = np.dtype(dtype)
  166. else:
  167. assert self._sig_dtype_by_group[group_id] == dtype
  168. else:
  169. assert sampling_rate == _sampling_rate, 'sampling is changing!!!'
  170. assert dtype == _dtype, 'sampling is changing!!!'
  171. # data buffer test if SEV file exists otherwise TEV
  172. path = os.path.join(self.dirname, segment_name)
  173. sev_filename = os.path.join(path, tankname + '_' + segment_name + '_'
  174. + info['StoreName'].decode('ascii')
  175. + '_ch' + str(chan_id) + '.sev')
  176. if os.path.exists(sev_filename):
  177. data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8')
  178. else:
  179. data = self._tev_datas[seg_index]
  180. assert data is not None, 'no TEV nor SEV'
  181. self._sigs_data_buf[seg_index][chan_index] = data
  182. chan_name = '{} {}'.format(info['StoreName'], c + 1)
  183. sampling_rate = sampling_rate
  184. units = 'V' # WARNING this is not sur at all
  185. gain = 1.
  186. offset = 0.
  187. signal_channels.append((chan_name, chan_id, sampling_rate, dtype,
  188. units, gain, offset, group_id))
  189. signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
  190. # unit channels EVTYPE_SNIP
  191. self.internal_unit_ids = {}
  192. self._waveforms_size = []
  193. self._waveforms_dtype = []
  194. unit_channels = []
  195. keep = info_channel_groups['TankEvType'] == EVTYPE_SNIP
  196. tsq = np.hstack(self._tsq)
  197. # If there is no chance the differet TSQ files will have different units,
  198. # then we can do tsq = self._tsq[0]
  199. for info in info_channel_groups[keep]:
  200. for c in range(info['NumChan']):
  201. chan_id = c + 1
  202. mask = (tsq['evtype'] == EVTYPE_SNIP) & \
  203. (tsq['evname'] == info['StoreName']) & \
  204. (tsq['channel'] == chan_id)
  205. unit_ids = np.unique(tsq[mask]['sortcode'])
  206. for unit_id in unit_ids:
  207. unit_index = len(unit_channels)
  208. self.internal_unit_ids[unit_index] = (info['StoreName'], chan_id, unit_id)
  209. unit_name = "ch{}#{}".format(chan_id, unit_id)
  210. wf_units = 'V'
  211. wf_gain = 1.
  212. wf_offset = 0.
  213. wf_left_sweep = info['NumPoints'] // 2
  214. wf_sampling_rate = info['SampleFreq']
  215. unit_channels.append((unit_name, '{}'.format(unit_id),
  216. wf_units, wf_gain, wf_offset,
  217. wf_left_sweep, wf_sampling_rate))
  218. self._waveforms_size.append(info['NumPoints'])
  219. self._waveforms_dtype.append(np.dtype(data_formats[info['DataFormat']]))
  220. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  221. # signal channels EVTYPE_STRON
  222. event_channels = []
  223. keep = info_channel_groups['TankEvType'] == EVTYPE_STRON
  224. for info in info_channel_groups[keep]:
  225. chan_name = info['StoreName']
  226. chan_id = 1
  227. event_channels.append((chan_name, chan_id, 'event'))
  228. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  229. # fill into header dict
  230. self.header = {}
  231. self.header['nb_block'] = 1
  232. self.header['nb_segment'] = [nb_segment]
  233. self.header['signal_channels'] = signal_channels
  234. self.header['unit_channels'] = unit_channels
  235. self.header['event_channels'] = event_channels
  236. # Annotations only standard ones:
  237. self._generate_minimal_annotations()
  238. def _block_count(self):
  239. return 1
  240. def _segment_count(self, block_index):
  241. return self.header['nb_segment'][block_index]
  242. def _segment_t_start(self, block_index, seg_index):
  243. return self._seg_t_starts[seg_index] - self._global_t_start
  244. def _segment_t_stop(self, block_index, seg_index):
  245. return self._seg_t_stops[seg_index] - self._global_t_start
  246. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  247. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  248. size = self._sigs_lengths[seg_index][group_id]
  249. return size
  250. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  251. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  252. return self._sigs_t_start[seg_index][group_id] - self._global_t_start
  253. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  254. # check of channel_indexes is same group_id is done outside (BaseRawIO)
  255. # so first is identique to others
  256. group_id = self.header['signal_channels'][channel_indexes[0]]['group_id']
  257. if i_start is None:
  258. i_start = 0
  259. if i_stop is None:
  260. i_stop = self._sigs_lengths[seg_index][group_id]
  261. dt = self._sig_dtype_by_group[group_id]
  262. raw_signals = np.zeros((i_stop - i_start, len(channel_indexes)), dtype=dt)
  263. sample_per_chunk = self._sig_sample_per_chunk[group_id]
  264. bl0 = i_start // sample_per_chunk
  265. bl1 = int(np.ceil(i_stop / sample_per_chunk))
  266. chunk_nb_bytes = sample_per_chunk * dt.itemsize
  267. for c, channel_index in enumerate(channel_indexes):
  268. data_index = self._sigs_index[seg_index][channel_index]
  269. data_buf = self._sigs_data_buf[seg_index][channel_index]
  270. # loop over data blocks and get chunks
  271. ind = 0
  272. for bl in range(bl0, bl1):
  273. ind0 = data_index[bl]['offset']
  274. ind1 = ind0 + chunk_nb_bytes
  275. data = data_buf[ind0:ind1].view(dt)
  276. if bl == bl1 - 1:
  277. # right border
  278. # be careful that bl could be both bl0 and bl1!!
  279. border = data.size - (i_stop % sample_per_chunk)
  280. data = data[:-border]
  281. if bl == bl0:
  282. # left border
  283. border = i_start % sample_per_chunk
  284. data = data[border:]
  285. raw_signals[ind:data.size + ind, c] = data
  286. ind += data.size
  287. return raw_signals
  288. def _get_mask(self, tsq, seg_index, evtype, evname, chan_id, unit_id, t_start, t_stop):
  289. """Used inside spike and events methods"""
  290. mask = (tsq['evtype'] == evtype) & \
  291. (tsq['evname'] == evname) & \
  292. (tsq['channel'] == chan_id)
  293. if unit_id is not None:
  294. mask &= (tsq['sortcode'] == unit_id)
  295. if t_start is not None:
  296. mask &= tsq['timestamp'] >= (t_start + self._global_t_start)
  297. if t_stop is not None:
  298. mask &= tsq['timestamp'] <= (t_stop + self._global_t_start)
  299. return mask
  300. def _spike_count(self, block_index, seg_index, unit_index):
  301. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  302. tsq = self._tsq[seg_index]
  303. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  304. chan_id, unit_id, None, None)
  305. nb_spike = np.sum(mask)
  306. return nb_spike
  307. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  308. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  309. tsq = self._tsq[seg_index]
  310. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  311. chan_id, unit_id, t_start, t_stop)
  312. timestamps = tsq[mask]['timestamp']
  313. timestamps -= self._global_t_start
  314. return timestamps
  315. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  316. # already in s
  317. spike_times = spike_timestamps.astype(dtype)
  318. return spike_times
  319. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  320. store_name, chan_id, unit_id = self.internal_unit_ids[unit_index]
  321. tsq = self._tsq[seg_index]
  322. mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name,
  323. chan_id, unit_id, t_start, t_stop)
  324. nb_spike = np.sum(mask)
  325. data = self._tev_datas[seg_index]
  326. dt = self._waveforms_dtype[unit_index]
  327. nb_sample = self._waveforms_size[unit_index]
  328. waveforms = np.zeros((nb_spike, 1, nb_sample), dtype=dt)
  329. for i, e in enumerate(tsq[mask]):
  330. ind0 = e['offset']
  331. ind1 = ind0 + nb_sample * dt.itemsize
  332. waveforms[i, 0, :] = data[ind0:ind1].view(dt)
  333. return waveforms
  334. def _event_count(self, block_index, seg_index, event_channel_index):
  335. h = self.header['event_channels'][event_channel_index]
  336. store_name = h['name'].encode('ascii')
  337. tsq = self._tsq[seg_index]
  338. chan_id = 0
  339. mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
  340. nb_event = np.sum(mask)
  341. return nb_event
  342. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  343. h = self.header['event_channels'][event_channel_index]
  344. store_name = h['name'].encode('ascii')
  345. tsq = self._tsq[seg_index]
  346. chan_id = 0
  347. mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None)
  348. timestamps = tsq[mask]['timestamp']
  349. timestamps -= self._global_t_start
  350. labels = tsq[mask]['offset'].astype('U')
  351. durations = None
  352. # TODO if user demand event to epoch
  353. # with EVTYPE_STROFF=258
  354. # and so durations would be not None
  355. # it was not implemented in previous IO.
  356. return timestamps, durations, labels
  357. def _rescale_event_timestamp(self, event_timestamps, dtype):
  358. # already in s
  359. ev_times = event_timestamps.astype(dtype)
  360. return ev_times
  361. tbk_field_types = [
  362. ('StoreName', 'S4'),
  363. ('HeadName', 'S16'),
  364. ('Enabled', 'bool'),
  365. ('CircType', 'int'),
  366. ('NumChan', 'int'),
  367. ('StrobeMode', 'int'),
  368. ('TankEvType', 'int32'),
  369. ('NumPoints', 'int'),
  370. ('DataFormat', 'int'),
  371. ('SampleFreq', 'float64'),
  372. ]
  373. def read_tbk(tbk_filename):
  374. """
  375. Tbk contains some visible header in txt mode to describe
  376. channel group info.
  377. """
  378. with open(tbk_filename, mode='rb') as f:
  379. txt_header = f.read()
  380. infos = []
  381. for chan_grp_header in txt_header.split(b'[STOREHDRITEM]'):
  382. if chan_grp_header.startswith(b'[USERNOTEDELIMITER]'):
  383. break
  384. # parse into a dict
  385. info = OrderedDict()
  386. pattern = br'NAME=(\S+);TYPE=(\S+);VALUE=(\S+);'
  387. r = re.findall(pattern, chan_grp_header)
  388. for name, _type, value in r:
  389. info[name.decode('ascii')] = value
  390. infos.append(info)
  391. # and put into numpy
  392. info_channel_groups = np.zeros(len(infos), dtype=tbk_field_types)
  393. for i, info in enumerate(infos):
  394. for k, dt in tbk_field_types:
  395. v = np.dtype(dt).type(info[k])
  396. info_channel_groups[i][k] = v
  397. return info_channel_groups
  398. tsq_dtype = [
  399. ('size', 'int32'), # bytes 0-4
  400. ('evtype', 'int32'), # bytes 5-8
  401. ('evname', 'S4'), # bytes 9-12
  402. ('channel', 'uint16'), # bytes 13-14
  403. ('sortcode', 'uint16'), # bytes 15-16
  404. ('timestamp', 'float64'), # bytes 17-24
  405. ('offset', 'int64'), # bytes 25-32
  406. ('dataformat', 'int32'), # bytes 33-36
  407. ('frequency', 'float32'), # bytes 37-40
  408. ]
  409. EVTYPE_UNKNOWN = int('00000000', 16) # 0
  410. EVTYPE_STRON = int('00000101', 16) # 257
  411. EVTYPE_STROFF = int('00000102', 16) # 258
  412. EVTYPE_SCALAR = int('00000201', 16) # 513
  413. EVTYPE_STREAM = int('00008101', 16) # 33025
  414. EVTYPE_SNIP = int('00008201', 16) # 33281
  415. EVTYPE_MARK = int('00008801', 16) # 34817
  416. EVTYPE_HASDATA = int('00008000', 16) # 32768
  417. EVTYPE_UCF = int('00000010', 16) # 16
  418. EVTYPE_PHANTOM = int('00000020', 16) # 32
  419. EVTYPE_MASK = int('0000FF0F', 16) # 65295
  420. EVTYPE_INVALID_MASK = int('FFFF0000', 16) # 4294901760
  421. EVMARK_STARTBLOCK = int('0001', 16) # 1
  422. EVMARK_STOPBLOCK = int('0002', 16) # 2
  423. data_formats = {
  424. 0: 'float32',
  425. 1: 'int32',
  426. 2: 'int16',
  427. 3: 'int8',
  428. 4: 'float64',
  429. }
  430. def is_tdtblock(blockpath):
  431. """Is tha path a TDT block (=neo.Segment) ?"""
  432. file_ext = list()
  433. if os.path.isdir(blockpath):
  434. # for every file, get extension, convert to lowercase and append
  435. for file in os.listdir(blockpath):
  436. file_ext.append(os.path.splitext(file)[1].lower())
  437. file_ext = set(file_ext)
  438. tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'}
  439. if file_ext >= tdt_ext: # if containing all the necessary files
  440. return True
  441. else:
  442. return False