spike2io.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # -*- coding: utf-8 -*-
  2. """
  3. Classe for reading data in CED spike2 files (.smr).
  4. This code is based on:
  5. - sonpy, written by Antonio Gonzalez <Antonio.Gonzalez@cantab.net>
  6. Disponible here ::
  7. http://www.neuro.ki.se/broberger/
  8. and sonpy come from :
  9. - SON Library 2.0 for MATLAB, written by Malcolm Lidierth at
  10. King's College London.
  11. See http://www.kcl.ac.uk/depsta/biomedical/cfnr/lidierth.html
  12. This IO support old (<v6) and new files (>v7) of spike2
  13. Depend on:
  14. Supported : Read
  15. Author: sgarcia
  16. """
  17. import os
  18. import sys
  19. import numpy as np
  20. import quantities as pq
  21. from neo.io.baseio import BaseIO
  22. from neo.core import Segment, AnalogSignal, SpikeTrain, Event
  23. PY3K = (sys.version_info[0] == 3)
  24. class Spike2IO(BaseIO):
  25. """
  26. Class for reading data from CED spike2.
  27. Usage:
  28. >>> from neo import io
  29. >>> r = io.Spike2IO( filename = 'File_spike2_1.smr')
  30. >>> seg = r.read_segment(lazy = False, cascade = True,)
  31. >>> print seg.analogsignals
  32. >>> print seg.spiketrains
  33. >>> print seg.events
  34. """
  35. is_readable = True
  36. is_writable = False
  37. supported_objects = [Segment, AnalogSignal, Event, SpikeTrain]
  38. readable_objects = [Segment]
  39. writeable_objects = []
  40. has_header = False
  41. is_streameable = False
  42. read_params = {Segment: [('take_ideal_sampling_rate', {'value': False})]}
  43. write_params = None
  44. name = 'Spike 2 CED'
  45. extensions = ['smr']
  46. mode = 'file'
  47. ced_units = False
  48. def __init__(self, filename=None, ced_units=False):
  49. """
  50. This class reads an smr file.
  51. Arguments:
  52. filename : the filename
  53. ced_units: whether a spike trains should be added for each unit
  54. as determined by Spike2's spike sorting (True), or if a spike
  55. channel should be considered a single unit and will ignore
  56. Spike2's spike sorting (False). Defaults to False.
  57. """
  58. BaseIO.__init__(self)
  59. self.filename = filename
  60. self.ced_units = ced_units
  61. def read_segment(self, take_ideal_sampling_rate=False,
  62. lazy=False, cascade=True):
  63. """
  64. Arguments:
  65. """
  66. header = self.read_header(filename=self.filename)
  67. # ~ print header
  68. fid = open(self.filename, 'rb')
  69. seg = Segment(
  70. file_origin=os.path.basename(self.filename),
  71. ced_version=str(header.system_id),
  72. )
  73. if not cascade:
  74. fid.close()
  75. return seg
  76. def addannotations(ob, channelHeader):
  77. ob.annotate(title=channelHeader.title)
  78. ob.annotate(physical_channel_index=channelHeader.phy_chan)
  79. ob.annotate(comment=channelHeader.comment)
  80. for i in range(header.channels):
  81. channelHeader = header.channelHeaders[i]
  82. #~ print 'channel' , i , 'kind' , channelHeader.kind
  83. if channelHeader.kind != 0:
  84. #~ print '####'
  85. #~ print 'channel' , i, 'kind' , channelHeader.kind , \
  86. #~ channelHeader.type , channelHeader.phy_chan
  87. #~ print channelHeader
  88. pass
  89. if channelHeader.kind in [1, 9]:
  90. #~ print 'analogChanel'
  91. ana_sigs = self.read_one_channel_continuous(
  92. fid, i, header, take_ideal_sampling_rate, lazy=lazy)
  93. #~ print 'nb sigs', len(anaSigs) , ' sizes : ',
  94. for anaSig in ana_sigs:
  95. addannotations(anaSig, channelHeader)
  96. anaSig.name = str(anaSig.annotations['title'])
  97. seg.analogsignals.append(anaSig)
  98. #~ print sig.signal.size,
  99. #~ print ''
  100. elif channelHeader.kind in [2, 3, 4, 5, 8]:
  101. ea = self.read_one_channel_event_or_spike(
  102. fid, i, header, lazy=lazy)
  103. if ea is not None:
  104. addannotations(ea, channelHeader)
  105. seg.events.append(ea)
  106. elif channelHeader.kind in [6, 7]:
  107. sptrs = self.read_one_channel_event_or_spike(
  108. fid, i, header, lazy=lazy)
  109. if sptrs is not None:
  110. for sptr in sptrs:
  111. addannotations(sptr, channelHeader)
  112. seg.spiketrains.append(sptr)
  113. fid.close()
  114. seg.create_many_to_one_relationship()
  115. return seg
  116. def read_header(self, filename=''):
  117. fid = open(filename, 'rb')
  118. header = HeaderReader(fid, np.dtype(headerDescription))
  119. # ~ print 'chan_size' , header.chan_size
  120. if header.system_id < 6:
  121. header.dtime_base = 1e-6
  122. header.datetime_detail = 0
  123. header.datetime_year = 0
  124. channelHeaders = []
  125. for i in range(header.channels):
  126. # read global channel header
  127. fid.seek(512 + 140 * i) # TODO verifier i ou i-1
  128. channelHeader = HeaderReader(fid,
  129. np.dtype(channelHeaderDesciption1))
  130. if channelHeader.kind in [1, 6]:
  131. dt = [('scale', 'f4'),
  132. ('offset', 'f4'),
  133. ('unit', 'S6'), ]
  134. channelHeader += HeaderReader(fid, np.dtype(dt))
  135. if header.system_id < 6:
  136. channelHeader += HeaderReader(fid, np.dtype([ ('divide' , 'i2')]) )
  137. else :
  138. channelHeader +=HeaderReader(fid, np.dtype([ ('interleave' , 'i2')]) )
  139. if channelHeader.kind in [7, 9]:
  140. dt = [('min', 'f4'),
  141. ('max', 'f4'),
  142. ('unit', 'S6'), ]
  143. channelHeader += HeaderReader(fid, np.dtype(dt))
  144. if header.system_id < 6:
  145. channelHeader += HeaderReader(fid, np.dtype([ ('divide' , 'i2')]))
  146. else :
  147. channelHeader += HeaderReader(fid, np.dtype([ ('interleave' , 'i2')]) )
  148. if channelHeader.kind in [4]:
  149. dt = [('init_low', 'u1'),
  150. ('next_low', 'u1'), ]
  151. channelHeader += HeaderReader(fid, np.dtype(dt))
  152. channelHeader.type = dict_kind[channelHeader.kind]
  153. #~ print i, channelHeader
  154. channelHeaders.append(channelHeader)
  155. header.channelHeaders = channelHeaders
  156. fid.close()
  157. return header
  158. def read_one_channel_continuous(self, fid, channel_num, header,
  159. take_ideal_sampling_rate, lazy=True):
  160. # read AnalogSignal
  161. channelHeader = header.channelHeaders[channel_num]
  162. # data type
  163. if channelHeader.kind == 1:
  164. dt = np.dtype('i2')
  165. elif channelHeader.kind == 9:
  166. dt = np.dtype('f4')
  167. # sample rate
  168. if take_ideal_sampling_rate:
  169. sampling_rate = channelHeader.ideal_rate * pq.Hz
  170. else:
  171. if header.system_id in [1, 2, 3, 4, 5]: # Before version 5
  172. #~ print channel_num, channelHeader.divide, \
  173. #~ header.us_per_time, header.time_per_adc
  174. sample_interval = (channelHeader.divide * header.us_per_time *
  175. header.time_per_adc) * 1e-6
  176. else:
  177. sample_interval = (channelHeader.l_chan_dvd *
  178. header.us_per_time * header.dtime_base)
  179. sampling_rate = (1. / sample_interval) * pq.Hz
  180. # read blocks header to preallocate memory by jumping block to block
  181. if channelHeader.blocks==0:
  182. return [ ]
  183. fid.seek(channelHeader.firstblock)
  184. blocksize = [0]
  185. starttimes = []
  186. for b in range(channelHeader.blocks):
  187. blockHeader = HeaderReader(fid, np.dtype(blockHeaderDesciption))
  188. if len(blocksize) > len(starttimes):
  189. starttimes.append(blockHeader.start_time)
  190. blocksize[-1] += blockHeader.items
  191. if blockHeader.succ_block > 0:
  192. # ugly but CED does not guarantee continuity in AnalogSignal
  193. fid.seek(blockHeader.succ_block)
  194. nextBlockHeader = HeaderReader(fid,
  195. np.dtype(blockHeaderDesciption))
  196. sample_interval = (blockHeader.end_time -
  197. blockHeader.start_time) / \
  198. (blockHeader.items - 1)
  199. interval_with_next = nextBlockHeader.start_time - \
  200. blockHeader.end_time
  201. if interval_with_next > sample_interval:
  202. blocksize.append(0)
  203. fid.seek(blockHeader.succ_block)
  204. ana_sigs = []
  205. if channelHeader.unit in unit_convert:
  206. unit = pq.Quantity(1, unit_convert[channelHeader.unit])
  207. else:
  208. # print channelHeader.unit
  209. try:
  210. unit = pq.Quantity(1, channelHeader.unit)
  211. except:
  212. unit = pq.Quantity(1, '')
  213. for b, bs in enumerate(blocksize):
  214. if lazy:
  215. signal = [] * unit
  216. else:
  217. signal = pq.Quantity(np.empty(bs, dtype='f4'), units=unit)
  218. ana_sig = AnalogSignal(
  219. signal, sampling_rate=sampling_rate,
  220. t_start=(starttimes[b] * header.us_per_time *
  221. header.dtime_base * pq.s),
  222. channel_index=channel_num)
  223. ana_sigs.append(ana_sig)
  224. if lazy:
  225. for s, ana_sig in enumerate(ana_sigs):
  226. ana_sig.lazy_shape = blocksize[s]
  227. else:
  228. # read data by jumping block to block
  229. fid.seek(channelHeader.firstblock)
  230. pos = 0
  231. numblock = 0
  232. for b in range(channelHeader.blocks):
  233. blockHeader = HeaderReader(
  234. fid, np.dtype(blockHeaderDesciption))
  235. # read data
  236. sig = np.fromstring(fid.read(blockHeader.items * dt.itemsize),
  237. dtype=dt)
  238. ana_sigs[numblock][pos:pos + sig.size] = \
  239. sig.reshape(-1, 1).astype('f4') * unit
  240. pos += sig.size
  241. if pos >= blocksize[numblock]:
  242. numblock += 1
  243. pos = 0
  244. # jump to next block
  245. if blockHeader.succ_block > 0:
  246. fid.seek(blockHeader.succ_block)
  247. # convert for int16
  248. if dt.kind == 'i':
  249. for ana_sig in ana_sigs:
  250. ana_sig *= channelHeader.scale / 6553.6
  251. ana_sig += channelHeader.offset * unit
  252. return ana_sigs
  253. def read_one_channel_event_or_spike(self, fid, channel_num, header,
  254. lazy=True):
  255. # return SPikeTrain or Event
  256. channelHeader = header.channelHeaders[channel_num]
  257. if channelHeader.firstblock < 0:
  258. return
  259. if channelHeader.kind not in [2, 3, 4, 5, 6, 7, 8]:
  260. return
  261. # # Step 1 : type of blocks
  262. if channelHeader.kind in [2, 3, 4]:
  263. # Event data
  264. fmt = [('tick', 'i4')]
  265. elif channelHeader.kind in [5]:
  266. # Marker data
  267. fmt = [('tick', 'i4'), ('marker', 'i4')]
  268. elif channelHeader.kind in [6]:
  269. # AdcMark data
  270. fmt = [('tick', 'i4'), ('marker', 'i4'),
  271. ('adc', 'S%d' % channelHeader.n_extra)]
  272. elif channelHeader.kind in [7]:
  273. # RealMark data
  274. fmt = [('tick', 'i4'), ('marker', 'i4'),
  275. ('real', 'S%d' % channelHeader.n_extra)]
  276. elif channelHeader.kind in [8]:
  277. # TextMark data
  278. fmt = [('tick', 'i4'), ('marker', 'i4'),
  279. ('label', 'S%d' % channelHeader.n_extra)]
  280. dt = np.dtype(fmt)
  281. ## Step 2 : first read for allocating mem
  282. fid.seek(channelHeader.firstblock)
  283. totalitems = 0
  284. for _ in range(channelHeader.blocks):
  285. blockHeader = HeaderReader(fid, np.dtype(blockHeaderDesciption))
  286. totalitems += blockHeader.items
  287. if blockHeader.succ_block > 0:
  288. fid.seek(blockHeader.succ_block)
  289. #~ print 'totalitems' , totalitems
  290. if lazy:
  291. if channelHeader.kind in [2, 3, 4, 5, 8]:
  292. ea = Event()
  293. ea.annotate(channel_index=channel_num)
  294. ea.lazy_shape = totalitems
  295. return ea
  296. elif channelHeader.kind in [6, 7]:
  297. # correct value for t_stop to be put in later
  298. sptr = SpikeTrain([] * pq.s, t_stop=1e99)
  299. sptr.annotate(channel_index=channel_num, ced_unit = 0)
  300. sptr.lazy_shape = totalitems
  301. return sptr
  302. else:
  303. alltrigs = np.zeros(totalitems, dtype=dt)
  304. ## Step 3 : read
  305. fid.seek(channelHeader.firstblock)
  306. pos = 0
  307. for _ in range(channelHeader.blocks):
  308. blockHeader = HeaderReader(
  309. fid, np.dtype(blockHeaderDesciption))
  310. # read all events in block
  311. trigs = np.fromstring(
  312. fid.read(blockHeader.items * dt.itemsize), dtype=dt)
  313. alltrigs[pos:pos + trigs.size] = trigs
  314. pos += trigs.size
  315. if blockHeader.succ_block > 0:
  316. fid.seek(blockHeader.succ_block)
  317. ## Step 3 convert in neo standard class: eventarrays or spiketrains
  318. alltimes = alltrigs['tick'].astype(
  319. 'f') * header.us_per_time * header.dtime_base * pq.s
  320. if channelHeader.kind in [2, 3, 4, 5, 8]:
  321. #events
  322. ea = Event(alltimes)
  323. ea.annotate(channel_index=channel_num)
  324. if channelHeader.kind >= 5:
  325. # Spike2 marker is closer to label sens of neo
  326. ea.labels = alltrigs['marker'].astype('S32')
  327. if channelHeader.kind == 8:
  328. ea.annotate(extra_labels=alltrigs['label'])
  329. return ea
  330. elif channelHeader.kind in [6, 7]:
  331. # spiketrains
  332. # waveforms
  333. if channelHeader.kind == 6:
  334. waveforms = np.fromstring(alltrigs['adc'].tostring(),
  335. dtype='i2')
  336. waveforms = waveforms.astype(
  337. 'f4') * channelHeader.scale / 6553.6 + \
  338. channelHeader.offset
  339. elif channelHeader.kind == 7:
  340. waveforms = np.fromstring(alltrigs['real'].tostring(),
  341. dtype='f4')
  342. if header.system_id >= 6 and channelHeader.interleave > 1:
  343. waveforms = waveforms.reshape(
  344. (alltimes.size, -1, channelHeader.interleave))
  345. waveforms = waveforms.swapaxes(1, 2)
  346. else:
  347. waveforms = waveforms.reshape((alltimes.size, 1, -1))
  348. if header.system_id in [1, 2, 3, 4, 5]:
  349. sample_interval = (channelHeader.divide *
  350. header.us_per_time *
  351. header.time_per_adc) * 1e-6
  352. else:
  353. sample_interval = (channelHeader.l_chan_dvd *
  354. header.us_per_time *
  355. header.dtime_base)
  356. if channelHeader.unit in unit_convert:
  357. unit = pq.Quantity(1, unit_convert[channelHeader.unit])
  358. else:
  359. #print channelHeader.unit
  360. try:
  361. unit = pq.Quantity(1, channelHeader.unit)
  362. except:
  363. unit = pq.Quantity(1, '')
  364. if len(alltimes) > 0:
  365. # can get better value from associated AnalogSignal(s) ?
  366. t_stop = alltimes.max()
  367. else:
  368. t_stop = 0.0
  369. if not self.ced_units:
  370. sptr = SpikeTrain(alltimes,
  371. waveforms = waveforms*unit,
  372. sampling_rate = (1./sample_interval)*pq.Hz,
  373. t_stop = t_stop
  374. )
  375. sptr.annotate(channel_index = channel_num, ced_unit = 0)
  376. return [sptr]
  377. sptrs = []
  378. for i in set(alltrigs['marker'] & 255):
  379. sptr = SpikeTrain(alltimes[alltrigs['marker'] == i],
  380. waveforms = waveforms[alltrigs['marker'] == i]*unit,
  381. sampling_rate = (1./sample_interval)*pq.Hz,
  382. t_stop = t_stop
  383. )
  384. sptr.annotate(channel_index = channel_num, ced_unit = i)
  385. sptrs.append(sptr)
  386. return sptrs
  387. class HeaderReader(object):
  388. def __init__(self, fid, dtype):
  389. if fid is not None:
  390. array = np.fromstring(fid.read(dtype.itemsize), dtype)[0]
  391. else:
  392. array = np.zeros(1, dtype=dtype)[0]
  393. super(HeaderReader, self).__setattr__('dtype', dtype)
  394. super(HeaderReader, self).__setattr__('array', array)
  395. def __setattr__(self, name, val):
  396. if name in self.dtype.names:
  397. self.array[name] = val
  398. else:
  399. super(HeaderReader, self).__setattr__(name, val)
  400. def __getattr__(self, name):
  401. # ~ print name
  402. if name in self.dtype.names:
  403. if self.dtype[name].kind == 'S':
  404. if PY3K:
  405. l = np.fromstring(self.array[name].decode('iso-8859-1')[0],
  406. 'u1')
  407. else:
  408. l = np.fromstring(self.array[name][0], 'u1')
  409. l = l[0]
  410. return self.array[name][1:l + 1]
  411. else:
  412. return self.array[name]
  413. def names(self):
  414. return self.array.dtype.names
  415. def __repr__(self):
  416. s = 'HEADER'
  417. for name in self.dtype.names:
  418. # ~ if self.dtype[name].kind != 'S' :
  419. #~ s += name + self.__getattr__(name)
  420. s += '{}: {}\n'.format(name, getattr(self, name))
  421. return s
  422. def __add__(self, header2):
  423. # print 'add' , self.dtype, header2.dtype
  424. newdtype = []
  425. for name in self.dtype.names:
  426. newdtype.append((name, self.dtype[name].str))
  427. for name in header2.dtype.names:
  428. newdtype.append((name, header2.dtype[name].str))
  429. newdtype = np.dtype(newdtype)
  430. newHeader = HeaderReader(None, newdtype)
  431. newHeader.array = np.fromstring(
  432. self.array.tostring() + header2.array.tostring(), newdtype)[0]
  433. return newHeader
  434. # headers structures :
  435. headerDescription = [
  436. ('system_id', 'i2'),
  437. ('copyright', 'S10'),
  438. ('creator', 'S8'),
  439. ('us_per_time', 'i2'),
  440. ('time_per_adc', 'i2'),
  441. ('filestate', 'i2'),
  442. ('first_data', 'i4'), # i8
  443. ('channels', 'i2'),
  444. ('chan_size', 'i2'),
  445. ('extra_data', 'i2'),
  446. ('buffersize', 'i2'),
  447. ('os_format', 'i2'),
  448. ('max_ftime', 'i4'), # i8
  449. ('dtime_base', 'f8'),
  450. ('datetime_detail', 'u1'),
  451. ('datetime_year', 'i2'),
  452. ('pad', 'S52'),
  453. ('comment1', 'S80'),
  454. ('comment2', 'S80'),
  455. ('comment3', 'S80'),
  456. ('comment4', 'S80'),
  457. ('comment5', 'S80'),
  458. ]
  459. channelHeaderDesciption1 = [
  460. ('del_size', 'i2'),
  461. ('next_del_block', 'i4'), # i8
  462. ('firstblock', 'i4'), # i8
  463. ('lastblock', 'i4'), # i8
  464. ('blocks', 'i2'),
  465. ('n_extra', 'i2'),
  466. ('pre_trig', 'i2'),
  467. ('free0', 'i2'),
  468. ('py_sz', 'i2'),
  469. ('max_data', 'i2'),
  470. ('comment', 'S72'),
  471. ('max_chan_time', 'i4'), # i8
  472. ('l_chan_dvd', 'i4'), # i8
  473. ('phy_chan', 'i2'),
  474. ('title', 'S10'),
  475. ('ideal_rate', 'f4'),
  476. ('kind', 'u1'),
  477. ('unused1', 'i1'),
  478. ]
  479. dict_kind = {
  480. 0: 'empty',
  481. 1: 'Adc',
  482. 2: 'EventFall',
  483. 3: 'EventRise',
  484. 4: 'EventBoth',
  485. 5: 'Marker',
  486. 6: 'AdcMark',
  487. 7: 'RealMark',
  488. 8: 'TextMark',
  489. 9: 'RealWave',
  490. }
  491. blockHeaderDesciption = [
  492. ('pred_block', 'i4'), # i8
  493. ('succ_block', 'i4'), # i8
  494. ('start_time', 'i4'), # i8
  495. ('end_time', 'i4'), # i8
  496. ('channel_num', 'i2'),
  497. ('items', 'i2'),
  498. ]
  499. unit_convert = {
  500. 'Volts': 'V',
  501. }