basefromrawio.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # -*- coding: utf-8 -*-
  2. """
  3. BaseFromRaw
  4. ======
  5. BaseFromRaw implement a bridge between the new neo.rawio API
  6. and the neo.io legacy that give neo.core object.
  7. The neo.rawio API is more restricted and limited and do not cover tricky
  8. cases with asymetrical tree of neo object.
  9. But if a format is done in neo.rawio the neo.io is done for free
  10. by inheritance of this class.
  11. """
  12. # needed for python 3 compatibility
  13. from __future__ import print_function, division, absolute_import
  14. # from __future__ import unicode_literals is not compatible with numpy.dtype both py2 py3
  15. import warnings
  16. import collections
  17. import logging
  18. import numpy as np
  19. from neo import logging_handler
  20. from neo.core import (AnalogSignal, Block,
  21. Epoch, Event,
  22. IrregularlySampledSignal,
  23. ChannelIndex,
  24. Segment, SpikeTrain, Unit)
  25. from neo.io.baseio import BaseIO
  26. import quantities as pq
  27. class BaseFromRaw(BaseIO):
  28. """
  29. This implement generic reader on top of RawIO reader.
  30. Arguments depend on `mode` (dir or file)
  31. File case::
  32. reader = BlackRockIO(filename='FileSpec2.3001.nev')
  33. Dir case::
  34. reader = NeuralynxIO(dirname='Cheetah_v5.7.4/original_data')
  35. Other arguments are IO specific.
  36. """
  37. is_readable = True
  38. is_writable = False
  39. supported_objects = [Block, Segment, AnalogSignal,
  40. SpikeTrain, Unit, ChannelIndex, Event, Epoch]
  41. readable_objects = [Block, Segment]
  42. writeable_objects = []
  43. support_lazy = True
  44. name = 'BaseIO'
  45. description = ''
  46. extentions = []
  47. mode = 'file'
  48. _prefered_signal_group_mode = 'split-all' # 'group-by-same-units'
  49. _prefered_units_group_mode = 'split-all' # 'all-in-one'
  50. def __init__(self, *args, **kargs):
  51. BaseIO.__init__(self, *args, **kargs)
  52. self.parse_header()
  53. def read_block(self, block_index=0, lazy=False, signal_group_mode=None,
  54. units_group_mode=None, load_waveforms=False):
  55. """
  56. :param block_index: int default 0. In case of several block block_index can be specified.
  57. :param lazy: False by default.
  58. :param signal_group_mode: 'split-all' or 'group-by-same-units' (default depend IO):
  59. This control behavior for grouping channels in AnalogSignal.
  60. * 'split-all': each channel will give an AnalogSignal
  61. * 'group-by-same-units' all channel sharing the same quantity units ar grouped in
  62. a 2D AnalogSignal
  63. :param units_group_mode: 'split-all' or 'all-in-one'(default depend IO)
  64. This control behavior for grouping Unit in ChannelIndex:
  65. * 'split-all': each neo.Unit is assigned to a new neo.ChannelIndex
  66. * 'all-in-one': all neo.Unit are grouped in the same neo.ChannelIndex
  67. (global spike sorting for instance)
  68. :param load_waveforms: False by default. Control SpikeTrains.waveforms is None or not.
  69. """
  70. if lazy:
  71. warnings.warn(
  72. "Lazy is deprecated and will be replaced by ProxyObject functionality.",
  73. DeprecationWarning)
  74. if signal_group_mode is None:
  75. signal_group_mode = self._prefered_signal_group_mode
  76. if units_group_mode is None:
  77. units_group_mode = self._prefered_units_group_mode
  78. # annotations
  79. bl_annotations = dict(self.raw_annotations['blocks'][block_index])
  80. bl_annotations.pop('segments')
  81. bl_annotations = check_annotations(bl_annotations)
  82. bl = Block(**bl_annotations)
  83. # ChannelIndex are plit in 2 parts:
  84. # * some for AnalogSignals
  85. # * some for Units
  86. # ChannelIndex for AnalogSignals
  87. all_channels = self.header['signal_channels']
  88. channel_indexes_list = self.get_group_channel_indexes()
  89. for channel_index in channel_indexes_list:
  90. for i, (ind_within, ind_abs) in self._make_signal_channel_subgroups(
  91. channel_index, signal_group_mode=signal_group_mode).items():
  92. chidx_annotations = {}
  93. if signal_group_mode == "split-all":
  94. chidx_annotations = self.raw_annotations['signal_channels'][i]
  95. elif signal_group_mode == "group-by-same-units":
  96. for key in list(self.raw_annotations['signal_channels'][i].keys()):
  97. chidx_annotations[key] = []
  98. for j in ind_abs:
  99. for key in list(self.raw_annotations['signal_channels'][i].keys()):
  100. chidx_annotations[key].append(self.raw_annotations[
  101. 'signal_channels'][j][key])
  102. if 'name' in list(chidx_annotations.keys()):
  103. chidx_annotations.pop('name')
  104. chidx_annotations = check_annotations(chidx_annotations)
  105. ch_names = all_channels[ind_abs]['name'].astype('S')
  106. neo_channel_index = ChannelIndex(index=ind_within,
  107. channel_names=ch_names,
  108. channel_ids=all_channels[ind_abs]['id'],
  109. name='Channel group {}'.format(i),
  110. **chidx_annotations)
  111. bl.channel_indexes.append(neo_channel_index)
  112. # ChannelIndex and Unit
  113. # 2 case are possible in neo defifferent IO have choosen one or other:
  114. # * All units are grouped in the same ChannelIndex and indexes are all channels:
  115. # 'all-in-one'
  116. # * Each units is assigned to one ChannelIndex: 'split-all'
  117. # This is kept for compatibility
  118. unit_channels = self.header['unit_channels']
  119. if units_group_mode == 'all-in-one':
  120. if unit_channels.size > 0:
  121. channel_index = ChannelIndex(index=np.array([], dtype='i'),
  122. name='ChannelIndex for all Unit')
  123. bl.channel_indexes.append(channel_index)
  124. for c in range(unit_channels.size):
  125. unit_annotations = self.raw_annotations['unit_channels'][c]
  126. unit_annotations = check_annotations(unit_annotations)
  127. unit = Unit(**unit_annotations)
  128. channel_index.units.append(unit)
  129. elif units_group_mode == 'split-all':
  130. for c in range(len(unit_channels)):
  131. unit_annotations = self.raw_annotations['unit_channels'][c]
  132. unit_annotations = check_annotations(unit_annotations)
  133. unit = Unit(**unit_annotations)
  134. channel_index = ChannelIndex(index=np.array([], dtype='i'),
  135. name='ChannelIndex for Unit')
  136. channel_index.units.append(unit)
  137. bl.channel_indexes.append(channel_index)
  138. # Read all segments
  139. for seg_index in range(self.segment_count(block_index)):
  140. seg = self.read_segment(block_index=block_index, seg_index=seg_index,
  141. lazy=lazy, signal_group_mode=signal_group_mode,
  142. load_waveforms=load_waveforms)
  143. bl.segments.append(seg)
  144. # create link to other containers ChannelIndex and Units
  145. for seg in bl.segments:
  146. for c, anasig in enumerate(seg.analogsignals):
  147. bl.channel_indexes[c].analogsignals.append(anasig)
  148. nsig = len(seg.analogsignals)
  149. for c, sptr in enumerate(seg.spiketrains):
  150. if units_group_mode == 'all-in-one':
  151. bl.channel_indexes[nsig].units[c].spiketrains.append(sptr)
  152. elif units_group_mode == 'split-all':
  153. bl.channel_indexes[nsig + c].units[0].spiketrains.append(sptr)
  154. bl.create_many_to_one_relationship()
  155. return bl
  156. def read_segment(self, block_index=0, seg_index=0, lazy=False,
  157. signal_group_mode=None, load_waveforms=False, time_slice=None):
  158. """
  159. :param block_index: int default 0. In case of several block block_index can be specified.
  160. :param seg_index: int default 0. Index of segment.
  161. :param lazy: False by default.
  162. :param signal_group_mode: 'split-all' or 'group-by-same-units' (default depend IO):
  163. This control behavior for grouping channels in AnalogSignal.
  164. * 'split-all': each channel will give an AnalogSignal
  165. * 'group-by-same-units' all channel sharing the same quantity units ar grouped in
  166. a 2D AnalogSignal
  167. :param load_waveforms: False by default. Control SpikeTrains.waveforms is None or not.
  168. :param time_slice: None by default means no limit.
  169. A time slice is (t_start, t_stop) both are quantities.
  170. All object AnalogSignal, SpikeTrain, Event, Epoch will load only in the slice.
  171. """
  172. if lazy:
  173. warnings.warn(
  174. "Lazy is deprecated and will be replaced by ProxyObject functionality.",
  175. DeprecationWarning)
  176. if signal_group_mode is None:
  177. signal_group_mode = self._prefered_signal_group_mode
  178. # annotations
  179. seg_annotations = dict(self.raw_annotations['blocks'][block_index]['segments'][seg_index])
  180. for k in ('signals', 'units', 'events'):
  181. seg_annotations.pop(k)
  182. seg_annotations = check_annotations(seg_annotations)
  183. seg = Segment(index=seg_index, **seg_annotations)
  184. seg_t_start = self.segment_t_start(block_index, seg_index) * pq.s
  185. seg_t_stop = self.segment_t_stop(block_index, seg_index) * pq.s
  186. # get only a slice of objects limited by t_start and t_stop time_slice = (t_start, t_stop)
  187. if time_slice is None:
  188. t_start, t_stop = None, None
  189. t_start_, t_stop_ = None, None
  190. else:
  191. assert not lazy, 'time slice only work when not lazy'
  192. t_start, t_stop = time_slice
  193. t_start = ensure_second(t_start)
  194. t_stop = ensure_second(t_stop)
  195. # checks limits
  196. if t_start < seg_t_start:
  197. t_start = seg_t_start
  198. if t_stop > seg_t_stop:
  199. t_stop = seg_t_stop
  200. # in float format in second (for rawio clip)
  201. t_start_, t_stop_ = float(t_start.magnitude), float(t_stop.magnitude)
  202. # new spiketrain limits
  203. seg_t_start = t_start
  204. seg_t_stop = t_stop
  205. # AnalogSignal
  206. signal_channels = self.header['signal_channels']
  207. if signal_channels.size > 0:
  208. channel_indexes_list = self.get_group_channel_indexes()
  209. for channel_indexes in channel_indexes_list:
  210. sr = self.get_signal_sampling_rate(channel_indexes) * pq.Hz
  211. sig_t_start = self.get_signal_t_start(
  212. block_index, seg_index, channel_indexes) * pq.s
  213. sig_size = self.get_signal_size(block_index=block_index, seg_index=seg_index,
  214. channel_indexes=channel_indexes)
  215. if not lazy:
  216. # in case of time_slice get: get i_start, i_stop, new sig_t_start
  217. if t_stop is not None:
  218. i_stop = int((t_stop - sig_t_start).magnitude * sr.magnitude)
  219. if i_stop > sig_size:
  220. i_stop = sig_size
  221. else:
  222. i_stop = None
  223. if t_start is not None:
  224. i_start = int((t_start - sig_t_start).magnitude * sr.magnitude)
  225. if i_start < 0:
  226. i_start = 0
  227. sig_t_start += (i_start / sr).rescale('s')
  228. else:
  229. i_start = None
  230. raw_signal = self.get_analogsignal_chunk(block_index=block_index,
  231. seg_index=seg_index, i_start=i_start,
  232. i_stop=i_stop,
  233. channel_indexes=channel_indexes)
  234. float_signal = self.rescale_signal_raw_to_float(
  235. raw_signal,
  236. dtype='float32',
  237. channel_indexes=channel_indexes)
  238. for i, (ind_within, ind_abs) in self._make_signal_channel_subgroups(
  239. channel_indexes,
  240. signal_group_mode=signal_group_mode).items():
  241. units = np.unique(signal_channels[ind_abs]['units'])
  242. assert len(units) == 1
  243. units = ensure_signal_units(units[0])
  244. if signal_group_mode == 'split-all':
  245. # in that case annotations by channel is OK
  246. chan_index = ind_abs[0]
  247. d = self.raw_annotations['blocks'][block_index]['segments'][seg_index][
  248. 'signals'][chan_index]
  249. annotations = dict(d)
  250. if 'name' not in annotations:
  251. annotations['name'] = signal_channels['name'][chan_index]
  252. else:
  253. # when channel are grouped by same unit
  254. # annotations have channel_names and channel_ids array
  255. # this will be moved in array annotations soon
  256. annotations = {}
  257. annotations['name'] = 'Channel bundle ({}) '.format(
  258. ','.join(signal_channels[ind_abs]['name']))
  259. annotations['channel_names'] = signal_channels[ind_abs]['name']
  260. annotations['channel_ids'] = signal_channels[ind_abs]['id']
  261. annotations = check_annotations(annotations)
  262. if lazy:
  263. anasig = AnalogSignal(np.array([]), units=units, copy=False,
  264. sampling_rate=sr, t_start=sig_t_start, **annotations)
  265. anasig.lazy_shape = (sig_size, len(ind_within))
  266. else:
  267. anasig = AnalogSignal(float_signal[:, ind_within], units=units, copy=False,
  268. sampling_rate=sr, t_start=sig_t_start, **annotations)
  269. seg.analogsignals.append(anasig)
  270. # SpikeTrain and waveforms (optional)
  271. unit_channels = self.header['unit_channels']
  272. for unit_index in range(len(unit_channels)):
  273. if not lazy and load_waveforms:
  274. raw_waveforms = self.get_spike_raw_waveforms(block_index=block_index,
  275. seg_index=seg_index,
  276. unit_index=unit_index,
  277. t_start=t_start_, t_stop=t_stop_)
  278. float_waveforms = self.rescale_waveforms_to_float(raw_waveforms, dtype='float32',
  279. unit_index=unit_index)
  280. wf_units = ensure_signal_units(unit_channels['wf_units'][unit_index])
  281. waveforms = pq.Quantity(float_waveforms, units=wf_units,
  282. dtype='float32', copy=False)
  283. wf_sampling_rate = unit_channels['wf_sampling_rate'][unit_index]
  284. wf_left_sweep = unit_channels['wf_left_sweep'][unit_index]
  285. if wf_left_sweep > 0:
  286. wf_left_sweep = float(wf_left_sweep) / wf_sampling_rate * pq.s
  287. else:
  288. wf_left_sweep = None
  289. wf_sampling_rate = wf_sampling_rate * pq.Hz
  290. else:
  291. waveforms = None
  292. wf_left_sweep = None
  293. wf_sampling_rate = None
  294. d = self.raw_annotations['blocks'][block_index]['segments'][seg_index]['units'][
  295. unit_index]
  296. annotations = dict(d)
  297. if 'name' not in annotations:
  298. annotations['name'] = unit_channels['name'][c]
  299. annotations = check_annotations(annotations)
  300. if not lazy:
  301. spike_timestamp = self.get_spike_timestamps(block_index=block_index,
  302. seg_index=seg_index,
  303. unit_index=unit_index,
  304. t_start=t_start_, t_stop=t_stop_)
  305. spike_times = self.rescale_spike_timestamp(spike_timestamp, 'float64')
  306. sptr = SpikeTrain(spike_times, units='s', copy=False,
  307. t_start=seg_t_start, t_stop=seg_t_stop,
  308. waveforms=waveforms, left_sweep=wf_left_sweep,
  309. sampling_rate=wf_sampling_rate, **annotations)
  310. else:
  311. nb = self.spike_count(block_index=block_index, seg_index=seg_index,
  312. unit_index=unit_index)
  313. sptr = SpikeTrain(np.array([]), units='s', copy=False, t_start=seg_t_start,
  314. t_stop=seg_t_stop, **annotations)
  315. sptr.lazy_shape = (nb,)
  316. seg.spiketrains.append(sptr)
  317. # Events/Epoch
  318. event_channels = self.header['event_channels']
  319. for chan_ind in range(len(event_channels)):
  320. if not lazy:
  321. ev_timestamp, ev_raw_durations, ev_labels = self.get_event_timestamps(
  322. block_index=block_index,
  323. seg_index=seg_index, event_channel_index=chan_ind,
  324. t_start=t_start_, t_stop=t_stop_)
  325. ev_times = self.rescale_event_timestamp(ev_timestamp, 'float64') * pq.s
  326. if ev_raw_durations is None:
  327. ev_durations = None
  328. else:
  329. ev_durations = self.rescale_epoch_duration(ev_raw_durations, 'float64') * pq.s
  330. ev_labels = ev_labels.astype('S')
  331. else:
  332. nb = self.event_count(block_index=block_index, seg_index=seg_index,
  333. event_channel_index=chan_ind)
  334. lazy_shape = (nb,)
  335. ev_times = np.array([]) * pq.s
  336. ev_labels = np.array([], dtype='S')
  337. ev_durations = np.array([]) * pq.s
  338. d = self.raw_annotations['blocks'][block_index]['segments'][seg_index]['events'][
  339. chan_ind]
  340. annotations = dict(d)
  341. if 'name' not in annotations:
  342. annotations['name'] = event_channels['name'][chan_ind]
  343. annotations = check_annotations(annotations)
  344. if event_channels['type'][chan_ind] == b'event':
  345. e = Event(times=ev_times, labels=ev_labels, units='s', copy=False, **annotations)
  346. e.segment = seg
  347. seg.events.append(e)
  348. elif event_channels['type'][chan_ind] == b'epoch':
  349. e = Epoch(times=ev_times, durations=ev_durations, labels=ev_labels,
  350. units='s', copy=False, **annotations)
  351. e.segment = seg
  352. seg.epochs.append(e)
  353. if lazy:
  354. e.lazy_shape = lazy_shape
  355. seg.create_many_to_one_relationship()
  356. return seg
  357. def _make_signal_channel_subgroups(self, channel_indexes,
  358. signal_group_mode='group-by-same-units'):
  359. """
  360. For some RawIO channel are already splitted in groups.
  361. But in any cases, channel need to be splitted again in sub groups
  362. because they do not have the same units.
  363. They can also be splitted one by one to match previous behavior for
  364. some IOs in older version of neo (<=0.5).
  365. This method aggregate signal channels with same units or split them all.
  366. """
  367. all_channels = self.header['signal_channels']
  368. if channel_indexes is None:
  369. channel_indexes = np.arange(all_channels.size, dtype=int)
  370. channels = all_channels[channel_indexes]
  371. groups = collections.OrderedDict()
  372. if signal_group_mode == 'group-by-same-units':
  373. all_units = np.unique(channels['units'])
  374. for i, unit in enumerate(all_units):
  375. ind_within, = np.nonzero(channels['units'] == unit)
  376. ind_abs = channel_indexes[ind_within]
  377. groups[i] = (ind_within, ind_abs)
  378. elif signal_group_mode == 'split-all':
  379. for i, chan_index in enumerate(channel_indexes):
  380. ind_within = [i]
  381. ind_abs = channel_indexes[ind_within]
  382. groups[i] = (ind_within, ind_abs)
  383. else:
  384. raise (NotImplementedError)
  385. return groups
  386. unit_convert = {'Volts': 'V', 'volts': 'V', 'Volt': 'V',
  387. 'volt': 'V', ' Volt': 'V', 'microV': 'V'}
  388. def ensure_signal_units(units):
  389. # test units
  390. units = units.replace(' ', '')
  391. if units in unit_convert:
  392. units = unit_convert[units]
  393. try:
  394. units = pq.Quantity(1, units)
  395. except:
  396. logging.warning('Units "{}" can not be converted to a quantity. Using dimensionless '
  397. 'instead'.format(units))
  398. units = ''
  399. return units
  400. def check_annotations(annotations):
  401. # force type to str for some keys
  402. # imposed for tests
  403. for k in ('name', 'description', 'file_origin'):
  404. if k in annotations:
  405. annotations[k] = str(annotations[k])
  406. if 'coordinates' in annotations:
  407. # some rawio expose some coordinates in annotations but is not standardized
  408. # (x, y, z) or polar, at the moment it is more resonable to remove them
  409. annotations.pop('coordinates')
  410. return annotations
  411. def ensure_second(v):
  412. if isinstance(v, float):
  413. return v * pq.s
  414. elif isinstance(v, pq.Quantity):
  415. return v.rescale('s')
  416. elif isinstance(v, int):
  417. return float(v) * pq.s