Browse Source

gin commit from nit-ope-de04

Modified files: 1
sprenger 3 years ago
parent
commit
a20c091f4b
1 changed files with 112 additions and 100 deletions
  1. 112 100
      code/data_overview_1.py

+ 112 - 100
code/data_overview_1.py

@@ -50,7 +50,7 @@ from reachgraspio import reachgraspio
 
 import odml.tools
 
-import neo_utils
+from neo import utils as neo_utils
 import odml_utils
 
 
@@ -102,22 +102,24 @@ def force_aspect(ax, aspect=1):
         (ax.get_ylim()[1] - ax.get_ylim()[0])) / aspect)
 
 
-def get_arraygrid(blackrock_elid_list, chosen_el, rej_el=None):
-    if rej_el is None:
-        rej_el = []
-    array_grid = np.zeros((10, 10))
-    for m in range(10):
-        for n in range(10):
-            idx = (9 - m) * 10 + n
-            bl_id = blackrock_elid_list[idx]
-            if bl_id == -1:
-                array_grid[m, n] = 0.7
-            elif bl_id == chosen_el:
-                array_grid[m, n] = -0.7
-            elif bl_id in rej_el:
-                array_grid[m, n] = -0.35
-            else:
-                array_grid[m, n] = 0
+def get_arraygrid(signals, chosen_el):
+    array_grid = np.ones((10, 10)) * 0.7
+
+    rejections = np.logical_or(signals.array_annotations['electrode_reject_HFC'],
+                               signals.array_annotations['electrode_reject_LFC'],
+                               signals.array_annotations['electrode_reject_IFC'])
+
+    for sig_idx in range(signals.shape[-1]):
+        connector_aligned_id = signals.array_annotations['connector_aligned_ids'][sig_idx]
+        x, y = int((connector_aligned_id -1)// 10), int((connector_aligned_id - 1) % 10)
+
+        if signals.array_annotations['channel_ids'][sig_idx] == chosen_el:
+            array_grid[x, y] = -0.7
+        elif rejections[sig_idx]:
+            array_grid[x, y] = -0.35
+        else:
+            array_grid[x, y] = 0
+
     return np.ma.array(array_grid, mask=np.isnan(array_grid))
 
 
@@ -138,14 +140,8 @@ session = reachgraspio.ReachGraspIO(
     odml_directory=datasetdir,
     verbose=False)
 
-# loads only ns2 data of all channels an chosen units
-#bl_lfp = session.read_block(
-
-# loads raw data of chosen electrode and chosen units
-#bl_raw = session.read_block(
-
 block = session.read_block(lazy=True)
-segment = block.segment[0]
+segment = block.segments[0]
 
 # Displaying loaded data structure as string output
 print("\nBlock")
@@ -160,16 +156,16 @@ for x in segment.events:
     print('\t\tAttributes ', x.__dict__.keys())
     print('\t\tAnnotation keys', x.annotations.keys())
     print('\t\ttimes', x.times[:20])
-    for anno_key in ['trial_id', 'trial_timestamp_id', 'trial_event_labels',
-                     'trial_reject_IFC']:
-        print('\t\t'+anno_key, x.annotations[anno_key][:20])
+    if x.name == 'TrialEvents':
+        for anno_key in ['trial_id', 'trial_timestamp_id', 'trial_event_labels',
+                         'trial_reject_IFC']:
+            print('\t\t'+anno_key, x.array_annotations[anno_key][:20])
 
 print("\nGroups")
 for x in block.groups:
     print('\tGroup with name', x.name)
     print('\t\tAttributes ', x.__dict__.keys())
     print('\t\tAnnotations', x.annotations)
-    # TODO: Add more here
 
 print("\nSpikeTrains")
 for x in segment.spiketrains:
@@ -180,52 +176,58 @@ for x in segment.spiketrains:
     print('\t\tunit_id', x.annotations['unit_id'])
     print('\t\tis sua', x.annotations['sua'])
     print('\t\tis mua', x.annotations['mua'])
-    print('\t\tspike times', x.times[0:20])
+
 print("\nAnalogSignals")
 for x in segment.analogsignals:
     print('\tAnalogSignal with name', x.name)
     print('\t\tAttributes ', x.__dict__.keys())
     print('\t\tAnnotations', x.annotations)
-    print('\t\tchannel_ids', x.annotations['channel_ids'])
+    print('\t\tchannel_ids', x.array_annotations['channel_ids'])
 
 # get start and stop events of trials
 start_events = neo_utils.get_events(
     segment,
-    properties={
-        'name': 'TrialEvents',
-        'trial_event_labels': 'TS-ON',
-        'performance_in_trial': 255})
+    **{
+    'name': 'TrialEvents',
+    'trial_event_labels': 'TS-ON',
+    'performance_in_trial': 255})
 stop_events = neo_utils.get_events(
     segment,
-    properties={
-        'name': 'TrialEvents',
-        'trial_event_labels': 'STOP',
-        'performance_in_trial': 255})
+    **{
+    'name': 'TrialEvents',
+    'trial_event_labels': 'STOP',
+    'performance_in_trial': 255})
 
 # there should only be one event object for these conditions
 assert len(start_events) == 1
 assert len(stop_events) == 1
 
 # insert epochs between 10ms before TS to 50ms after RW corresponding to trails
-neo_utils.add_epoch(
+ep = neo_utils.add_epoch(
     segment,
     start_events[0],
     stop_events[0],
     pre=-250 * pq.ms,
     post=500 * pq.ms,
-    trial_status='complete_trials',
-    trial_type=start_events[0].annotations['belongs_to_trialtype'],
-    trial_performance=start_events[0].annotations['performance_in_trial'])
+    trial_status='complete_trials')
+ep.array_annotate(trial_type=start_events[0].array_annotations['belongs_to_trialtype'],
+                  trial_performance=start_events[0].array_annotations['performance_in_trial'])
 
 # access single epoch of this data_segment
-epochs = neo_utils.get_epochs(segment,
-                              properties={'trial_status': 'complete_trials'})
+epochs = neo_utils.get_epochs(segment, **{'trial_status': 'complete_trials'})
 assert len(epochs) == 1
 
-# cut segments according to inserted 'complete_trials' epochs and reset trial
-#  times
-cut_segments = neo_utils.cut_segment_by_epoch(
-    segment, epochs[0], reset_time=True)
+# remove spiketrains not belonging to chosen_electrode
+segment.spiketrains = segment.filter(targdict={'channel_id': chosen_el[monkey]},
+                                     recursive=True, objects='SpikeTrainProxy')
+segment.spiketrains = [st for st in segment.spiketrains if st.annotations['unit_id'] in range(1, 5)]
+# replacing the segment with a new segment containing all data
+# to speed up cutting of segments
+from neo_utils import load_segment
+segment = load_segment(segment, load_wavefroms=True, channel_indexes=[chosen_el[monkey]])
+
+# cut segments according to inserted 'complete_trials' epochs and reset trial times
+cut_segments = neo_utils.cut_segment_by_epoch(segment, epochs[0], reset_time=True)
 
 # =============================================================================
 # Define data for overview plots
@@ -238,9 +240,9 @@ blackrock_elid_list = block.annotations['avail_electrode_ids']
 
 # get 'TrialEvents'
 event = trial_segment.events[2]
-start = event.annotations['trial_event_labels'].index('TS-ON')
-trialx_trty = event.annotations['belongs_to_trialtype'][start]
-trialx_trtimeid = event.annotations['trial_timestamp_id'][start]
+start = np.where(event.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
+trialx_trty = event.array_annotations['belongs_to_trialtype'][start]
+trialx_trtimeid = event.array_annotations['trial_timestamp_id'][start]
 trialx_color = trialtype_colors[trialx_trty]
 
 # find trial index for next trial with opposite force type (for ax5b plot)
@@ -251,11 +253,11 @@ else:
 
 for i, tr in enumerate(cut_segments):
     eventz = tr.events[2]
-    nextft = eventz.annotations['trial_event_labels'].index('TS-ON')
-    if eventz.annotations['belongs_to_trialtype'][nextft] == trialz_trty:
-        trialz_trtimeid = eventz.annotations['trial_timestamp_id'][nextft]
+    nextft = np.where(eventz.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
+    if eventz.array_annotations['belongs_to_trialtype'][nextft] == trialz_trty:
+        trialz_trtimeid = eventz.array_annotations['trial_timestamp_id'][nextft]
         trialz_color = trialtype_colors[trialz_trty]
-        trialz_seg_lfp = tr
+        trialz_seg = tr
         break
 
 
@@ -304,7 +306,7 @@ behav_signal_unit = pq.V
 # =============================================================================
 
 # load complete metadata collection
-odmldoc = odml.tools.xmlparser.load(datasetdir + datafile + '.odml')
+odmldoc = odml.load(datasetdir + datafile + '.odml')
 
 # get total trial number
 trno_tot = odml_utils.get_TrialCount(odmldoc)
@@ -382,7 +384,7 @@ leg = ax1.legend(
 leg.draw_frame(False)
 
 # adjust x and y axis
-xticks = [i for i in range(1, 101, 10)] + [100]
+xticks = list(range(1, 101, 10)) + [100]
 ax1.set_xticks(xticks)
 ax1.set_xticklabels([str(int(t)) for t in xticks], size='xx-small')
 ax1.set_xlabel('trial ID', size='x-small')
@@ -392,7 +394,7 @@ ax1.set_ylim(0, 3)
 ax1.spines['top'].set_visible(False)
 ax1.spines['left'].set_visible(False)
 ax1.spines['right'].set_visible(False)
-ax1.tick_params(direction='out', top='off')
+ax1.tick_params(direction='out', top=False, left=False, right=False)
 ax1.set_title('sequence of the first 100 trials', fontdict_titles, y=2)
 ax1.set_aspect('equal')
 
@@ -400,17 +402,22 @@ ax1.set_aspect('equal')
 # =============================================================================
 # PLOT ELECTRODE POSITION of chosen electrode
 # =============================================================================
-arraygrid = get_arraygrid(blackrock_elid_list, chosen_el[monkey])
+neural_signals = [sig for sig in trial_segment.analogsignals if sig.annotations['neural_signal']]
+assert len(neural_signals) == 1
+neural_signals = neural_signals[0]
+
+arraygrid = get_arraygrid(neural_signals, chosen_el[monkey])
 cmap = plt.cm.RdGy
 
 ax2a.pcolormesh(
-    np.flipud(arraygrid), vmin=-1, vmax=1, lw=1, cmap=cmap, edgecolors='k',
-    shading='faceted')
+    arraygrid, vmin=-1, vmax=1, lw=1, cmap=cmap, edgecolors='k',
+    #shading='faceted'
+    )
 
 force_aspect(ax2a, aspect=1)
 ax2a.tick_params(
-    bottom='off', top='off', left='off', right='off',
-    labelbottom='off', labeltop='off', labelleft='off', labelright='off')
+    bottom=False, top=False, left=False, right=False,
+    labelbottom=False, labeltop=False, labelleft=False, labelright=False)
 ax2a.set_title('electrode pos.', fontdict_titles)
 
 
@@ -429,8 +436,10 @@ for spiketrain in trial_segment.spiketrains:
         unit_type[unit_id] = 'SUA'
     elif spiketrain.annotations['mua']:
         unit_type[unit_id] = 'MUA'
+    elif unit_id in [0, 255]:
+        continue
     else:
-        pass
+        raise ValueError(f'Found unit with id {unit_id}, that is not SUA or MUA.')
     # get correct ax
     ax = unit_ax_translator[unit_id]
     # get wf sampling time before threshold crossing
@@ -454,7 +463,7 @@ for unit_id, ax in unit_ax_translator.items():
     ax.set_title('unit %i (%s)' % (unit_id, unit_type[unit_id]),
                  fontdict_titles)
     ax.tick_params(direction='in', length=3, labelsize='xx-small',
-                   labelleft='off', labelright='off')
+                   labelleft=False, labelright=False)
     ax.set_xlabel(wf_time_unit.dimensionality.latex, fontdict_axis)
     xticklocator = ticker.MaxNLocator(nbins=5)
     ax.xaxis.set_major_locator(xticklocator)
@@ -462,7 +471,7 @@ for unit_id, ax in unit_ax_translator.items():
     force_aspect(ax, aspect=1)
 
 # adding ylabel
-ax2d.tick_params(labelsize='xx-small', labelright='on')
+ax2d.tick_params(labelsize='xx-small', labelright=True)
 ax2d.set_ylabel(wf_signal_unit.dimensionality.latex, fontdict_axis)
 ax2d.yaxis.set_label_position("right")
 
@@ -480,13 +489,13 @@ for st in trial_segment.spiketrains:
              np.zeros(len(st.times)) + unit_id,
              'k|')
 
-# setting layout of spiktrain plot
+# setting layout of spiketrain plot
 ax3.set_ylim(min(plotted_unit_ids) - 0.5, max(plotted_unit_ids) + 0.5)
 ax3.set_ylabel(r'unit ID', fontdict_axis)
 ax3.yaxis.set_major_locator(ticker.MultipleLocator(base=1))
 ax3.yaxis.set_label_position("right")
 ax3.tick_params(axis='y', direction='in', length=3, labelsize='xx-small',
-                labelleft='off', labelright='on')
+                labelleft=False, labelright=True)
 ax3.invert_yaxis()
 ax3.set_title('spiketrains', fontdict_titles)
 
@@ -494,19 +503,21 @@ ax3.set_title('spiketrains', fontdict_titles)
 # PLOT "raw" SIGNAL of chosen trial of chosen electrode
 # =============================================================================
 # get "raw" data from chosen electrode
-assert len(trial_segment.analogsignals) == 1
-el_raw_sig = trial_segment.analogsignals[0]
+el_raw_sig = [a for a in trial_segment.analogsignals if a.annotations['neural_signal']]
+assert len(el_raw_sig) == 1
+el_raw_sig = el_raw_sig[0]
 
-# plotting raw signal trace
+# plotting raw signal trace of chosen electrode
+chosen_el_idx = np.where(el_raw_sig.array_annotations['channel_ids'] == chosen_el[monkey])[0][0]
 ax4.plot(el_raw_sig.times.rescale(plotting_time_unit),
-         el_raw_sig.squeeze().rescale(raw_signal_unit),
+         el_raw_sig[:, chosen_el_idx].squeeze().rescale(raw_signal_unit),
          color='k')
 
 # setting layout of raw signal plot
 ax4.set_ylabel(raw_signal_unit.units.dimensionality.latex, fontdict_axis)
 ax4.yaxis.set_label_position("right")
 ax4.tick_params(axis='y', direction='in', length=3, labelsize='xx-small',
-                labelleft='off', labelright='on')
+                labelleft=False, labelright=True)
 ax4.set_title('"raw" signal', fontdict_titles)
 
 ax4.set_xlim(trial_segment.t_start.rescale(plotting_time_unit),
@@ -518,15 +529,14 @@ ax4.xaxis.set_major_locator(ticker.MultipleLocator(base=1))
 # PLOT EVENTS across ax3 and ax4 and add time bar
 # =============================================================================
 # find trial relevant events
-startidx = event.annotations['trial_event_labels'].index('TS-ON')
-stopidx = event.annotations['trial_event_labels'][startidx:].index('STOP') + \
-    startidx + 1
+startidx = np.where(event.array_annotations['trial_event_labels'] == 'TS-ON')[0][0]
+stopidx = np.where(event.array_annotations['trial_event_labels'][startidx:] == 'STOP')[0][0] + startidx + 1
 
 for ax in [ax3, ax4]:
     xticks = []
     xticklabels = []
     for ev_id, ev in enumerate(event[startidx:stopidx]):
-        ev_labels = event.annotations['trial_event_labels'][startidx:stopidx]
+        ev_labels = event.array_annotations['trial_event_labels'][startidx:stopidx]
         if ev_labels[ev_id] in event_colors.keys():
             ev_color = event_colors[ev_labels[ev_id]]
             ax.axvline(
@@ -542,7 +552,7 @@ for ax in [ax3, ax4]:
     ax.set_xticks(xticks)
     ax.set_xticklabels(xticklabels)
     ax.tick_params(axis='x', direction='out', length=3, labelsize='xx-small',
-                   labeltop='off', top='off')
+                   labeltop=False, top=False)
 
 timebar_ypos = ax4.get_ylim()[0] + np.diff(ax4.get_ylim())[0] / 10
 timebar_labeloffset = np.diff(ax4.get_ylim())[0] * 0.01
@@ -559,12 +569,13 @@ ax4.text(timebar_xmin + 0.25 * pq.s, timebar_ypos + timebar_labeloffset,
 # PLOT BEHAVIORAL SIGNALS of chosen trial
 # =============================================================================
 # get behavioral signals
-# TODO: Adjust this to merged analogsignals
 ainp_signals = [nsig for nsig in trial_segment.analogsignals if
-                nsig.annotations['channel_id'] > 96]
+                not nsig.annotations['neural_signal']][0]
 
-ainp_trialz = [nsig for nsig in trialz_seg_lfp.analogsignals if
-               nsig.annotations['channel_id'] == 141][0]
+force_channel_idx = np.where(ainp_signals.array_annotations['channel_ids'] == 141)[0][0]
+ainp_trialz_signals = [a for a in trialz_seg.analogsignals if not a.annotations['neural_signal']]
+assert len(ainp_trialz_signals)
+ainp_trialz = ainp_trialz_signals[0][:, force_channel_idx]
 
 # find out what signal to use
 trialx_sec = odmldoc['Recording']['TaskSettings']['Trial_%03i' % trialx_trid]
@@ -582,31 +593,32 @@ else:
 
 
 # define time epoch
-startidx = event.annotations['trial_event_labels'].index('SR')
-stopidx = event.annotations['trial_event_labels'].index('OBB')
+startidx = np.where(event.array_annotations['trial_event_labels'] == 'SR')[0][0]
+stopidx = np.where(event.array_annotations['trial_event_labels'] == 'OBB')[0][0]
 sr = event[startidx].rescale(plotting_time_unit)
 stop = event[stopidx].rescale(plotting_time_unit) + 0.050 * pq.s
-startidx = event.annotations['trial_event_labels'].index('FSRplat-ON')
-stopidx = event.annotations['trial_event_labels'].index('FSRplat-OFF')
+startidx = np.where(event.array_annotations['trial_event_labels'] == 'FSRplat-ON')[0][0]
+stopidx = np.where(event.array_annotations['trial_event_labels'] == 'FSRplat-OFF')[0][0]
 fplon = event[startidx].rescale(plotting_time_unit)
 fploff = event[stopidx].rescale(plotting_time_unit)
 
 # define time epoch trialz
-startidx = eventz.annotations['trial_event_labels'].index('FSRplat-ON')
-stopidx = eventz.annotations['trial_event_labels'].index('FSRplat-OFF')
+startidx = np.where(eventz.array_annotations['trial_event_labels'] == 'FSRplat-ON')[0][0]
+stopidx = np.where(eventz.array_annotations['trial_event_labels'] == 'FSRplat-OFF')[0][0]
 fplon_trz = eventz[startidx].rescale(plotting_time_unit)
 fploff_trz = eventz[stopidx].rescale(plotting_time_unit)
 
 # plotting grip force and object displacement
 ai_legend = []
 ai_legend_txt = []
-for ainp in ainp_signals:
-    if ainp.annotations['channel_id'] in trialx_chids:
+for chidx, chid in enumerate(ainp_signals.array_annotations['channel_ids']):
+    ainp = ainp_signals[:, chidx]
+    if ainp.array_annotations['channel_ids'][0] in trialx_chids:
         ainp_times = ainp.times.rescale(plotting_time_unit)
         mask = (ainp_times > sr) & (ainp_times < stop)
         ainp_ampli = stats.zscore(ainp.magnitude[mask])
 
-        if ainp.annotations['channel_id'] != 143:
+        if ainp.array_annotations['channel_ids'][0] != 143:
             color = 'gray'
             ai_legend_txt.append('grip force')
         else:
@@ -616,7 +628,7 @@ for ainp in ainp_signals:
             ax5a.plot(ainp_times[mask], ainp_ampli, color=color)[0])
 
     # get force load of this trial for next plot
-    elif ainp.annotations['channel_id'] == 141:
+    elif ainp.array_annotations['channel_ids'][0] == 141:
         ainp_times = ainp.times.rescale(plotting_time_unit)
         mask = (ainp_times > fplon) & (ainp_times < fploff)
         force_av_01 = np.mean(ainp.rescale(behav_signal_unit).magnitude[mask])
@@ -625,7 +637,7 @@ for ainp in ainp_signals:
 ax5a.set_title('grip force and object displacement', fontdict_titles)
 ax5a.yaxis.set_label_position("left")
 ax5a.tick_params(direction='in', length=3, labelsize='xx-small',
-                 labelleft='off', labelright='on')
+                 labelleft=False, labelright=True)
 ax5a.set_ylabel('zscore', fontdict_axis)
 ax5a.legend(
     ai_legend, ai_legend_txt,
@@ -645,22 +657,22 @@ ax5b.bar([0, 0.6], [force_av_01, force_av_02], bar_width, color=color)
 ax5b.set_title('load/pull force', fontdict_titles)
 ax5b.set_ylabel(behav_signal_unit.units.dimensionality.latex, fontdict_axis)
 ax5b.set_xticks([0, 0.6])
-ax5b.set_xticklabels([trialx_trty, trialz_trty], fontdict_axis)
+ax5b.set_xticklabels([trialx_trty, trialz_trty], fontdict=fontdict_axis)
 ax5b.yaxis.set_label_position("right")
 ax5b.tick_params(direction='in', length=3, labelsize='xx-small',
-                 labelleft='off', labelright='on')
+                 labelleft=False, labelright=True)
 
 # =============================================================================
 # PLOT EVENTS across ax5a and add time bar
 # =============================================================================
 # find trial relevant events
-startidx = event.annotations['trial_event_labels'].index('SR')
-stopidx = event.annotations['trial_event_labels'].index('OBB')
+startidx = np.where(event.array_annotations['trial_event_labels'] == 'SR')[0][0]
+stopidx = np.where(event.array_annotations['trial_event_labels'] == 'OBB')[0][0]
 
 xticks = []
 xticklabels = []
 for ev_id, ev in enumerate(event[startidx:stopidx]):
-    ev_labels = event.annotations['trial_event_labels'][startidx:stopidx + 1]
+    ev_labels = event.array_annotations['trial_event_labels'][startidx:stopidx + 1]
     if ev_labels[ev_id] in ['RW-ON']:
         ax5a.axvline(ev.rescale(plotting_time_unit), color='k', zorder=0.5)
         xticks.append(ev.rescale(plotting_time_unit))
@@ -678,9 +690,9 @@ for ev_id, ev in enumerate(event[startidx:stopidx]):
             ev.rescale(plotting_time_unit), color='k', ls='-.', zorder=0.5)
 
 ax5a.set_xticks(xticks)
-ax5a.set_xticklabels(xticklabels, fontdict_axis, rotation=90)
+ax5a.set_xticklabels(xticklabels, fontdict=fontdict_axis, rotation=90)
 ax5a.tick_params(axis='x', direction='out', length=3, labelsize='xx-small',
-                 labeltop='off', top='off')
+                 labeltop=False, top=False)
 ax5a.set_ylim([-2.0, 2.0])
 
 timebar_ypos = ax5a.get_ylim()[0] + np.diff(ax5a.get_ylim())[0] / 10