rawio_compliance.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. """
  2. Here a list for testing neo.rawio API compliance.
  3. This is called automatically by `BaseTestRawIO`
  4. All rules are listed as function so it should be easier to:
  5. * identify the rawio API
  6. * debug
  7. * discuss rules
  8. """
  9. import time
  10. if not hasattr(time, 'perf_counter'):
  11. time.perf_counter = time.time
  12. import logging
  13. import numpy as np
  14. from neo.rawio.baserawio import (_signal_channel_dtype, _unit_channel_dtype,
  15. _event_channel_dtype, _common_sig_characteristics)
  16. def print_class(reader):
  17. return reader.__class__.__name__
  18. def header_is_total(reader):
  19. """
  20. Test if hedaer contains:
  21. * 'signal_channels'
  22. * 'unit_channels'
  23. * 'event_channels'
  24. """
  25. h = reader.header
  26. assert 'signal_channels' in h, 'signal_channels missing in header'
  27. if h['signal_channels'] is not None:
  28. dt = h['signal_channels'].dtype
  29. for k, _ in _signal_channel_dtype:
  30. assert k in dt.fields, '%s not in signal_channels.dtype' % k
  31. assert 'unit_channels' in h, 'unit_channels missing in header'
  32. if h['unit_channels'] is not None:
  33. dt = h['unit_channels'].dtype
  34. for k, _ in _unit_channel_dtype:
  35. assert k in dt.fields, '%s not in unit_channels.dtype' % k
  36. assert 'event_channels' in h, 'event_channels missing in header'
  37. if h['event_channels'] is not None:
  38. dt = h['event_channels'].dtype
  39. for k, _ in _event_channel_dtype:
  40. assert k in dt.fields, '%s not in event_channels.dtype' % k
  41. def count_element(reader):
  42. """
  43. Count block/segment/signals/spike/events
  44. """
  45. nb_sig = reader.signal_channels_count()
  46. nb_unit = reader.unit_channels_count()
  47. nb_event_channel = reader.event_channels_count()
  48. nb_block = reader.block_count()
  49. assert nb_block > 0, '{} have {} block'.format(print_class(reader), nb_block)
  50. for block_index in range(nb_block):
  51. nb_seg = reader.segment_count(block_index)
  52. for seg_index in range(nb_seg):
  53. t_start = reader.segment_t_start(block_index=block_index, seg_index=seg_index)
  54. t_stop = reader.segment_t_stop(block_index=block_index, seg_index=seg_index)
  55. assert t_stop > t_start
  56. if nb_sig > 0:
  57. if reader._several_channel_groups:
  58. channel_indexes_list = reader.get_group_signal_channel_indexes()
  59. for channel_indexes in channel_indexes_list:
  60. sig_size = reader.get_signal_size(block_index, seg_index,
  61. channel_indexes=channel_indexes)
  62. else:
  63. sig_size = reader.get_signal_size(block_index, seg_index,
  64. channel_indexes=None)
  65. for unit_index in range(nb_unit):
  66. nb_spike = reader.spike_count(block_index=block_index, seg_index=seg_index,
  67. unit_index=unit_index)
  68. for event_channel_index in range(nb_event_channel):
  69. nb_event = reader.event_count(block_index=block_index, seg_index=seg_index,
  70. event_channel_index=event_channel_index)
  71. def iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  72. if channel_indexes is None:
  73. nb_sig = reader.signal_channels_count()
  74. else:
  75. nb_sig = len(channel_indexes)
  76. if nb_sig == 0:
  77. return
  78. nb_block = reader.block_count()
  79. # read all chunk in RAW data
  80. chunksize = 1024
  81. for block_index in range(nb_block):
  82. nb_seg = reader.segment_count(block_index)
  83. for seg_index in range(nb_seg):
  84. sig_size = reader.get_signal_size(block_index, seg_index, channel_indexes)
  85. nb = sig_size // chunksize + 1
  86. for i in range(nb):
  87. i_start = i * chunksize
  88. i_stop = min((i + 1) * chunksize, sig_size)
  89. raw_chunk = reader.get_analogsignal_chunk(block_index=block_index,
  90. seg_index=seg_index,
  91. i_start=i_start, i_stop=i_stop,
  92. channel_indexes=channel_indexes)
  93. yield raw_chunk
  94. def read_analogsignals(reader):
  95. """
  96. Read and convert some signals chunks.
  97. Test special case when signal_channels do not have same sampling_rate.
  98. AKA _need_chan_index_check
  99. """
  100. nb_sig = reader.signal_channels_count()
  101. if nb_sig == 0:
  102. return
  103. if reader._several_channel_groups:
  104. channel_indexes_list = reader.get_group_signal_channel_indexes()
  105. else:
  106. channel_indexes_list = [None]
  107. # read all chunk for all channel all block all segment
  108. for channel_indexes in channel_indexes_list:
  109. for raw_chunk in iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  110. assert raw_chunk.ndim == 2
  111. # ~ pass
  112. for channel_indexes in channel_indexes_list:
  113. sr = reader.get_signal_sampling_rate(channel_indexes=channel_indexes)
  114. assert type(sr) == float, 'Type of sampling is {} should float'.format(type(sr))
  115. # make other test on the first chunk of first block first block
  116. block_index = 0
  117. seg_index = 0
  118. for channel_indexes in channel_indexes_list:
  119. i_start = 0
  120. sig_size = reader.get_signal_size(block_index, seg_index,
  121. channel_indexes=channel_indexes)
  122. i_stop = min(1024, sig_size)
  123. if channel_indexes is None:
  124. nb_sig = reader.header['signal_channels'].size
  125. channel_indexes = np.arange(nb_sig, dtype=int)
  126. all_signal_channels = reader.header['signal_channels']
  127. signal_names = all_signal_channels['name'][channel_indexes]
  128. signal_ids = all_signal_channels['id'][channel_indexes]
  129. unique_chan_name = (np.unique(signal_names).size == all_signal_channels.size)
  130. unique_chan_id = (np.unique(signal_ids).size == all_signal_channels.size)
  131. # acces by channel inde/ids/names should give the same chunk
  132. channel_indexes2 = channel_indexes[::2]
  133. channel_names2 = signal_names[::2]
  134. channel_ids2 = signal_ids[::2]
  135. raw_chunk0 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  136. i_start=i_start, i_stop=i_stop,
  137. channel_indexes=channel_indexes2)
  138. assert raw_chunk0.ndim == 2
  139. assert raw_chunk0.shape[0] == i_stop
  140. assert raw_chunk0.shape[1] == len(channel_indexes2)
  141. if unique_chan_name:
  142. raw_chunk1 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  143. i_start=i_start, i_stop=i_stop,
  144. channel_names=channel_names2)
  145. np.testing.assert_array_equal(raw_chunk0, raw_chunk1)
  146. if unique_chan_id:
  147. raw_chunk2 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index,
  148. i_start=i_start, i_stop=i_stop,
  149. channel_ids=channel_ids2)
  150. np.testing.assert_array_equal(raw_chunk0, raw_chunk2)
  151. # convert to float32/float64
  152. for dt in ('float32', 'float64'):
  153. float_chunk0 = reader.rescale_signal_raw_to_float(raw_chunk0, dtype=dt,
  154. channel_indexes=channel_indexes2)
  155. if unique_chan_name:
  156. float_chunk1 = reader.rescale_signal_raw_to_float(raw_chunk1, dtype=dt,
  157. channel_names=channel_names2)
  158. if unique_chan_id:
  159. float_chunk2 = reader.rescale_signal_raw_to_float(raw_chunk2, dtype=dt,
  160. channel_ids=channel_ids2)
  161. assert float_chunk0.dtype == dt
  162. if unique_chan_name:
  163. np.testing.assert_array_equal(float_chunk0, float_chunk1)
  164. if unique_chan_id:
  165. np.testing.assert_array_equal(float_chunk0, float_chunk2)
  166. # read 500ms with several chunksize
  167. sr = reader.get_signal_sampling_rate(channel_indexes=channel_indexes)
  168. lenght_to_read = int(.5 * sr)
  169. if lenght_to_read < sig_size:
  170. ref_raw_sigs = reader.get_analogsignal_chunk(block_index=block_index,
  171. seg_index=seg_index, i_start=0,
  172. i_stop=lenght_to_read,
  173. channel_indexes=channel_indexes)
  174. for chunksize in (511, 512, 513, 1023, 1024, 1025):
  175. i_start = 0
  176. chunks = []
  177. while i_start < lenght_to_read:
  178. i_stop = min(i_start + chunksize, lenght_to_read)
  179. raw_chunk = reader.get_analogsignal_chunk(block_index=block_index,
  180. seg_index=seg_index, i_start=i_start,
  181. i_stop=i_stop,
  182. channel_indexes=channel_indexes)
  183. chunks.append(raw_chunk)
  184. i_start += chunksize
  185. chunk_raw_sigs = np.concatenate(chunks, axis=0)
  186. np.testing.assert_array_equal(ref_raw_sigs, chunk_raw_sigs)
  187. def benchmark_speed_read_signals(reader):
  188. """
  189. A very basic speed measurement that read all signal
  190. in a file.
  191. """
  192. if reader._several_channel_groups:
  193. channel_indexes_list = reader.get_group_signal_channel_indexes()
  194. else:
  195. channel_indexes_list = [None]
  196. for channel_indexes in channel_indexes_list:
  197. if channel_indexes is None:
  198. nb_sig = reader.signal_channels_count()
  199. else:
  200. nb_sig = len(channel_indexes)
  201. if nb_sig == 0:
  202. continue
  203. nb_samples = 0
  204. t0 = time.perf_counter()
  205. for raw_chunk in iter_over_sig_chunks(reader, channel_indexes, chunksize=1024):
  206. nb_samples += raw_chunk.shape[0]
  207. t1 = time.perf_counter()
  208. if t0 != t1:
  209. speed = (nb_samples * nb_sig) / (t1 - t0) / 1e6
  210. else:
  211. speed = np.inf
  212. logging.info(
  213. '{} read ({}signals x {}samples) in {:0.3f} s so speed {:0.3f} MSPS from {}'.format(
  214. print_class(reader),
  215. nb_sig, nb_samples, t1 - t0, speed, reader.source_name()))
  216. def read_spike_times(reader):
  217. """
  218. Read and convert all spike times.
  219. """
  220. nb_block = reader.block_count()
  221. nb_unit = reader.unit_channels_count()
  222. for block_index in range(nb_block):
  223. nb_seg = reader.segment_count(block_index)
  224. for seg_index in range(nb_seg):
  225. for unit_index in range(nb_unit):
  226. nb_spike = reader.spike_count(block_index=block_index,
  227. seg_index=seg_index, unit_index=unit_index)
  228. if nb_spike == 0:
  229. continue
  230. spike_timestamp = reader.get_spike_timestamps(block_index=block_index,
  231. seg_index=seg_index,
  232. unit_index=unit_index, t_start=None,
  233. t_stop=None)
  234. assert spike_timestamp.shape[0] == nb_spike, 'nb_spike {} != {}'.format(
  235. spike_timestamp.shape[0], nb_spike)
  236. spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64')
  237. assert spike_times.dtype == 'float64'
  238. if spike_times.size > 3:
  239. # load only one spike by forcing limits
  240. t_start = spike_times[1] - 0.001
  241. t_stop = spike_times[1] + 0.001
  242. spike_timestamp2 = reader.get_spike_timestamps(block_index=block_index,
  243. seg_index=seg_index,
  244. unit_index=unit_index,
  245. t_start=t_start, t_stop=t_stop)
  246. assert spike_timestamp2.shape[0] == 1
  247. spike_times2 = reader.rescale_spike_timestamp(spike_timestamp2, 'float64')
  248. assert spike_times2[0] == spike_times[1]
  249. def read_spike_waveforms(reader):
  250. """
  251. Read and convert some all waveforms.
  252. """
  253. nb_block = reader.block_count()
  254. nb_unit = reader.unit_channels_count()
  255. for block_index in range(nb_block):
  256. nb_seg = reader.segment_count(block_index)
  257. for seg_index in range(nb_seg):
  258. for unit_index in range(nb_unit):
  259. nb_spike = reader.spike_count(block_index=block_index,
  260. seg_index=seg_index, unit_index=unit_index)
  261. if nb_spike == 0:
  262. continue
  263. raw_waveforms = reader.get_spike_raw_waveforms(block_index=block_index,
  264. seg_index=seg_index,
  265. unit_index=unit_index,
  266. t_start=None, t_stop=None)
  267. if raw_waveforms is None:
  268. continue
  269. assert raw_waveforms.shape[0] == nb_spike
  270. assert raw_waveforms.ndim == 3
  271. for dt in ('float32', 'float64'):
  272. float_waveforms = reader.rescale_waveforms_to_float(
  273. raw_waveforms, dtype=dt, unit_index=unit_index)
  274. assert float_waveforms.dtype == dt
  275. assert float_waveforms.shape == raw_waveforms.shape
  276. def read_events(reader):
  277. """
  278. Read and convert some event or epoch.
  279. """
  280. nb_block = reader.block_count()
  281. nb_event_channel = reader.event_channels_count()
  282. for block_index in range(nb_block):
  283. nb_seg = reader.segment_count(block_index)
  284. for seg_index in range(nb_seg):
  285. for ev_chan in range(nb_event_channel):
  286. nb_event = reader.event_count(block_index=block_index, seg_index=seg_index,
  287. event_channel_index=ev_chan)
  288. if nb_event == 0:
  289. continue
  290. ev_timestamps, ev_durations, ev_labels = reader.get_event_timestamps(
  291. block_index=block_index, seg_index=seg_index,
  292. event_channel_index=ev_chan)
  293. assert ev_timestamps.shape[0] == nb_event, 'Wrong shape {}, {}'.format(
  294. ev_timestamps.shape[0], nb_event)
  295. if ev_durations is not None:
  296. assert ev_durations.shape[0] == nb_event
  297. assert ev_labels.shape[0] == nb_event
  298. ev_times = reader.rescale_event_timestamp(ev_timestamps, dtype='float64')
  299. assert ev_times.dtype == 'float64'
  300. def has_annotations(reader):
  301. assert hasattr(reader, 'raw_annotations'), 'raw_annotation are not set'