openephysrawio.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. """
  2. This module implement OpenEphys format.
  3. Author: Samuel Garcia
  4. """
  5. import os
  6. import re
  7. import numpy as np
  8. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  9. _event_channel_dtype)
  10. RECORD_SIZE = 1024
  11. HEADER_SIZE = 1024
  12. class OpenEphysRawIO(BaseRawIO):
  13. """
  14. OpenEphys GUI software offers several data formats, see
  15. https://open-ephys.atlassian.net/wiki/spaces/OEW/pages/491632/Data+format
  16. This class implements the legacy OpenEphys format here
  17. https://open-ephys.atlassian.net/wiki/spaces/OEW/pages/65667092/Open+Ephys+format
  18. The OpenEphys group already proposes some tools here:
  19. https://github.com/open-ephys/analysis-tools/blob/master/OpenEphys.py
  20. but (i) there is no package at PyPI and (ii) those tools read everything in memory.
  21. The format is directory based with several files:
  22. * .continuous
  23. * .events
  24. * .spikes
  25. This implementation is based on:
  26. * this code https://github.com/open-ephys/analysis-tools/blob/master/Python3/OpenEphys.py
  27. written by Dan Denman and Josh Siegle
  28. * a previous PR by Cristian Tatarau at Charité Berlin
  29. In contrast to previous code for reading this format, here all data use memmap so it should
  30. be super fast and light compared to legacy code.
  31. When the acquisition is stopped and restarted then files are named ``*_2``, ``*_3``.
  32. In that case this class creates a new Segment. Note that timestamps are reset in this
  33. situation.
  34. Limitation :
  35. * Works only if all continuous channels have the same sampling rate, which is a reasonable
  36. hypothesis.
  37. * When the recording is stopped and restarted all continuous files will contain gaps.
  38. Ideally this would lead to a new Segment but this use case is not implemented due to its
  39. complexity.
  40. Instead it will raise an error.
  41. Special cases:
  42. * Normaly all continuous files have the same first timestamp and length. In situations
  43. where it is not the case all files are clipped to the smallest one so that they are all
  44. aligned,
  45. and a warning is emitted.
  46. """
  47. extensions = []
  48. rawmode = 'one-dir'
  49. def __init__(self, dirname=''):
  50. BaseRawIO.__init__(self)
  51. self.dirname = dirname
  52. def _source_name(self):
  53. return self.dirname
  54. def _parse_header(self):
  55. info = self._info = explore_folder(self.dirname)
  56. nb_segment = info['nb_segment']
  57. # scan for continuous files
  58. self._sigs_memmap = {}
  59. self._sig_length = {}
  60. self._sig_timestamp0 = {}
  61. sig_channels = []
  62. oe_indices = sorted(list(info['continuous'].keys()))
  63. for seg_index, oe_index in enumerate(oe_indices):
  64. self._sigs_memmap[seg_index] = {}
  65. all_sigs_length = []
  66. all_first_timestamps = []
  67. all_last_timestamps = []
  68. all_samplerate = []
  69. for chan_index, continuous_filename in enumerate(info['continuous'][oe_index]):
  70. fullname = os.path.join(self.dirname, continuous_filename)
  71. chan_info = read_file_header(fullname)
  72. s = continuous_filename.replace('.continuous', '').split('_')
  73. processor_id, ch_name = s[0], s[1]
  74. chan_str = re.split(r'(\d+)', s[1])[0]
  75. # note that chan_id is not unique in case of CH + AUX
  76. chan_id = int(ch_name.replace(chan_str, ''))
  77. filesize = os.stat(fullname).st_size
  78. size = (filesize - HEADER_SIZE) // np.dtype(continuous_dtype).itemsize
  79. data_chan = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  80. dtype=continuous_dtype, shape=(size, ))
  81. self._sigs_memmap[seg_index][chan_index] = data_chan
  82. all_sigs_length.append(data_chan.size * RECORD_SIZE)
  83. all_first_timestamps.append(data_chan[0]['timestamp'])
  84. all_last_timestamps.append(data_chan[-1]['timestamp'])
  85. all_samplerate.append(chan_info['sampleRate'])
  86. # check for continuity (no gaps)
  87. diff = np.diff(data_chan['timestamp'])
  88. assert np.all(diff == RECORD_SIZE), \
  89. 'Not continuous timestamps for {}. ' \
  90. 'Maybe because recording was paused/stopped.'.format(continuous_filename)
  91. if seg_index == 0:
  92. # add in channel list
  93. sig_channels.append((ch_name, chan_id, chan_info['sampleRate'],
  94. 'int16', 'V', chan_info['bitVolts'], 0., int(processor_id)))
  95. # In some cases, continuous do not have the same lentgh because
  96. # one record block is missing when the "OE GUI is freezing"
  97. # So we need to clip to the smallest files
  98. if not all(all_sigs_length[0] == e for e in all_sigs_length) or\
  99. not all(all_first_timestamps[0] == e for e in all_first_timestamps):
  100. self.logger.warning('Continuous files do not have aligned timestamps; '
  101. 'clipping to make them aligned.')
  102. first, last = -np.inf, np.inf
  103. for chan_index in self._sigs_memmap[seg_index]:
  104. data_chan = self._sigs_memmap[seg_index][chan_index]
  105. if data_chan[0]['timestamp'] > first:
  106. first = data_chan[0]['timestamp']
  107. if data_chan[-1]['timestamp'] < last:
  108. last = data_chan[-1]['timestamp']
  109. all_sigs_length = []
  110. all_first_timestamps = []
  111. all_last_timestamps = []
  112. for chan_index in self._sigs_memmap[seg_index]:
  113. data_chan = self._sigs_memmap[seg_index][chan_index]
  114. keep = (data_chan['timestamp'] >= first) & (data_chan['timestamp'] <= last)
  115. data_chan = data_chan[keep]
  116. self._sigs_memmap[seg_index][chan_index] = data_chan
  117. all_sigs_length.append(data_chan.size * RECORD_SIZE)
  118. all_first_timestamps.append(data_chan[0]['timestamp'])
  119. all_last_timestamps.append(data_chan[-1]['timestamp'])
  120. # chech that all signals have the same lentgh and timestamp0 for this segment
  121. assert all(all_sigs_length[0] == e for e in all_sigs_length),\
  122. 'All signals do not have the same lentgh'
  123. assert all(all_first_timestamps[0] == e for e in all_first_timestamps),\
  124. 'All signals do not have the same first timestamp'
  125. assert all(all_samplerate[0] == e for e in all_samplerate),\
  126. 'All signals do not have the same sample rate'
  127. self._sig_length[seg_index] = all_sigs_length[0]
  128. self._sig_timestamp0[seg_index] = all_first_timestamps[0]
  129. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  130. self._sig_sampling_rate = sig_channels['sampling_rate'][0] # unique for channel
  131. # scan for spikes files
  132. unit_channels = []
  133. if len(info['spikes']) > 0:
  134. self._spikes_memmap = {}
  135. for seg_index, oe_index in enumerate(oe_indices):
  136. self._spikes_memmap[seg_index] = {}
  137. for spike_filename in info['spikes'][oe_index]:
  138. fullname = os.path.join(self.dirname, spike_filename)
  139. spike_info = read_file_header(fullname)
  140. spikes_dtype = make_spikes_dtype(fullname)
  141. # "STp106.0n0_2.spikes" to "STp106.0n0"
  142. name = spike_filename.replace('.spikes', '')
  143. if seg_index > 0:
  144. name = name.replace('_' + str(seg_index + 1), '')
  145. data_spike = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  146. dtype=spikes_dtype)
  147. self._spikes_memmap[seg_index][name] = data_spike
  148. # In each file 'sorted_id' indicate the number of cluster so number of units
  149. # so need to scan file for all segment to get units
  150. self._spike_sampling_rate = None
  151. for spike_filename_seg0 in info['spikes'][0]:
  152. name = spike_filename_seg0.replace('.spikes', '')
  153. fullname = os.path.join(self.dirname, spike_filename_seg0)
  154. spike_info = read_file_header(fullname)
  155. if self._spike_sampling_rate is None:
  156. self._spike_sampling_rate = spike_info['sampleRate']
  157. else:
  158. assert self._spike_sampling_rate == spike_info['sampleRate'],\
  159. 'mismatch in spike sampling rate'
  160. # scan all to detect several all unique(sorted_ids)
  161. all_sorted_ids = []
  162. for seg_index in range(nb_segment):
  163. data_spike = self._spikes_memmap[seg_index][name]
  164. all_sorted_ids += np.unique(data_spike['sorted_id']).tolist()
  165. all_sorted_ids = np.unique(all_sorted_ids)
  166. # supose all channel have the same gain
  167. wf_units = 'uV'
  168. wf_gain = 1000. / data_spike[0]['gains'][0]
  169. wf_offset = - (2**15) * wf_gain
  170. wf_left_sweep = 0
  171. wf_sampling_rate = spike_info['sampleRate']
  172. # each sorted_id is one channel
  173. for sorted_id in all_sorted_ids:
  174. unit_name = "{}#{}".format(name, sorted_id)
  175. unit_id = "{}#{}".format(name, sorted_id)
  176. unit_channels.append((unit_name, unit_id, wf_units,
  177. wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate))
  178. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  179. # event file are:
  180. # * all_channel.events (header + binray) --> event 0
  181. # and message.events (text based) --> event 1 not implemented yet
  182. event_channels = []
  183. self._events_memmap = {}
  184. for seg_index, oe_index in enumerate(oe_indices):
  185. if oe_index == 0:
  186. event_filename = 'all_channels.events'
  187. else:
  188. event_filename = 'all_channels_{}.events'.format(oe_index + 1)
  189. fullname = os.path.join(self.dirname, event_filename)
  190. event_info = read_file_header(fullname)
  191. self._event_sampling_rate = event_info['sampleRate']
  192. data_event = np.memmap(fullname, mode='r', offset=HEADER_SIZE,
  193. dtype=events_dtype)
  194. self._events_memmap[seg_index] = data_event
  195. event_channels.append(('all_channels', '', 'event'))
  196. # event_channels.append(('message', '', 'event')) # not implemented
  197. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  198. # main header
  199. self.header = {}
  200. self.header['nb_block'] = 1
  201. self.header['nb_segment'] = [nb_segment]
  202. self.header['signal_channels'] = sig_channels
  203. self.header['unit_channels'] = unit_channels
  204. self.header['event_channels'] = event_channels
  205. # Annotate some objects from coninuous files
  206. self._generate_minimal_annotations()
  207. bl_ann = self.raw_annotations['blocks'][0]
  208. for seg_index, oe_index in enumerate(oe_indices):
  209. seg_ann = bl_ann['segments'][seg_index]
  210. if len(info['continuous']) > 0:
  211. fullname = os.path.join(self.dirname, info['continuous'][oe_index][0])
  212. chan_info = read_file_header(fullname)
  213. seg_ann['openephys_version'] = chan_info['version']
  214. bl_ann['openephys_version'] = chan_info['version']
  215. seg_ann['date_created'] = chan_info['date_created']
  216. seg_ann['openephys_segment_index'] = oe_index + 1
  217. def _segment_t_start(self, block_index, seg_index):
  218. # segment start/stop are difine by continuous channels
  219. return self._sig_timestamp0[seg_index] / self._sig_sampling_rate
  220. def _segment_t_stop(self, block_index, seg_index):
  221. return (self._sig_timestamp0[seg_index] + self._sig_length[seg_index])\
  222. / self._sig_sampling_rate
  223. def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
  224. return self._sig_length[seg_index]
  225. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  226. return self._sig_timestamp0[seg_index] / self._sig_sampling_rate
  227. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  228. if i_start is None:
  229. i_start = 0
  230. if i_stop is None:
  231. i_stop = self._sig_length[seg_index]
  232. block_start = i_start // RECORD_SIZE
  233. block_stop = i_stop // RECORD_SIZE + 1
  234. sl0 = i_start % RECORD_SIZE
  235. sl1 = sl0 + (i_stop - i_start)
  236. if channel_indexes is None:
  237. channel_indexes = slice(None)
  238. channel_indexes = np.arange(self.header['signal_channels'].size)[channel_indexes]
  239. sigs_chunk = np.zeros((i_stop - i_start, len(channel_indexes)), dtype='int16')
  240. for i, chan_index in enumerate(channel_indexes):
  241. data = self._sigs_memmap[seg_index][chan_index]
  242. sub = data[block_start:block_stop]
  243. sigs_chunk[:, i] = sub['samples'].flatten()[sl0:sl1]
  244. return sigs_chunk
  245. def _get_spike_slice(self, seg_index, unit_index, t_start, t_stop):
  246. name, sorted_id = self.header['unit_channels'][unit_index]['name'].split('#')
  247. sorted_id = int(sorted_id)
  248. data_spike = self._spikes_memmap[seg_index][name]
  249. if t_start is None:
  250. t_start = self._segment_t_start(0, seg_index)
  251. if t_stop is None:
  252. t_stop = self._segment_t_stop(0, seg_index)
  253. ts0 = int(t_start * self._spike_sampling_rate)
  254. ts1 = int(t_stop * self._spike_sampling_rate)
  255. ts = data_spike['timestamp']
  256. keep = (data_spike['sorted_id'] == sorted_id) & (ts >= ts0) & (ts <= ts1)
  257. return data_spike, keep
  258. def _spike_count(self, block_index, seg_index, unit_index):
  259. data_spike, keep = self._get_spike_slice(seg_index, unit_index, None, None)
  260. return np.sum(keep)
  261. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  262. data_spike, keep = self._get_spike_slice(seg_index, unit_index, t_start, t_stop)
  263. return data_spike['timestamp'][keep]
  264. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  265. spike_times = spike_timestamps.astype(dtype) / self._spike_sampling_rate
  266. return spike_times
  267. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  268. data_spike, keep = self._get_spike_slice(seg_index, unit_index, t_start, t_stop)
  269. nb_chan = data_spike[0]['nb_channel']
  270. nb = np.sum(keep)
  271. waveforms = data_spike[keep]['samples'].flatten()
  272. waveforms = waveforms.reshape(nb, nb_chan, -1)
  273. return waveforms
  274. def _event_count(self, block_index, seg_index, event_channel_index):
  275. # assert event_channel_index==0
  276. return self._events_memmap[seg_index].size
  277. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  278. # assert event_channel_index==0
  279. if t_start is None:
  280. t_start = self._segment_t_start(block_index, seg_index)
  281. if t_stop is None:
  282. t_stop = self._segment_t_stop(block_index, seg_index)
  283. ts0 = int(t_start * self._event_sampling_rate)
  284. ts1 = int(t_stop * self._event_sampling_rate)
  285. ts = self._events_memmap[seg_index]['timestamp']
  286. keep = (ts >= ts0) & (ts <= ts1)
  287. subdata = self._events_memmap[seg_index][keep]
  288. timestamps = subdata['timestamp']
  289. # question what is the label????
  290. # here I put a combinaison
  291. labels = np.array(['{}#{}#{}'.format(int(d['event_type']),
  292. int(d['processor_id']),
  293. int(d['chan_id']))
  294. for d in subdata],
  295. dtype='U')
  296. durations = None
  297. return timestamps, durations, labels
  298. def _rescale_event_timestamp(self, event_timestamps, dtype):
  299. event_times = event_timestamps.astype(dtype) / self._event_sampling_rate
  300. return event_times
  301. def _rescale_epoch_duration(self, raw_duration, dtype):
  302. return None
  303. continuous_dtype = [('timestamp', 'int64'), ('nb_sample', 'uint16'),
  304. ('rec_num', 'uint16'), ('samples', '>i2', RECORD_SIZE),
  305. ('markers', 'uint8', 10)]
  306. events_dtype = [('timestamp', 'int64'), ('sample_pos', 'int16'),
  307. ('event_type', 'uint8'), ('processor_id', 'uint8'),
  308. ('event_id', 'uint8'), ('chan_id', 'uint8'),
  309. ('record_num', 'uint16')]
  310. # the dtype is dynamic and depend on nb_channel and nb_sample
  311. _base_spikes_dtype = [('event_stype', 'uint8'), ('timestamp', 'int64'),
  312. ('software_timestamp', 'int64'), ('source_id', 'uint16'),
  313. ('nb_channel', 'uint16'), ('nb_sample', 'uint16'),
  314. ('sorted_id', 'uint16'), ('electrode_id', 'uint16'),
  315. ('within_chan_index', 'uint16'), ('color', 'uint8', 3),
  316. ('pca', 'float32', 2), ('sampling_rate', 'uint16'),
  317. ('samples', 'uint16', None), ('gains', 'float32', None),
  318. ('thresholds', 'uint16', None), ('rec_num', 'uint16')]
  319. def make_spikes_dtype(filename):
  320. """
  321. Given the spike file make the appropriate dtype that depends on:
  322. * N - number of channels
  323. * M - samples per spike
  324. See documentation of file format.
  325. """
  326. # strangly the header do not have the sample size
  327. # So this do not work (too bad):
  328. # spike_info = read_file_header(filename)
  329. # N = spike_info['num_channels']
  330. # M =????
  331. # so we need to read the very first spike
  332. # but it will fail when 0 spikes (too bad)
  333. filesize = os.stat(filename).st_size
  334. if filesize >= (HEADER_SIZE + 23):
  335. with open(filename, mode='rb') as f:
  336. # M and N is at 1024 + 19 bytes
  337. f.seek(HEADER_SIZE + 19)
  338. N = np.fromfile(f, np.dtype('<u2'), 1)[0]
  339. M = np.fromfile(f, np.dtype('<u2'), 1)[0]
  340. else:
  341. spike_info = read_file_header(filename)
  342. N = spike_info['num_channels']
  343. M = 40 # this is in the original code from openephys
  344. # make a copy
  345. spikes_dtype = [e for e in _base_spikes_dtype]
  346. spikes_dtype[12] = ('samples', 'uint16', N * M)
  347. spikes_dtype[13] = ('gains', 'float32', N)
  348. spikes_dtype[14] = ('thresholds', 'uint16', N)
  349. return spikes_dtype
  350. def explore_folder(dirname):
  351. """
  352. This explores a folder and dispatch coninuous, event and spikes
  353. files by segment (aka recording session).
  354. The number of segments is checked with these rules
  355. "100_CH0.continuous" ---> seg_index 0
  356. "100_CH0_2.continuous" ---> seg_index 1
  357. "100_CH0_N.continuous" ---> seg_index N-1
  358. """
  359. filenames = os.listdir(dirname)
  360. filenames.sort()
  361. info = {}
  362. info['nb_segment'] = 0
  363. info['continuous'] = {}
  364. info['spikes'] = {}
  365. for filename in filenames:
  366. if filename.endswith('.continuous'):
  367. s = filename.replace('.continuous', '').split('_')
  368. if len(s) == 2:
  369. seg_index = 0
  370. else:
  371. seg_index = int(s[2]) - 1
  372. if seg_index not in info['continuous'].keys():
  373. info['continuous'][seg_index] = []
  374. info['continuous'][seg_index].append(filename)
  375. if (seg_index + 1) > info['nb_segment']:
  376. info['nb_segment'] += 1
  377. elif filename.endswith('.spikes'):
  378. s = filename.replace('.spikes', '').split('_')
  379. if len(s) == 1:
  380. seg_index = 0
  381. else:
  382. seg_index = int(s[1]) - 1
  383. if seg_index not in info['spikes'].keys():
  384. info['spikes'][seg_index] = []
  385. info['spikes'][seg_index].append(filename)
  386. if (seg_index + 1) > info['nb_segment']:
  387. info['nb_segment'] += 1
  388. # order continuous file by channel number within segment
  389. for seg_index, continuous_filenames in info['continuous'].items():
  390. chan_ids = {}
  391. for continuous_filename in continuous_filenames:
  392. s = continuous_filename.replace('.continuous', '').split('_')
  393. processor_id, ch_name = s[0], s[1]
  394. chan_str = re.split(r'(\d+)', s[1])[0]
  395. chan_id = int(ch_name.replace(chan_str, ''))
  396. if chan_str in chan_ids.keys():
  397. chan_ids[chan_str].append(chan_id)
  398. else:
  399. chan_ids[chan_str] = [chan_id]
  400. order = []
  401. for type in chan_ids.keys():
  402. order.append(np.argsort(chan_ids[type]))
  403. order = [list.tolist() for list in order]
  404. for i, sublist in enumerate(order):
  405. if i > 0:
  406. order[i] = [x + max(order[i - 1]) + 1 for x in order[i]]
  407. order = [item for sublist in order for item in sublist]
  408. continuous_filenames = [continuous_filenames[i] for i in order]
  409. info['continuous'][seg_index] = continuous_filenames
  410. # order spike files within segment
  411. for seg_index, spike_filenames in info['spikes'].items():
  412. names = []
  413. for spike_filename in spike_filenames:
  414. name = spike_filename.replace('.spikes', '')
  415. if seg_index > 0:
  416. name = name.replace('_' + str(seg_index + 1), '')
  417. names.append(name)
  418. order = np.argsort(names)
  419. spike_filenames = [spike_filenames[i] for i in order]
  420. info['spikes'][seg_index] = spike_filenames
  421. return info
  422. def read_file_header(filename):
  423. """Read header information from the first 1024 bytes of an OpenEphys file.
  424. See docs.
  425. """
  426. header = {}
  427. with open(filename, mode='rb') as f:
  428. # Read the data as a string
  429. # Remove newlines and redundant "header." prefixes
  430. # The result should be a series of "key = value" strings, separated
  431. # by semicolons.
  432. header_string = f.read(HEADER_SIZE).replace(b'\n', b'').replace(b'header.', b'')
  433. # Parse each key = value string separately
  434. for pair in header_string.split(b';'):
  435. if b'=' in pair:
  436. key, value = pair.split(b' = ')
  437. key = key.strip().decode('ascii')
  438. value = value.strip()
  439. # Convert some values to numeric
  440. if key in ['bitVolts', 'sampleRate']:
  441. header[key] = float(value)
  442. elif key in ['blockLength', 'bufferSize', 'header_bytes', 'num_channels']:
  443. header[key] = int(value)
  444. else:
  445. # Keep as string
  446. header[key] = value.decode('ascii')
  447. return header