data_overview_2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # -*- coding: utf-8 -*-
  2. """
  3. Code for generating the second data figure in the manuscript.
  4. Authors: Julia Sprenger, Lyuba Zehl, Michael Denker
  5. Copyright (c) 2017, Institute of Neuroscience and Medicine (INM-6),
  6. Forschungszentrum Juelich, Germany
  7. All rights reserved.
  8. Redistribution and use in source and binary forms, with or without
  9. modification, are permitted provided that the following conditions are met:
  10. * Redistributions of source code must retain the above copyright notice, this
  11. list of conditions and the following disclaimer.
  12. * Redistributions in binary form must reproduce the above copyright notice,
  13. this list of conditions and the following disclaimer in the documentation
  14. and/or other materials provided with the distribution.
  15. * Neither the names of the copyright holders nor the names of the contributors
  16. may be used to endorse or promote products derived from this software without
  17. specific prior written permission.
  18. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  19. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  20. WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  21. DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  22. FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  23. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  24. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  25. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  26. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  27. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  28. """
  29. # This loads the Neo and odML libraries shipped with this code. For production
  30. # use, please use the newest releases of odML and Neo.
  31. import load_local_neo_odml_elephant
  32. import os
  33. import matplotlib.pyplot as plt
  34. from matplotlib import gridspec, transforms
  35. import quantities as pq
  36. import numpy as np
  37. from neo import (AnalogSignal, SpikeTrain)
  38. from neo.utils import *
  39. from reachgraspio import reachgraspio
  40. from neo_utils import load_segment
  41. # =============================================================================
  42. # Define data and metadata directories and general settings
  43. # =============================================================================
  44. def get_monkey_datafile(monkey):
  45. if monkey == "Lilou":
  46. return "l101210-001" # ns2 (behavior) and ns5 present
  47. elif monkey == "Nikos2":
  48. return "i140703-001" # ns2 and ns6 present
  49. else:
  50. return ""
  51. # Enter your dataset directory here
  52. datasetdir = os.path.join('..', 'datasets')
  53. chosen_els = {'Lilou': range(3, 97, 7), 'Nikos2': range(1, 97, 7)}
  54. chosen_el = {
  55. 'Lilou': chosen_els['Lilou'][0],
  56. 'Nikos2': chosen_els['Nikos2'][0]}
  57. chosen_unit = 1
  58. trial_indexes = range(14)
  59. trial_index = trial_indexes[0]
  60. chosen_events = ['TS-ON', 'WS-ON', 'CUE-ON', 'CUE-OFF', 'GO-ON', 'SR-ON',
  61. 'RW-ON', 'WS-OFF'] # , 'RW-OFF'
  62. # =============================================================================
  63. # Load data and metadata for a monkey
  64. # =============================================================================
  65. # CHANGE this parameter to load data of the different monkeys
  66. # monkey = 'Nikos2'
  67. monkey = 'Lilou'
  68. datafile = get_monkey_datafile(monkey)
  69. session = reachgraspio.ReachGraspIO(
  70. filename=os.path.join(datasetdir, datafile),
  71. odml_directory=datasetdir,
  72. verbose=False)
  73. bl = session.read_block(lazy=True)
  74. # channels=chosen_els[monkey],
  75. # units=[1], # loading only unit_id 1
  76. # load_waveforms=False,
  77. # load_events=True,
  78. # scaling='voltage')
  79. seg = bl.segments[0]
  80. # get start and stop events of trials
  81. start_events = get_events(
  82. seg, **{
  83. 'name': 'TrialEvents',
  84. 'trial_event_labels': 'TS-ON',
  85. 'performance_in_trial': session.performance_codes['correct_trial']})
  86. stop_events = get_events(
  87. seg, **{
  88. 'name': 'TrialEvents',
  89. 'trial_event_labels': 'RW-ON',
  90. 'performance_in_trial': session.performance_codes['correct_trial']})
  91. # there should only be one event object for these conditions
  92. assert len(start_events) == 1
  93. assert len(stop_events) == 1
  94. # insert epochs between 10ms before TS to 50ms after RW corresponding to trails
  95. ep = add_epoch(seg,
  96. start_events[0],
  97. stop_events[0],
  98. pre=-250 * pq.ms,
  99. post=500 * pq.ms,
  100. segment_type='complete_trials')
  101. ep.array_annotate(trialtype=start_events[0].array_annotations['belongs_to_trialtype'])
  102. # access single epoch of this data_segment
  103. epochs = get_epochs(seg, **{'segment_type': 'complete_trials'})
  104. assert len(epochs) == 1
  105. # remove spiketrains not belonging to chosen_electrode
  106. seg.spiketrains = seg.filter(targdict={'unit_id': chosen_unit},
  107. recursive=True, objects='SpikeTrainProxy')
  108. # remove all non-neural signals
  109. seg.analogsignals = seg.filter(targdict={'neural_signal': True},
  110. objects='AnalogSignalProxy')
  111. # use most raw data if multiple versions are present
  112. raw_signal = seg.analogsignals[0]
  113. for sig in seg.analogsignals:
  114. if sig.sampling_rate > raw_signal.sampling_rate:
  115. raw_signal = sig
  116. seg.analogsignals = [raw_signal]
  117. # replacing the segment with a new segment containing all data
  118. # to speed up cutting of segments
  119. seg = load_segment(seg, load_wavefroms=True)
  120. # only keep the chosen electrode signal in the AnalogSignal object
  121. mask = np.isin(seg.analogsignals[0].array_annotations['channel_ids'], chosen_els[monkey])
  122. seg.analogsignals[0] = seg.analogsignals[0][:, mask]
  123. # cut segments according to inserted 'complete_trials' epochs and reset trial
  124. # times
  125. cut_segments = cut_segment_by_epoch(seg, epochs[0], reset_time=True)
  126. # explicitly adding trial type annotations to cut segments
  127. for i, cut_seg in enumerate(cut_segments):
  128. cut_seg.annotate(trialtype=epochs[0].array_annotations['trialtype'][i])
  129. # =============================================================================
  130. # Define figure and subplot axis for first data overview
  131. # =============================================================================
  132. fig = plt.figure(facecolor='w')
  133. fig.set_size_inches(7.0, 9.9) # (w, h) in inches
  134. # #(7.0, 9.9) corresponds to A4 portrait ratio
  135. gs = gridspec.GridSpec(
  136. nrows=2,
  137. ncols=2,
  138. left=0.1,
  139. bottom=0.05,
  140. right=0.9,
  141. top=0.975,
  142. wspace=0.1,
  143. hspace=0.1,
  144. width_ratios=None,
  145. height_ratios=[2, 1])
  146. ax1 = plt.subplot(gs[0, 0]) # top left
  147. ax2 = plt.subplot(gs[0, 1], sharex=ax1) # top right
  148. ax3 = plt.subplot(gs[1, 0], sharex=ax1) # bottom left
  149. ax4 = plt.subplot(gs[1, 1], sharex=ax1) # bottom right
  150. fontdict_titles = {'fontsize': 9, 'fontweight': 'bold'}
  151. fontdict_axis = {'fontsize': 10, 'fontweight': 'bold'}
  152. # the x coords of the event labels are data, and the y coord are axes
  153. event_label_transform = transforms.blended_transform_factory(ax1.transData,
  154. ax1.transAxes)
  155. trialtype_colors = {
  156. 'SGHF': 'MediumBlue', 'SGLF': 'Turquoise',
  157. 'PGHF': 'DarkGreen', 'PGLF': 'YellowGreen',
  158. 'LFSG': 'Orange', 'LFPG': 'Yellow',
  159. 'HFSG': 'DarkRed', 'HFPG': 'OrangeRed',
  160. 'SGSG': 'SteelBlue', 'PGPG': 'LimeGreen',
  161. None: 'black'}
  162. event_colors = {
  163. 'TS-ON': 'indigo', 'TS-OFF': 'indigo',
  164. 'WS-ON': 'purple', 'WS-OFF': 'purple',
  165. 'CUE-ON': 'crimson', 'CUE-OFF': 'crimson',
  166. 'GO-ON': 'orangered', 'GO-OFF': 'orangered',
  167. 'SR-ON': 'darkorange',
  168. 'RW-ON': 'orange', 'RW-OFF': 'orange'}
  169. electrode_cmap = plt.get_cmap('bone')
  170. electrode_colors = [electrode_cmap(x) for x in
  171. np.tile(np.array([0.3, 0.7]), int(len(chosen_els[monkey]) / 2))]
  172. time_unit = 'ms'
  173. lfp_unit = 'uV'
  174. # define scaling factors for analogsignals
  175. anasig_std = np.mean(np.std(cut_segments[trial_index].analogsignals[0].rescale(lfp_unit), axis=0))
  176. anasig_offset = 3 * anasig_std
  177. # =============================================================================
  178. # SUPPLEMENTARY PLOTTING functions
  179. # =============================================================================
  180. def add_scalebar(ax, std):
  181. # the x coords of the scale bar are axis, and the y coord are data
  182. scalebar_transform = transforms.blended_transform_factory(ax.transAxes,
  183. ax.transData)
  184. # adding scalebar
  185. yscalebar = max(int(std.rescale(lfp_unit)), 1) * getattr(pq, lfp_unit) * 2
  186. scalebar_offset = -2 * std
  187. ax.vlines(x=0.4,
  188. ymin=(scalebar_offset - yscalebar).magnitude,
  189. ymax=scalebar_offset.magnitude,
  190. color='k',
  191. linewidth=4,
  192. transform=scalebar_transform)
  193. ax.text(0.4, (scalebar_offset - 0.5 * yscalebar).magnitude,
  194. ' %i %s' % (yscalebar.magnitude, lfp_unit),
  195. ha="left", va="center", rotation=0, color='k',
  196. size=8, transform=scalebar_transform)
  197. # =============================================================================
  198. # PLOT DATA OF SINGLE TRIAL (left plots)
  199. # =============================================================================
  200. # get data of selected trial
  201. selected_trial = cut_segments[trial_index]
  202. # PLOT DATA FOR EACH CHOSEN ELECTRODE
  203. for el_idx, electrode_id in enumerate(chosen_els[monkey]):
  204. # PLOT ANALOGSIGNALS in upper plot
  205. chosen_el_idx = np.where(cut_segments[0].analogsignals[0].array_annotations['channel_ids'] == electrode_id)[0][0]
  206. anasig = selected_trial.analogsignals[0][:, chosen_el_idx]
  207. ax1.plot(anasig.times.rescale(time_unit),
  208. np.asarray(anasig.rescale(lfp_unit))
  209. + anasig_offset.magnitude * el_idx,
  210. color=electrode_colors[el_idx])
  211. # PLOT SPIKETRAINS in lower plot
  212. spiketrains = selected_trial.filter(
  213. channel_id=electrode_id, objects=SpikeTrain)
  214. for spiketrain in spiketrains:
  215. ax3.plot(spiketrain.times.rescale(time_unit),
  216. np.zeros(len(spiketrain.times)) + el_idx, 'k|')
  217. # PLOT EVENTS in both plots
  218. for event_type in chosen_events:
  219. # get events of each chosen event type
  220. event_data = get_events(selected_trial,
  221. **{'trial_event_labels': event_type})
  222. for event in event_data:
  223. event_color = event_colors[event.array_annotations['trial_event_labels'][0]]
  224. # adding lines
  225. for ax in [ax1, ax3]:
  226. ax.axvline(event.times.rescale(time_unit),
  227. color=event_color,
  228. zorder=0.5)
  229. # adding labels
  230. ax1.text(event.times.rescale(time_unit), 0,
  231. event.array_annotations['trial_event_labels'][0],
  232. ha="center", va="top", rotation=45, color=event_color,
  233. size=8, transform=event_label_transform)
  234. # SUBPLOT ADJUSTMENTS
  235. ax1.set_title('single trial', fontdict=fontdict_titles)
  236. ax1.set_ylabel('electrode id', fontdict=fontdict_axis)
  237. ax1.set_yticks(np.arange(len(chosen_els[monkey])) * anasig_offset)
  238. ax1.set_yticklabels(chosen_els[monkey])
  239. ax1.autoscale(enable=True, axis='y')
  240. plt.setp(ax1.get_xticklabels(), visible=False) # show no xticklabels
  241. ax3.set_ylabel('electrode id', fontdict=fontdict_axis)
  242. ax3.set_yticks(range(0, len(chosen_els[monkey])))
  243. ax3.set_yticklabels(np.asarray(chosen_els[monkey]))
  244. ax3.set_ylim(-1, len(chosen_els[monkey]))
  245. ax3.set_xlabel('time [%s]' % time_unit, fontdict=fontdict_axis)
  246. # ax3.autoscale(axis='y')
  247. # =============================================================================
  248. # PLOT DATA OF SINGLE ELECTRODE
  249. # =============================================================================
  250. # plot data for each chosen trial
  251. chosen_el_idx = np.where(cut_segments[0].analogsignals[0].array_annotations['channel_ids'] == chosen_el[monkey])[0][0]
  252. for trial_idx, trial_id in enumerate(trial_indexes):
  253. trial_spikes = cut_segments[trial_id].filter(channel_id=chosen_el[monkey], objects='SpikeTrain')
  254. trial_type = cut_segments[trial_id].annotations['trialtype']
  255. trial_color = trialtype_colors[trial_type]
  256. t_signal = cut_segments[trial_id].analogsignals[0][:, chosen_el_idx]
  257. # PLOT ANALOGSIGNALS in upper plot
  258. ax2.plot(t_signal.times.rescale(time_unit),
  259. np.asarray(t_signal.rescale(lfp_unit))
  260. + anasig_offset.magnitude * trial_idx,
  261. color=trial_color, zorder=1)
  262. for t_data in trial_spikes:
  263. # PLOT SPIKETRAINS in lower plot
  264. ax4.plot(t_data.times.rescale(time_unit),
  265. np.ones(len(t_data.times)) + trial_idx, 'k|')
  266. # PLOT EVENTS in both plots
  267. for event_type in chosen_events:
  268. # get events of each chosen event type
  269. event_data = get_events(cut_segments[trial_id], **{'trial_event_labels': event_type})
  270. for event in event_data:
  271. color = event_colors[event.array_annotations['trial_event_labels'][0]]
  272. ax2.vlines(x=event.times.rescale(time_unit),
  273. ymin=(trial_idx - 0.5) * anasig_offset,
  274. ymax=(trial_idx + 0.5) * anasig_offset,
  275. color=color,
  276. zorder=2)
  277. ax4.vlines(x=event.times.rescale(time_unit),
  278. ymin=trial_idx + 1 - 0.4,
  279. ymax=trial_idx + 1 + 0.4,
  280. color=color,
  281. zorder=0.5)
  282. # SUBPLOT ADJUSTMENTS
  283. ax2.set_title('single electrode', fontdict=fontdict_titles)
  284. ax2.set_ylabel('trial id', fontdict=fontdict_axis)
  285. ax2.set_yticks(np.asarray(trial_indexes) * anasig_offset)
  286. ax2.set_yticklabels(
  287. [epochs[0].array_annotations['trial_id'][_] for _ in trial_indexes])
  288. ax2.yaxis.set_label_position("right")
  289. ax2.tick_params(direction='in', length=3, labelleft='off', labelright='on')
  290. ax2.autoscale(enable=True, axis='y')
  291. add_scalebar(ax2, anasig_std)
  292. plt.setp(ax2.get_xticklabels(), visible=False) # show no xticklabels
  293. ax4.set_ylabel('trial id', fontdict=fontdict_axis)
  294. ax4.set_xlabel('time [%s]' % time_unit, fontdict=fontdict_axis)
  295. start, end = ax4.get_xlim()
  296. ax4.xaxis.set_ticks(np.arange(start, end, 1000))
  297. ax4.xaxis.set_ticks(np.arange(start, end, 500), minor=True)
  298. ax4.set_yticks(range(1, len(trial_indexes) + 1))
  299. ax4.set_yticklabels(np.asarray(
  300. [epochs[0].array_annotations['trial_id'][_] for _ in trial_indexes]))
  301. ax4.yaxis.set_label_position("right")
  302. ax4.tick_params(direction='in', length=3, labelleft='off', labelright='on')
  303. ax4.autoscale(enable=True, axis='y')
  304. # GENERAL PLOT ADJUSTMENTS
  305. # adjust font sizes of ticks
  306. for ax in [ax4.yaxis, ax4.xaxis, ax3.xaxis, ax3.yaxis]:
  307. for tick in ax.get_major_ticks():
  308. tick.label.set_fontsize(10)
  309. # adjust time range on x axis
  310. t_min = np.min([cut_segments[tid].t_start.rescale(time_unit)
  311. for tid in trial_indexes])
  312. t_max = np.max([cut_segments[tid].t_stop.rescale(time_unit)
  313. for tid in trial_indexes])
  314. ax1.set_xlim(t_min, t_max)
  315. add_scalebar(ax1, anasig_std)
  316. # =============================================================================
  317. # SAVE FIGURE
  318. # =============================================================================
  319. fname = 'data_overview_2_%s' % monkey
  320. for file_format in ['eps', 'pdf', 'png']:
  321. fig.savefig(fname + '.%s' % file_format, dpi=400, format=file_format)