test_brainwaresrcio.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. """
  2. Tests of neo.io.brainwaresrcio
  3. """
  4. import logging
  5. import os.path
  6. import unittest
  7. import numpy as np
  8. import quantities as pq
  9. from neo.core import (Block, Event,
  10. Group, Segment, SpikeTrain)
  11. from neo.io import BrainwareSrcIO, brainwaresrcio
  12. from neo.test.iotest.common_io_test import BaseTestIO
  13. from neo.test.tools import (assert_same_sub_schema,
  14. assert_neo_object_is_compliant)
  15. from neo.test.iotest.tools import create_generic_reader
  16. FILES_TO_TEST = ['block_300ms_4rep_1clust_part_ch1.src',
  17. 'block_500ms_5rep_empty_fullclust_ch1.src',
  18. 'block_500ms_5rep_empty_partclust_ch1.src',
  19. 'interleaved_500ms_5rep_ch2.src',
  20. 'interleaved_500ms_5rep_nospikes_ch1.src',
  21. 'interleaved_500ms_7rep_noclust_ch1.src',
  22. 'long_170s_1rep_1clust_ch2.src',
  23. 'multi_500ms_mulitrep_ch1.src',
  24. 'random_500ms_12rep_noclust_part_ch2.src',
  25. 'sequence_500ms_5rep_ch2.src']
  26. FILES_TO_COMPARE = ['block_300ms_4rep_1clust_part_ch1',
  27. 'block_500ms_5rep_empty_fullclust_ch1',
  28. 'block_500ms_5rep_empty_partclust_ch1',
  29. 'interleaved_500ms_5rep_ch2',
  30. 'interleaved_500ms_5rep_nospikes_ch1',
  31. 'interleaved_500ms_7rep_noclust_ch1',
  32. '',
  33. 'multi_500ms_mulitrep_ch1',
  34. 'random_500ms_12rep_noclust_part_ch2',
  35. 'sequence_500ms_5rep_ch2']
  36. def proc_src(filename):
  37. '''Load an src file that has already been processed by the official matlab
  38. file converter. That matlab data is saved to an m-file, which is then
  39. converted to a numpy '.npz' file. This numpy file is the file actually
  40. loaded. This function converts it to a neo block and returns the block.
  41. This block can be compared to the block produced by BrainwareSrcIO to
  42. make sure BrainwareSrcIO is working properly
  43. block = proc_src(filename)
  44. filename: The file name of the numpy file to load. It should end with
  45. '*_src_py?.npz'. This will be converted to a neo 'file_origin' property
  46. with the value '*.src', so the filename to compare should fit that pattern.
  47. 'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
  48. for the python 3 version of the numpy file.
  49. example: filename = 'file1_src_py2.npz'
  50. src file name = 'file1.src'
  51. '''
  52. with np.load(filename, allow_pickle=True) as srcobj:
  53. srcfile = list(srcobj.items())[0][1]
  54. filename = os.path.basename(filename[:-12] + '.src')
  55. block = Block(file_origin=filename)
  56. NChannels = srcfile['NChannels'][0, 0][0, 0]
  57. side = str(srcfile['side'][0, 0][0])
  58. ADperiod = srcfile['ADperiod'][0, 0][0, 0]
  59. comm_seg = proc_src_comments(srcfile, filename)
  60. block.segments.append(comm_seg)
  61. all_units = proc_src_units(srcfile, filename)
  62. block.groups.extend(all_units)
  63. for rep in srcfile['sets'][0, 0].flatten():
  64. proc_src_condition(rep, filename, ADperiod, side, block)
  65. block.create_many_to_one_relationship()
  66. return block
  67. def proc_src_comments(srcfile, filename):
  68. '''Get the comments in an src file that has been#!N
  69. processed by the official
  70. matlab function. See proc_src for details'''
  71. comm_seg = Segment(name='Comments', file_origin=filename)
  72. commentarray = srcfile['comments'].flatten()[0]
  73. senders = [res[0] for res in commentarray['sender'].flatten()]
  74. texts = [res[0] for res in commentarray['text'].flatten()]
  75. timeStamps = [res[0, 0] for res in commentarray['timeStamp'].flatten()]
  76. timeStamps = np.array(timeStamps, dtype=np.float32)
  77. t_start = timeStamps.min()
  78. timeStamps = pq.Quantity(timeStamps - t_start, units=pq.d).rescale(pq.s)
  79. texts = np.array(texts, dtype='U')
  80. senders = np.array(senders, dtype='S')
  81. t_start = brainwaresrcio.convert_brainwaresrc_timestamp(t_start.tolist())
  82. comments = Event(times=timeStamps, labels=texts, senders=senders)
  83. comm_seg.events = [comments]
  84. comm_seg.rec_datetime = t_start
  85. return comm_seg
  86. def proc_src_units(srcfile, filename):
  87. '''Get the units in an src file that has been processed by the official
  88. matlab function. See proc_src for details'''
  89. all_units = []
  90. un_unit = Group(name='UnassignedSpikes', file_origin=filename,
  91. elliptic=[], boundaries=[], timestamp=[], max_valid=[])
  92. all_units.append(un_unit)
  93. sortInfo = srcfile['sortInfo'][0, 0]
  94. timeslice = sortInfo['timeslice'][0, 0]
  95. maxValid = timeslice['maxValid'][0, 0]
  96. cluster = timeslice['cluster'][0, 0]
  97. if len(cluster):
  98. maxValid = maxValid[0, 0]
  99. elliptic = [res.flatten() for res in cluster['elliptic'].flatten()]
  100. boundaries = [res.flatten() for res in cluster['boundaries'].flatten()]
  101. fullclust = zip(elliptic, boundaries)
  102. for ielliptic, iboundaries in fullclust:
  103. unit = Group(file_origin=filename,
  104. boundaries=[iboundaries],
  105. elliptic=[ielliptic], timeStamp=[],
  106. max_valid=[maxValid])
  107. all_units.append(unit)
  108. return all_units
  109. def proc_src_condition(rep, filename, ADperiod, side, block):
  110. '''Get the condition in a src file that has been processed by the official
  111. matlab function. See proc_src for details'''
  112. stim = rep['stim'].flatten()
  113. params = [str(res[0]) for res in stim['paramName'][0].flatten()]
  114. values = [res for res in stim['paramVal'][0].flatten()]
  115. stim = dict(zip(params, values))
  116. sweepLen = rep['sweepLen'][0, 0]
  117. if not len(rep):
  118. return
  119. unassignedSpikes = rep['unassignedSpikes'].flatten()
  120. if len(unassignedSpikes):
  121. damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
  122. timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
  123. spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
  124. respWin = np.array([], dtype=np.int32)
  125. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  126. respWin, damaIndexes, timeStamps,
  127. filename)
  128. block.groups[0].spiketrains.extend(trains)
  129. atrains = [trains]
  130. else:
  131. damaIndexes = []
  132. timeStamps = []
  133. atrains = []
  134. clusters = rep['clusters'].flatten()
  135. if len(clusters):
  136. IdStrings = [res[0] for res in clusters['IdString']]
  137. sweepLens = [res[0, 0] for res in clusters['sweepLen']]
  138. respWins = [res.flatten() for res in clusters['respWin']]
  139. spikeunits = []
  140. for cluster in clusters['sweeps']:
  141. if len(cluster):
  142. spikes = [res.flatten() for res in
  143. cluster['spikes'].flatten()]
  144. else:
  145. spikes = []
  146. spikeunits.append(spikes)
  147. else:
  148. IdStrings = []
  149. sweepLens = []
  150. respWins = []
  151. spikeunits = []
  152. for unit, IdString in zip(block.groups[1:], IdStrings):
  153. unit.name = str(IdString)
  154. fullunit = zip(spikeunits, block.groups[1:], sweepLens, respWins)
  155. for spikeunit, unit, sweepLen, respWin in fullunit:
  156. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  157. respWin, damaIndexes, timeStamps,
  158. filename)
  159. atrains.append(trains)
  160. unit.spiketrains.extend(trains)
  161. atrains = zip(*atrains)
  162. for trains in atrains:
  163. segment = Segment(file_origin=filename, feature_type=-1,
  164. go_by_closest_unit_center=False,
  165. include_unit_bounds=False, **stim)
  166. block.segments.append(segment)
  167. segment.spiketrains = trains
  168. def proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod, respWin,
  169. damaIndexes, timeStamps, filename):
  170. '''Get the unit in a condition in a src file that has been processed by
  171. the official matlab function. See proc_src for details'''
  172. if not damaIndexes:
  173. damaIndexes = [0] * len(spikeunit)
  174. timeStamps = [0] * len(spikeunit)
  175. trains = []
  176. for sweep, damaIndex, timeStamp in zip(spikeunit, damaIndexes,
  177. timeStamps):
  178. timeStamp = brainwaresrcio.convert_brainwaresrc_timestamp(timeStamp)
  179. train = proc_src_condition_unit_repetition(sweep, damaIndex,
  180. timeStamp, sweepLen,
  181. side, ADperiod, respWin,
  182. filename)
  183. trains.append(train)
  184. return trains
  185. def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen,
  186. side, ADperiod, respWin, filename):
  187. '''Get the repetion for a unit in a condition in a src file that has been
  188. processed by the official matlab function. See proc_src for details'''
  189. damaIndex = damaIndex.astype('int32')
  190. if len(sweep):
  191. times = np.array([res[0, 0] for res in sweep['time']])
  192. shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res
  193. in sweep['shape']], axis=0)
  194. trig2 = np.array([res[0, 0] for res in sweep['trig2']])
  195. else:
  196. times = np.array([])
  197. shapes = np.array([[[]]])
  198. trig2 = np.array([])
  199. times = pq.Quantity(times, units=pq.ms, dtype=np.float32)
  200. t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32)
  201. t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32)
  202. trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8)
  203. waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV)
  204. sampling_period = pq.Quantity(ADperiod, units=pq.us)
  205. train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop,
  206. trig2=trig2, dtype=np.float32, timestamp=timeStamp,
  207. dama_index=damaIndex, side=side, copy=True,
  208. respwin=respWin, waveforms=waveforms,
  209. file_origin=filename)
  210. train.annotations['side'] = side
  211. train.sampling_period = sampling_period
  212. return train
  213. class BrainwareSrcIOTestCase(BaseTestIO, unittest.TestCase):
  214. '''
  215. Unit test testcase for neo.io.BrainwareSrcIO
  216. '''
  217. ioclass = BrainwareSrcIO
  218. read_and_write_is_bijective = False
  219. # These are the files it tries to read and test for compliance
  220. files_to_test = FILES_TO_TEST
  221. # these are reference files to compare to
  222. files_to_compare = FILES_TO_COMPARE
  223. # add the suffix
  224. for i, fname in enumerate(files_to_compare):
  225. if fname:
  226. files_to_compare[i] += '_src_py3.npz'
  227. # Will fetch from g-node if they don't already exist locally
  228. # How does it know to do this before any of the other tests?
  229. files_to_download = files_to_test + files_to_compare
  230. def setUp(self):
  231. super().setUp()
  232. def test_reading_same(self):
  233. for ioobj, path in self.iter_io_objects(return_path=True):
  234. obj_reader_all = create_generic_reader(ioobj, readall=True)
  235. obj_reader_base = create_generic_reader(ioobj, target=False)
  236. obj_reader_next = create_generic_reader(ioobj, target='next_block')
  237. obj_reader_single = create_generic_reader(ioobj)
  238. obj_all = obj_reader_all()
  239. obj_base = obj_reader_base()
  240. obj_single = obj_reader_single()
  241. obj_next = [obj_reader_next()]
  242. while ioobj._isopen:
  243. obj_next.append(obj_reader_next())
  244. try:
  245. assert_same_sub_schema(obj_all, obj_base)
  246. assert_same_sub_schema(obj_all[0], obj_single)
  247. assert_same_sub_schema(obj_all, obj_next)
  248. except BaseException as exc:
  249. exc.args += ('from ' + os.path.basename(path),)
  250. raise
  251. self.assertEqual(len(obj_all), len(obj_next))
  252. def test_against_reference(self):
  253. for filename, refname in zip(self.files_to_test,
  254. self.files_to_compare):
  255. if not refname:
  256. continue
  257. obj = self.read_file(filename=filename, readall=True)[0]
  258. refobj = proc_src(self.get_filename_path(refname))
  259. try:
  260. assert_neo_object_is_compliant(obj)
  261. assert_neo_object_is_compliant(refobj)
  262. #assert_same_sub_schema(obj, refobj) # commented out until IO is adapted to use Group
  263. except BaseException as exc:
  264. exc.args += ('from ' + filename,)
  265. raise
  266. if __name__ == '__main__':
  267. logger = logging.getLogger(BrainwareSrcIO.__module__ +
  268. '.' +
  269. BrainwareSrcIO.__name__)
  270. logger.setLevel(100)
  271. unittest.main()