neo_tools.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tools to manipulate Neo objects.
  4. :copyright: Copyright 2014-2016 by the Elephant team, see `doc/authors.rst`.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division, print_function, unicode_literals
  8. import warnings
  9. from itertools import chain
  10. from neo.core.container import unique_objs
  11. from elephant.utils import deprecated_alias
  12. __all__ = [
  13. "extract_neo_attributes",
  14. "get_all_spiketrains",
  15. "get_all_events",
  16. "get_all_epochs"
  17. ]
  18. @deprecated_alias(obj='neo_object')
  19. def extract_neo_attributes(neo_object, parents=True, child_first=True,
  20. skip_array=False, skip_none=False):
  21. """
  22. Given a Neo object, return a dictionary of attributes and annotations.
  23. Parameters
  24. ----------
  25. neo_object : neo.BaseNeo
  26. Object to get attributes and annotations.
  27. parents : bool, optional
  28. If True, also include attributes and annotations from parent Neo
  29. objects (if any).
  30. Default: True.
  31. child_first : bool, optional
  32. If True, values of child attributes are used over parent attributes in
  33. the event of a name conflict.
  34. If False, parent attributes are used.
  35. This parameter does nothing if `parents` is False.
  36. Default: True.
  37. skip_array : bool, optional
  38. If True, skip attributes that store non-scalar array values.
  39. Default: False.
  40. skip_none : bool, optional
  41. If True, skip annotations and attributes that have a value of None.
  42. Default: False.
  43. Returns
  44. -------
  45. dict
  46. A dictionary where the keys are annotations or attribute names and
  47. the values are the corresponding annotation or attribute value.
  48. """
  49. attrs = neo_object.annotations.copy()
  50. if not skip_array and hasattr(neo_object, "array_annotations"):
  51. # Exclude labels and durations, and any other fields that should not
  52. # be a part of array_annotation.
  53. required_keys = set(neo_object.array_annotations).difference(
  54. dir(neo_object))
  55. for a in required_keys:
  56. if "array_annotations" not in attrs:
  57. attrs["array_annotations"] = {}
  58. attrs["array_annotations"][a] = \
  59. neo_object.array_annotations[a].copy()
  60. for attr in neo_object._necessary_attrs + neo_object._recommended_attrs:
  61. if skip_array and len(attr) >= 3 and attr[2]:
  62. continue
  63. attr = attr[0]
  64. if attr == getattr(neo_object, '_quantity_attr', None):
  65. continue
  66. attrs[attr] = getattr(neo_object, attr, None)
  67. if skip_none:
  68. for attr, value in attrs.copy().items():
  69. if value is None:
  70. del attrs[attr]
  71. if not parents:
  72. return attrs
  73. for parent in getattr(neo_object, 'parents', []):
  74. if parent is None:
  75. continue
  76. newattr = extract_neo_attributes(parent, parents=True,
  77. child_first=child_first,
  78. skip_array=skip_array,
  79. skip_none=skip_none)
  80. if child_first:
  81. newattr.update(attrs)
  82. attrs = newattr
  83. else:
  84. attrs.update(newattr)
  85. return attrs
  86. def extract_neo_attrs(*args, **kwargs):
  87. warnings.warn("'extract_neo_attrs' function is deprecated; "
  88. "use 'extract_neo_attributes'", DeprecationWarning)
  89. return extract_neo_attributes(*args, **kwargs)
  90. def _get_all_objs(container, class_name):
  91. """
  92. Get all Neo objects of a given type from a container.
  93. The objects can be any list, dict, or other iterable or mapping containing
  94. Neo objects of a particular class, as well as any Neo object that can hold
  95. the object.
  96. Objects are searched recursively, so the objects can be nested (such as a
  97. list of blocks).
  98. Parameters
  99. ----------
  100. container : list, tuple, iterable, dict, neo.Container
  101. The container for the Neo objects.
  102. class_name : str
  103. The name of the class, with proper capitalization
  104. (i.e., 'SpikeTrain', not 'Spiketrain' or 'spiketrain').
  105. Returns
  106. -------
  107. list
  108. A list of unique Neo objects.
  109. Raises
  110. ------
  111. ValueError
  112. If can not handle containers of the type passed in `container`.
  113. """
  114. if container.__class__.__name__ == class_name:
  115. return [container]
  116. classholder = class_name.lower() + 's'
  117. if hasattr(container, classholder):
  118. vals = getattr(container, classholder)
  119. elif hasattr(container, 'list_children_by_class'):
  120. vals = container.list_children_by_class(class_name)
  121. elif hasattr(container, 'values') and not hasattr(container, 'ndim'):
  122. vals = container.values()
  123. elif hasattr(container, '__iter__') and not hasattr(container, 'ndim'):
  124. vals = container
  125. else:
  126. raise ValueError('Cannot handle object of type %s' % type(container))
  127. res = list(chain.from_iterable(_get_all_objs(obj, class_name)
  128. for obj in vals))
  129. return unique_objs(res)
  130. def get_all_spiketrains(container):
  131. """
  132. Get all `neo.Spiketrain` objects from a container.
  133. The objects can be any list, dict, or other iterable or mapping containing
  134. spiketrains, as well as any Neo object that can hold spiketrains:
  135. `neo.Block`, `neo.ChannelIndex`, `neo.Unit`, and `neo.Segment`.
  136. Containers are searched recursively, so the objects can be nested
  137. (such as a list of blocks).
  138. Parameters
  139. ----------
  140. container : list, tuple, iterable, dict, neo.Block, neo.Segment, neo.Unit,
  141. neo.ChannelIndex
  142. The container for the spiketrains.
  143. Returns
  144. -------
  145. list
  146. A list of the unique `neo.SpikeTrain` objects in `container`.
  147. """
  148. return _get_all_objs(container, 'SpikeTrain')
  149. def get_all_events(container):
  150. """
  151. Get all `neo.Event` objects from a container.
  152. The objects can be any list, dict, or other iterable or mapping containing
  153. events, as well as any neo object that can hold events:
  154. `neo.Block` and `neo.Segment`.
  155. Containers are searched recursively, so the objects can be nested
  156. (such as a list of blocks).
  157. Parameters
  158. ----------
  159. container : list, tuple, iterable, dict, neo.Block, neo.Segment
  160. The container for the events.
  161. Returns
  162. -------
  163. list
  164. A list of the unique `neo.Event` objects in `container`.
  165. """
  166. return _get_all_objs(container, 'Event')
  167. def get_all_epochs(container):
  168. """
  169. Get all `neo.Epoch` objects from a container.
  170. The objects can be any list, dict, or other iterable or mapping containing
  171. epochs, as well as any neo object that can hold epochs:
  172. `neo.Block` and `neo.Segment`.
  173. Containers are searched recursively, so the objects can be nested
  174. (such as a list of blocks).
  175. Parameters
  176. ----------
  177. container : list, tuple, iterable, dict, neo.Block, neo.Segment
  178. The container for the epochs.
  179. Returns
  180. -------
  181. list
  182. A list of the unique `neo.Epoch` objects in `container`.
  183. """
  184. return _get_all_objs(container, 'Epoch')