basesignal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # -*- coding: utf-8 -*-
  2. '''
  3. This module implements :class:`BaseSignal`, an array of signals.
  4. This is a parent class from which all signal objects inherit:
  5. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  6. :class:`BaseSignal` inherits from :class:`quantities.Quantity`, which
  7. inherits from :class:`numpy.array`.
  8. Inheritance from :class:`numpy.array` is explained here:
  9. http://docs.scipy.org/doc/numpy/user/basics.subclassing.html
  10. In brief:
  11. * Constructor :meth:`__new__` for :class:`BaseSignal` doesn't exist.
  12. Only child objects :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  13. can be created.
  14. '''
  15. # needed for Python 3 compatibility
  16. from __future__ import absolute_import, division, print_function
  17. import copy
  18. import logging
  19. import numpy as np
  20. import quantities as pq
  21. from neo.core.baseneo import BaseNeo, MergeError, merge_annotations
  22. from neo.core.dataobject import DataObject, ArrayDict
  23. from neo.core.channelindex import ChannelIndex
  24. logger = logging.getLogger("Neo")
  25. class BaseSignal(DataObject):
  26. '''
  27. This is the base class from which all signal objects inherit:
  28. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`.
  29. This class contains all common methods of both child classes.
  30. It uses the following child class attributes:
  31. :_necessary_attrs: a list of the attributes that the class must have.
  32. :_recommended_attrs: a list of the attributes that the class may
  33. optionally have.
  34. '''
  35. def _array_finalize_spec(self, obj):
  36. '''
  37. Called by :meth:`__array_finalize__`, used to customize behaviour of sub-classes.
  38. '''
  39. return obj
  40. def __array_finalize__(self, obj):
  41. '''
  42. This is called every time a new signal is created.
  43. It is the appropriate place to set default values for attributes
  44. for a signal constructed by slicing or viewing.
  45. User-specified values are only relevant for construction from
  46. constructor, and these are set in __new__ in the child object.
  47. Then they are just copied over here. Default values for the
  48. specific attributes for subclasses (:class:`AnalogSignal`
  49. and :class:`IrregularlySampledSignal`) are set in
  50. :meth:`_array_finalize_spec`
  51. '''
  52. super(BaseSignal, self).__array_finalize__(obj)
  53. self._array_finalize_spec(obj)
  54. # The additional arguments
  55. self.annotations = getattr(obj, 'annotations', {})
  56. # Add empty array annotations, because they cannot always be copied,
  57. # but do not overwrite existing ones from slicing etc.
  58. # This ensures the attribute exists
  59. if not hasattr(self, 'array_annotations'):
  60. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  61. # Globally recommended attributes
  62. self.name = getattr(obj, 'name', None)
  63. self.file_origin = getattr(obj, 'file_origin', None)
  64. self.description = getattr(obj, 'description', None)
  65. # Parent objects
  66. self.segment = getattr(obj, 'segment', None)
  67. self.channel_index = getattr(obj, 'channel_index', None)
  68. @classmethod
  69. def _rescale(self, signal, units=None):
  70. '''
  71. Check that units are present, and rescale the signal if necessary.
  72. This is called whenever a new signal is
  73. created from the constructor. See :meth:`__new__' in
  74. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  75. '''
  76. if units is None:
  77. if not hasattr(signal, "units"):
  78. raise ValueError("Units must be specified")
  79. elif isinstance(signal, pq.Quantity):
  80. # This test always returns True, i.e. rescaling is always executed if one of the units
  81. # is a pq.CompoundUnit. This is fine because rescaling is correct anyway.
  82. if pq.quantity.validate_dimensionality(units) != signal.dimensionality:
  83. signal = signal.rescale(units)
  84. return signal
  85. def rescale(self, units):
  86. obj = super(BaseSignal, self).rescale(units)
  87. obj.channel_index = self.channel_index
  88. return obj
  89. def __getslice__(self, i, j):
  90. '''
  91. Get a slice from :attr:`i` to :attr:`j`.attr[0]
  92. Doesn't get called in Python 3, :meth:`__getitem__` is called instead
  93. '''
  94. return self.__getitem__(slice(i, j))
  95. def __ne__(self, other):
  96. '''
  97. Non-equality test (!=)
  98. '''
  99. return not self.__eq__(other)
  100. def _apply_operator(self, other, op, *args):
  101. '''
  102. Handle copying metadata to the new signal
  103. after a mathematical operation.
  104. '''
  105. self._check_consistency(other)
  106. f = getattr(super(BaseSignal, self), op)
  107. new_signal = f(other, *args)
  108. new_signal._copy_data_complement(self)
  109. # _copy_data_complement can't always copy array annotations,
  110. # so this needs to be done locally
  111. new_signal.array_annotations = copy.deepcopy(self.array_annotations)
  112. return new_signal
  113. def _get_required_attributes(self, signal, units):
  114. '''
  115. Return a list of the required attributes for a signal as a dictionary
  116. '''
  117. required_attributes = {}
  118. for attr in self._necessary_attrs:
  119. if 'signal' == attr[0]:
  120. required_attributes[str(attr[0])] = signal
  121. else:
  122. required_attributes[str(attr[0])] = getattr(self, attr[0], None)
  123. required_attributes['units'] = units
  124. return required_attributes
  125. def duplicate_with_new_data(self, signal, units=None):
  126. '''
  127. Create a new signal with the same metadata but different data.
  128. Required attributes of the signal are used.
  129. Note: Array annotations can not be copied here because length of data can change
  130. '''
  131. if units is None:
  132. units = self.units
  133. # else:
  134. # units = pq.quantity.validate_dimensionality(units)
  135. # signal is the new signal
  136. required_attributes = self._get_required_attributes(signal, units)
  137. new = self.__class__(**required_attributes)
  138. new._copy_data_complement(self)
  139. new.annotations.update(self.annotations)
  140. # Note: Array annotations are not copied here, because it is not ensured
  141. # that the same number of signals is used and they would possibly make no sense
  142. # when combined with another signal
  143. return new
  144. def _copy_data_complement(self, other):
  145. '''
  146. Copy the metadata from another signal.
  147. Required and recommended attributes of the signal are used.
  148. Note: Array annotations can not be copied here because length of data can change
  149. '''
  150. all_attr = {self._recommended_attrs, self._necessary_attrs}
  151. for sub_at in all_attr:
  152. for attr in sub_at:
  153. if attr[0] != 'signal':
  154. setattr(self, attr[0], getattr(other, attr[0], None))
  155. setattr(self, 'annotations', getattr(other, 'annotations', None))
  156. # Note: Array annotations cannot be copied because length of data can be changed # here
  157. # which would cause inconsistencies
  158. def __rsub__(self, other, *args):
  159. '''
  160. Backwards subtraction (other-self)
  161. '''
  162. return self.__mul__(-1, *args) + other
  163. def __add__(self, other, *args):
  164. '''
  165. Addition (+)
  166. '''
  167. return self._apply_operator(other, "__add__", *args)
  168. def __sub__(self, other, *args):
  169. '''
  170. Subtraction (-)
  171. '''
  172. return self._apply_operator(other, "__sub__", *args)
  173. def __mul__(self, other, *args):
  174. '''
  175. Multiplication (*)
  176. '''
  177. return self._apply_operator(other, "__mul__", *args)
  178. def __truediv__(self, other, *args):
  179. '''
  180. Float division (/)
  181. '''
  182. return self._apply_operator(other, "__truediv__", *args)
  183. def __div__(self, other, *args):
  184. '''
  185. Integer division (//)
  186. '''
  187. return self._apply_operator(other, "__div__", *args)
  188. __radd__ = __add__
  189. __rmul__ = __sub__
  190. def merge(self, other):
  191. '''
  192. Merge another signal into this one.
  193. The signal objects are concatenated horizontally
  194. (column-wise, :func:`np.hstack`).
  195. If the attributes of the two signal are not
  196. compatible, an Exception is raised.
  197. Required attributes of the signal are used.
  198. '''
  199. for attr in self._necessary_attrs:
  200. if 'signal' != attr[0]:
  201. if getattr(self, attr[0], None) != getattr(other, attr[0], None):
  202. raise MergeError("Cannot merge these two signals as the %s differ." % attr[0])
  203. if self.segment != other.segment:
  204. raise MergeError(
  205. "Cannot merge these two signals as they belong to different segments.")
  206. if hasattr(self, "lazy_shape"):
  207. if hasattr(other, "lazy_shape"):
  208. if self.lazy_shape[0] != other.lazy_shape[0]:
  209. raise MergeError("Cannot merge signals of different length.")
  210. merged_lazy_shape = (self.lazy_shape[0], self.lazy_shape[1] + other.lazy_shape[1])
  211. else:
  212. raise MergeError("Cannot merge a lazy object with a real object.")
  213. if other.units != self.units:
  214. other = other.rescale(self.units)
  215. stack = np.hstack(map(np.array, (self, other)))
  216. kwargs = {}
  217. for name in ("name", "description", "file_origin"):
  218. attr_self = getattr(self, name)
  219. attr_other = getattr(other, name)
  220. if attr_self == attr_other:
  221. kwargs[name] = attr_self
  222. else:
  223. kwargs[name] = "merge(%s, %s)" % (attr_self, attr_other)
  224. merged_annotations = merge_annotations(self.annotations, other.annotations)
  225. kwargs.update(merged_annotations)
  226. kwargs['array_annotations'] = self._merge_array_annotations(other)
  227. signal = self.__class__(stack, units=self.units, dtype=self.dtype, copy=False,
  228. t_start=self.t_start, sampling_rate=self.sampling_rate, **kwargs)
  229. signal.segment = self.segment
  230. if hasattr(self, "lazy_shape"):
  231. signal.lazy_shape = merged_lazy_shape
  232. # merge channel_index (move to ChannelIndex.merge()?)
  233. if self.channel_index and other.channel_index:
  234. signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]),
  235. channel_ids=np.hstack(
  236. [self.channel_index.channel_ids, other.channel_index.channel_ids]),
  237. channel_names=np.hstack(
  238. [self.channel_index.channel_names, other.channel_index.channel_names]))
  239. else:
  240. signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]))
  241. return signal