test_klustakwikio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. Tests of neo.io.klustakwikio
  3. """
  4. import glob
  5. import os.path
  6. import sys
  7. import tempfile
  8. import unittest
  9. import numpy as np
  10. import quantities as pq
  11. import neo
  12. from neo.test.iotest.common_io_test import BaseTestIO
  13. from neo.test.tools import assert_arrays_almost_equal
  14. from neo.io.klustakwikio import KlustaKwikIO, HAVE_MLAB
  15. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  16. class testFilenameParser(unittest.TestCase):
  17. """Tests that filenames can be loaded with or without basename.
  18. The test directory contains two basenames and some decoy files with
  19. malformed group numbers."""
  20. def setUp(self):
  21. self.dirname = os.path.join(tempfile.gettempdir(),
  22. 'files_for_testing_neo',
  23. 'klustakwik/test1')
  24. if not os.path.exists(self.dirname):
  25. raise unittest.SkipTest('data directory does not exist: ' +
  26. self.dirname)
  27. def test1(self):
  28. """Tests that files can be loaded by basename"""
  29. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename'))
  30. if not BaseTestIO.use_network:
  31. raise unittest.SkipTest("Requires download of data from the web")
  32. fetfiles = kio._fp.read_filenames('fet')
  33. self.assertEqual(len(fetfiles), 2)
  34. self.assertEqual(os.path.abspath(fetfiles[0]),
  35. os.path.abspath(os.path.join(self.dirname,
  36. 'basename.fet.0')))
  37. self.assertEqual(os.path.abspath(fetfiles[1]),
  38. os.path.abspath(os.path.join(self.dirname,
  39. 'basename.fet.1')))
  40. def test2(self):
  41. """Tests that files are loaded even without basename"""
  42. pass
  43. # this test is in flux, should probably have it default to
  44. # basename = os.path.split(dirname)[1] when dirname is a directory
  45. # dirname = os.path.normpath('./files_for_tests/klustakwik/test1')
  46. # kio = KlustaKwikIO(filename=dirname)
  47. # fetfiles = kio._fp.read_filenames('fet')
  48. # It will just choose one of the two basenames, depending on which
  49. # is first, so just assert that it did something without error.
  50. # self.assertNotEqual(len(fetfiles), 0)
  51. def test3(self):
  52. """Tests that files can be loaded by basename2"""
  53. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename2'))
  54. if not BaseTestIO.use_network:
  55. raise unittest.SkipTest("Requires download of data from the web")
  56. clufiles = kio._fp.read_filenames('clu')
  57. self.assertEqual(len(clufiles), 1)
  58. self.assertEqual(os.path.abspath(clufiles[1]),
  59. os.path.abspath(os.path.join(self.dirname,
  60. 'basename2.clu.1')))
  61. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  62. class testRead(unittest.TestCase):
  63. """Tests that data can be read from KlustaKwik files"""
  64. def setUp(self):
  65. self.dirname = os.path.join(tempfile.gettempdir(),
  66. 'files_for_testing_neo',
  67. 'klustakwik/test2')
  68. if not os.path.exists(self.dirname):
  69. raise unittest.SkipTest('data directory does not exist: ' +
  70. self.dirname)
  71. def test1(self):
  72. """Tests that data and metadata are read correctly"""
  73. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base'),
  74. sampling_rate=1000.)
  75. block = kio.read()[0]
  76. seg = block.segments[0]
  77. self.assertEqual(len(seg.spiketrains), 4)
  78. for st in seg.spiketrains:
  79. self.assertEqual(st.units, np.array(1.0) * pq.s)
  80. self.assertEqual(st.t_start, 0.0)
  81. self.assertEqual(seg.spiketrains[0].name, 'unit 1 from group 0')
  82. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 1)
  83. self.assertEqual(seg.spiketrains[0].annotations['group'], 0)
  84. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([.100,
  85. .200])))
  86. self.assertEqual(seg.spiketrains[1].name, 'unit 2 from group 0')
  87. self.assertEqual(seg.spiketrains[1].annotations['cluster'], 2)
  88. self.assertEqual(seg.spiketrains[1].annotations['group'], 0)
  89. self.assertEqual(seg.spiketrains[1].t_start, 0.0)
  90. self.assertTrue(np.all(seg.spiketrains[1].times == np.array([.305])))
  91. self.assertEqual(seg.spiketrains[2].name, 'unit -1 from group 1')
  92. self.assertEqual(seg.spiketrains[2].annotations['cluster'], -1)
  93. self.assertEqual(seg.spiketrains[2].annotations['group'], 1)
  94. self.assertEqual(seg.spiketrains[2].t_start, 0.0)
  95. self.assertTrue(np.all(seg.spiketrains[2].times == np.array([.253])))
  96. self.assertEqual(seg.spiketrains[3].name, 'unit 2 from group 1')
  97. self.assertEqual(seg.spiketrains[3].annotations['cluster'], 2)
  98. self.assertEqual(seg.spiketrains[3].annotations['group'], 1)
  99. self.assertEqual(seg.spiketrains[3].t_start, 0.0)
  100. self.assertTrue(np.all(seg.spiketrains[3].times == np.array([.050,
  101. .152])))
  102. def test2(self):
  103. """Checks that cluster id autosets to 0 without clu file"""
  104. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  105. sampling_rate=1000.)
  106. block = kio.read()[0]
  107. seg = block.segments[0]
  108. self.assertEqual(len(seg.spiketrains), 1)
  109. self.assertEqual(seg.spiketrains[0].name, 'unit 0 from group 5')
  110. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 0)
  111. self.assertEqual(seg.spiketrains[0].annotations['group'], 5)
  112. self.assertEqual(seg.spiketrains[0].t_start, 0.0)
  113. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([0.026,
  114. 0.122,
  115. 0.228])))
  116. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  117. class testWrite(unittest.TestCase):
  118. def setUp(self):
  119. self.dirname = os.path.join(tempfile.gettempdir(),
  120. 'files_for_testing_neo',
  121. 'klustakwik/test3')
  122. if not os.path.exists(self.dirname):
  123. raise unittest.SkipTest('data directory does not exist: ' +
  124. self.dirname)
  125. def test1(self):
  126. """Create clu and fet files based on spiketrains in a block.
  127. Checks that
  128. Files are created
  129. Converted to samples correctly
  130. Missing sampling rate are taken from IO reader default
  131. Spiketrains without cluster info are assigned to cluster 0
  132. Spiketrains across segments are concatenated
  133. """
  134. block = neo.Block()
  135. segment = neo.Segment()
  136. segment2 = neo.Segment()
  137. block.segments.append(segment)
  138. block.segments.append(segment2)
  139. # Fake spiketrain 1, will be sorted
  140. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  141. st1.annotations['cluster'] = 0
  142. st1.annotations['group'] = 0
  143. segment.spiketrains.append(st1)
  144. # Fake spiketrain 1B, on another segment. No group specified,
  145. # default is 0.
  146. st1B = neo.SpikeTrain(times=[.106], units='s', t_stop=1.)
  147. st1B.annotations['cluster'] = 0
  148. segment2.spiketrains.append(st1B)
  149. # Fake spiketrain 2 on same group, no sampling rate specified
  150. st2 = neo.SpikeTrain(times=[.001, .003, .011], units='s', t_stop=1.)
  151. st2.annotations['cluster'] = 1
  152. st2.annotations['group'] = 0
  153. segment.spiketrains.append(st2)
  154. # Fake spiketrain 3 on new group, with different sampling rate
  155. st3 = neo.SpikeTrain(times=[.05, .09, .10], units='s', t_stop=1.)
  156. st3.annotations['cluster'] = -1
  157. st3.annotations['group'] = 1
  158. segment.spiketrains.append(st3)
  159. # Fake spiketrain 4 on new group, without cluster info
  160. st4 = neo.SpikeTrain(times=[.005, .009], units='s', t_stop=1.)
  161. st4.annotations['group'] = 2
  162. segment.spiketrains.append(st4)
  163. # Create empty directory for writing
  164. delete_test_session()
  165. # Create writer with default sampling rate
  166. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base1'),
  167. sampling_rate=1000.)
  168. kio.write_block(block)
  169. # Check files were created
  170. for fn in ['.fet.0', '.fet.1', '.clu.0', '.clu.1']:
  171. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  172. 'base1' + fn)))
  173. # Check files contain correct content
  174. # Spike times on group 0
  175. with open(os.path.join(self.dirname, 'base1.fet.0'), mode='r') as f:
  176. data = f.readlines()
  177. data = [int(d) for d in data]
  178. self.assertEqual(data, [0, 2, 4, 6, 1, 3, 11, 106])
  179. # Clusters on group 0
  180. with open(os.path.join(self.dirname, 'base1.clu.0'), mode='r') as f:
  181. data = f.readlines()
  182. data = [int(d) for d in data]
  183. self.assertEqual(data, [2, 0, 0, 0, 1, 1, 1, 0])
  184. # Spike times on group 1
  185. with open(os.path.join(self.dirname, 'base1.fet.1'), mode='r') as f:
  186. data = f.readlines()
  187. data = [int(d) for d in data]
  188. self.assertEqual(data, [0, 50, 90, 100])
  189. # Clusters on group 1
  190. with open(os.path.join(self.dirname, 'base1.clu.1')) as f:
  191. data = f.readlines()
  192. data = [int(d) for d in data]
  193. self.assertEqual(data, [1, -1, -1, -1])
  194. # Spike times on group 2
  195. with open(os.path.join(self.dirname, 'base1.fet.2')) as f:
  196. data = f.readlines()
  197. data = [int(d) for d in data]
  198. self.assertEqual(data, [0, 5, 9])
  199. # Clusters on group 2
  200. with open(os.path.join(self.dirname, 'base1.clu.2')) as f:
  201. data = f.readlines()
  202. data = [int(d) for d in data]
  203. self.assertEqual(data, [1, 0, 0])
  204. # Empty out test session again
  205. delete_test_session()
  206. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  207. class testWriteWithFeatures(unittest.TestCase):
  208. def setUp(self):
  209. self.dirname = os.path.join(tempfile.gettempdir(),
  210. 'files_for_testing_neo',
  211. 'klustakwik/test4')
  212. if not os.path.exists(self.dirname):
  213. raise unittest.SkipTest('data directory does not exist: ' +
  214. self.dirname)
  215. def test1(self):
  216. """Create clu and fet files based on spiketrains in a block.
  217. Checks that
  218. Files are created
  219. Converted to samples correctly
  220. Missing sampling rate are taken from IO reader default
  221. Spiketrains without cluster info are assigned to cluster 0
  222. Spiketrains across segments are concatenated
  223. """
  224. block = neo.Block()
  225. segment = neo.Segment()
  226. segment2 = neo.Segment()
  227. block.segments.append(segment)
  228. block.segments.append(segment2)
  229. # Fake spiketrain 1
  230. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  231. st1.annotations['cluster'] = 0
  232. st1.annotations['group'] = 0
  233. wff = np.array([
  234. [11.3, 0.2],
  235. [-0.3, 12.3],
  236. [3.0, -2.5]])
  237. st1.annotations['waveform_features'] = wff
  238. segment.spiketrains.append(st1)
  239. # Create empty directory for writing
  240. if not os.path.exists(self.dirname):
  241. os.mkdir(self.dirname)
  242. delete_test_session(self.dirname)
  243. # Create writer
  244. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  245. sampling_rate=1000.)
  246. kio.write_block(block)
  247. # Check files were created
  248. for fn in ['.fet.0', '.clu.0']:
  249. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  250. 'base2' + fn)))
  251. # Check files contain correct content
  252. fi = file(os.path.join(self.dirname, 'base2.fet.0'))
  253. # first line is nbFeatures
  254. self.assertEqual(fi.readline(), '2\n')
  255. # Now check waveforms and times are same
  256. data = fi.readlines()
  257. new_wff = []
  258. new_times = []
  259. for line in data:
  260. line_split = line.split()
  261. new_wff.append([float(val) for val in line_split[:-1]])
  262. new_times.append(int(line_split[-1]))
  263. self.assertEqual(new_times, [2, 4, 6])
  264. assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
  265. # Clusters on group 0
  266. data = file(os.path.join(self.dirname, 'base2.clu.0')).readlines()
  267. data = [int(d) for d in data]
  268. self.assertEqual(data, [1, 0, 0, 0])
  269. # Now read the features and test same
  270. block = kio.read_block()
  271. train = block.segments[0].spiketrains[0]
  272. assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
  273. .00001)
  274. # Empty out test session again
  275. delete_test_session(self.dirname)
  276. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  277. class CommonTests(BaseTestIO, unittest.TestCase):
  278. ioclass = KlustaKwikIO
  279. # These are the files it tries to read and test for compliance
  280. files_to_test = [
  281. 'test2/base',
  282. 'test2/base2',
  283. ]
  284. # Will fetch from g-node if they don't already exist locally
  285. # How does it know to do this before any of the other tests?
  286. files_to_download = [
  287. 'test1/basename.clu.0',
  288. 'test1/basename.fet.-1',
  289. 'test1/basename.fet.0',
  290. 'test1/basename.fet.1',
  291. 'test1/basename.fet.1a',
  292. 'test1/basename.fet.a1',
  293. 'test1/basename2.clu.1',
  294. 'test1/basename2.fet.1',
  295. 'test1/basename2.fet.1a',
  296. 'test2/base2.fet.5',
  297. 'test2/base.clu.0',
  298. 'test2/base.clu.1',
  299. 'test2/base.fet.0',
  300. 'test2/base.fet.1',
  301. 'test3/base1.clu.0',
  302. 'test3/base1.clu.1',
  303. 'test3/base1.clu.2',
  304. 'test3/base1.fet.0',
  305. 'test3/base1.fet.1',
  306. 'test3/base1.fet.2'
  307. ]
  308. def delete_test_session(dirname=None):
  309. """Removes all file in directory so we can test writing to it"""
  310. if dirname is None:
  311. dirname = os.path.join(os.path.dirname(__file__),
  312. 'files_for_tests/klustakwik/test3')
  313. for fi in glob.glob(os.path.join(dirname, '*')):
  314. os.remove(fi)
  315. if __name__ == '__main__':
  316. unittest.main()