test_spade.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """
  2. Unit tests for the spade module.
  3. :copyright: Copyright 2014-2016 by the Elephant team, see AUTHORS.txt.
  4. :license: Modified BSD, see LICENSE.txt for details.
  5. """
  6. from __future__ import division
  7. import unittest
  8. import neo
  9. import numpy as np
  10. from numpy.testing.utils import assert_array_equal
  11. import quantities as pq
  12. import elephant.spade as spade
  13. import elephant.conversion as conv
  14. import elephant.spike_train_generation as stg
  15. try:
  16. from elephant.spade_src import fim
  17. HAVE_FIM = True
  18. except ImportError:
  19. HAVE_FIM = False
  20. class SpadeTestCase(unittest.TestCase):
  21. def setUp(self):
  22. # Spade parameters
  23. self.binsize = 1 * pq.ms
  24. self.winlen = 10
  25. self.n_subset = 10
  26. self.n_surr = 10
  27. self.alpha = 0.05
  28. self.stability_thresh = [0.1, 0.1]
  29. self.psr_param = [0, 0, 0]
  30. self.min_occ = 4
  31. self.min_spikes = 4
  32. self.min_neu = 4
  33. # Test data parameters
  34. # CPP parameters
  35. self.n_neu = 100
  36. self.amplitude = [0] * self.n_neu + [1]
  37. self.cpp = stg.cpp(rate=3*pq.Hz, A=self.amplitude, t_stop=5*pq.s)
  38. # Number of patterns' occurrences
  39. self.n_occ1 = 10
  40. self.n_occ2 = 12
  41. self.n_occ3 = 15
  42. # Patterns lags
  43. self.lags1 = [2]
  44. self.lags2 = [1, 2]
  45. self.lags3 = [1, 2, 3, 4, 5]
  46. # Length of the spiketrain
  47. self.t_stop = 3000
  48. # Patterns times
  49. self.patt1_times = neo.SpikeTrain(
  50. np.arange(
  51. 0, 1000, 1000//self.n_occ1) *
  52. pq.ms, t_stop=self.t_stop*pq.ms)
  53. self.patt2_times = neo.SpikeTrain(
  54. np.arange(
  55. 1000, 2000, 1000 // self.n_occ2) *
  56. pq.ms, t_stop=self.t_stop * pq.ms)
  57. self.patt3_times = neo.SpikeTrain(
  58. np.arange(
  59. 2000, 3000, 1000 // self.n_occ3) *
  60. pq.ms, t_stop=self.t_stop * pq.ms)
  61. # Patterns
  62. self.patt1 = [self.patt1_times] + [neo.SpikeTrain(
  63. self.patt1_times.view(pq.Quantity)+l * pq.ms,
  64. t_stop=self.t_stop*pq.ms) for l in self.lags1]
  65. self.patt2 = [self.patt2_times] + [neo.SpikeTrain(
  66. self.patt2_times.view(pq.Quantity)+l * pq.ms,
  67. t_stop=self.t_stop*pq.ms) for l in self.lags2]
  68. self.patt3 = [self.patt3_times] + [neo.SpikeTrain(
  69. self.patt3_times.view(pq.Quantity)+l * pq.ms,
  70. t_stop=self.t_stop*pq.ms) for l in self.lags3]
  71. # Data
  72. self.msip = self.patt1 + self.patt2 + self.patt3
  73. # Expected results
  74. self.n_spk1 = len(self.lags1) + 1
  75. self.n_spk2 = len(self.lags2) + 1
  76. self.n_spk3 = len(self.lags3) + 1
  77. self.elements1 = list(range(self.n_spk1))
  78. self.elements2 = list(range(self.n_spk2))
  79. self.elements3 = list(range(self.n_spk3))
  80. self.elements_msip = [
  81. self.elements1, list(range(self.n_spk1, self.n_spk1 + self.n_spk2)),
  82. list(range(self.n_spk1 + self.n_spk2, self.n_spk1 +
  83. self.n_spk2 + self.n_spk3))]
  84. self.occ1 = np.unique(conv.BinnedSpikeTrain(
  85. self.patt1_times, self.binsize).spike_indices[0])
  86. self.occ2 = np.unique(conv.BinnedSpikeTrain(
  87. self.patt2_times, self.binsize).spike_indices[0])
  88. self.occ3 = np.unique(conv.BinnedSpikeTrain(
  89. self.patt3_times, self.binsize).spike_indices[0])
  90. self.occ_msip = [
  91. list(self.occ1), list(self.occ2), list(self.occ3)]
  92. self.lags_msip = [self.lags1, self.lags2, self.lags3]
  93. # Testing cpp
  94. def test_spade_cpp(self):
  95. output_cpp = spade.spade(self.cpp, self.binsize,
  96. 1,
  97. n_subsets=self.n_subset,
  98. stability_thresh=self.stability_thresh,
  99. n_surr=self.n_surr, alpha=self.alpha,
  100. psr_param=self.psr_param,
  101. output_format='patterns')['patterns']
  102. elements_cpp = []
  103. lags_cpp = []
  104. # collecting spade output
  105. for out in output_cpp:
  106. elements_cpp.append(sorted(out['neurons']))
  107. lags_cpp.append(list(out['lags'].magnitude))
  108. # check neurons in the patterns
  109. assert_array_equal(elements_cpp, [range(self.n_neu)])
  110. # check the lags
  111. assert_array_equal(lags_cpp, [np.array([0]*(self.n_neu - 1))])
  112. # Testing spectrum cpp
  113. def test_spade_cpp(self):
  114. # Computing Spectrum
  115. spectrum_cpp = spade.concepts_mining(self.cpp, self.binsize,
  116. 1,report='#')[0]
  117. # Check spectrum
  118. assert_array_equal(spectrum_cpp, [(len(self.cpp), len(self.cpp[0]), 1)])
  119. # Testing with multiple patterns input
  120. def test_spade_msip(self):
  121. output_msip = spade.spade(self.msip, self.binsize,
  122. self.winlen,
  123. n_subsets=self.n_subset,
  124. stability_thresh=self.stability_thresh,
  125. n_surr=self.n_surr, alpha=self.alpha,
  126. psr_param=self.psr_param,
  127. output_format='patterns')['patterns']
  128. elements_msip = []
  129. occ_msip = []
  130. lags_msip = []
  131. # collecting spade output
  132. for out in output_msip:
  133. elements_msip.append(out['neurons'])
  134. occ_msip.append(list(out['times'].magnitude))
  135. lags_msip.append(list(out['lags'].magnitude))
  136. elements_msip = sorted(elements_msip, key=lambda d: len(d))
  137. occ_msip = sorted(occ_msip, key=lambda d: len(d))
  138. lags_msip = sorted(lags_msip, key=lambda d: len(d))
  139. # check neurons in the patterns
  140. assert_array_equal(elements_msip, self.elements_msip)
  141. # check the occurrences time of the patters
  142. assert_array_equal(occ_msip, self.occ_msip)
  143. # check the lags
  144. assert_array_equal(lags_msip, self.lags_msip)
  145. # test under different configuration of parameters than the default one
  146. def test_parameters(self):
  147. # test min_spikes parameter
  148. output_msip_min_spikes = spade.spade(self.msip, self.binsize,
  149. self.winlen,
  150. n_subsets=self.n_subset,
  151. n_surr=self.n_surr, alpha=self.alpha,
  152. min_spikes=self.min_spikes,
  153. psr_param=self.psr_param,
  154. output_format='patterns')['patterns']
  155. # collecting spade output
  156. elements_msip_min_spikes= []
  157. for out in output_msip_min_spikes:
  158. elements_msip_min_spikes.append(out['neurons'])
  159. elements_msip_min_spikes = sorted(elements_msip_min_spikes, key=lambda d: len(d))
  160. lags_msip_min_spikes= []
  161. for out in output_msip_min_spikes:
  162. lags_msip_min_spikes.append(list(out['lags'].magnitude))
  163. lags_msip_min_spikes = sorted(lags_msip_min_spikes, key=lambda d: len(d))
  164. # check the lags
  165. assert_array_equal(lags_msip_min_spikes, [
  166. l for l in self.lags_msip if len(l)+1>=self.min_spikes])
  167. # check the neurons in the patterns
  168. assert_array_equal(elements_msip_min_spikes, [
  169. el for el in self.elements_msip if len(el)>=self.min_neu and len(
  170. el)>=self.min_spikes])
  171. # test min_occ parameter
  172. output_msip_min_occ = spade.spade(self.msip, self.binsize,
  173. self.winlen,
  174. n_subsets=self.n_subset,
  175. n_surr=self.n_surr, alpha=self.alpha,
  176. min_occ=self.min_occ,
  177. psr_param=self.psr_param,
  178. output_format='patterns')['patterns']
  179. # collect spade output
  180. occ_msip_min_occ= []
  181. for out in output_msip_min_occ:
  182. occ_msip_min_occ.append(list(out['times'].magnitude))
  183. occ_msip_min_occ = sorted(occ_msip_min_occ, key=lambda d: len(d))
  184. # test occurrences time
  185. assert_array_equal(occ_msip_min_occ, [
  186. occ for occ in self.occ_msip if len(occ)>=self.min_occ])
  187. # test to compare the python and the C implementation of FIM
  188. # skip this test if C code not available
  189. @unittest.skipIf(HAVE_FIM == False, 'Requires fim.so')
  190. def test_fpgrowth_fca(self):
  191. binary_matrix = conv.BinnedSpikeTrain(
  192. self.patt1, self.binsize).to_bool_array()
  193. context, transactions, rel_matrix = spade._build_context(
  194. binary_matrix, self.winlen)
  195. # mining the data with python fast_fca
  196. mining_results_fpg = spade._fpgrowth(
  197. transactions,
  198. rel_matrix=rel_matrix)
  199. # mining the data with C fim
  200. mining_results_ffca = spade._fast_fca(context)
  201. # testing that the outputs are identical
  202. assert_array_equal(sorted(mining_results_ffca[0][0]), sorted(
  203. mining_results_fpg[0][0]))
  204. assert_array_equal(sorted(mining_results_ffca[0][1]), sorted(
  205. mining_results_fpg[0][1]))
  206. # test the errors raised
  207. def test_spade_raise_error(self):
  208. self.assertRaises(TypeError, spade.spade, [[1,2,3],[3,4,5]], 1*pq.ms, 4)
  209. self.assertRaises(AttributeError, spade.spade, [neo.SpikeTrain(
  210. [1,2,3]*pq.s, t_stop=5*pq.s), neo.SpikeTrain(
  211. [3,4,5]*pq.s, t_stop=6*pq.s)], 1*pq.ms, 4)
  212. self.assertRaises(AttributeError, spade.spade, [neo.SpikeTrain(
  213. [1, 2, 3] * pq.s, t_stop=5 * pq.s), neo.SpikeTrain(
  214. [3, 4, 5] * pq.s, t_stop=5 * pq.s)], 1 * pq.ms, 4, min_neu=-3)
  215. self.assertRaises(AttributeError, spade.pvalue_spectrum, [
  216. neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=5 * pq.s), neo.SpikeTrain(
  217. [3, 4, 5] * pq.s, t_stop=5 * pq.s)], 1 * pq.ms, 4, 3*pq.ms,
  218. n_surr=-3)
  219. self.assertRaises(AttributeError, spade.test_signature_significance, (
  220. (2, 3, 0.2), (2, 4, 0.1)), 0.01, corr='try')
  221. self.assertRaises(AttributeError, spade.approximate_stability, (),
  222. np.array([]), n_subsets=-3)
  223. def suite():
  224. suite = unittest.makeSuite(SpadeTestCase, 'test')
  225. return suite
  226. if __name__ == "__main__":
  227. runner = unittest.TextTestRunner(verbosity=2)
  228. runner.run(suite())
  229. globals()