epoch.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. '''
  2. This module defines :class:`Epoch`, an array of epochs.
  3. :class:`Epoch` derives from :class:`BaseNeo`, from
  4. :module:`neo.core.baseneo`.
  5. '''
  6. from copy import deepcopy, copy
  7. import numpy as np
  8. import quantities as pq
  9. from neo.core.baseneo import BaseNeo, merge_annotations
  10. from neo.core.dataobject import DataObject, ArrayDict
  11. def _new_epoch(cls, times=None, durations=None, labels=None, units=None, name=None,
  12. description=None, file_origin=None, array_annotations=None, annotations=None,
  13. segment=None):
  14. '''
  15. A function to map epoch.__new__ to function that
  16. does not do the unit checking. This is needed for pickle to work.
  17. '''
  18. e = Epoch(times=times, durations=durations, labels=labels, units=units, name=name,
  19. file_origin=file_origin, description=description,
  20. array_annotations=array_annotations, **annotations)
  21. e.segment = segment
  22. return e
  23. class Epoch(DataObject):
  24. '''
  25. Array of epochs.
  26. *Usage*::
  27. >>> from neo.core import Epoch
  28. >>> from quantities import s, ms
  29. >>> import numpy as np
  30. >>>
  31. >>> epc = Epoch(times=np.arange(0, 30, 10)*s,
  32. ... durations=[10, 5, 7]*ms,
  33. ... labels=np.array(['btn0', 'btn1', 'btn2'], dtype='U'))
  34. >>>
  35. >>> epc.times
  36. array([ 0., 10., 20.]) * s
  37. >>> epc.durations
  38. array([ 10., 5., 7.]) * ms
  39. >>> epc.labels
  40. array(['btn0', 'btn1', 'btn2'],
  41. dtype='<U4')
  42. *Required attributes/properties*:
  43. :times: (quantity array 1D, numpy array 1D or list) The start times
  44. of each time period.
  45. :durations: (quantity array 1D, numpy array 1D, list, or quantity scalar)
  46. The length(s) of each time period.
  47. If a scalar, the same value is used for all time periods.
  48. :labels: (numpy.array 1D dtype='U' or list) Names or labels for the time periods.
  49. *Recommended attributes/properties*:
  50. :name: (str) A label for the dataset,
  51. :description: (str) Text description,
  52. :file_origin: (str) Filesystem path or URL of the original data file.
  53. *Optional attributes/properties*:
  54. :array_annotations: (dict) Dict mapping strings to numpy arrays containing annotations \
  55. for all data points
  56. Note: Any other additional arguments are assumed to be user-specific
  57. metadata and stored in :attr:`annotations`,
  58. '''
  59. _single_parent_objects = ('Segment',)
  60. _single_parent_attrs = ('segment',)
  61. _quantity_attr = 'times'
  62. _necessary_attrs = (('times', pq.Quantity, 1), ('durations', pq.Quantity, 1),
  63. ('labels', np.ndarray, 1, np.dtype('U')))
  64. def __new__(cls, times=None, durations=None, labels=None, units=None, name=None,
  65. description=None, file_origin=None, array_annotations=None, **annotations):
  66. if times is None:
  67. times = np.array([]) * pq.s
  68. elif isinstance(times, (list, tuple)):
  69. times = np.array(times)
  70. if len(times.shape) > 1:
  71. raise ValueError("Times array has more than 1 dimension")
  72. if isinstance(durations, (list, tuple)):
  73. durations = np.array(durations)
  74. if durations is None:
  75. durations = np.array([]) * pq.s
  76. elif durations.size != times.size:
  77. if durations.size == 1:
  78. durations = durations * np.ones_like(times.magnitude)
  79. else:
  80. raise ValueError("Durations array has different length to times")
  81. if labels is None:
  82. labels = np.array([], dtype='U')
  83. else:
  84. labels = np.array(labels)
  85. if labels.size != times.size and labels.size:
  86. raise ValueError("Labels array has different length to times")
  87. if units is None:
  88. # No keyword units, so get from `times`
  89. try:
  90. units = times.units
  91. dim = units.dimensionality
  92. except AttributeError:
  93. raise ValueError('you must specify units')
  94. else:
  95. if hasattr(units, 'dimensionality'):
  96. dim = units.dimensionality
  97. else:
  98. dim = pq.quantity.validate_dimensionality(units)
  99. if not hasattr(durations, "dimensionality"):
  100. durations = pq.Quantity(durations, dim)
  101. # check to make sure the units are time
  102. # this approach is much faster than comparing the
  103. # reference dimensionality
  104. if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
  105. pq.UnitTime)):
  106. ValueError("Unit %s has dimensions %s, not [time]" % (units, dim.simplified))
  107. obj = pq.Quantity.__new__(cls, times, units=dim)
  108. obj._labels = labels
  109. obj._durations = durations
  110. obj.segment = None
  111. return obj
  112. def __init__(self, times=None, durations=None, labels=None, units=None, name=None,
  113. description=None, file_origin=None, array_annotations=None, **annotations):
  114. '''
  115. Initialize a new :class:`Epoch` instance.
  116. '''
  117. DataObject.__init__(self, name=name, file_origin=file_origin, description=description,
  118. array_annotations=array_annotations, **annotations)
  119. def __reduce__(self):
  120. '''
  121. Map the __new__ function onto _new_epoch, so that pickle
  122. works
  123. '''
  124. return _new_epoch, (self.__class__, self.times, self.durations, self.labels, self.units,
  125. self.name, self.file_origin, self.description, self.array_annotations,
  126. self.annotations, self.segment)
  127. def __array_finalize__(self, obj):
  128. super().__array_finalize__(obj)
  129. self._durations = getattr(obj, 'durations', None)
  130. self._labels = getattr(obj, 'labels', None)
  131. self.annotations = getattr(obj, 'annotations', None)
  132. self.name = getattr(obj, 'name', None)
  133. self.file_origin = getattr(obj, 'file_origin', None)
  134. self.description = getattr(obj, 'description', None)
  135. self.segment = getattr(obj, 'segment', None)
  136. # Add empty array annotations, because they cannot always be copied,
  137. # but do not overwrite existing ones from slicing etc.
  138. # This ensures the attribute exists
  139. if not hasattr(self, 'array_annotations'):
  140. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  141. def __repr__(self):
  142. '''
  143. Returns a string representing the :class:`Epoch`.
  144. '''
  145. objs = ['%s@%s for %s' % (label, str(time), str(dur)) for label, time, dur in
  146. zip(self.labels, self.times, self.durations)]
  147. return '<Epoch: %s>' % ', '.join(objs)
  148. def _repr_pretty_(self, pp, cycle):
  149. super()._repr_pretty_(pp, cycle)
  150. def rescale(self, units):
  151. '''
  152. Return a copy of the :class:`Epoch` converted to the specified
  153. units
  154. :return: Copy of self with specified units
  155. '''
  156. # Use simpler functionality, if nothing will be changed
  157. dim = pq.quantity.validate_dimensionality(units)
  158. if self.dimensionality == dim:
  159. return self.copy()
  160. # Rescale the object into a new object
  161. obj = self.duplicate_with_new_data(
  162. times=self.view(pq.Quantity).rescale(dim),
  163. durations=self.durations.rescale(dim),
  164. labels=self.labels,
  165. units=units)
  166. # Expected behavior is deepcopy, so deepcopying array_annotations
  167. obj.array_annotations = deepcopy(self.array_annotations)
  168. obj.segment = self.segment
  169. return obj
  170. def __getitem__(self, i):
  171. '''
  172. Get the item or slice :attr:`i`.
  173. '''
  174. obj = super().__getitem__(i)
  175. obj._durations = self.durations[i]
  176. if self._labels is not None and self._labels.size > 0:
  177. obj._labels = self.labels[i]
  178. else:
  179. obj._labels = self.labels
  180. try:
  181. # Array annotations need to be sliced accordingly
  182. obj.array_annotate(**deepcopy(self.array_annotations_at_index(i)))
  183. obj._copy_data_complement(self)
  184. except AttributeError: # If Quantity was returned, not Epoch
  185. obj.times = obj
  186. obj.durations = obj._durations
  187. obj.labels = obj._labels
  188. return obj
  189. def __getslice__(self, i, j):
  190. '''
  191. Get a slice from :attr:`i` to :attr:`j`.attr[0]
  192. Doesn't get called in Python 3, :meth:`__getitem__` is called instead
  193. '''
  194. return self.__getitem__(slice(i, j))
  195. @property
  196. def times(self):
  197. return pq.Quantity(self)
  198. def merge(self, other):
  199. '''
  200. Merge the another :class:`Epoch` into this one.
  201. The :class:`Epoch` objects are concatenated horizontally
  202. (column-wise), :func:`np.hstack`).
  203. If the attributes of the two :class:`Epoch` are not
  204. compatible, and Exception is raised.
  205. '''
  206. othertimes = other.times.rescale(self.times.units)
  207. otherdurations = other.durations.rescale(self.durations.units)
  208. times = np.hstack([self.times, othertimes]) * self.times.units
  209. durations = np.hstack([self.durations,
  210. otherdurations]) * self.durations.units
  211. labels = np.hstack([self.labels, other.labels])
  212. kwargs = {}
  213. for name in ("name", "description", "file_origin"):
  214. attr_self = getattr(self, name)
  215. attr_other = getattr(other, name)
  216. if attr_self == attr_other:
  217. kwargs[name] = attr_self
  218. else:
  219. kwargs[name] = "merge({}, {})".format(attr_self, attr_other)
  220. merged_annotations = merge_annotations(self.annotations, other.annotations)
  221. kwargs.update(merged_annotations)
  222. kwargs['array_annotations'] = self._merge_array_annotations(other)
  223. return Epoch(times=times, durations=durations, labels=labels, **kwargs)
  224. def _copy_data_complement(self, other):
  225. '''
  226. Copy the metadata from another :class:`Epoch`.
  227. Note: Array annotations can not be copied here because length of data can change
  228. '''
  229. # Note: Array annotations cannot be copied because length of data could be changed
  230. # here which would cause inconsistencies. This is instead done locally.
  231. for attr in ("name", "file_origin", "description"):
  232. setattr(self, attr, deepcopy(getattr(other, attr, None)))
  233. self._copy_annotations(other)
  234. def _copy_annotations(self, other):
  235. self.annotations = deepcopy(other.annotations)
  236. def duplicate_with_new_data(self, times, durations, labels, units=None):
  237. '''
  238. Create a new :class:`Epoch` with the same metadata
  239. but different data (times, durations)
  240. Note: Array annotations can not be copied here because length of data can change
  241. '''
  242. if units is None:
  243. units = self.units
  244. else:
  245. units = pq.quantity.validate_dimensionality(units)
  246. new = self.__class__(times=times, durations=durations, labels=labels, units=units)
  247. new._copy_data_complement(self)
  248. new._labels = labels
  249. new._durations = durations
  250. # Note: Array annotations can not be copied here because length of data can change
  251. return new
  252. def time_slice(self, t_start, t_stop):
  253. '''
  254. Creates a new :class:`Epoch` corresponding to the time slice of
  255. the original :class:`Epoch` between (and including) times
  256. :attr:`t_start` and :attr:`t_stop`. Either parameter can also be None
  257. to use infinite endpoints for the time interval.
  258. '''
  259. _t_start = t_start
  260. _t_stop = t_stop
  261. if t_start is None:
  262. _t_start = -np.inf
  263. if t_stop is None:
  264. _t_stop = np.inf
  265. indices = (self >= _t_start) & (self <= _t_stop)
  266. # Time slicing should create a deep copy of the object
  267. new_epc = deepcopy(self[indices])
  268. return new_epc
  269. def time_shift(self, t_shift):
  270. """
  271. Shifts an :class:`Epoch` by an amount of time.
  272. Parameters:
  273. -----------
  274. t_shift: Quantity (time)
  275. Amount of time by which to shift the :class:`Epoch`.
  276. Returns:
  277. --------
  278. epoch: :class:`Epoch`
  279. New instance of an :class:`Epoch` object starting at t_shift later than the
  280. original :class:`Epoch` (the original :class:`Epoch` is not modified).
  281. """
  282. new_epc = self.duplicate_with_new_data(times=self.times + t_shift,
  283. durations=self.durations,
  284. labels=self.labels)
  285. # Here we can safely copy the array annotations since we know that
  286. # the length of the Epoch does not change.
  287. new_epc.array_annotate(**self.array_annotations)
  288. return new_epc
  289. def set_labels(self, labels):
  290. if self.labels is not None and self.labels.size > 0 and len(labels) != self.size:
  291. raise ValueError("Labels array has different length to times ({} != {})"
  292. .format(len(labels), self.size))
  293. self._labels = np.array(labels)
  294. def get_labels(self):
  295. return self._labels
  296. labels = property(get_labels, set_labels)
  297. def set_durations(self, durations):
  298. if self.durations is not None and self.durations.size > 0 and len(durations) != self.size:
  299. raise ValueError("Durations array has different length to times ({} != {})"
  300. .format(len(durations), self.size))
  301. self._durations = durations
  302. def get_durations(self):
  303. return self._durations
  304. durations = property(get_durations, set_durations)