examplerawio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # -*- coding: utf-8 -*-
  2. """
  3. ExampleRawIO is a class of a fake example.
  4. This is to be used when coding a new RawIO.
  5. Rules for creating a new class:
  6. 1. Step 1: Create the main class
  7. * Create a file in **neo/rawio/** that endith with "rawio.py"
  8. * Create the class that inherits BaseRawIO
  9. * copy/paste all methods that need to be implemented.
  10. See the end a neo.rawio.baserawio.BaseRawIO
  11. * code hard! The main difficulty **is _parse_header()**.
  12. In short you have a create a mandatory dict than
  13. contains channel informations::
  14. self.header = {}
  15. self.header['nb_block'] = 2
  16. self.header['nb_segment'] = [2, 3]
  17. self.header['signal_channels'] = sig_channels
  18. self.header['unit_channels'] = unit_channels
  19. self.header['event_channels'] = event_channels
  20. 2. Step 2: RawIO test:
  21. * create a file in neo/rawio/tests with the same name with "test_" prefix
  22. * copy paste neo/rawio/tests/test_examplerawio.py and do the same
  23. 3. Step 3 : Create the neo.io class with the wrapper
  24. * Create a file in neo/io/ that endith with "io.py"
  25. * Create a that hinerits bot yrou RawIO class and BaseFromRaw class
  26. * copy/paste from neo/io/exampleio.py
  27. 4.Step 4 : IO test
  28. * create a file in neo/test/iotest with the same previous name with "test_" prefix
  29. * copy/paste from neo/test/iotest/test_exampleio.py
  30. """
  31. from __future__ import unicode_literals, print_function, division, absolute_import
  32. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  33. _event_channel_dtype)
  34. import numpy as np
  35. class ExampleRawIO(BaseRawIO):
  36. """
  37. Class for "reading" fake data from an imaginary file.
  38. For the user, it give acces to raw data (signals, event, spikes) as they
  39. are in the (fake) file int16 and int64.
  40. For a developer, it is just an example showing guidelines for someone who wants
  41. to develop a new IO module.
  42. Two rules for developers:
  43. * Respect the Neo RawIO API (:ref:`_neo_rawio_API`)
  44. * Follow :ref:`_io_guiline`
  45. This fake IO:
  46. * have 2 blocks
  47. * blocks have 2 and 3 segments
  48. * have 16 signal_channel sample_rate = 10000
  49. * have 3 unit_channel
  50. * have 2 event channel: one have *type=event*, the other have
  51. *type=epoch*
  52. Usage:
  53. >>> import neo.rawio
  54. >>> r = neo.rawio.ExampleRawIO(filename='itisafake.nof')
  55. >>> r.parse_header()
  56. >>> print(r)
  57. >>> raw_chunk = r.get_analogsignal_chunk(block_index=0, seg_index=0,
  58. i_start=0, i_stop=1024, channel_names=channel_names)
  59. >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64',
  60. channel_indexes=[0, 3, 6])
  61. >>> spike_timestamp = reader.spike_timestamps(unit_index=0, t_start=None, t_stop=None)
  62. >>> spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64')
  63. >>> ev_timestamps, _, ev_labels = reader.event_timestamps(event_channel_index=0)
  64. """
  65. extensions = ['fake']
  66. rawmode = 'one-file'
  67. def __init__(self, filename=''):
  68. BaseRawIO.__init__(self)
  69. # note that this filename is ued in self._source_name
  70. self.filename = filename
  71. def _source_name(self):
  72. # this function is used by __repr__
  73. # for general cases self.filename is good
  74. # But for URL you could mask some part of the URL to keep
  75. # the main part.
  76. return self.filename
  77. def _parse_header(self):
  78. # This is the central of a RawIO
  79. # we need to collect in the original format all
  80. # informations needed for further fast acces
  81. # at any place in the file
  82. # In short _parse_header can be slow but
  83. # _get_analogsignal_chunk need to be as fast as possible
  84. # create signals channels information
  85. # This is mandatory!!!!
  86. # gain/offset/units are really important because
  87. # the scaling to real value will be done with that
  88. # at the end real_signal = (raw_signal* gain + offset) * pq.Quantity(units)
  89. sig_channels = []
  90. for c in range(16):
  91. ch_name = 'ch{}'.format(c)
  92. # our channel id is c+1 just for fun
  93. # Note that chan_id should be realated to
  94. # original channel id in the file format
  95. # so that the end user should not be lost when reading datasets
  96. chan_id = c + 1
  97. sr = 10000. # Hz
  98. dtype = 'int16'
  99. units = 'uV'
  100. gain = 1000. / 2 ** 16
  101. offset = 0.
  102. # group_id isonly for special cases when channel have diferents
  103. # sampling rate for instance. See TdtIO for that.
  104. # Here this is the general case :all channel have the same characteritics
  105. group_id = 0
  106. sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, group_id))
  107. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  108. # creating units channels
  109. # This is mandatory!!!!
  110. # Note that if there is no waveform at all in the file
  111. # then wf_units/wf_gain/wf_offset/wf_left_sweep/wf_sampling_rate
  112. # can be set to any value because _spike_raw_waveforms
  113. # will return None
  114. unit_channels = []
  115. for c in range(3):
  116. unit_name = 'unit{}'.format(c)
  117. unit_id = '#{}'.format(c)
  118. wf_units = 'uV'
  119. wf_gain = 1000. / 2 ** 16
  120. wf_offset = 0.
  121. wf_left_sweep = 20
  122. wf_sampling_rate = 10000.
  123. unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
  124. wf_offset, wf_left_sweep, wf_sampling_rate))
  125. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  126. # creating event/epoch channel
  127. # This is mandatory!!!!
  128. # In RawIO epoch and event they are dealt the same way.
  129. event_channels = []
  130. event_channels.append(('Some events', 'ev_0', 'event'))
  131. event_channels.append(('Some epochs', 'ep_1', 'epoch'))
  132. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  133. # fille into header dict
  134. # This is mandatory!!!!!
  135. self.header = {}
  136. self.header['nb_block'] = 2
  137. self.header['nb_segment'] = [2, 3]
  138. self.header['signal_channels'] = sig_channels
  139. self.header['unit_channels'] = unit_channels
  140. self.header['event_channels'] = event_channels
  141. # insert some annotation at some place
  142. # at neo.io level IO are free to add some annoations
  143. # to any object. To keep this functionality with the wrapper
  144. # BaseFromRaw you can add annoations in a nested dict.
  145. self._generate_minimal_annotations()
  146. # If you are a lazy dev you can stop here.
  147. for block_index in range(2):
  148. bl_ann = self.raw_annotations['blocks'][block_index]
  149. bl_ann['name'] = 'Block #{}'.format(block_index)
  150. bl_ann['block_extra_info'] = 'This is the block {}'.format(block_index)
  151. for seg_index in range([2, 3][block_index]):
  152. seg_ann = bl_ann['segments'][seg_index]
  153. seg_ann['name'] = 'Seg #{} Block #{}'.format(
  154. seg_index, block_index)
  155. seg_ann['seg_extra_info'] = 'This is the seg {} of block {}'.format(
  156. seg_index, block_index)
  157. for c in range(16):
  158. anasig_an = seg_ann['signals'][c]
  159. anasig_an['info'] = 'This is a good signals'
  160. for c in range(3):
  161. spiketrain_an = seg_ann['units'][c]
  162. spiketrain_an['quality'] = 'Good!!'
  163. for c in range(2):
  164. event_an = seg_ann['events'][c]
  165. if c == 0:
  166. event_an['nickname'] = 'Miss Event 0'
  167. elif c == 1:
  168. event_an['nickname'] = 'MrEpoch 1'
  169. def _segment_t_start(self, block_index, seg_index):
  170. # this must return an float scale in second
  171. # this t_start will be shared by all object in the segment
  172. # except AnalogSignal
  173. all_starts = [[0., 15.], [0., 20., 60.]]
  174. return all_starts[block_index][seg_index]
  175. def _segment_t_stop(self, block_index, seg_index):
  176. # this must return an float scale in second
  177. all_stops = [[10., 25.], [10., 30., 70.]]
  178. return all_stops[block_index][seg_index]
  179. def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
  180. # we are lucky: signals in all segment have the same shape!! (10.0 seconds)
  181. # it is not always the case
  182. # this must return an int = the number of sample
  183. # Note that channel_indexes can be ignored for most cases
  184. # except for several sampling rate.
  185. return 100000
  186. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  187. # This give the t_start of signals.
  188. # Very often this equal to _segment_t_start but not
  189. # always.
  190. # this must return an float scale in second
  191. # Note that channel_indexes can be ignored for most cases
  192. # except for several sampling rate.
  193. # Here this is the same.
  194. # this is not always the case
  195. return self._segment_t_start(block_index, seg_index)
  196. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  197. # this must return a signal chunk limited with
  198. # i_start/i_stop (can be None)
  199. # channel_indexes can be None (=all channel) or a list or numpy.array
  200. # This must return a numpy array 2D (even with one channel).
  201. # This must return the orignal dtype. No conversion here.
  202. # This must as fast as possible.
  203. # Everything that can be done in _parse_header() must not be here.
  204. # Here we are lucky: our signals is always zeros!!
  205. # it is not always the case
  206. # internally signals are int16
  207. # convertion to real units is done with self.header['signal_channels']
  208. if i_start is None:
  209. i_start = 0
  210. if i_stop is None:
  211. i_stop = 100000
  212. assert i_start >= 0, "I don't like your jokes"
  213. assert i_stop <= 100000, "I don't like your jokes"
  214. if channel_indexes is None:
  215. nb_chan = 16
  216. else:
  217. nb_chan = len(channel_indexes)
  218. raw_signals = np.zeros((i_stop - i_start, nb_chan), dtype='int16')
  219. return raw_signals
  220. def _spike_count(self, block_index, seg_index, unit_index):
  221. # Must return the nb of spike for given (block_index, seg_index, unit_index)
  222. # we are lucky: our units have all the same nb of spikes!!
  223. # it is not always the case
  224. nb_spikes = 20
  225. return nb_spikes
  226. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  227. # In our IO, timstamp are internally coded 'int64' and they
  228. # represent the index of the signals 10kHz
  229. # we are lucky: spikes have the same discharge in all segments!!
  230. # incredible neuron!! This is not always the case
  231. # the same clip t_start/t_start must be used in _spike_raw_waveforms()
  232. ts_start = (self._segment_t_start(block_index, seg_index) * 10000)
  233. spike_timestamps = np.arange(0, 10000, 500) + ts_start
  234. if t_start is not None or t_stop is not None:
  235. # restricte spikes to given limits (in seconds)
  236. lim0 = int(t_start * 10000)
  237. lim1 = int(t_stop * 10000)
  238. mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
  239. spike_timestamps = spike_timestamps[mask]
  240. return spike_timestamps
  241. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  242. # must rescale to second a particular spike_timestamps
  243. # with a fixed dtype so the user can choose the precisino he want.
  244. spike_times = spike_timestamps.astype(dtype)
  245. spike_times /= 10000. # because 10kHz
  246. return spike_times
  247. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  248. # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
  249. # in the original dtype
  250. # this must be as fast as possible.
  251. # the same clip t_start/t_start must be used in _spike_timestamps()
  252. # If there there is no waveform supported in the
  253. # IO them _spike_raw_waveforms must return None
  254. # In our IO waveforms come from all channels
  255. # they are int16
  256. # convertion to real units is done with self.header['unit_channels']
  257. # Here, we have a realistic case: all waveforms are only noise.
  258. # it is not always the case
  259. # we 20 spikes with a sweep of 50 (5ms)
  260. np.random.seed(2205) # a magic number (my birthday)
  261. waveforms = np.random.randint(low=-2 ** 4, high=2 ** 4, size=20 * 50, dtype='int16')
  262. waveforms = waveforms.reshape(20, 1, 50)
  263. return waveforms
  264. def _event_count(self, block_index, seg_index, event_channel_index):
  265. # event and spike are very similar
  266. # we have 2 event channels
  267. if event_channel_index == 0:
  268. # event channel
  269. return 6
  270. elif event_channel_index == 1:
  271. # epoch channel
  272. return 10
  273. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  274. # the main difference between spike channel and event channel
  275. # is that for here we have 3 numpy array timestamp, durations, labels
  276. # durations must be None for 'event'
  277. # label must a dtype ='U'
  278. # in our IO event are directly coded in seconds
  279. seg_t_start = self._segment_t_start(block_index, seg_index)
  280. if event_channel_index == 0:
  281. timestamp = np.arange(0, 6, dtype='float64') + seg_t_start
  282. durations = None
  283. labels = np.array(['trigger_a', 'trigger_b'] * 3, dtype='U12')
  284. elif event_channel_index == 1:
  285. timestamp = np.arange(0, 10, dtype='float64') + .5 + seg_t_start
  286. durations = np.ones((10), dtype='float64') * .25
  287. labels = np.array(['zoneX'] * 5 + ['zoneZ'] * 5, dtype='U12')
  288. if t_start is not None:
  289. keep = timestamp >= t_start
  290. timestamp, labels = timestamp[keep], labels[keep]
  291. if durations is not None:
  292. durations = durations[keep]
  293. if t_stop is not None:
  294. keep = timestamp <= t_stop
  295. timestamp, labels = timestamp[keep], labels[keep]
  296. if durations is not None:
  297. durations = durations[keep]
  298. return timestamp, durations, labels
  299. def _rescale_event_timestamp(self, event_timestamps, dtype):
  300. # must rescale to second a particular event_timestamps
  301. # with a fixed dtype so the user can choose the precisino he want.
  302. # really easy here because in our case it is already seconds
  303. event_times = event_timestamps.astype(dtype)
  304. return event_times
  305. def _rescale_epoch_duration(self, raw_duration, dtype):
  306. # really easy here because in our case it is already seconds
  307. durations = raw_duration.astype(dtype)
  308. return durations