examplerawio.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. """
  2. ExampleRawIO is a class of a fake example.
  3. This is to be used when coding a new RawIO.
  4. Rules for creating a new class:
  5. 1. Step 1: Create the main class
  6. * Create a file in **neo/rawio/** that endith with "rawio.py"
  7. * Create the class that inherits BaseRawIO
  8. * copy/paste all methods that need to be implemented.
  9. See the end a neo.rawio.baserawio.BaseRawIO
  10. * code hard! The main difficulty **is _parse_header()**.
  11. In short you have a create a mandatory dict than
  12. contains channel informations::
  13. self.header = {}
  14. self.header['nb_block'] = 2
  15. self.header['nb_segment'] = [2, 3]
  16. self.header['signal_channels'] = sig_channels
  17. self.header['unit_channels'] = unit_channels
  18. self.header['event_channels'] = event_channels
  19. 2. Step 2: RawIO test:
  20. * create a file in neo/rawio/tests with the same name with "test_" prefix
  21. * copy paste neo/rawio/tests/test_examplerawio.py and do the same
  22. 3. Step 3 : Create the neo.io class with the wrapper
  23. * Create a file in neo/io/ that endith with "io.py"
  24. * Create a that inherits both your RawIO class and BaseFromRaw class
  25. * copy/paste from neo/io/exampleio.py
  26. 4.Step 4 : IO test
  27. * create a file in neo/test/iotest with the same previous name with "test_" prefix
  28. * copy/paste from neo/test/iotest/test_exampleio.py
  29. """
  30. from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype,
  31. _event_channel_dtype)
  32. import numpy as np
  33. class ExampleRawIO(BaseRawIO):
  34. """
  35. Class for "reading" fake data from an imaginary file.
  36. For the user, it give acces to raw data (signals, event, spikes) as they
  37. are in the (fake) file int16 and int64.
  38. For a developer, it is just an example showing guidelines for someone who wants
  39. to develop a new IO module.
  40. Two rules for developers:
  41. * Respect the :ref:`neo_rawio_API`
  42. * Follow the :ref:`io_guiline`
  43. This fake IO:
  44. * have 2 blocks
  45. * blocks have 2 and 3 segments
  46. * have 16 signal_channel sample_rate = 10000
  47. * have 3 unit_channel
  48. * have 2 event channel: one have *type=event*, the other have
  49. *type=epoch*
  50. Usage:
  51. >>> import neo.rawio
  52. >>> r = neo.rawio.ExampleRawIO(filename='itisafake.nof')
  53. >>> r.parse_header()
  54. >>> print(r)
  55. >>> raw_chunk = r.get_analogsignal_chunk(block_index=0, seg_index=0,
  56. i_start=0, i_stop=1024, channel_names=channel_names)
  57. >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64',
  58. channel_indexes=[0, 3, 6])
  59. >>> spike_timestamp = reader.spike_timestamps(unit_index=0, t_start=None, t_stop=None)
  60. >>> spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64')
  61. >>> ev_timestamps, _, ev_labels = reader.event_timestamps(event_channel_index=0)
  62. """
  63. extensions = ['fake']
  64. rawmode = 'one-file'
  65. def __init__(self, filename=''):
  66. BaseRawIO.__init__(self)
  67. # note that this filename is ued in self._source_name
  68. self.filename = filename
  69. def _source_name(self):
  70. # this function is used by __repr__
  71. # for general cases self.filename is good
  72. # But for URL you could mask some part of the URL to keep
  73. # the main part.
  74. return self.filename
  75. def _parse_header(self):
  76. # This is the central of a RawIO
  77. # we need to collect in the original format all
  78. # informations needed for further fast acces
  79. # at any place in the file
  80. # In short _parse_header can be slow but
  81. # _get_analogsignal_chunk need to be as fast as possible
  82. # create signals channels information
  83. # This is mandatory!!!!
  84. # gain/offset/units are really important because
  85. # the scaling to real value will be done with that
  86. # at the end real_signal = (raw_signal* gain + offset) * pq.Quantity(units)
  87. sig_channels = []
  88. for c in range(16):
  89. ch_name = 'ch{}'.format(c)
  90. # our channel id is c+1 just for fun
  91. # Note that chan_id should be realated to
  92. # original channel id in the file format
  93. # so that the end user should not be lost when reading datasets
  94. chan_id = c + 1
  95. sr = 10000. # Hz
  96. dtype = 'int16'
  97. units = 'uV'
  98. gain = 1000. / 2 ** 16
  99. offset = 0.
  100. # group_id isonly for special cases when channel have diferents
  101. # sampling rate for instance. See TdtIO for that.
  102. # Here this is the general case :all channel have the same characteritics
  103. group_id = 0
  104. sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, group_id))
  105. sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
  106. # creating units channels
  107. # This is mandatory!!!!
  108. # Note that if there is no waveform at all in the file
  109. # then wf_units/wf_gain/wf_offset/wf_left_sweep/wf_sampling_rate
  110. # can be set to any value because _spike_raw_waveforms
  111. # will return None
  112. unit_channels = []
  113. for c in range(3):
  114. unit_name = 'unit{}'.format(c)
  115. unit_id = '#{}'.format(c)
  116. wf_units = 'uV'
  117. wf_gain = 1000. / 2 ** 16
  118. wf_offset = 0.
  119. wf_left_sweep = 20
  120. wf_sampling_rate = 10000.
  121. unit_channels.append((unit_name, unit_id, wf_units, wf_gain,
  122. wf_offset, wf_left_sweep, wf_sampling_rate))
  123. unit_channels = np.array(unit_channels, dtype=_unit_channel_dtype)
  124. # creating event/epoch channel
  125. # This is mandatory!!!!
  126. # In RawIO epoch and event they are dealt the same way.
  127. event_channels = []
  128. event_channels.append(('Some events', 'ev_0', 'event'))
  129. event_channels.append(('Some epochs', 'ep_1', 'epoch'))
  130. event_channels = np.array(event_channels, dtype=_event_channel_dtype)
  131. # fille into header dict
  132. # This is mandatory!!!!!
  133. self.header = {}
  134. self.header['nb_block'] = 2
  135. self.header['nb_segment'] = [2, 3]
  136. self.header['signal_channels'] = sig_channels
  137. self.header['unit_channels'] = unit_channels
  138. self.header['event_channels'] = event_channels
  139. # insert some annotation at some place
  140. # at neo.io level IO are free to add some annoations
  141. # to any object. To keep this functionality with the wrapper
  142. # BaseFromRaw you can add annoations in a nested dict.
  143. self._generate_minimal_annotations()
  144. # If you are a lazy dev you can stop here.
  145. for block_index in range(2):
  146. bl_ann = self.raw_annotations['blocks'][block_index]
  147. bl_ann['name'] = 'Block #{}'.format(block_index)
  148. bl_ann['block_extra_info'] = 'This is the block {}'.format(block_index)
  149. for seg_index in range([2, 3][block_index]):
  150. seg_ann = bl_ann['segments'][seg_index]
  151. seg_ann['name'] = 'Seg #{} Block #{}'.format(
  152. seg_index, block_index)
  153. seg_ann['seg_extra_info'] = 'This is the seg {} of block {}'.format(
  154. seg_index, block_index)
  155. for c in range(16):
  156. anasig_an = seg_ann['signals'][c]
  157. anasig_an['info'] = 'This is a good signals'
  158. for c in range(3):
  159. spiketrain_an = seg_ann['units'][c]
  160. spiketrain_an['quality'] = 'Good!!'
  161. for c in range(2):
  162. event_an = seg_ann['events'][c]
  163. if c == 0:
  164. event_an['nickname'] = 'Miss Event 0'
  165. elif c == 1:
  166. event_an['nickname'] = 'MrEpoch 1'
  167. def _segment_t_start(self, block_index, seg_index):
  168. # this must return an float scale in second
  169. # this t_start will be shared by all object in the segment
  170. # except AnalogSignal
  171. all_starts = [[0., 15.], [0., 20., 60.]]
  172. return all_starts[block_index][seg_index]
  173. def _segment_t_stop(self, block_index, seg_index):
  174. # this must return an float scale in second
  175. all_stops = [[10., 25.], [10., 30., 70.]]
  176. return all_stops[block_index][seg_index]
  177. def _get_signal_size(self, block_index, seg_index, channel_indexes=None):
  178. # we are lucky: signals in all segment have the same shape!! (10.0 seconds)
  179. # it is not always the case
  180. # this must return an int = the number of sample
  181. # Note that channel_indexes can be ignored for most cases
  182. # except for several sampling rate.
  183. return 100000
  184. def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
  185. # This give the t_start of signals.
  186. # Very often this equal to _segment_t_start but not
  187. # always.
  188. # this must return an float scale in second
  189. # Note that channel_indexes can be ignored for most cases
  190. # except for several sampling rate.
  191. # Here this is the same.
  192. # this is not always the case
  193. return self._segment_t_start(block_index, seg_index)
  194. def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes):
  195. # this must return a signal chunk limited with
  196. # i_start/i_stop (can be None)
  197. # channel_indexes can be None (=all channel) or a list or numpy.array
  198. # This must return a numpy array 2D (even with one channel).
  199. # This must return the orignal dtype. No conversion here.
  200. # This must as fast as possible.
  201. # Everything that can be done in _parse_header() must not be here.
  202. # Here we are lucky: our signals is always zeros!!
  203. # it is not always the case
  204. # internally signals are int16
  205. # convertion to real units is done with self.header['signal_channels']
  206. if i_start is None:
  207. i_start = 0
  208. if i_stop is None:
  209. i_stop = 100000
  210. assert i_start >= 0, "I don't like your jokes"
  211. assert i_stop <= 100000, "I don't like your jokes"
  212. if channel_indexes is None:
  213. nb_chan = 16
  214. else:
  215. nb_chan = len(channel_indexes)
  216. raw_signals = np.zeros((i_stop - i_start, nb_chan), dtype='int16')
  217. return raw_signals
  218. def _spike_count(self, block_index, seg_index, unit_index):
  219. # Must return the nb of spike for given (block_index, seg_index, unit_index)
  220. # we are lucky: our units have all the same nb of spikes!!
  221. # it is not always the case
  222. nb_spikes = 20
  223. return nb_spikes
  224. def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
  225. # In our IO, timstamp are internally coded 'int64' and they
  226. # represent the index of the signals 10kHz
  227. # we are lucky: spikes have the same discharge in all segments!!
  228. # incredible neuron!! This is not always the case
  229. # the same clip t_start/t_start must be used in _spike_raw_waveforms()
  230. ts_start = (self._segment_t_start(block_index, seg_index) * 10000)
  231. spike_timestamps = np.arange(0, 10000, 500) + ts_start
  232. if t_start is not None or t_stop is not None:
  233. # restricte spikes to given limits (in seconds)
  234. lim0 = int(t_start * 10000)
  235. lim1 = int(t_stop * 10000)
  236. mask = (spike_timestamps >= lim0) & (spike_timestamps <= lim1)
  237. spike_timestamps = spike_timestamps[mask]
  238. return spike_timestamps
  239. def _rescale_spike_timestamp(self, spike_timestamps, dtype):
  240. # must rescale to second a particular spike_timestamps
  241. # with a fixed dtype so the user can choose the precisino he want.
  242. spike_times = spike_timestamps.astype(dtype)
  243. spike_times /= 10000. # because 10kHz
  244. return spike_times
  245. def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
  246. # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample)
  247. # in the original dtype
  248. # this must be as fast as possible.
  249. # the same clip t_start/t_start must be used in _spike_timestamps()
  250. # If there there is no waveform supported in the
  251. # IO them _spike_raw_waveforms must return None
  252. # In our IO waveforms come from all channels
  253. # they are int16
  254. # convertion to real units is done with self.header['unit_channels']
  255. # Here, we have a realistic case: all waveforms are only noise.
  256. # it is not always the case
  257. # we 20 spikes with a sweep of 50 (5ms)
  258. # trick to get how many spike in the slice
  259. ts = self._get_spike_timestamps(block_index, seg_index, unit_index, t_start, t_stop)
  260. nb_spike = ts.size
  261. np.random.seed(2205) # a magic number (my birthday)
  262. waveforms = np.random.randint(low=-2**4, high=2**4, size=nb_spike * 50, dtype='int16')
  263. waveforms = waveforms.reshape(nb_spike, 1, 50)
  264. return waveforms
  265. def _event_count(self, block_index, seg_index, event_channel_index):
  266. # event and spike are very similar
  267. # we have 2 event channels
  268. if event_channel_index == 0:
  269. # event channel
  270. return 6
  271. elif event_channel_index == 1:
  272. # epoch channel
  273. return 10
  274. def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
  275. # the main difference between spike channel and event channel
  276. # is that for here we have 3 numpy array timestamp, durations, labels
  277. # durations must be None for 'event'
  278. # label must a dtype ='U'
  279. # in our IO event are directly coded in seconds
  280. seg_t_start = self._segment_t_start(block_index, seg_index)
  281. if event_channel_index == 0:
  282. timestamp = np.arange(0, 6, dtype='float64') + seg_t_start
  283. durations = None
  284. labels = np.array(['trigger_a', 'trigger_b'] * 3, dtype='U12')
  285. elif event_channel_index == 1:
  286. timestamp = np.arange(0, 10, dtype='float64') + .5 + seg_t_start
  287. durations = np.ones((10), dtype='float64') * .25
  288. labels = np.array(['zoneX'] * 5 + ['zoneZ'] * 5, dtype='U12')
  289. if t_start is not None:
  290. keep = timestamp >= t_start
  291. timestamp, labels = timestamp[keep], labels[keep]
  292. if durations is not None:
  293. durations = durations[keep]
  294. if t_stop is not None:
  295. keep = timestamp <= t_stop
  296. timestamp, labels = timestamp[keep], labels[keep]
  297. if durations is not None:
  298. durations = durations[keep]
  299. return timestamp, durations, labels
  300. def _rescale_event_timestamp(self, event_timestamps, dtype):
  301. # must rescale to second a particular event_timestamps
  302. # with a fixed dtype so the user can choose the precisino he want.
  303. # really easy here because in our case it is already seconds
  304. event_times = event_timestamps.astype(dtype)
  305. return event_times
  306. def _rescale_epoch_duration(self, raw_duration, dtype):
  307. # really easy here because in our case it is already seconds
  308. durations = raw_duration.astype(dtype)
  309. return durations