test_klustakwikio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.klustakwikio
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. import glob
  8. import os.path
  9. import sys
  10. import tempfile
  11. import unittest
  12. import numpy as np
  13. import quantities as pq
  14. import neo
  15. from neo.test.iotest.common_io_test import BaseTestIO
  16. from neo.test.tools import assert_arrays_almost_equal
  17. from neo.io.klustakwikio import KlustaKwikIO, HAVE_MLAB
  18. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  19. class testFilenameParser(unittest.TestCase):
  20. """Tests that filenames can be loaded with or without basename.
  21. The test directory contains two basenames and some decoy files with
  22. malformed group numbers."""
  23. def setUp(self):
  24. self.dirname = os.path.join(tempfile.gettempdir(),
  25. 'files_for_testing_neo',
  26. 'klustakwik/test1')
  27. if not os.path.exists(self.dirname):
  28. raise unittest.SkipTest('data directory does not exist: ' +
  29. self.dirname)
  30. def test1(self):
  31. """Tests that files can be loaded by basename"""
  32. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename'))
  33. if not BaseTestIO.use_network:
  34. raise unittest.SkipTest("Requires download of data from the web")
  35. fetfiles = kio._fp.read_filenames('fet')
  36. self.assertEqual(len(fetfiles), 2)
  37. self.assertEqual(os.path.abspath(fetfiles[0]),
  38. os.path.abspath(os.path.join(self.dirname,
  39. 'basename.fet.0')))
  40. self.assertEqual(os.path.abspath(fetfiles[1]),
  41. os.path.abspath(os.path.join(self.dirname,
  42. 'basename.fet.1')))
  43. def test2(self):
  44. """Tests that files are loaded even without basename"""
  45. pass
  46. # this test is in flux, should probably have it default to
  47. # basename = os.path.split(dirname)[1] when dirname is a directory
  48. # dirname = os.path.normpath('./files_for_tests/klustakwik/test1')
  49. # kio = KlustaKwikIO(filename=dirname)
  50. # fetfiles = kio._fp.read_filenames('fet')
  51. # It will just choose one of the two basenames, depending on which
  52. # is first, so just assert that it did something without error.
  53. # self.assertNotEqual(len(fetfiles), 0)
  54. def test3(self):
  55. """Tests that files can be loaded by basename2"""
  56. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'basename2'))
  57. if not BaseTestIO.use_network:
  58. raise unittest.SkipTest("Requires download of data from the web")
  59. clufiles = kio._fp.read_filenames('clu')
  60. self.assertEqual(len(clufiles), 1)
  61. self.assertEqual(os.path.abspath(clufiles[1]),
  62. os.path.abspath(os.path.join(self.dirname,
  63. 'basename2.clu.1')))
  64. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  65. class testRead(unittest.TestCase):
  66. """Tests that data can be read from KlustaKwik files"""
  67. def setUp(self):
  68. self.dirname = os.path.join(tempfile.gettempdir(),
  69. 'files_for_testing_neo',
  70. 'klustakwik/test2')
  71. if not os.path.exists(self.dirname):
  72. raise unittest.SkipTest('data directory does not exist: ' +
  73. self.dirname)
  74. def test1(self):
  75. """Tests that data and metadata are read correctly"""
  76. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base'),
  77. sampling_rate=1000.)
  78. block = kio.read()[0]
  79. seg = block.segments[0]
  80. self.assertEqual(len(seg.spiketrains), 4)
  81. for st in seg.spiketrains:
  82. self.assertEqual(st.units, np.array(1.0) * pq.s)
  83. self.assertEqual(st.t_start, 0.0)
  84. self.assertEqual(seg.spiketrains[0].name, 'unit 1 from group 0')
  85. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 1)
  86. self.assertEqual(seg.spiketrains[0].annotations['group'], 0)
  87. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([.100,
  88. .200])))
  89. self.assertEqual(seg.spiketrains[1].name, 'unit 2 from group 0')
  90. self.assertEqual(seg.spiketrains[1].annotations['cluster'], 2)
  91. self.assertEqual(seg.spiketrains[1].annotations['group'], 0)
  92. self.assertEqual(seg.spiketrains[1].t_start, 0.0)
  93. self.assertTrue(np.all(seg.spiketrains[1].times == np.array([.305])))
  94. self.assertEqual(seg.spiketrains[2].name, 'unit -1 from group 1')
  95. self.assertEqual(seg.spiketrains[2].annotations['cluster'], -1)
  96. self.assertEqual(seg.spiketrains[2].annotations['group'], 1)
  97. self.assertEqual(seg.spiketrains[2].t_start, 0.0)
  98. self.assertTrue(np.all(seg.spiketrains[2].times == np.array([.253])))
  99. self.assertEqual(seg.spiketrains[3].name, 'unit 2 from group 1')
  100. self.assertEqual(seg.spiketrains[3].annotations['cluster'], 2)
  101. self.assertEqual(seg.spiketrains[3].annotations['group'], 1)
  102. self.assertEqual(seg.spiketrains[3].t_start, 0.0)
  103. self.assertTrue(np.all(seg.spiketrains[3].times == np.array([.050,
  104. .152])))
  105. def test2(self):
  106. """Checks that cluster id autosets to 0 without clu file"""
  107. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  108. sampling_rate=1000.)
  109. block = kio.read()[0]
  110. seg = block.segments[0]
  111. self.assertEqual(len(seg.spiketrains), 1)
  112. self.assertEqual(seg.spiketrains[0].name, 'unit 0 from group 5')
  113. self.assertEqual(seg.spiketrains[0].annotations['cluster'], 0)
  114. self.assertEqual(seg.spiketrains[0].annotations['group'], 5)
  115. self.assertEqual(seg.spiketrains[0].t_start, 0.0)
  116. self.assertTrue(np.all(seg.spiketrains[0].times == np.array([0.026,
  117. 0.122,
  118. 0.228])))
  119. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  120. class testWrite(unittest.TestCase):
  121. def setUp(self):
  122. self.dirname = os.path.join(tempfile.gettempdir(),
  123. 'files_for_testing_neo',
  124. 'klustakwik/test3')
  125. if not os.path.exists(self.dirname):
  126. raise unittest.SkipTest('data directory does not exist: ' +
  127. self.dirname)
  128. def test1(self):
  129. """Create clu and fet files based on spiketrains in a block.
  130. Checks that
  131. Files are created
  132. Converted to samples correctly
  133. Missing sampling rate are taken from IO reader default
  134. Spiketrains without cluster info are assigned to cluster 0
  135. Spiketrains across segments are concatenated
  136. """
  137. block = neo.Block()
  138. segment = neo.Segment()
  139. segment2 = neo.Segment()
  140. block.segments.append(segment)
  141. block.segments.append(segment2)
  142. # Fake spiketrain 1, will be sorted
  143. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  144. st1.annotations['cluster'] = 0
  145. st1.annotations['group'] = 0
  146. segment.spiketrains.append(st1)
  147. # Fake spiketrain 1B, on another segment. No group specified,
  148. # default is 0.
  149. st1B = neo.SpikeTrain(times=[.106], units='s', t_stop=1.)
  150. st1B.annotations['cluster'] = 0
  151. segment2.spiketrains.append(st1B)
  152. # Fake spiketrain 2 on same group, no sampling rate specified
  153. st2 = neo.SpikeTrain(times=[.001, .003, .011], units='s', t_stop=1.)
  154. st2.annotations['cluster'] = 1
  155. st2.annotations['group'] = 0
  156. segment.spiketrains.append(st2)
  157. # Fake spiketrain 3 on new group, with different sampling rate
  158. st3 = neo.SpikeTrain(times=[.05, .09, .10], units='s', t_stop=1.)
  159. st3.annotations['cluster'] = -1
  160. st3.annotations['group'] = 1
  161. segment.spiketrains.append(st3)
  162. # Fake spiketrain 4 on new group, without cluster info
  163. st4 = neo.SpikeTrain(times=[.005, .009], units='s', t_stop=1.)
  164. st4.annotations['group'] = 2
  165. segment.spiketrains.append(st4)
  166. # Create empty directory for writing
  167. delete_test_session()
  168. # Create writer with default sampling rate
  169. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base1'),
  170. sampling_rate=1000.)
  171. kio.write_block(block)
  172. # Check files were created
  173. for fn in ['.fet.0', '.fet.1', '.clu.0', '.clu.1']:
  174. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  175. 'base1' + fn)))
  176. # Check files contain correct content
  177. # Spike times on group 0
  178. with open(os.path.join(self.dirname, 'base1.fet.0'), mode='r') as f:
  179. data = f.readlines()
  180. data = [int(d) for d in data]
  181. self.assertEqual(data, [0, 2, 4, 6, 1, 3, 11, 106])
  182. # Clusters on group 0
  183. with open(os.path.join(self.dirname, 'base1.clu.0'), mode='r') as f:
  184. data = f.readlines()
  185. data = [int(d) for d in data]
  186. self.assertEqual(data, [2, 0, 0, 0, 1, 1, 1, 0])
  187. # Spike times on group 1
  188. with open(os.path.join(self.dirname, 'base1.fet.1'), mode='r') as f:
  189. data = f.readlines()
  190. data = [int(d) for d in data]
  191. self.assertEqual(data, [0, 50, 90, 100])
  192. # Clusters on group 1
  193. with open(os.path.join(self.dirname, 'base1.clu.1')) as f:
  194. data = f.readlines()
  195. data = [int(d) for d in data]
  196. self.assertEqual(data, [1, -1, -1, -1])
  197. # Spike times on group 2
  198. with open(os.path.join(self.dirname, 'base1.fet.2')) as f:
  199. data = f.readlines()
  200. data = [int(d) for d in data]
  201. self.assertEqual(data, [0, 5, 9])
  202. # Clusters on group 2
  203. with open(os.path.join(self.dirname, 'base1.clu.2')) as f:
  204. data = f.readlines()
  205. data = [int(d) for d in data]
  206. self.assertEqual(data, [1, 0, 0])
  207. # Empty out test session again
  208. delete_test_session()
  209. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  210. class testWriteWithFeatures(unittest.TestCase):
  211. def setUp(self):
  212. self.dirname = os.path.join(tempfile.gettempdir(),
  213. 'files_for_testing_neo',
  214. 'klustakwik/test4')
  215. if not os.path.exists(self.dirname):
  216. raise unittest.SkipTest('data directory does not exist: ' +
  217. self.dirname)
  218. def test1(self):
  219. """Create clu and fet files based on spiketrains in a block.
  220. Checks that
  221. Files are created
  222. Converted to samples correctly
  223. Missing sampling rate are taken from IO reader default
  224. Spiketrains without cluster info are assigned to cluster 0
  225. Spiketrains across segments are concatenated
  226. """
  227. block = neo.Block()
  228. segment = neo.Segment()
  229. segment2 = neo.Segment()
  230. block.segments.append(segment)
  231. block.segments.append(segment2)
  232. # Fake spiketrain 1
  233. st1 = neo.SpikeTrain(times=[.002, .004, .006], units='s', t_stop=1.)
  234. st1.annotations['cluster'] = 0
  235. st1.annotations['group'] = 0
  236. wff = np.array([
  237. [11.3, 0.2],
  238. [-0.3, 12.3],
  239. [3.0, -2.5]])
  240. st1.annotations['waveform_features'] = wff
  241. segment.spiketrains.append(st1)
  242. # Create empty directory for writing
  243. if not os.path.exists(self.dirname):
  244. os.mkdir(self.dirname)
  245. delete_test_session(self.dirname)
  246. # Create writer
  247. kio = KlustaKwikIO(filename=os.path.join(self.dirname, 'base2'),
  248. sampling_rate=1000.)
  249. kio.write_block(block)
  250. # Check files were created
  251. for fn in ['.fet.0', '.clu.0']:
  252. self.assertTrue(os.path.exists(os.path.join(self.dirname,
  253. 'base2' + fn)))
  254. # Check files contain correct content
  255. fi = file(os.path.join(self.dirname, 'base2.fet.0'))
  256. # first line is nbFeatures
  257. self.assertEqual(fi.readline(), '2\n')
  258. # Now check waveforms and times are same
  259. data = fi.readlines()
  260. new_wff = []
  261. new_times = []
  262. for line in data:
  263. line_split = line.split()
  264. new_wff.append([float(val) for val in line_split[:-1]])
  265. new_times.append(int(line_split[-1]))
  266. self.assertEqual(new_times, [2, 4, 6])
  267. assert_arrays_almost_equal(wff, np.array(new_wff), .00001)
  268. # Clusters on group 0
  269. data = file(os.path.join(self.dirname, 'base2.clu.0')).readlines()
  270. data = [int(d) for d in data]
  271. self.assertEqual(data, [1, 0, 0, 0])
  272. # Now read the features and test same
  273. block = kio.read_block()
  274. train = block.segments[0].spiketrains[0]
  275. assert_arrays_almost_equal(wff, train.annotations['waveform_features'],
  276. .00001)
  277. # Empty out test session again
  278. delete_test_session(self.dirname)
  279. @unittest.skipUnless(HAVE_MLAB, "requires matplotlib")
  280. class CommonTests(BaseTestIO, unittest.TestCase):
  281. ioclass = KlustaKwikIO
  282. # These are the files it tries to read and test for compliance
  283. files_to_test = [
  284. 'test2/base',
  285. 'test2/base2',
  286. ]
  287. # Will fetch from g-node if they don't already exist locally
  288. # How does it know to do this before any of the other tests?
  289. files_to_download = [
  290. 'test1/basename.clu.0',
  291. 'test1/basename.fet.-1',
  292. 'test1/basename.fet.0',
  293. 'test1/basename.fet.1',
  294. 'test1/basename.fet.1a',
  295. 'test1/basename.fet.a1',
  296. 'test1/basename2.clu.1',
  297. 'test1/basename2.fet.1',
  298. 'test1/basename2.fet.1a',
  299. 'test2/base2.fet.5',
  300. 'test2/base.clu.0',
  301. 'test2/base.clu.1',
  302. 'test2/base.fet.0',
  303. 'test2/base.fet.1',
  304. 'test3/base1.clu.0',
  305. 'test3/base1.clu.1',
  306. 'test3/base1.clu.2',
  307. 'test3/base1.fet.0',
  308. 'test3/base1.fet.1',
  309. 'test3/base1.fet.2'
  310. ]
  311. def delete_test_session(dirname=None):
  312. """Removes all file in directory so we can test writing to it"""
  313. if dirname is None:
  314. dirname = os.path.join(os.path.dirname(__file__),
  315. 'files_for_tests/klustakwik/test3')
  316. for fi in glob.glob(os.path.join(dirname, '*')):
  317. os.remove(fi)
  318. if __name__ == '__main__':
  319. unittest.main()