test_cell_assembly_detection.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. Unit test for cell_assembly_detection
  3. """
  4. import unittest
  5. import numpy as np
  6. from numpy.testing.utils import assert_array_equal
  7. import neo
  8. import quantities as pq
  9. import elephant.conversion as conv
  10. import elephant.cell_assembly_detection as cad
  11. class CadTestCase(unittest.TestCase):
  12. def setUp(self):
  13. # Parameters
  14. self.binsize = 1*pq.ms
  15. self.alpha = 0.05
  16. self.size_chunks = 100
  17. self.maxlag = 10
  18. self.reference_lag = 2
  19. self.min_occ = 1
  20. self.max_spikes = np.inf
  21. self.significance_pruning = True
  22. self.subgroup_pruning = True
  23. self.flag_mypruning = False
  24. # Input parameters
  25. # Number of pattern occurrences
  26. self.n_occ1 = 150
  27. self.n_occ2 = 170
  28. self.n_occ3 = 210
  29. # Pattern lags
  30. self.lags1 = [0, 0.001]
  31. self.lags2 = [0, 0.002]
  32. self.lags3 = [0, 0.003]
  33. # Output pattern lags
  34. self.output_lags1 = [0, 1]
  35. self.output_lags2 = [0, 2]
  36. self.output_lags3 = [0, 3]
  37. # Length of the spiketrain
  38. self.t_start = 0
  39. self.t_stop = 1
  40. # Patterns times
  41. np.random.seed(1)
  42. self.patt1_times = neo.SpikeTrain(
  43. np.random.uniform(0, 1 - max(self.lags1), self.n_occ1) * pq.s,
  44. t_start=0*pq.s, t_stop=1*pq.s)
  45. self.patt2_times = neo.SpikeTrain(
  46. np.random.uniform(0, 1 - max(self.lags2), self.n_occ2) * pq.s,
  47. t_start=0*pq.s, t_stop=1*pq.s)
  48. self.patt3_times = neo.SpikeTrain(
  49. np.random.uniform(0, 1 - max(self.lags3), self.n_occ3) * pq.s,
  50. t_start=0*pq.s, t_stop=1*pq.s)
  51. # Patterns
  52. self.patt1 = [self.patt1_times] + [neo.SpikeTrain(
  53. self.patt1_times+l * pq.s, t_start=self.t_start * pq.s,
  54. t_stop=self.t_stop * pq.s) for l in self.lags1]
  55. self.patt2 = [self.patt2_times] + [neo.SpikeTrain(
  56. self.patt2_times+l * pq.s, t_start=self.t_start * pq.s,
  57. t_stop=self.t_stop * pq.s) for l in self.lags2]
  58. self.patt3 = [self.patt3_times] + [neo.SpikeTrain(
  59. self.patt3_times+l * pq.s, t_start=self.t_start * pq.s,
  60. t_stop=self.t_stop * pq.s) for l in self.lags3]
  61. # Binning spiketrains
  62. self.bin_patt1 = conv.BinnedSpikeTrain(self.patt1,
  63. binsize=self.binsize)
  64. # Data
  65. self.msip = self.patt1 + self.patt2 + self.patt3
  66. self.msip = conv.BinnedSpikeTrain(self.msip, binsize=self.binsize)
  67. # Expected results
  68. self.n_spk1 = len(self.lags1) + 1
  69. self.n_spk2 = len(self.lags2) + 1
  70. self.n_spk3 = len(self.lags3) + 1
  71. self.elements1 = range(self.n_spk1)
  72. self.elements2 = range(self.n_spk2)
  73. self.elements3 = range(self.n_spk3)
  74. self.elements_msip = [
  75. self.elements1, range(self.n_spk1, self.n_spk1 + self.n_spk2),
  76. range(self.n_spk1 + self.n_spk2,
  77. self.n_spk1 + self.n_spk2 + self.n_spk3)]
  78. self.occ1 = np.unique(conv.BinnedSpikeTrain(
  79. self.patt1_times, self.binsize).spike_indices[0])
  80. self.occ2 = np.unique(conv.BinnedSpikeTrain(
  81. self.patt2_times, self.binsize).spike_indices[0])
  82. self.occ3 = np.unique(conv.BinnedSpikeTrain(
  83. self.patt3_times, self.binsize).spike_indices[0])
  84. self.occ_msip = [list(self.occ1), list(self.occ2), list(self.occ3)]
  85. self.lags_msip = [self.output_lags1,
  86. self.output_lags2,
  87. self.output_lags3]
  88. # test for single pattern injection input
  89. def test_cad_single_sip(self):
  90. # collecting cad output
  91. output_single = cad.\
  92. cell_assembly_detection(data=self.bin_patt1, maxlag=self.maxlag)
  93. # check neurons in the pattern
  94. assert_array_equal(sorted(output_single[0]['neurons']),
  95. self.elements1)
  96. # check the occurrences time of the patter
  97. assert_array_equal(output_single[0]['times'],
  98. self.occ1)
  99. # check the lags
  100. assert_array_equal(sorted(output_single[0]['lags']),
  101. self.output_lags1)
  102. # test with multiple (3) patterns injected in the data
  103. def test_cad_msip(self):
  104. # collecting cad output
  105. output_msip = cad.\
  106. cell_assembly_detection(data=self.msip, maxlag=self.maxlag)
  107. elements_msip = []
  108. occ_msip = []
  109. lags_msip = []
  110. for out in output_msip:
  111. elements_msip.append(out['neurons'])
  112. occ_msip.append(out['times'])
  113. lags_msip.append(list(out['lags']))
  114. elements_msip = sorted(elements_msip, key=lambda d: len(d))
  115. occ_msip = sorted(occ_msip, key=lambda d: len(d))
  116. lags_msip = sorted(lags_msip, key=lambda d: len(d))
  117. elements_msip = [sorted(e) for e in elements_msip]
  118. # check neurons in the patterns
  119. assert_array_equal(elements_msip, self.elements_msip)
  120. # check the occurrences time of the patters
  121. assert_array_equal(occ_msip[0], self.occ_msip[0])
  122. assert_array_equal(occ_msip[1], self.occ_msip[1])
  123. assert_array_equal(occ_msip[2], self.occ_msip[2])
  124. lags_msip = [sorted(e) for e in lags_msip]
  125. # check the lags
  126. assert_array_equal(lags_msip, self.lags_msip)
  127. # test the errors raised
  128. def test_cad_raise_error(self):
  129. # test error data input format
  130. self.assertRaises(TypeError, cad.cell_assembly_detection,
  131. data=[[1, 2, 3], [3, 4, 5]],
  132. maxlag=self.maxlag)
  133. # test error significance level
  134. self.assertRaises(ValueError, cad.cell_assembly_detection,
  135. data=conv.BinnedSpikeTrain(
  136. [neo.SpikeTrain([1, 2, 3]*pq.s, t_stop=5*pq.s),
  137. neo.SpikeTrain([3, 4, 5]*pq.s, t_stop=5*pq.s)],
  138. binsize=self.binsize),
  139. maxlag=self.maxlag,
  140. alpha=-3)
  141. # test error minimum number of occurrences
  142. self.assertRaises(ValueError, cad.cell_assembly_detection,
  143. data=conv.BinnedSpikeTrain(
  144. [neo.SpikeTrain([1, 2, 3]*pq.s, t_stop=5*pq.s),
  145. neo.SpikeTrain([3, 4, 5]*pq.s, t_stop=5*pq.s)],
  146. binsize=self.binsize),
  147. maxlag=self.maxlag,
  148. min_occ=-1)
  149. # test error minimum number of spikes in a pattern
  150. self.assertRaises(ValueError, cad.cell_assembly_detection,
  151. data=conv.BinnedSpikeTrain(
  152. [neo.SpikeTrain([1, 2, 3]*pq.s, t_stop=5*pq.s),
  153. neo.SpikeTrain([3, 4, 5]*pq.s, t_stop=5*pq.s)],
  154. binsize=self.binsize),
  155. maxlag=self.maxlag,
  156. max_spikes=1)
  157. # test error chunk size for variance computation
  158. self.assertRaises(ValueError, cad.cell_assembly_detection,
  159. data=conv.BinnedSpikeTrain(
  160. [neo.SpikeTrain([1, 2, 3]*pq.s, t_stop=5*pq.s),
  161. neo.SpikeTrain([3, 4, 5]*pq.s, t_stop=5*pq.s)],
  162. binsize=self.binsize),
  163. maxlag=self.maxlag,
  164. size_chunks=1)
  165. # test error maximum lag
  166. self.assertRaises(ValueError, cad.cell_assembly_detection,
  167. data=conv.BinnedSpikeTrain(
  168. [neo.SpikeTrain([1, 2, 3]*pq.s, t_stop=5*pq.s),
  169. neo.SpikeTrain([3, 4, 5]*pq.s, t_stop=5*pq.s)],
  170. binsize=self.binsize),
  171. maxlag=1)
  172. # test error minimum length spike train
  173. self.assertRaises(ValueError, cad.cell_assembly_detection,
  174. data=conv.BinnedSpikeTrain(
  175. [neo.SpikeTrain([1, 2, 3]*pq.ms, t_stop=6*pq.ms),
  176. neo.SpikeTrain([3, 4, 5]*pq.ms,
  177. t_stop=6*pq.ms)],
  178. binsize=1*pq.ms),
  179. maxlag=self.maxlag)
  180. def suite():
  181. suite = unittest.makeSuite(CadTestCase, 'test')
  182. return suite
  183. if __name__ == "__main__":
  184. runner = unittest.TextTestRunner(verbosity=2)
  185. runner.run(suite())