example.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # -*- coding: utf-8 -*-
  2. """
  3. Example code for loading and processing of a recording of the reach-
  4. to-grasp experiments conducted at the Institute de Neurosciences de la Timone
  5. by Thomas Brochier and Alexa Riehle.
  6. Authors: Julia Sprenger, Lyuba Zehl, Michael Denker
  7. Copyright (c) 2017, Institute of Neuroscience and Medicine (INM-6),
  8. Forschungszentrum Juelich, Germany
  9. All rights reserved.
  10. Redistribution and use in source and binary forms, with or without
  11. modification, are permitted provided that the following conditions are met:
  12. * Redistributions of source code must retain the above copyright notice, this
  13. list of conditions and the following disclaimer.
  14. * Redistributions in binary form must reproduce the above copyright notice,
  15. this list of conditions and the following disclaimer in the documentation
  16. and/or other materials provided with the distribution.
  17. * Neither the names of the copyright holders nor the names of the contributors
  18. may be used to endorse or promote products derived from this software without
  19. specific prior written permission.
  20. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  21. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  22. WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. """
  31. # This loads the Neo and odML libraries shipped with this code. For production
  32. # use, please use the newest releases of odML and Neo.
  33. import load_local_neo_odml_elephant
  34. import os
  35. import numpy as np
  36. import matplotlib.pyplot as plt
  37. import quantities as pq
  38. from neo import Block, Segment
  39. from elephant.signal_processing import butter
  40. from reachgraspio import reachgraspio
  41. from neo.utils import cut_segment_by_epoch, add_epoch, get_events
  42. # from neo.utils import add_epoch, cut_segment_by_epoch, get_events
  43. # =============================================================================
  44. # Load data
  45. #
  46. # As a first step, we partially load the data file into memory as a Neo object.
  47. # =============================================================================
  48. # Specify the path to the recording session to load, eg,
  49. # '/home/user/l101210-001'
  50. session_name = os.path.join('..', 'datasets', 'i140703-001')
  51. # session_name = os.path.join('..', 'datasets', 'l101210-001')
  52. odml_dir = os.path.join('..', 'datasets')
  53. # Open the session for reading
  54. session = reachgraspio.ReachGraspIO(session_name, odml_directory=odml_dir)
  55. # Read the complete dataset in lazy mode. Neo object will be created, but
  56. # data are not loaded in to memory.
  57. data_block = session.read_block(correct_filter_shifts=True, lazy=True)
  58. # Access the single Segment of the data block
  59. assert len(data_block.segments) == 1
  60. data_segment = data_block.segments[0]
  61. # =============================================================================
  62. # Create offline filtered LFP
  63. #
  64. # Here, we construct one offline filtered LFP from each ns5 (monkey L) or ns6
  65. # (monkey N) raw recording trace. For monkey N, this filtered LFP can be
  66. # compared to the LFPs in the ns2 file (note that monkey L contains only
  67. # behavioral signals in the ns2 file). Also, we assign telling names to each
  68. # Neo AnalogSignal, which is used for plotting later on in this script.
  69. # =============================================================================
  70. # Iterate through all analog signals and replace these lazy object by new
  71. # analog signals containing only data of channel 62 (target_channel_id) and
  72. # provide human readable name for the analog signal (LFP / raw signal type)
  73. target_channel_id = 62
  74. nsx_to_anasig_name = {2: 'LFP signal (online filtered)',
  75. 5: 'raw signal',
  76. 6: 'raw signal'}
  77. idx = 0
  78. while idx < len(data_segment.analogsignals):
  79. # remove analog signals, that don't contain target channel
  80. channel_ids = data_segment.analogsignals[idx].array_annotations['channel_ids']
  81. if target_channel_id not in channel_ids:
  82. data_segment.analogsignals.pop(idx)
  83. continue
  84. # replace analog signal with analog signal containing data
  85. target_channel_index = np.where(channel_ids == target_channel_id)[0][0]
  86. anasig = data_segment.analogsignals[idx].load(
  87. channel_indexes=[target_channel_index])
  88. data_segment.analogsignals[idx] = anasig
  89. idx += 1
  90. # replace name by label of contained signal type
  91. anasig.name = nsx_to_anasig_name[anasig.array_annotations['nsx'][0]]
  92. # load spiketrains of same channel
  93. channel_spiketrains = data_segment.filter({'channel_id': target_channel_id})
  94. data_segment.spiketrains = [st.load(load_waveforms=True) for st in channel_spiketrains]
  95. # The LFP is not present in the data fils of both recording. Here, we
  96. # generate the LFP signal from the raw signal if it's not present already.
  97. if not data_segment.filter({'name': 'LFP signal (online filtered)'}):
  98. raw_signal = data_segment.filter({'name': 'raw signal'})[0]
  99. # Use the Elephant library to filter the raw analog signal
  100. f_anasig = butter(raw_signal, highpass_freq=None, lowpass_freq=250 * pq.Hz, order=4)
  101. print('filtering done.')
  102. f_anasig.name = 'LFP signal (offline filtered)'
  103. # Attach offline filtered LFP to the segment of data
  104. data_segment.analogsignals.extend(f_anasig)
  105. # =============================================================================
  106. # Construct analysis epochs
  107. #
  108. # In this step we extract and cut the data into time segments (termed analysis
  109. # epochs) that we wish to analyze. We contrast these analysis epochs to the
  110. # behavioral trials that are defined by the experiment as occurrence of a Trial
  111. # Start (TS-ON) event in the experiment. Concretely, here our analysis epochs
  112. # are constructed as a cutout of 25ms of data around the TS-ON event of all
  113. # successful behavioral trials.
  114. # =============================================================================
  115. # Get Trial Start (TS-ON) events of all successful behavioral trials
  116. # (corresponds to performance code 255, which is accessed for convenience and
  117. # better legibility in the dictionary attribute performance_codes of the
  118. # ReachGraspIO class).
  119. #
  120. # To this end, we filter all event objects of the loaded data to match the name
  121. # "TrialEvents", which is the Event object containing all Events available (see
  122. # documentation of ReachGraspIO). From this Event object we extract only events
  123. # matching "TS-ON" and the desired trial performance code (which are
  124. # annotations of the Event object).
  125. start_events = get_events(
  126. data_segment,
  127. name='TrialEvents',
  128. trial_event_labels='TS-ON',
  129. performance_in_trial=session.performance_codes['correct_trial'])
  130. print('got start events.')
  131. # Extract single Neo Event object containing all TS-ON triggers
  132. assert len(start_events) == 1
  133. start_event = start_events[0]
  134. # Construct analysis epochs from 10ms before the TS-ON of a successful
  135. # behavioral trial to 15ms after TS-ON. The name "analysis_epochs" is given to
  136. # the resulting Neo Epoch object. The object is not attached to the Neo
  137. # Segment. The parameter event2 of add_epoch() is left empty, since we are
  138. # cutting around a single event, as opposed to cutting between two events.
  139. pre = -10 * pq.ms
  140. post = 15 * pq.ms
  141. epoch = add_epoch(
  142. data_segment,
  143. event1=start_event, event2=None,
  144. pre=pre, post=post,
  145. attach_result=False,
  146. name='analysis_epochs',
  147. array_annotations=start_event.array_annotations)
  148. print('added epoch.')
  149. # Create new segments of data cut according to the analysis epochs of the
  150. # 'analysis_epochs' Neo Epoch object. The time axes of all segments are aligned
  151. # such that each segment starts at time 0 (parameter reset_times); annotations
  152. # describing the analysis epoch are carried over to the segments. A new Neo
  153. # Block named "data_cut_to_analysis_epochs" is created to capture all cut
  154. # analysis epochs. For execution time reason, we are only considering the
  155. # first 10 epochs here.
  156. cut_trial_block = Block(name="data_cut_to_analysis_epochs")
  157. cut_trial_block.segments = cut_segment_by_epoch(
  158. data_segment, epoch[:10], reset_time=True)
  159. # =============================================================================
  160. # Plot data
  161. # =============================================================================
  162. # Determine the first existing trial ID i from the Event object containing all
  163. # start events. Then, by calling the filter() function of the Neo Block
  164. # "data_cut_to_analysis_epochs" containing the data cut into the analysis
  165. # epochs, we ask to return all Segments annotated by the behavioral trial ID i.
  166. # In this case this call should return one matching analysis epoch around TS-ON
  167. # belonging to behavioral trial ID i. For monkey N, this is trial ID 1, for
  168. # monkey L this is trial ID 2 since trial ID 1 is not a correct trial.
  169. trial_id = int(np.min(start_event.array_annotations['trial_id']))
  170. trial_segments = cut_trial_block.filter(
  171. targdict={"trial_id": trial_id}, objects=Segment)
  172. assert len(trial_segments) == 1
  173. trial_segment = trial_segments[0]
  174. # Create figure
  175. fig = plt.figure(facecolor='w')
  176. time_unit = pq.CompoundUnit('1./30000*s')
  177. amplitude_unit = pq.microvolt
  178. nsx_colors = ['b', 'k', 'r']
  179. # Loop through all analog signals and plot the signal in a color corresponding
  180. # to its sampling frequency (i.e., originating from the ns2/ns5 or ns2/ns6).
  181. for i, anasig in enumerate(trial_segment.analogsignals):
  182. plt.plot(
  183. anasig.times.rescale(time_unit),
  184. anasig.squeeze().rescale(amplitude_unit),
  185. label=anasig.name,
  186. color=nsx_colors[i])
  187. # Loop through all spike trains and plot the spike time, and overlapping the
  188. # wave form of the spike used for spike sorting stored separately in the nev
  189. # file.
  190. for st in trial_segment.spiketrains:
  191. color = np.random.rand(3,)
  192. for spike_id, spike in enumerate(st):
  193. # Plot spike times
  194. plt.axvline(
  195. spike.rescale(time_unit).magnitude,
  196. color=color,
  197. label='Unit ID %i' % st.annotations['unit_id'])
  198. # Plot waveforms
  199. waveform = st.waveforms[spike_id, 0, :]
  200. waveform_times = np.arange(len(waveform))*time_unit + spike
  201. plt.plot(
  202. waveform_times.rescale(time_unit).magnitude,
  203. waveform.rescale(amplitude_unit),
  204. '--',
  205. linewidth=2,
  206. color=color,
  207. zorder=0)
  208. # Loop through all events
  209. for event in trial_segment.events:
  210. if event.name == 'TrialEvents':
  211. for ev_id, ev in enumerate(event):
  212. plt.axvline(
  213. ev.rescale(time_unit),
  214. alpha=0.2,
  215. linewidth=3,
  216. linestyle='dashed',
  217. label='event ' + event.array_annotations[
  218. 'trial_event_labels'][ev_id])
  219. # Finishing touches on the plot
  220. plt.autoscale(enable=True, axis='x', tight=True)
  221. plt.xlabel(time_unit.name)
  222. plt.ylabel(amplitude_unit.name)
  223. plt.legend(loc=4, fontsize=10)
  224. # Save plot
  225. fname = 'example_plot'
  226. for file_format in ['eps', 'png', 'pdf']:
  227. fig.savefig(fname + '.%s' % file_format, dpi=400, format=file_format)