openephysrawio.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # -*- coding: utf-8 -*-
  2. """
  3. This module implement OpenEphys format.
  4. Author: Samuel Garcia
  5. """
  6. from __future__ import unicode_literals, print_function, division, absolute_import
  7. import os
  8. import numpy as np
  9. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  10. _event_channel_dtype)
  11. RECORD_SIZE = 1024
  12. HEADER_SIZE = 1024
  13. class OpenEphysRawIO(BaseRawIO):
  14. """
  15. OpenEphys GUI software offers several data formats, see
  16. https://open-ephys.atlassian.net/wiki/spaces/OEW/pages/491632/Data+format
  17. This class implements the legacy OpenEphys format here
  18. https://open-ephys.atlassian.net/wiki/spaces/OEW/pages/65667092/Open+Ephys+format
  19. The OpenEphys group already proposes some tools here:
  20. https://github.com/open-ephys/analysis-tools/blob/master/OpenEphys.py
  21. but (i) there is no package at PyPI and (ii) those tools read everything in memory.
  22. The format is directory based with several files:
  23. * .continuous
  24. * .events
  25. * .spikes
  26. This implementation is based on:
  27. * this code https://github.com/open-ephys/analysis-tools/blob/master/Python3/OpenEphys.py
  28. written by Dan Denman and Josh Siegle
  29. * a previous PR by Cristian Tatarau at Charité Berlin
  30. In contrast to previous code for reading this format, here all data use memmap so it should
  31. be super fast and light compared to legacy code.
  32. When the acquisition is stopped and restarted then files are named *_2, *_3.
  33. In that case this class creates a new Segment. Note that timestamps are reset in this
  34. situation.
  35. Limitation :
  36. * Works only if all continuous channels have the same sampling rate, which is a reasonable
  37. hypothesis.
  38. * When the recording is stopped and restarted all continuous files will contain gaps.
  39. Ideally this would lead to a new Segment but this use case is not implemented due to its
  40. complexity.
  41. Instead it will raise an error.
  42. Special cases:
  43. * Normaly all continuous files have the same first timestamp and length. In situations
  44. where it is not the case all files are clipped to the smallest one so that they are all
  45. aligned,
  46. and a warning is emitted.
  47. """
  48. extensions = []
  49. rawmode = 'one-dir'
  50. def __init__(self, dirname=''):
  51. BaseRawIO.__init__(self)
  52. self.dirname = dirname
  53. def _source_name(self):
  54. return self.dirname
  55. def _parse_header(self):
  56. info = self._info = explore_folder(self.dirname)
  57. nb_segment = info['nb_segment']
  58. # scan for continuous files
  59. self._sigs_memmap = {}
  60. self._sig_length = {}
  61. self._sig_timestamp0 = {}
  62. sig_channels = []
  63. for seg_index in range(nb_segment):
  64. self._sigs_memmap[seg_index] = {}
  65. all_sigs_length = []
  66. all_first_timestamps = []
  67. all_last_timestamps = []
  68. all_samplerate = []
  69. for continuous_filename in info['continuous'][seg_index]:
  70. fullname = os.path.join(self.dirname, continuous_filename)
  71. chan_info = read_file_header(fullname)
  72. s = continuous_filename.replace('.continuous', '').split('_')
  73. processor_id, ch_name = s[0], s[1]
  74. chan_id = int(ch_name.replace('CH', ''))
  75. filesize = os.stat(fullname).st_size
  76. size = (filesize - HEADER_SIZE) // np.dtype(continuous_dtype).itemsize
  77. data_chan = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  78. dtype=continuous_dtype, shape=(size, ))
  79. self._sigs_memmap[seg_index][chan_id] = data_chan
  80. all_sigs_length.append(data_chan.size * RECORD_SIZE)
  81. all_first_timestamps.append(data_chan[0]['timestamp'])
  82. all_last_timestamps.append(data_chan[-1]['timestamp'])
  83. all_samplerate.append(chan_info['sampleRate'])
  84. # check for continuity (no gaps)
  85. diff = np.diff(data_chan['timestamp'])
  86. assert np.all(diff == RECORD_SIZE), \
  87. 'Not continuous timestamps for {}. ' \
  88. 'Maybe because recording was paused/stopped.'.format(continuous_filename)
  89. if seg_index == 0:
  90. # add in channel list
  91. sig_channels.append((ch_name, chan_id, chan_info['sampleRate'],
  92. 'int16', 'V', chan_info['bitVolts'], 0., int(processor_id)))
  93. # In some cases, continuous do not have the same lentgh because
  94. # one record block is missing when the "OE GUI is freezing"
  95. # So we need to clip to the smallest files
  96. if not all(all_sigs_length[0] == e for e in all_sigs_length) or\
  97. not all(all_first_timestamps[0] == e for e in all_first_timestamps):
  98. self.logger.warning('Continuous files do not have aligned timestamps; '
  99. 'clipping to make them aligned.')
  100. first, last = -np.inf, np.inf
  101. for chan_id in self._sigs_memmap[seg_index]:
  102. data_chan = self._sigs_memmap[seg_index][chan_id]
  103. if data_chan[0]['timestamp'] > first:
  104. first = data_chan[0]['timestamp']
  105. if data_chan[-1]['timestamp'] < last:
  106. last = data_chan[-1]['timestamp']
  107. all_sigs_length = []
  108. all_first_timestamps = []
  109. all_last_timestamps = []
  110. for chan_id in self._sigs_memmap[seg_index]:
  111. data_chan = self._sigs_memmap[seg_index][chan_id]
  112. keep = (data_chan['timestamp'] >= first) & (data_chan['timestamp'] <= last)
  113. data_chan = data_chan[keep]
  114. self._sigs_memmap[seg_index][chan_id] = data_chan
  115. all_sigs_length.append(data_chan.size * RECORD_SIZE)
  116. all_first_timestamps.append(data_chan[0]['timestamp'])
  117. all_last_timestamps.append(data_chan[-1]['timestamp'])
  118. # chech that all signals have the same lentgh and timestamp0 for this segment
  119. assert all(all_sigs_length[0] == e for e in all_sigs_length),\
  120. 'All signals do not have the same lentgh'
  121. assert all(all_first_timestamps[0] == e for e in all_first_timestamps),\
  122. 'All signals do not have the same first timestamp'
  123. assert all(all_samplerate[0] == e for e in all_samplerate),\
  124. 'All signals do not have the same sample rate'
  125. self._sig_length[seg_index] = all_sigs_length[0]
  126. self._sig_timestamp0[seg_index] = all_first_timestamps[0]
  127. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  128. self._sig_sampling_rate = sig_channels['sampling_rate'][0] # unique for channel
  129. # scan for spikes files
  130. unit_channels = []
  131. if len(info['spikes']) > 0:
  132. self._spikes_memmap = {}
  133. for seg_index in range(nb_segment):
  134. self._spikes_memmap[seg_index] = {}
  135. for spike_filename in info['spikes'][seg_index]:
  136. fullname = os.path.join(self.dirname, spike_filename)
  137. spike_info = read_file_header(fullname)
  138. spikes_dtype = make_spikes_dtype(fullname)
  139. # "STp106.0n0_2.spikes" to "STp106.0n0"
  140. name = spike_filename.replace('.spikes', '')
  141. if seg_index > 0:
  142. name = name.replace('_' + str(seg_index + 1), '')
  143. data_spike = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  144. dtype=spikes_dtype)
  145. self._spikes_memmap[seg_index][name] = data_spike
  146. # In each file 'sorted_id' indicate the number of cluster so number of units
  147. # so need to scan file for all segment to get units
  148. self._spike_sampling_rate = None
  149. for spike_filename_seg0 in info['spikes'][0]:
  150. name = spike_filename_seg0.replace('.spikes', '')
  151. fullname = os.path.join(self.dirname, spike_filename_seg0)
  152. spike_info = read_file_header(fullname)
  153. if self._spike_sampling_rate is None:
  154. self._spike_sampling_rate = spike_info['sampleRate']
  155. else:
  156. assert self._spike_sampling_rate == spike_info['sampleRate'],\
  157. 'mismatch in spike sampling rate'
  158. # scan all to detect several all unique(sorted_ids)
  159. all_sorted_ids = []
  160. for seg_index in range(nb_segment):
  161. data_spike = self._spikes_memmap[seg_index][name]
  162. all_sorted_ids += np.unique(data_spike['sorted_id']).tolist()
  163. all_sorted_ids = np.unique(all_sorted_ids)
  164. # supose all channel have the same gain
  165. wf_units = 'uV'
  166. wf_gain = 1000. / data_spike[0]['gains'][0]
  167. wf_offset = - (2**15) * wf_gain
  168. wf_left_sweep = 0
  169. wf_sampling_rate = spike_info['sampleRate']
  170. # each sorted_id is one channel
  171. for sorted_id in all_sorted_ids:
  172. unit_name = "{}#{}".format(name, sorted_id)
  173. unit_id = "{}#{}".format(name, sorted_id)
  174. unit_channels.append((unit_name, unit_id, wf_units,
  175. wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate))
  176. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  177. # event file are:
  178. # * all_channel.events (header + binray) --> event 0
  179. # and message.events (text based) --> event 1 not implemented yet
  180. event_channels = []
  181. self._events_memmap = {}
  182. for seg_index in range(nb_segment):
  183. if seg_index == 0:
  184. event_filename = 'all_channels.events'
  185. else:
  186. event_filename = 'all_channels_{}.events'.format(seg_index + 1)
  187. fullname = os.path.join(self.dirname, event_filename)
  188. event_info = read_file_header(fullname)
  189. self._event_sampling_rate = event_info['sampleRate']
  190. data_event = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  191. dtype=events_dtype)
  192. self._events_memmap[seg_index] = data_event
  193. event_channels.append(('all_channels', '', 'event'))
  194. # event_channels.append(('message', '', 'event')) # not implemented
  195. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  196. # main header
  197. self.header = {}
  198. self.header['nb_block'] = 1
  199. self.header['nb_segment'] = [nb_segment]
  200. self.header['signal_channels'] = sig_channels
  201. self.header['unit_channels'] = unit_channels
  202. self.header['event_channels'] = event_channels
  203. # Annotate some objects from coninuous files
  204. self._generate_minimal_annotations()
  205. bl_ann = self.raw_annotations['blocks'][0]
  206. for seg_index in range(nb_segment):
  207. seg_ann = bl_ann['segments'][seg_index]
  208. if len(info['continuous']) > 0:
  209. fullname = os.path.join(self.dirname, info['continuous'][seg_index][0])
  210. chan_info = read_file_header(fullname)
  211. seg_ann['openephys_version'] = chan_info['version']
  212. bl_ann['openephys_version'] = chan_info['version']
  213. seg_ann['date_created'] = chan_info['date_created']
  214. def _segment_t_start(self, block_index, seg_index):
  215. # segment start/stop are difine by continuous channels
  216. return self._sig_timestamp0[seg_index] / self._sig_sampling_rate
  217. def _segment_t_stop(self, block_index, seg_index):
  218. return (self._sig_timestamp0[seg_index] + self._sig_length[seg_index])\
  219. / self._sig_sampling_rate
  220. def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
  221. return self._sig_length[seg_index]
  222. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  223. return self._sig_timestamp0[seg_index] / self._sig_sampling_rate
  224. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  225. if i_start is None:
  226. i_start = 0
  227. if i_stop is None:
  228. i_stop = self._sig_length[seg_index]
  229. block_start = i_start // RECORD_SIZE
  230. block_stop = i_stop // RECORD_SIZE + 1
  231. sl0 = i_start % RECORD_SIZE
  232. sl1 = sl0 + (i_stop - i_start)
  233. if channel_indexes is None:
  234. channel_indexes = slice(None)
  235. channel_ids = self.header['signal_channels'][channel_indexes]['id']
  236. sigs_chunk = np.zeros((i_stop - i_start, len(channel_ids)), dtype='int16')
  237. for i, chan_id in enumerate(channel_ids):
  238. data = self._sigs_memmap[seg_index][chan_id]
  239. sub = data[block_start:block_stop]
  240. sigs_chunk[:, i] = sub['samples'].flatten()[sl0:sl1]
  241. return sigs_chunk
  242. def _get_spike_slice(self, seg_index, unit_index, t_start, t_stop):
  243. name, sorted_id = self.header['unit_channels'][unit_index]['name'].split('#')
  244. sorted_id = int(sorted_id)
  245. data_spike = self._spikes_memmap[seg_index][name]
  246. if t_start is None:
  247. t_start = self._segment_t_start(0, seg_index)
  248. if t_stop is None:
  249. t_stop = self._segment_t_stop(0, seg_index)
  250. ts0 = int(t_start * self._spike_sampling_rate)
  251. ts1 = int(t_stop * self._spike_sampling_rate)
  252. ts = data_spike['timestamp']
  253. keep = (data_spike['sorted_id'] == sorted_id) & (ts >= ts0) & (ts <= ts1)
  254. return data_spike, keep
  255. def _spike_count(self, block_index, seg_index, unit_index):
  256. data_spike, keep = self._get_spike_slice(seg_index, unit_index, None, None)
  257. return np.sum(keep)
  258. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  259. data_spike, keep = self._get_spike_slice(seg_index, unit_index, t_start, t_stop)
  260. return data_spike['timestamp'][keep]
  261. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  262. spike_times = spike_timestamps.astype(dtype) / self._spike_sampling_rate
  263. return spike_times
  264. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  265. data_spike, keep = self._get_spike_slice(seg_index, unit_index, t_start, t_stop)
  266. nb_chan = data_spike[0]['nb_channel']
  267. nb = np.sum(keep)
  268. waveforms = data_spike[keep]['samples'].flatten()
  269. waveforms = waveforms.reshape(nb, nb_chan, -1)
  270. return waveforms
  271. def _event_count(self, block_index, seg_index, event_channel_index):
  272. # assert event_channel_index==0
  273. return self._events_memmap[seg_index].size
  274. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  275. # assert event_channel_index==0
  276. if t_start is None:
  277. t_start = self._segment_t_start(block_index, seg_index)
  278. if t_stop is None:
  279. t_stop = self._segment_t_stop(block_index, seg_index)
  280. ts0 = int(t_start * self._event_sampling_rate)
  281. ts1 = int(t_stop * self._event_sampling_rate)
  282. ts = self._events_memmap[seg_index]['timestamp']
  283. keep = (ts >= ts0) & (ts <= ts1)
  284. subdata = self._events_memmap[seg_index][keep]
  285. timestamps = subdata['timestamp']
  286. # question what is the label????
  287. # here I put a combinaison
  288. labels = np.array(['{}#{}#{}'.format(int(d['event_type']),
  289. int(d['processor_id']), int(d['chan_id'])) for d in subdata])
  290. durations = None
  291. return timestamps, durations, labels
  292. def _rescale_event_timestamp(self, event_timestamps, dtype):
  293. event_times = event_timestamps.astype(dtype) / self._event_sampling_rate
  294. return event_times
  295. def _rescale_epoch_duration(self, raw_duration, dtype):
  296. return None
  297. continuous_dtype = [('timestamp', 'int64'), ('nb_sample', 'uint16'),
  298. ('rec_num', 'uint16'), ('samples', 'int16', RECORD_SIZE),
  299. ('markers', 'uint8', 10)]
  300. events_dtype = [('timestamp', 'int64'), ('sample_pos', 'int16'),
  301. ('event_type', 'uint8'), ('processor_id', 'uint8'),
  302. ('event_id', 'uint8'), ('chan_id', 'uint8'),
  303. ('record_num', 'uint16')]
  304. # the dtype is dynamic and depend on nb_channel and nb_sample
  305. _base_spikes_dtype = [('event_stype', 'uint8'), ('timestamp', 'int64'),
  306. ('software_timestamp', 'int64'), ('source_id', 'uint16'),
  307. ('nb_channel', 'uint16'), ('nb_sample', 'uint16'),
  308. ('sorted_id', 'uint16'), ('electrode_id', 'uint16'),
  309. ('within_chan_index', 'uint16'), ('color', 'uint8', 3),
  310. ('pca', 'float32', 2), ('sampling_rate', 'uint16'),
  311. ('samples', 'uint16', None), ('gains', 'float32', None),
  312. ('thresholds', 'uint16', None), ('rec_num', 'uint16')]
  313. def make_spikes_dtype(filename):
  314. """
  315. Given the spike file make the appropriate dtype that depends on:
  316. * N - number of channels
  317. * M - samples per spike
  318. See documentation of file format.
  319. """
  320. # strangly the header do not have the sample size
  321. # So this do not work (too bad):
  322. # spike_info = read_file_header(filename)
  323. # N = spike_info['num_channels']
  324. # M =????
  325. # so we need to read the very first spike
  326. # but it will fail when 0 spikes (too bad)
  327. filesize = os.stat(filename).st_size
  328. if filesize >= (HEADER_SIZE + 23):
  329. with open(filename, mode='rb') as f:
  330. # M and N is at 1024 + 19 bytes
  331. f.seek(HEADER_SIZE + 19)
  332. N = np.fromfile(f, np.dtype('<u2'), 1)[0]
  333. M = np.fromfile(f, np.dtype('<u2'), 1)[0]
  334. else:
  335. spike_info = read_file_header(filename)
  336. N = spike_info['num_channels']
  337. M = 40 # this is in the original code from openephys
  338. # make a copy
  339. spikes_dtype = [e for e in _base_spikes_dtype]
  340. spikes_dtype[12] = ('samples', 'uint16', N * M)
  341. spikes_dtype[13] = ('gains', 'float32', N)
  342. spikes_dtype[14] = ('thresholds', 'uint16', N)
  343. return spikes_dtype
  344. def explore_folder(dirname):
  345. """
  346. This explores a folder and dispatch coninuous, event and spikes
  347. files by segment (aka recording session).
  348. The number of segments is checked with these rules
  349. "100_CH0.continuous" ---> seg_index 0
  350. "100_CH0_2.continuous" ---> seg_index 1
  351. "100_CH0_N.continuous" ---> seg_index N-1
  352. """
  353. filenames = os.listdir(dirname)
  354. info = {}
  355. info['nb_segment'] = 0
  356. info['continuous'] = {}
  357. info['spikes'] = {}
  358. for filename in filenames:
  359. if filename.endswith('.continuous'):
  360. s = filename.replace('.continuous', '').split('_')
  361. if len(s) == 2:
  362. seg_index = 0
  363. else:
  364. seg_index = int(s[2]) - 1
  365. if seg_index not in info['continuous'].keys():
  366. info['continuous'][seg_index] = []
  367. info['continuous'][seg_index].append(filename)
  368. if (seg_index + 1) > info['nb_segment']:
  369. info['nb_segment'] += 1
  370. elif filename.endswith('.spikes'):
  371. s = filename.replace('.spikes', '').split('_')
  372. if len(s) == 1:
  373. seg_index = 0
  374. else:
  375. seg_index = int(s[1]) - 1
  376. if seg_index not in info['spikes'].keys():
  377. info['spikes'][seg_index] = []
  378. info['spikes'][seg_index].append(filename)
  379. if (seg_index + 1) > info['nb_segment']:
  380. info['nb_segment'] += 1
  381. # order continuous file by channel number within segment
  382. for seg_index, continuous_filenames in info['continuous'].items():
  383. channel_ids = []
  384. for continuous_filename in continuous_filenames:
  385. s = continuous_filename.replace('.continuous', '').split('_')
  386. processor_id, ch_name = s[0], s[1]
  387. chan_id = int(ch_name.replace('CH', ''))
  388. channel_ids.append(chan_id)
  389. order = np.argsort(channel_ids)
  390. continuous_filenames = [continuous_filenames[i] for i in order]
  391. info['continuous'][seg_index] = continuous_filenames
  392. # order spike files within segment
  393. for seg_index, spike_filenames in info['spikes'].items():
  394. names = []
  395. for spike_filename in spike_filenames:
  396. name = spike_filename.replace('.spikes', '')
  397. if seg_index > 0:
  398. name = name.replace('_' + str(seg_index + 1), '')
  399. names.append(name)
  400. order = np.argsort(names)
  401. spike_filenames = [spike_filenames[i] for i in order]
  402. info['spikes'][seg_index] = spike_filenames
  403. return info
  404. def read_file_header(filename):
  405. """Read header information from the first 1024 bytes of an OpenEphys file.
  406. See docs.
  407. """
  408. header = {}
  409. with open(filename, mode='rb') as f:
  410. # Read the data as a string
  411. # Remove newlines and redundant "header." prefixes
  412. # The result should be a series of "key = value" strings, separated
  413. # by semicolons.
  414. header_string = f.read(HEADER_SIZE).replace(b'\n', b'').replace(b'header.', b'')
  415. # Parse each key = value string separately
  416. for pair in header_string.split(b';'):
  417. if b'=' in pair:
  418. key, value = pair.split(b' = ')
  419. key = key.strip().decode('ascii')
  420. value = value.strip()
  421. # Convert some values to numeric
  422. if key in ['bitVolts', 'sampleRate']:
  423. header[key] = float(value)
  424. elif key in ['blockLength', 'bufferSize', 'header_bytes', 'num_channels']:
  425. header[key] = int(value)
  426. else:
  427. # Keep as string
  428. header[key] = value.decode('ascii')
  429. return header