test_dataobject.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import copy
  2. import numpy as np
  3. import unittest
  4. from neo.core.dataobject import DataObject, _normalize_array_annotations, ArrayDict
  5. class Test_DataObject(unittest.TestCase):
  6. def test(self):
  7. pass
  8. class Test_array_annotations(unittest.TestCase):
  9. def test_check_arr_ann(self):
  10. # DataObject instance that handles checking
  11. datobj = DataObject([1, 2]) # Inherits from Quantity, so some data is required
  12. # Correct annotations
  13. arr1 = np.asarray(["ABC", "DEF"])
  14. arr2 = np.asarray([3, 6])
  15. corr_ann = {'anno1': arr1, 'anno2': arr2}
  16. corr_ann_copy = copy.deepcopy(corr_ann)
  17. # Checking correct annotations should work fine
  18. corr_ann = _normalize_array_annotations(corr_ann, datobj._get_arr_ann_length())
  19. # Make sure the annotations have not been altered
  20. self.assertSequenceEqual(corr_ann.keys(), corr_ann_copy.keys())
  21. self.assertTrue((corr_ann['anno1'] == corr_ann_copy['anno1']).all())
  22. self.assertTrue((corr_ann['anno2'] == corr_ann_copy['anno2']).all())
  23. # Now creating incorrect inputs:
  24. # Nested dict
  25. nested_ann = {'anno1': {'val1': arr1}, 'anno2': {'val2': arr2}}
  26. with self.assertRaises(ValueError):
  27. nested_ann = _normalize_array_annotations(nested_ann, datobj._get_arr_ann_length())
  28. # Containing None
  29. none_ann = corr_ann_copy
  30. # noinspection PyTypeChecker
  31. none_ann['anno2'] = None
  32. with self.assertRaises(ValueError):
  33. none_ann = _normalize_array_annotations(none_ann, datobj._get_arr_ann_length())
  34. # Multi-dimensional arrays in annotations
  35. multi_dim_ann = copy.deepcopy(corr_ann)
  36. multi_dim_ann['anno2'] = multi_dim_ann['anno2'].reshape(1, 2)
  37. with self.assertRaises(ValueError):
  38. multi_dim_ann = _normalize_array_annotations(multi_dim_ann,
  39. datobj._get_arr_ann_length())
  40. # Wrong length of annotations
  41. len_ann = corr_ann
  42. len_ann['anno1'] = np.asarray(['ABC', 'DEF', 'GHI'])
  43. with self.assertRaises(ValueError):
  44. len_ann = _normalize_array_annotations(len_ann, datobj._get_arr_ann_length())
  45. # Scalar as array annotation raises Error if len(datobj)!=1
  46. scalar_ann = copy.deepcopy(corr_ann)
  47. # noinspection PyTypeChecker
  48. scalar_ann['anno2'] = 3
  49. with self.assertRaises(ValueError):
  50. scalar_ann = _normalize_array_annotations(scalar_ann, datobj._get_arr_ann_length())
  51. # But not if len(datobj) == 1, then it's wrapped into an array
  52. # noinspection PyTypeChecker
  53. scalar_ann['anno1'] = 'ABC'
  54. datobj2 = DataObject([1])
  55. scalar_ann = _normalize_array_annotations(scalar_ann, datobj2._get_arr_ann_length())
  56. self.assertIsInstance(scalar_ann['anno1'], np.ndarray)
  57. self.assertIsInstance(scalar_ann['anno2'], np.ndarray)
  58. # Lists are also made to np.ndarrays
  59. list_ann = {'anno1': [3, 6], 'anno2': ['ABC', 'DEF']}
  60. list_ann = _normalize_array_annotations(list_ann, datobj._get_arr_ann_length())
  61. self.assertIsInstance(list_ann['anno1'], np.ndarray)
  62. self.assertIsInstance(list_ann['anno2'], np.ndarray)
  63. def test_implicit_dict_check(self):
  64. # DataObject instance that handles checking
  65. datobj = DataObject([1, 2]) # Inherits from Quantity, so some data is required
  66. # Correct annotations
  67. arr1 = np.asarray(["ABC", "DEF"])
  68. arr2 = np.asarray([3, 6])
  69. corr_ann = {'anno1': arr1, 'anno2': arr2}
  70. corr_ann_copy = copy.deepcopy(corr_ann)
  71. # Implicit checks when setting item in dict directly
  72. # Checking correct annotations should work fine
  73. datobj.array_annotations['anno1'] = arr1
  74. datobj.array_annotations.update({'anno2': arr2})
  75. # Make sure the annotations have not been altered
  76. self.assertTrue((datobj.array_annotations['anno1'] == corr_ann_copy['anno1']).all())
  77. self.assertTrue((datobj.array_annotations['anno2'] == corr_ann_copy['anno2']).all())
  78. # Now creating incorrect inputs:
  79. # Nested dict
  80. nested_ann = {'anno1': {'val1': arr1}, 'anno2': {'val2': arr2}}
  81. with self.assertRaises(ValueError):
  82. datobj.array_annotations['anno1'] = {'val1': arr1}
  83. # Containing None
  84. none_ann = corr_ann_copy
  85. # noinspection PyTypeChecker
  86. none_ann['anno2'] = None
  87. with self.assertRaises(ValueError):
  88. datobj.array_annotations['anno1'] = None
  89. # Multi-dimensional arrays in annotations
  90. multi_dim_ann = copy.deepcopy(corr_ann)
  91. multi_dim_ann['anno2'] = multi_dim_ann['anno2'].reshape(1, 2)
  92. with self.assertRaises(ValueError):
  93. datobj.array_annotations.update(multi_dim_ann)
  94. # Wrong length of annotations
  95. len_ann = corr_ann
  96. len_ann['anno1'] = np.asarray(['ABC', 'DEF', 'GHI'])
  97. with self.assertRaises(ValueError):
  98. datobj.array_annotations.update(len_ann)
  99. # Scalar as array annotation raises Error if len(datobj)!=1
  100. scalar_ann = copy.deepcopy(corr_ann)
  101. # noinspection PyTypeChecker
  102. scalar_ann['anno2'] = 3
  103. with self.assertRaises(ValueError):
  104. datobj.array_annotations.update(scalar_ann)
  105. # But not if len(datobj) == 1, then it's wrapped into an array
  106. # noinspection PyTypeChecker
  107. scalar_ann['anno1'] = 'ABC'
  108. datobj2 = DataObject([1])
  109. datobj2.array_annotations.update(scalar_ann)
  110. self.assertIsInstance(datobj2.array_annotations['anno1'], np.ndarray)
  111. self.assertIsInstance(datobj2.array_annotations['anno2'], np.ndarray)
  112. # Lists are also made to np.ndarrays
  113. list_ann = {'anno1': [3, 6], 'anno2': ['ABC', 'DEF']}
  114. datobj.array_annotations.update(list_ann)
  115. self.assertIsInstance(datobj.array_annotations['anno1'], np.ndarray)
  116. self.assertIsInstance(datobj.array_annotations['anno2'], np.ndarray)
  117. def test_array_annotate(self):
  118. # Calls _check_array_annotations, so no need to test for these Errors here
  119. datobj = DataObject([2, 3, 4])
  120. arr_ann = {'anno1': [3, 4, 5], 'anno2': ['ABC', 'DEF', 'GHI']}
  121. # Pass annotations
  122. datobj.array_annotate(**arr_ann)
  123. # Make sure they are correct
  124. self.assertTrue((datobj.array_annotations['anno1'] == np.array([3, 4, 5])).all())
  125. self.assertTrue(
  126. (datobj.array_annotations['anno2'] == np.array(['ABC', 'DEF', 'GHI'])).all())
  127. self.assertIsInstance(datobj.array_annotations, ArrayDict)
  128. def test_arr_anns_at_index(self):
  129. # Get them, test for desired type and size, content
  130. datobj = DataObject([1, 2, 3, 4])
  131. arr_ann = {'anno1': [3, 4, 5, 6], 'anno2': ['ABC', 'DEF', 'GHI', 'JKL']}
  132. datobj.array_annotate(**arr_ann)
  133. # Integer as index
  134. ann_int = datobj.array_annotations_at_index(1)
  135. self.assertEqual(ann_int, {'anno1': 4, 'anno2': 'DEF'})
  136. # Negative integer as index
  137. ann_int_back = datobj.array_annotations_at_index(-2)
  138. self.assertEqual(ann_int_back, {'anno1': 5, 'anno2': 'GHI'})
  139. # Slice as index
  140. ann_slice = datobj.array_annotations_at_index(slice(1, 3))
  141. self.assert_((ann_slice['anno1'] == np.array([4, 5])).all())
  142. self.assert_((ann_slice['anno2'] == np.array(['DEF', 'GHI'])).all())
  143. # Slice from beginning to end
  144. ann_slice_all = datobj.array_annotations_at_index(slice(0, None))
  145. self.assert_((ann_slice_all['anno1'] == np.array([3, 4, 5, 6])).all())
  146. self.assert_((ann_slice_all['anno2'] == np.array(['ABC', 'DEF', 'GHI', 'JKL'])).all())
  147. # Make sure that original object is edited when editing extracted array_annotations
  148. ann_slice_all['anno1'][2] = 10
  149. self.assertEqual(datobj.array_annotations_at_index(2)['anno1'], 10)