neuroexplorerrawio.py 12 KB


  1. """
  2. Class for reading data from NeuroExplorer (.nex)
  3. Note:
  4. * NeuroExplorer have introduced a new .nex5 file format
  5. with 64 timestamps. This is NOT implemented here.
  6. If someone have some file in that new format we could also
  7. integrate it in neo
  8. * NeuroExplorer now provide there own python class for
  9. reading/writting nex and nex5. This could be usefull
  10. for testing this class.
  11. Porting NeuroExplorerIO to NeuroExplorerRawIO have some
  12. limitation because in neuro explorer signals can differents sampling
  13. rate and shape. So NeuroExplorerRawIO can read only one channel
  14. at once.
  15. Documentation for dev :
  16. http://www.neuroexplorer.com/downloadspage/
  17. Author: Samuel Garcia, luc estebanez, mark hollenbeck
  18. """
  19. # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
  20. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  21. _event_channel_dtype)
  22. import numpy as np
  23. from collections import OrderedDict
  24. import datetime
  25. class NeuroExplorerRawIO(BaseRawIO):
  26. extensions = ['nex']
  27. rawmode = 'one-file'
  28. def __init__(self, filename=''):
  29. BaseRawIO.__init__(self)
  30. self.filename = filename
  31. def _source_name(self):
  32. return self.filename
  33. def _parse_header(self):
  34. with open(self.filename, 'rb') as fid:
  35. self.global_header = read_as_dict(fid, GlobalHeader, offset=0)
  36. offset = 544
  37. self._entity_headers = []
  38. for i in range(self.global_header['nvar']):
  39. self._entity_headers.append(read_as_dict(
  40. fid, EntityHeader, offset=offset + i * 208))
  41. self._memmap = np.memmap(self.filename, dtype='u1', mode='r')
  42. self._sig_lengths = []
  43. self._sig_t_starts = []
  44. sig_channels = []
  45. unit_channels = []
  46. event_channels = []
  47. for i in range(self.global_header['nvar']):
  48. entity_header = self._entity_headers[i]
  49. name = entity_header['name']
  50. _id = i
  51. if entity_header['type'] == 0: # Unit
  52. unit_channels.append((name, _id, '', 0, 0, 0, 0))
  53. elif entity_header['type'] == 1: # Event
  54. event_channels.append((name, _id, 'event'))
  55. elif entity_header['type'] == 2: # interval = Epoch
  56. event_channels.append((name, _id, 'epoch'))
  57. elif entity_header['type'] == 3: # spiketrain and wavefoms
  58. wf_units = 'mV'
  59. wf_gain = entity_header['ADtoMV']
  60. wf_offset = entity_header['MVOffset']
  61. wf_left_sweep = 0
  62. wf_sampling_rate = entity_header['WFrequency']
  63. unit_channels.append((name, _id, wf_units, wf_gain, wf_offset,
  64. wf_left_sweep, wf_sampling_rate))
  65. elif entity_header['type'] == 4:
  66. # popvectors
  67. pass
  68. if entity_header['type'] == 5: # Signals
  69. units = 'mV'
  70. sampling_rate = entity_header['WFrequency']
  71. dtype = 'int16'
  72. gain = entity_header['ADtoMV']
  73. offset = entity_header['MVOffset']
  74. group_id = 0
  75. sig_channels.append((name, _id, sampling_rate, dtype, units,
  76. gain, offset, group_id))
  77. self._sig_lengths.append(entity_header['NPointsWave'])
  78. # sig t_start is the first timestamp if datablock
  79. offset = entity_header['offset']
  80. timestamps0 = self._memmap[offset:offset + 4].view('int32')
  81. t_start = timestamps0[0] / self.global_header['freq']
  82. self._sig_t_starts.append(t_start)
  83. elif entity_header['type'] == 6: # Markers
  84. event_channels.append((name, _id, 'event'))
  85. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  86. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  87. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  88. # each signal channel have a dierent groups that force reading
  89. # them one by one
  90. sig_channels['group_id'] = np.arange(sig_channels.size)
  91. # fill into header dict
  92. self.header = {}
  93. self.header['nb_block'] = 1
  94. self.header['nb_segment'] = [1]
  95. self.header['signal_channels'] = sig_channels
  96. self.header['unit_channels'] = unit_channels
  97. self.header['event_channels'] = event_channels
  98. # Annotations
  99. self._generate_minimal_annotations()
  100. bl_annotations = self.raw_annotations['blocks'][0]
  101. seg_annotations = bl_annotations['segments'][0]
  102. for d in (bl_annotations, seg_annotations):
  103. d['neuroexplorer_version'] = self.global_header['version']
  104. d['comment'] = self.global_header['comment']
  105. def _segment_t_start(self, block_index, seg_index):
  106. t_start = self.global_header['tbeg'] / self.global_header['freq']
  107. return t_start
  108. def _segment_t_stop(self, block_index, seg_index):
  109. t_stop = self.global_header['tend'] / self.global_header['freq']
  110. return t_stop
  111. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  112. assert len(channel_indexes) == 1, 'only one channel by one channel'
  113. return self._sig_lengths[channel_indexes[0]]
  114. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  115. assert len(channel_indexes) == 1, 'only one channel by one channel'
  116. return self._sig_t_starts[channel_indexes[0]]
  117. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  118. assert len(channel_indexes) == 1, 'only one channel by one channel'
  119. channel_index = channel_indexes[0]
  120. entity_index = int(self.header['signal_channels'][channel_index]['id'])
  121. entity_header = self._entity_headers[entity_index]
  122. n = entity_header['n']
  123. nb_sample = entity_header['NPointsWave']
  124. # offset = entity_header['offset']
  125. # timestamps = self._memmap[offset:offset+n*4].view('int32')
  126. # offset2 = entity_header['offset'] + n*4
  127. # fragment_starts = self._memmap[offset2:offset2+n*4].view('int32')
  128. offset3 = entity_header['offset'] + n * 4 + n * 4
  129. raw_signal = self._memmap[offset3:offset3 + nb_sample * 2].view('int16')
  130. raw_signal = raw_signal[slice(i_start, i_stop), None] # 2D for compliance
  131. return raw_signal
  132. def _spike_count(self, block_index, seg_index, unit_index):
  133. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  134. entity_header = self._entity_headers[entity_index]
  135. nb_spike = entity_header['n']
  136. return nb_spike
  137. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  138. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  139. entity_header = self._entity_headers[entity_index]
  140. n = entity_header['n']
  141. offset = entity_header['offset']
  142. timestamps = self._memmap[offset:offset + n * 4].view('int32')
  143. if t_start is not None:
  144. keep = timestamps >= int(t_start * self.global_header['freq'])
  145. timestamps = timestamps[keep]
  146. if t_stop is not None:
  147. keep = timestamps <= int(t_stop * self.global_header['freq'])
  148. timestamps = timestamps[keep]
  149. return timestamps
  150. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  151. spike_times = spike_timestamps.astype(dtype)
  152. spike_times /= self.global_header['freq']
  153. return spike_times
  154. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  155. entity_index = int(self.header['unit_channels'][unit_index]['id'])
  156. entity_header = self._entity_headers[entity_index]
  157. if entity_header['type'] == 0:
  158. return None
  159. assert entity_header['type'] == 3
  160. n = entity_header['n']
  161. width = entity_header['NPointsWave']
  162. offset = entity_header['offset'] + n * 2
  163. waveforms = self._memmap[offset:offset + n * 2 * width].view('int16')
  164. waveforms = waveforms.reshape(n, 1, width)
  165. return waveforms
  166. def _event_count(self, block_index, seg_index, event_channel_index):
  167. entity_index = int(self.header['event_channels'][event_channel_index]['id'])
  168. entity_header = self._entity_headers[entity_index]
  169. nb_event = entity_header['n']
  170. return nb_event
  171. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  172. entity_index = int(self.header['event_channels'][event_channel_index]['id'])
  173. entity_header = self._entity_headers[entity_index]
  174. n = entity_header['n']
  175. offset = entity_header['offset']
  176. timestamps = self._memmap[offset:offset + n * 4].view('int32')
  177. if t_start is None:
  178. i_start = None
  179. else:
  180. i_start = np.searchsorted(timestamps, int(t_start * self.global_header['freq']))
  181. if t_stop is None:
  182. i_stop = None
  183. else:
  184. i_stop = np.searchsorted(timestamps, int(t_stop * self.global_header['freq']))
  185. keep = slice(i_start, i_stop)
  186. timestamps = timestamps[keep]
  187. if entity_header['type'] == 1: # Event
  188. durations = None
  189. labels = np.array([''] * timestamps.size, dtype='U')
  190. elif entity_header['type'] == 2: # Epoch
  191. offset2 = offset + n * 4
  192. stop_timestamps = self._memmap[offset2:offset2 + n * 4].view('int32')
  193. durations = stop_timestamps[keep] - timestamps
  194. labels = np.array([''] * timestamps.size, dtype='U')
  195. elif entity_header['type'] == 6: # Marker
  196. durations = None
  197. offset2 = offset + n * 4 + 64
  198. s = entity_header['MarkerLength']
  199. labels = self._memmap[offset2:offset2 + s * n].view('S' + str(s))
  200. labels = labels[keep].astype('U')
  201. return timestamps, durations, labels
  202. def _rescale_event_timestamp(self, event_timestamps, dtype):
  203. event_times = event_timestamps.astype(dtype)
  204. event_times /= self.global_header['freq']
  205. return event_times
  206. def _rescale_epoch_duration(self, raw_duration, dtype):
  207. durations = raw_duration.astype(dtype)
  208. durations /= self.global_header['freq']
  209. return durations
  210. def read_as_dict(fid, dtype, offset=None):
  211. """
  212. Given a file descriptor
  213. and a numpy.dtype of the binary struct return a dict.
  214. Make conversion for strings.
  215. """
  216. if offset is not None:
  217. fid.seek(offset)
  218. dt = np.dtype(dtype)
  219. h = np.frombuffer(fid.read(dt.itemsize), dt)[0]
  220. info = OrderedDict()
  221. for k in dt.names:
  222. v = h[k]
  223. if dt[k].kind == 'S':
  224. v = v.replace(b'\x00', b'')
  225. v = v.decode('utf8')
  226. info[k] = v
  227. return info
  228. GlobalHeader = [
  229. ('signature', 'S4'),
  230. ('version', 'int32'),
  231. ('comment', 'S256'),
  232. ('freq', 'float64'),
  233. ('tbeg', 'int32'),
  234. ('tend', 'int32'),
  235. ('nvar', 'int32'),
  236. ]
  237. EntityHeader = [
  238. ('type', 'int32'),
  239. ('varVersion', 'int32'),
  240. ('name', 'S64'),
  241. ('offset', 'int32'),
  242. ('n', 'int32'),
  243. ('WireNumber', 'int32'),
  244. ('UnitNumber', 'int32'),
  245. ('Gain', 'int32'),
  246. ('Filter', 'int32'),
  247. ('XPos', 'float64'),
  248. ('YPos', 'float64'),
  249. ('WFrequency', 'float64'),
  250. ('ADtoMV', 'float64'),
  251. ('NPointsWave', 'int32'),
  252. ('NMarkers', 'int32'),
  253. ('MarkerLength', 'int32'),
  254. ('MVOffset', 'float64'),
  255. ('dummy', 'S60'),
  256. ]
  257. MarkerHeader = [
  258. ('type', 'int32'),
  259. ('varVersion', 'int32'),
  260. ('name', 'S64'),
  261. ('offset', 'int32'),
  262. ('n', 'int32'),
  263. ('WireNumber', 'int32'),
  264. ('UnitNumber', 'int32'),
  265. ('Gain', 'int32'),
  266. ('Filter', 'int32'),
  267. ]