test_converter.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """
  2. Tests of the neo.conversion module
  3. """
  4. import unittest
  5. import copy
  6. import numpy as np
  7. from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy,
  8. EventProxy, EpochProxy)
  9. from neo.core import (Epoch, Event, SpikeTrain)
  10. from neo.core.basesignal import BaseSignal
  11. from neo.test.tools import (assert_arrays_equal, assert_same_attributes)
  12. from neo.test.generate_datasets import fake_neo
  13. from neo.converter import convert_channelindex_to_view_group
  14. class ConversionTest(unittest.TestCase):
  15. def setUp(self):
  16. block = fake_neo(n=3)
  17. self.old_block = copy.deepcopy(block)
  18. self.new_block = convert_channelindex_to_view_group(block)
  19. def test_no_deprecated_attributes(self):
  20. self.assertFalse(hasattr(self.new_block, 'channel_indexes'))
  21. # collecting data objects
  22. objs = []
  23. for seg in self.new_block.segments:
  24. objs.extend(seg.analogsignals)
  25. objs.extend(seg.irregularlysampledsignals)
  26. objs.extend(seg.events)
  27. objs.extend(seg.epochs)
  28. objs.extend(seg.spiketrains)
  29. objs.extend(seg.imagesequences)
  30. for obj in objs:
  31. if isinstance(obj, BaseSignal):
  32. self.assertFalse(hasattr(obj, 'channel_index'))
  33. elif isinstance(obj, SpikeTrain):
  34. self.assertFalse(hasattr(obj, 'unit'))
  35. elif isinstance(obj, (Event, Epoch)):
  36. pass
  37. else:
  38. raise TypeError(f'Unexpected data type object {type(obj)}')
  39. def test_block_conversion(self):
  40. # verify that all previous data is present in new structure
  41. groups = self.new_block.groups
  42. for channel_index in self.old_block.channel_indexes:
  43. # check existence of objects and attributes
  44. self.assertIn(channel_index.name, [g.name for g in groups])
  45. group = groups[[g.name for g in groups].index(channel_index.name)]
  46. # comparing group attributes to channel_index attributes
  47. assert_same_attributes(group, channel_index)
  48. self.assertDictEqual(channel_index.annotations, group.annotations)
  49. # comparing views and their attributes
  50. view_names = np.asarray([v.name for v in group.channelviews])
  51. matching_views = np.asarray(group.channelviews)[view_names == channel_index.name]
  52. for view in matching_views:
  53. self.assertIn('channel_ids', view.array_annotations)
  54. self.assertIn('channel_names', view.array_annotations)
  55. self.assertIn('coordinates_dim0', view.array_annotations)
  56. self.assertIn('coordinates_dim1', view.array_annotations)
  57. # check content of attributes
  58. assert_arrays_equal(channel_index.index, view.index)
  59. assert_arrays_equal(channel_index.channel_ids, view.array_annotations['channel_ids'])
  60. assert_arrays_equal(channel_index.channel_names,
  61. view.array_annotations['channel_names'])
  62. view_coordinates = np.vstack((view.array_annotations['coordinates_dim0'],
  63. view.array_annotations['coordinates_dim1'])).T
  64. # readd unit lost during stacking of arrays
  65. units = view.array_annotations['coordinates_dim0'].units
  66. view_coordinates = view_coordinates.magnitude * units
  67. assert_arrays_equal(channel_index.coordinates, view_coordinates)
  68. self.assertDictEqual(channel_index.annotations, view.annotations)
  69. # check linking between objects
  70. self.assertEqual(len(channel_index.data_children), len(matching_views))
  71. # check linking between objects
  72. for child in channel_index.data_children:
  73. # comparing names instead of objects as attributes differ
  74. self.assertIn(child.name, [v.obj.name for v in matching_views])
  75. group_names = np.asarray([g.name for g in group.groups])
  76. for unit in channel_index.units:
  77. self.assertIn(unit.name, group_names)
  78. unit_names = np.asarray([u.name for u in channel_index.units])
  79. matching_groups = np.isin(group_names, unit_names)
  80. self.assertEqual(len(channel_index.units), len(matching_groups))