neuroexplorerrawio.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading data from NeuroExplorer (.nex)
  4. Note:
  5. * NeuroExplorer have introduced a new .nex5 file format
  6. with 64 timestamps. This is NOT implemented here.
  7. If someone have some file in that new format we could also
  8. integrate it in neo
  9. * NeuroExplorer now provide there own python class for
  10. reading/writting nex and nex5. This could be usefull
  11. for testing this class.
  12. Porting NeuroExplorerIO to NeuroExplorerRawIO have some
  13. limitation because in neuro explorer signals can differents sampling
  14. rate and shape. So NeuroExplorerRawIO can read only one channel
  15. at once.
  16. Documentation for dev :
  17. http://www.neuroexplorer.com/downloadspage/
  18. Author: Samuel Garcia, luc estebanez, mark hollenbeck
  19. """
  20. from __future__ import print_function, division, absolute_import
  21. # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
  22. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  23. _event_channel_dtype)
  24. import numpy as np
  25. from collections import OrderedDict
  26. import datetime
  27. class NeuroExplorerRawIO(BaseRawIO):
  28. extensions = ['nex']
  29. rawmode = 'one-file'
  30. def __init__(self, filename=''):
  31. BaseRawIO.__init__(self)
  32. self.filename = filename
  33. def _source_name(self):
  34. return self.filename
  35. def _parse_header(self):
  36. with open(self.filename, 'rb') as fid:
  37. self.global_header = read_as_dict(fid, GlobalHeader, offset=0)
  38. offset = 544
  39. self._entity_headers = []
  40. for i in range(self.global_header['nvar']):
  41. self._entity_headers.append(read_as_dict(
  42. fid, EntityHeader, offset=offset + i * 208))
  43. self._memmap = np.memmap(self.filename, dtype='u1', mode='r')
  44. self._sig_lengths = []
  45. self._sig_t_starts = []
  46. sig_channels = []
  47. unit_channels = []
  48. event_channels = []
  49. for i in range(self.global_header['nvar']):
  50. entity_header = self._entity_headers[i]
  51. name = entity_header['name']
  52. _id = i
  53. if entity_header['type'] == 0: # Unit
  54. unit_channels.append((name, _id, '', 0, 0, 0, 0))
  55. elif entity_header['type'] == 1: # Event
  56. event_channels.append((name, _id, 'event'))
  57. elif entity_header['type'] == 2: # interval = Epoch
  58. event_channels.append((name, _id, 'epoch'))
  59. elif entity_header['type'] == 3: # spiketrain and wavefoms
  60. wf_units = 'mV'
  61. wf_gain = entity_header['ADtoMV']
  62. wf_offset = entity_header['MVOffset']
  63. wf_left_sweep = 0
  64. wf_sampling_rate = entity_header['WFrequency']
  65. unit_channels.append((name, _id, wf_units, wf_gain, wf_offset,
  66. wf_left_sweep, wf_sampling_rate))
  67. elif entity_header['type'] == 4:
  68. # popvectors
  69. pass
  70. if entity_header['type'] == 5: # Signals
  71. units = 'mV'
  72. sampling_rate = entity_header['WFrequency']
  73. dtype = 'int16'
  74. gain = entity_header['ADtoMV']
  75. offset = entity_header['MVOffset']
  76. group_id = 0
  77. sig_channels.append((name, _id, sampling_rate, dtype, units,
  78. gain, offset, group_id))
  79. self._sig_lengths.append(entity_header['NPointsWave'])
  80. # sig t_start is the first timestamp if datablock
  81. offset = entity_header['offset']
  82. timestamps0 = self._memmap[offset:offset + 4].view('int32')
  83. t_start = timestamps0[0] / self.global_header['freq']
  84. self._sig_t_starts.append(t_start)
  85. elif entity_header['type'] == 6: # Markers
  86. event_channels.append((name, _id, 'event'))
  87. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  88. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  89. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  90. # each signal channel have a dierent groups that force reading
  91. # them one by one
  92. sig_channels['group_id'] = np.arange(sig_channels.size)
  93. # fill into header dict
  94. self.header = {}
  95. self.header['nb_block'] = 1
  96. self.header['nb_segment'] = [1]
  97. self.header['signal_channels'] = sig_channels
  98. self.header['unit_channels'] = unit_channels
  99. self.header['event_channels'] = event_channels
  100. # Annotations
  101. self._generate_minimal_annotations()
  102. bl_annotations = self.raw_annotations['blocks'][0]
  103. seg_annotations = bl_annotations['segments'][0]
  104. for d in (bl_annotations, seg_annotations):
  105. d['neuroexplorer_version'] = self.global_header['version']
  106. d['comment'] = self.global_header['comment']
  107. def _segment_t_start(self, block_index, seg_index):
  108. t_start = self.global_header['tbeg'] / self.global_header['freq']
  109. return t_start
  110. def _segment_t_stop(self, block_index, seg_index):
  111. t_stop = self.global_header['tend'] / self.global_header['freq']
  112. return t_stop
  113. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  114. assert len(channel_indexes) == 1, 'only one channel by one channel'
  115. return self._sig_lengths[channel_indexes[0]]
  116. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  117. assert len(channel_indexes) == 1, 'only one channel by one channel'
  118. return self._sig_t_starts[channel_indexes[0]]
  119. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  120. assert len(channel_indexes) == 1, 'only one channel by one channel'
  121. channel_index = channel_indexes[0]
  122. entity_index = int(self.header['signal_channels'][channel_index]['id'])
  123. entity_header = self._entity_headers[entity_index]
  124. n = entity_header['n']
  125. nb_sample = entity_header['NPointsWave']
  126. # offset = entity_header['offset']
  127. # timestamps = self._memmap[offset:offset+n*4].view('int32')
  128. # offset2 = entity_header['offset'] + n*4
  129. # fragment_starts = self._memmap[offset2:offset2+n*4].view('int32')
  130. offset3 = entity_header['offset'] + n * 4 + n * 4
  131. raw_signal = self._memmap[offset3:offset3 + nb_sample * 2].view('int16')
  132. raw_signal = raw_signal[slice(i_start, i_stop), None] # 2D for compliance
  133. return raw_signal
  134. def _spike_count(self, block_index, seg_index, unit_index):
  135. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  136. entity_header = self._entity_headers[entity_index]
  137. nb_spike = entity_header['n']
  138. return nb_spike
  139. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  140. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  141. entity_header = self._entity_headers[entity_index]
  142. n = entity_header['n']
  143. offset = entity_header['offset']
  144. timestamps = self._memmap[offset:offset + n * 4].view('int32')
  145. if t_start is not None:
  146. keep = timestamps >= int(t_start * self.global_header['freq'])
  147. timestamps = timestamps[keep]
  148. if t_stop is not None:
  149. keep = timestamps <= int(t_stop * self.global_header['freq'])
  150. timestamps = timestamps[keep]
  151. return timestamps
  152. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  153. spike_times = spike_timestamps.astype(dtype)
  154. spike_times /= self.global_header['freq']
  155. return spike_times
  156. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  157. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  158. entity_header = self._entity_headers[entity_index]
  159. if entity_header['type'] == 0:
  160. return None
  161. assert entity_header['type'] == 3
  162. n = entity_header['n']
  163. width = entity_header['NPointsWave']
  164. offset = entity_header['offset'] + n * 2
  165. waveforms = self._memmap[offset:offset + n * 2 * width].view('int16')
  166. waveforms = waveforms.reshape(n, 1, width)
  167. return waveforms
  168. def _event_count(self, block_index, seg_index, event_channel_index):
  169. entity_index = int(self.header['event_channels'][event_channel_index]['id'])
  170. entity_header = self._entity_headers[entity_index]
  171. nb_event = entity_header['n']
  172. return nb_event
  173. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  174. entity_index = int(self.header['event_channels'][event_channel_index]['id'])
  175. entity_header = self._entity_headers[entity_index]
  176. n = entity_header['n']
  177. offset = entity_header['offset']
  178. timestamps = self._memmap[offset:offset + n * 4].view('int32')
  179. if t_start is None:
  180. i_start = None
  181. else:
  182. i_start = np.searchsorted(timestamps, int(t_start * self.global_header['freq']))
  183. if t_stop is None:
  184. i_stop = None
  185. else:
  186. i_stop = np.searchsorted(timestamps, int(t_stop * self.global_header['freq']))
  187. keep = slice(i_start, i_stop)
  188. timestamps = timestamps[keep]
  189. if entity_header['type'] == 1: # Event
  190. durations = None
  191. labels = np.array([''] * timestamps.size, dtype='U')
  192. elif entity_header['type'] == 2: # Epoch
  193. offset2 = offset + n * 4
  194. stop_timestamps = self._memmap[offset2:offset2 + n * 4].view('int32')
  195. durations = stop_timestamps[keep] - timestamps
  196. labels = np.array([''] * timestamps.size, dtype='U')
  197. elif entity_header['type'] == 6: # Marker
  198. durations = None
  199. offset2 = offset + n * 4 + 64
  200. s = entity_header['MarkerLength']
  201. labels = self._memmap[offset2:offset2 + s * n].view('S' + str(s))
  202. labels = labels[keep].astype('U')
  203. return timestamps, durations, labels
  204. def _rescale_event_timestamp(self, event_timestamps, dtype):
  205. event_times = event_timestamps.astype(dtype)
  206. event_times /= self.global_header['freq']
  207. return event_times
  208. def _rescale_epoch_duration(self, raw_duration, dtype):
  209. durations = raw_duration.astype(dtype)
  210. durations /= self.global_header['freq']
  211. return durations
  212. def read_as_dict(fid, dtype, offset=None):
  213. """
  214. Given a file descriptor
  215. and a numpy.dtype of the binary struct return a dict.
  216. Make conversion for strings.
  217. """
  218. if offset is not None:
  219. fid.seek(offset)
  220. dt = np.dtype(dtype)
  221. h = np.frombuffer(fid.read(dt.itemsize), dt)[0]
  222. info = OrderedDict()
  223. for k in dt.names:
  224. v = h[k]
  225. if dt[k].kind == 'S':
  226. v = v.replace(b'\x00', b'')
  227. v = v.decode('utf8')
  228. info[k] = v
  229. return info
  230. GlobalHeader = [
  231. ('signature', 'S4'),
  232. ('version', 'int32'),
  233. ('comment', 'S256'),
  234. ('freq', 'float64'),
  235. ('tbeg', 'int32'),
  236. ('tend', 'int32'),
  237. ('nvar', 'int32'),
  238. ]
  239. EntityHeader = [
  240. ('type', 'int32'),
  241. ('varVersion', 'int32'),
  242. ('name', 'S64'),
  243. ('offset', 'int32'),
  244. ('n', 'int32'),
  245. ('WireNumber', 'int32'),
  246. ('UnitNumber', 'int32'),
  247. ('Gain', 'int32'),
  248. ('Filter', 'int32'),
  249. ('XPos', 'float64'),
  250. ('YPos', 'float64'),
  251. ('WFrequency', 'float64'),
  252. ('ADtoMV', 'float64'),
  253. ('NPointsWave', 'int32'),
  254. ('NMarkers', 'int32'),
  255. ('MarkerLength', 'int32'),
  256. ('MVOffset', 'float64'),
  257. ('dummy', 'S60'),
  258. ]
  259. MarkerHeader = [
  260. ('type', 'int32'),
  261. ('varVersion', 'int32'),
  262. ('name', 'S64'),
  263. ('offset', 'int32'),
  264. ('n', 'int32'),
  265. ('WireNumber', 'int32'),
  266. ('UnitNumber', 'int32'),
  267. ('Gain', 'int32'),
  268. ('Filter', 'int32'),
  269. ]