rawio_compliance.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # -*- coding: utf-8 -*-
  2. """
  3. Here a list for testing neo.rawio API compliance.
  4. This is called automatically by `BaseTestRawIO`
  5. All rules are listed as function so it should be easier to:
  6. * identify the rawio API
  7. * debug
  8. * discuss rules
  9. """
  10. import time
  11. if not hasattr(time, 'perf_counter'):
  12. time.perf_counter = time.time
  13. import logging
  14. import numpy as np
  15. from neo.rawio.baserawio import (_signal_channel_dtype, _unit_channel_dtype,
  16. _event_channel_dtype, _common_sig_characteristics)
  17. def print_class(reader):
  18. return reader.__class__.__name__
  19. def header_is_total(reader):
  20. """
  21. Test if hedaer contains:
  22. * 'signal_channels'
  23. * 'unit_channels'
  24. * 'event_channels'
  25. """
  26. h = reader.header
  27. assert 'signal_channels' in h, 'signal_channels missing in header'
  28. if h['signal_channels'] is not None:
  29. dt = h['signal_channels'].dtype
  30. for k, _ in _signal_channel_dtype:
  31. assert k in dt.fields, '%s not in signal_channels.dtype' % k
  32. assert 'unit_channels' in h, 'unit_channels missing in header'
  33. if h['unit_channels'] is not None:
  34. dt = h['unit_channels'].dtype
  35. for k, _ in _unit_channel_dtype:
  36. assert k in dt.fields, '%s not in unit_channels.dtype' % k
  37. assert 'event_channels' in h, 'event_channels missing in header'
  38. if h['event_channels'] is not None:
  39. dt = h['event_channels'].dtype
  40. for k, _ in _event_channel_dtype:
  41. assert k in dt.fields, '%s not in event_channels.dtype' % k
  42. def count_element(reader):
  43. """
  44. Count block/segment/signals/spike/events
  45. """
  46. nb_sig = reader.signal_channels_count()
  47. nb_unit = reader.unit_channels_count()
  48. nb_event_channel = reader.event_channels_count()
  49. nb_block = reader.block_count()
  50. assert nb_block > 0, '{} have {} block'.format(print_class(reader), nb_block)
  51. for block_index in range(nb_block):
  52. nb_seg = reader.segment_count(block_index)
  53. for seg_index in range(nb_seg):
  54. t_start = reader.segment_t_start(block_index=block_index, seg_index=seg_index)
  55. t_stop = reader.segment_t_stop(block_index=block_index, seg_index=seg_index)
  56. assert t_stop > t_start
  57. if nb_sig > 0:
  58. if reader._several_channel_groups:
  59. channel_indexes_list = reader.get_group_channel_indexes()
  60. for channel_indexes in channel_indexes_list:
  61. sig_size = reader.get_signal_size(block_index, seg_index,
  62. channel_indexes=channel_indexes)
  63. else:
  64. sig_size = reader.get_signal_size(block_index, seg_index,
  65. channel_indexes=None)
  66. for unit_index in range(nb_unit):
  67. nb_spike = reader.spike_count(block_index=block_index, seg_index=seg_index,
  68. unit_index=unit_index)
  69. for event_channel_index in range(nb_event_channel):
  70. nb_event = reader.event_count(block_index=block_index, seg_index=seg_index,
  71. event_channel_index=event_channel_index)
  72. def iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  73. if channel_indexes is None:
  74. nb_sig = reader.signal_channels_count()
  75. else:
  76. nb_sig = len(channel_indexes)
  77. if nb_sig == 0:
  78. return
  79. nb_block = reader.block_count()
  80. # read all chunk in RAW data
  81. chunksize = 1024
  82. for block_index in range(nb_block):
  83. nb_seg = reader.segment_count(block_index)
  84. for seg_index in range(nb_seg):
  85. sig_size = reader.get_signal_size(block_index, seg_index, channel_indexes)
  86. nb = sig_size // chunksize + 1
  87. for i in range(nb):
  88. i_start = i * chunksize
  89. i_stop = min((i + 1) * chunksize, sig_size)
  90. raw_chunk = reader.get_analogsignal_chunk(block_index=block_index,
  91. seg_index=seg_index,
  92. i_start=i_start, i_stop=i_stop,
  93. channel_indexes=channel_indexes)
  94. yield raw_chunk
  95. def read_analogsignals(reader):
  96. """
  97. Read and convert some signals chunks.
  98. Test special case when signal_channels do not have same sampling_rate.
  99. AKA _need_chan_index_check
  100. """
  101. nb_sig = reader.signal_channels_count()
  102. if nb_sig == 0:
  103. return
  104. if reader._several_channel_groups:
  105. channel_indexes_list = reader.get_group_channel_indexes()
  106. else:
  107. channel_indexes_list = [None]
  108. # read all chunk for all channel all block all segment
  109. for channel_indexes in channel_indexes_list:
  110. for raw_chunk in iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  111. assert raw_chunk.ndim == 2
  112. # ~ pass
  113. for channel_indexes in channel_indexes_list:
  114. sr = reader.get_signal_sampling_rate(channel_indexes=channel_indexes)
  115. assert type(sr) == float, 'Type of sampling is {} should float'.format(type(sr))
  116. # make other test on the first chunk of first block first block
  117. block_index = 0
  118. seg_index = 0
  119. for channel_indexes in channel_indexes_list:
  120. i_start = 0
  121. sig_size = reader.get_signal_size(block_index, seg_index,
  122. channel_indexes=channel_indexes)
  123. i_stop = min(1024, sig_size)
  124. if channel_indexes is None:
  125. nb_sig = reader.header['signal_channels'].size
  126. channel_indexes = np.arange(nb_sig, dtype=int)
  127. all_signal_channels = reader.header['signal_channels']
  128. signal_names = all_signal_channels['name'][channel_indexes]
  129. signal_ids = all_signal_channels['id'][channel_indexes]
  130. unique_chan_name = (np.unique(signal_names).size == all_signal_channels.size)
  131. unique_chan_id = (np.unique(signal_ids).size == all_signal_channels.size)
  132. # acces by channel inde/ids/names should give the same chunk
  133. channel_indexes2 = channel_indexes[::2]
  134. channel_names2 = signal_names[::2]
  135. channel_ids2 = signal_ids[::2]
  136. raw_chunk0 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  137. i_start=i_start, i_stop=i_stop,
  138. channel_indexes=channel_indexes2)
  139. assert raw_chunk0.ndim == 2
  140. assert raw_chunk0.shape[0] == i_stop
  141. assert raw_chunk0.shape[1] == len(channel_indexes2)
  142. if unique_chan_name:
  143. raw_chunk1 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  144. i_start=i_start, i_stop=i_stop,
  145. channel_names=channel_names2)
  146. np.testing.assert_array_equal(raw_chunk0, raw_chunk1)
  147. if unique_chan_id:
  148. raw_chunk2 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  149. i_start=i_start, i_stop=i_stop,
  150. channel_ids=channel_ids2)
  151. np.testing.assert_array_equal(raw_chunk0, raw_chunk2)
  152. # convert to float32/float64
  153. for dt in ('float32', 'float64'):
  154. float_chunk0 = reader.rescale_signal_raw_to_float(raw_chunk0, dtype=dt,
  155. channel_indexes=channel_indexes2)
  156. if unique_chan_name:
  157. float_chunk1 = reader.rescale_signal_raw_to_float(raw_chunk1, dtype=dt,
  158. channel_names=channel_names2)
  159. if unique_chan_id:
  160. float_chunk2 = reader.rescale_signal_raw_to_float(raw_chunk2, dtype=dt,
  161. channel_ids=channel_ids2)
  162. assert float_chunk0.dtype == dt
  163. if unique_chan_name:
  164. np.testing.assert_array_equal(float_chunk0, float_chunk1)
  165. if unique_chan_id:
  166. np.testing.assert_array_equal(float_chunk0, float_chunk2)
  167. def benchmark_speed_read_signals(reader):
  168. """
  169. A very basic speed measurement that read all signal
  170. in a file.
  171. """
  172. if reader._several_channel_groups:
  173. channel_indexes_list = reader.get_group_channel_indexes()
  174. else:
  175. channel_indexes_list = [None]
  176. for channel_indexes in channel_indexes_list:
  177. if channel_indexes is None:
  178. nb_sig = reader.signal_channels_count()
  179. else:
  180. nb_sig = len(channel_indexes)
  181. if nb_sig == 0:
  182. continue
  183. nb_samples = 0
  184. t0 = time.perf_counter()
  185. for raw_chunk in iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  186. nb_samples += raw_chunk.shape[0]
  187. t1 = time.perf_counter()
  188. speed = (nb_samples * nb_sig) / (t1 - t0) / 1e6
  189. logging.info(
  190. '{} read ({}signals x {}samples) in {:0.3f} s so speed {:0.3f} MSPS from {}'.format(
  191. print_class(reader),
  192. nb_sig, nb_samples, t1 - t0, speed, reader.source_name()))
  193. def read_spike_times(reader):
  194. """
  195. Read and convert all spike times.
  196. """
  197. nb_block = reader.block_count()
  198. nb_unit = reader.unit_channels_count()
  199. for block_index in range(nb_block):
  200. nb_seg = reader.segment_count(block_index)
  201. for seg_index in range(nb_seg):
  202. for unit_index in range(nb_unit):
  203. nb_spike = reader.spike_count(block_index=block_index,
  204. seg_index=seg_index, unit_index=unit_index)
  205. if nb_spike == 0:
  206. continue
  207. spike_timestamp = reader.get_spike_timestamps(block_index=block_index,
  208. seg_index=seg_index,
  209. unit_index=unit_index, t_start=None,
  210. t_stop=None)
  211. assert spike_timestamp.shape[0] == nb_spike, 'nb_spike {} != {}'.format(
  212. spike_timestamp.shape[0], nb_spike)
  213. spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64')
  214. assert spike_times.dtype == 'float64'
  215. if spike_times.size > 3:
  216. # load only one spike by forcing limits
  217. t_start = spike_times[1] - 0.001
  218. t_stop = spike_times[1] + 0.001
  219. spike_timestamp2 = reader.get_spike_timestamps(block_index=block_index,
  220. seg_index=seg_index,
  221. unit_index=unit_index,
  222. t_start=t_start, t_stop=t_stop)
  223. assert spike_timestamp2.shape[0] == 1
  224. spike_times2 = reader.rescale_spike_timestamp(spike_timestamp2, 'float64')
  225. assert spike_times2[0] == spike_times[1]
  226. def read_spike_waveforms(reader):
  227. """
  228. Read and convert some all waveforms.
  229. """
  230. nb_block = reader.block_count()
  231. nb_unit = reader.unit_channels_count()
  232. for block_index in range(nb_block):
  233. nb_seg = reader.segment_count(block_index)
  234. for seg_index in range(nb_seg):
  235. for unit_index in range(nb_unit):
  236. nb_spike = reader.spike_count(block_index=block_index,
  237. seg_index=seg_index, unit_index=unit_index)
  238. if nb_spike == 0:
  239. continue
  240. raw_waveforms = reader.get_spike_raw_waveforms(block_index=block_index,
  241. seg_index=seg_index,
  242. unit_index=unit_index,
  243. t_start=None, t_stop=None)
  244. if raw_waveforms is None:
  245. continue
  246. assert raw_waveforms.shape[0] == nb_spike
  247. assert raw_waveforms.ndim == 3
  248. for dt in ('float32', 'float64'):
  249. float_waveforms = reader.rescale_waveforms_to_float(
  250. raw_waveforms, dtype=dt, unit_index=unit_index)
  251. assert float_waveforms.dtype == dt
  252. assert float_waveforms.shape == raw_waveforms.shape
  253. def read_events(reader):
  254. """
  255. Read and convert some event or epoch.
  256. """
  257. nb_block = reader.block_count()
  258. nb_event_channel = reader.event_channels_count()
  259. for block_index in range(nb_block):
  260. nb_seg = reader.segment_count(block_index)
  261. for seg_index in range(nb_seg):
  262. for ev_chan in range(nb_event_channel):
  263. nb_event = reader.event_count(block_index=block_index, seg_index=seg_index,
  264. event_channel_index=ev_chan)
  265. if nb_event == 0:
  266. continue
  267. ev_timestamps, ev_durations, ev_labels = reader.get_event_timestamps(
  268. block_index=block_index, seg_index=seg_index,
  269. event_channel_index=ev_chan)
  270. assert ev_timestamps.shape[0] == nb_event, 'Wrong shape {}, {}'.format(
  271. ev_timestamps.shape[0], nb_event)
  272. if ev_durations is not None:
  273. assert ev_durations.shape[0] == nb_event
  274. assert ev_labels.shape[0] == nb_event
  275. ev_times = reader.rescale_event_timestamp(ev_timestamps, dtype='float64')
  276. assert ev_times.dtype == 'float64'
  277. def has_annotations(reader):
  278. assert hasattr(reader, 'raw_annotations'), 'raw_annotation are not set'