epoch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # -*- coding: utf-8 -*-
  2. '''
  3. This module defines :class:`Epoch`, an array of epochs.
  4. :class:`Epoch` derives from :class:`BaseNeo`, from
  5. :module:`neo.core.baseneo`.
  6. '''
  7. # needed for python 3 compatibility
  8. from __future__ import absolute_import, division, print_function
  9. import sys
  10. from copy import deepcopy
  11. import numpy as np
  12. import quantities as pq
  13. from neo.core.baseneo import BaseNeo, merge_annotations
  14. from neo.core.dataobject import DataObject, ArrayDict
  15. PY_VER = sys.version_info[0]
  16. def _new_epoch(cls, times=None, durations=None, labels=None, units=None, name=None,
  17. description=None, file_origin=None, array_annotations=None, annotations=None,
  18. segment=None):
  19. '''
  20. A function to map epoch.__new__ to function that
  21. does not do the unit checking. This is needed for pickle to work.
  22. '''
  23. e = Epoch(times=times, durations=durations, labels=labels, units=units, name=name,
  24. file_origin=file_origin, description=description,
  25. array_annotations=array_annotations, **annotations)
  26. e.segment = segment
  27. return e
  28. class Epoch(DataObject):
  29. '''
  30. Array of epochs.
  31. *Usage*::
  32. >>> from neo.core import Epoch
  33. >>> from quantities import s, ms
  34. >>> import numpy as np
  35. >>>
  36. >>> epc = Epoch(times=np.arange(0, 30, 10)*s,
  37. ... durations=[10, 5, 7]*ms,
  38. ... labels=np.array(['btn0', 'btn1', 'btn2'], dtype='S'))
  39. >>>
  40. >>> epc.times
  41. array([ 0., 10., 20.]) * s
  42. >>> epc.durations
  43. array([ 10., 5., 7.]) * ms
  44. >>> epc.labels
  45. array(['btn0', 'btn1', 'btn2'],
  46. dtype='|S4')
  47. *Required attributes/properties*:
  48. :times: (quantity array 1D) The start times of each time period.
  49. :durations: (quantity array 1D or quantity scalar) The length(s) of each time period.
  50. If a scalar, the same value is used for all time periods.
  51. :labels: (numpy.array 1D dtype='S') Names or labels for the time periods.
  52. *Recommended attributes/properties*:
  53. :name: (str) A label for the dataset,
  54. :description: (str) Text description,
  55. :file_origin: (str) Filesystem path or URL of the original data file.
  56. *Optional attributes/properties*:
  57. :array_annotations: (dict) Dict mapping strings to numpy arrays containing annotations \
  58. for all data points
  59. Note: Any other additional arguments are assumed to be user-specific
  60. metadata and stored in :attr:`annotations`,
  61. '''
  62. _single_parent_objects = ('Segment',)
  63. _quantity_attr = 'times'
  64. _necessary_attrs = (('times', pq.Quantity, 1), ('durations', pq.Quantity, 1),
  65. ('labels', np.ndarray, 1, np.dtype('S')))
  66. def __new__(cls, times=None, durations=None, labels=None, units=None, name=None,
  67. description=None, file_origin=None, array_annotations=None, **annotations):
  68. if times is None:
  69. times = np.array([]) * pq.s
  70. if durations is None:
  71. durations = np.array([]) * pq.s
  72. elif durations.size != times.size:
  73. if durations.size == 1:
  74. durations = durations * np.ones_like(times.magnitude)
  75. else:
  76. raise ValueError("Durations array has different length to times")
  77. if labels is None:
  78. labels = np.array([], dtype='S')
  79. elif len(labels) != times.size:
  80. raise ValueError("Labels array has different length to times")
  81. if units is None:
  82. # No keyword units, so get from `times`
  83. try:
  84. units = times.units
  85. dim = units.dimensionality
  86. except AttributeError:
  87. raise ValueError('you must specify units')
  88. else:
  89. if hasattr(units, 'dimensionality'):
  90. dim = units.dimensionality
  91. else:
  92. dim = pq.quantity.validate_dimensionality(units)
  93. # check to make sure the units are time
  94. # this approach is much faster than comparing the
  95. # reference dimensionality
  96. if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
  97. pq.UnitTime)):
  98. ValueError("Unit %s has dimensions %s, not [time]" % (units, dim.simplified))
  99. obj = pq.Quantity.__new__(cls, times, units=dim)
  100. obj.labels = labels
  101. obj.durations = durations
  102. obj.segment = None
  103. return obj
  104. def __init__(self, times=None, durations=None, labels=None, units=None, name=None,
  105. description=None, file_origin=None, array_annotations=None, **annotations):
  106. '''
  107. Initialize a new :class:`Epoch` instance.
  108. '''
  109. DataObject.__init__(self, name=name, file_origin=file_origin, description=description,
  110. array_annotations=array_annotations, **annotations)
  111. def __reduce__(self):
  112. '''
  113. Map the __new__ function onto _new_epoch, so that pickle
  114. works
  115. '''
  116. return _new_epoch, (self.__class__, self.times, self.durations, self.labels, self.units,
  117. self.name, self.file_origin, self.description, self.array_annotations,
  118. self.annotations, self.segment)
  119. def __array_finalize__(self, obj):
  120. super(Epoch, self).__array_finalize__(obj)
  121. self.annotations = getattr(obj, 'annotations', None)
  122. self.name = getattr(obj, 'name', None)
  123. self.file_origin = getattr(obj, 'file_origin', None)
  124. self.description = getattr(obj, 'description', None)
  125. self.segment = getattr(obj, 'segment', None)
  126. # Add empty array annotations, because they cannot always be copied,
  127. # but do not overwrite existing ones from slicing etc.
  128. # This ensures the attribute exists
  129. if not hasattr(self, 'array_annotations'):
  130. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  131. def __repr__(self):
  132. '''
  133. Returns a string representing the :class:`Epoch`.
  134. '''
  135. # need to convert labels to unicode for python 3 or repr is messed up
  136. if PY_VER == 3:
  137. labels = self.labels.astype('U')
  138. else:
  139. labels = self.labels
  140. objs = ['%s@%s for %s' % (label, time, dur) for label, time, dur in
  141. zip(labels, self.times, self.durations)]
  142. return '<Epoch: %s>' % ', '.join(objs)
  143. def _repr_pretty_(self, pp, cycle):
  144. super(Epoch, self)._repr_pretty_(pp, cycle)
  145. def rescale(self, units):
  146. '''
  147. Return a copy of the :class:`Epoch` converted to the specified
  148. units
  149. '''
  150. obj = super(Epoch, self).rescale(units)
  151. obj.segment = self.segment
  152. return obj
  153. def __getitem__(self, i):
  154. '''
  155. Get the item or slice :attr:`i`.
  156. '''
  157. obj = Epoch(times=super(Epoch, self).__getitem__(i))
  158. obj._copy_data_complement(self)
  159. try:
  160. # Array annotations need to be sliced accordingly
  161. obj.array_annotate(**deepcopy(self.array_annotations_at_index(i)))
  162. except AttributeError: # If Quantity was returned, not Epoch
  163. pass
  164. return obj
  165. def __getslice__(self, i, j):
  166. '''
  167. Get a slice from :attr:`i` to :attr:`j`.attr[0]
  168. Doesn't get called in Python 3, :meth:`__getitem__` is called instead
  169. '''
  170. return self.__getitem__(slice(i, j))
  171. @property
  172. def times(self):
  173. return pq.Quantity(self)
  174. def merge(self, other):
  175. '''
  176. Merge the another :class:`Epoch` into this one.
  177. The :class:`Epoch` objects are concatenated horizontally
  178. (column-wise), :func:`np.hstack`).
  179. If the attributes of the two :class:`Epoch` are not
  180. compatible, and Exception is raised.
  181. '''
  182. othertimes = other.times.rescale(self.times.units)
  183. times = np.hstack([self.times, othertimes]) * self.times.units
  184. kwargs = {}
  185. for name in ("name", "description", "file_origin"):
  186. attr_self = getattr(self, name)
  187. attr_other = getattr(other, name)
  188. if attr_self == attr_other:
  189. kwargs[name] = attr_self
  190. else:
  191. kwargs[name] = "merge(%s, %s)" % (attr_self, attr_other)
  192. merged_annotations = merge_annotations(self.annotations, other.annotations)
  193. kwargs.update(merged_annotations)
  194. kwargs['array_annotations'] = self._merge_array_annotations(other)
  195. labels = kwargs['array_annotations']['labels']
  196. durations = kwargs['array_annotations']['durations']
  197. return Epoch(times=times, durations=durations, labels=labels, **kwargs)
  198. def _copy_data_complement(self, other):
  199. '''
  200. Copy the metadata from another :class:`Epoch`.
  201. Note: Array annotations can not be copied here because length of data can change
  202. '''
  203. # Note: Array annotations cannot be copied because length of data could be changed
  204. # here which would cause inconsistencies. This is instead done locally.
  205. for attr in ("name", "file_origin", "description", "annotations"):
  206. setattr(self, attr, getattr(other, attr, None))
  207. def __deepcopy__(self, memo):
  208. cls = self.__class__
  209. new_ep = cls(times=self.times, durations=self.durations, labels=self.labels,
  210. units=self.units, name=self.name, description=self.description,
  211. file_origin=self.file_origin)
  212. new_ep.__dict__.update(self.__dict__)
  213. memo[id(self)] = new_ep
  214. for k, v in self.__dict__.items():
  215. try:
  216. setattr(new_ep, k, deepcopy(v, memo))
  217. except TypeError:
  218. setattr(new_ep, k, v)
  219. return new_ep
  220. def duplicate_with_new_data(self, signal, units=None):
  221. '''
  222. Create a new :class:`Epoch` with the same metadata
  223. but different data (times, durations)
  224. Note: Array annotations can not be copied here because length of data can change
  225. '''
  226. if units is None:
  227. units = self.units
  228. else:
  229. units = pq.quantity.validate_dimensionality(units)
  230. new = self.__class__(times=signal, units=units)
  231. new._copy_data_complement(self)
  232. # Note: Array annotations can not be copied here because length of data can change
  233. return new
  234. def time_slice(self, t_start, t_stop):
  235. '''
  236. Creates a new :class:`Epoch` corresponding to the time slice of
  237. the original :class:`Epoch` between (and including) times
  238. :attr:`t_start` and :attr:`t_stop`. Either parameter can also be None
  239. to use infinite endpoints for the time interval.
  240. '''
  241. _t_start = t_start
  242. _t_stop = t_stop
  243. if t_start is None:
  244. _t_start = -np.inf
  245. if t_stop is None:
  246. _t_stop = np.inf
  247. indices = (self >= _t_start) & (self <= _t_stop)
  248. new_epc = self[indices]
  249. return new_epc
  250. def set_labels(self, labels):
  251. self.array_annotate(labels=labels)
  252. def get_labels(self):
  253. return self.array_annotations['labels']
  254. labels = property(get_labels, set_labels)
  255. def set_durations(self, durations):
  256. self.array_annotate(durations=durations)
  257. def get_durations(self):
  258. return self.array_annotations['durations']
  259. durations = property(get_durations, set_durations)