Browse Source

Fix bug in trial_id assignment and add test

Julia Sprenger 3 years ago
parent
commit
c44c93aa95
2 changed files with 29 additions and 22 deletions
  1. 28 21
      code/reachgraspio/reachgraspio.py
  2. 1 1
      code/reachgraspio/test_reachgraspio.py

+ 28 - 21
code/reachgraspio/reachgraspio.py

@@ -535,7 +535,13 @@ class ReachGraspIO(BlackrockIO):
             for signalname in ['GripForceSignals', 'DisplacementSignal']:
                 for analog_events in trial_sec['AnalogEvents'][signalname].properties:
 
+                    # skip invalid values
+                    if analog_events.values == [-1]:  # this was used as default time
+                        continue
+
                     time = analog_events.values * pq.CompoundUnit(analog_events.unit)
+                    time = time.rescale('ms')
+
                     if time >= t_start and time < t_stop:
                         event_name.append(analog_events.name)
                         event_time.append(time)
@@ -546,9 +552,7 @@ class ReachGraspIO(BlackrockIO):
 
         # Create event object with analog events
         analog_events = neo.Event(
-            times=pq.Quantity(
-                [_.magnitude for _ in event_time],
-                units=event_time[0].units).rescale('ms').flatten(),
+            times=pq.Quantity(event_time, 'ms').flatten(),
             labels=np.array(event_name),
             name='AnalogTrialEvents',
             description='Events extracted from analog signals')
@@ -579,16 +583,18 @@ class ReachGraspIO(BlackrockIO):
         events.name = "DigitalTrialEvents"
         events.description = "Trial " + events.description.lower()
 
+        events_rescaled = events.rescale(pq.CompoundUnit('1/30000*s'))
+
         # Extract beginning of first complete trial
         tson_label = self.event_labels_codes['TS-ON'][0]
-        if tson_label in events.labels:
-            first_TSon_idx = list(events.labels).index(tson_label)
+        if tson_label in events_rescaled.labels:
+            first_TSon_idx = list(events_rescaled.labels).index(tson_label)
         else:
-            first_TSon_idx = len(events.labels)
+            first_TSon_idx = len(events_rescaled.labels)
         # Extract end of last complete trial
         stop_label = self.event_labels_codes['STOP'][0]
-        if stop_label in events.labels:
-            last_WSoff_idx = len(events.labels) - list(events.labels[::-1]).index(stop_label) - 1
+        if stop_label in events_rescaled.labels:
+            last_WSoff_idx = len(events_rescaled.labels) - list(events_rescaled.labels[::-1]).index(stop_label) - 1
         else:
             last_WSoff_idx = -1
 
@@ -598,7 +604,7 @@ class ReachGraspIO(BlackrockIO):
         trial_timestamp_ID = []
         trialtypes = {-1: 'NONE'}
         trialsequence = {-1: 0}
-        for i, l in enumerate(events.labels):
+        for i, l in enumerate(events_rescaled.labels):
             if i < first_TSon_idx or i > last_WSoff_idx:
                 trial_event_labels.append('NONE')
                 trial_ID.append(-1)
@@ -607,9 +613,9 @@ class ReachGraspIO(BlackrockIO):
                 # interpretation of TS-ON
                 if self.event_labels_str[l] == 'TS-ON':
                     if i > 0:
-                        prev_ev = events.labels[i - 1]
+                        prev_ev = events_rescaled.labels[i - 1]
                         if self.event_labels_str[prev_ev] in ['STOP', 'TS-OFF/STOP']:
-                            timestamp_id = int(events.times[i].item())
+                            timestamp_id = int(round(events_rescaled.times[i].item()))
                             trial_timestamp_ID.append(timestamp_id)
                             trial_event_labels.append('TS-ON')
                             trialsequence[timestamp_id] = self.__set_bit(
@@ -619,7 +625,7 @@ class ReachGraspIO(BlackrockIO):
                             trial_timestamp_ID.append(timestamp_id)
                             trial_event_labels.append('TS-ON-ERROR')
                     else:
-                        timestamp_id = int(events.times[i].item())
+                        timestamp_id = int(events_rescaled.times[i].item())
                         trial_timestamp_ID.append(timestamp_id)
                         trial_event_labels.append('TS-ON')
                         trialsequence[timestamp_id] = self.__set_bit(
@@ -654,7 +660,7 @@ class ReachGraspIO(BlackrockIO):
                 elif self.event_labels_str[l] == 'TS-OFF/STOP':
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prev_ev = events.labels[i - 1]
+                    prev_ev = events_rescaled.labels[i - 1]
                     if self.event_labels_str[prev_ev] == 'TS-ON':
                         trial_event_labels.append('TS-OFF')
                     elif prev_ev in self.event_labels_codes['ERROR-FLASH-ON']:
@@ -671,7 +677,7 @@ class ReachGraspIO(BlackrockIO):
                 elif self.event_labels_str[l] == 'WS-ON/CUE-OFF':
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prev_ev = events.labels[i - 1]
+                    prev_ev = events_rescaled.labels[i - 1]
                     if self.event_labels_str[prev_ev] in ['TS-ON', 'TS-OFF/STOP']:
                         trial_event_labels.append('WS-ON')
                         trialsequence[timestamp_id] = self.__set_bit(
@@ -689,7 +695,7 @@ class ReachGraspIO(BlackrockIO):
                 elif l in self.event_labels_codes['CUE/GO']:
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prprev_ev = events.labels[i - 2]
+                    prprev_ev = events_rescaled.labels[i - 2]
                     if self.event_labels_str[prprev_ev] in ['TS-ON', 'TS-OFF/STOP']:
                         trial_event_labels.append('CUE-ON')
                         trialsequence[timestamp_id] = self.__set_bit(
@@ -708,7 +714,7 @@ class ReachGraspIO(BlackrockIO):
                 elif self.event_labels_str[l] == 'STOP':
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prev_ev = self.event_labels_str[events.labels[i - 1]]
+                    prev_ev = self.event_labels_str[events_rescaled.labels[i - 1]]
                     if prev_ev == 'ERROR-FLASH-ON':
                         trial_event_labels.append('ERROR-FLASH-OFF')
                     else:
@@ -720,7 +726,7 @@ class ReachGraspIO(BlackrockIO):
                 elif l in self.event_labels_codes['SR']:
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prev_ev = events.labels[i - 1]
+                    prev_ev = events_rescaled.labels[i - 1]
                     if prev_ev in self.event_labels_codes['SR']:
                         trial_event_labels.append('SR-REP')
                     elif prev_ev in self.event_labels_codes['RW-ON']:
@@ -730,11 +736,11 @@ class ReachGraspIO(BlackrockIO):
                         trialsequence[timestamp_id] = self.__set_bit(
                             trialsequence[timestamp_id],
                             self.trial_const_sequence_codes['SR'])
-                # interpretation of RW events
+                # interpretation of RW events_rescaled
                 elif l in self.event_labels_codes['RW-ON']:
                     trial_timestamp_ID.append(timestamp_id)
                     trial_ID.append(ID)
-                    prev_ev = events.labels[i - 1]
+                    prev_ev = events_rescaled.labels[i - 1]
                     if prev_ev in self.event_labels_codes['RW-ON']:
                         trial_event_labels.append('RW-ON-REP')
                     else:
@@ -1062,7 +1068,7 @@ class ReachGraspIO(BlackrockIO):
             if event.name in ['AnalogTrialEvents', 'DigitalTrialEvents']:
                 # Extract event times
                 if event_time is None:
-                    event_time = event.times.magnitude
+                    event_time = event.times.magnitude.flatten()
                     event_units = event.times.units
                 else:
                     event_time = np.concatenate((
@@ -1459,7 +1465,7 @@ class ReachGraspIO(BlackrockIO):
                 Event annotations:
                     The resulting Block contains three Event objects with the
                     following names:
-                    "DigitalTrialEvents' contains all digitally recorded events
+                    'DigitalTrialEvents' contains all digitally recorded events
                         returned by BlackrockIO, annotated with semantic labels
                         in accordance with the reach-to-grasp experiment (e.g.,
                         'TS-ON').
@@ -1671,6 +1677,7 @@ class ReachGraspIO(BlackrockIO):
         for st in seg.spiketrains:
             self.__annotate_electrode_rejections(st)
 
+
         for ev in seg.events:
             # Modify digital trial events to include semantic event information
             if ev.name == 'digital_input_port':

+ 1 - 1
code/reachgraspio/test_reachgraspio.py

@@ -1 +1 @@
-/annex/objects/MD5-s5074--bc01692822a636fa8d96277e22d1b3c8
+/annex/objects/MD5-s6310--4e5f07cf4d415caa7b32517d24c50d62