hdf5io.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. """
  2. """
  3. import logging
  4. from distutils.version import LooseVersion
  5. import pickle
  6. from warnings import warn
  7. import numpy as np
  8. import quantities as pq
  9. try:
  10. import h5py
  11. except ImportError as err:
  12. HAVE_H5PY = False
  13. else:
  14. HAVE_H5PY = True
  15. from neo.core import (objectlist, Block, Segment, AnalogSignal, SpikeTrain,
  16. Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  17. Unit)
  18. from neo.io.baseio import BaseIO
  19. from neo.core.baseneo import MergeError
  20. logger = logging.getLogger('Neo')
  21. min_h5py_version = LooseVersion('2.6.0')
  22. def disjoint_groups(groups):
  23. """`groups` should be a list of sets"""
  24. groups = groups[:] # copy, so as not to change original
  25. for group1 in groups:
  26. for group2 in groups:
  27. if group1 != group2:
  28. if group2.issubset(group1):
  29. groups.remove(group2)
  30. elif group1.issubset(group2):
  31. groups.remove(group1)
  32. return groups
  33. class NeoHdf5IO(BaseIO):
  34. """
  35. Class for reading HDF5 format files created by Neo version 0.4 or earlier.
  36. Writing to HDF5 is not supported by this IO; we recommend using NixIO for this.
  37. """
  38. supported_objects = objectlist
  39. readable_objects = objectlist
  40. name = 'NeoHdf5 IO'
  41. extensions = ['h5']
  42. mode = 'file'
  43. is_readable = True
  44. is_writable = False
  45. def __init__(self, filename):
  46. warning_msg = (
  47. "NeoHdf5IO will be removed in the next release of Neo. "
  48. "If you still have data in this format, we recommend saving it using NixIO "
  49. "which is also based on HDF5."
  50. )
  51. warn(warning_msg, FutureWarning)
  52. if not HAVE_H5PY:
  53. raise ImportError("h5py is not available")
  54. if HAVE_H5PY:
  55. if LooseVersion(h5py.__version__) < min_h5py_version:
  56. raise ImportError('h5py version {} is too old. Minimal required version is {}'
  57. ''.format(h5py.__version__, min_h5py_version))
  58. BaseIO.__init__(self, filename=filename)
  59. self._data = h5py.File(filename, 'r')
  60. self.object_refs = {}
  61. def read_all_blocks(self, lazy=False, merge_singles=True, **kargs):
  62. """
  63. Loads all blocks in the file that are attached to the root (which
  64. happens when they are saved with save() or write_block()).
  65. If `merge_singles` is True, then the IO will attempt to merge single channel
  66. `AnalogSignal` objects into multichannel objects, and similarly for single `Epoch`,
  67. `Event` and `IrregularlySampledSignal` objects.
  68. """
  69. assert not lazy, 'Do not support lazy'
  70. self.merge_singles = merge_singles
  71. blocks = []
  72. for name, node in self._data.items():
  73. if "Block" in name:
  74. blocks.append(self._read_block(node))
  75. return blocks
  76. def read_block(self, lazy=False, **kargs):
  77. """
  78. Load the first block in the file.
  79. """
  80. assert not lazy, 'Do not support lazy'
  81. return self.read_all_blocks(lazy=lazy)[0]
  82. def _read_block(self, node):
  83. attributes = self._get_standard_attributes(node)
  84. if "index" in attributes:
  85. attributes["index"] = int(attributes["index"])
  86. block = Block(**attributes)
  87. for name, child_node in node['segments'].items():
  88. if "Segment" in name:
  89. block.segments.append(self._read_segment(child_node, parent=block))
  90. if len(node['recordingchannelgroups']) > 0:
  91. for name, child_node in node['recordingchannelgroups'].items():
  92. if "RecordingChannelGroup" in name:
  93. block.channel_indexes.append(
  94. self._read_recordingchannelgroup(child_node, parent=block))
  95. self._resolve_channel_indexes(block)
  96. elif self.merge_singles:
  97. # if no RecordingChannelGroups are defined, merging
  98. # takes place here.
  99. for segment in block.segments:
  100. if hasattr(segment, 'unmerged_analogsignals'):
  101. segment.analogsignals.extend(
  102. self._merge_data_objects(segment.unmerged_analogsignals))
  103. del segment.unmerged_analogsignals
  104. if hasattr(segment, 'unmerged_irregularlysampledsignals'):
  105. segment.irregularlysampledsignals.extend(
  106. self._merge_data_objects(segment.unmerged_irregularlysampledsignals))
  107. del segment.unmerged_irregularlysampledsignals
  108. return block
  109. def _read_segment(self, node, parent):
  110. attributes = self._get_standard_attributes(node)
  111. segment = Segment(**attributes)
  112. signals = []
  113. for name, child_node in node['analogsignals'].items():
  114. if "AnalogSignal" in name:
  115. signals.append(self._read_analogsignal(child_node, parent=segment))
  116. if signals and self.merge_singles:
  117. segment.unmerged_analogsignals = signals # signals will be merged later
  118. signals = []
  119. for name, child_node in node['analogsignalarrays'].items():
  120. if "AnalogSignalArray" in name:
  121. signals.append(self._read_analogsignalarray(child_node, parent=segment))
  122. segment.analogsignals = signals
  123. irr_signals = []
  124. for name, child_node in node['irregularlysampledsignals'].items():
  125. if "IrregularlySampledSignal" in name:
  126. irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
  127. if irr_signals and self.merge_singles:
  128. segment.unmerged_irregularlysampledsignals = irr_signals
  129. irr_signals = []
  130. segment.irregularlysampledsignals = irr_signals
  131. epochs = []
  132. for name, child_node in node['epochs'].items():
  133. if "Epoch" in name:
  134. epochs.append(self._read_epoch(child_node, parent=segment))
  135. if self.merge_singles:
  136. epochs = self._merge_data_objects(epochs)
  137. for name, child_node in node['epocharrays'].items():
  138. if "EpochArray" in name:
  139. epochs.append(self._read_epocharray(child_node, parent=segment))
  140. segment.epochs = epochs
  141. events = []
  142. for name, child_node in node['events'].items():
  143. if "Event" in name:
  144. events.append(self._read_event(child_node, parent=segment))
  145. if self.merge_singles:
  146. events = self._merge_data_objects(events)
  147. for name, child_node in node['eventarrays'].items():
  148. if "EventArray" in name:
  149. events.append(self._read_eventarray(child_node, parent=segment))
  150. segment.events = events
  151. spiketrains = []
  152. for name, child_node in node['spikes'].items():
  153. raise NotImplementedError('Spike objects not yet handled.')
  154. for name, child_node in node['spiketrains'].items():
  155. if "SpikeTrain" in name:
  156. spiketrains.append(self._read_spiketrain(child_node, parent=segment))
  157. segment.spiketrains = spiketrains
  158. segment.block = parent
  159. return segment
  160. def _read_analogsignalarray(self, node, parent):
  161. attributes = self._get_standard_attributes(node)
  162. # todo: handle channel_index
  163. sampling_rate = self._get_quantity(node["sampling_rate"])
  164. t_start = self._get_quantity(node["t_start"])
  165. signal = AnalogSignal(self._get_quantity(node["signal"]),
  166. sampling_rate=sampling_rate, t_start=t_start,
  167. **attributes)
  168. signal.segment = parent
  169. self.object_refs[node.attrs["object_ref"]] = signal
  170. return signal
  171. def _read_analogsignal(self, node, parent):
  172. return self._read_analogsignalarray(node, parent)
  173. def _read_irregularlysampledsignal(self, node, parent):
  174. attributes = self._get_standard_attributes(node)
  175. signal = IrregularlySampledSignal(times=self._get_quantity(node["times"]),
  176. signal=self._get_quantity(node["signal"]),
  177. **attributes)
  178. signal.segment = parent
  179. return signal
  180. def _read_spiketrain(self, node, parent):
  181. attributes = self._get_standard_attributes(node)
  182. t_start = self._get_quantity(node["t_start"])
  183. t_stop = self._get_quantity(node["t_stop"])
  184. # todo: handle sampling_rate, waveforms, left_sweep
  185. spiketrain = SpikeTrain(self._get_quantity(node["times"]),
  186. t_start=t_start, t_stop=t_stop,
  187. **attributes)
  188. spiketrain.segment = parent
  189. self.object_refs[node.attrs["object_ref"]] = spiketrain
  190. return spiketrain
  191. def _read_epocharray(self, node, parent):
  192. attributes = self._get_standard_attributes(node)
  193. times = self._get_quantity(node["times"])
  194. durations = self._get_quantity(node["durations"])
  195. labels = node["labels"][()].astype('U')
  196. epoch = Epoch(times=times, durations=durations, labels=labels, **attributes)
  197. epoch.segment = parent
  198. return epoch
  199. def _read_epoch(self, node, parent):
  200. return self._read_epocharray(node, parent)
  201. def _read_eventarray(self, node, parent):
  202. attributes = self._get_standard_attributes(node)
  203. times = self._get_quantity(node["times"])
  204. labels = node["labels"][()].astype('U')
  205. event = Event(times=times, labels=labels, **attributes)
  206. event.segment = parent
  207. return event
  208. def _read_event(self, node, parent):
  209. return self._read_eventarray(node, parent)
  210. def _read_recordingchannelgroup(self, node, parent):
  211. # todo: handle Units
  212. attributes = self._get_standard_attributes(node)
  213. channel_indexes = node["channel_indexes"][()]
  214. channel_names = node["channel_names"][()]
  215. if channel_indexes.size:
  216. if len(node['recordingchannels']):
  217. raise MergeError("Cannot handle a RecordingChannelGroup which both has a "
  218. "'channel_indexes' attribute and contains "
  219. "RecordingChannel objects")
  220. raise NotImplementedError("todo") # need to handle node['analogsignalarrays']
  221. else:
  222. channels = []
  223. for name, child_node in node['recordingchannels'].items():
  224. if "RecordingChannel" in name:
  225. channels.append(self._read_recordingchannel(child_node))
  226. channel_index = ChannelIndex(None, **attributes)
  227. channel_index._channels = channels
  228. # construction of the index is deferred until we have processed
  229. # all RecordingChannelGroup nodes
  230. units = []
  231. for name, child_node in node['units'].items():
  232. if "Unit" in name:
  233. units.append(self._read_unit(child_node, parent=channel_index))
  234. channel_index.units = units
  235. channel_index.block = parent
  236. return channel_index
  237. def _read_recordingchannel(self, node):
  238. attributes = self._get_standard_attributes(node)
  239. analogsignals = []
  240. irregsignals = []
  241. for name, child_node in node["analogsignals"].items():
  242. if "AnalogSignal" in name:
  243. obj_ref = child_node.attrs["object_ref"]
  244. analogsignals.append(obj_ref)
  245. for name, child_node in node["irregularlysampledsignals"].items():
  246. if "IrregularlySampledSignal" in name:
  247. obj_ref = child_node.attrs["object_ref"]
  248. irregsignals.append(obj_ref)
  249. return attributes['index'], analogsignals, irregsignals
  250. def _read_unit(self, node, parent):
  251. attributes = self._get_standard_attributes(node)
  252. spiketrains = []
  253. for name, child_node in node["spiketrains"].items():
  254. if "SpikeTrain" in name:
  255. obj_ref = child_node.attrs["object_ref"]
  256. spiketrains.append(self.object_refs[obj_ref])
  257. unit = Unit(**attributes)
  258. unit.channel_index = parent
  259. unit.spiketrains = spiketrains
  260. return unit
  261. def _merge_data_objects(self, objects):
  262. if len(objects) > 1:
  263. merged_objects = [objects.pop(0)]
  264. while objects:
  265. obj = objects.pop(0)
  266. try:
  267. combined_obj_ref = merged_objects[-1].annotations['object_ref']
  268. merged_objects[-1] = merged_objects[-1].merge(obj)
  269. merged_objects[-1].annotations['object_ref'] = combined_obj_ref + \
  270. "-" + obj.annotations[
  271. 'object_ref']
  272. except MergeError:
  273. merged_objects.append(obj)
  274. for obj in merged_objects:
  275. self.object_refs[obj.annotations['object_ref']] = obj
  276. return merged_objects
  277. else:
  278. return objects
  279. def _get_quantity(self, node):
  280. value = node[()]
  281. unit_str = [x for x in node.attrs.keys() if "unit" in x][0].split("__")[1]
  282. units = getattr(pq, unit_str)
  283. return value * units
  284. def _get_standard_attributes(self, node):
  285. """Retrieve attributes"""
  286. attributes = {}
  287. for name in ('name', 'description', 'index', 'file_origin', 'object_ref'):
  288. if name in node.attrs:
  289. attributes[name] = node.attrs[name]
  290. for name in ('rec_datetime', 'file_datetime'):
  291. if name in node.attrs:
  292. attributes[name] = pickle.loads(node.attrs[name], encoding='bytes')
  293. annotations = pickle.loads(node.attrs['annotations'], encoding='bytes')
  294. attributes.update(annotations)
  295. # avoid "dictionary changed size during iteration" error
  296. attribute_names = list(attributes.keys())
  297. for name in attribute_names:
  298. if isinstance(attributes[name], (bytes, np.bytes_)):
  299. attributes[name] = attributes[name].decode('utf-8')
  300. if isinstance(name, bytes):
  301. attributes[name.decode('utf-8')] = attributes[name]
  302. attributes.pop(name)
  303. return attributes
  304. def _resolve_channel_indexes(self, block):
  305. def disjoint_channel_indexes(channel_indexes):
  306. channel_indexes = channel_indexes[:]
  307. for ci1 in channel_indexes:
  308. # this works only on analogsignals
  309. signal_group1 = {tuple(x[1]) for x in ci1._channels}
  310. for ci2 in channel_indexes: # need to take irregularly sampled signals
  311. signal_group2 = {tuple(x[1]) for x in ci2._channels} # into account too
  312. if signal_group1 != signal_group2:
  313. if signal_group2.issubset(signal_group1):
  314. channel_indexes.remove(ci2)
  315. elif signal_group1.issubset(signal_group2):
  316. channel_indexes.remove(ci1)
  317. return channel_indexes
  318. principal_indexes = disjoint_channel_indexes(block.channel_indexes)
  319. for ci in principal_indexes:
  320. ids = []
  321. by_segment = {}
  322. for (index, analogsignals, irregsignals) in ci._channels:
  323. # note that what was called "index" in Neo 0.3/0.4 is "id" in Neo 0.5
  324. ids.append(index)
  325. for signal_ref in analogsignals:
  326. signal = self.object_refs[signal_ref]
  327. segment_id = id(signal.segment)
  328. if segment_id in by_segment:
  329. by_segment[segment_id]['analogsignals'].append(signal)
  330. else:
  331. by_segment[segment_id] = {'analogsignals': [signal], 'irregsignals': []}
  332. for signal_ref in irregsignals:
  333. signal = self.object_refs[signal_ref]
  334. segment_id = id(signal.segment)
  335. if segment_id in by_segment:
  336. by_segment[segment_id]['irregsignals'].append(signal)
  337. else:
  338. by_segment[segment_id] = {'analogsignals': [], 'irregsignals': [signal]}
  339. assert len(ids) > 0
  340. if self.merge_singles:
  341. ci.channel_ids = np.array(ids)
  342. ci.index = np.arange(len(ids))
  343. for seg_id, segment_data in by_segment.items():
  344. # get the segment object
  345. segment = None
  346. for seg in ci.block.segments:
  347. if id(seg) == seg_id:
  348. segment = seg
  349. break
  350. assert segment is not None
  351. if segment_data['analogsignals']:
  352. merged_signals = self._merge_data_objects(segment_data['analogsignals'])
  353. assert len(merged_signals) == 1
  354. merged_signals[0].channel_index = ci
  355. merged_signals[0].annotations['object_ref'] = "-".join(
  356. obj.annotations['object_ref']
  357. for obj in segment_data['analogsignals'])
  358. segment.analogsignals.extend(merged_signals)
  359. ci.analogsignals = merged_signals
  360. if segment_data['irregsignals']:
  361. merged_signals = self._merge_data_objects(segment_data['irregsignals'])
  362. assert len(merged_signals) == 1
  363. merged_signals[0].channel_index = ci
  364. merged_signals[0].annotations['object_ref'] = "-".join(
  365. obj.annotations['object_ref']
  366. for obj in segment_data['irregsignals'])
  367. segment.irregularlysampledsignals.extend(merged_signals)
  368. ci.irregularlysampledsignals = merged_signals
  369. else:
  370. raise NotImplementedError() # will need to return multiple ChannelIndexes
  371. # handle non-principal channel indexes
  372. for ci in block.channel_indexes:
  373. if ci not in principal_indexes:
  374. ids = [c[0] for c in ci._channels]
  375. for cipr in principal_indexes:
  376. if ids[0] in cipr.channel_ids:
  377. break
  378. ci.analogsignals = cipr.analogsignals
  379. ci.channel_ids = np.array(ids)
  380. ci.index = np.where(np.in1d(cipr.channel_ids, ci.channel_ids))[0]