Browse Source

gin commit from nit-ope-de04

Modified files: 1
sprenger 3 years ago
parent
commit
8aee311657
1 changed files with 71 additions and 69 deletions
  1. 71 69
      code/data_overview_2.py

+ 71 - 69
code/data_overview_2.py

@@ -46,9 +46,11 @@ import quantities as pq
 import numpy as np
 
 from neo import (AnalogSignal, SpikeTrain)
+from neo.utils import *
 from reachgraspio import reachgraspio
 
-import neo_utils
+
+from neo_utils import load_segment
 
 # =============================================================================
 # Define data and metadata directories and general settings
@@ -67,12 +69,11 @@ def get_monkey_datafile(monkey):
 # Enter your dataset directory here
 datasetdir = os.path.join('..', 'datasets')
 
-nsx_none = {'Lilou': None, 'Nikos2': None}
-nsx_lfp = {'Lilou': 5, 'Nikos2': 2}
 chosen_els = {'Lilou': range(3, 97, 7), 'Nikos2': range(1, 97, 7)}
 chosen_el = {
     'Lilou': chosen_els['Lilou'][0],
     'Nikos2': chosen_els['Nikos2'][0]}
+chosen_unit = 1
 trial_indexes = range(14)
 trial_index = trial_indexes[0]
 chosen_events = ['TS-ON', 'WS-ON', 'CUE-ON', 'CUE-OFF', 'GO-ON', 'SR-ON',
@@ -92,31 +93,23 @@ session = reachgraspio.ReachGraspIO(
     odml_directory=datasetdir,
     verbose=False)
 
-bl = session.read_block(
-    index=None,
-    name=None,
-    description=None,
-    nsx_to_load=nsx_lfp[monkey],
-    n_starts=None,
-    n_stops=None,
-    channels=chosen_els[monkey],
-    units=[1],  # loading only unit_id 1
-    load_waveforms=False,
-    load_events=True,
-    scaling='voltage',
-    lazy=False,
-    cascade=True)
+bl = session.read_block(lazy=True)
+    # channels=chosen_els[monkey],
+    # units=[1],  # loading only unit_id 1
+    # load_waveforms=False,
+    # load_events=True,
+    # scaling='voltage')
 
 seg = bl.segments[0]
 
 # get start and stop events of trials
-start_events = neo_utils.get_events(
-    seg, properties={
+start_events = get_events(
+    seg, **{
         'name': 'TrialEvents',
         'trial_event_labels': 'TS-ON',
         'performance_in_trial': session.performance_codes['correct_trial']})
-stop_events = neo_utils.get_events(
-    seg, properties={
+stop_events = get_events(
+    seg, **{
         'name': 'TrialEvents',
         'trial_event_labels': 'RW-ON',
         'performance_in_trial': session.performance_codes['correct_trial']})
@@ -126,30 +119,41 @@ assert len(start_events) == 1
 assert len(stop_events) == 1
 
 # insert epochs between 10ms before TS to 50ms after RW corresponding to trails
-neo_utils.add_epoch(
-    seg,
-    start_events[0],
-    stop_events[0],
-    pre=-250 * pq.ms,
-    post=500 * pq.ms,
-    segment_type='complete_trials',
-    trialtype=start_events[0].annotations[
-        'belongs_to_trialtype'])
+ep = add_epoch(seg,
+               start_events[0],
+               stop_events[0],
+               pre=-250 * pq.ms,
+               post=500 * pq.ms,
+               segment_type='complete_trials')
+ep.array_annotate(trialtype=start_events[0].array_annotations['belongs_to_trialtype'])
 
 # access single epoch of this data_segment
-epochs = neo_utils.get_epochs(seg,
-                              properties={'segment_type': 'complete_trials'})
+epochs = get_epochs(seg, **{'segment_type': 'complete_trials'})
 assert len(epochs) == 1
 
+# remove spiketrains not belonging to chosen_electrode
+seg.spiketrains = seg.filter(targdict={'unit_id': chosen_unit},
+                             recursive=True, objects='SpikeTrainProxy')
+# remove all non-neural signals
+seg.analogsignals = seg.filter(targdict={'neural_signal': True},
+                               objects='AnalogSignalProxy')
+# replacing the segment with a new segment containing all data
+# to speed up cutting of segments
+seg = load_segment(seg, load_wavefroms=True)
+assert len(seg.analogsignals) == 1
+
+# only keep the chosen electrode signal in the AnalogSignal object
+mask = np.isin(seg.analogsignals[0].array_annotations['channel_ids'], chosen_els[monkey])
+
+seg.analogsignals[0] = seg.analogsignals[0][:, mask]
+
 # cut segments according to inserted 'complete_trials' epochs and reset trial
 # times
-cut_segments = neo_utils.cut_segment_by_epoch(seg,
-                                              epochs[0],
-                                              reset_time=True)
+cut_segments = cut_segment_by_epoch(seg, epochs[0], reset_time=True)
 
-# explicitely adding trial type annotations to cut segments
+# explicitly adding trial type annotations to cut segments
 for i, cut_seg in enumerate(cut_segments):
-    cut_seg.annotate(trialtype=epochs[0].annotations['trialtype'][i])
+    cut_seg.annotate(trialtype=epochs[0].array_annotations['trialtype'][i])
 
 # =============================================================================
 # Define figure and subplot axis for first data overview
@@ -206,14 +210,13 @@ time_unit = 'ms'
 lfp_unit = 'uV'
 
 # define scaling factors for analogsignals
-anasig_std = np.mean([np.std(anasig.rescale(lfp_unit)) for anasig in
-                      cut_segments[trial_index].analogsignals]) \
-    * getattr(pq, lfp_unit)
+
+anasig_std = np.mean(np.std(cut_segments[trial_index].analogsignals[0].rescale(lfp_unit), axis=0))
 anasig_offset = 3 * anasig_std
 
 
 # =============================================================================
-# SUPPLEMENTORY PLOTTING functions
+# SUPPLEMENTARY PLOTTING functions
 # =============================================================================
 
 def add_scalebar(ax, std):
@@ -246,13 +249,12 @@ selected_trial = cut_segments[trial_index]
 for el_idx, electrode_id in enumerate(chosen_els[monkey]):
 
     # PLOT ANALOGSIGNALS in upper plot
-    anasigs = selected_trial.filter(
-        channel_id=electrode_id, objects=AnalogSignal)
-    for anasig in anasigs:
-        ax1.plot(anasig.times.rescale(time_unit),
-                 np.asarray(anasig.rescale(lfp_unit))
-                 + anasig_offset.magnitude * el_idx,
-                 color=electrode_colors[el_idx])
+    chosen_el_idx = np.where(cut_segments[0].analogsignals[0].array_annotations['channel_ids'] == electrode_id)[0][0]
+    anasig = selected_trial.analogsignals[0][:, chosen_el_idx]
+    ax1.plot(anasig.times.rescale(time_unit),
+             np.asarray(anasig.rescale(lfp_unit))
+             + anasig_offset.magnitude * el_idx,
+             color=electrode_colors[el_idx])
 
     # PLOT SPIKETRAINS in lower plot
     spiketrains = selected_trial.filter(
@@ -264,10 +266,10 @@ for el_idx, electrode_id in enumerate(chosen_els[monkey]):
 # PLOT EVENTS in both plots
 for event_type in chosen_events:
     # get events of each chosen event type
-    event_data = neo_utils.get_events(selected_trial,
-                                      {'trial_event_labels': event_type})
+    event_data = get_events(selected_trial,
+                                      **{'trial_event_labels': event_type})
     for event in event_data:
-        event_color = event_colors[event.annotations['trial_event_labels'][0]]
+        event_color = event_colors[event.array_annotations['trial_event_labels'][0]]
         # adding lines
         for ax in [ax1, ax3]:
             ax.axvline(event.times.rescale(time_unit),
@@ -275,7 +277,7 @@ for event_type in chosen_events:
                        zorder=0.5)
         # adding labels
         ax1.text(event.times.rescale(time_unit), 0,
-                 event.annotations['trial_event_labels'][0],
+                 event.array_annotations['trial_event_labels'][0],
                  ha="center", va="top", rotation=45, color=event_color,
                  size=8, transform=event_label_transform)
 
@@ -298,32 +300,32 @@ ax3.set_xlabel('time [%s]' % time_unit, fontdict=fontdict_axis)
 # PLOT DATA OF SINGLE ELECTRODE
 # =============================================================================
 
+
 # plot data for each chosen trial
+chosen_el_idx = np.where(cut_segments[0].analogsignals[0].array_annotations['channel_ids'] == chosen_el[monkey])[0][0]
 for trial_idx, trial_id in enumerate(trial_indexes):
-    trial_data = cut_segments[trial_id].filter(channel_id=chosen_el[monkey])
-    trial_type = trial_data[0].parents[0].annotations['trialtype']
+    trial_spikes = cut_segments[trial_id].filter(channel_id=chosen_el[monkey], objects='SpikeTrain')
+    trial_type = cut_segments[trial_id].annotations['trialtype']
     trial_color = trialtype_colors[trial_type]
-    for t_data in trial_data:
 
-        # PLOT ANALOGSIGNALS in upper plot
-        if isinstance(t_data, AnalogSignal):
-            ax2.plot(t_data.times.rescale(time_unit),
-                     np.asarray(t_data.rescale(lfp_unit))
-                     + anasig_offset.magnitude * trial_idx,
-                     color=trial_color, zorder=1)
+    t_signal = cut_segments[trial_id].analogsignals[0][:, chosen_el_idx]
+    # PLOT ANALOGSIGNALS in upper plot
+    ax2.plot(t_signal.times.rescale(time_unit),
+             np.asarray(t_signal.rescale(lfp_unit))
+             + anasig_offset.magnitude * trial_idx,
+             color=trial_color, zorder=1)
 
+    for t_data in trial_spikes:
         # PLOT SPIKETRAINS in lower plot
-        elif isinstance(t_data, SpikeTrain):
-            ax4.plot(t_data.times.rescale(time_unit),
-                     np.ones(len(t_data.times)) + trial_idx, 'k|')
+        ax4.plot(t_data.times.rescale(time_unit),
+                 np.ones(len(t_data.times)) + trial_idx, 'k|')
 
     # PLOT EVENTS in both plots
     for event_type in chosen_events:
         # get events of each chosen event type
-        event_data = neo_utils.get_events(cut_segments[trial_id],
-                                          {'trial_event_labels': event_type})
+        event_data = get_events(cut_segments[trial_id], **{'trial_event_labels': event_type})
         for event in event_data:
-            color = event_colors[event.annotations['trial_event_labels'][0]]
+            color = event_colors[event.array_annotations['trial_event_labels'][0]]
             ax2.vlines(x=event.times.rescale(time_unit),
                        ymin=(trial_idx - 0.5) * anasig_offset,
                        ymax=(trial_idx + 0.5) * anasig_offset,
@@ -340,7 +342,7 @@ ax2.set_title('single electrode', fontdict=fontdict_titles)
 ax2.set_ylabel('trial id', fontdict=fontdict_axis)
 ax2.set_yticks(np.asarray(trial_indexes) * anasig_offset)
 ax2.set_yticklabels(
-    [epochs[0].annotations['trial_id'][_] for _ in trial_indexes])
+    [epochs[0].array_annotations['trial_id'][_] for _ in trial_indexes])
 ax2.yaxis.set_label_position("right")
 ax2.tick_params(direction='in', length=3, labelleft='off', labelright='on')
 ax2.autoscale(enable=True, axis='y')
@@ -355,7 +357,7 @@ ax4.xaxis.set_ticks(np.arange(start, end, 1000))
 ax4.xaxis.set_ticks(np.arange(start, end, 500), minor=True)
 ax4.set_yticks(range(1, len(trial_indexes) + 1))
 ax4.set_yticklabels(np.asarray(
-    [epochs[0].annotations['trial_id'][_] for _ in trial_indexes]))
+    [epochs[0].array_annotations['trial_id'][_] for _ in trial_indexes]))
 ax4.yaxis.set_label_position("right")
 ax4.tick_params(direction='in', length=3, labelleft='off', labelright='on')
 ax4.autoscale(enable=True, axis='y')