epoch.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # -*- coding: utf-8 -*-
  2. '''
  3. This module defines :class:`Epoch`, an array of epochs.
  4. :class:`Epoch` derives from :class:`BaseNeo`, from
  5. :module:`neo.core.baseneo`.
  6. '''
  7. # needed for python 3 compatibility
  8. from __future__ import absolute_import, division, print_function
  9. import sys
  10. import numpy as np
  11. import quantities as pq
  12. from neo.core.baseneo import BaseNeo, merge_annotations
  13. PY_VER = sys.version_info[0]
  14. def _new_epoch(cls, times=None, durations=None, labels=None, units=None,
  15. name=None, description=None, file_origin=None, annotations = None, segment=None):
  16. '''
  17. A function to map epoch.__new__ to function that
  18. does not do the unit checking. This is needed for pickle to work.
  19. '''
  20. e = Epoch( times=times, durations=durations, labels=labels, units=units, name=name, file_origin=file_origin,
  21. description=description, **annotations)
  22. e.segment = segment
  23. return e
  24. class Epoch(BaseNeo, pq.Quantity):
  25. '''
  26. Array of epochs.
  27. *Usage*::
  28. >>> from neo.core import Epoch
  29. >>> from quantities import s, ms
  30. >>> import numpy as np
  31. >>>
  32. >>> epc = Epoch(times=np.arange(0, 30, 10)*s,
  33. ... durations=[10, 5, 7]*ms,
  34. ... labels=np.array(['btn0', 'btn1', 'btn2'], dtype='S'))
  35. >>>
  36. >>> epc.times
  37. array([ 0., 10., 20.]) * s
  38. >>> epc.durations
  39. array([ 10., 5., 7.]) * ms
  40. >>> epc.labels
  41. array(['btn0', 'btn1', 'btn2'],
  42. dtype='|S4')
  43. *Required attributes/properties*:
  44. :times: (quantity array 1D) The starts of the time periods.
  45. :durations: (quantity array 1D) The length of the time period.
  46. :labels: (numpy.array 1D dtype='S') Names or labels for the
  47. time periods.
  48. *Recommended attributes/properties*:
  49. :name: (str) A label for the dataset,
  50. :description: (str) Text description,
  51. :file_origin: (str) Filesystem path or URL of the original data file.
  52. Note: Any other additional arguments are assumed to be user-specific
  53. metadata and stored in :attr:`annotations`,
  54. '''
  55. _single_parent_objects = ('Segment',)
  56. _quantity_attr = 'times'
  57. _necessary_attrs = (('times', pq.Quantity, 1),
  58. ('durations', pq.Quantity, 1),
  59. ('labels', np.ndarray, 1, np.dtype('S')))
  60. def __new__(cls, times=None, durations=None, labels=None, units=None,
  61. name=None, description=None, file_origin=None, **annotations):
  62. if times is None:
  63. times = np.array([]) * pq.s
  64. if durations is None:
  65. durations = np.array([]) * pq.s
  66. if labels is None:
  67. labels = np.array([], dtype='S')
  68. if units is None:
  69. # No keyword units, so get from `times`
  70. try:
  71. units = times.units
  72. dim = units.dimensionality
  73. except AttributeError:
  74. raise ValueError('you must specify units')
  75. else:
  76. if hasattr(units, 'dimensionality'):
  77. dim = units.dimensionality
  78. else:
  79. dim = pq.quantity.validate_dimensionality(units)
  80. # check to make sure the units are time
  81. # this approach is much faster than comparing the
  82. # reference dimensionality
  83. if (len(dim) != 1 or list(dim.values())[0] != 1 or
  84. not isinstance(list(dim.keys())[0], pq.UnitTime)):
  85. ValueError("Unit %s has dimensions %s, not [time]" %
  86. (units, dim.simplified))
  87. obj = pq.Quantity.__new__(cls, times, units=dim)
  88. obj.durations = durations
  89. obj.labels = labels
  90. obj.segment = None
  91. return obj
  92. def __init__(self, times=None, durations=None, labels=None, units=None,
  93. name=None, description=None, file_origin=None, **annotations):
  94. '''
  95. Initialize a new :class:`Epoch` instance.
  96. '''
  97. BaseNeo.__init__(self, name=name, file_origin=file_origin,
  98. description=description, **annotations)
  99. def __reduce__(self):
  100. '''
  101. Map the __new__ function onto _new_BaseAnalogSignal, so that pickle
  102. works
  103. '''
  104. return _new_epoch, (self.__class__, self.times, self.durations, self.labels, self.units,
  105. self.name, self.file_origin, self.description,
  106. self.annotations, self.segment)
  107. def __array_finalize__(self, obj):
  108. super(Epoch, self).__array_finalize__(obj)
  109. self.durations = getattr(obj, 'durations', None)
  110. self.labels = getattr(obj, 'labels', None)
  111. self.annotations = getattr(obj, 'annotations', None)
  112. self.name = getattr(obj, 'name', None)
  113. self.file_origin = getattr(obj, 'file_origin', None)
  114. self.description = getattr(obj, 'description', None)
  115. self.segment = getattr(obj, 'segment', None)
  116. def __repr__(self):
  117. '''
  118. Returns a string representing the :class:`Epoch`.
  119. '''
  120. # need to convert labels to unicode for python 3 or repr is messed up
  121. if PY_VER == 3:
  122. labels = self.labels.astype('U')
  123. else:
  124. labels = self.labels
  125. objs = ['%s@%s for %s' % (label, time, dur) for
  126. label, time, dur in zip(labels, self.times, self.durations)]
  127. return '<Epoch: %s>' % ', '.join(objs)
  128. @property
  129. def times(self):
  130. return pq.Quantity(self)
  131. def merge(self, other):
  132. '''
  133. Merge the another :class:`Epoch` into this one.
  134. The :class:`Epoch` objects are concatenated horizontally
  135. (column-wise), :func:`np.hstack`).
  136. If the attributes of the two :class:`Epoch` are not
  137. compatible, and Exception is raised.
  138. '''
  139. othertimes = other.times.rescale(self.times.units)
  140. otherdurations = other.durations.rescale(self.durations.units)
  141. times = np.hstack([self.times, othertimes]) * self.times.units
  142. durations = np.hstack([self.durations,
  143. otherdurations]) * self.durations.units
  144. labels = np.hstack([self.labels, other.labels])
  145. kwargs = {}
  146. for name in ("name", "description", "file_origin"):
  147. attr_self = getattr(self, name)
  148. attr_other = getattr(other, name)
  149. if attr_self == attr_other:
  150. kwargs[name] = attr_self
  151. else:
  152. kwargs[name] = "merge(%s, %s)" % (attr_self, attr_other)
  153. merged_annotations = merge_annotations(self.annotations,
  154. other.annotations)
  155. kwargs.update(merged_annotations)
  156. return Epoch(times=times, durations=durations, labels=labels, **kwargs)
  157. def _copy_data_complement(self, other):
  158. '''
  159. Copy the metadata from another :class:`Epoch`.
  160. '''
  161. for attr in ("labels", "durations", "name", "file_origin",
  162. "description", "annotations"):
  163. setattr(self, attr, getattr(other, attr, None))
  164. def duplicate_with_new_data(self, signal):
  165. '''
  166. Create a new :class:`Epoch` with the same metadata
  167. but different data (times, durations)
  168. '''
  169. new = self.__class__(times=signal)
  170. new._copy_data_complement(self)
  171. return new
  172. def time_slice(self, t_start, t_stop):
  173. '''
  174. Creates a new :class:`Epoch` corresponding to the time slice of
  175. the original :class:`Epoch` between (and including) times
  176. :attr:`t_start` and :attr:`t_stop`. Either parameter can also be None
  177. to use infinite endpoints for the time interval.
  178. '''
  179. _t_start = t_start
  180. _t_stop = t_stop
  181. if t_start is None:
  182. _t_start = -np.inf
  183. if t_stop is None:
  184. _t_stop = np.inf
  185. indices = (self >= _t_start) & (self <= _t_stop)
  186. new_epc = self[indices]
  187. new_epc.durations = self.durations[indices]
  188. new_epc.labels = self.labels[indices]
  189. return new_epc
  190. def as_array(self, units=None):
  191. """
  192. Return the epoch start times as a plain NumPy array.
  193. If `units` is specified, first rescale to those units.
  194. """
  195. if units:
  196. return self.rescale(units).magnitude
  197. else:
  198. return self.magnitude
  199. def as_quantity(self):
  200. """
  201. Return the epoch start times as a quantities array.
  202. """
  203. return self.view(pq.Quantity)