test_phase_analysis.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # -*- coding: utf-8 -*-
  2. """
  3. Unit tests for the phase analysis module.
  4. :copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division, print_function
  8. import unittest
  9. from neo import SpikeTrain, AnalogSignal
  10. import numpy as np
  11. import quantities as pq
  12. import elephant.phase_analysis
  13. from numpy.ma.testutils import assert_allclose
  14. class SpikeTriggeredPhaseTestCase(unittest.TestCase):
  15. def setUp(self):
  16. tlen0 = 100 * pq.s
  17. f0 = 20. * pq.Hz
  18. fs0 = 1 * pq.ms
  19. t0 = np.arange(
  20. 0, tlen0.rescale(pq.s).magnitude,
  21. fs0.rescale(pq.s).magnitude) * pq.s
  22. self.anasig0 = AnalogSignal(
  23. np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
  24. units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
  25. self.st0 = SpikeTrain(
  26. np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms,
  27. t_start=0 * pq.ms, t_stop=tlen0)
  28. self.st1 = SpikeTrain(
  29. [100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms,
  30. t_start=0 * pq.ms, t_stop=tlen0)
  31. def test_perfect_locking_one_spiketrain_one_signal(self):
  32. phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
  33. elephant.signal_processing.hilbert(self.anasig0),
  34. self.st0,
  35. interpolate=True)
  36. assert_allclose(phases[0], - np.pi / 2.)
  37. assert_allclose(amps[0], 1, atol=0.1)
  38. assert_allclose(times[0].magnitude, self.st0.magnitude)
  39. self.assertEqual(len(phases[0]), len(self.st0))
  40. self.assertEqual(len(amps[0]), len(self.st0))
  41. self.assertEqual(len(times[0]), len(self.st0))
  42. def test_perfect_locking_many_spiketrains_many_signals(self):
  43. phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
  44. [
  45. elephant.signal_processing.hilbert(self.anasig0),
  46. elephant.signal_processing.hilbert(self.anasig0)],
  47. [self.st0, self.st0],
  48. interpolate=True)
  49. assert_allclose(phases[0], -np.pi / 2.)
  50. assert_allclose(amps[0], 1, atol=0.1)
  51. assert_allclose(times[0].magnitude, self.st0.magnitude)
  52. self.assertEqual(len(phases[0]), len(self.st0))
  53. self.assertEqual(len(amps[0]), len(self.st0))
  54. self.assertEqual(len(times[0]), len(self.st0))
  55. def test_perfect_locking_one_spiketrains_many_signals(self):
  56. phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
  57. [
  58. elephant.signal_processing.hilbert(self.anasig0),
  59. elephant.signal_processing.hilbert(self.anasig0)],
  60. [self.st0],
  61. interpolate=True)
  62. assert_allclose(phases[0], -np.pi / 2.)
  63. assert_allclose(amps[0], 1, atol=0.1)
  64. assert_allclose(times[0].magnitude, self.st0.magnitude)
  65. self.assertEqual(len(phases[0]), len(self.st0))
  66. self.assertEqual(len(amps[0]), len(self.st0))
  67. self.assertEqual(len(times[0]), len(self.st0))
  68. def test_perfect_locking_many_spiketrains_one_signal(self):
  69. phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
  70. elephant.signal_processing.hilbert(self.anasig0),
  71. [self.st0, self.st0],
  72. interpolate=True)
  73. assert_allclose(phases[0], -np.pi / 2.)
  74. assert_allclose(amps[0], 1, atol=0.1)
  75. assert_allclose(times[0].magnitude, self.st0.magnitude)
  76. self.assertEqual(len(phases[0]), len(self.st0))
  77. self.assertEqual(len(amps[0]), len(self.st0))
  78. self.assertEqual(len(times[0]), len(self.st0))
  79. def test_interpolate(self):
  80. phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase(
  81. elephant.signal_processing.hilbert(self.anasig0),
  82. self.st1,
  83. interpolate=True)
  84. self.assertLess(phases_int[0][0], phases_int[0][1])
  85. self.assertLess(phases_int[0][1], phases_int[0][2])
  86. self.assertLess(phases_int[0][2], phases_int[0][3])
  87. self.assertLess(phases_int[0][3], phases_int[0][4])
  88. self.assertLess(phases_int[0][4], phases_int[0][5])
  89. phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
  90. elephant.signal_processing.hilbert(self.anasig0),
  91. self.st1,
  92. interpolate=False)
  93. self.assertEqual(phases_noint[0][0], phases_noint[0][1])
  94. self.assertEqual(phases_noint[0][1], phases_noint[0][2])
  95. self.assertEqual(phases_noint[0][2], phases_noint[0][3])
  96. self.assertEqual(phases_noint[0][3], phases_noint[0][4])
  97. self.assertNotEqual(phases_noint[0][4], phases_noint[0][5])
  98. # Verify that when using interpolation and the spike sits on the sample
  99. # of the Hilbert transform, this is the same result as when not using
  100. # interpolation with a spike slightly to the right
  101. self.assertEqual(phases_noint[0][2], phases_int[0][0])
  102. self.assertEqual(phases_noint[0][4], phases_int[0][0])
  103. def test_inconsistent_numbers_spiketrains_hilbert(self):
  104. self.assertRaises(
  105. ValueError, elephant.phase_analysis.spike_triggered_phase,
  106. [
  107. elephant.signal_processing.hilbert(self.anasig0),
  108. elephant.signal_processing.hilbert(self.anasig0)],
  109. [self.st0, self.st0, self.st0], False)
  110. self.assertRaises(
  111. ValueError, elephant.phase_analysis.spike_triggered_phase,
  112. [
  113. elephant.signal_processing.hilbert(self.anasig0),
  114. elephant.signal_processing.hilbert(self.anasig0)],
  115. [self.st0, self.st0, self.st0], False)
  116. def test_spike_earlier_than_hilbert(self):
  117. # This is a spike clearly outside the bounds
  118. st = SpikeTrain(
  119. [-50, 50],
  120. units='s', t_start=-100*pq.s, t_stop=100*pq.s)
  121. phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
  122. elephant.signal_processing.hilbert(self.anasig0),
  123. st,
  124. interpolate=False)
  125. self.assertEqual(len(phases_noint[0]), 1)
  126. # This is a spike right on the border (start of the signal is at 0s,
  127. # spike sits at t=0s). By definition of intervals in
  128. # Elephant (left borders inclusive, right borders exclusive), this
  129. # spike is to be considered.
  130. st = SpikeTrain(
  131. [0, 50],
  132. units='s', t_start=-100*pq.s, t_stop=100*pq.s)
  133. phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
  134. elephant.signal_processing.hilbert(self.anasig0),
  135. st,
  136. interpolate=False)
  137. self.assertEqual(len(phases_noint[0]), 2)
  138. def test_spike_later_than_hilbert(self):
  139. # This is a spike clearly outside the bounds
  140. st = SpikeTrain(
  141. [1, 250],
  142. units='s', t_start=-1*pq.s, t_stop=300*pq.s)
  143. phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
  144. elephant.signal_processing.hilbert(self.anasig0),
  145. st,
  146. interpolate=False)
  147. self.assertEqual(len(phases_noint[0]), 1)
  148. # This is a spike right on the border (length of the signal is 100s,
  149. # spike sits at t=100s). However, by definition of intervals in
  150. # Elephant (left borders inclusive, right borders exclusive), this
  151. # spike is not to be considered.
  152. st = SpikeTrain(
  153. [1, 100],
  154. units='s', t_start=-1*pq.s, t_stop=200*pq.s)
  155. phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
  156. elephant.signal_processing.hilbert(self.anasig0),
  157. st,
  158. interpolate=False)
  159. self.assertEqual(len(phases_noint[0]), 1)
  160. if __name__ == '__main__':
  161. unittest.main()