1
1

test_change_point_detection.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # -*- coding: utf-8 -*-
  2. import neo
  3. import numpy as np
  4. import quantities as pq
  5. import unittest
  6. import elephant.change_point_detection as mft
  7. from numpy.testing.utils import assert_array_almost_equal, assert_allclose
  8. #np.random.seed(13)
  9. class FilterTestCase(unittest.TestCase):
  10. def setUp(self):
  11. self.test_array = [0.4, 0.5, 0.65, 0.7, 0.9, 1.15, 1.2, 1.9]
  12. '''
  13. spks_ri = [0.9, 1.15, 1.2]
  14. spk_le = [0.4, 0.5, 0.65, 0.7]
  15. '''
  16. mu_ri = (0.25 + 0.05) / 2
  17. mu_le = (0.1 + 0.15 + 0.05) / 3
  18. sigma_ri = ((0.25 - 0.15) ** 2 + (0.05 - 0.15) ** 2) / 2
  19. sigma_le = ((0.1 - 0.1) ** 2 + (0.15 - 0.1) ** 2 + (
  20. 0.05 - 0.1) ** 2) / 3
  21. self.targ_t08_h025 = 0
  22. self.targ_t08_h05 = (3 - 4) / np.sqrt(
  23. (sigma_ri / mu_ri ** (3)) * 0.5 + (sigma_le / mu_le ** (3)) * 0.5)
  24. # Window Large #
  25. def test_filter_with_spiketrain_h05(self):
  26. st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.0)
  27. target = self.targ_t08_h05
  28. res = mft._filter(0.8 * pq.s, 0.5 * pq.s, st)
  29. assert_array_almost_equal(res, target, decimal=9)
  30. self.assertRaises(ValueError, mft._filter, 0.8, 0.5 * pq.s, st)
  31. self.assertRaises(ValueError, mft._filter, 0.8 * pq.s, 0.5, st)
  32. self.assertRaises(ValueError, mft._filter, 0.8 * pq.s, 0.5 * pq.s,
  33. self.test_array)
  34. # Window Small #
  35. def test_filter_with_spiketrain_h025(self):
  36. st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.0)
  37. target = self.targ_t08_h025
  38. res = mft._filter(0.8 * pq.s, 0.25 * pq.s, st)
  39. assert_array_almost_equal(res, target, decimal=9)
  40. def test_filter_with_quantities_h025(self):
  41. st = pq.Quantity(self.test_array, units='s')
  42. target = self.targ_t08_h025
  43. res = mft._filter(0.8 * pq.s, 0.25 * pq.s, st)
  44. assert_array_almost_equal(res, target, decimal=9)
  45. def test_filter_with_plain_array_h025(self):
  46. st = self.test_array
  47. target = self.targ_t08_h025
  48. res = mft._filter(0.8 * pq.s, 0.25 * pq.s, st * pq.s)
  49. assert_array_almost_equal(res, target, decimal=9)
  50. def test_isi_with_quantities_h05(self):
  51. st = pq.Quantity(self.test_array, units='s')
  52. target = self.targ_t08_h05
  53. res = mft._filter(0.8 * pq.s, 0.5 * pq.s, st)
  54. assert_array_almost_equal(res, target, decimal=9)
  55. def test_isi_with_plain_array_h05(self):
  56. st = self.test_array
  57. target = self.targ_t08_h05
  58. res = mft._filter(0.8 * pq.s, 0.5 * pq.s, st * pq.s)
  59. assert_array_almost_equal(res, target, decimal=9)
  60. class FilterProcessTestCase(unittest.TestCase):
  61. def setUp(self):
  62. self.test_array = [1.1, 1.2, 1.4, 1.6, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95]
  63. x = (7 - 3) / np.sqrt(
  64. (0.0025 / 0.15 ** 3) * 0.5 + (0.0003472 / 0.05833 ** 3) * 0.5)
  65. self.targ_h05 = [[0.5, 1, 1.5],
  66. [(0 - 1.7) / np.sqrt(0.4), (0 - 1.7) / np.sqrt(0.4),
  67. (x - 1.7) / np.sqrt(0.4)]]
  68. def test_filter_process_with_spiketrain_h05(self):
  69. st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.1)
  70. target = self.targ_h05
  71. res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s,
  72. np.array([[0.5], [1.7], [0.4]]))
  73. assert_array_almost_equal(res[1], target[1], decimal=3)
  74. self.assertRaises(ValueError, mft._filter_process, 0.5 , 0.5 * pq.s,
  75. st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]]))
  76. self.assertRaises(ValueError, mft._filter_process, 0.5 * pq.s, 0.5,
  77. st, 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]]))
  78. self.assertRaises(ValueError, mft._filter_process, 0.5 * pq.s,
  79. 0.5 * pq.s, self.test_array, 2.01 * pq.s,
  80. np.array([[0.5], [1.7], [0.4]]))
  81. def test_filter_proces_with_quantities_h05(self):
  82. st = pq.Quantity(self.test_array, units='s')
  83. target = self.targ_h05
  84. res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st, 2.01 * pq.s,
  85. np.array([[0.5], [1.7], [0.4]]))
  86. assert_array_almost_equal(res[0], target[0], decimal=3)
  87. def test_filter_proces_with_plain_array_h05(self):
  88. st = self.test_array
  89. target = self.targ_h05
  90. res = mft._filter_process(0.5 * pq.s, 0.5 * pq.s, st * pq.s,
  91. 2.01 * pq.s, np.array([[0.5], [1.7], [0.4]]))
  92. self.assertNotIsInstance(res, pq.Quantity)
  93. assert_array_almost_equal(res, target, decimal=3)
  94. class MultipleFilterAlgorithmTestCase(unittest.TestCase):
  95. def setUp(self):
  96. self.test_array = [1.1, 1.2, 1.4, 1.6, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95]
  97. self.targ_h05_dt05 = [1.5 * pq.s]
  98. # to speed up the test, the following `test_param` and `test_quantile`
  99. # paramters have been calculated offline using the function:
  100. # empirical_parameters([10, 25, 50, 75, 100, 125, 150]*pq.s,700*pq.s,5,
  101. # 10000)
  102. # the user should do the same, if the metohd has to be applied to several
  103. # spike trains of the same length `T` and with the same set of window.
  104. self.test_param = np.array([[10., 25., 50., 75., 100., 125., 150.],
  105. [3.167, 2.955, 2.721, 2.548, 2.412, 2.293, 2.180],
  106. [0.150, 0.185, 0.224, 0.249, 0.269, 0.288, 0.301]])
  107. self.test_quantile = 2.75
  108. def test_MultipleFilterAlgorithm_with_spiketrain_h05(self):
  109. st = neo.SpikeTrain(self.test_array, units='s', t_stop=2.1)
  110. target = [self.targ_h05_dt05]
  111. res = mft.multiple_filter_test([0.5] * pq.s, st, 2.1 * pq.s, 5, 100,
  112. dt=0.1 * pq.s)
  113. assert_array_almost_equal(res, target, decimal=9)
  114. def test_MultipleFilterAlgorithm_with_quantities_h05(self):
  115. st = pq.Quantity(self.test_array, units='s')
  116. target = [self.targ_h05_dt05]
  117. res = mft.multiple_filter_test([0.5] * pq.s, st, 2.1 * pq.s, 5, 100,
  118. dt=0.5 * pq.s)
  119. assert_array_almost_equal(res, target, decimal=9)
  120. def test_MultipleFilterAlgorithm_with_plain_array_h05(self):
  121. st = self.test_array
  122. target = [self.targ_h05_dt05]
  123. res = mft.multiple_filter_test([0.5] * pq.s, st * pq.s, 2.1 * pq.s, 5,
  124. 100, dt=0.5 * pq.s)
  125. self.assertNotIsInstance(res, pq.Quantity)
  126. assert_array_almost_equal(res, target, decimal=9)
  127. def test_MultipleFilterAlgorithm_with_longdata(self):
  128. def gamma_train(k, teta, tmax):
  129. x = np.random.gamma(k, teta, int(tmax * (k * teta) ** (-1) * 3))
  130. s = np.cumsum(x)
  131. idx = np.where(s < tmax)
  132. s = s[idx] # gamma process
  133. return s
  134. def alternative_hypothesis(k1, teta1, c1, k2, teta2, c2, k3, teta3, c3,
  135. k4, teta4, T):
  136. s1 = gamma_train(k1, teta1, c1)
  137. s2 = gamma_train(k2, teta2, c2) + c1
  138. s3 = gamma_train(k3, teta3, c3) + c1 + c2
  139. s4 = gamma_train(k4, teta4, T) + c1 + c2 + c3
  140. return np.concatenate((s1, s2, s3, s4)), [s1[-1], s2[-1], s3[-1]]
  141. st = self.h1 = alternative_hypothesis(1, 1 / 4., 150, 2, 1 / 26., 30,
  142. 1, 1 / 36., 320,
  143. 2, 1 / 33., 200)[0]
  144. window_size = [10, 25, 50, 75, 100, 125, 150] * pq.s
  145. self.target_points = [150, 180, 500]
  146. target = self.target_points
  147. result = mft.multiple_filter_test(window_size, st * pq.s, 700 * pq.s, 5,
  148. 10000, test_quantile=self.test_quantile, test_param=self.test_param,
  149. dt=1 * pq.s)
  150. self.assertNotIsInstance(result, pq.Quantity)
  151. result_concatenated = []
  152. for i in result:
  153. result_concatenated = np.hstack([result_concatenated, i])
  154. result_concatenated = np.sort(result_concatenated)
  155. assert_allclose(result_concatenated[:3], target[:3], rtol=0,
  156. atol=5)
  157. print('detected {0} cps: {1}'.format(len(result_concatenated),
  158. result_concatenated))
  159. if __name__ == '__main__':
  160. unittest.main()