dataobject.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # -*- coding: utf-8 -*-
  2. """
  3. This module defines :class:`DataObject`, the abstract base class
  4. used by all :module:`neo.core` classes that can contain data (i.e. are not container classes).
  5. It contains basic functionality that is shared among all those data objects.
  6. """
  7. import copy
  8. import warnings
  9. import quantities as pq
  10. import numpy as np
  11. from neo.core.baseneo import BaseNeo, _check_annotations
  12. def _normalize_array_annotations(value, length):
  13. """Check consistency of array annotations
  14. Recursively check that value is either an array or list containing only "simple" types
  15. (number, string, date/time) or is a dict of those.
  16. Args:
  17. :value: (np.ndarray, list or dict) value to be checked for consistency
  18. :length: (int) required length of the array annotation
  19. Returns:
  20. np.ndarray The array_annotations from value in correct form
  21. Raises:
  22. ValueError: In case value is not accepted as array_annotation(s)
  23. """
  24. # First stage, resolve dict of annotations into single annotations
  25. if isinstance(value, dict):
  26. for key in value.keys():
  27. if isinstance(value[key], dict):
  28. raise ValueError("Nested dicts are not allowed as array annotations")
  29. value[key] = _normalize_array_annotations(value[key], length)
  30. elif value is None:
  31. raise ValueError("Array annotations must not be None")
  32. # If not array annotation, pass on to regular check and make it a list, that is checked again
  33. # This covers array annotations with length 1
  34. elif not isinstance(value, (list, np.ndarray)) or (
  35. isinstance(value, pq.Quantity) and value.shape == ()):
  36. _check_annotations(value)
  37. value = _normalize_array_annotations(np.array([value]), length)
  38. # If array annotation, check for correct length, only single dimension and allowed data
  39. else:
  40. # Get length that is required for array annotations, which is equal to the length
  41. # of the object's data
  42. own_length = length
  43. # Escape check if empty array or list and just annotate an empty array (length 0)
  44. # This enables the user to easily create dummy array annotations that will be filled
  45. # with data later on
  46. if len(value) == 0:
  47. if not isinstance(value, np.ndarray):
  48. value = np.ndarray((0,))
  49. val_length = own_length
  50. else:
  51. # Note: len(o) also works for np.ndarray, it then uses the first dimension,
  52. # which is exactly the desired behaviour here
  53. val_length = len(value)
  54. if not own_length == val_length:
  55. raise ValueError(
  56. "Incorrect length of array annotation: {} != {}".format(val_length, own_length))
  57. # Local function used to check single elements of a list or an array
  58. # They must not be lists or arrays and fit the usual annotation data types
  59. def _check_single_elem(element):
  60. # Nested array annotations not allowed currently
  61. # If element is a list or a np.ndarray, it's not conform except if it's a quantity of
  62. # length 1
  63. if isinstance(element, list) or (isinstance(element, np.ndarray) and not (
  64. isinstance(element, pq.Quantity) and (
  65. element.shape == () or element.shape == (1,)))):
  66. raise ValueError("Array annotations should only be 1-dimensional")
  67. if isinstance(element, dict):
  68. raise ValueError("Dictionaries are not supported as array annotations")
  69. # Perform regular check for elements of array or list
  70. _check_annotations(element)
  71. # Arrays only need testing of single element to make sure the others are the same
  72. if isinstance(value, np.ndarray):
  73. # Type of first element is representative for all others
  74. # Thus just performing a check on the first element is enough
  75. # Even if it's a pq.Quantity, which can be scalar or array, this is still true
  76. # Because a np.ndarray cannot contain scalars and sequences simultaneously
  77. # If length of data is 0, then nothing needs to be checked
  78. if len(value):
  79. # Perform check on first element
  80. _check_single_elem(value[0])
  81. return value
  82. # In case of list, it needs to be ensured that all data are of the same type
  83. else:
  84. # Conversion to numpy array makes all elements same type
  85. # Converts elements to most general type
  86. try:
  87. value = np.array(value)
  88. # Except when scalar and non-scalar values are mixed, this causes conversion to fail
  89. except ValueError as e:
  90. msg = str(e)
  91. if "setting an array element with a sequence." in msg:
  92. raise ValueError("Scalar values and arrays/lists cannot be "
  93. "combined into a single array annotation")
  94. else:
  95. raise e
  96. # If most specialized data type that possibly fits all elements is object,
  97. # raise an Error with a telling error message, because this means the elements
  98. # are not compatible
  99. if value.dtype == object:
  100. raise ValueError("Cannot convert list of incompatible types into a single"
  101. " array annotation")
  102. # Check the first element for correctness
  103. # If its type is correct for annotations, all others are correct as well
  104. # Note: Emtpy lists cannot reach this point
  105. _check_single_elem(value[0])
  106. return value
  107. class DataObject(BaseNeo, pq.Quantity):
  108. '''
  109. This is the base class from which all objects containing data inherit
  110. It contains common functionality for all those objects and handles array_annotations.
  111. Common functionality that is not included in BaseNeo includes:
  112. - duplicating with new data
  113. - rescaling the object
  114. - copying the object
  115. - returning it as pq.Quantity or np.ndarray
  116. - handling of array_annotations
  117. Array_annotations are a kind of annotation that contains metadata for every data point,
  118. i.e. per timestamp (in SpikeTrain, Event and Epoch) or signal channel (in AnalogSignal
  119. and IrregularlySampledSignal).
  120. They can contain the same data types as regular annotations, but are always represented
  121. as numpy arrays of the same length as the number of data points of the annotated neo object.
  122. Args:
  123. name (str, optional): Name of the Neo object
  124. description (str, optional): Human readable string description of the Neo object
  125. file_origin (str, optional): Origin of the data contained in this Neo object
  126. array_annotations (dict, optional): Dictionary containing arrays / lists which annotate
  127. individual data points of the Neo object.
  128. kwargs: regular annotations stored in a separate annotation dictionary
  129. '''
  130. def __init__(self, name=None, description=None, file_origin=None, array_annotations=None,
  131. **annotations):
  132. """
  133. This method is called by each data object and initializes the newly created object by
  134. adding array annotations and calling __init__ of the super class, where more annotations
  135. and attributes are processed.
  136. """
  137. if not hasattr(self, 'array_annotations') or not self.array_annotations:
  138. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  139. if array_annotations is not None:
  140. self.array_annotate(**array_annotations)
  141. BaseNeo.__init__(self, name=name, description=description, file_origin=file_origin,
  142. **annotations)
  143. def array_annotate(self, **array_annotations):
  144. """
  145. Add array annotations (annotations for individual data points) as arrays to a Neo data
  146. object.
  147. Example:
  148. >>> obj.array_annotate(code=['a', 'b', 'a'], category=[2, 1, 1])
  149. >>> obj.array_annotations['code'][1]
  150. 'b'
  151. """
  152. self.array_annotations.update(array_annotations)
  153. def array_annotations_at_index(self, index):
  154. """
  155. Return dictionary of array annotations at a given index or list of indices
  156. :param index: int, list, numpy array: The index (indices) from which the annotations
  157. are extracted
  158. :return: dictionary of values or numpy arrays containing all array annotations
  159. for given index/indices
  160. Example:
  161. >>> obj.array_annotate(code=['a', 'b', 'a'], category=[2, 1, 1])
  162. >>> obj.array_annotations_at_index(1)
  163. {code='b', category=1}
  164. """
  165. # Taking only a part of the array annotations
  166. # Thus not using ArrayDict here, because checks for length are not needed
  167. index_annotations = {}
  168. # Use what is given as an index to determine the corresponding annotations,
  169. # if not possible, numpy raises an Error
  170. for ann in self.array_annotations.keys():
  171. # NO deepcopy, because someone might want to alter the actual object using this
  172. try:
  173. index_annotations[ann] = self.array_annotations[ann][index]
  174. except IndexError as e:
  175. # IndexError caused by 'dummy' array annotations should not result in failure
  176. # Taking a slice from nothing results in nothing
  177. if len(self.array_annotations[ann]) == 0 and not self._get_arr_ann_length() == 0:
  178. index_annotations[ann] = self.array_annotations[ann]
  179. else:
  180. raise e
  181. return index_annotations
  182. def _merge_array_annotations(self, other):
  183. '''
  184. Merges array annotations of 2 different objects.
  185. The merge happens in such a way that the result fits the merged data
  186. In general this means concatenating the arrays from the 2 objects.
  187. If an annotation is only present in one of the objects, it will be omitted
  188. :return Merged array_annotations
  189. '''
  190. merged_array_annotations = {}
  191. omitted_keys_self = []
  192. # Concatenating arrays for each key
  193. for key in self.array_annotations:
  194. try:
  195. value = copy.deepcopy(self.array_annotations[key])
  196. other_value = copy.deepcopy(other.array_annotations[key])
  197. # Quantities need to be rescaled to common unit
  198. if isinstance(value, pq.Quantity):
  199. try:
  200. other_value = other_value.rescale(value.units)
  201. except ValueError:
  202. raise ValueError("Could not merge array annotations "
  203. "due to different units")
  204. merged_array_annotations[key] = np.append(value, other_value) * value.units
  205. else:
  206. merged_array_annotations[key] = np.append(value, other_value)
  207. except KeyError:
  208. # Save the omitted keys to be able to print them
  209. omitted_keys_self.append(key)
  210. continue
  211. # Also save omitted keys from 'other'
  212. omitted_keys_other = [key for key in other.array_annotations if
  213. key not in self.array_annotations]
  214. # Warn if keys were omitted
  215. if omitted_keys_other or omitted_keys_self:
  216. warnings.warn("The following array annotations were omitted, because they were only "
  217. "present in one of the merged objects: {} from the one that was merged "
  218. "into and {} from the one that was merged into the other"
  219. "".format(omitted_keys_self, omitted_keys_other), UserWarning)
  220. # Return the merged array_annotations
  221. return merged_array_annotations
  222. def rescale(self, units):
  223. '''
  224. Return a copy of the object converted to the specified
  225. units
  226. :return: Copy of self with specified units
  227. '''
  228. # Use simpler functionality, if nothing will be changed
  229. dim = pq.quantity.validate_dimensionality(units)
  230. if self.dimensionality == dim:
  231. return self.copy()
  232. # Rescale the object into a new object
  233. obj = self.duplicate_with_new_data(signal=self.view(pq.Quantity).rescale(dim), units=units)
  234. # Expected behavior is deepcopy, so deepcopying array_annotations
  235. obj.array_annotations = copy.deepcopy(self.array_annotations)
  236. obj.segment = self.segment
  237. return obj
  238. # Needed to implement this so array annotations are copied as well, ONLY WHEN copying 1:1
  239. def copy(self, **kwargs):
  240. '''
  241. Returns a copy of the object
  242. :return: Copy of self
  243. '''
  244. obj = super(DataObject, self).copy(**kwargs)
  245. obj.array_annotations = self.array_annotations
  246. return obj
  247. def as_array(self, units=None):
  248. """
  249. Return the object's data as a plain NumPy array.
  250. If `units` is specified, first rescale to those units.
  251. """
  252. if units:
  253. return self.rescale(units).magnitude
  254. else:
  255. return self.magnitude
  256. def as_quantity(self):
  257. """
  258. Return the object's data as a quantities array.
  259. """
  260. return self.view(pq.Quantity)
  261. def _get_arr_ann_length(self):
  262. """
  263. Return the length of the object's data as required for array annotations
  264. This is the last dimension of every object.
  265. :return Required length of array annotations for this object
  266. """
  267. # Number of items is last dimension in of data object
  268. # This method should be overridden in case this changes
  269. try:
  270. length = self.shape[-1]
  271. # Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
  272. # To be removed if __getitem__[int] is changed
  273. except IndexError:
  274. length = 1
  275. return length
  276. def duplicate_with_new_array(self, signal, units=None):
  277. warnings.warn("Use of the `duplicate_with_new_array function is deprecated. "
  278. "Please use `duplicate_with_new_data` instead.",
  279. DeprecationWarning)
  280. return self.duplicate_with_new_data(signal, units=units)
  281. class ArrayDict(dict):
  282. """Dictionary subclass to handle array annotations
  283. When setting `obj.array_annotations[key]=value`, checks for consistency
  284. should not be bypassed.
  285. This class overrides __setitem__ from dict to perform these checks every time.
  286. The method used for these checks is given as an argument for __init__.
  287. """
  288. def __init__(self, length, check_function=_normalize_array_annotations, *args, **kwargs):
  289. super(ArrayDict, self).__init__(*args, **kwargs)
  290. self.check_function = check_function
  291. self.length = length
  292. def __setitem__(self, key, value):
  293. # Directly call the defined function
  294. # Need to wrap key and value in a dict in order to make sure
  295. # that nested dicts are detected
  296. value = self.check_function({key: value}, self.length)[key]
  297. super(ArrayDict, self).__setitem__(key, value)
  298. # Updating the dict also needs to perform checks, so rerouting this to __setitem__
  299. def update(self, *args, **kwargs):
  300. if args:
  301. if len(args) > 1:
  302. raise TypeError("update expected at most 1 arguments, "
  303. "got %d" % len(args))
  304. other = dict(args[0])
  305. for key in other:
  306. self[key] = other[key]
  307. for key in kwargs:
  308. self[key] = kwargs[key]
  309. def __reduce__(self):
  310. return super(ArrayDict, self).__reduce__()