plexonrawio.py 18 KB


  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading the old data format from Plexon
  4. acquisition system (.plx)
  5. Note that Plexon now use a new format PL2 which is NOT
  6. supported by this IO.
  7. Compatible with versions 100 to 106.
  8. Other versions have not been tested.
  9. This IO is developed thanks to the header file downloadable from:
  10. http://www.plexon.com/software-downloads
  11. This IO was rewritten in 2017 and this was a huge pain because
  12. the underlying file format is really inefficient.
  13. The rewrite is now based on numpy dtype and not on Python struct.
  14. This should be faster.
  15. If one day, somebody use it, consider to offer me a beer.
  16. Author: Samuel Garcia
  17. """
  18. from __future__ import print_function, division, absolute_import
  19. # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
  20. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  21. _event_channel_dtype)
  22. import numpy as np
  23. from collections import OrderedDict
  24. import datetime
  25. class PlexonRawIO(BaseRawIO):
  26. extensions = ['plx']
  27. rawmode = 'one-file'
  28. def __init__(self, filename=''):
  29. BaseRawIO.__init__(self)
  30. self.filename = filename
  31. def _source_name(self):
  32. return self.filename
  33. def _parse_header(self):
  34. # global header
  35. with open(self.filename, 'rb') as fid:
  36. offset0 = 0
  37. global_header = read_as_dict(fid, GlobalHeader, offset=offset0)
  38. rec_datetime = datetime.datetime(global_header['Year'],
  39. global_header['Month'],
  40. global_header['Day'],
  41. global_header['Hour'],
  42. global_header['Minute'],
  43. global_header['Second'])
  44. # dsp channels header = spikes and waveforms
  45. nb_unit_chan = global_header['NumDSPChannels']
  46. offset1 = np.dtype(GlobalHeader).itemsize
  47. dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r',
  48. offset=offset1, shape=(nb_unit_chan,))
  49. # event channel header
  50. nb_event_chan = global_header['NumEventChannels']
  51. offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan
  52. eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r',
  53. offset=offset2, shape=(nb_event_chan,))
  54. # slow channel header = signal
  55. nb_sig_chan = global_header['NumSlowChannels']
  56. offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan
  57. slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r',
  58. offset=offset3, shape=(nb_sig_chan,))
  59. offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan
  60. # loop over data blocks and put them by type and channel
  61. block_headers = {1: {c: [] for c in dspChannelHeaders['Channel']},
  62. 4: {c: [] for c in eventHeaders['Channel']},
  63. 5: {c: [] for c in slowChannelHeaders['Channel']},
  64. }
  65. block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']},
  66. 4: {c: [] for c in eventHeaders['Channel']},
  67. 5: {c: [] for c in slowChannelHeaders['Channel']},
  68. }
  69. data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r')
  70. pos = offset4
  71. while pos < data.size:
  72. bl_header = data[pos:pos + 16].view(DataBlockHeader)[0]
  73. length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16
  74. bl_type = int(bl_header['Type'])
  75. chan_id = int(bl_header['Channel'])
  76. block_headers[bl_type][chan_id].append(bl_header)
  77. block_pos[bl_type][chan_id].append(pos)
  78. pos += length
  79. self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
  80. 2 ** 32 + bl_header['TimeStamp']
  81. # ... and finalize them in self._data_blocks
  82. # for a faster acces depending on type (1, 4, 5)
  83. self._data_blocks = {}
  84. dt_base = [('pos', 'int64'), ('timestamp', 'int64'), ('size', 'int64')]
  85. dtype_by_bltype = {
  86. # Spikes and waveforms
  87. 1: np.dtype(dt_base + [('unit_id', 'uint16'), ('n1', 'uint16'), ('n2', 'uint16'), ]),
  88. # Events
  89. 4: np.dtype(dt_base + [('label', 'uint16'), ]),
  90. # Signals
  91. 5: np.dtype(dt_base + [('cumsum', 'int64'), ]),
  92. }
  93. for bl_type in block_headers:
  94. self._data_blocks[bl_type] = {}
  95. for chan_id in block_headers[bl_type]:
  96. bl_header = np.array(block_headers[bl_type][chan_id], dtype=DataBlockHeader)
  97. bl_pos = np.array(block_pos[bl_type][chan_id], dtype='int64')
  98. timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
  99. 2 ** 32 + bl_header['TimeStamp']
  100. n1 = bl_header['NumberOfWaveforms']
  101. n2 = bl_header['NumberOfWordsInWaveform']
  102. dt = dtype_by_bltype[bl_type]
  103. data_block = np.empty(bl_pos.size, dtype=dt)
  104. data_block['pos'] = bl_pos + 16
  105. data_block['timestamp'] = timestamps
  106. data_block['size'] = n1 * n2 * 2
  107. if bl_type == 1: # Spikes and waveforms
  108. data_block['unit_id'] = bl_header['Unit']
  109. data_block['n1'] = n1
  110. data_block['n2'] = n2
  111. elif bl_type == 4: # Events
  112. data_block['label'] = bl_header['Unit']
  113. elif bl_type == 5: # Signals
  114. if data_block.size > 0:
  115. # cumulative some of sample index for fast acces to chunks
  116. data_block['cumsum'][0] = 0
  117. data_block['cumsum'][1:] = np.cumsum(data_block['size'][:-1]) // 2
  118. self._data_blocks[bl_type][chan_id] = data_block
  119. # signals channels
  120. sig_channels = []
  121. all_sig_length = []
  122. for chan_index in range(nb_sig_chan):
  123. h = slowChannelHeaders[chan_index]
  124. name = h['Name'].decode('utf8')
  125. chan_id = h['Channel']
  126. length = self._data_blocks[5][chan_id]['size'].sum() // 2
  127. if length == 0:
  128. continue # channel not added
  129. all_sig_length.append(length)
  130. sampling_rate = float(h['ADFreq'])
  131. sig_dtype = 'int16'
  132. units = '' # I dont't knwon units
  133. if global_header['Version'] in [100, 101]:
  134. gain = 5000. / (2048 * h['Gain'] * 1000.)
  135. elif global_header['Version'] in [102]:
  136. gain = 5000. / (2048 * h['Gain'] * h['PreampGain'])
  137. elif global_header['Version'] >= 103:
  138. gain = global_header['SlowMaxMagnitudeMV'] / (
  139. .5 * (2 ** global_header['BitsPerSpikeSample']) *
  140. h['Gain'] * h['PreampGain'])
  141. offset = 0.
  142. group_id = 0
  143. sig_channels.append((name, chan_id, sampling_rate, sig_dtype,
  144. units, gain, offset, group_id))
  145. if len(all_sig_length) > 0:
  146. self._signal_length = min(all_sig_length)
  147. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  148. self._global_ssampling_rate = global_header['ADFrequency']
  149. if slowChannelHeaders.size > 0:
  150. assert np.unique(slowChannelHeaders['ADFreq']
  151. ).size == 1, 'Signal do not have the same sampling rate'
  152. self._sig_sampling_rate = float(slowChannelHeaders['ADFreq'][0])
  153. # Determine number of units per channels
  154. self.internal_unit_ids = []
  155. for chan_id, data_clock in self._data_blocks[1].items():
  156. unit_ids = np.unique(data_clock['unit_id'])
  157. for unit_id in unit_ids:
  158. self.internal_unit_ids.append((chan_id, unit_id))
  159. # Spikes channels
  160. unit_channels = []
  161. for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids):
  162. c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0]
  163. h = dspChannelHeaders[c]
  164. name = h['Name'].decode('utf8')
  165. _id = 'ch{}#{}'.format(chan_id, unit_id)
  166. wf_units = ''
  167. if global_header['Version'] < 103:
  168. wf_gain = 3000. / (2048 * h['Gain'] * 1000.)
  169. elif 103 <= global_header['Version'] < 105:
  170. wf_gain = global_header['SpikeMaxMagnitudeMV'] / (
  171. .5 * 2. ** (global_header['BitsPerSpikeSample']) *
  172. h['Gain'] * 1000.)
  173. elif global_header['Version'] >= 105:
  174. wf_gain = global_header['SpikeMaxMagnitudeMV'] / (
  175. .5 * 2. ** (global_header['BitsPerSpikeSample']) *
  176. h['Gain'] * global_header['SpikePreAmpGain'])
  177. wf_offset = 0.
  178. wf_left_sweep = -1 # DONT KNOWN
  179. wf_sampling_rate = global_header['WaveformFreq']
  180. unit_channels.append((name, _id, wf_units, wf_gain, wf_offset,
  181. wf_left_sweep, wf_sampling_rate))
  182. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  183. # Event channels
  184. event_channels = []
  185. for chan_index in range(nb_event_chan):
  186. h = eventHeaders[chan_index]
  187. chan_id = h['Channel']
  188. name = h['Name'].decode('utf8')
  189. _id = h['Channel']
  190. event_channels.append((name, _id, 'event'))
  191. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  192. # fille into header dict
  193. self.header = {}
  194. self.header['nb_block'] = 1
  195. self.header['nb_segment'] = [1]
  196. self.header['signal_channels'] = sig_channels
  197. self.header['unit_channels'] = unit_channels
  198. self.header['event_channels'] = event_channels
  199. # Annotations
  200. self._generate_minimal_annotations()
  201. bl_annotations = self.raw_annotations['blocks'][0]
  202. seg_annotations = bl_annotations['segments'][0]
  203. for d in (bl_annotations, seg_annotations):
  204. d['rec_datetime'] = rec_datetime
  205. d['plexon_version'] = global_header['Version']
  206. def _segment_t_start(self, block_index, seg_index):
  207. return 0.
  208. def _segment_t_stop(self, block_index, seg_index):
  209. t_stop1 = float(self._last_timestamps) / self._global_ssampling_rate
  210. if hasattr(self, '_signal_length'):
  211. t_stop2 = self._signal_length / self._sig_sampling_rate
  212. return max(t_stop1, t_stop2)
  213. else:
  214. return t_stop1
  215. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  216. return self._signal_length
  217. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  218. return 0.
  219. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  220. if i_start is None:
  221. i_start = 0
  222. if i_stop is None:
  223. i_stop = self._signal_length
  224. if channel_indexes is None:
  225. channel_indexes = np.arange(self.header['signal_channels'].size)
  226. raw_signals = np.zeros((i_stop - i_start, len(channel_indexes)), dtype='int16')
  227. for c, channel_index in enumerate(channel_indexes):
  228. chan_header = self.header['signal_channels'][channel_index]
  229. chan_id = chan_header['id']
  230. data_blocks = self._data_blocks[5][chan_id]
  231. # loop over data blocks and get chunks
  232. bl0 = np.searchsorted(data_blocks['cumsum'], i_start, side='left')
  233. bl1 = np.searchsorted(data_blocks['cumsum'], i_stop, side='left')
  234. ind = 0
  235. for bl in range(bl0, bl1):
  236. ind0 = data_blocks[bl]['pos']
  237. ind1 = data_blocks[bl]['size'] + ind0
  238. data = self._memmap[ind0:ind1].view('int16')
  239. if bl == bl1 - 1:
  240. # right border
  241. # be carfull that bl could be both bl0 and bl1!!
  242. border = data.size - (i_stop - data_blocks[bl]['cumsum'])
  243. data = data[:-border]
  244. if bl == bl0:
  245. # left border
  246. border = i_start - data_blocks[bl]['cumsum']
  247. data = data[border:]
  248. raw_signals[ind:data.size + ind, c] = data
  249. ind += data.size
  250. return raw_signals
  251. def _get_internal_mask(self, data_block, t_start, t_stop):
  252. timestamps = data_block['timestamp']
  253. if t_start is None:
  254. lim0 = 0
  255. else:
  256. lim0 = int(t_start * self._global_ssampling_rate)
  257. if t_stop is None:
  258. lim1 = self._last_timestamps
  259. else:
  260. lim1 = int(t_stop * self._global_ssampling_rate)
  261. keep = (timestamps >= lim0) & (timestamps <= lim1)
  262. return keep
  263. def _spike_count(self, block_index, seg_index, unit_index):
  264. chan_id, unit_id = self.internal_unit_ids[unit_index]
  265. data_block = self._data_blocks[1][chan_id]
  266. nb_spike = np.sum(data_block['unit_id'] == unit_id)
  267. return nb_spike
  268. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  269. chan_id, unit_id = self.internal_unit_ids[unit_index]
  270. data_block = self._data_blocks[1][chan_id]
  271. keep = self._get_internal_mask(data_block, t_start, t_stop)
  272. keep &= data_block['unit_id'] == unit_id
  273. spike_timestamps = data_block[keep]['timestamp']
  274. return spike_timestamps
  275. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  276. spike_times = spike_timestamps.astype(dtype)
  277. spike_times /= self._global_ssampling_rate
  278. return spike_times
  279. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  280. chan_id, unit_id = self.internal_unit_ids[unit_index]
  281. data_block = self._data_blocks[1][chan_id]
  282. n1 = data_block['n1'][0]
  283. n2 = data_block['n2'][0]
  284. keep = self._get_internal_mask(data_block, t_start, t_stop)
  285. keep &= data_block['unit_id'] == unit_id
  286. data_block = data_block[keep]
  287. nb_spike = data_block.size
  288. waveforms = np.zeros((nb_spike, n1, n2), dtype='int16')
  289. for i, db in enumerate(data_block):
  290. ind0 = db['pos']
  291. ind1 = db['size'] + ind0
  292. data = self._memmap[ind0:ind1].view('int16').reshape(n1, n2)
  293. waveforms[i, :, :] = data
  294. return waveforms
  295. def _event_count(self, block_index, seg_index, event_channel_index):
  296. chan_id = int(self.header['event_channels'][event_channel_index]['id'])
  297. nb_event = self._data_blocks[4][chan_id].size
  298. return nb_event
  299. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  300. chan_id = int(self.header['event_channels'][event_channel_index]['id'])
  301. data_block = self._data_blocks[4][chan_id]
  302. keep = self._get_internal_mask(data_block, t_start, t_stop)
  303. db = data_block[keep]
  304. timestamps = db['timestamp']
  305. labels = db['label'].astype('U')
  306. durations = None
  307. return timestamps, durations, labels
  308. def _rescale_event_timestamp(self, event_timestamps, dtype):
  309. event_times = event_timestamps.astype(dtype)
  310. event_times /= self._global_ssampling_rate
  311. return event_times
  312. def read_as_dict(fid, dtype, offset=None):
  313. """
  314. Given a file descriptor
  315. and a numpy.dtype of the binary struct return a dict.
  316. Make conversion for strings.
  317. """
  318. if offset is not None:
  319. fid.seek(offset)
  320. dt = np.dtype(dtype)
  321. h = np.frombuffer(fid.read(dt.itemsize), dt)[0]
  322. info = OrderedDict()
  323. for k in dt.names:
  324. v = h[k]
  325. if dt[k].kind == 'S':
  326. v = v.decode('utf8')
  327. v = v.replace('\x03', '')
  328. v = v.replace('\x00', '')
  329. info[k] = v
  330. return info
  331. GlobalHeader = [
  332. ('MagicNumber', 'uint32'),
  333. ('Version', 'int32'),
  334. ('Comment', 'S128'),
  335. ('ADFrequency', 'int32'),
  336. ('NumDSPChannels', 'int32'),
  337. ('NumEventChannels', 'int32'),
  338. ('NumSlowChannels', 'int32'),
  339. ('NumPointsWave', 'int32'),
  340. ('NumPointsPreThr', 'int32'),
  341. ('Year', 'int32'),
  342. ('Month', 'int32'),
  343. ('Day', 'int32'),
  344. ('Hour', 'int32'),
  345. ('Minute', 'int32'),
  346. ('Second', 'int32'),
  347. ('FastRead', 'int32'),
  348. ('WaveformFreq', 'int32'),
  349. ('LastTimestamp', 'float64'),
  350. # version >103
  351. ('Trodalness', 'uint8'),
  352. ('DataTrodalness', 'uint8'),
  353. ('BitsPerSpikeSample', 'uint8'),
  354. ('BitsPerSlowSample', 'uint8'),
  355. ('SpikeMaxMagnitudeMV', 'uint16'),
  356. ('SlowMaxMagnitudeMV', 'uint16'),
  357. # version 105
  358. ('SpikePreAmpGain', 'uint16'),
  359. # version 106
  360. ('AcquiringSoftware', 'S18'),
  361. ('ProcessingSoftware', 'S18'),
  362. ('Padding', 'S10'),
  363. # all version
  364. ('TSCounts', 'int32', (650,)),
  365. ('WFCounts', 'int32', (650,)),
  366. ('EVCounts', 'int32', (512,)),
  367. ]
  368. DspChannelHeader = [
  369. ('Name', 'S32'),
  370. ('SIGName', 'S32'),
  371. ('Channel', 'int32'),
  372. ('WFRate', 'int32'),
  373. ('SIG', 'int32'),
  374. ('Ref', 'int32'),
  375. ('Gain', 'int32'),
  376. ('Filter', 'int32'),
  377. ('Threshold', 'int32'),
  378. ('Method', 'int32'),
  379. ('NUnits', 'int32'),
  380. ('Template', 'uint16', (320,)),
  381. ('Fit', 'int32', (5,)),
  382. ('SortWidth', 'int32'),
  383. ('Boxes', 'uint16', (40,)),
  384. ('SortBeg', 'int32'),
  385. # version 105
  386. ('Comment', 'S128'),
  387. # version 106
  388. ('SrcId', 'uint8'),
  389. ('reserved', 'uint8'),
  390. ('ChanId', 'uint16'),
  391. ('Padding', 'int32', (10,)),
  392. ]
  393. EventChannelHeader = [
  394. ('Name', 'S32'),
  395. ('Channel', 'int32'),
  396. # version 105
  397. ('Comment', 'S128'),
  398. # version 106
  399. ('SrcId', 'uint8'),
  400. ('reserved', 'uint8'),
  401. ('ChanId', 'uint16'),
  402. ('Padding', 'int32', (32,)),
  403. ]
  404. SlowChannelHeader = [
  405. ('Name', 'S32'),
  406. ('Channel', 'int32'),
  407. ('ADFreq', 'int32'),
  408. ('Gain', 'int32'),
  409. ('Enabled', 'int32'),
  410. ('PreampGain', 'int32'),
  411. # version 104
  412. ('SpikeChannel', 'int32'),
  413. # version 105
  414. ('Comment', 'S128'),
  415. # version 106
  416. ('SrcId', 'uint8'),
  417. ('reserved', 'uint8'),
  418. ('ChanId', 'uint16'),
  419. ('Padding', 'int32', (27,)),
  420. ]
  421. DataBlockHeader = [
  422. ('Type', 'uint16'),
  423. ('UpperByteOf5ByteTimestamp', 'uint16'),
  424. ('TimeStamp', 'int32'),
  425. ('Channel', 'uint16'),
  426. ('Unit', 'uint16'),
  427. ('NumberOfWaveforms', 'uint16'),
  428. ('NumberOfWordsInWaveform', 'uint16'),
  429. ] # 16 bytes