test_brainwaresrcio.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.brainwaresrcio
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division, print_function
  7. import logging
  8. import os.path
  9. import sys
  10. import unittest
  11. import numpy as np
  12. import quantities as pq
  13. from neo.core import (Block, Event,
  14. ChannelIndex, Segment, SpikeTrain, Unit)
  15. from neo.io import BrainwareSrcIO, brainwaresrcio
  16. from neo.test.iotest.common_io_test import BaseTestIO
  17. from neo.test.tools import (assert_same_sub_schema,
  18. assert_neo_object_is_compliant)
  19. from neo.test.iotest.tools import create_generic_reader
  20. PY_VER = sys.version_info[0]
  21. FILES_TO_TEST = ['block_300ms_4rep_1clust_part_ch1.src',
  22. 'block_500ms_5rep_empty_fullclust_ch1.src',
  23. 'block_500ms_5rep_empty_partclust_ch1.src',
  24. 'interleaved_500ms_5rep_ch2.src',
  25. 'interleaved_500ms_5rep_nospikes_ch1.src',
  26. 'interleaved_500ms_7rep_noclust_ch1.src',
  27. 'long_170s_1rep_1clust_ch2.src',
  28. 'multi_500ms_mulitrep_ch1.src',
  29. 'random_500ms_12rep_noclust_part_ch2.src',
  30. 'sequence_500ms_5rep_ch2.src']
  31. FILES_TO_COMPARE = ['block_300ms_4rep_1clust_part_ch1',
  32. 'block_500ms_5rep_empty_fullclust_ch1',
  33. 'block_500ms_5rep_empty_partclust_ch1',
  34. 'interleaved_500ms_5rep_ch2',
  35. 'interleaved_500ms_5rep_nospikes_ch1',
  36. 'interleaved_500ms_7rep_noclust_ch1',
  37. '',
  38. 'multi_500ms_mulitrep_ch1',
  39. 'random_500ms_12rep_noclust_part_ch2',
  40. 'sequence_500ms_5rep_ch2']
  41. def proc_src(filename):
  42. '''Load an src file that has already been processed by the official matlab
  43. file converter. That matlab data is saved to an m-file, which is then
  44. converted to a numpy '.npz' file. This numpy file is the file actually
  45. loaded. This function converts it to a neo block and returns the block.
  46. This block can be compared to the block produced by BrainwareSrcIO to
  47. make sure BrainwareSrcIO is working properly
  48. block = proc_src(filename)
  49. filename: The file name of the numpy file to load. It should end with
  50. '*_src_py?.npz'. This will be converted to a neo 'file_origin' property
  51. with the value '*.src', so the filename to compare should fit that pattern.
  52. 'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
  53. for the python 3 version of the numpy file.
  54. example: filename = 'file1_src_py2.npz'
  55. src file name = 'file1.src'
  56. '''
  57. with np.load(filename) as srcobj:
  58. srcfile = srcobj.items()[0][1]
  59. filename = os.path.basename(filename[:-12]+'.src')
  60. block = Block(file_origin=filename)
  61. NChannels = srcfile['NChannels'][0, 0][0, 0]
  62. side = str(srcfile['side'][0, 0][0])
  63. ADperiod = srcfile['ADperiod'][0, 0][0, 0]
  64. comm_seg = proc_src_comments(srcfile, filename)
  65. block.segments.append(comm_seg)
  66. chx = proc_src_units(srcfile, filename)
  67. chan_nums = np.arange(NChannels, dtype='int')
  68. chan_names = ['Chan{}'.format(i) for i in range(NChannels)]
  69. chx.index = chan_nums
  70. chx.channel_names = np.array(chan_names, dtype='string_')
  71. block.channel_indexes.append(chx)
  72. for rep in srcfile['sets'][0, 0].flatten():
  73. proc_src_condition(rep, filename, ADperiod, side, block)
  74. block.create_many_to_one_relationship()
  75. return block
  76. def proc_src_comments(srcfile, filename):
  77. '''Get the comments in an src file that has been#!N
  78. processed by the official
  79. matlab function. See proc_src for details'''
  80. comm_seg = Segment(name='Comments', file_origin=filename)
  81. commentarray = srcfile['comments'].flatten()[0]
  82. senders = [res[0] for res in commentarray['sender'].flatten()]
  83. texts = [res[0] for res in commentarray['text'].flatten()]
  84. timeStamps = [res[0, 0] for res in commentarray['timeStamp'].flatten()]
  85. timeStamps = np.array(timeStamps, dtype=np.float32)
  86. t_start = timeStamps.min()
  87. timeStamps = pq.Quantity(timeStamps-t_start, units=pq.d).rescale(pq.s)
  88. texts = np.array(texts, dtype='S')
  89. senders = np.array(senders, dtype='S')
  90. t_start = brainwaresrcio.convert_brainwaresrc_timestamp(t_start.tolist())
  91. comments = Event(times=timeStamps, labels=texts, senders=senders)
  92. comm_seg.events = [comments]
  93. comm_seg.rec_datetime = t_start
  94. return comm_seg
  95. def proc_src_units(srcfile, filename):
  96. '''Get the units in an src file that has been processed by the official
  97. matlab function. See proc_src for details'''
  98. chx = ChannelIndex(file_origin=filename,
  99. index=np.array([], dtype=int))
  100. un_unit = Unit(name='UnassignedSpikes', file_origin=filename,
  101. elliptic=[], boundaries=[], timestamp=[], max_valid=[])
  102. chx.units.append(un_unit)
  103. sortInfo = srcfile['sortInfo'][0, 0]
  104. timeslice = sortInfo['timeslice'][0, 0]
  105. maxValid = timeslice['maxValid'][0, 0]
  106. cluster = timeslice['cluster'][0, 0]
  107. if len(cluster):
  108. maxValid = maxValid[0, 0]
  109. elliptic = [res.flatten() for res in cluster['elliptic'].flatten()]
  110. boundaries = [res.flatten() for res in cluster['boundaries'].flatten()]
  111. fullclust = zip(elliptic, boundaries)
  112. for ielliptic, iboundaries in fullclust:
  113. unit = Unit(file_origin=filename,
  114. boundaries=[iboundaries],
  115. elliptic=[ielliptic], timeStamp=[],
  116. max_valid=[maxValid])
  117. chx.units.append(unit)
  118. return chx
  119. def proc_src_condition(rep, filename, ADperiod, side, block):
  120. '''Get the condition in a src file that has been processed by the official
  121. matlab function. See proc_src for details'''
  122. chx = block.channel_indexes[0]
  123. stim = rep['stim'].flatten()
  124. params = [str(res[0]) for res in stim['paramName'][0].flatten()]
  125. values = [res for res in stim['paramVal'][0].flatten()]
  126. stim = dict(zip(params, values))
  127. sweepLen = rep['sweepLen'][0, 0]
  128. if not len(rep):
  129. return
  130. unassignedSpikes = rep['unassignedSpikes'].flatten()
  131. if len(unassignedSpikes):
  132. damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
  133. timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
  134. spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
  135. respWin = np.array([], dtype=np.int32)
  136. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  137. respWin, damaIndexes, timeStamps,
  138. filename)
  139. chx.units[0].spiketrains.extend(trains)
  140. atrains = [trains]
  141. else:
  142. damaIndexes = []
  143. timeStamps = []
  144. atrains = []
  145. clusters = rep['clusters'].flatten()
  146. if len(clusters):
  147. IdStrings = [res[0] for res in clusters['IdString']]
  148. sweepLens = [res[0, 0] for res in clusters['sweepLen']]
  149. respWins = [res.flatten() for res in clusters['respWin']]
  150. spikeunits = []
  151. for cluster in clusters['sweeps']:
  152. if len(cluster):
  153. spikes = [res.flatten() for res in
  154. cluster['spikes'].flatten()]
  155. else:
  156. spikes = []
  157. spikeunits.append(spikes)
  158. else:
  159. IdStrings = []
  160. sweepLens = []
  161. respWins = []
  162. spikeunits = []
  163. for unit, IdString in zip(chx.units[1:], IdStrings):
  164. unit.name = str(IdString)
  165. fullunit = zip(spikeunits, chx.units[1:], sweepLens, respWins)
  166. for spikeunit, unit, sweepLen, respWin in fullunit:
  167. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  168. respWin, damaIndexes, timeStamps,
  169. filename)
  170. atrains.append(trains)
  171. unit.spiketrains.extend(trains)
  172. atrains = zip(*atrains)
  173. for trains in atrains:
  174. segment = Segment(file_origin=filename, feature_type=-1,
  175. go_by_closest_unit_center=False,
  176. include_unit_bounds=False, **stim)
  177. block.segments.append(segment)
  178. segment.spiketrains = trains
  179. def proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod, respWin,
  180. damaIndexes, timeStamps, filename):
  181. '''Get the unit in a condition in a src file that has been processed by
  182. the official matlab function. See proc_src for details'''
  183. if not damaIndexes:
  184. damaIndexes = [0]*len(spikeunit)
  185. timeStamps = [0]*len(spikeunit)
  186. trains = []
  187. for sweep, damaIndex, timeStamp in zip(spikeunit, damaIndexes,
  188. timeStamps):
  189. timeStamp = brainwaresrcio.convert_brainwaresrc_timestamp(timeStamp)
  190. train = proc_src_condition_unit_repetition(sweep, damaIndex,
  191. timeStamp, sweepLen,
  192. side, ADperiod, respWin,
  193. filename)
  194. trains.append(train)
  195. return trains
  196. def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen,
  197. side, ADperiod, respWin, filename):
  198. '''Get the repetion for a unit in a condition in a src file that has been
  199. processed by the official matlab function. See proc_src for details'''
  200. damaIndex = damaIndex.astype('int32')
  201. if len(sweep):
  202. times = np.array([res[0, 0] for res in sweep['time']])
  203. shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res
  204. in sweep['shape']], axis=0)
  205. trig2 = np.array([res[0, 0] for res in sweep['trig2']])
  206. else:
  207. times = np.array([])
  208. shapes = np.array([[[]]])
  209. trig2 = np.array([])
  210. times = pq.Quantity(times, units=pq.ms, dtype=np.float32)
  211. t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32)
  212. t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32)
  213. trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8)
  214. waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV)
  215. sampling_period = pq.Quantity(ADperiod, units=pq.us)
  216. train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop,
  217. trig2=trig2, dtype=np.float32, timestamp=timeStamp,
  218. dama_index=damaIndex, side=side, copy=True,
  219. respwin=respWin, waveforms=waveforms,
  220. file_origin=filename)
  221. train.annotations['side'] = side
  222. train.sampling_period = sampling_period
  223. return train
  224. class BrainwareSrcIOTestCase(BaseTestIO, unittest.TestCase):
  225. '''
  226. Unit test testcase for neo.io.BrainwareSrcIO
  227. '''
  228. ioclass = BrainwareSrcIO
  229. read_and_write_is_bijective = False
  230. # These are the files it tries to read and test for compliance
  231. files_to_test = FILES_TO_TEST
  232. # these are reference files to compare to
  233. files_to_compare = FILES_TO_COMPARE
  234. # add the appropriate suffix depending on the python version
  235. for i, fname in enumerate(files_to_compare):
  236. if fname:
  237. files_to_compare[i] += '_src_py%s.npz' % PY_VER
  238. # Will fetch from g-node if they don't already exist locally
  239. # How does it know to do this before any of the other tests?
  240. files_to_download = files_to_test + files_to_compare
  241. def setUp(self):
  242. super(BrainwareSrcIOTestCase, self).setUp()
  243. def test_reading_same(self):
  244. for ioobj, path in self.iter_io_objects(return_path=True):
  245. obj_reader_all = create_generic_reader(ioobj, readall=True)
  246. obj_reader_base = create_generic_reader(ioobj, target=False)
  247. obj_reader_next = create_generic_reader(ioobj, target='next_block')
  248. obj_reader_single = create_generic_reader(ioobj)
  249. obj_all = obj_reader_all()
  250. obj_base = obj_reader_base()
  251. obj_single = obj_reader_single()
  252. obj_next = [obj_reader_next()]
  253. while ioobj._isopen:
  254. obj_next.append(obj_reader_next())
  255. try:
  256. assert_same_sub_schema(obj_all[0], obj_base)
  257. assert_same_sub_schema(obj_all[0], obj_single)
  258. assert_same_sub_schema(obj_all, obj_next)
  259. except BaseException as exc:
  260. exc.args += ('from ' + os.path.basename(path),)
  261. raise
  262. self.assertEqual(len(obj_all), len(obj_next))
  263. def test_against_reference(self):
  264. for filename, refname in zip(self.files_to_test,
  265. self.files_to_compare):
  266. if not refname:
  267. continue
  268. obj = self.read_file(filename=filename, readall=True)[0]
  269. refobj = proc_src(self.get_filename_path(refname))
  270. try:
  271. assert_neo_object_is_compliant(obj)
  272. assert_neo_object_is_compliant(refobj)
  273. assert_same_sub_schema(obj, refobj)
  274. except BaseException as exc:
  275. exc.args += ('from ' + filename,)
  276. raise
  277. if __name__ == '__main__':
  278. logger = logging.getLogger(BrainwareSrcIO.__module__ +
  279. '.' +
  280. BrainwareSrcIO.__name__)
  281. logger.setLevel(100)
  282. unittest.main()