Browse Source

First running version with neo 0.8dev

Julia Sprenger 3 years ago
parent
commit
16b7734e7f
1 changed files with 127 additions and 97 deletions
  1. 127 97
      code/reachgraspio/reachgraspio.py

+ 127 - 97
code/reachgraspio/reachgraspio.py

@@ -53,6 +53,7 @@ import quantities as pq
 
 import neo
 from neo.io.blackrockio import BlackrockIO
+from neo.io.proxyobjects import SpikeTrainProxy, AnalogSignalProxy
 
 
 class ReachGraspIO(BlackrockIO):
@@ -437,20 +438,17 @@ class ReachGraspIO(BlackrockIO):
 
         # extract available neuronal ids
         self.avail_electrode_ids = None
+        self.connector_aligned_map = {}
         if self.odmldoc:
             self.avail_electrode_ids = []
             secs = self.odmldoc['UtahArray']['Array'].sections
-            for i in range(1, 101):
-                elidx = [s.properties['ID'].values for s in secs if
-                         s.name.startswith('Electrode') and
-                         s.properties['ConnectorAlignedID'].values[0] == i]
-                if len(elidx) == 0:
-                    self.avail_electrode_ids.append(-1)
-                elif len(elidx) == 1:
-                    self.avail_electrode_ids.append(elidx[0])
-                else:
-                    raise ValueError("Electrode IDs in odML file are corrupt. "
-                                     "ID %i occurs %i times" % (i, len(elidx)))
+            for sec in secs:
+                if not sec.name.startswith('Electrode_'):
+                    continue
+                id = sec.properties['ID'].values[0]
+                ca_id = sec.properties['ConnectorAlignedID'].values[0]
+                self.avail_electrode_ids.append(id)
+                self.connector_aligned_map[id] = ca_id
 
     def __is_set(self, flag, pos):
         """
@@ -591,8 +589,7 @@ class ReachGraspIO(BlackrockIO):
         # Extract end of last complete trial
         stop_label = self.event_labels_codes['STOP'][0]
         if stop_label in events.labels:
-            last_WSoff_idx = len(events.labels) - \
-                             list(events.labels[::-1]).index(stop_label) - 1
+            last_WSoff_idx = len(events.labels) - list(events.labels[::-1]).index(stop_label) - 1
         else:
             last_WSoff_idx = -1
 
@@ -636,9 +633,7 @@ class ReachGraspIO(BlackrockIO):
                         ff = lambda x: x.name.startswith('Trial_')
                         tr_secs = sec.itersections(filter_func=ff)
                         for trial_sec in tr_secs:
-                            if trial_sec.properties[
-                                    'TrialTimestampID'].values[0] == \
-                                    timestamp_id:
+                            if trial_sec.properties['TrialTimestampID'].values[0] == timestamp_id:
                                 ID = trial_sec.properties['TrialID'].values[0]
                     trial_ID.append(ID)
                 # interpretation of GO/RW-OFF
@@ -678,8 +673,7 @@ class ReachGraspIO(BlackrockIO):
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
                     prev_ev = events.labels[i - 1]
-                    if self.event_labels_str[prev_ev] in \
-                            ['TS-ON', 'TS-OFF/STOP']:
+                    if self.event_labels_str[prev_ev] in ['TS-ON', 'TS-OFF/STOP']:
                         trial_event_labels.append('WS-ON')
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
@@ -697,8 +691,7 @@ class ReachGraspIO(BlackrockIO):
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
                     prprev_ev = events.labels[i - 2]
-                    if self.event_labels_str[prprev_ev] in \
-                            ['TS-ON', 'TS-OFF/STOP']:
+                    if self.event_labels_str[prprev_ev] in ['TS-ON', 'TS-OFF/STOP']:
                         trial_event_labels.append('CUE-ON')
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
@@ -709,8 +702,7 @@ class ReachGraspIO(BlackrockIO):
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
                             self.trial_const_sequence_codes['GO-ON'])
-                        trialtypes[timestamp_id] += \
-                            self.event_labels_str[l][:2]
+                        trialtypes[timestamp_id] += self.event_labels_str[l][:2]
                     else:
                         raise ValueError("Unknown trial event sequence.")
                 # interpretation of WS-OFF
@@ -781,11 +773,54 @@ class ReachGraspIO(BlackrockIO):
         events.array_annotate(performance_in_trial=performance_in_trial)
         events.array_annotate(performance_in_trial_str=performance_in_trial_str)
 
-    def __annotate_units_with_odml(self, units):
+    def __create_unit_groups(self, block, view_dict=None):
+        unit_dict = {}
+        for seg in block.segments:
+            for st in seg.spiketrains:
+                chid = st.annotations['channel_id']
+                unit_id = st.annotations['unit_id']
+                if chid not in unit_dict:
+                    unit_dict[chid] = {}
+                if unit_id not in unit_dict[chid]:
+                    group = neo.Group(name='Unit {} on channel {}'.format(unit_id, chid),
+                                      description='Group for neuronal data related to unit {} on '
+                                                  'channel {}'.format(unit_id, chid),
+                                      group_type='unit',
+                                      allowed_types=[neo.SpikeTrain, SpikeTrainProxy,
+                                                     neo.AnalogSignal, AnalogSignalProxy,
+                                                     neo.ChannelView],
+                                      channel_id=chid,
+                                      unit_id=unit_id)
+                    block.groups.append(group)
+                    unit_dict[chid][unit_id] = group
+
+                unit_dict[chid][unit_id].add(st)
+
+        # if views are already created, link them to unit groups
+        if view_dict:
+            for chid, channel_dict in unit_dict.items():
+                for unit_id, group in channel_dict.items():
+                    group.add(view_dict[chid])
+
+    def __create_channel_views(self, block):
+        view_dict = {}
+        for seg in block.segments:
+            for anasig in seg.analogsignals:
+                for chidx, chid in enumerate(anasig.array_annotations['channel_ids']):
+                    if chid not in view_dict:
+                        view = neo.ChannelView(anasig, [chidx],
+                                               name='Channel {} of {}'.format(chid,anasig.name),
+                                               channel_id=chid)
+                        view_dict[chid] = view
+
+        return view_dict
+
+    def __annotate_units_with_odml(self, groups):
         """
         Annotates units with metadata from odml file.
         """
-        # Can the spike sorting info from the odML be matched with the odML?
+        units = [g for g in groups if
+                 'group_type' in g.annotations and g.annotations['group_type'] == 'unit']
         if not self._load_spikesorting_info:
             return
 
@@ -849,8 +884,8 @@ class ReachGraspIO(BlackrockIO):
 
                 # Annotate filter settings from odML
                 nchan = asig.shape[-1]
-                sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][
-                    'Filter_ns%i' % asig.array_annotations['nsx'][0]]
+                filter = 'Filter_ns%i' % asig.array_annotations['nsx'][0]
+                sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter]
                 props = sec.properties
                 hi_pass_freq = np.full((nchan), pq.Quantity(props['HighPassFreq'].values[0],
                                                             props['HighPassFreq'].unit))
@@ -860,8 +895,8 @@ class ReachGraspIO(BlackrockIO):
                 lo_pass_order = np.zeros_like(lo_pass_freq)
                 filter_type= np.empty((nchan), np.str)
                 for chidx in range(nchan):
-                    filter_name = 'Filter_ns%i' % asig.array_annotations['nsx'][chidx]
-                    sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter_name]
+                    filter = 'Filter_ns%i' % asig.array_annotations['nsx'][chidx]
+                    sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter]
                     hi_pass_freq[chidx] = pq.Quantity(
                         sec.properties['HighPassFreq'].values[0],
                         sec.properties['HighPassFreq'].unit)
@@ -880,65 +915,62 @@ class ReachGraspIO(BlackrockIO):
                     filter_type=filter_type
                 ))
 
-    def __annotate_channelindex_with_odml(self, chidx):
-        """
-        Annotates channelindex with metadata from odml file.
-        """
+                # Get rejection bands
+                sec = self.odmldoc['PreProcessing']
+                bands = sec.properties['LFPBands'].values
 
-        if self.odmldoc:
-            # Get rejection bands
-            sec = self.odmldoc['PreProcessing']
-            bands = sec.properties['LFPBands'].values
+                if hasattr(bands, '__iter__'):
+                    for band in bands:
+                        sec = self.odmldoc['PreProcessing'][band]
 
-            if hasattr(bands, '__iter__'):
-                for band in bands:
-                    sec = self.odmldoc['PreProcessing'][band]
-
-                    if type(sec.properties['RejElectrodes'].values) is list:
-                        rej_electrodes = [int(_) for _ in sec.properties[
-                            'RejElectrodes'].values]
-                        rej = chidx.channel_ids[0] in rej_electrodes
-                    elif sec.properties['RejElectrodes'].values == -1:
-                        rej = False
-                    elif sec.properties['RejElectrodes'].values >= 0:
-                        rej_electrodes = sec.properties[
-                            'RejElectrodes'].values
-                        rej = (chidx.channel_ids[0] == rej_electrodes)
-                    else:
-                        raise ValueError(
-                            "Invalid entry %s in odML for rejected electrodes "
-                            "in LFP band %s." % (
-                                sec.properties['RejElectrodes'].values,
-                                band))
-
-                    rej_dict = {str('electrode_reject_' + band): rej}
-
-                    # Annotate ChannelIndex and all children for convenience
-                    chidx.annotate(**rej_dict)
-                    for asig in chidx.analogsignals:
-                        asig.annotate(**rej_dict)
-                    for unit in chidx.units:
-                        unit.annotate(**rej_dict)
-                        for st in unit.spiketrains:
-                            st.annotate(**rej_dict)
-
-            # Annotate connector aligned ID to channel
-            if chidx.channel_ids[0] in chidx.block.annotations['avail_electrode_ids']:
-                ca_dict = {
-                    'connector_aligned_id': chidx.block.annotations[
-                        'avail_electrode_ids'].index(chidx.channel_ids[0])+1}
-                chidx.coordinates = pq.Quantity(np.array([
-                    np.mod(ca_dict['connector_aligned_id']-1, 10)*.4,
-                    (ca_dict['connector_aligned_id']-1)/10*.4]),
-                    units=pq.mm)
-
-                chidx.annotate(**ca_dict)
-                for asig in chidx.analogsignals:
-                    asig.annotate(**ca_dict)
-                for unit in chidx.units:
-                    unit.annotate(**ca_dict)
-                    for st in unit.spiketrains:
-                        st.annotate(**ca_dict)
+                        # default: No rejection information present
+                        rej = np.full((asig.shape[-1]), None)
+
+                        if sec.properties['RejElectrodes'].values:
+                            rej_els = np.asarray(sec.properties['RejElectrodes'].values, dtype=int)
+                            rej = np.isin(asig.array_annotations['channel_ids'], rej_els)
+                        else:
+                            raise ValueError(
+                                "Invalid entry %s in odML for rejected electrodes "
+                                "in LFP band %s." % (sec.properties['RejElectrodes'].values, band))
+
+                        asig.array_annotations.update({str('electrode_reject_' + band): rej})
+
+    def __convert_chids_and_coordinates(self, channel_ids):
+        nchan = len(channel_ids)
+        ca_ids = np.full(nchan, fill_value=None)
+        # use negative infinity for invalid coordinates as None is incompatible with pq.mm
+        coordinates_x = np.full(nchan, fill_value=-np.inf) * pq.mm
+        coordinates_y = np.full(nchan, fill_value=-np.inf) * pq.mm
+
+        for i, channel_id in enumerate(channel_ids):
+            if channel_id not in self.connector_aligned_map:
+                continue
+            ca_ids[i] = self.connector_aligned_map[channel_id]
+            coordinates_x[i] = np.mod(ca_ids[i] - 1, 10) * 0.4 * pq.mm
+            coordinates_y[i] = ((ca_ids[i] - 1) / 10) * 0.4 * pq.mm
+
+        return ca_ids, coordinates_x, coordinates_y
+
+    def __annotate_channel_infos(self, block):
+        if self.odmldoc:
+            objs = block.groups
+            for seg in block.segments:
+                objs.extend(seg.analogsignals + seg.spiketrains)
+
+            for obj in objs:
+                if hasattr(obj, 'array_annotations') and ('channel_ids' in obj.array_annotations):
+                    chids = obj.array_annotations['channel_ids']
+                    ca_ids, *coordinates = self.__convert_chids_and_coordinates(chids)
+                    obj.array_annotations.update(dict(connector_aligned_ids=ca_ids,
+                                                      coordinates_x=coordinates[0],
+                                                      coordinates_y=coordinates[1]))
+                elif 'channel_id' in obj.annotations:
+                    chid = obj.annotations['channel_id']
+                    ca_id, *coordinates = self.__convert_chids_and_coordinates([chid])
+                    obj.annotate(connector_aligned_id=ca_id[0],
+                                 coordinate_x=coordinates[0][0],
+                                 coordinate_y=coordinates[1][0])
 
     def __annotate_block_with_odml(self, bl):
         """
@@ -993,7 +1025,7 @@ class ReachGraspIO(BlackrockIO):
             if len(np.unique(asig.array_annotations['nsx'])) > 1:
                 raise ValueError('Multiple nsx file origins (%s) in single AnalogSignal'
                                  ''.format(asig.array_annotations['nsx']))
-            
+
             # Get and correct for shifts
             filter_name = 'Filter_ns%i' % asig.array_annotations['nsx'][0] # use nsx of 1st signal
             sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter_name]
@@ -1113,7 +1145,8 @@ class ReachGraspIO(BlackrockIO):
                 individually for each channel (keys), e.g. {1: 5, 2: 'all'}
                 loads unit 5 from channel 1 and all units from channel 2.
             load_waveforms (boolean):
-                If True, waveforms are attached to all loaded spiketrains.
+                 Control SpikeTrains.waveforms is None or not.
+                 Default: False
             load_events (boolean): DEPRECATED
                 If True, all recorded events are loaded.
             scaling (str): DEPRECATED
@@ -1495,15 +1528,13 @@ class ReachGraspIO(BlackrockIO):
             if 'condition' in list(seg.annotations):
                 bl.annotations['conditions'].append(seg.annotations['condition'])
 
+        ch_dict = self.__create_channel_views(bl)
+        self.__create_unit_groups(bl, ch_dict)
+
         if self.odmldoc:
             self.__annotate_block_with_odml(bl)
-            for chidx in bl.channel_indexes:
-                self.__annotate_channelindex_with_odml(chidx)
-                self.__annotate_units_with_odml(chidx.units)
-
-        for chidx in bl.channel_indexes:
-            if isinstance(chidx.index, int):
-                chidx.index = [chidx.index]
+            self.__annotate_channel_infos(bl)
+            self.__annotate_units_with_odml(bl.groups)
 
         return bl
 
@@ -1614,7 +1645,7 @@ class ReachGraspIO(BlackrockIO):
             seg.name = name
         if description is not None:
             seg.description = description
-            
+
         # load data of all events and epochs
         for ev_idx, event in enumerate(seg.events):
             seg.events[ev_idx] = event.load()
@@ -1629,8 +1660,7 @@ class ReachGraspIO(BlackrockIO):
                 self.__correct_filter_shifts(asig)
 
         for ev in seg.events:
-            # Modify digital trial events to include semantic event
-            # informations
+            # Modify digital trial events to include semantic event information
             if ev.name == 'digital_input_port':
                 self.__annotate_dig_trial_events(ev)
                 self.__add_rejection_to_event(ev)