hdf5io.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # -*- coding: utf-8 -*-
  2. """
  3. """
  4. from __future__ import absolute_import
  5. import sys
  6. import logging
  7. import pickle
  8. import numpy as np
  9. import quantities as pq
  10. try:
  11. import h5py
  12. except ImportError as err:
  13. HAVE_H5PY = False
  14. else:
  15. HAVE_H5PY = True
  16. from neo.core import (objectlist, Block, Segment, AnalogSignal, SpikeTrain,
  17. Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  18. Unit)
  19. from neo.io.baseio import BaseIO
  20. from neo.core.baseneo import MergeError
  21. logger = logging.getLogger('Neo')
  22. def disjoint_groups(groups):
  23. """`groups` should be a list of sets"""
  24. groups = groups[:] # copy, so as not to change original
  25. for group1 in groups:
  26. for group2 in groups:
  27. if group1 != group2:
  28. if group2.issubset(group1):
  29. groups.remove(group2)
  30. elif group1.issubset(group2):
  31. groups.remove(group1)
  32. return groups
  33. class NeoHdf5IO(BaseIO):
  34. """
  35. Class for reading HDF5 format files created by Neo version 0.4 or earlier.
  36. Writing to HDF5 is not supported by this IO; we recommend using NixIO for this.
  37. """
  38. supported_objects = objectlist
  39. readable_objects = objectlist
  40. name = 'NeoHdf5 IO'
  41. extensions = ['h5']
  42. mode = 'file'
  43. is_readable = True
  44. is_writable = False
  45. def __init__(self, filename):
  46. if not HAVE_H5PY:
  47. raise ImportError("h5py is not available")
  48. BaseIO.__init__(self, filename=filename)
  49. self._data = h5py.File(filename, 'r')
  50. self.object_refs = {}
  51. self._lazy = False
  52. def read_all_blocks(self, lazy=False, cascade=True, merge_singles=True, **kargs):
  53. """
  54. Loads all blocks in the file that are attached to the root (which
  55. happens when they are saved with save() or write_block()).
  56. If `merge_singles` is True, then the IO will attempt to merge single channel
  57. `AnalogSignal` objects into multichannel objects, and similarly for single `Epoch`,
  58. `Event` and `IrregularlySampledSignal` objects.
  59. """
  60. self._lazy = lazy
  61. self._cascade = cascade
  62. self.merge_singles = merge_singles
  63. blocks = []
  64. for name, node in self._data.items():
  65. if "Block" in name:
  66. blocks.append(self._read_block(node))
  67. return blocks
  68. def read_block(self, lazy=False, cascade=True, **kargs):
  69. """
  70. Load the first block in the file.
  71. """
  72. return self.read_all_blocks(lazy=lazy, cascade=cascade)[0]
  73. def _read_block(self, node):
  74. attributes = self._get_standard_attributes(node)
  75. if "index" in attributes:
  76. attributes["index"] = int(attributes["index"])
  77. block = Block(**attributes)
  78. if self._cascade:
  79. for name, child_node in node['segments'].items():
  80. if "Segment" in name:
  81. block.segments.append(self._read_segment(child_node, parent=block))
  82. if len(node['recordingchannelgroups']) > 0:
  83. for name, child_node in node['recordingchannelgroups'].items():
  84. if "RecordingChannelGroup" in name:
  85. block.channel_indexes.append(self._read_recordingchannelgroup(child_node, parent=block))
  86. self._resolve_channel_indexes(block)
  87. elif self.merge_singles:
  88. # if no RecordingChannelGroups are defined, merging
  89. # takes place here.
  90. for segment in block.segments:
  91. if hasattr(segment, 'unmerged_analogsignals'):
  92. segment.analogsignals.extend(
  93. self._merge_data_objects(segment.unmerged_analogsignals))
  94. del segment.unmerged_analogsignals
  95. if hasattr(segment, 'unmerged_irregularlysampledsignals'):
  96. segment.irregularlysampledsignals.extend(
  97. self._merge_data_objects(segment.unmerged_irregularlysampledsignals))
  98. del segment.unmerged_irregularlysampledsignals
  99. return block
  100. def _read_segment(self, node, parent):
  101. attributes = self._get_standard_attributes(node)
  102. segment = Segment(**attributes)
  103. signals = []
  104. for name, child_node in node['analogsignals'].items():
  105. if "AnalogSignal" in name:
  106. signals.append(self._read_analogsignal(child_node, parent=segment))
  107. if signals and self.merge_singles:
  108. segment.unmerged_analogsignals = signals # signals will be merged later
  109. signals = []
  110. for name, child_node in node['analogsignalarrays'].items():
  111. if "AnalogSignalArray" in name:
  112. signals.append(self._read_analogsignalarray(child_node, parent=segment))
  113. segment.analogsignals = signals
  114. irr_signals = []
  115. for name, child_node in node['irregularlysampledsignals'].items():
  116. if "IrregularlySampledSignal" in name:
  117. irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
  118. if irr_signals and self.merge_singles:
  119. segment.unmerged_irregularlysampledsignals = irr_signals
  120. irr_signals = []
  121. segment.irregularlysampledsignals = irr_signals
  122. epochs = []
  123. for name, child_node in node['epochs'].items():
  124. if "Epoch" in name:
  125. epochs.append(self._read_epoch(child_node, parent=segment))
  126. if self.merge_singles:
  127. epochs = self._merge_data_objects(epochs)
  128. for name, child_node in node['epocharrays'].items():
  129. if "EpochArray" in name:
  130. epochs.append(self._read_epocharray(child_node, parent=segment))
  131. segment.epochs = epochs
  132. events = []
  133. for name, child_node in node['events'].items():
  134. if "Event" in name:
  135. events.append(self._read_event(child_node, parent=segment))
  136. if self.merge_singles:
  137. events = self._merge_data_objects(events)
  138. for name, child_node in node['eventarrays'].items():
  139. if "EventArray" in name:
  140. events.append(self._read_eventarray(child_node, parent=segment))
  141. segment.events = events
  142. spiketrains = []
  143. for name, child_node in node['spikes'].items():
  144. raise NotImplementedError('Spike objects not yet handled.')
  145. for name, child_node in node['spiketrains'].items():
  146. if "SpikeTrain" in name:
  147. spiketrains.append(self._read_spiketrain(child_node, parent=segment))
  148. segment.spiketrains = spiketrains
  149. segment.block = parent
  150. return segment
  151. def _read_analogsignalarray(self, node, parent):
  152. attributes = self._get_standard_attributes(node)
  153. # todo: handle channel_index
  154. sampling_rate = self._get_quantity(node["sampling_rate"])
  155. t_start = self._get_quantity(node["t_start"])
  156. signal = AnalogSignal(self._get_quantity(node["signal"]),
  157. sampling_rate=sampling_rate, t_start=t_start,
  158. **attributes)
  159. if self._lazy:
  160. signal.lazy_shape = node["signal"].shape
  161. if len(signal.lazy_shape) == 1:
  162. signal.lazy_shape = (signal.lazy_shape[0], 1)
  163. signal.segment = parent
  164. self.object_refs[node.attrs["object_ref"]] = signal
  165. return signal
  166. def _read_analogsignal(self, node, parent):
  167. return self._read_analogsignalarray(node, parent)
  168. def _read_irregularlysampledsignal(self, node, parent):
  169. attributes = self._get_standard_attributes(node)
  170. signal = IrregularlySampledSignal(times=self._get_quantity(node["times"]),
  171. signal=self._get_quantity(node["signal"]),
  172. **attributes)
  173. signal.segment = parent
  174. if self._lazy:
  175. signal.lazy_shape = node["signal"].shape
  176. if len(signal.lazy_shape) == 1:
  177. signal.lazy_shape = (signal.lazy_shape[0], 1)
  178. return signal
  179. def _read_spiketrain(self, node, parent):
  180. attributes = self._get_standard_attributes(node)
  181. t_start = self._get_quantity(node["t_start"])
  182. t_stop = self._get_quantity(node["t_stop"])
  183. # todo: handle sampling_rate, waveforms, left_sweep
  184. spiketrain = SpikeTrain(self._get_quantity(node["times"]),
  185. t_start=t_start, t_stop=t_stop,
  186. **attributes)
  187. spiketrain.segment = parent
  188. if self._lazy:
  189. spiketrain.lazy_shape = node["times"].shape
  190. self.object_refs[node.attrs["object_ref"]] = spiketrain
  191. return spiketrain
  192. def _read_epocharray(self, node, parent):
  193. attributes = self._get_standard_attributes(node)
  194. times = self._get_quantity(node["times"])
  195. durations = self._get_quantity(node["durations"])
  196. if self._lazy:
  197. labels = np.array((), dtype=node["labels"].dtype)
  198. else:
  199. labels = node["labels"].value
  200. epoch = Epoch(times=times, durations=durations, labels=labels, **attributes)
  201. epoch.segment = parent
  202. if self._lazy:
  203. epoch.lazy_shape = node["times"].shape
  204. return epoch
  205. def _read_epoch(self, node, parent):
  206. return self._read_epocharray(node, parent)
  207. def _read_eventarray(self, node, parent):
  208. attributes = self._get_standard_attributes(node)
  209. times = self._get_quantity(node["times"])
  210. if self._lazy:
  211. labels = np.array((), dtype=node["labels"].dtype)
  212. else:
  213. labels = node["labels"].value
  214. event = Event(times=times, labels=labels, **attributes)
  215. event.segment = parent
  216. if self._lazy:
  217. event.lazy_shape = node["times"].shape
  218. return event
  219. def _read_event(self, node, parent):
  220. return self._read_eventarray(node, parent)
  221. def _read_recordingchannelgroup(self, node, parent):
  222. # todo: handle Units
  223. attributes = self._get_standard_attributes(node)
  224. channel_indexes = node["channel_indexes"].value
  225. channel_names = node["channel_names"].value
  226. if channel_indexes.size:
  227. if len(node['recordingchannels']) :
  228. raise MergeError("Cannot handle a RecordingChannelGroup which both has a "
  229. "'channel_indexes' attribute and contains "
  230. "RecordingChannel objects")
  231. raise NotImplementedError("todo") # need to handle node['analogsignalarrays']
  232. else:
  233. channels = []
  234. for name, child_node in node['recordingchannels'].items():
  235. if "RecordingChannel" in name:
  236. channels.append(self._read_recordingchannel(child_node))
  237. channel_index = ChannelIndex(None, **attributes)
  238. channel_index._channels = channels
  239. # construction of the index is deferred until we have processed
  240. # all RecordingChannelGroup nodes
  241. units = []
  242. for name, child_node in node['units'].items():
  243. if "Unit" in name:
  244. units.append(self._read_unit(child_node, parent=channel_index))
  245. channel_index.units = units
  246. channel_index.block = parent
  247. return channel_index
  248. def _read_recordingchannel(self, node):
  249. attributes = self._get_standard_attributes(node)
  250. analogsignals = []
  251. irregsignals = []
  252. for name, child_node in node["analogsignals"].items():
  253. if "AnalogSignal" in name:
  254. obj_ref = child_node.attrs["object_ref"]
  255. analogsignals.append(obj_ref)
  256. for name, child_node in node["irregularlysampledsignals"].items():
  257. if "IrregularlySampledSignal" in name:
  258. obj_ref = child_node.attrs["object_ref"]
  259. irregsignals.append(obj_ref)
  260. return attributes['index'], analogsignals, irregsignals
  261. def _read_unit(self, node, parent):
  262. attributes = self._get_standard_attributes(node)
  263. spiketrains = []
  264. for name, child_node in node["spiketrains"].items():
  265. if "SpikeTrain" in name:
  266. obj_ref = child_node.attrs["object_ref"]
  267. spiketrains.append(self.object_refs[obj_ref])
  268. unit = Unit(**attributes)
  269. unit.channel_index = parent
  270. unit.spiketrains = spiketrains
  271. return unit
  272. def _merge_data_objects(self, objects):
  273. if len(objects) > 1:
  274. merged_objects = [objects.pop(0)]
  275. while objects:
  276. obj = objects.pop(0)
  277. try:
  278. combined_obj_ref = merged_objects[-1].annotations['object_ref']
  279. merged_objects[-1] = merged_objects[-1].merge(obj)
  280. merged_objects[-1].annotations['object_ref'] = combined_obj_ref + "-" + obj.annotations['object_ref']
  281. except MergeError:
  282. merged_objects.append(obj)
  283. for obj in merged_objects:
  284. self.object_refs[obj.annotations['object_ref']] = obj
  285. return merged_objects
  286. else:
  287. return objects
  288. def _get_quantity(self, node):
  289. if self._lazy and len(node.shape) > 0:
  290. value = np.array((), dtype=node.dtype)
  291. else:
  292. value = node.value
  293. unit_str = [x for x in node.attrs.keys() if "unit" in x][0].split("__")[1]
  294. units = getattr(pq, unit_str)
  295. return value * units
  296. def _get_standard_attributes(self, node):
  297. """Retrieve attributes"""
  298. attributes = {}
  299. for name in ('name', 'description', 'index', 'file_origin', 'object_ref'):
  300. if name in node.attrs:
  301. attributes[name] = node.attrs[name]
  302. for name in ('rec_datetime', 'file_datetime'):
  303. if name in node.attrs:
  304. if sys.version_info.major > 2:
  305. attributes[name] = pickle.loads(node.attrs[name], encoding='bytes')
  306. else: # Python 2 doesn't have the encoding argument
  307. attributes[name] = pickle.loads(node.attrs[name])
  308. if sys.version_info.major > 2:
  309. annotations = pickle.loads(node.attrs['annotations'], encoding='bytes')
  310. else:
  311. annotations = pickle.loads(node.attrs['annotations'])
  312. attributes.update(annotations)
  313. attribute_names = list(attributes.keys()) # avoid "dictionary changed size during iteration" error
  314. if sys.version_info.major > 2:
  315. for name in attribute_names:
  316. if isinstance(attributes[name], (bytes, np.bytes_)):
  317. attributes[name] = attributes[name].decode('utf-8')
  318. if isinstance(name, bytes):
  319. attributes[name.decode('utf-8')] = attributes[name]
  320. attributes.pop(name)
  321. return attributes
  322. def _resolve_channel_indexes(self, block):
  323. def disjoint_channel_indexes(channel_indexes):
  324. channel_indexes = channel_indexes[:]
  325. for ci1 in channel_indexes:
  326. signal_group1 = set(tuple(x[1]) for x in ci1._channels) # this works only on analogsignals
  327. for ci2 in channel_indexes: # need to take irregularly sampled signals
  328. signal_group2 = set(tuple(x[1]) for x in ci2._channels) # into account too
  329. if signal_group1 != signal_group2:
  330. if signal_group2.issubset(signal_group1):
  331. channel_indexes.remove(ci2)
  332. elif signal_group1.issubset(signal_group2):
  333. channel_indexes.remove(ci1)
  334. return channel_indexes
  335. principal_indexes = disjoint_channel_indexes(block.channel_indexes)
  336. for ci in principal_indexes:
  337. ids = []
  338. by_segment = {}
  339. for (index, analogsignals, irregsignals) in ci._channels:
  340. ids.append(index) # note that what was called "index" in Neo 0.3/0.4 is "id" in Neo 0.5
  341. for signal_ref in analogsignals:
  342. signal = self.object_refs[signal_ref]
  343. segment_id = id(signal.segment)
  344. if segment_id in by_segment:
  345. by_segment[segment_id]['analogsignals'].append(signal)
  346. else:
  347. by_segment[segment_id] = {'analogsignals': [signal], 'irregsignals': []}
  348. for signal_ref in irregsignals:
  349. signal = self.object_refs[signal_ref]
  350. segment_id = id(signal.segment)
  351. if segment_id in by_segment:
  352. by_segment[segment_id]['irregsignals'].append(signal)
  353. else:
  354. by_segment[segment_id] = {'analogsignals': [], 'irregsignals': [signal]}
  355. assert len(ids) > 0
  356. if self.merge_singles:
  357. ci.channel_ids = np.array(ids)
  358. ci.index = np.arange(len(ids))
  359. for seg_id, segment_data in by_segment.items():
  360. # get the segment object
  361. segment = None
  362. for seg in ci.block.segments:
  363. if id(seg) == seg_id:
  364. segment = seg
  365. break
  366. assert segment is not None
  367. if segment_data['analogsignals']:
  368. merged_signals = self._merge_data_objects(segment_data['analogsignals'])
  369. assert len(merged_signals) == 1
  370. merged_signals[0].channel_index = ci
  371. merged_signals[0].annotations['object_ref'] = "-".join(obj.annotations['object_ref']
  372. for obj in segment_data['analogsignals'])
  373. segment.analogsignals.extend(merged_signals)
  374. ci.analogsignals = merged_signals
  375. if segment_data['irregsignals']:
  376. merged_signals = self._merge_data_objects(segment_data['irregsignals'])
  377. assert len(merged_signals) == 1
  378. merged_signals[0].channel_index = ci
  379. merged_signals[0].annotations['object_ref'] = "-".join(obj.annotations['object_ref']
  380. for obj in segment_data['irregsignals'])
  381. segment.irregularlysampledsignals.extend(merged_signals)
  382. ci.irregularlysampledsignals = merged_signals
  383. else:
  384. raise NotImplementedError() # will need to return multiple ChannelIndexes
  385. # handle non-principal channel indexes
  386. for ci in block.channel_indexes:
  387. if ci not in principal_indexes:
  388. ids = [c[0] for c in ci._channels]
  389. for cipr in principal_indexes:
  390. if ids[0] in cipr.channel_ids:
  391. break
  392. ci.analogsignals = cipr.analogsignals
  393. ci.channel_ids = np.array(ids)
  394. ci.index = np.where(np.in1d(cipr.channel_ids, ci.channel_ids))[0]