test_kcsd.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the kCSD methods
  4. This was written by :
  5. Chaitanya Chintaluri,
  6. Laboratory of Neuroinformatics,
  7. Nencki Institute of Exprimental Biology, Warsaw.
  8. :license: Modified BSD, see LICENSE.txt for details.
  9. """
  10. import unittest
  11. import neo
  12. import numpy as np
  13. import quantities as pq
  14. from elephant import current_source_density as CSD
  15. import elephant.current_source_density_src.utility_functions as utils
  16. class KCSD1D_TestCase(unittest.TestCase):
  17. def setUp(self):
  18. self.ele_pos = utils.generate_electrodes(dim=1).reshape(5, 1)
  19. self.csd_profile = utils.gauss_1d_dipole
  20. pots = CSD.generate_lfp(self.csd_profile, self.ele_pos)
  21. self.pots = np.reshape(pots, (-1, 1))
  22. self.test_method = 'KCSD1D'
  23. self.test_params = {'h': 50.}
  24. temp_signals = []
  25. for ii in range(len(self.pots)):
  26. temp_signals.append(self.pots[ii])
  27. self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
  28. sampling_rate=1000 * pq.Hz)
  29. chidx = neo.ChannelIndex(range(len(self.pots)))
  30. chidx.analogsignals.append(self.an_sigs)
  31. chidx.coordinates = self.ele_pos * pq.mm
  32. chidx.create_relationship()
  33. def test_kcsd1d_estimate(self, cv_params={}):
  34. self.test_params.update(cv_params)
  35. result = CSD.estimate_csd(self.an_sigs, method=self.test_method,
  36. **self.test_params)
  37. self.assertEqual(result.t_start, 0.0 * pq.s)
  38. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  39. self.assertEqual(result.times, [0.] * pq.s)
  40. self.assertEqual(len(result.annotations.keys()), 1)
  41. true_csd = self.csd_profile(result.annotations['x_coords'])
  42. rms = np.linalg.norm(np.array(result[0, :]) - true_csd)
  43. rms /= np.linalg.norm(true_csd)
  44. self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5')
  45. def test_valid_inputs(self):
  46. self.test_method = 'InvalidMethodName'
  47. self.assertRaises(ValueError, self.test_kcsd1d_estimate)
  48. self.test_method = 'KCSD1D'
  49. self.test_params = {'src_type': 22}
  50. self.assertRaises(KeyError, self.test_kcsd1d_estimate)
  51. self.test_method = 'KCSD1D'
  52. self.test_params = {'InvalidKwarg': 21}
  53. self.assertRaises(TypeError, self.test_kcsd1d_estimate)
  54. cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))}
  55. self.assertRaises(TypeError, self.test_kcsd1d_estimate, cv_params)
  56. class KCSD2D_TestCase(unittest.TestCase):
  57. def setUp(self):
  58. xx_ele, yy_ele = utils.generate_electrodes(dim=2, res=9,
  59. xlims=[0.05, 0.95],
  60. ylims=[0.05, 0.95])
  61. self.ele_pos = np.vstack((xx_ele, yy_ele)).T
  62. self.csd_profile = utils.large_source_2D
  63. pots = CSD.generate_lfp(self.csd_profile, xx_ele, yy_ele, res=100)
  64. self.pots = np.reshape(pots, (-1, 1))
  65. self.test_method = 'KCSD2D'
  66. self.test_params = {'gdx': 0.25, 'gdy': 0.25, 'R_init': 0.08,
  67. 'h': 50., 'xmin': 0., 'xmax': 1.,
  68. 'ymin': 0., 'ymax': 1.}
  69. temp_signals = []
  70. for ii in range(len(self.pots)):
  71. temp_signals.append(self.pots[ii])
  72. self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
  73. sampling_rate=1000 * pq.Hz)
  74. chidx = neo.ChannelIndex(range(len(self.pots)))
  75. chidx.analogsignals.append(self.an_sigs)
  76. chidx.coordinates = self.ele_pos * pq.mm
  77. chidx.create_relationship()
  78. def test_kcsd2d_estimate(self, cv_params={}):
  79. self.test_params.update(cv_params)
  80. result = CSD.estimate_csd(self.an_sigs, method=self.test_method,
  81. **self.test_params)
  82. self.assertEqual(result.t_start, 0.0 * pq.s)
  83. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  84. self.assertEqual(result.times, [0.] * pq.s)
  85. self.assertEqual(len(result.annotations.keys()), 2)
  86. true_csd = self.csd_profile(result.annotations['x_coords'],
  87. result.annotations['y_coords'])
  88. rms = np.linalg.norm(np.array(result[0, :]) - true_csd)
  89. rms /= np.linalg.norm(true_csd)
  90. self.assertLess(rms, 0.5, msg='RMS ' + str(rms) +
  91. 'between trueCSD and estimate > 0.5')
  92. def test_moi_estimate(self):
  93. result = CSD.estimate_csd(self.an_sigs, method='MoIKCSD',
  94. MoI_iters=10, lambd=0.0,
  95. gdx=0.2, gdy=0.2)
  96. self.assertEqual(result.t_start, 0.0 * pq.s)
  97. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  98. self.assertEqual(result.times, [0.] * pq.s)
  99. self.assertEqual(len(result.annotations.keys()), 2)
  100. def test_valid_inputs(self):
  101. self.test_method = 'InvalidMethodName'
  102. self.assertRaises(ValueError, self.test_kcsd2d_estimate)
  103. self.test_method = 'KCSD2D'
  104. self.test_params = {'src_type': 22}
  105. self.assertRaises(KeyError, self.test_kcsd2d_estimate)
  106. self.test_params = {'InvalidKwarg': 21}
  107. self.assertRaises(TypeError, self.test_kcsd2d_estimate)
  108. cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))}
  109. self.assertRaises(TypeError, self.test_kcsd2d_estimate, cv_params)
  110. class KCSD3D_TestCase(unittest.TestCase):
  111. def setUp(self):
  112. xx_ele, yy_ele, zz_ele = utils.generate_electrodes(dim=3, res=5,
  113. xlims=[0.15, 0.85],
  114. ylims=[0.15, 0.85],
  115. zlims=[0.15, 0.85])
  116. self.ele_pos = np.vstack((xx_ele, yy_ele, zz_ele)).T
  117. self.csd_profile = utils.gauss_3d_dipole
  118. pots = CSD.generate_lfp(self.csd_profile, xx_ele, yy_ele, zz_ele)
  119. self.pots = np.reshape(pots, (-1, 1))
  120. self.test_method = 'KCSD3D'
  121. self.test_params = {'gdx': 0.05, 'gdy': 0.05, 'gdz': 0.05,
  122. 'lambd': 5.10896977451e-19, 'src_type': 'step',
  123. 'R_init': 0.31, 'xmin': 0., 'xmax': 1., 'ymin': 0.,
  124. 'ymax': 1., 'zmin': 0., 'zmax': 1.}
  125. temp_signals = []
  126. for ii in range(len(self.pots)):
  127. temp_signals.append(self.pots[ii])
  128. self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
  129. sampling_rate=1000 * pq.Hz)
  130. chidx = neo.ChannelIndex(range(len(self.pots)))
  131. chidx.analogsignals.append(self.an_sigs)
  132. chidx.coordinates = self.ele_pos * pq.mm
  133. chidx.create_relationship()
  134. def test_kcsd3d_estimate(self, cv_params={}):
  135. self.test_params.update(cv_params)
  136. result = CSD.estimate_csd(self.an_sigs, method=self.test_method,
  137. **self.test_params)
  138. self.assertEqual(result.t_start, 0.0 * pq.s)
  139. self.assertEqual(result.sampling_rate, 1000 * pq.Hz)
  140. self.assertEqual(result.times, [0.] * pq.s)
  141. self.assertEqual(len(result.annotations.keys()), 3)
  142. true_csd = self.csd_profile(result.annotations['x_coords'],
  143. result.annotations['y_coords'],
  144. result.annotations['z_coords'])
  145. rms = np.linalg.norm(np.array(result[0, :]) - true_csd)
  146. rms /= np.linalg.norm(true_csd)
  147. self.assertLess(rms, 0.5, msg='RMS ' + str(rms) +
  148. ' between trueCSD and estimate > 0.5')
  149. def test_valid_inputs(self):
  150. self.test_method = 'InvalidMethodName'
  151. self.assertRaises(ValueError, self.test_kcsd3d_estimate)
  152. self.test_method = 'KCSD3D'
  153. self.test_params = {'src_type': 22}
  154. self.assertRaises(KeyError, self.test_kcsd3d_estimate)
  155. self.test_params = {'InvalidKwarg': 21}
  156. self.assertRaises(TypeError, self.test_kcsd3d_estimate)
  157. cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))}
  158. self.assertRaises(TypeError, self.test_kcsd3d_estimate, cv_params)
  159. if __name__ == '__main__':
  160. unittest.main()