dataobject.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """
  2. This module defines :class:`DataObject`, the abstract base class
  3. used by all :module:`neo.core` classes that can contain data (i.e. are not container classes).
  4. It contains basic functionality that is shared among all those data objects.
  5. """
  6. from copy import deepcopy
  7. import warnings
  8. import quantities as pq
  9. import numpy as np
  10. from neo.core.baseneo import BaseNeo, _check_annotations
  11. def _normalize_array_annotations(value, length):
  12. """Check consistency of array annotations
  13. Recursively check that value is either an array or list containing only "simple" types
  14. (number, string, date/time) or is a dict of those.
  15. Args:
  16. :value: (np.ndarray, list or dict) value to be checked for consistency
  17. :length: (int) required length of the array annotation
  18. Returns:
  19. np.ndarray The array_annotations from value in correct form
  20. Raises:
  21. ValueError: In case value is not accepted as array_annotation(s)
  22. """
  23. # First stage, resolve dict of annotations into single annotations
  24. if isinstance(value, dict):
  25. for key in value.keys():
  26. if isinstance(value[key], dict):
  27. raise ValueError("Nested dicts are not allowed as array annotations")
  28. value[key] = _normalize_array_annotations(value[key], length)
  29. elif value is None:
  30. raise ValueError("Array annotations must not be None")
  31. # If not array annotation, pass on to regular check and make it a list, that is checked again
  32. # This covers array annotations with length 1
  33. elif not isinstance(value, (list, np.ndarray)) or (
  34. isinstance(value, pq.Quantity) and value.shape == ()):
  35. _check_annotations(value)
  36. value = _normalize_array_annotations(np.array([value]), length)
  37. # If array annotation, check for correct length, only single dimension and allowed data
  38. else:
  39. # Get length that is required for array annotations, which is equal to the length
  40. # of the object's data
  41. own_length = length
  42. # Escape check if empty array or list and just annotate an empty array (length 0)
  43. # This enables the user to easily create dummy array annotations that will be filled
  44. # with data later on
  45. if len(value) == 0:
  46. if not isinstance(value, np.ndarray):
  47. value = np.ndarray((0,))
  48. val_length = own_length
  49. else:
  50. # Note: len(o) also works for np.ndarray, it then uses the first dimension,
  51. # which is exactly the desired behaviour here
  52. val_length = len(value)
  53. if not own_length == val_length:
  54. raise ValueError(
  55. "Incorrect length of array annotation: {} != {}".format(val_length, own_length))
  56. # Local function used to check single elements of a list or an array
  57. # They must not be lists or arrays and fit the usual annotation data types
  58. def _check_single_elem(element):
  59. # Nested array annotations not allowed currently
  60. # If element is a list or a np.ndarray, it's not conform except if it's a quantity of
  61. # length 1
  62. if isinstance(element, list) or (isinstance(element, np.ndarray) and not (
  63. isinstance(element, pq.Quantity) and (
  64. element.shape == () or element.shape == (1,)))):
  65. raise ValueError("Array annotations should only be 1-dimensional")
  66. if isinstance(element, dict):
  67. raise ValueError("Dictionaries are not supported as array annotations")
  68. # Perform regular check for elements of array or list
  69. _check_annotations(element)
  70. # Arrays only need testing of single element to make sure the others are the same
  71. if isinstance(value, np.ndarray):
  72. # Type of first element is representative for all others
  73. # Thus just performing a check on the first element is enough
  74. # Even if it's a pq.Quantity, which can be scalar or array, this is still true
  75. # Because a np.ndarray cannot contain scalars and sequences simultaneously
  76. # If length of data is 0, then nothing needs to be checked
  77. if len(value):
  78. # Perform check on first element
  79. _check_single_elem(value[0])
  80. return value
  81. # In case of list, it needs to be ensured that all data are of the same type
  82. else:
  83. # Conversion to numpy array makes all elements same type
  84. # Converts elements to most general type
  85. try:
  86. value = np.array(value)
  87. # Except when scalar and non-scalar values are mixed, this causes conversion to fail
  88. except ValueError as e:
  89. msg = str(e)
  90. if "setting an array element with a sequence." in msg:
  91. raise ValueError("Scalar values and arrays/lists cannot be "
  92. "combined into a single array annotation")
  93. else:
  94. raise e
  95. # If most specialized data type that possibly fits all elements is object,
  96. # raise an Error with a telling error message, because this means the elements
  97. # are not compatible
  98. if value.dtype == object:
  99. raise ValueError("Cannot convert list of incompatible types into a single"
  100. " array annotation")
  101. # Check the first element for correctness
  102. # If its type is correct for annotations, all others are correct as well
  103. # Note: Emtpy lists cannot reach this point
  104. _check_single_elem(value[0])
  105. return value
  106. class DataObject(BaseNeo, pq.Quantity):
  107. '''
  108. This is the base class from which all objects containing data inherit
  109. It contains common functionality for all those objects and handles array_annotations.
  110. Common functionality that is not included in BaseNeo includes:
  111. - duplicating with new data
  112. - rescaling the object
  113. - copying the object
  114. - returning it as pq.Quantity or np.ndarray
  115. - handling of array_annotations
  116. Array_annotations are a kind of annotation that contains metadata for every data point,
  117. i.e. per timestamp (in SpikeTrain, Event and Epoch) or signal channel (in AnalogSignal
  118. and IrregularlySampledSignal).
  119. They can contain the same data types as regular annotations, but are always represented
  120. as numpy arrays of the same length as the number of data points of the annotated neo object.
  121. Args:
  122. name (str, optional): Name of the Neo object
  123. description (str, optional): Human readable string description of the Neo object
  124. file_origin (str, optional): Origin of the data contained in this Neo object
  125. array_annotations (dict, optional): Dictionary containing arrays / lists which annotate
  126. individual data points of the Neo object.
  127. kwargs: regular annotations stored in a separate annotation dictionary
  128. '''
  129. def __init__(self, name=None, description=None, file_origin=None, array_annotations=None,
  130. **annotations):
  131. """
  132. This method is called by each data object and initializes the newly created object by
  133. adding array annotations and calling __init__ of the super class, where more annotations
  134. and attributes are processed.
  135. """
  136. if not hasattr(self, 'array_annotations') or not self.array_annotations:
  137. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  138. if array_annotations is not None:
  139. self.array_annotate(**array_annotations)
  140. BaseNeo.__init__(self, name=name, description=description, file_origin=file_origin,
  141. **annotations)
  142. def array_annotate(self, **array_annotations):
  143. """
  144. Add array annotations (annotations for individual data points) as arrays to a Neo data
  145. object.
  146. Example:
  147. >>> obj.array_annotate(code=['a', 'b', 'a'], category=[2, 1, 1])
  148. >>> obj.array_annotations['code'][1]
  149. 'b'
  150. """
  151. self.array_annotations.update(array_annotations)
  152. def array_annotations_at_index(self, index):
  153. """
  154. Return dictionary of array annotations at a given index or list of indices
  155. :param index: int, list, numpy array: The index (indices) from which the annotations
  156. are extracted
  157. :return: dictionary of values or numpy arrays containing all array annotations
  158. for given index/indices
  159. Example:
  160. >>> obj.array_annotate(code=['a', 'b', 'a'], category=[2, 1, 1])
  161. >>> obj.array_annotations_at_index(1)
  162. {code='b', category=1}
  163. """
  164. # Taking only a part of the array annotations
  165. # Thus not using ArrayDict here, because checks for length are not needed
  166. index_annotations = {}
  167. # Use what is given as an index to determine the corresponding annotations,
  168. # if not possible, numpy raises an Error
  169. for ann in self.array_annotations.keys():
  170. # NO deepcopy, because someone might want to alter the actual object using this
  171. try:
  172. index_annotations[ann] = self.array_annotations[ann][index]
  173. except IndexError as e:
  174. # IndexError caused by 'dummy' array annotations should not result in failure
  175. # Taking a slice from nothing results in nothing
  176. if len(self.array_annotations[ann]) == 0 and not self._get_arr_ann_length() == 0:
  177. index_annotations[ann] = self.array_annotations[ann]
  178. else:
  179. raise e
  180. return index_annotations
  181. def _merge_array_annotations(self, other):
  182. '''
  183. Merges array annotations of 2 different objects.
  184. The merge happens in such a way that the result fits the merged data
  185. In general this means concatenating the arrays from the 2 objects.
  186. If an annotation is only present in one of the objects, it will be omitted
  187. :return Merged array_annotations
  188. '''
  189. merged_array_annotations = {}
  190. omitted_keys_self = []
  191. # Concatenating arrays for each key
  192. for key in self.array_annotations:
  193. try:
  194. value = deepcopy(self.array_annotations[key])
  195. other_value = deepcopy(other.array_annotations[key])
  196. # Quantities need to be rescaled to common unit
  197. if isinstance(value, pq.Quantity):
  198. try:
  199. other_value = other_value.rescale(value.units)
  200. except ValueError:
  201. raise ValueError("Could not merge array annotations "
  202. "due to different units")
  203. merged_array_annotations[key] = np.append(value, other_value) * value.units
  204. else:
  205. merged_array_annotations[key] = np.append(value, other_value)
  206. except KeyError:
  207. # Save the omitted keys to be able to print them
  208. omitted_keys_self.append(key)
  209. continue
  210. # Also save omitted keys from 'other'
  211. omitted_keys_other = [key for key in other.array_annotations if
  212. key not in self.array_annotations]
  213. # Warn if keys were omitted
  214. if omitted_keys_other or omitted_keys_self:
  215. warnings.warn("The following array annotations were omitted, because they were only "
  216. "present in one of the merged objects: {} from the one that was merged "
  217. "into and {} from the one that was merged into the other"
  218. "".format(omitted_keys_self, omitted_keys_other), UserWarning)
  219. # Return the merged array_annotations
  220. return merged_array_annotations
  221. def rescale(self, units):
  222. '''
  223. Return a copy of the object converted to the specified
  224. units
  225. :return: Copy of self with specified units
  226. '''
  227. # Use simpler functionality, if nothing will be changed
  228. dim = pq.quantity.validate_dimensionality(units)
  229. if self.dimensionality == dim:
  230. return self.copy()
  231. # Rescale the object into a new object
  232. obj = self.duplicate_with_new_data(signal=self.view(pq.Quantity).rescale(dim), units=units)
  233. # Expected behavior is deepcopy, so deepcopying array_annotations
  234. obj.array_annotations = deepcopy(self.array_annotations)
  235. obj.segment = self.segment
  236. return obj
  237. # Needed to implement this so array annotations are copied as well, ONLY WHEN copying 1:1
  238. def copy(self, **kwargs):
  239. '''
  240. Returns a shallow copy of the object
  241. :return: Copy of self
  242. '''
  243. obj = super().copy(**kwargs)
  244. obj.array_annotations = self.array_annotations
  245. return obj
  246. def as_array(self, units=None):
  247. """
  248. Return the object's data as a plain NumPy array.
  249. If `units` is specified, first rescale to those units.
  250. """
  251. if units:
  252. return self.rescale(units).magnitude
  253. else:
  254. return self.magnitude
  255. def as_quantity(self):
  256. """
  257. Return the object's data as a quantities array.
  258. """
  259. return self.view(pq.Quantity)
  260. def _get_arr_ann_length(self):
  261. """
  262. Return the length of the object's data as required for array annotations
  263. This is the last dimension of every object.
  264. :return Required length of array annotations for this object
  265. """
  266. # Number of items is last dimension in of data object
  267. # This method should be overridden in case this changes
  268. try:
  269. length = self.shape[-1]
  270. # Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
  271. # To be removed if __getitem__[int] is changed
  272. except IndexError:
  273. length = 1
  274. return length
  275. def __deepcopy__(self, memo):
  276. """
  277. Create a deep copy of the data object.
  278. All attributes and annotations are also deep copied.
  279. References to parent objects are not kept, they are set to None.
  280. :param memo: (dict) Objects that have been deep copied already
  281. :return: (DataObject) Deep copy of the input DataObject
  282. """
  283. cls = self.__class__
  284. necessary_attrs = {}
  285. # Units need to be specified explicitly for analogsignals/irregularlysampledsignals
  286. for k in self._necessary_attrs + (('units',),):
  287. necessary_attrs[k[0]] = getattr(self, k[0], self)
  288. # Create object using constructor with necessary attributes
  289. new_obj = cls(**necessary_attrs)
  290. # Add all attributes
  291. new_obj.__dict__.update(self.__dict__)
  292. memo[id(self)] = new_obj
  293. for k, v in self.__dict__.items():
  294. # Single parent objects should not be deepcopied, because this is not expected behavior
  295. # and leads to a lot of stuff being copied (e.g. all other children of the parent as well),
  296. # thus creating a lot of overhead
  297. # But keeping the reference to the same parent is not desired either, because this would be unidirectional
  298. # When deepcopying top-down, e.g. a whole block, the links will be handled by the parent
  299. if k in self._single_parent_attrs:
  300. setattr(new_obj, k, None)
  301. continue
  302. try:
  303. setattr(new_obj, k, deepcopy(v, memo))
  304. except TypeError:
  305. setattr(new_obj, k, v)
  306. return new_obj
  307. class ArrayDict(dict):
  308. """Dictionary subclass to handle array annotations
  309. When setting `obj.array_annotations[key]=value`, checks for consistency
  310. should not be bypassed.
  311. This class overrides __setitem__ from dict to perform these checks every time.
  312. The method used for these checks is given as an argument for __init__.
  313. """
  314. def __init__(self, length, check_function=_normalize_array_annotations, *args, **kwargs):
  315. super().__init__(*args, **kwargs)
  316. self.check_function = check_function
  317. self.length = length
  318. def __setitem__(self, key, value):
  319. # Directly call the defined function
  320. # Need to wrap key and value in a dict in order to make sure
  321. # that nested dicts are detected
  322. value = self.check_function({key: value}, self.length)[key]
  323. super().__setitem__(key, value)
  324. # Updating the dict also needs to perform checks, so rerouting this to __setitem__
  325. def update(self, *args, **kwargs):
  326. if args:
  327. if len(args) > 1:
  328. raise TypeError("update expected at most 1 arguments, "
  329. "got %d" % len(args))
  330. other = dict(args[0])
  331. for key in other:
  332. self[key] = other[key]
  333. for key in kwargs:
  334. self[key] = kwargs[key]
  335. def __reduce__(self):
  336. return super().__reduce__()