basesignal.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. '''
  2. This module implements :class:`BaseSignal`, an array of signals.
  3. This is a parent class from which all signal objects inherit:
  4. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  5. :class:`BaseSignal` inherits from :class:`quantities.Quantity`, which
  6. inherits from :class:`numpy.array`.
  7. Inheritance from :class:`numpy.array` is explained here:
  8. http://docs.scipy.org/doc/numpy/user/basics.subclassing.html
  9. In brief:
  10. * Constructor :meth:`__new__` for :class:`BaseSignal` doesn't exist.
  11. Only child objects :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  12. can be created.
  13. '''
  14. import copy
  15. import logging
  16. from copy import deepcopy
  17. import numpy as np
  18. import quantities as pq
  19. from neo.core.baseneo import MergeError, merge_annotations
  20. from neo.core.dataobject import DataObject, ArrayDict
  21. from neo.core.channelindex import ChannelIndex
  22. logger = logging.getLogger("Neo")
  23. class BaseSignal(DataObject):
  24. '''
  25. This is the base class from which all signal objects inherit:
  26. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`.
  27. This class contains all common methods of both child classes.
  28. It uses the following child class attributes:
  29. :_necessary_attrs: a list of the attributes that the class must have.
  30. :_recommended_attrs: a list of the attributes that the class may
  31. optionally have.
  32. '''
  33. def _array_finalize_spec(self, obj):
  34. '''
  35. Called by :meth:`__array_finalize__`, used to customize behaviour of sub-classes.
  36. '''
  37. return obj
  38. def __array_finalize__(self, obj):
  39. '''
  40. This is called every time a new signal is created.
  41. It is the appropriate place to set default values for attributes
  42. for a signal constructed by slicing or viewing.
  43. User-specified values are only relevant for construction from
  44. constructor, and these are set in __new__ in the child object.
  45. Then they are just copied over here. Default values for the
  46. specific attributes for subclasses (:class:`AnalogSignal`
  47. and :class:`IrregularlySampledSignal`) are set in
  48. :meth:`_array_finalize_spec`
  49. '''
  50. super().__array_finalize__(obj)
  51. self._array_finalize_spec(obj)
  52. # The additional arguments
  53. self.annotations = getattr(obj, 'annotations', {})
  54. # Add empty array annotations, because they cannot always be copied,
  55. # but do not overwrite existing ones from slicing etc.
  56. # This ensures the attribute exists
  57. if not hasattr(self, 'array_annotations'):
  58. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  59. # Globally recommended attributes
  60. self.name = getattr(obj, 'name', None)
  61. self.file_origin = getattr(obj, 'file_origin', None)
  62. self.description = getattr(obj, 'description', None)
  63. # Parent objects
  64. self.segment = getattr(obj, 'segment', None)
  65. self.channel_index = getattr(obj, 'channel_index', None)
  66. @classmethod
  67. def _rescale(self, signal, units=None):
  68. '''
  69. Check that units are present, and rescale the signal if necessary.
  70. This is called whenever a new signal is
  71. created from the constructor. See :meth:`__new__' in
  72. :class:`AnalogSignal` and :class:`IrregularlySampledSignal`
  73. '''
  74. if units is None:
  75. if not hasattr(signal, "units"):
  76. raise ValueError("Units must be specified")
  77. elif isinstance(signal, pq.Quantity):
  78. # This test always returns True, i.e. rescaling is always executed if one of the units
  79. # is a pq.CompoundUnit. This is fine because rescaling is correct anyway.
  80. if pq.quantity.validate_dimensionality(units) != signal.dimensionality:
  81. signal = signal.rescale(units)
  82. return signal
  83. def rescale(self, units):
  84. obj = super().rescale(units)
  85. obj.channel_index = self.channel_index
  86. return obj
  87. def __getslice__(self, i, j):
  88. '''
  89. Get a slice from :attr:`i` to :attr:`j`.attr[0]
  90. Doesn't get called in Python 3, :meth:`__getitem__` is called instead
  91. '''
  92. return self.__getitem__(slice(i, j))
  93. def __ne__(self, other):
  94. '''
  95. Non-equality test (!=)
  96. '''
  97. return not self.__eq__(other)
  98. def _apply_operator(self, other, op, *args):
  99. '''
  100. Handle copying metadata to the new signal
  101. after a mathematical operation.
  102. '''
  103. self._check_consistency(other)
  104. f = getattr(super(), op)
  105. new_signal = f(other, *args)
  106. new_signal._copy_data_complement(self)
  107. # _copy_data_complement can't always copy array annotations,
  108. # so this needs to be done locally
  109. new_signal.array_annotations = copy.deepcopy(self.array_annotations)
  110. return new_signal
  111. def _get_required_attributes(self, signal, units):
  112. '''
  113. Return a list of the required attributes for a signal as a dictionary
  114. '''
  115. required_attributes = {}
  116. for attr in self._necessary_attrs:
  117. if attr[0] == "signal":
  118. required_attributes["signal"] = signal
  119. elif attr[0] == "image_data":
  120. required_attributes["image_data"] = signal
  121. elif attr[0] == "t_start":
  122. required_attributes["t_start"] = getattr(self, "t_start", 0.0 * pq.ms)
  123. else:
  124. required_attributes[str(attr[0])] = getattr(self, attr[0], None)
  125. required_attributes['units'] = units
  126. return required_attributes
  127. def duplicate_with_new_data(self, signal, units=None):
  128. '''
  129. Create a new signal with the same metadata but different data.
  130. Required attributes of the signal are used.
  131. Note: Array annotations can not be copied here because length of data can change
  132. '''
  133. if units is None:
  134. units = self.units
  135. # else:
  136. # units = pq.quantity.validate_dimensionality(units)
  137. # signal is the new signal
  138. required_attributes = self._get_required_attributes(signal, units)
  139. new = self.__class__(**required_attributes)
  140. new._copy_data_complement(self)
  141. new.annotations.update(self.annotations)
  142. # Note: Array annotations are not copied here, because it is not ensured
  143. # that the same number of signals is used and they would possibly make no sense
  144. # when combined with another signal
  145. return new
  146. def _copy_data_complement(self, other):
  147. '''
  148. Copy the metadata from another signal.
  149. Required and recommended attributes of the signal are used.
  150. Note: Array annotations can not be copied here because length of data can change
  151. '''
  152. all_attr = {self._recommended_attrs, self._necessary_attrs}
  153. for sub_at in all_attr:
  154. for attr in sub_at:
  155. if attr[0] == "t_start":
  156. setattr(self, attr[0], deepcopy(getattr(other, attr[0], 0.0 * pq.ms)))
  157. elif attr[0] != 'signal':
  158. setattr(self, attr[0], deepcopy(getattr(other, attr[0], None)))
  159. setattr(self, 'annotations', deepcopy(getattr(other, 'annotations', None)))
  160. # Note: Array annotations cannot be copied because length of data can be changed # here
  161. # which would cause inconsistencies
  162. def __rsub__(self, other, *args):
  163. '''
  164. Backwards subtraction (other-self)
  165. '''
  166. return self.__mul__(-1, *args) + other
  167. def __add__(self, other, *args):
  168. '''
  169. Addition (+)
  170. '''
  171. return self._apply_operator(other, "__add__", *args)
  172. def __sub__(self, other, *args):
  173. '''
  174. Subtraction (-)
  175. '''
  176. return self._apply_operator(other, "__sub__", *args)
  177. def __mul__(self, other, *args):
  178. '''
  179. Multiplication (*)
  180. '''
  181. return self._apply_operator(other, "__mul__", *args)
  182. def __truediv__(self, other, *args):
  183. '''
  184. Float division (/)
  185. '''
  186. return self._apply_operator(other, "__truediv__", *args)
  187. def __div__(self, other, *args):
  188. '''
  189. Integer division (//)
  190. '''
  191. return self._apply_operator(other, "__div__", *args)
  192. __radd__ = __add__
  193. __rmul__ = __sub__
  194. def merge(self, other):
  195. '''
  196. Merge another signal into this one.
  197. The signal objects are concatenated horizontally
  198. (column-wise, :func:`np.hstack`).
  199. If the attributes of the two signal are not
  200. compatible, an Exception is raised.
  201. Required attributes of the signal are used.
  202. '''
  203. for attr in self._necessary_attrs:
  204. if 'signal' != attr[0]:
  205. if getattr(self, attr[0], None) != getattr(other, attr[0], None):
  206. raise MergeError("Cannot merge these two signals as the %s differ." % attr[0])
  207. if self.segment != other.segment:
  208. raise MergeError(
  209. "Cannot merge these two signals as they belong to different segments.")
  210. if hasattr(self, "lazy_shape"):
  211. if hasattr(other, "lazy_shape"):
  212. if self.lazy_shape[0] != other.lazy_shape[0]:
  213. raise MergeError("Cannot merge signals of different length.")
  214. merged_lazy_shape = (self.lazy_shape[0], self.lazy_shape[1] + other.lazy_shape[1])
  215. else:
  216. raise MergeError("Cannot merge a lazy object with a real object.")
  217. if other.units != self.units:
  218. other = other.rescale(self.units)
  219. stack = np.hstack((self.magnitude, other.magnitude))
  220. kwargs = {}
  221. for name in ("name", "description", "file_origin"):
  222. attr_self = getattr(self, name)
  223. attr_other = getattr(other, name)
  224. if attr_self == attr_other:
  225. kwargs[name] = attr_self
  226. else:
  227. kwargs[name] = "merge({}, {})".format(attr_self, attr_other)
  228. merged_annotations = merge_annotations(self.annotations, other.annotations)
  229. kwargs.update(merged_annotations)
  230. kwargs['array_annotations'] = self._merge_array_annotations(other)
  231. signal = self.__class__(stack, units=self.units, dtype=self.dtype, copy=False,
  232. t_start=self.t_start, sampling_rate=self.sampling_rate, **kwargs)
  233. signal.segment = self.segment
  234. if hasattr(self, "lazy_shape"):
  235. signal.lazy_shape = merged_lazy_shape
  236. # merge channel_index (move to ChannelIndex.merge()?)
  237. if self.channel_index and other.channel_index:
  238. signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]),
  239. channel_ids=np.hstack(
  240. [self.channel_index.channel_ids,
  241. other.channel_index.channel_ids]),
  242. channel_names=np.hstack(
  243. [self.channel_index.channel_names,
  244. other.channel_index.channel_names]))
  245. else:
  246. signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]))
  247. return signal
  248. def time_slice(self, t_start, t_stop):
  249. '''
  250. Creates a new AnalogSignal corresponding to the time slice of the
  251. original Signal between times t_start, t_stop.
  252. '''
  253. NotImplementedError('Needs to be implemented for subclasses.')
  254. def concatenate(self, *signals):
  255. '''
  256. Concatenate multiple signals across time.
  257. The signal objects are concatenated vertically
  258. (row-wise, :func:`np.vstack`). Concatenation can be
  259. used to combine signals across segments.
  260. Note: Only (array) annotations common to
  261. both signals are attached to the concatenated signal.
  262. If the attributes of the signals are not
  263. compatible, an Exception is raised.
  264. Parameters
  265. ----------
  266. signals : multiple neo.BaseSignal objects
  267. The objects that is concatenated with this one.
  268. Returns
  269. -------
  270. signal : neo.BaseSignal
  271. Signal containing all non-overlapping samples of
  272. the source signals.
  273. Raises
  274. ------
  275. MergeError
  276. If `other` object has incompatible attributes.
  277. '''
  278. NotImplementedError('Patching need to be implemented in sublcasses')