view.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. This module implements :class:`ChannelView`, which represents a subset of the
  3. channels in an :class:`AnalogSignal` or :class:`IrregularlySampledSignal`.
  4. It replaces the indexing function of the former :class:`ChannelIndex`.
  5. """
  6. import numpy as np
  7. from .baseneo import BaseNeo
  8. from .basesignal import BaseSignal
  9. from .dataobject import ArrayDict
  10. class ChannelView(BaseNeo):
  11. """
  12. A tool for indexing a subset of the channels within an :class:`AnalogSignal`
  13. or :class:`IrregularlySampledSignal`;
  14. *Required attributes/properties*:
  15. :obj: (AnalogSignal or IrregularlySampledSignal) The signal being indexed.
  16. :index: (list/1D-array) boolean or integer mask to select the channels of interest.
  17. *Recommended attributes/properties*:
  18. :name: (str) A label for the view.
  19. :description: (str) Text description.
  20. :file_origin: (str) Filesystem path or URL of the original data file.
  21. :array_annotations: (dict) Dict mapping strings to numpy arrays containing annotations
  22. for all data points
  23. Note: Any other additional arguments are assumed to be user-specific
  24. metadata and stored in :attr:`annotations`.
  25. """
  26. _single_parent_objects = ('Segment',)
  27. _single_parent_attrs = ('segment',)
  28. _necessary_attrs = (
  29. ('index', np.ndarray, 1, np.dtype('i')),
  30. ('obj', ('AnalogSignal', 'IrregularlySampledSignal'), 1)
  31. )
  32. # "mask" would be an alternative name, proposing "index" for
  33. # backwards-compatibility with ChannelIndex
  34. def __init__(self, obj, index, name=None, description=None, file_origin=None,
  35. array_annotations=None, **annotations):
  36. super().__init__(name=name, description=description,
  37. file_origin=file_origin, **annotations)
  38. if not (isinstance(obj, BaseSignal) or (
  39. hasattr(obj, "proxy_for") and issubclass(obj.proxy_for, BaseSignal))):
  40. raise ValueError("Can only take a ChannelView of an AnalogSignal "
  41. "or an IrregularlySampledSignal")
  42. self.obj = obj
  43. # check type and dtype of index and convert index to a common form
  44. # (accept list or array of bool or int, convert to int array)
  45. self.index = np.array(index)
  46. if len(self.index.shape) != 1:
  47. raise ValueError("index must be a 1D array")
  48. if self.index.dtype == np.bool: # convert boolean mask to integer index
  49. if self.index.size != self.obj.shape[-1]:
  50. raise ValueError("index size does not match number of channels in signal")
  51. self.index, = np.nonzero(self.index)
  52. # allow any type of integer representation
  53. elif self.index.dtype.char not in np.typecodes['AllInteger']:
  54. raise ValueError("index must be of a list or array of data type boolean or integer")
  55. if not hasattr(self, 'array_annotations') or not self.array_annotations:
  56. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  57. if array_annotations is not None:
  58. self.array_annotate(**array_annotations)
  59. @property
  60. def shape(self):
  61. return (self.obj.shape[0], self.index.size)
  62. def _get_arr_ann_length(self):
  63. return self.shape[-1]
  64. def array_annotate(self, **array_annotations):
  65. self.array_annotations.update(array_annotations)
  66. def resolve(self):
  67. """
  68. Return a copy of the underlying object containing just the subset of channels
  69. defined by the index.
  70. """
  71. return self.obj[:, self.index]