tools.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. '''
  2. Tools for use with neo tests.
  3. '''
  4. import hashlib
  5. import os
  6. import numpy as np
  7. import quantities as pq
  8. import neo
  9. from neo.core import objectlist
  10. from neo.core.baseneo import _reference_name, _container_name
  11. from neo.core.container import Container
  12. from neo.io.basefromrawio import proxyobjectlist, EventProxy, EpochProxy
  13. def assert_arrays_equal(a, b, dtype=False):
  14. '''
  15. Check if two arrays have the same shape and contents.
  16. If dtype is True (default=False), then also theck that they have the same
  17. dtype.
  18. '''
  19. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  20. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  21. assert a.shape == b.shape, "{} != {}".format(a, b)
  22. # assert a.dtype == b.dtype, "%s and %s not same dtype %s %s" % (a, b,
  23. # a.dtype,
  24. # b.dtype)
  25. try:
  26. assert (a.flatten() == b.flatten()).all(), "{} != {}".format(a, b)
  27. except (AttributeError, ValueError):
  28. try:
  29. ar = np.array(a)
  30. br = np.array(b)
  31. assert (ar.flatten() == br.flatten()).all(), "{} != {}".format(ar, br)
  32. except (AttributeError, ValueError):
  33. assert np.all(a.flatten() == b.flatten()), "{} != {}".format(a, b)
  34. if dtype:
  35. assert a.dtype == b.dtype, "{} and {} not same dtype {} and {}".format(
  36. a, b, a.dtype, b.dtype)
  37. def assert_arrays_almost_equal(a, b, threshold, dtype=False):
  38. '''
  39. Check if two arrays have the same shape and contents that differ
  40. by abs(a - b) <= threshold for all elements.
  41. If threshold is None, do an absolute comparison rather than a relative
  42. comparison.
  43. '''
  44. if threshold is None:
  45. return assert_arrays_equal(a, b, dtype=dtype)
  46. assert isinstance(a, np.ndarray), "a is a %s" % type(a)
  47. assert isinstance(b, np.ndarray), "b is a %s" % type(b)
  48. assert a.shape == b.shape, "{} != {}".format(a, b)
  49. # assert a.dtype == b.dtype, "%s and %b not same dtype %s %s" % (a, b,
  50. # a.dtype,
  51. # b.dtype)
  52. if a.dtype.kind in ['f', 'c', 'i']:
  53. assert (abs(
  54. a - b) < threshold).all(), "abs(%s - %s) max(|a - b|) = %s threshold:%s" \
  55. "" % (a, b, (abs(a - b)).max(), threshold)
  56. if dtype:
  57. assert a.dtype == b.dtype, "{} and {} not same dtype {} and {}".format(
  58. a, b, a.dtype, b.dtype)
  59. def file_digest(filename):
  60. '''
  61. Get the sha1 hash of the file with the given filename.
  62. '''
  63. with open(filename, 'rb') as fobj:
  64. return hashlib.sha1(fobj.read()).hexdigest()
  65. def assert_file_contents_equal(a, b):
  66. '''
  67. Assert that two files have the same size and hash.
  68. '''
  69. def generate_error_message(a, b):
  70. '''
  71. This creates the error message for the assertion error
  72. '''
  73. size_a = os.stat(a).st_size
  74. size_b = os.stat(b).st_size
  75. if size_a == size_b:
  76. return "Files have the same size but different contents"
  77. else:
  78. return "Files have different sizes: a:%d b: %d" % (size_a, size_b)
  79. assert file_digest(a) == file_digest(b), generate_error_message(a, b)
  80. def assert_neo_object_is_compliant(ob, check_type=True):
  81. '''
  82. Test neo compliance of one object and sub objects
  83. (one_to_many_relation only):
  84. * check types and/or presence of necessary and recommended attribute.
  85. * If attribute is Quantities or numpy.ndarray it also check ndim.
  86. * If attribute is numpy.ndarray also check dtype.kind.
  87. check_type=True by default can be set to false for testing ProxyObject
  88. '''
  89. if check_type:
  90. assert type(ob) in objectlist, \
  91. '%s is not a neo object' % (type(ob))
  92. classname = ob.__class__.__name__
  93. # test presence of necessary attributes
  94. for ioattr in ob._necessary_attrs:
  95. attrname, attrtype = ioattr[0], ioattr[1]
  96. # ~ if attrname != '':
  97. if not hasattr(ob, '_quantity_attr'):
  98. assert hasattr(ob, attrname), '{} neo obect does not have {}'.format(
  99. classname, attrname)
  100. # test attributes types
  101. for ioattr in ob._all_attrs:
  102. attrname, attrtype = ioattr[0], ioattr[1]
  103. if (hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname and (
  104. attrtype == pq.Quantity or attrtype == np.ndarray)):
  105. # object inherits from Quantity (AnalogSignal, SpikeTrain, ...)
  106. ndim = ioattr[2]
  107. assert ob.ndim == ndim, '%s dimension is %d should be %d' % (classname, ob.ndim, ndim)
  108. if attrtype == np.ndarray:
  109. dtp = ioattr[3]
  110. assert ob.dtype.kind == dtp.kind, '%s dtype.kind is %s should be %s' \
  111. '' % (classname, ob.dtype.kind, dtp.kind)
  112. elif hasattr(ob, attrname):
  113. if getattr(ob, attrname) is not None:
  114. obattr = getattr(ob, attrname)
  115. assert issubclass(type(obattr), attrtype), '%s in %s is %s should be %s' \
  116. '' % (attrname, classname,
  117. type(obattr), attrtype)
  118. if attrtype == pq.Quantity or attrtype == np.ndarray:
  119. ndim = ioattr[2]
  120. assert obattr.ndim == ndim, '%s.%s dimension is %d should be %d' \
  121. '' % (classname, attrname, obattr.ndim, ndim)
  122. if attrtype == np.ndarray:
  123. dtp = ioattr[3]
  124. assert obattr.dtype.kind == dtp.kind, '%s.%s dtype.kind is %s should be %s' \
  125. '' % (classname, attrname,
  126. obattr.dtype.kind, dtp.kind)
  127. # test bijectivity : parents and children
  128. if classname != "Group": # objects in a Group do not keep a reference to the group.
  129. for container in getattr(ob, '_single_child_containers', []):
  130. for i, child in enumerate(getattr(ob, container, [])):
  131. assert hasattr(child, _reference_name(
  132. classname)), '%s should have %s attribute (2 way relationship)' \
  133. '' % (container, _reference_name(classname))
  134. if hasattr(child, _reference_name(classname)):
  135. parent = getattr(child, _reference_name(classname))
  136. assert parent == ob, '%s.%s %s is not symmetric with %s.%s' \
  137. '' % (container, _reference_name(classname), i, classname,
  138. container)
  139. # recursive on one to many rel
  140. for i, child in enumerate(getattr(ob, 'children', [])):
  141. try:
  142. assert_neo_object_is_compliant(child)
  143. # intercept exceptions and add more information
  144. except BaseException as exc:
  145. exc.args += ('from {} {} of {}'.format(child.__class__.__name__, i, classname),)
  146. raise
  147. def assert_same_sub_schema(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  148. '''
  149. Test if ob1 and ob2 has the same sub schema.
  150. Explore all parent/child relationships.
  151. Many_to_many_relationship is not tested
  152. because of infinite recursive loops.
  153. Arguments:
  154. equal_almost: if False do a strict arrays_equal if
  155. True do arrays_almost_equal
  156. exclude: a list of attributes and annotations to ignore in
  157. the comparison
  158. '''
  159. assert type(ob1) == type(ob2), 'type({}) != type({})'.format(type(ob1), type(ob2))
  160. classname = ob1.__class__.__name__
  161. if exclude is None:
  162. exclude = []
  163. if isinstance(ob1, list):
  164. assert len(ob1) == len(ob2), 'lens %s and %s not equal for %s and %s' \
  165. '' % (len(ob1), len(ob2), ob1, ob2)
  166. for i, (sub1, sub2) in enumerate(zip(ob1, ob2)):
  167. try:
  168. assert_same_sub_schema(sub1, sub2, equal_almost=equal_almost, threshold=threshold,
  169. exclude=exclude)
  170. # intercept exceptions and add more information
  171. except BaseException as exc:
  172. exc.args += ('{}[{}]'.format(classname, i),)
  173. raise
  174. return
  175. # test parent/child relationship
  176. for container in getattr(ob1, '_single_child_containers', []):
  177. if container in exclude:
  178. continue
  179. if not hasattr(ob1, container):
  180. assert not hasattr(ob2, container), '%s 2 does have %s but not %s 1' \
  181. '' % (classname, container, classname)
  182. continue
  183. else:
  184. assert hasattr(ob2, container), '{} 1 has {} but not {} 2'.format(classname, container,
  185. classname)
  186. sub1 = getattr(ob1, container)
  187. sub2 = getattr(ob2, container)
  188. assert len(sub1) == len(
  189. sub2), 'theses two %s do not have the same %s number: %s and %s' \
  190. '' % (classname, container, len(sub1), len(sub2))
  191. for i in range(len(getattr(ob1, container))):
  192. # previously lacking parameter
  193. try:
  194. assert_same_sub_schema(sub1[i], sub2[i], equal_almost=equal_almost,
  195. threshold=threshold, exclude=exclude)
  196. # intercept exceptions and add more information
  197. except BaseException as exc:
  198. exc.args += ('from {}[{}] of {}'.format(container, i, classname),)
  199. raise
  200. assert_same_attributes(ob1, ob2, equal_almost=equal_almost, threshold=threshold,
  201. exclude=exclude)
  202. def assert_same_attributes(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  203. '''
  204. Test if ob1 and ob2 has the same attributes.
  205. Arguments:
  206. equal_almost: if False do a strict arrays_equal if
  207. True do arrays_almost_equal
  208. exclude: a list of attributes and annotations to ignore in
  209. the comparison
  210. '''
  211. classname = ob1.__class__.__name__
  212. if exclude is None:
  213. exclude = []
  214. if not equal_almost:
  215. threshold = None
  216. dtype = True
  217. else:
  218. dtype = False
  219. for ioattr in ob1._all_attrs:
  220. if ioattr[0] in exclude:
  221. continue
  222. attrname, attrtype = ioattr[0], ioattr[1]
  223. # ~ if attrname =='':
  224. if hasattr(ob1, '_quantity_attr') and ob1._quantity_attr == attrname:
  225. # object is hinerited from Quantity (AnalogSignal, SpikeTrain, ...)
  226. try:
  227. assert_arrays_almost_equal(ob1.magnitude, ob2.magnitude, threshold=threshold,
  228. dtype=dtype)
  229. # intercept exceptions and add more information
  230. except BaseException as exc:
  231. exc.args += ('from {} {}'.format(classname, attrname),)
  232. raise
  233. assert ob1.dimensionality.string == ob2.dimensionality.string,\
  234. 'Units of %s %s are not the same: %s and %s' \
  235. '' % (classname, attrname, ob1.dimensionality.string, ob2.dimensionality.string)
  236. continue
  237. if not hasattr(ob1, attrname):
  238. assert not hasattr(ob2, attrname), '%s 2 does have %s but not %s 1' \
  239. '' % (classname, attrname, classname)
  240. continue
  241. else:
  242. assert hasattr(ob2, attrname), '%s 1 has %s but not %s 2' \
  243. '' % (classname, attrname, classname)
  244. if getattr(ob1, attrname) is None:
  245. assert getattr(ob2, attrname) is None, 'In %s.%s %s and %s differed' \
  246. '' % (classname, attrname,
  247. getattr(ob1, attrname),
  248. getattr(ob2, attrname))
  249. continue
  250. if getattr(ob2, attrname) is None:
  251. assert getattr(ob1, attrname) is None, 'In %s.%s %s and %s differed' \
  252. '' % (classname, attrname,
  253. getattr(ob1, attrname),
  254. getattr(ob2, attrname))
  255. continue
  256. if attrtype == pq.Quantity:
  257. # Compare magnitudes
  258. mag1 = getattr(ob1, attrname).magnitude
  259. mag2 = getattr(ob2, attrname).magnitude
  260. # print "2. ob1(%s) %s:%s\n ob2(%s) %s:%s" % \
  261. # (ob1,attrname,mag1,ob2,attrname,mag2)
  262. try:
  263. assert_arrays_almost_equal(mag1, mag2, threshold=threshold, dtype=dtype)
  264. # intercept exceptions and add more information
  265. except BaseException as exc:
  266. exc.args += ('from {} of {}'.format(attrname, classname),)
  267. raise
  268. # Compare dimensionalities
  269. dim1 = getattr(ob1, attrname).dimensionality.simplified
  270. dim2 = getattr(ob2, attrname).dimensionality.simplified
  271. dimstr1 = getattr(ob1, attrname).dimensionality.string
  272. dimstr2 = getattr(ob2, attrname).dimensionality.string
  273. assert dim1 == dim2, 'Attribute %s of %s are not the same: %s != %s' \
  274. '' % (attrname, classname, dimstr1, dimstr2)
  275. elif attrtype == np.ndarray:
  276. try:
  277. assert_arrays_almost_equal(getattr(ob1, attrname), getattr(ob2, attrname),
  278. threshold=threshold, dtype=dtype)
  279. # intercept exceptions and add more information
  280. except BaseException as exc:
  281. exc.args += ('from {} of {}'.format(attrname, classname),)
  282. raise
  283. else:
  284. # ~ print 'yep', getattr(ob1, attrname), getattr(ob2, attrname)
  285. assert getattr(ob1, attrname) == getattr(ob2, attrname),\
  286. 'Attribute %s.%s are not the same %s %s %s %s' \
  287. '' % (classname, attrname, type(getattr(ob1, attrname)), getattr(ob1, attrname),
  288. type(getattr(ob2, attrname)), getattr(ob2, attrname))
  289. def assert_same_annotations(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  290. '''
  291. Test if ob1 and ob2 has the same annotations.
  292. Arguments:
  293. equal_almost: if False do a strict arrays_equal if
  294. True do arrays_almost_equal
  295. exclude: a list of attributes and annotations to ignore in
  296. the comparison
  297. '''
  298. if exclude is None:
  299. exclude = []
  300. if not equal_almost:
  301. threshold = None
  302. dtype = False
  303. else:
  304. dtype = True
  305. for key in ob2.annotations:
  306. if key in exclude:
  307. continue
  308. assert key in ob1.annotations
  309. for key, value1 in ob1.annotations.items():
  310. if key in exclude:
  311. continue
  312. assert key in ob2.annotations
  313. value2 = ob2.annotations[key]
  314. if isinstance(value1, np.ndarray):
  315. assert isinstance(value2, np.ndarray)
  316. assert_arrays_almost_equal(value1, value2, threshold=threshold, dtype=False)
  317. else:
  318. assert value1 == value2
  319. def assert_same_array_annotations(ob1, ob2, equal_almost=True, threshold=1e-10, exclude=None):
  320. '''
  321. Test if ob1 and ob2 has the same annotations.
  322. Arguments:
  323. equal_almost: if False do a strict arrays_equal if
  324. True do arrays_almost_equal
  325. exclude: a list of attributes and annotations to ignore in
  326. the comparison
  327. '''
  328. if exclude is None:
  329. exclude = []
  330. if not equal_almost:
  331. threshold = None
  332. dtype = False
  333. else:
  334. dtype = True
  335. for key in ob2.array_annotations:
  336. if key in exclude:
  337. continue
  338. assert key in ob1.array_annotations
  339. for key, value in ob1.array_annotations.items():
  340. if key in exclude:
  341. continue
  342. assert key in ob2.array_annotations
  343. try:
  344. assert_arrays_equal(value, ob2.array_annotations[key])
  345. except ValueError:
  346. assert_arrays_almost_equal(ob1, ob2,
  347. threshold=threshold, dtype=False)
  348. def assert_sub_schema_is_lazy_loaded(ob):
  349. '''
  350. This is util for testing lazy load. All data object must be in proxyobjectlist.
  351. '''
  352. classname = ob.__class__.__name__
  353. if isinstance(ob, Container):
  354. for container in getattr(ob, '_single_child_containers', []):
  355. if not hasattr(ob, container):
  356. continue
  357. sub = getattr(ob, container)
  358. for i, child in enumerate(sub):
  359. try:
  360. assert_sub_schema_is_lazy_loaded(child)
  361. # intercept exceptions and add more information
  362. except BaseException as exc:
  363. exc.args += ('from {} {} of {}'.format(container, i, classname),)
  364. raise
  365. else:
  366. assert ob.__class__ in proxyobjectlist, 'Data object must lazy %' % classname
  367. loaded_ob = ob.load()
  368. assert_neo_object_is_compliant(loaded_ob)
  369. assert_same_annotations(ob, loaded_ob)
  370. exclude = []
  371. if isinstance(ob, EventProxy):
  372. exclude = ['labels']
  373. elif isinstance(ob, EpochProxy):
  374. exclude = ['labels', 'durations']
  375. else:
  376. exclude = []
  377. assert_same_array_annotations(ob, loaded_ob, exclude=exclude)
  378. def assert_objects_equivalent(obj1, obj2):
  379. '''
  380. Compares two NEO objects by looping over the attributes and annotations
  381. and asserting their hashes. No relationships involved.
  382. '''
  383. def assert_attr(obj1, obj2, attr_name):
  384. '''
  385. Assert a single attribute and annotation are the same
  386. '''
  387. assert hasattr(obj1, attr_name)
  388. attr1 = hashlib.md5(getattr(obj1, attr_name)).hexdigest()
  389. assert hasattr(obj2, attr_name)
  390. attr2 = hashlib.md5(getattr(obj2, attr_name)).hexdigest()
  391. assert attr1 == attr2, "Attribute %s for class %s is not equal." \
  392. "" % (attr_name, obj1.__class__.__name__)
  393. obj_type = obj1.__class__.__name__
  394. assert obj_type == obj2.__class__.__name__
  395. for ioattr in obj1._necessary_attrs:
  396. assert_attr(obj1, obj2, ioattr[0])
  397. for ioattr in obj1._recommended_attrs:
  398. if hasattr(obj1, ioattr[0]) or hasattr(obj2, ioattr[0]):
  399. assert_attr(obj1, obj2, ioattr[0])
  400. if hasattr(obj1, "annotations"):
  401. assert hasattr(obj2, "annotations")
  402. for key, value in obj1.annotations:
  403. assert hasattr(obj2.annotations, key)
  404. assert obj2.annotations[key] == value
  405. def assert_children_empty(obj, parent):
  406. '''
  407. Check that the children of a neo object are empty. Used
  408. to check the cascade is implemented properly
  409. '''
  410. classname = obj.__class__.__name__
  411. errmsg = '''%s reader with cascade=False should return
  412. empty children''' % parent.__name__
  413. if hasattr(obj, 'children'):
  414. assert not obj.children, errmsg