test_klustakwikio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.klustakwikio
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import glob
  8. import os.path
  9. import sys
  10. import tempfile
  11. try:
  12. import unittest2 as unittest
  13. except ImportError:
  14. import unittest
  15. import numpy as np
  16. import quantities as pq
  17. import neo
  18. from neo.test.iotest.common_io_test import BaseTestIO
  19. from neo.test.tools import assert_arrays_almost_equal
  20. from neo.io.klustakwikio import KlustaKwikIO, HAVE_MLAB
  21. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  22. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  23. class testFilenameParser(unittest.TestCase):
  24. """Tests that filenames can be loaded with or without basename.
  25. The test directory contains two basenames and some decoy files with
  26. malformed group numbers."""
  27. def setUp(self):
  28. self.dirname = os.path.join(tempfile.gettempdir(),
  29. 'files_for_testing_neo',
  30. 'klustakwik/test1')
  31. if not os.path.exists(self.dirname):
  32. raise unittest.SkipTest('data directory does not exist: ' +
  33. self.dirname)
  34. def test1(self):
  35. """Tests that files can be loaded by basename"""
  36. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename'))
  37. if not BaseTestIO.use_network:
  38. raise unittest.SkipTest("Requires download of data from the web")
  39. fetfiles = kio._fp.read_filenames('fet')
  40. self.assertEqual(len(fetfiles), 2)
  41. self.assertEqual(os.path.abspath(fetfiles[0]),
  42. os.path.abspath(os.path.join(self.dirname,
  43. 'basename.fet.0')))
  44. self.assertEqual(os.path.abspath(fetfiles[1]),
  45. os.path.abspath(os.path.join(self.dirname,
  46. 'basename.fet.1')))
  47. def test2(self):
  48. """Tests that files are loaded even without basename"""
  49. pass
  50. # this test is in flux, should probably have it default to
  51. # basename = os.path.split(dirname)[1] when dirname is a directory
  52. #~ dirname = os.path.normpath('./files_for_tests/klustakwik/test1')
  53. #~ kio = KlustaKwikIO(filename=dirname)
  54. #~ fetfiles = kio._fp.read_filenames('fet')
  55. #~ # It will just choose one of the two basenames, depending on which
  56. #~ # is first, so just assert that it did something without error.
  57. #~ self.assertNotEqual(len(fetfiles), 0)
  58. def test3(self):
  59. """Tests that files can be loaded by basename2"""
  60. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename2'))
  61. if not BaseTestIO.use_network:
  62. raise unittest.SkipTest("Requires download of data from the web")
  63. clufiles = kio._fp.read_filenames('clu')
  64. self.assertEqual(len(clufiles), 1)
  65. self.assertEqual(os.path.abspath(clufiles[1]),
  66. os.path.abspath(os.path.join(self.dirname,
  67. 'basename2.clu.1')))
  68. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  69. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  70. class testRead(unittest.TestCase):
  71. """Tests that data can be read from KlustaKwik files"""
  72. def setUp(self):
  73. self.dirname = os.path.join(tempfile.gettempdir(),
  74. 'files_for_testing_neo',
  75. 'klustakwik/test2')
  76. if not os.path.exists(self.dirname):
  77. raise unittest.SkipTest('data directory does not exist: ' +
  78. self.dirname)
  79. def test1(self):
  80. """Tests that data and metadata are read correctly"""
  81. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base'),
  82. sampling_rate=1000.)
  83. block = kio.read()[0]
  84. seg = block.segments[0]
  85. self.assertEqual(len(seg.spiketrains), 4)
  86. for st in seg.spiketrains:
  87. self.assertEqual(st.units, np.array(1.0) * pq.s)
  88. self.assertEqual(st.t_start, 0.0)
  89. self.assertEqual(seg.spiketrains[0].name, 'unit 1 from group 0')
  90. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 1)
  91. self.assertEqual(seg.spiketrains[0].annotations['group'], 0)
  92. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([.100,
  93. .200])))
  94. self.assertEqual(seg.spiketrains[1].name, 'unit 2 from group 0')
  95. self.assertEqual(seg.spiketrains[1].annotations['cluster'], 2)
  96. self.assertEqual(seg.spiketrains[1].annotations['group'], 0)
  97. self.assertEqual(seg.spiketrains[1].t_start, 0.0)
  98. self.assertTrue(np.all(seg.spiketrains[1].times == np.array([.305])))
  99. self.assertEqual(seg.spiketrains[2].name, 'unit -1 from group 1')
  100. self.assertEqual(seg.spiketrains[2].annotations['cluster'], -1)
  101. self.assertEqual(seg.spiketrains[2].annotations['group'], 1)
  102. self.assertEqual(seg.spiketrains[2].t_start, 0.0)
  103. self.assertTrue(np.all(seg.spiketrains[2].times == np.array([.253])))
  104. self.assertEqual(seg.spiketrains[3].name, 'unit 2 from group 1')
  105. self.assertEqual(seg.spiketrains[3].annotations['cluster'], 2)
  106. self.assertEqual(seg.spiketrains[3].annotations['group'], 1)
  107. self.assertEqual(seg.spiketrains[3].t_start, 0.0)
  108. self.assertTrue(np.all(seg.spiketrains[3].times == np.array([.050,
  109. .152])))
  110. def test2(self):
  111. """Checks that cluster id autosets to 0 without clu file"""
  112. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  113. sampling_rate=1000.)
  114. block = kio.read()[0]
  115. seg = block.segments[0]
  116. self.assertEqual(len(seg.spiketrains), 1)
  117. self.assertEqual(seg.spiketrains[0].name, 'unit 0 from group 5')
  118. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 0)
  119. self.assertEqual(seg.spiketrains[0].annotations['group'], 5)
  120. self.assertEqual(seg.spiketrains[0].t_start, 0.0)
  121. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([0.026,
  122. 0.122,
  123. 0.228])))
  124. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  125. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  126. class testWrite(unittest.TestCase):
  127. def setUp(self):
  128. self.dirname = os.path.join(tempfile.gettempdir(),
  129. 'files_for_testing_neo',
  130. 'klustakwik/test3')
  131. if not os.path.exists(self.dirname):
  132. raise unittest.SkipTest('data directory does not exist: ' +
  133. self.dirname)
  134. def test1(self):
  135. """Create clu and fet files based on spiketrains in a block.
  136. Checks that
  137. Files are created
  138. Converted to samples correctly
  139. Missing sampling rate are taken from IO reader default
  140. Spiketrains without cluster info are assigned to cluster 0
  141. Spiketrains across segments are concatenated
  142. """
  143. block = neo.Block()
  144. segment = neo.Segment()
  145. segment2 = neo.Segment()
  146. block.segments.append(segment)
  147. block.segments.append(segment2)
  148. # Fake spiketrain 1, will be sorted
  149. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  150. st1.annotations['cluster'] = 0
  151. st1.annotations['group'] = 0
  152. segment.spiketrains.append(st1)
  153. # Fake spiketrain 1B, on another segment. No group specified,
  154. # default is 0.
  155. st1B = neo.SpikeTrain(times=[.106], units='s', t_stop=1.)
  156. st1B.annotations['cluster'] = 0
  157. segment2.spiketrains.append(st1B)
  158. # Fake spiketrain 2 on same group, no sampling rate specified
  159. st2 = neo.SpikeTrain(times=[.001, .003, .011], units='s', t_stop=1.)
  160. st2.annotations['cluster'] = 1
  161. st2.annotations['group'] = 0
  162. segment.spiketrains.append(st2)
  163. # Fake spiketrain 3 on new group, with different sampling rate
  164. st3 = neo.SpikeTrain(times=[.05, .09, .10], units='s', t_stop=1.)
  165. st3.annotations['cluster'] = -1
  166. st3.annotations['group'] = 1
  167. segment.spiketrains.append(st3)
  168. # Fake spiketrain 4 on new group, without cluster info
  169. st4 = neo.SpikeTrain(times=[.005, .009], units='s', t_stop=1.)
  170. st4.annotations['group'] = 2
  171. segment.spiketrains.append(st4)
  172. # Create empty directory for writing
  173. delete_test_session()
  174. # Create writer with default sampling rate
  175. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base1'),
  176. sampling_rate=1000.)
  177. kio.write_block(block)
  178. # Check files were created
  179. for fn in ['.fet.0', '.fet.1', '.clu.0', '.clu.1']:
  180. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  181. 'base1' + fn)))
  182. # Check files contain correct content
  183. # Spike times on group 0
  184. data = file(os.path.join(self.dirname, 'base1.fet.0')).readlines()
  185. data = [int(d) for d in data]
  186. self.assertEqual(data, [0, 2, 4, 6, 1, 3, 11, 106])
  187. # Clusters on group 0
  188. data = file(os.path.join(self.dirname, 'base1.clu.0')).readlines()
  189. data = [int(d) for d in data]
  190. self.assertEqual(data, [2, 0, 0, 0, 1, 1, 1, 0])
  191. # Spike times on group 1
  192. data = file(os.path.join(self.dirname, 'base1.fet.1')).readlines()
  193. data = [int(d) for d in data]
  194. self.assertEqual(data, [0, 50, 90, 100])
  195. # Clusters on group 1
  196. data = file(os.path.join(self.dirname, 'base1.clu.1')).readlines()
  197. data = [int(d) for d in data]
  198. self.assertEqual(data, [1, -1, -1, -1])
  199. # Spike times on group 2
  200. data = file(os.path.join(self.dirname, 'base1.fet.2')).readlines()
  201. data = [int(d) for d in data]
  202. self.assertEqual(data, [0, 5, 9])
  203. # Clusters on group 2
  204. data = file(os.path.join(self.dirname, 'base1.clu.2')).readlines()
  205. data = [int(d) for d in data]
  206. self.assertEqual(data, [1, 0, 0])
  207. # Empty out test session again
  208. delete_test_session()
  209. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  210. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  211. class testWriteWithFeatures(unittest.TestCase):
  212. def setUp(self):
  213. self.dirname = os.path.join(tempfile.gettempdir(),
  214. 'files_for_testing_neo',
  215. 'klustakwik/test4')
  216. if not os.path.exists(self.dirname):
  217. raise unittest.SkipTest('data directory does not exist: ' +
  218. self.dirname)
  219. def test1(self):
  220. """Create clu and fet files based on spiketrains in a block.
  221. Checks that
  222. Files are created
  223. Converted to samples correctly
  224. Missing sampling rate are taken from IO reader default
  225. Spiketrains without cluster info are assigned to cluster 0
  226. Spiketrains across segments are concatenated
  227. """
  228. block = neo.Block()
  229. segment = neo.Segment()
  230. segment2 = neo.Segment()
  231. block.segments.append(segment)
  232. block.segments.append(segment2)
  233. # Fake spiketrain 1
  234. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  235. st1.annotations['cluster'] = 0
  236. st1.annotations['group'] = 0
  237. wff = np.array([
  238. [11.3, 0.2],
  239. [-0.3, 12.3],
  240. [3.0, -2.5]])
  241. st1.annotations['waveform_features'] = wff
  242. segment.spiketrains.append(st1)
  243. # Create empty directory for writing
  244. if not os.path.exists(self.dirname):
  245. os.mkdir(self.dirname)
  246. delete_test_session(self.dirname)
  247. # Create writer
  248. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  249. sampling_rate=1000.)
  250. kio.write_block(block)
  251. # Check files were created
  252. for fn in ['.fet.0', '.clu.0']:
  253. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  254. 'base2' + fn)))
  255. # Check files contain correct content
  256. fi = file(os.path.join(self.dirname, 'base2.fet.0'))
  257. # first line is nbFeatures
  258. self.assertEqual(fi.readline(), '2\n')
  259. # Now check waveforms and times are same
  260. data = fi.readlines()
  261. new_wff = []
  262. new_times = []
  263. for line in data:
  264. line_split = line.split()
  265. new_wff.append([float(val) for val in line_split[:-1]])
  266. new_times.append(int(line_split[-1]))
  267. self.assertEqual(new_times, [2, 4, 6])
  268. assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
  269. # Clusters on group 0
  270. data = file(os.path.join(self.dirname, 'base2.clu.0')).readlines()
  271. data = [int(d) for d in data]
  272. self.assertEqual(data, [1, 0, 0, 0])
  273. # Now read the features and test same
  274. block = kio.read_block()
  275. train = block.segments[0].spiketrains[0]
  276. assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
  277. .00001)
  278. # Empty out test session again
  279. delete_test_session(self.dirname)
  280. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  281. @unittest.skipIf(sys.version_info[0] > 2, "not Python 3 compatible")
  282. class CommonTests(BaseTestIO, unittest.TestCase):
  283. ioclass = KlustaKwikIO
  284. # These are the files it tries to read and test for compliance
  285. files_to_test = [
  286. 'test2/base',
  287. 'test2/base2',
  288. ]
  289. # Will fetch from g-node if they don't already exist locally
  290. # How does it know to do this before any of the other tests?
  291. files_to_download = [
  292. 'test1/basename.clu.0',
  293. 'test1/basename.fet.-1',
  294. 'test1/basename.fet.0',
  295. 'test1/basename.fet.1',
  296. 'test1/basename.fet.1a',
  297. 'test1/basename.fet.a1',
  298. 'test1/basename2.clu.1',
  299. 'test1/basename2.fet.1',
  300. 'test1/basename2.fet.1a',
  301. 'test2/base2.fet.5',
  302. 'test2/base.clu.0',
  303. 'test2/base.clu.1',
  304. 'test2/base.fet.0',
  305. 'test2/base.fet.1',
  306. 'test3/base1.clu.0',
  307. 'test3/base1.clu.1',
  308. 'test3/base1.clu.2',
  309. 'test3/base1.fet.0',
  310. 'test3/base1.fet.1',
  311. 'test3/base1.fet.2'
  312. ]
  313. def delete_test_session(dirname=None):
  314. """Removes all file in directory so we can test writing to it"""
  315. if dirname is None:
  316. dirname = os.path.join(os.path.dirname(__file__),
  317. 'files_for_tests/klustakwik/test3')
  318. for fi in glob.glob(os.path.join(dirname, '*')):
  319. os.remove(fi)
  320. if __name__ == '__main__':
  321. unittest.main()