hdf5io.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # -*- coding: utf-8 -*-
  2. """
  3. """
  4. from __future__ import absolute_import
  5. import sys
  6. import logging
  7. import pickle
  8. import numpy as np
  9. import quantities as pq
  10. try:
  11. import h5py
  12. except ImportError as err:
  13. HAVE_H5PY = False
  14. else:
  15. HAVE_H5PY = True
  16. from neo.core import (objectlist, Block, Segment, AnalogSignal, SpikeTrain,
  17. Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  18. Unit)
  19. from neo.io.baseio import BaseIO
  20. from neo.core.baseneo import MergeError
  21. logger = logging.getLogger('Neo')
  22. def disjoint_groups(groups):
  23. """`groups` should be a list of sets"""
  24. groups = groups[:] # copy, so as not to change original
  25. for group1 in groups:
  26. for group2 in groups:
  27. if group1 != group2:
  28. if group2.issubset(group1):
  29. groups.remove(group2)
  30. elif group1.issubset(group2):
  31. groups.remove(group1)
  32. return groups
  33. class NeoHdf5IO(BaseIO):
  34. """
  35. Class for reading HDF5 format files created by Neo version 0.4 or earlier.
  36. Writing to HDF5 is not supported by this IO; we recommend using NixIO for this.
  37. """
  38. supported_objects = objectlist
  39. readable_objects = objectlist
  40. name = 'NeoHdf5 IO'
  41. extensions = ['h5']
  42. mode = 'file'
  43. is_readable = True
  44. is_writable = False
  45. def __init__(self, filename):
  46. if not HAVE_H5PY:
  47. raise ImportError("h5py is not available")
  48. BaseIO.__init__(self, filename=filename)
  49. self._data = h5py.File(filename, 'r')
  50. self.object_refs = {}
  51. def read_all_blocks(self, lazy=False, merge_singles=True, **kargs):
  52. """
  53. Loads all blocks in the file that are attached to the root (which
  54. happens when they are saved with save() or write_block()).
  55. If `merge_singles` is True, then the IO will attempt to merge single channel
  56. `AnalogSignal` objects into multichannel objects, and similarly for single `Epoch`,
  57. `Event` and `IrregularlySampledSignal` objects.
  58. """
  59. assert not lazy, 'Do not support lazy'
  60. self.merge_singles = merge_singles
  61. blocks = []
  62. for name, node in self._data.items():
  63. if "Block" in name:
  64. blocks.append(self._read_block(node))
  65. return blocks
  66. def read_block(self, lazy=False, **kargs):
  67. """
  68. Load the first block in the file.
  69. """
  70. assert not lazy, 'Do not support lazy'
  71. return self.read_all_blocks(lazy=lazy)[0]
  72. def _read_block(self, node):
  73. attributes = self._get_standard_attributes(node)
  74. if "index" in attributes:
  75. attributes["index"] = int(attributes["index"])
  76. block = Block(**attributes)
  77. for name, child_node in node['segments'].items():
  78. if "Segment" in name:
  79. block.segments.append(self._read_segment(child_node, parent=block))
  80. if len(node['recordingchannelgroups']) > 0:
  81. for name, child_node in node['recordingchannelgroups'].items():
  82. if "RecordingChannelGroup" in name:
  83. block.channel_indexes.append(
  84. self._read_recordingchannelgroup(child_node, parent=block))
  85. self._resolve_channel_indexes(block)
  86. elif self.merge_singles:
  87. # if no RecordingChannelGroups are defined, merging
  88. # takes place here.
  89. for segment in block.segments:
  90. if hasattr(segment, 'unmerged_analogsignals'):
  91. segment.analogsignals.extend(
  92. self._merge_data_objects(segment.unmerged_analogsignals))
  93. del segment.unmerged_analogsignals
  94. if hasattr(segment, 'unmerged_irregularlysampledsignals'):
  95. segment.irregularlysampledsignals.extend(
  96. self._merge_data_objects(segment.unmerged_irregularlysampledsignals))
  97. del segment.unmerged_irregularlysampledsignals
  98. return block
  99. def _read_segment(self, node, parent):
  100. attributes = self._get_standard_attributes(node)
  101. segment = Segment(**attributes)
  102. signals = []
  103. for name, child_node in node['analogsignals'].items():
  104. if "AnalogSignal" in name:
  105. signals.append(self._read_analogsignal(child_node, parent=segment))
  106. if signals and self.merge_singles:
  107. segment.unmerged_analogsignals = signals # signals will be merged later
  108. signals = []
  109. for name, child_node in node['analogsignalarrays'].items():
  110. if "AnalogSignalArray" in name:
  111. signals.append(self._read_analogsignalarray(child_node, parent=segment))
  112. segment.analogsignals = signals
  113. irr_signals = []
  114. for name, child_node in node['irregularlysampledsignals'].items():
  115. if "IrregularlySampledSignal" in name:
  116. irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
  117. if irr_signals and self.merge_singles:
  118. segment.unmerged_irregularlysampledsignals = irr_signals
  119. irr_signals = []
  120. segment.irregularlysampledsignals = irr_signals
  121. epochs = []
  122. for name, child_node in node['epochs'].items():
  123. if "Epoch" in name:
  124. epochs.append(self._read_epoch(child_node, parent=segment))
  125. if self.merge_singles:
  126. epochs = self._merge_data_objects(epochs)
  127. for name, child_node in node['epocharrays'].items():
  128. if "EpochArray" in name:
  129. epochs.append(self._read_epocharray(child_node, parent=segment))
  130. segment.epochs = epochs
  131. events = []
  132. for name, child_node in node['events'].items():
  133. if "Event" in name:
  134. events.append(self._read_event(child_node, parent=segment))
  135. if self.merge_singles:
  136. events = self._merge_data_objects(events)
  137. for name, child_node in node['eventarrays'].items():
  138. if "EventArray" in name:
  139. events.append(self._read_eventarray(child_node, parent=segment))
  140. segment.events = events
  141. spiketrains = []
  142. for name, child_node in node['spikes'].items():
  143. raise NotImplementedError('Spike objects not yet handled.')
  144. for name, child_node in node['spiketrains'].items():
  145. if "SpikeTrain" in name:
  146. spiketrains.append(self._read_spiketrain(child_node, parent=segment))
  147. segment.spiketrains = spiketrains
  148. segment.block = parent
  149. return segment
  150. def _read_analogsignalarray(self, node, parent):
  151. attributes = self._get_standard_attributes(node)
  152. # todo: handle channel_index
  153. sampling_rate = self._get_quantity(node["sampling_rate"])
  154. t_start = self._get_quantity(node["t_start"])
  155. signal = AnalogSignal(self._get_quantity(node["signal"]),
  156. sampling_rate=sampling_rate, t_start=t_start,
  157. **attributes)
  158. signal.segment = parent
  159. self.object_refs[node.attrs["object_ref"]] = signal
  160. return signal
  161. def _read_analogsignal(self, node, parent):
  162. return self._read_analogsignalarray(node, parent)
  163. def _read_irregularlysampledsignal(self, node, parent):
  164. attributes = self._get_standard_attributes(node)
  165. signal = IrregularlySampledSignal(times=self._get_quantity(node["times"]),
  166. signal=self._get_quantity(node["signal"]),
  167. **attributes)
  168. signal.segment = parent
  169. return signal
  170. def _read_spiketrain(self, node, parent):
  171. attributes = self._get_standard_attributes(node)
  172. t_start = self._get_quantity(node["t_start"])
  173. t_stop = self._get_quantity(node["t_stop"])
  174. # todo: handle sampling_rate, waveforms, left_sweep
  175. spiketrain = SpikeTrain(self._get_quantity(node["times"]),
  176. t_start=t_start, t_stop=t_stop,
  177. **attributes)
  178. spiketrain.segment = parent
  179. self.object_refs[node.attrs["object_ref"]] = spiketrain
  180. return spiketrain
  181. def _read_epocharray(self, node, parent):
  182. attributes = self._get_standard_attributes(node)
  183. times = self._get_quantity(node["times"])
  184. durations = self._get_quantity(node["durations"])
  185. labels = node["labels"].value
  186. epoch = Epoch(times=times, durations=durations, labels=labels, **attributes)
  187. epoch.segment = parent
  188. return epoch
  189. def _read_epoch(self, node, parent):
  190. return self._read_epocharray(node, parent)
  191. def _read_eventarray(self, node, parent):
  192. attributes = self._get_standard_attributes(node)
  193. times = self._get_quantity(node["times"])
  194. labels = node["labels"].value
  195. event = Event(times=times, labels=labels, **attributes)
  196. event.segment = parent
  197. return event
  198. def _read_event(self, node, parent):
  199. return self._read_eventarray(node, parent)
  200. def _read_recordingchannelgroup(self, node, parent):
  201. # todo: handle Units
  202. attributes = self._get_standard_attributes(node)
  203. channel_indexes = node["channel_indexes"].value
  204. channel_names = node["channel_names"].value
  205. if channel_indexes.size:
  206. if len(node['recordingchannels']):
  207. raise MergeError("Cannot handle a RecordingChannelGroup which both has a "
  208. "'channel_indexes' attribute and contains "
  209. "RecordingChannel objects")
  210. raise NotImplementedError("todo") # need to handle node['analogsignalarrays']
  211. else:
  212. channels = []
  213. for name, child_node in node['recordingchannels'].items():
  214. if "RecordingChannel" in name:
  215. channels.append(self._read_recordingchannel(child_node))
  216. channel_index = ChannelIndex(None, **attributes)
  217. channel_index._channels = channels
  218. # construction of the index is deferred until we have processed
  219. # all RecordingChannelGroup nodes
  220. units = []
  221. for name, child_node in node['units'].items():
  222. if "Unit" in name:
  223. units.append(self._read_unit(child_node, parent=channel_index))
  224. channel_index.units = units
  225. channel_index.block = parent
  226. return channel_index
  227. def _read_recordingchannel(self, node):
  228. attributes = self._get_standard_attributes(node)
  229. analogsignals = []
  230. irregsignals = []
  231. for name, child_node in node["analogsignals"].items():
  232. if "AnalogSignal" in name:
  233. obj_ref = child_node.attrs["object_ref"]
  234. analogsignals.append(obj_ref)
  235. for name, child_node in node["irregularlysampledsignals"].items():
  236. if "IrregularlySampledSignal" in name:
  237. obj_ref = child_node.attrs["object_ref"]
  238. irregsignals.append(obj_ref)
  239. return attributes['index'], analogsignals, irregsignals
  240. def _read_unit(self, node, parent):
  241. attributes = self._get_standard_attributes(node)
  242. spiketrains = []
  243. for name, child_node in node["spiketrains"].items():
  244. if "SpikeTrain" in name:
  245. obj_ref = child_node.attrs["object_ref"]
  246. spiketrains.append(self.object_refs[obj_ref])
  247. unit = Unit(**attributes)
  248. unit.channel_index = parent
  249. unit.spiketrains = spiketrains
  250. return unit
  251. def _merge_data_objects(self, objects):
  252. if len(objects) > 1:
  253. merged_objects = [objects.pop(0)]
  254. while objects:
  255. obj = objects.pop(0)
  256. try:
  257. combined_obj_ref = merged_objects[-1].annotations['object_ref']
  258. merged_objects[-1] = merged_objects[-1].merge(obj)
  259. merged_objects[-1].annotations['object_ref'] = combined_obj_ref + \
  260. "-" + obj.annotations[
  261. 'object_ref']
  262. except MergeError:
  263. merged_objects.append(obj)
  264. for obj in merged_objects:
  265. self.object_refs[obj.annotations['object_ref']] = obj
  266. return merged_objects
  267. else:
  268. return objects
  269. def _get_quantity(self, node):
  270. value = node.value
  271. unit_str = [x for x in node.attrs.keys() if "unit" in x][0].split("__")[1]
  272. units = getattr(pq, unit_str)
  273. return value * units
  274. def _get_standard_attributes(self, node):
  275. """Retrieve attributes"""
  276. attributes = {}
  277. for name in ('name', 'description', 'index', 'file_origin', 'object_ref'):
  278. if name in node.attrs:
  279. attributes[name] = node.attrs[name]
  280. for name in ('rec_datetime', 'file_datetime'):
  281. if name in node.attrs:
  282. if sys.version_info.major > 2:
  283. attributes[name] = pickle.loads(node.attrs[name], encoding='bytes')
  284. else: # Python 2 doesn't have the encoding argument
  285. attributes[name] = pickle.loads(node.attrs[name])
  286. if sys.version_info.major > 2:
  287. annotations = pickle.loads(node.attrs['annotations'], encoding='bytes')
  288. else:
  289. annotations = pickle.loads(node.attrs['annotations'])
  290. attributes.update(annotations)
  291. # avoid "dictionary changed size during iteration" error
  292. attribute_names = list(attributes.keys())
  293. if sys.version_info.major > 2:
  294. for name in attribute_names:
  295. if isinstance(attributes[name], (bytes, np.bytes_)):
  296. attributes[name] = attributes[name].decode('utf-8')
  297. if isinstance(name, bytes):
  298. attributes[name.decode('utf-8')] = attributes[name]
  299. attributes.pop(name)
  300. return attributes
  301. def _resolve_channel_indexes(self, block):
  302. def disjoint_channel_indexes(channel_indexes):
  303. channel_indexes = channel_indexes[:]
  304. for ci1 in channel_indexes:
  305. # this works only on analogsignals
  306. signal_group1 = set(tuple(x[1]) for x in ci1._channels)
  307. for ci2 in channel_indexes: # need to take irregularly sampled signals
  308. signal_group2 = set(tuple(x[1]) for x in ci2._channels) # into account too
  309. if signal_group1 != signal_group2:
  310. if signal_group2.issubset(signal_group1):
  311. channel_indexes.remove(ci2)
  312. elif signal_group1.issubset(signal_group2):
  313. channel_indexes.remove(ci1)
  314. return channel_indexes
  315. principal_indexes = disjoint_channel_indexes(block.channel_indexes)
  316. for ci in principal_indexes:
  317. ids = []
  318. by_segment = {}
  319. for (index, analogsignals, irregsignals) in ci._channels:
  320. # note that what was called "index" in Neo 0.3/0.4 is "id" in Neo 0.5
  321. ids.append(index)
  322. for signal_ref in analogsignals:
  323. signal = self.object_refs[signal_ref]
  324. segment_id = id(signal.segment)
  325. if segment_id in by_segment:
  326. by_segment[segment_id]['analogsignals'].append(signal)
  327. else:
  328. by_segment[segment_id] = {'analogsignals': [signal], 'irregsignals': []}
  329. for signal_ref in irregsignals:
  330. signal = self.object_refs[signal_ref]
  331. segment_id = id(signal.segment)
  332. if segment_id in by_segment:
  333. by_segment[segment_id]['irregsignals'].append(signal)
  334. else:
  335. by_segment[segment_id] = {'analogsignals': [], 'irregsignals': [signal]}
  336. assert len(ids) > 0
  337. if self.merge_singles:
  338. ci.channel_ids = np.array(ids)
  339. ci.index = np.arange(len(ids))
  340. for seg_id, segment_data in by_segment.items():
  341. # get the segment object
  342. segment = None
  343. for seg in ci.block.segments:
  344. if id(seg) == seg_id:
  345. segment = seg
  346. break
  347. assert segment is not None
  348. if segment_data['analogsignals']:
  349. merged_signals = self._merge_data_objects(segment_data['analogsignals'])
  350. assert len(merged_signals) == 1
  351. merged_signals[0].channel_index = ci
  352. merged_signals[0].annotations['object_ref'] = "-".join(
  353. obj.annotations['object_ref']
  354. for obj in segment_data['analogsignals'])
  355. segment.analogsignals.extend(merged_signals)
  356. ci.analogsignals = merged_signals
  357. if segment_data['irregsignals']:
  358. merged_signals = self._merge_data_objects(segment_data['irregsignals'])
  359. assert len(merged_signals) == 1
  360. merged_signals[0].channel_index = ci
  361. merged_signals[0].annotations['object_ref'] = "-".join(
  362. obj.annotations['object_ref']
  363. for obj in segment_data['irregsignals'])
  364. segment.irregularlysampledsignals.extend(merged_signals)
  365. ci.irregularlysampledsignals = merged_signals
  366. else:
  367. raise NotImplementedError() # will need to return multiple ChannelIndexes
  368. # handle non-principal channel indexes
  369. for ci in block.channel_indexes:
  370. if ci not in principal_indexes:
  371. ids = [c[0] for c in ci._channels]
  372. for cipr in principal_indexes:
  373. if ids[0] in cipr.channel_ids:
  374. break
  375. ci.analogsignals = cipr.analogsignals
  376. ci.channel_ids = np.array(ids)
  377. ci.index = np.where(np.in1d(cipr.channel_ids, ci.channel_ids))[0]