nixrawio.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. """
  2. RawIO Class for NIX files
  3. The RawIO assumes all segments and all blocks have the same structure.
  4. It supports all kinds of NEO objects.
  5. Author: Chek Yin Choi
  6. """
  7. from .baserawio import (BaseRawIO, _signal_channel_dtype,
  8. _unit_channel_dtype, _event_channel_dtype)
  9. from ..io.nixio import NixIO
  10. from ..io.nixio import check_nix_version
  11. import numpy as np
  12. try:
  13. import nixio as nix
  14. HAVE_NIX = True
  15. except ImportError:
  16. HAVE_NIX = False
  17. nix = None
  18. # When reading metadata properties, the following keys are ignored since they
  19. # are used to store Neo object properties.
  20. # This dictionary is used in the _filter_properties() method.
  21. neo_attributes = {
  22. "segment": ["index"],
  23. "analogsignal": ["units", "copy", "sampling_rate", "t_start"],
  24. "spiketrain": ["units", "copy", "sampling_rate", "t_start", "t_stop",
  25. "waveforms", "left_sweep"],
  26. "event": ["times", "labels", "units", "durations", "copy"]
  27. }
  28. class NIXRawIO(BaseRawIO):
  29. extensions = ['nix']
  30. rawmode = 'one-file'
  31. def __init__(self, filename=''):
  32. check_nix_version()
  33. BaseRawIO.__init__(self)
  34. self.filename = filename
  35. def _source_name(self):
  36. return self.filename
  37. def _parse_header(self):
  38. self.file = nix.File.open(self.filename, nix.FileMode.ReadOnly)
  39. sig_channels = []
  40. size_list = []
  41. for bl in self.file.blocks:
  42. for seg in bl.groups:
  43. for da_idx, da in enumerate(seg.data_arrays):
  44. if da.type == "neo.analogsignal":
  45. chan_id = da_idx
  46. ch_name = da.metadata['neo_name']
  47. units = str(da.unit)
  48. dtype = str(da.dtype)
  49. sr = 1 / da.dimensions[0].sampling_interval
  50. da_leng = da.size
  51. if da_leng not in size_list:
  52. size_list.append(da_leng)
  53. group_id = 0
  54. for sid, li_leng in enumerate(size_list):
  55. if li_leng == da_leng:
  56. group_id = sid
  57. # very important! group_id use to store
  58. # channel groups!!!
  59. # use only for different signal length
  60. gain = 1
  61. offset = 0.
  62. sig_channels.append((ch_name, chan_id, sr, dtype,
  63. units, gain, offset, group_id))
  64. break
  65. break
  66. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  67. unit_channels = []
  68. unit_name = ""
  69. unit_id = ""
  70. for bl in self.file.blocks:
  71. for seg in bl.groups:
  72. for mt in seg.multi_tags:
  73. if mt.type == "neo.spiketrain":
  74. unit_name = mt.metadata['neo_name']
  75. unit_id = mt.id
  76. wf_left_sweep = 0
  77. wf_units = None
  78. wf_sampling_rate = 0
  79. if mt.features:
  80. wf = mt.features[0].data
  81. wf_units = wf.unit
  82. dim = wf.dimensions[2]
  83. interval = dim.sampling_interval
  84. wf_sampling_rate = 1 / interval
  85. if wf.metadata:
  86. wf_left_sweep = wf.metadata["left_sweep"]
  87. wf_gain = 1
  88. wf_offset = 0.
  89. unit_channels.append(
  90. (unit_name, unit_id, wf_units, wf_gain,
  91. wf_offset, wf_left_sweep, wf_sampling_rate)
  92. )
  93. break
  94. break
  95. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  96. event_channels = []
  97. event_count = 0
  98. epoch_count = 0
  99. for bl in self.file.blocks:
  100. for seg in bl.groups:
  101. for mt in seg.multi_tags:
  102. if mt.type == "neo.event":
  103. ev_name = mt.metadata['neo_name']
  104. ev_id = event_count
  105. event_count += 1
  106. ev_type = "event"
  107. event_channels.append((ev_name, ev_id, ev_type))
  108. if mt.type == "neo.epoch":
  109. ep_name = mt.metadata['neo_name']
  110. ep_id = epoch_count
  111. epoch_count += 1
  112. ep_type = "epoch"
  113. event_channels.append((ep_name, ep_id, ep_type))
  114. break
  115. break
  116. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  117. self.da_list = {'blocks': []}
  118. for block_index, blk in enumerate(self.file.blocks):
  119. d = {'segments': []}
  120. self.da_list['blocks'].append(d)
  121. for seg_index, seg in enumerate(blk.groups):
  122. d = {'signals': []}
  123. self.da_list['blocks'][block_index]['segments'].append(d)
  124. size_list = []
  125. data_list = []
  126. da_name_list = []
  127. for da in seg.data_arrays:
  128. if da.type == 'neo.analogsignal':
  129. size_list.append(da.size)
  130. data_list.append(da)
  131. da_name_list.append(da.metadata['neo_name'])
  132. block = self.da_list['blocks'][block_index]
  133. segment = block['segments'][seg_index]
  134. segment['data_size'] = size_list
  135. segment['data'] = data_list
  136. segment['ch_name'] = da_name_list
  137. self.unit_list = {'blocks': []}
  138. for block_index, blk in enumerate(self.file.blocks):
  139. d = {'segments': []}
  140. self.unit_list['blocks'].append(d)
  141. for seg_index, seg in enumerate(blk.groups):
  142. d = {'spiketrains': [],
  143. 'spiketrains_id': [],
  144. 'spiketrains_unit': []}
  145. self.unit_list['blocks'][block_index]['segments'].append(d)
  146. st_idx = 0
  147. for st in seg.multi_tags:
  148. d = {'waveforms': []}
  149. block = self.unit_list['blocks'][block_index]
  150. segment = block['segments'][seg_index]
  151. segment['spiketrains_unit'].append(d)
  152. if st.type == 'neo.spiketrain':
  153. segment['spiketrains'].append(st.positions)
  154. segment['spiketrains_id'].append(st.id)
  155. wftypestr = "neo.waveforms"
  156. if (st.features
  157. and st.features[0].data.type == wftypestr):
  158. waveforms = st.features[0].data
  159. stdict = segment['spiketrains_unit'][st_idx]
  160. if waveforms:
  161. stdict['waveforms'] = waveforms
  162. else:
  163. stdict['waveforms'] = None
  164. # assume one spiketrain one waveform
  165. st_idx += 1
  166. self.header = {}
  167. self.header['nb_block'] = len(self.file.blocks)
  168. self.header['nb_segment'] = [len(bl.groups) for bl in self.file.blocks]
  169. self.header['signal_channels'] = sig_channels
  170. self.header['unit_channels'] = unit_channels
  171. self.header['event_channels'] = event_channels
  172. self._generate_minimal_annotations()
  173. for blk_idx, blk in enumerate(self.file.blocks):
  174. bl_ann = self.raw_annotations['blocks'][blk_idx]
  175. props = blk.metadata.inherited_properties()
  176. bl_ann.update(self._filter_properties(props, "block"))
  177. for grp_idx, grp in enumerate(blk.groups):
  178. seg_ann = bl_ann['segments'][grp_idx]
  179. props = grp.metadata.inherited_properties()
  180. seg_ann.update(self._filter_properties(props, "segment"))
  181. sig_idx = 0
  182. groupdas = NixIO._group_signals(grp.data_arrays)
  183. for nix_name, signals in groupdas.items():
  184. da = signals[0]
  185. if da.type == 'neo.analogsignal' and seg_ann['signals']:
  186. # collect and group DataArrays
  187. sig_ann = seg_ann['signals'][sig_idx]
  188. sig_chan_ann = self.raw_annotations['signal_channels'][sig_idx]
  189. props = da.metadata.inherited_properties()
  190. sig_ann.update(self._filter_properties(props, 'analogsignal'))
  191. sig_chan_ann.update(self._filter_properties(props, 'analogsignal'))
  192. sig_idx += 1
  193. sp_idx = 0
  194. ev_idx = 0
  195. for mt in grp.multi_tags:
  196. if mt.type == 'neo.spiketrain' and seg_ann['units']:
  197. st_ann = seg_ann['units'][sp_idx]
  198. props = mt.metadata.inherited_properties()
  199. st_ann.update(self._filter_properties(props, 'spiketrain'))
  200. sp_idx += 1
  201. # if order is preserving, the annotations
  202. # should go to the right place, need test
  203. if mt.type == "neo.event" or mt.type == "neo.epoch":
  204. if seg_ann['events'] != []:
  205. event_ann = seg_ann['events'][ev_idx]
  206. props = mt.metadata.inherited_properties()
  207. event_ann.update(self._filter_properties(props, 'event'))
  208. ev_idx += 1
  209. # populate ChannelIndex annotations
  210. for srcidx, source in enumerate(blk.sources):
  211. chx_ann = self.raw_annotations["signal_channels"][srcidx]
  212. props = source.metadata.inherited_properties()
  213. chx_ann.update(self._filter_properties(props, "channelindex"))
  214. def _segment_t_start(self, block_index, seg_index):
  215. t_start = 0
  216. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  217. if mt.type == "neo.spiketrain":
  218. t_start = mt.metadata['t_start']
  219. return t_start
  220. def _segment_t_stop(self, block_index, seg_index):
  221. t_stop = 0
  222. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  223. if mt.type == "neo.spiketrain":
  224. t_stop = mt.metadata['t_stop']
  225. return t_stop
  226. def _get_signal_size(self, block_index, seg_index, channel_indexes):
  227. if channel_indexes is None:
  228. channel_indexes = list(range(self.header['signal_channels'].size))
  229. ch_idx = channel_indexes[0]
  230. block = self.da_list['blocks'][block_index]
  231. segment = block['segments'][seg_index]
  232. size = segment['data_size'][ch_idx]
  233. return size # size is per signal, not the sum of all channel_indexes
  234. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  235. if channel_indexes is None:
  236. channel_indexes = list(range(self.header['signal_channels'].size))
  237. ch_idx = channel_indexes[0]
  238. block = self.file.blocks[block_index]
  239. das = [da for da in block.groups[seg_index].data_arrays]
  240. da = das[ch_idx]
  241. sig_t_start = float(da.metadata['t_start'])
  242. return sig_t_start # assume same group_id always same t_start
  243. def _get_analogsignal_chunk(self, block_index, seg_index,
  244. i_start, i_stop, channel_indexes):
  245. if channel_indexes is None:
  246. channel_indexes = list(range(self.header['signal_channels'].size))
  247. if i_start is None:
  248. i_start = 0
  249. if i_stop is None:
  250. block = self.da_list['blocks'][block_index]
  251. segment = block['segments'][seg_index]
  252. for c in channel_indexes:
  253. i_stop = segment['data_size'][c]
  254. break
  255. raw_signals_list = []
  256. da_list = self.da_list['blocks'][block_index]['segments'][seg_index]
  257. for idx in channel_indexes:
  258. da = da_list['data'][idx]
  259. raw_signals_list.append(da[i_start:i_stop])
  260. raw_signals = np.array(raw_signals_list)
  261. raw_signals = np.transpose(raw_signals)
  262. return raw_signals
  263. def _spike_count(self, block_index, seg_index, unit_index):
  264. count = 0
  265. head_id = self.header['unit_channels'][unit_index][1]
  266. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  267. for src in mt.sources:
  268. if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]:
  269. if head_id == src.id:
  270. return len(mt.positions)
  271. return count
  272. def _get_spike_timestamps(self, block_index, seg_index, unit_index,
  273. t_start, t_stop):
  274. block = self.unit_list['blocks'][block_index]
  275. segment = block['segments'][seg_index]
  276. spike_dict = segment['spiketrains']
  277. spike_timestamps = spike_dict[unit_index]
  278. spike_timestamps = np.transpose(spike_timestamps)
  279. if t_start is not None or t_stop is not None:
  280. lim0 = t_start
  281. lim1 = t_stop
  282. mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
  283. spike_timestamps = spike_timestamps[mask]
  284. return spike_timestamps
  285. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  286. spike_times = spike_timestamps.astype(dtype)
  287. return spike_times
  288. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index,
  289. t_start, t_stop):
  290. # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
  291. seg = self.unit_list['blocks'][block_index]['segments'][seg_index]
  292. waveforms = seg['spiketrains_unit'][unit_index]['waveforms']
  293. if not waveforms:
  294. return None
  295. raw_waveforms = np.array(waveforms)
  296. if t_start is not None:
  297. lim0 = t_start
  298. mask = (raw_waveforms >= lim0)
  299. # use nan to keep the shape
  300. raw_waveforms = np.where(mask, raw_waveforms, np.nan)
  301. if t_stop is not None:
  302. lim1 = t_stop
  303. mask = (raw_waveforms <= lim1)
  304. raw_waveforms = np.where(mask, raw_waveforms, np.nan)
  305. return raw_waveforms
  306. def _event_count(self, block_index, seg_index, event_channel_index):
  307. event_count = 0
  308. segment = self.file.blocks[block_index].groups[seg_index]
  309. for event in segment.multi_tags:
  310. if event.type == 'neo.event' or event.type == 'neo.epoch':
  311. if event_count == event_channel_index:
  312. return len(event.positions)
  313. else:
  314. event_count += 1
  315. return event_count
  316. def _get_event_timestamps(self, block_index, seg_index,
  317. event_channel_index, t_start, t_stop):
  318. timestamp = []
  319. labels = []
  320. durations = None
  321. if event_channel_index is None:
  322. raise IndexError
  323. for mt in self.file.blocks[block_index].groups[seg_index].multi_tags:
  324. if mt.type == "neo.event" or mt.type == "neo.epoch":
  325. labels.append(mt.positions.dimensions[0].labels)
  326. po = mt.positions
  327. if (po.type == "neo.event.times"
  328. or po.type == "neo.epoch.times"):
  329. timestamp.append(po)
  330. channel = self.header['event_channels'][event_channel_index]
  331. if channel['type'] == b'epoch' and mt.extents:
  332. if mt.extents.type == 'neo.epoch.durations':
  333. durations = np.array(mt.extents)
  334. break
  335. timestamp = timestamp[event_channel_index][:]
  336. timestamp = np.array(timestamp, dtype="float")
  337. labels = labels[event_channel_index][:]
  338. labels = np.array(labels, dtype='U')
  339. if t_start is not None:
  340. keep = timestamp >= t_start
  341. timestamp, labels = timestamp[keep], labels[keep]
  342. if t_stop is not None:
  343. keep = timestamp <= t_stop
  344. timestamp, labels = timestamp[keep], labels[keep]
  345. return timestamp, durations, labels # only the first fits in rescale
  346. def _rescale_event_timestamp(self, event_timestamps, dtype='float64'):
  347. ev_unit = ''
  348. for mt in self.file.blocks[0].groups[0].multi_tags:
  349. if mt.type == "neo.event":
  350. ev_unit = mt.positions.unit
  351. break
  352. if ev_unit == 'ms':
  353. event_timestamps /= 1000
  354. event_times = event_timestamps.astype(dtype)
  355. # supposing unit is second, other possibilities maybe mS microS...
  356. return event_times # return in seconds
  357. def _rescale_epoch_duration(self, raw_duration, dtype='float64'):
  358. ep_unit = ''
  359. for mt in self.file.blocks[0].groups[0].multi_tags:
  360. if mt.type == "neo.epoch":
  361. ep_unit = mt.positions.unit
  362. break
  363. if ep_unit == 'ms':
  364. raw_duration /= 1000
  365. durations = raw_duration.astype(dtype)
  366. # supposing unit is second, other possibilities maybe mS microS...
  367. return durations # return in seconds
  368. def _filter_properties(self, properties, neo_type):
  369. """
  370. Takes a collection of NIX metadata properties and the name of a Neo
  371. type and returns a dictionary representing the Neo object annotations.
  372. Properties that represent the attributes of the Neo object type are
  373. filtered, based on the global 'neo_attributes' dictionary.
  374. """
  375. annotations = dict()
  376. attrs = neo_attributes.get(neo_type, list())
  377. for prop in properties:
  378. # filter neo_name explicitly
  379. if not (prop.name in attrs or prop.name == "neo_name"):
  380. values = prop.values
  381. if len(values) == 1:
  382. values = values[0]
  383. annotations[str(prop.name)] = values
  384. return annotations