test_csd.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the kCSD methods
  4. This was written by :
  5. Chaitanya Chintaluri,
  6. Laboratory of Neuroinformatics,
  7. Nencki Institute of Exprimental Biology, Warsaw.
  8. :license: Modified BSD, see LICENSE.txt for details.
  9. """
  10. import unittest
  11. import numpy as np
  12. import quantities as pq
  13. from elephant import current_source_density as csd
  14. import elephant.current_source_density_src.utility_functions as utils
  15. available_1d = ['StandardCSD', 'DeltaiCSD', 'StepiCSD', 'SplineiCSD', 'KCSD1D']
  16. available_2d = ['KCSD2D', 'MoIKCSD']
  17. available_3d = ['KCSD3D']
  18. kernel_methods = ['KCSD1D', 'KCSD2D', 'KCSD3D', 'MoIKCSD']
  19. icsd_methods = ['DeltaiCSD', 'StepiCSD', 'SplineiCSD']
  20. py_iCSD_toolbox = ['StandardCSD', 'DeltaiCSD', 'StepiCSD', 'SplineiCSD']
  21. class LFP_TestCase(unittest.TestCase):
  22. def test_lfp1d_electrodes(self):
  23. ele_pos = utils.generate_electrodes(dim=1).reshape(5, 1)
  24. lfp = csd.generate_lfp(utils.gauss_1d_dipole, ele_pos)
  25. self.assertEqual(ele_pos.shape[1], 1)
  26. self.assertEqual(ele_pos.shape[0], len(lfp))
  27. def test_lfp2d_electrodes(self):
  28. ele_pos = utils.generate_electrodes(dim=2)
  29. xx_ele, yy_ele = ele_pos
  30. lfp = csd.generate_lfp(utils.large_source_2D, xx_ele, yy_ele)
  31. self.assertEqual(len(ele_pos), 2)
  32. self.assertEqual(xx_ele.shape[0], len(lfp))
  33. def test_lfp3d_electrodes(self):
  34. ele_pos = utils.generate_electrodes(dim=3, res=3)
  35. xx_ele, yy_ele, zz_ele = ele_pos
  36. lfp = csd.generate_lfp(utils.gauss_3d_dipole, xx_ele, yy_ele, zz_ele)
  37. self.assertEqual(len(ele_pos), 3)
  38. self.assertEqual(xx_ele.shape[0], len(lfp))
  39. class CSD1D_TestCase(unittest.TestCase):
  40. def setUp(self):
  41. self.ele_pos = utils.generate_electrodes(dim=1).reshape(5, 1)
  42. self.lfp = csd.generate_lfp(utils.gauss_1d_dipole, self.ele_pos)
  43. self.csd_method = csd.estimate_csd
  44. self.params = {} # Input dictionaries for each method
  45. self.params['DeltaiCSD'] = {'sigma_top': 0. * pq.S / pq.m,
  46. 'diam': 500E-6 * pq.m}
  47. self.params['StepiCSD'] = {'sigma_top': 0. * pq.S / pq.m, 'tol': 1E-12,
  48. 'diam': 500E-6 * pq.m}
  49. self.params['SplineiCSD'] = {'sigma_top': 0. * pq.S / pq.m,
  50. 'num_steps': 201, 'tol': 1E-12,
  51. 'diam': 500E-6 * pq.m}
  52. self.params['StandardCSD'] = {}
  53. self.params['KCSD1D'] = {'h': 50., 'Rs': np.array((0.1, 0.25, 0.5))}
  54. def test_validate_inputs(self):
  55. self.assertRaises(TypeError, self.csd_method, lfp=[[1], [2], [3]])
  56. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  57. coords=self.ele_pos * pq.mm)
  58. # inconsistent number of electrodes
  59. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  60. coords=[1, 2, 3, 4] * pq.mm, method='StandardCSD')
  61. # bad method name
  62. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  63. method='InvalidMethodName')
  64. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  65. method='KCSD2D')
  66. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  67. method='KCSD3D')
  68. def test_inputs_standardcsd(self):
  69. method = 'StandardCSD'
  70. result = self.csd_method(self.lfp, method=method)
  71. self.assertEqual(result.t_start, 0.0 * pq.s)
  72. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  73. self.assertEqual(len(result.times), 1)
  74. def test_inputs_deltasplineicsd(self):
  75. methods = ['DeltaiCSD', 'SplineiCSD']
  76. for method in methods:
  77. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  78. method=method)
  79. result = self.csd_method(self.lfp, method=method,
  80. **self.params[method])
  81. self.assertEqual(result.t_start, 0.0 * pq.s)
  82. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  83. self.assertEqual(len(result.times), 1)
  84. def test_inputs_stepicsd(self):
  85. method = 'StepiCSD'
  86. self.assertRaises(ValueError, self.csd_method, lfp=self.lfp,
  87. method=method)
  88. self.assertRaises(AssertionError, self.csd_method, lfp=self.lfp,
  89. method=method, **self.params[method])
  90. self.params['StepiCSD'].update({'h': np.ones(5) * 100E-6 * pq.m})
  91. result = self.csd_method(self.lfp, method=method,
  92. **self.params[method])
  93. self.assertEqual(result.t_start, 0.0 * pq.s)
  94. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  95. self.assertEqual(len(result.times), 1)
  96. def test_inuts_kcsd(self):
  97. method = 'KCSD1D'
  98. result = self.csd_method(self.lfp, method=method,
  99. **self.params[method])
  100. self.assertEqual(result.t_start, 0.0 * pq.s)
  101. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  102. self.assertEqual(len(result.times), 1)
  103. class CSD2D_TestCase(unittest.TestCase):
  104. def setUp(self):
  105. xx_ele, yy_ele = utils.generate_electrodes(dim=2)
  106. self.lfp = csd.generate_lfp(utils.large_source_2D, xx_ele, yy_ele)
  107. self.params = {} # Input dictionaries for each method
  108. self.params['KCSD2D'] = {'sigma': 1., 'Rs': np.array((0.1, 0.25, 0.5))}
  109. def test_kcsd2d_init(self):
  110. method = 'KCSD2D'
  111. result = csd.estimate_csd(lfp=self.lfp, method=method,
  112. **self.params[method])
  113. self.assertEqual(result.t_start, 0.0 * pq.s)
  114. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  115. self.assertEqual(len(result.times), 1)
  116. class CSD3D_TestCase(unittest.TestCase):
  117. def setUp(self):
  118. xx_ele, yy_ele, zz_ele = utils.generate_electrodes(dim=3)
  119. self.lfp = csd.generate_lfp(utils.gauss_3d_dipole,
  120. xx_ele, yy_ele, zz_ele)
  121. self.params = {}
  122. self.params['KCSD3D'] = {'gdx': 0.1, 'gdy': 0.1, 'gdz': 0.1,
  123. 'src_type': 'step',
  124. 'Rs': np.array((0.1, 0.25, 0.5))}
  125. def test_kcsd2d_init(self):
  126. method = 'KCSD3D'
  127. result = csd.estimate_csd(lfp=self.lfp, method=method,
  128. **self.params[method])
  129. self.assertEqual(result.t_start, 0.0 * pq.s)
  130. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  131. self.assertEqual(len(result.times), 1)
  132. if __name__ == '__main__':
  133. unittest.main()