test_blackrockrawio.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Tests of neo.rawio.examplerawio
  3. """
  4. import unittest
  5. from neo.rawio.blackrockrawio import BlackrockRawIO
  6. from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
  7. import numpy as np
  8. from numpy.testing import assert_equal
  9. try:
  10. import scipy.io
  11. HAVE_SCIPY = True
  12. except ImportError:
  13. HAVE_SCIPY = False
  14. class TestBlackrockRawIO(BaseTestRawIO, unittest.TestCase, ):
  15. rawioclass = BlackrockRawIO
  16. entities_to_test = ['FileSpec2.3001',
  17. 'blackrock_2_1/l101210-001']
  18. files_to_download = [
  19. 'FileSpec2.3001.nev',
  20. 'FileSpec2.3001.ns5',
  21. 'FileSpec2.3001.ccf',
  22. 'FileSpec2.3001.mat',
  23. 'blackrock_2_1/l101210-001.mat',
  24. 'blackrock_2_1/l101210-001_nev-02_ns5.mat',
  25. 'blackrock_2_1/l101210-001.ns2',
  26. 'blackrock_2_1/l101210-001.ns5',
  27. 'blackrock_2_1/l101210-001.nev',
  28. 'blackrock_2_1/l101210-001-02.nev']
  29. @unittest.skipUnless(HAVE_SCIPY, "requires scipy")
  30. def test_compare_blackrockio_with_matlabloader(self):
  31. """
  32. This test compares the output of ReachGraspIO.read_block() with the
  33. output generated by a Matlab implementation of a Blackrock file reader
  34. provided by the company. The output for comparison is provided in a
  35. .mat file created by the script create_data_matlab_blackrock.m.
  36. The function tests LFPs, spike times, and digital events on channels
  37. 80-83 and spike waveforms on channel 82, unit 1.
  38. For details on the file contents, refer to FileSpec2.3.txt
  39. Ported to the rawio API by Samuel Garcia.
  40. """
  41. # Load data from Matlab generated files
  42. ml = scipy.io.loadmat(self.get_filename_path('FileSpec2.3001.mat'))
  43. lfp_ml = ml['lfp'] # (channel x time) LFP matrix
  44. ts_ml = ml['ts'] # spike time stamps
  45. elec_ml = ml['el'] # spike electrodes
  46. unit_ml = ml['un'] # spike unit IDs
  47. wf_ml = ml['wf'] # waveform unit 1 channel 1
  48. mts_ml = ml['mts'] # marker time stamps
  49. mid_ml = ml['mid'] # marker IDs
  50. # Load data in channels 1-3 from original data files using the Neo
  51. # BlackrockIO
  52. reader = BlackrockRawIO(filename=self.get_filename_path('FileSpec2.3001'))
  53. reader.parse_header()
  54. # Check if analog data on channels 1-8 are equal
  55. self.assertGreater(reader.signal_channels_count(), 0)
  56. for c in range(0, 8):
  57. raw_sigs = reader.get_analogsignal_chunk(channel_indexes=[c])
  58. raw_sigs = raw_sigs.flatten()
  59. assert_equal(raw_sigs[:-1], lfp_ml[c, :])
  60. # Check if spikes in channels are equal
  61. nb_unit = reader.unit_channels_count()
  62. for unit_index in range(nb_unit):
  63. unit_name = reader.header['unit_channels'][unit_index]['name']
  64. # name is chXX#YY where XX is channel_id and YY is unit_id
  65. channel_id, unit_id = unit_name.split('#')
  66. channel_id = int(channel_id.replace('ch', ''))
  67. unit_id = int(unit_id)
  68. matlab_spikes = ts_ml[(elec_ml == channel_id) & (unit_ml == unit_id)]
  69. io_spikes = reader.get_spike_timestamps(unit_index=unit_index)
  70. assert_equal(io_spikes, matlab_spikes)
  71. # Check waveforms of channel 1, unit 0
  72. if channel_id == 1 and unit_id == 0:
  73. io_waveforms = reader.get_spike_raw_waveforms(unit_index=unit_index)
  74. io_waveforms = io_waveforms[:, 0, :] # remove dim 1
  75. assert_equal(io_waveforms, wf_ml)
  76. # Check if digital input port events are equal
  77. nb_ev_chan = reader.event_channels_count()
  78. # ~ print(reader.header['event_channels'])
  79. for ev_chan in range(nb_ev_chan):
  80. name = reader.header['event_channels']['name'][ev_chan]
  81. # ~ print(name)
  82. all_timestamps, _, labels = reader.get_event_timestamps(
  83. event_channel_index=ev_chan)
  84. if name == 'digital_input_port':
  85. for label in np.unique(labels):
  86. python_digievents = all_timestamps[labels == label]
  87. matlab_digievents = mts_ml[mid_ml == int(label)]
  88. assert_equal(python_digievents, matlab_digievents)
  89. elif name == 'comments':
  90. pass
  91. # TODO: Save comments to Matlab file.
  92. @unittest.skipUnless(HAVE_SCIPY, "requires scipy")
  93. def test_compare_blackrockio_with_matlabloader_v21(self):
  94. """
  95. This test compares the output of ReachGraspIO.read_block() with the
  96. output generated by a Matlab implementation of a Blackrock file reader
  97. provided by the company. The output for comparison is provided in a
  98. .mat file created by the script create_data_matlab_blackrock.m.
  99. The function tests LFPs, spike times, and digital events.
  100. Ported to the rawio API by Samuel Garcia.
  101. """
  102. dirname = self.get_filename_path('blackrock_2_1/l101210-001')
  103. # First run with parameters for ns5, then run with correct parameters for ns2
  104. parameters = [('blackrock_2_1/l101210-001_nev-02_ns5.mat',
  105. {'nsx_to_load': 5, 'nev_override': '-'.join([dirname, '02'])}, 96),
  106. ('blackrock_2_1/l101210-001.mat', {'nsx_to_load': 2}, 6)]
  107. for param in parameters:
  108. # Load data from Matlab generated files
  109. ml = scipy.io.loadmat(self.get_filename_path(filename=param[0]))
  110. lfp_ml = ml['lfp'] # (channel x time) LFP matrix
  111. ts_ml = ml['ts'] # spike time stamps
  112. elec_ml = ml['el'] # spike electrodes
  113. unit_ml = ml['un'] # spike unit IDs
  114. wf_ml = ml['wf'] # waveforms
  115. mts_ml = ml['mts'] # marker time stamps
  116. mid_ml = ml['mid'] # marker IDs
  117. # Load data from original data files using the Neo BlackrockIO
  118. reader = BlackrockRawIO(dirname, **param[1])
  119. reader.parse_header()
  120. # Check if analog data are equal
  121. self.assertGreater(reader.signal_channels_count(), 0)
  122. for c in range(0, param[2]):
  123. raw_sigs = reader.get_analogsignal_chunk(channel_indexes=[c])
  124. raw_sigs = raw_sigs.flatten()
  125. assert_equal(raw_sigs[:], lfp_ml[c, :])
  126. # Check if spikes in channels are equal
  127. nb_unit = reader.unit_channels_count()
  128. for unit_index in range(nb_unit):
  129. unit_name = reader.header['unit_channels'][unit_index]['name']
  130. # name is chXX#YY where XX is channel_id and YY is unit_id
  131. channel_id, unit_id = unit_name.split('#')
  132. channel_id = int(channel_id.replace('ch', ''))
  133. unit_id = int(unit_id)
  134. matlab_spikes = ts_ml[(elec_ml == channel_id) & (unit_ml == unit_id)]
  135. io_spikes = reader.get_spike_timestamps(unit_index=unit_index)
  136. assert_equal(io_spikes, matlab_spikes)
  137. # Check all waveforms
  138. io_waveforms = reader.get_spike_raw_waveforms(unit_index=unit_index)
  139. io_waveforms = io_waveforms[:, 0, :] # remove dim 1
  140. matlab_wf = wf_ml[np.nonzero(
  141. np.logical_and(elec_ml == channel_id, unit_ml == unit_id)), :][0]
  142. assert_equal(io_waveforms, matlab_wf)
  143. # Check if digital input port events are equal
  144. nb_ev_chan = reader.event_channels_count()
  145. # ~ print(reader.header['event_channels'])
  146. for ev_chan in range(nb_ev_chan):
  147. name = reader.header['event_channels']['name'][ev_chan]
  148. # ~ print(name)
  149. if name == 'digital_input_port':
  150. all_timestamps, _, labels = reader.get_event_timestamps(
  151. event_channel_index=ev_chan)
  152. for label in np.unique(labels):
  153. python_digievents = all_timestamps[labels == label]
  154. matlab_digievents = mts_ml[mid_ml == int(label)]
  155. assert_equal(python_digievents, matlab_digievents)
  156. if __name__ == '__main__':
  157. unittest.main()