bci2000rawio.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. """
  2. BCI2000RawIO is a class to read BCI2000 .dat files.
  3. https://www.bci2000.org/mediawiki/index.php/Technical_Reference:BCI2000_File_Format
  4. """
  5. from .baserawio import BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype
  6. import numpy as np
  7. import re
  8. try:
  9. from urllib.parse import unquote
  10. except ImportError:
  11. from urllib import url2pathname as unquote
  12. class BCI2000RawIO(BaseRawIO):
  13. """
  14. Class for reading data from a BCI2000 .dat file, either version 1.0 or 1.1
  15. """
  16. extensions = ['dat']
  17. rawmode = 'one-file'
  18. def __init__(self, filename=''):
  19. BaseRawIO.__init__(self)
  20. self.filename = filename
  21. self._my_events = None
  22. def _source_name(self):
  23. return self.filename
  24. def _parse_header(self):
  25. file_info, state_defs, param_defs = parse_bci2000_header(self.filename)
  26. self.header = {}
  27. self.header['nb_block'] = 1
  28. self.header['nb_segment'] = [1]
  29. sig_channels = []
  30. for chan_ix in range(file_info['SourceCh']):
  31. ch_name = param_defs['ChannelNames']['value'][chan_ix] \
  32. if 'ChannelNames' in param_defs and param_defs['ChannelNames']['value'] is not np.nan else 'ch' + str(chan_ix)
  33. chan_id = chan_ix + 1
  34. sr = param_defs['SamplingRate']['value'] # Hz
  35. dtype = file_info['DataFormat']
  36. units = 'uV'
  37. gain = param_defs['SourceChGain']['value'][chan_ix]
  38. if isinstance(gain, str):
  39. r = re.findall(r'(\d+)(\D+)', gain)
  40. # some files have strange units attached to gain
  41. # in that case it is ignored
  42. if len(r) == 1:
  43. gain = r[0][0]
  44. gain = float(gain)
  45. offset = param_defs['SourceChOffset']['value'][chan_ix]
  46. if isinstance(offset, str):
  47. offset = float(offset)
  48. group_id = 0
  49. sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, group_id))
  50. self.header['signal_channels'] = np.array(sig_channels, dtype=_signal_channel_dtype)
  51. self.header['unit_channels'] = np.array([], dtype=_unit_channel_dtype)
  52. # creating event channel for each state variable
  53. event_channels = []
  54. for st_ix, st_tup in enumerate(state_defs):
  55. event_channels.append((st_tup[0], 'ev_' + str(st_ix), 'event'))
  56. self.header['event_channels'] = np.array(event_channels, dtype=_event_channel_dtype)
  57. # Add annotations.
  58. # Generates basic annotations in nested dict self.raw_annotations
  59. self._generate_minimal_annotations()
  60. self.raw_annotations['blocks'][0].update({
  61. 'file_info': file_info,
  62. 'param_defs': param_defs
  63. })
  64. for ev_ix, ev_dict in enumerate(self.raw_annotations['event_channels']):
  65. ev_dict.update({
  66. 'length': state_defs[ev_ix][1],
  67. 'startVal': state_defs[ev_ix][2],
  68. 'bytePos': state_defs[ev_ix][3],
  69. 'bitPos': state_defs[ev_ix][4]
  70. })
  71. import time
  72. time_formats = ['%a %b %d %H:%M:%S %Y', '%Y-%m-%dT%H:%M:%S']
  73. try:
  74. self._global_time = time.mktime(time.strptime(param_defs['StorageTime']['value'],
  75. time_formats[0]))
  76. except:
  77. self._global_time = time.mktime(time.strptime(param_defs['StorageTime']['value'],
  78. time_formats[1]))
  79. # Save variables to make it easier to load the binary data.
  80. self._read_info = {
  81. 'header_len': file_info['HeaderLen'],
  82. 'n_chans': file_info['SourceCh'],
  83. 'sample_dtype': {
  84. 'int16': np.int16,
  85. 'int32': np.int32,
  86. 'float32': np.float32}.get(file_info['DataFormat']),
  87. 'state_vec_len': file_info['StatevectorLen'],
  88. 'sampling_rate': param_defs['SamplingRate']['value']
  89. }
  90. # Calculate the dtype for a single timestamp of data. This contains the data + statevector
  91. self._read_info['line_dtype'] = [
  92. ('raw_vector', self._read_info['sample_dtype'], self._read_info['n_chans']),
  93. ('state_vector', np.uint8, self._read_info['state_vec_len'])]
  94. import os
  95. self._read_info['n_samps'] = int((os.stat(self.filename).st_size - file_info['HeaderLen'])
  96. / np.dtype(self._read_info['line_dtype']).itemsize)
  97. # memmap is fast so we can get the data ready for reading now.
  98. self._memmap = np.memmap(self.filename, dtype=self._read_info['line_dtype'],
  99. offset=self._read_info['header_len'], mode='r')
  100. def _segment_t_start(self, block_index, seg_index):
  101. return 0.
  102. def _segment_t_stop(self, block_index, seg_index):
  103. return self._read_info['n_samps'] / self._read_info['sampling_rate']
  104. def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
  105. return self._read_info['n_samps']
  106. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  107. return 0.
  108. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  109. if i_start is None:
  110. i_start = 0
  111. if i_stop is None:
  112. i_stop = self._read_info['n_samps']
  113. assert (0 <= i_start <= self._read_info['n_samps']), "i_start outside data range"
  114. assert (0 <= i_stop <= self._read_info['n_samps']), "i_stop outside data range"
  115. if channel_indexes is None:
  116. channel_indexes = np.arange(self.header['signal_channels'].size)
  117. return self._memmap['raw_vector'][i_start:i_stop, channel_indexes]
  118. def _spike_count(self, block_index, seg_index, unit_index):
  119. return 0
  120. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  121. return None
  122. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  123. return None
  124. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  125. return None
  126. def _event_count(self, block_index, seg_index, event_channel_index):
  127. return self._event_arrays_list[event_channel_index][0].shape[0]
  128. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  129. # Return 3 numpy arrays: timestamp, durations, labels
  130. # durations must be None for 'event'
  131. # label must a dtype ='U'
  132. ts, dur, labels = self._event_arrays_list[event_channel_index]
  133. # seg_t_start = self._segment_t_start(block_index, seg_index)
  134. keep = np.ones(ts.shape, dtype=np.bool)
  135. if t_start is not None:
  136. keep = np.logical_and(keep, ts >= t_start)
  137. if t_stop is not None:
  138. keep = np.logical_and(keep, ts <= t_stop)
  139. return ts[keep], dur[keep], labels[keep]
  140. def _rescale_event_timestamp(self, event_timestamps, dtype):
  141. event_times = (event_timestamps / float(self._read_info['sampling_rate'])).astype(dtype)
  142. return event_times
  143. def _rescale_epoch_duration(self, raw_duration, dtype):
  144. durations = (raw_duration / float(self._read_info['sampling_rate'])).astype(dtype)
  145. return durations
  146. @property
  147. def _event_arrays_list(self):
  148. if self._my_events is None:
  149. self._my_events = []
  150. for s_ix, sd in enumerate(self.raw_annotations['event_channels']):
  151. ev_times = durs = vals = np.array([])
  152. # Skip these big but mostly useless (?) states.
  153. if sd['name'] not in ['SourceTime', 'StimulusTime']:
  154. # Determine which bytes of self._memmap['state_vector'] are needed.
  155. nbytes = int(np.ceil((sd['bitPos'] + sd['length']) / 8))
  156. byte_slice = slice(sd['bytePos'], sd['bytePos'] + nbytes)
  157. # Then determine how to mask those bytes to get only the needed bits.
  158. bit_mask = np.array([255] * nbytes, dtype=np.uint8)
  159. bit_mask[0] &= 255 & (255 << sd['bitPos']) # Fix the mask for the first byte
  160. extra_bits = 8 - (sd['bitPos'] + sd['length']) % 8
  161. bit_mask[-1] &= 255 & (255 >> extra_bits) # Fix the mask for the last byte
  162. # When converting to an int, we need to know which integer type it will become
  163. n_max_bytes = 1 << (nbytes - 1).bit_length()
  164. view_type = {1: np.int8, 2: np.int16,
  165. 4: np.int32, 8: np.int64}.get(n_max_bytes)
  166. # Slice and mask the data
  167. masked_byte_array = self._memmap['state_vector'][:, byte_slice] & bit_mask
  168. # Convert byte array to a vector of ints:
  169. # pad to give even columns then view as larger int type
  170. state_vec = np.pad(masked_byte_array,
  171. (0, n_max_bytes - nbytes),
  172. 'constant').view(dtype=view_type)
  173. state_vec = np.right_shift(state_vec, sd['bitPos'])[:, 0]
  174. # In the state vector, find 'events' whenever the state changes
  175. st_ch_ix = np.where(np.hstack((0, np.diff(state_vec))) != 0)[0] # event inds
  176. if len(st_ch_ix) > 0:
  177. ev_times = st_ch_ix
  178. durs = np.asarray([None] * len(st_ch_ix))
  179. # np.hstack((np.diff(st_ch_ix), len(state_vec) - st_ch_ix[-1]))
  180. vals = np.char.mod('%d', state_vec[st_ch_ix]) # event val, string'd
  181. self._my_events.append([ev_times, durs, vals.astype('U')])
  182. return self._my_events
  183. def parse_bci2000_header(filename):
  184. # typically we want parameter values in Hz, seconds, or microvolts.
  185. scales_dict = {
  186. 'hz': 1, 'khz': 1000, 'mhz': 1000000,
  187. 'uv': 1, 'muv': 1, 'mv': 1000, 'v': 1000000,
  188. 's': 1, 'us': 0.000001, 'mus': 0.000001, 'ms': 0.001, 'min': 60,
  189. 'sec': 1, 'usec': 0.000001, 'musec': 0.000001, 'msec': 0.001
  190. }
  191. def rescale_value(param_val, data_type):
  192. unit_str = ''
  193. if param_val.lower().startswith('0x'):
  194. param_val = int(param_val, 16)
  195. elif data_type in ['int', 'float']:
  196. matches = re.match(r'(-*\d+)(\w*)', param_val)
  197. if matches is not None: # Can be None for % in def, min, max vals
  198. param_val, unit_str = matches.group(1), matches.group(2)
  199. param_val = int(param_val) if data_type == 'int' else float(param_val)
  200. if len(unit_str) > 0:
  201. param_val *= scales_dict.get(unit_str.lower(), 1)
  202. else:
  203. param_val = unquote(param_val)
  204. return param_val, unit_str
  205. def parse_dimensions(param_list):
  206. num_els = param_list.pop(0)
  207. # Sometimes the number of elements isn't given,
  208. # but the list of element labels is wrapped with {}
  209. if num_els == '{':
  210. num_els = param_list.index('}')
  211. el_labels = [unquote(param_list.pop(0)) for x in range(num_els)]
  212. param_list.pop(0) # Remove the '}'
  213. else:
  214. num_els = int(num_els)
  215. el_labels = [str(ix) for ix in range(num_els)]
  216. return num_els, el_labels
  217. with open(filename, 'rb') as fid:
  218. # Parse the file header (plain text)
  219. # The first line contains basic information which we store in a dictionary.
  220. temp = fid.readline().decode('utf8').split()
  221. keys = [k.rstrip('=') for k in temp[::2]]
  222. vals = temp[1::2]
  223. # Insert default version and format
  224. file_info = {'BCI2000V': 1.0, 'DataFormat': 'int16'}
  225. file_info.update(**dict(zip(keys, vals)))
  226. # From string to float/int
  227. file_info['BCI2000V'] = float(file_info['BCI2000V'])
  228. for k in ['HeaderLen', 'SourceCh', 'StatevectorLen']:
  229. if k in file_info:
  230. file_info[k] = int(file_info[k])
  231. # The next lines contain state vector definitions.
  232. temp = fid.readline().decode('utf8').strip()
  233. assert temp == '[ State Vector Definition ]', \
  234. "State definitions not found in header %s" % filename
  235. state_defs = []
  236. state_def_dtype = [('name', 'a64'),
  237. ('length', int),
  238. ('startVal', int),
  239. ('bytePos', int),
  240. ('bitPos', int)]
  241. while True:
  242. temp = fid.readline().decode('utf8').strip()
  243. if len(temp) == 0 or temp[0] == '[':
  244. # Presence of '[' signifies new section.
  245. break
  246. temp = temp.split()
  247. state_defs.append((temp[0], int(temp[1]), int(temp[2]), int(temp[3]), int(temp[4])))
  248. state_defs = np.array(state_defs, dtype=state_def_dtype)
  249. # The next lines contain parameter definitions.
  250. # There are many, and their formatting can be complicated.
  251. assert temp == '[ Parameter Definition ]', \
  252. "Parameter definitions not found in header %s" % filename
  253. param_defs = {}
  254. while True:
  255. temp = fid.readline().decode('utf8')
  256. if fid.tell() >= file_info['HeaderLen']:
  257. # End of header.
  258. break
  259. if len(temp.strip()) == 0:
  260. continue # Skip empty lines
  261. # Everything after the '//' is a comment.
  262. temp = temp.strip().split('//', 1)
  263. param_def = {'comment': temp[1].strip() if len(temp) > 1 else ''}
  264. # Parse the parameter definition. Generally it is sec:cat:name dtype name param_value+
  265. temp = temp[0].split()
  266. param_def.update(
  267. {'section_category_name': [unquote(x) for x in temp.pop(0).split(':')]})
  268. dtype = temp.pop(0)
  269. param_name = unquote(temp.pop(0).rstrip('='))
  270. # Parse the rest. Parse method depends on the dtype
  271. param_value, units = None, None
  272. if dtype in ('int', 'float'):
  273. param_value = temp.pop(0)
  274. if param_value == 'auto':
  275. param_value = np.nan
  276. units = ''
  277. else:
  278. param_value, units = rescale_value(param_value, dtype)
  279. elif dtype in ('string', 'variant'):
  280. param_value = unquote(temp.pop(0))
  281. elif dtype.endswith('list'): # e.g., intlist, stringlist, floatlist, list
  282. dtype = dtype[:-4]
  283. # The list parameter values will begin with either
  284. # an int to specify the number of elements
  285. # or a list of labels surrounded by { }.
  286. num_elements, element_labels = parse_dimensions(temp) # This will pop off info.
  287. param_def.update({'element_labels': element_labels})
  288. pv_un = [rescale_value(pv, dtype) for pv in temp[:num_elements]]
  289. if len(pv_un) > 0:
  290. param_value, units = zip(*pv_un)
  291. else:
  292. param_value, units = np.nan, ''
  293. temp = temp[num_elements:]
  294. # Sometimes an element list will be a list of ints even though
  295. # the element_type is '' (str)...
  296. # This usually happens for known parameters, such as SourceChOffset,
  297. # that can be dealt with explicitly later.
  298. elif dtype.endswith('matrix'):
  299. dtype = dtype[:-6]
  300. # The parameter values will be preceded by two dimension descriptors,
  301. # first rows then columns. Each dimension might be described by an
  302. # int or a list of labels surrounded by {}
  303. n_rows, row_labels = parse_dimensions(temp)
  304. n_cols, col_labels = parse_dimensions(temp)
  305. param_def.update({'row_labels': row_labels, 'col_labels': col_labels})
  306. param_value = []
  307. units = []
  308. for row_ix in range(n_rows):
  309. cols = []
  310. for col_ix in range(n_cols):
  311. col_val, _units = rescale_value(temp[row_ix * n_cols + col_ix], dtype)
  312. cols.append(col_val)
  313. units.append(_units)
  314. param_value.append(cols)
  315. temp = temp[n_rows * n_cols:]
  316. param_def.update({
  317. 'value': param_value,
  318. 'units': units,
  319. 'dtype': dtype
  320. })
  321. # At the end of the parameter definition, we might get
  322. # default, min, max values for the parameter.
  323. temp.reverse()
  324. if len(temp):
  325. param_def.update({'max_val': rescale_value(temp.pop(0), dtype)})
  326. if len(temp):
  327. param_def.update({'min_val': rescale_value(temp.pop(0), dtype)})
  328. if len(temp):
  329. param_def.update({'default_val': rescale_value(temp.pop(0), dtype)})
  330. param_defs.update({param_name: param_def})
  331. # End parameter block
  332. # Outdent to close file
  333. return file_info, state_defs, param_defs