Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

test_brainwaresrcio.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # -*- coding: utf-8 -*-
  2. """
  3. Tests of neo.io.brainwaresrcio
  4. """
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import, division, print_function
  7. import logging
  8. import os.path
  9. import sys
  10. try:
  11. import unittest2 as unittest
  12. except ImportError:
  13. import unittest
  14. import numpy as np
  15. import quantities as pq
  16. from neo.core import (Block, Event,
  17. ChannelIndex, Segment, SpikeTrain, Unit)
  18. from neo.io import BrainwareSrcIO, brainwaresrcio
  19. from neo.test.iotest.common_io_test import BaseTestIO
  20. from neo.test.tools import (assert_same_sub_schema,
  21. assert_neo_object_is_compliant)
  22. from neo.test.iotest.tools import create_generic_reader
  23. PY_VER = sys.version_info[0]
  24. FILES_TO_TEST = ['block_300ms_4rep_1clust_part_ch1.src',
  25. 'block_500ms_5rep_empty_fullclust_ch1.src',
  26. 'block_500ms_5rep_empty_partclust_ch1.src',
  27. 'interleaved_500ms_5rep_ch2.src',
  28. 'interleaved_500ms_5rep_nospikes_ch1.src',
  29. 'interleaved_500ms_7rep_noclust_ch1.src',
  30. 'long_170s_1rep_1clust_ch2.src',
  31. 'multi_500ms_mulitrep_ch1.src',
  32. 'random_500ms_12rep_noclust_part_ch2.src',
  33. 'sequence_500ms_5rep_ch2.src']
  34. FILES_TO_COMPARE = ['block_300ms_4rep_1clust_part_ch1',
  35. 'block_500ms_5rep_empty_fullclust_ch1',
  36. 'block_500ms_5rep_empty_partclust_ch1',
  37. 'interleaved_500ms_5rep_ch2',
  38. 'interleaved_500ms_5rep_nospikes_ch1',
  39. 'interleaved_500ms_7rep_noclust_ch1',
  40. '',
  41. 'multi_500ms_mulitrep_ch1',
  42. 'random_500ms_12rep_noclust_part_ch2',
  43. 'sequence_500ms_5rep_ch2']
  44. def proc_src(filename):
  45. '''Load an src file that has already been processed by the official matlab
  46. file converter. That matlab data is saved to an m-file, which is then
  47. converted to a numpy '.npz' file. This numpy file is the file actually
  48. loaded. This function converts it to a neo block and returns the block.
  49. This block can be compared to the block produced by BrainwareSrcIO to
  50. make sure BrainwareSrcIO is working properly
  51. block = proc_src(filename)
  52. filename: The file name of the numpy file to load. It should end with
  53. '*_src_py?.npz'. This will be converted to a neo 'file_origin' property
  54. with the value '*.src', so the filename to compare should fit that pattern.
  55. 'py?' should be 'py2' for the python 2 version of the numpy file or 'py3'
  56. for the python 3 version of the numpy file.
  57. example: filename = 'file1_src_py2.npz'
  58. src file name = 'file1.src'
  59. '''
  60. with np.load(filename) as srcobj:
  61. srcfile = srcobj.items()[0][1]
  62. filename = os.path.basename(filename[:-12]+'.src')
  63. block = Block(file_origin=filename)
  64. NChannels = srcfile['NChannels'][0, 0][0, 0]
  65. side = str(srcfile['side'][0, 0][0])
  66. ADperiod = srcfile['ADperiod'][0, 0][0, 0]
  67. comm_seg = proc_src_comments(srcfile, filename)
  68. block.segments.append(comm_seg)
  69. chx = proc_src_units(srcfile, filename)
  70. chan_nums = np.arange(NChannels, dtype='int')
  71. chan_names = ['Chan{}'.format(i) for i in range(NChannels)]
  72. chx.index = chan_nums
  73. chx.channel_names = np.array(chan_names, dtype='string_')
  74. block.channel_indexes.append(chx)
  75. for rep in srcfile['sets'][0, 0].flatten():
  76. proc_src_condition(rep, filename, ADperiod, side, block)
  77. block.create_many_to_one_relationship()
  78. return block
  79. def proc_src_comments(srcfile, filename):
  80. '''Get the comments in an src file that has been#!N
  81. processed by the official
  82. matlab function. See proc_src for details'''
  83. comm_seg = Segment(name='Comments', file_origin=filename)
  84. commentarray = srcfile['comments'].flatten()[0]
  85. senders = [res[0] for res in commentarray['sender'].flatten()]
  86. texts = [res[0] for res in commentarray['text'].flatten()]
  87. timeStamps = [res[0, 0] for res in commentarray['timeStamp'].flatten()]
  88. timeStamps = np.array(timeStamps, dtype=np.float32)
  89. t_start = timeStamps.min()
  90. timeStamps = pq.Quantity(timeStamps-t_start, units=pq.d).rescale(pq.s)
  91. texts = np.array(texts, dtype='S')
  92. senders = np.array(senders, dtype='S')
  93. t_start = brainwaresrcio.convert_brainwaresrc_timestamp(t_start.tolist())
  94. comments = Event(times=timeStamps, labels=texts, senders=senders)
  95. comm_seg.events = [comments]
  96. comm_seg.rec_datetime = t_start
  97. return comm_seg
  98. def proc_src_units(srcfile, filename):
  99. '''Get the units in an src file that has been processed by the official
  100. matlab function. See proc_src for details'''
  101. chx = ChannelIndex(file_origin=filename,
  102. index=np.array([], dtype=int))
  103. un_unit = Unit(name='UnassignedSpikes', file_origin=filename,
  104. elliptic=[], boundaries=[], timestamp=[], max_valid=[])
  105. chx.units.append(un_unit)
  106. sortInfo = srcfile['sortInfo'][0, 0]
  107. timeslice = sortInfo['timeslice'][0, 0]
  108. maxValid = timeslice['maxValid'][0, 0]
  109. cluster = timeslice['cluster'][0, 0]
  110. if len(cluster):
  111. maxValid = maxValid[0, 0]
  112. elliptic = [res.flatten() for res in cluster['elliptic'].flatten()]
  113. boundaries = [res.flatten() for res in cluster['boundaries'].flatten()]
  114. fullclust = zip(elliptic, boundaries)
  115. for ielliptic, iboundaries in fullclust:
  116. unit = Unit(file_origin=filename,
  117. boundaries=[iboundaries],
  118. elliptic=[ielliptic], timeStamp=[],
  119. max_valid=[maxValid])
  120. chx.units.append(unit)
  121. return chx
  122. def proc_src_condition(rep, filename, ADperiod, side, block):
  123. '''Get the condition in a src file that has been processed by the official
  124. matlab function. See proc_src for details'''
  125. chx = block.channel_indexes[0]
  126. stim = rep['stim'].flatten()
  127. params = [str(res[0]) for res in stim['paramName'][0].flatten()]
  128. values = [res for res in stim['paramVal'][0].flatten()]
  129. stim = dict(zip(params, values))
  130. sweepLen = rep['sweepLen'][0, 0]
  131. if not len(rep):
  132. return
  133. unassignedSpikes = rep['unassignedSpikes'].flatten()
  134. if len(unassignedSpikes):
  135. damaIndexes = [res[0, 0] for res in unassignedSpikes['damaIndex']]
  136. timeStamps = [res[0, 0] for res in unassignedSpikes['timeStamp']]
  137. spikeunit = [res.flatten() for res in unassignedSpikes['spikes']]
  138. respWin = np.array([], dtype=np.int32)
  139. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  140. respWin, damaIndexes, timeStamps,
  141. filename)
  142. chx.units[0].spiketrains.extend(trains)
  143. atrains = [trains]
  144. else:
  145. damaIndexes = []
  146. timeStamps = []
  147. atrains = []
  148. clusters = rep['clusters'].flatten()
  149. if len(clusters):
  150. IdStrings = [res[0] for res in clusters['IdString']]
  151. sweepLens = [res[0, 0] for res in clusters['sweepLen']]
  152. respWins = [res.flatten() for res in clusters['respWin']]
  153. spikeunits = []
  154. for cluster in clusters['sweeps']:
  155. if len(cluster):
  156. spikes = [res.flatten() for res in
  157. cluster['spikes'].flatten()]
  158. else:
  159. spikes = []
  160. spikeunits.append(spikes)
  161. else:
  162. IdStrings = []
  163. sweepLens = []
  164. respWins = []
  165. spikeunits = []
  166. for unit, IdString in zip(chx.units[1:], IdStrings):
  167. unit.name = str(IdString)
  168. fullunit = zip(spikeunits, chx.units[1:], sweepLens, respWins)
  169. for spikeunit, unit, sweepLen, respWin in fullunit:
  170. trains = proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod,
  171. respWin, damaIndexes, timeStamps,
  172. filename)
  173. atrains.append(trains)
  174. unit.spiketrains.extend(trains)
  175. atrains = zip(*atrains)
  176. for trains in atrains:
  177. segment = Segment(file_origin=filename, feature_type=-1,
  178. go_by_closest_unit_center=False,
  179. include_unit_bounds=False, **stim)
  180. block.segments.append(segment)
  181. segment.spiketrains = trains
  182. def proc_src_condition_unit(spikeunit, sweepLen, side, ADperiod, respWin,
  183. damaIndexes, timeStamps, filename):
  184. '''Get the unit in a condition in a src file that has been processed by
  185. the official matlab function. See proc_src for details'''
  186. if not damaIndexes:
  187. damaIndexes = [0]*len(spikeunit)
  188. timeStamps = [0]*len(spikeunit)
  189. trains = []
  190. for sweep, damaIndex, timeStamp in zip(spikeunit, damaIndexes,
  191. timeStamps):
  192. timeStamp = brainwaresrcio.convert_brainwaresrc_timestamp(timeStamp)
  193. train = proc_src_condition_unit_repetition(sweep, damaIndex,
  194. timeStamp, sweepLen,
  195. side, ADperiod, respWin,
  196. filename)
  197. trains.append(train)
  198. return trains
  199. def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen,
  200. side, ADperiod, respWin, filename):
  201. '''Get the repetion for a unit in a condition in a src file that has been
  202. processed by the official matlab function. See proc_src for details'''
  203. damaIndex = damaIndex.astype('int32')
  204. if len(sweep):
  205. times = np.array([res[0, 0] for res in sweep['time']])
  206. shapes = np.concatenate([res.flatten()[np.newaxis][np.newaxis] for res
  207. in sweep['shape']], axis=0)
  208. trig2 = np.array([res[0, 0] for res in sweep['trig2']])
  209. else:
  210. times = np.array([])
  211. shapes = np.array([[[]]])
  212. trig2 = np.array([])
  213. times = pq.Quantity(times, units=pq.ms, dtype=np.float32)
  214. t_start = pq.Quantity(0, units=pq.ms, dtype=np.float32)
  215. t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32)
  216. trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8)
  217. waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV)
  218. sampling_period = pq.Quantity(ADperiod, units=pq.us)
  219. train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop,
  220. trig2=trig2, dtype=np.float32, timestamp=timeStamp,
  221. dama_index=damaIndex, side=side, copy=True,
  222. respwin=respWin, waveforms=waveforms,
  223. file_origin=filename)
  224. train.annotations['side'] = side
  225. train.sampling_period = sampling_period
  226. return train
  227. class BrainwareSrcIOTestCase(BaseTestIO, unittest.TestCase):
  228. '''
  229. Unit test testcase for neo.io.BrainwareSrcIO
  230. '''
  231. ioclass = BrainwareSrcIO
  232. read_and_write_is_bijective = False
  233. # These are the files it tries to read and test for compliance
  234. files_to_test = FILES_TO_TEST
  235. # these are reference files to compare to
  236. files_to_compare = FILES_TO_COMPARE
  237. # add the appropriate suffix depending on the python version
  238. for i, fname in enumerate(files_to_compare):
  239. if fname:
  240. files_to_compare[i] += '_src_py%s.npz' % PY_VER
  241. # Will fetch from g-node if they don't already exist locally
  242. # How does it know to do this before any of the other tests?
  243. files_to_download = files_to_test + files_to_compare
  244. def setUp(self):
  245. super(BrainwareSrcIOTestCase, self).setUp()
  246. def test_reading_same(self):
  247. for ioobj, path in self.iter_io_objects(return_path=True):
  248. obj_reader_all = create_generic_reader(ioobj, readall=True)
  249. obj_reader_base = create_generic_reader(ioobj, target=False)
  250. obj_reader_next = create_generic_reader(ioobj, target='next_block')
  251. obj_reader_single = create_generic_reader(ioobj)
  252. obj_all = obj_reader_all()
  253. obj_base = obj_reader_base()
  254. obj_single = obj_reader_single()
  255. obj_next = [obj_reader_next()]
  256. while ioobj._isopen:
  257. obj_next.append(obj_reader_next())
  258. try:
  259. assert_same_sub_schema(obj_all[0], obj_base)
  260. assert_same_sub_schema(obj_all[0], obj_single)
  261. assert_same_sub_schema(obj_all, obj_next)
  262. except BaseException as exc:
  263. exc.args += ('from ' + os.path.basename(path),)
  264. raise
  265. self.assertEqual(len(obj_all), len(obj_next))
  266. def test_against_reference(self):
  267. for filename, refname in zip(self.files_to_test,
  268. self.files_to_compare):
  269. if not refname:
  270. continue
  271. obj = self.read_file(filename=filename, readall=True)[0]
  272. refobj = proc_src(self.get_filename_path(refname))
  273. try:
  274. assert_neo_object_is_compliant(obj)
  275. assert_neo_object_is_compliant(refobj)
  276. assert_same_sub_schema(obj, refobj)
  277. except BaseException as exc:
  278. exc.args += ('from ' + filename,)
  279. raise
  280. if __name__ == '__main__':
  281. logger = logging.getLogger(BrainwareSrcIO.__module__ +
  282. '.' +
  283. BrainwareSrcIO.__name__)
  284. logger.setLevel(100)
  285. unittest.main()