tdtio.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading data from from Tucker Davis TTank format.
  4. Terminology:
  5. TDT hold data with tanks (actually a directory). And tanks hold sub block
  6. (sub directories).
  7. Tanks correspond to neo.Block and tdt block correspond to neo.Segment.
  8. Note the name Block is ambiguous because it does not refer to same thing in TDT
  9. terminology and neo.
  10. Depend on:
  11. Supported : Read
  12. Author: sgarcia
  13. """
  14. import os
  15. import struct
  16. import sys
  17. import numpy as np
  18. import quantities as pq
  19. import itertools
  20. from neo.io.baseio import BaseIO
  21. from neo.core import Block, Segment, AnalogSignal, SpikeTrain, Event
  22. PY3K = (sys.version_info[0] == 3)
  23. if not PY3K:
  24. zip = itertools.izip
  25. def get_chunks(sizes, offsets, big_array):
  26. # offsets are octect count
  27. # sizes are not!!
  28. # so need this (I really do not knwo why...):
  29. sizes = (sizes -10) * 4 #
  30. all = np.concatenate([ big_array[o:o+s] for s, o in zip(sizes, offsets) ])
  31. return all
  32. class TdtIO(BaseIO):
  33. """
  34. Class for reading data from from Tucker Davis TTank format.
  35. Usage:
  36. >>> from neo import io
  37. >>> r = io.TdtIO(dirname='aep_05')
  38. >>> bl = r.read_block(lazy=False, cascade=True)
  39. >>> print bl.segments
  40. [<neo.core.segment.Segment object at 0x1060a4d10>]
  41. >>> print bl.segments[0].analogsignals
  42. [<AnalogSignal(array([ 2.18811035, 2.19726562, 2.21252441, ...,
  43. 1.33056641, 1.3458252 , 1.3671875 ], dtype=float32) * pA,
  44. [0.0 s, 191.2832 s], sampling rate: 10000.0 Hz)>]
  45. >>> print bl.segments[0].events
  46. []
  47. """
  48. is_readable = True
  49. is_writable = False
  50. supported_objects = [Block, Segment , AnalogSignal, Event]
  51. readable_objects = [Block, Segment]
  52. writeable_objects = []
  53. has_header = False
  54. is_streameable = False
  55. read_params = {
  56. Block : [],
  57. Segment : []
  58. }
  59. write_params = None
  60. name = 'TDT'
  61. extensions = [ ]
  62. mode = 'dir'
  63. def __init__(self , dirname=None) :
  64. """
  65. **Arguments**
  66. Arguments:
  67. dirname: path of the TDT tank (a directory)
  68. """
  69. BaseIO.__init__(self)
  70. self.dirname = dirname
  71. if self.dirname.endswith('/'):
  72. self.dirname = self.dirname[:-1]
  73. def read_segment(self, blockname=None, lazy=False, cascade=True, sortname=''):
  74. """
  75. Read a single segment from the tank. Note that TDT blocks are Neo
  76. segments, and TDT tanks are Neo blocks, so here the 'blockname' argument
  77. refers to the TDT block's name, which will be the Neo segment name.
  78. 'sortname' is used to specify the external sortcode generated by offline spike sorting.
  79. if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block,
  80. which stores the sortcode for every spike; defaults to '', which uses the original online sort
  81. """
  82. if not blockname:
  83. blockname = os.listdir(self.dirname)[0]
  84. if blockname == 'TempBlk': return None
  85. if not self.is_tdtblock(blockname): return None # if not a tdt block
  86. subdir = os.path.join(self.dirname, blockname)
  87. if not os.path.isdir(subdir): return None
  88. seg = Segment(name=blockname)
  89. tankname = os.path.basename(self.dirname)
  90. #TSQ is the global index
  91. tsq_filename = os.path.join(subdir, tankname+'_'+blockname+'.tsq')
  92. dt = [('size','int32'),
  93. ('evtype','int32'),
  94. ('code','S4'),
  95. ('channel','uint16'),
  96. ('sortcode','uint16'),
  97. ('timestamp','float64'),
  98. ('eventoffset','int64'),
  99. ('dataformat','int32'),
  100. ('frequency','float32'),
  101. ]
  102. tsq = np.fromfile(tsq_filename, dtype=dt)
  103. #0x8801: 'EVTYPE_MARK' give the global_start
  104. global_t_start = tsq[tsq['evtype']==0x8801]['timestamp'][0]
  105. #TEV is the old data file
  106. try:
  107. tev_filename = os.path.join(subdir, tankname+'_'+blockname+'.tev')
  108. #tev_array = np.memmap(tev_filename, mode = 'r', dtype = 'uint8') # if memory problem use this instead
  109. tev_array = np.fromfile(tev_filename, dtype='uint8')
  110. except IOError:
  111. tev_filename = None
  112. #if there exists an external sortcode in ./sort/[sortname]/*.SortResult (generated after offline sortting)
  113. sortresult_filename = None
  114. if sortname is not '':
  115. try:
  116. for file in os.listdir(os.path.join(subdir, 'sort', sortname)):
  117. if file.endswith(".SortResult"):
  118. sortresult_filename = os.path.join(subdir, 'sort', sortname, file)
  119. # get new sortcode
  120. newsorcode = np.fromfile(sortresult_filename,'int8')[1024:] # the first 1024 byte is file header
  121. # update the sort code with the info from this file
  122. tsq['sortcode'][1:-1]=newsorcode
  123. # print('sortcode updated')
  124. break
  125. except OSError:
  126. sortresult_filename = None
  127. except IOError:
  128. sortresult_filename = None
  129. for type_code, type_label in tdt_event_type:
  130. mask1 = tsq['evtype']==type_code
  131. codes = np.unique(tsq[mask1]['code'])
  132. for code in codes:
  133. mask2 = mask1 & (tsq['code']==code)
  134. channels = np.unique(tsq[mask2]['channel'])
  135. for channel in channels:
  136. mask3 = mask2 & (tsq['channel']==channel)
  137. if type_label in ['EVTYPE_STRON', 'EVTYPE_STROFF']:
  138. if lazy:
  139. times = [ ]*pq.s
  140. labels = np.array([ ], dtype='S')
  141. else:
  142. times = (tsq[mask3]['timestamp'] - global_t_start) * pq.s
  143. labels = tsq[mask3]['eventoffset'].view('float64').astype('S')
  144. ea = Event(times=times,
  145. name=str(code),
  146. channel_index=int(channel),
  147. labels=labels)
  148. if lazy:
  149. ea.lazy_shape = np.sum(mask3)
  150. seg.events.append(ea)
  151. elif type_label == 'EVTYPE_SNIP':
  152. sortcodes = np.unique(tsq[mask3]['sortcode'])
  153. for sortcode in sortcodes:
  154. mask4 = mask3 & (tsq['sortcode']==sortcode)
  155. nb_spike = np.sum(mask4)
  156. sr = tsq[mask4]['frequency'][0]
  157. waveformsize = tsq[mask4]['size'][0]-10
  158. if lazy:
  159. times = [ ]*pq.s
  160. waveforms = None
  161. else:
  162. times = (tsq[mask4]['timestamp'] - global_t_start) * pq.s
  163. dt = np.dtype(data_formats[ tsq[mask3]['dataformat'][0]])
  164. waveforms = get_chunks(tsq[mask4]['size'],tsq[mask4]['eventoffset'], tev_array).view(dt)
  165. waveforms = waveforms.reshape(nb_spike, -1, waveformsize)
  166. waveforms = waveforms * pq.mV
  167. if nb_spike > 0:
  168. # t_start = (tsq['timestamp'][0] - global_t_start) * pq.s # this hould work but not
  169. t_start = 0 *pq.s
  170. t_stop = (tsq['timestamp'][-1] - global_t_start) * pq.s
  171. else:
  172. t_start = 0 *pq.s
  173. t_stop = 0 *pq.s
  174. st = SpikeTrain(times = times,
  175. name = 'Chan{0} Code{1}'.format(channel,sortcode),
  176. t_start = t_start,
  177. t_stop = t_stop,
  178. waveforms = waveforms,
  179. left_sweep = waveformsize/2./sr * pq.s,
  180. sampling_rate = sr * pq.Hz,
  181. )
  182. st.annotate(channel_index=channel)
  183. if lazy:
  184. st.lazy_shape = nb_spike
  185. seg.spiketrains.append(st)
  186. elif type_label == 'EVTYPE_STREAM':
  187. dt = np.dtype(data_formats[ tsq[mask3]['dataformat'][0]])
  188. shape = np.sum(tsq[mask3]['size']-10)
  189. sr = tsq[mask3]['frequency'][0]
  190. if lazy:
  191. signal = [ ]
  192. else:
  193. if PY3K:
  194. signame = code.decode('ascii')
  195. else:
  196. signame = code
  197. sev_filename = os.path.join(subdir, tankname+'_'+blockname+'_'+signame+'_ch'+str(channel)+'.sev')
  198. try:
  199. #sig_array = np.memmap(sev_filename, mode = 'r', dtype = 'uint8') # if memory problem use this instead
  200. sig_array = np.fromfile(sev_filename, dtype='uint8')
  201. except IOError:
  202. sig_array = tev_array
  203. signal = get_chunks(tsq[mask3]['size'],tsq[mask3]['eventoffset'], sig_array).view(dt)
  204. anasig = AnalogSignal(signal = signal* pq.V,
  205. name = '{0} {1}'.format(code, channel),
  206. sampling_rate = sr * pq.Hz,
  207. t_start = (tsq[mask3]['timestamp'][0] - global_t_start) * pq.s,
  208. channel_index = int(channel)
  209. )
  210. if lazy:
  211. anasig.lazy_shape = shape
  212. seg.analogsignals.append(anasig)
  213. return seg
  214. def read_block(self, lazy=False, cascade=True, sortname=''):
  215. bl = Block()
  216. tankname = os.path.basename(self.dirname)
  217. bl.file_origin = tankname
  218. if not cascade : return bl
  219. for blockname in os.listdir(self.dirname):
  220. if self.is_tdtblock(blockname): # if the folder is a tdt block
  221. seg = self.read_segment(blockname, lazy, cascade, sortname)
  222. bl.segments.append(seg)
  223. bl.create_many_to_one_relationship()
  224. return bl
  225. # to determine if this folder is a TDT block, based on the extension of the files inside it
  226. # to deal with unexpected files in the tank, e.g. .DS_Store on Mac machines
  227. def is_tdtblock(self, blockname):
  228. file_ext = list()
  229. blockpath = os.path.join(self.dirname, blockname) # get block path
  230. if os.path.isdir(blockpath):
  231. for file in os.listdir( blockpath ): # for every file, get extension, convert to lowercase and append
  232. file_ext.append( os.path.splitext( file )[1].lower() )
  233. file_ext = set(file_ext)
  234. tdt_ext = set(['.tbk', '.tdx', '.tev', '.tsq'])
  235. if file_ext >= tdt_ext: # if containing all the necessary files
  236. return True
  237. else:
  238. return False
  239. tdt_event_type = [
  240. #(0x0,'EVTYPE_UNKNOWN'),
  241. (0x101, 'EVTYPE_STRON'),
  242. (0x102,'EVTYPE_STROFF'),
  243. #(0x201,'EVTYPE_SCALER'),
  244. (0x8101, 'EVTYPE_STREAM'),
  245. (0x8201, 'EVTYPE_SNIP'),
  246. #(0x8801, 'EVTYPE_MARK'),
  247. ]
  248. data_formats = {
  249. 0 : np.float32,
  250. 1 : np.int32,
  251. 2 : np.int16,
  252. 3 : np.int8,
  253. 4 : np.float64,
  254. }