Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

hdf5io.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # -*- coding: utf-8 -*-
  2. """
  3. """
  4. from __future__ import absolute_import
  5. import logging
  6. import pickle
  7. import numpy as np
  8. import quantities as pq
  9. try:
  10. import h5py
  11. except ImportError as err:
  12. HAVE_H5PY = False
  13. else:
  14. HAVE_H5PY = True
  15. from neo.core import (objectlist, Block, Segment, AnalogSignal, SpikeTrain,
  16. Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  17. Unit)
  18. from neo.io.baseio import BaseIO
  19. from neo.core.baseneo import MergeError
  20. logger = logging.getLogger('Neo')
  21. def disjoint_groups(groups):
  22. """`groups` should be a list of sets"""
  23. groups = groups[:] # copy, so as not to change original
  24. for group1 in groups:
  25. for group2 in groups:
  26. if group1 != group2:
  27. if group2.issubset(group1):
  28. groups.remove(group2)
  29. elif group1.issubset(group2):
  30. groups.remove(group1)
  31. return groups
  32. class NeoHdf5IO(BaseIO):
  33. """
  34. Class for reading HDF5 format files created by Neo version 0.4 or earlier.
  35. Writing to HDF5 is not supported by this IO; we recommend using NixIO for this.
  36. """
  37. supported_objects = objectlist
  38. readable_objects = objectlist
  39. name = 'NeoHdf5 IO'
  40. extensions = ['h5']
  41. mode = 'file'
  42. is_readable = True
  43. is_writable = False
  44. def __init__(self, filename):
  45. if not HAVE_H5PY:
  46. raise ImportError("h5py is not available")
  47. BaseIO.__init__(self, filename=filename)
  48. self._data = h5py.File(filename, 'r')
  49. self.object_refs = {}
  50. self._lazy = False
  51. def read_all_blocks(self, lazy=False, cascade=True, merge_singles=True, **kargs):
  52. """
  53. Loads all blocks in the file that are attached to the root (which
  54. happens when they are saved with save() or write_block()).
  55. If `merge_singles` is True, then the IO will attempt to merge single channel
  56. `AnalogSignal` objects into multichannel objects, and similarly for single `Epoch`,
  57. `Event` and `IrregularlySampledSignal` objects.
  58. """
  59. self._lazy = lazy
  60. self._cascade = cascade
  61. self.merge_singles = merge_singles
  62. blocks = []
  63. for name, node in self._data.items():
  64. if "Block" in name:
  65. blocks.append(self._read_block(node))
  66. return blocks
  67. def read_block(self, lazy=False, cascade=True, **kargs):
  68. """
  69. Load the first block in the file.
  70. """
  71. return self.read_all_blocks(lazy=lazy, cascade=cascade)[0]
  72. def _read_block(self, node):
  73. attributes = self._get_standard_attributes(node)
  74. block = Block(**attributes)
  75. if self._cascade:
  76. for name, child_node in node['segments'].items():
  77. if "Segment" in name:
  78. block.segments.append(self._read_segment(child_node, parent=block))
  79. if len(node['recordingchannelgroups']) > 0:
  80. for name, child_node in node['recordingchannelgroups'].items():
  81. if "RecordingChannelGroup" in name:
  82. block.channel_indexes.append(self._read_recordingchannelgroup(child_node, parent=block))
  83. self._resolve_channel_indexes(block)
  84. elif self.merge_singles:
  85. # if no RecordingChannelGroups are defined, merging
  86. # takes place here.
  87. for segment in block.segments:
  88. if hasattr(segment, 'unmerged_analogsignals'):
  89. segment.analogsignals.extend(
  90. self._merge_data_objects(segment.unmerged_analogsignals))
  91. del segment.unmerged_analogsignals
  92. if hasattr(segment, 'unmerged_irregularlysampledsignals'):
  93. segment.irregularlysampledsignals.extend(
  94. self._merge_data_objects(segment.unmerged_irregularlysampledsignals))
  95. del segment.unmerged_irregularlysampledsignals
  96. return block
  97. def _read_segment(self, node, parent):
  98. attributes = self._get_standard_attributes(node)
  99. segment = Segment(**attributes)
  100. signals = []
  101. for name, child_node in node['analogsignals'].items():
  102. if "AnalogSignal" in name:
  103. signals.append(self._read_analogsignal(child_node, parent=segment))
  104. if signals and self.merge_singles:
  105. segment.unmerged_analogsignals = signals # signals will be merged later
  106. signals = []
  107. for name, child_node in node['analogsignalarrays'].items():
  108. if "AnalogSignalArray" in name:
  109. signals.append(self._read_analogsignalarray(child_node, parent=segment))
  110. segment.analogsignals = signals
  111. irr_signals = []
  112. for name, child_node in node['irregularlysampledsignals'].items():
  113. if "IrregularlySampledSignal" in name:
  114. irr_signals.append(self._read_irregularlysampledsignal(child_node, parent=segment))
  115. if irr_signals and self.merge_singles:
  116. segment.unmerged_irregularlysampledsignals = irr_signals
  117. irr_signals = []
  118. segment.irregularlysampledsignals = irr_signals
  119. epochs = []
  120. for name, child_node in node['epochs'].items():
  121. if "Epoch" in name:
  122. epochs.append(self._read_epoch(child_node, parent=segment))
  123. if self.merge_singles:
  124. epochs = self._merge_data_objects(epochs)
  125. for name, child_node in node['epocharrays'].items():
  126. if "EpochArray" in name:
  127. epochs.append(self._read_epocharray(child_node, parent=segment))
  128. segment.epochs = epochs
  129. events = []
  130. for name, child_node in node['events'].items():
  131. if "Event" in name:
  132. events.append(self._read_event(child_node, parent=segment))
  133. if self.merge_singles:
  134. events = self._merge_data_objects(events)
  135. for name, child_node in node['eventarrays'].items():
  136. if "EventArray" in name:
  137. events.append(self._read_eventarray(child_node, parent=segment))
  138. segment.events = events
  139. spiketrains = []
  140. for name, child_node in node['spikes'].items():
  141. raise NotImplementedError('Spike objects not yet handled.')
  142. for name, child_node in node['spiketrains'].items():
  143. if "SpikeTrain" in name:
  144. spiketrains.append(self._read_spiketrain(child_node, parent=segment))
  145. segment.spiketrains = spiketrains
  146. segment.block = parent
  147. return segment
  148. def _read_analogsignalarray(self, node, parent):
  149. attributes = self._get_standard_attributes(node)
  150. # todo: handle channel_index
  151. sampling_rate = self._get_quantity(node["sampling_rate"])
  152. t_start = self._get_quantity(node["t_start"])
  153. signal = AnalogSignal(self._get_quantity(node["signal"]),
  154. sampling_rate=sampling_rate, t_start=t_start,
  155. **attributes)
  156. if self._lazy:
  157. signal.lazy_shape = node["signal"].shape
  158. if len(signal.lazy_shape) == 1:
  159. signal.lazy_shape = (signal.lazy_shape[0], 1)
  160. signal.segment = parent
  161. self.object_refs[node.attrs["object_ref"]] = signal
  162. return signal
  163. def _read_analogsignal(self, node, parent):
  164. return self._read_analogsignalarray(node, parent)
  165. def _read_irregularlysampledsignal(self, node, parent):
  166. attributes = self._get_standard_attributes(node)
  167. signal = IrregularlySampledSignal(times=self._get_quantity(node["times"]),
  168. signal=self._get_quantity(node["signal"]),
  169. **attributes)
  170. signal.segment = parent
  171. if self._lazy:
  172. signal.lazy_shape = node["signal"].shape
  173. if len(signal.lazy_shape) == 1:
  174. signal.lazy_shape = (signal.lazy_shape[0], 1)
  175. return signal
  176. def _read_spiketrain(self, node, parent):
  177. attributes = self._get_standard_attributes(node)
  178. t_start = self._get_quantity(node["t_start"])
  179. t_stop = self._get_quantity(node["t_stop"])
  180. # todo: handle sampling_rate, waveforms, left_sweep
  181. spiketrain = SpikeTrain(self._get_quantity(node["times"]),
  182. t_start=t_start, t_stop=t_stop,
  183. **attributes)
  184. spiketrain.segment = parent
  185. if self._lazy:
  186. spiketrain.lazy_shape = node["times"].shape
  187. self.object_refs[node.attrs["object_ref"]] = spiketrain
  188. return spiketrain
  189. def _read_epocharray(self, node, parent):
  190. attributes = self._get_standard_attributes(node)
  191. times = self._get_quantity(node["times"])
  192. durations = self._get_quantity(node["durations"])
  193. if self._lazy:
  194. labels = np.array((), dtype=node["labels"].dtype)
  195. else:
  196. labels = node["labels"].value
  197. epoch = Epoch(times=times, durations=durations, labels=labels, **attributes)
  198. epoch.segment = parent
  199. if self._lazy:
  200. epoch.lazy_shape = node["times"].shape
  201. return epoch
  202. def _read_epoch(self, node, parent):
  203. return self._read_epocharray(node, parent)
  204. def _read_eventarray(self, node, parent):
  205. attributes = self._get_standard_attributes(node)
  206. times = self._get_quantity(node["times"])
  207. if self._lazy:
  208. labels = np.array((), dtype=node["labels"].dtype)
  209. else:
  210. labels = node["labels"].value
  211. event = Event(times=times, labels=labels, **attributes)
  212. event.segment = parent
  213. if self._lazy:
  214. event.lazy_shape = node["times"].shape
  215. return event
  216. def _read_event(self, node, parent):
  217. return self._read_eventarray(node, parent)
  218. def _read_recordingchannelgroup(self, node, parent):
  219. # todo: handle Units
  220. attributes = self._get_standard_attributes(node)
  221. channel_indexes = node["channel_indexes"].value
  222. channel_names = node["channel_names"].value
  223. if channel_indexes.size:
  224. if len(node['recordingchannels']) :
  225. raise MergeError("Cannot handle a RecordingChannelGroup which both has a "
  226. "'channel_indexes' attribute and contains "
  227. "RecordingChannel objects")
  228. raise NotImplementedError("todo") # need to handle node['analogsignalarrays']
  229. else:
  230. channels = []
  231. for name, child_node in node['recordingchannels'].items():
  232. if "RecordingChannel" in name:
  233. channels.append(self._read_recordingchannel(child_node))
  234. channel_index = ChannelIndex(None, **attributes)
  235. channel_index._channels = channels
  236. # construction of the index is deferred until we have processed
  237. # all RecordingChannelGroup nodes
  238. units = []
  239. for name, child_node in node['units'].items():
  240. if "Unit" in name:
  241. units.append(self._read_unit(child_node, parent=channel_index))
  242. channel_index.units = units
  243. channel_index.block = parent
  244. return channel_index
  245. def _read_recordingchannel(self, node):
  246. attributes = self._get_standard_attributes(node)
  247. analogsignals = []
  248. irregsignals = []
  249. for name, child_node in node["analogsignals"].items():
  250. if "AnalogSignal" in name:
  251. obj_ref = child_node.attrs["object_ref"]
  252. analogsignals.append(obj_ref)
  253. for name, child_node in node["irregularlysampledsignals"].items():
  254. if "IrregularlySampledSignal" in name:
  255. obj_ref = child_node.attrs["object_ref"]
  256. irregsignals.append(obj_ref)
  257. return attributes['index'], analogsignals, irregsignals
  258. def _read_unit(self, node, parent):
  259. attributes = self._get_standard_attributes(node)
  260. spiketrains = []
  261. for name, child_node in node["spiketrains"].items():
  262. if "SpikeTrain" in name:
  263. obj_ref = child_node.attrs["object_ref"]
  264. spiketrains.append(self.object_refs[obj_ref])
  265. unit = Unit(**attributes)
  266. unit.channel_index = parent
  267. unit.spiketrains = spiketrains
  268. return unit
  269. def _merge_data_objects(self, objects):
  270. if len(objects) > 1:
  271. merged_objects = [objects.pop(0)]
  272. while objects:
  273. obj = objects.pop(0)
  274. try:
  275. combined_obj_ref = merged_objects[-1].annotations['object_ref']
  276. merged_objects[-1] = merged_objects[-1].merge(obj)
  277. merged_objects[-1].annotations['object_ref'] = combined_obj_ref + "-" + obj.annotations['object_ref']
  278. except MergeError:
  279. merged_objects.append(obj)
  280. for obj in merged_objects:
  281. self.object_refs[obj.annotations['object_ref']] = obj
  282. return merged_objects
  283. else:
  284. return objects
  285. def _get_quantity(self, node):
  286. if self._lazy and len(node.shape) > 0:
  287. value = np.array((), dtype=node.dtype)
  288. else:
  289. value = node.value
  290. unit_str = [x for x in node.attrs.keys() if "unit" in x][0].split("__")[1]
  291. units = getattr(pq, unit_str)
  292. return value * units
  293. def _get_standard_attributes(self, node):
  294. """Retrieve attributes"""
  295. attributes = {}
  296. for name in ('name', 'description', 'index', 'file_origin', 'object_ref'):
  297. if name in node.attrs:
  298. attributes[name] = node.attrs[name]
  299. for name in ('rec_datetime', 'file_datetime'):
  300. if name in node.attrs:
  301. attributes[name] = pickle.loads(node.attrs[name])
  302. attributes.update(pickle.loads(node.attrs['annotations']))
  303. return attributes
  304. def _resolve_channel_indexes(self, block):
  305. def disjoint_channel_indexes(channel_indexes):
  306. channel_indexes = channel_indexes[:]
  307. for ci1 in channel_indexes:
  308. signal_group1 = set(tuple(x[1]) for x in ci1._channels) # this works only on analogsignals
  309. for ci2 in channel_indexes: # need to take irregularly sampled signals
  310. signal_group2 = set(tuple(x[1]) for x in ci2._channels) # into account too
  311. if signal_group1 != signal_group2:
  312. if signal_group2.issubset(signal_group1):
  313. channel_indexes.remove(ci2)
  314. elif signal_group1.issubset(signal_group2):
  315. channel_indexes.remove(ci1)
  316. return channel_indexes
  317. principal_indexes = disjoint_channel_indexes(block.channel_indexes)
  318. for ci in principal_indexes:
  319. ids = []
  320. by_segment = {}
  321. for (index, analogsignals, irregsignals) in ci._channels:
  322. ids.append(index) # note that what was called "index" in Neo 0.3/0.4 is "id" in Neo 0.5
  323. for signal_ref in analogsignals:
  324. signal = self.object_refs[signal_ref]
  325. segment_id = id(signal.segment)
  326. if segment_id in by_segment:
  327. by_segment[segment_id]['analogsignals'].append(signal)
  328. else:
  329. by_segment[segment_id] = {'analogsignals': [signal], 'irregsignals': []}
  330. for signal_ref in irregsignals:
  331. signal = self.object_refs[signal_ref]
  332. segment_id = id(signal.segment)
  333. if segment_id in by_segment:
  334. by_segment[segment_id]['irregsignals'].append(signal)
  335. else:
  336. by_segment[segment_id] = {'analogsignals': [], 'irregsignals': [signal]}
  337. assert len(ids) > 0
  338. if self.merge_singles:
  339. ci.channel_ids = np.array(ids)
  340. ci.index = np.arange(len(ids))
  341. for seg_id, segment_data in by_segment.items():
  342. # get the segment object
  343. segment = None
  344. for seg in ci.block.segments:
  345. if id(seg) == seg_id:
  346. segment = seg
  347. break
  348. assert segment is not None
  349. if segment_data['analogsignals']:
  350. merged_signals = self._merge_data_objects(segment_data['analogsignals'])
  351. assert len(merged_signals) == 1
  352. merged_signals[0].channel_index = ci
  353. merged_signals[0].annotations['object_ref'] = "-".join(obj.annotations['object_ref']
  354. for obj in segment_data['analogsignals'])
  355. segment.analogsignals.extend(merged_signals)
  356. ci.analogsignals = merged_signals
  357. if segment_data['irregsignals']:
  358. merged_signals = self._merge_data_objects(segment_data['irregsignals'])
  359. assert len(merged_signals) == 1
  360. merged_signals[0].channel_index = ci
  361. merged_signals[0].annotations['object_ref'] = "-".join(obj.annotations['object_ref']
  362. for obj in segment_data['irregsignals'])
  363. segment.irregularlysampledsignals.extend(merged_signals)
  364. ci.irregularlysampledsignals = merged_signals
  365. else:
  366. raise NotImplementedError() # will need to return multiple ChannelIndexes
  367. # handle non-principal channel indexes
  368. for ci in block.channel_indexes:
  369. if ci not in principal_indexes:
  370. ids = [c[0] for c in ci._channels]
  371. for cipr in principal_indexes:
  372. if ids[0] in cipr.channel_ids:
  373. break
  374. ci.analogsignals = cipr.analogsignals
  375. ci.channel_ids = np.array(ids)
  376. ci.index = np.where(np.in1d(cipr.channel_ids, ci.channel_ids))[0]