|
@@ -53,6 +53,7 @@ import quantities as pq
|
|
|
|
|
|
import neo
|
|
|
from neo.io.blackrockio import BlackrockIO
|
|
|
+from neo.io.proxyobjects import SpikeTrainProxy, AnalogSignalProxy
|
|
|
|
|
|
|
|
|
class ReachGraspIO(BlackrockIO):
|
|
@@ -437,20 +438,17 @@ class ReachGraspIO(BlackrockIO):
|
|
|
|
|
|
# extract available neuronal ids
|
|
|
self.avail_electrode_ids = None
|
|
|
+ self.connector_aligned_map = {}
|
|
|
if self.odmldoc:
|
|
|
self.avail_electrode_ids = []
|
|
|
secs = self.odmldoc['UtahArray']['Array'].sections
|
|
|
- for i in range(1, 101):
|
|
|
- elidx = [s.properties['ID'].values for s in secs if
|
|
|
- s.name.startswith('Electrode') and
|
|
|
- s.properties['ConnectorAlignedID'].values[0] == i]
|
|
|
- if len(elidx) == 0:
|
|
|
- self.avail_electrode_ids.append(-1)
|
|
|
- elif len(elidx) == 1:
|
|
|
- self.avail_electrode_ids.append(elidx[0])
|
|
|
- else:
|
|
|
- raise ValueError("Electrode IDs in odML file are corrupt. "
|
|
|
- "ID %i occurs %i times" % (i, len(elidx)))
|
|
|
+ for sec in secs:
|
|
|
+ if not sec.name.startswith('Electrode_'):
|
|
|
+ continue
|
|
|
+ id = sec.properties['ID'].values[0]
|
|
|
+ ca_id = sec.properties['ConnectorAlignedID'].values[0]
|
|
|
+ self.avail_electrode_ids.append(id)
|
|
|
+ self.connector_aligned_map[id] = ca_id
|
|
|
|
|
|
def __is_set(self, flag, pos):
|
|
|
"""
|
|
@@ -591,8 +589,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
# Extract end of last complete trial
|
|
|
stop_label = self.event_labels_codes['STOP'][0]
|
|
|
if stop_label in events.labels:
|
|
|
- last_WSoff_idx = len(events.labels) - \
|
|
|
- list(events.labels[::-1]).index(stop_label) - 1
|
|
|
+ last_WSoff_idx = len(events.labels) - list(events.labels[::-1]).index(stop_label) - 1
|
|
|
else:
|
|
|
last_WSoff_idx = -1
|
|
|
|
|
@@ -636,9 +633,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
ff = lambda x: x.name.startswith('Trial_')
|
|
|
tr_secs = sec.itersections(filter_func=ff)
|
|
|
for trial_sec in tr_secs:
|
|
|
- if trial_sec.properties[
|
|
|
- 'TrialTimestampID'].values[0] == \
|
|
|
- timestamp_id:
|
|
|
+ if trial_sec.properties['TrialTimestampID'].values[0] == timestamp_id:
|
|
|
ID = trial_sec.properties['TrialID'].values[0]
|
|
|
trial_ID.append(ID)
|
|
|
# interpretation of GO/RW-OFF
|
|
@@ -678,8 +673,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
trial_timestamp_ID.append(timestamp_id)
|
|
|
trial_ID.append(ID)
|
|
|
prev_ev = events.labels[i - 1]
|
|
|
- if self.event_labels_str[prev_ev] in \
|
|
|
- ['TS-ON', 'TS-OFF/STOP']:
|
|
|
+ if self.event_labels_str[prev_ev] in ['TS-ON', 'TS-OFF/STOP']:
|
|
|
trial_event_labels.append('WS-ON')
|
|
|
trialsequence[timestamp_id] = self.__set_bit(
|
|
|
trialsequence[timestamp_id],
|
|
@@ -697,8 +691,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
trial_timestamp_ID.append(timestamp_id)
|
|
|
trial_ID.append(ID)
|
|
|
prprev_ev = events.labels[i - 2]
|
|
|
- if self.event_labels_str[prprev_ev] in \
|
|
|
- ['TS-ON', 'TS-OFF/STOP']:
|
|
|
+ if self.event_labels_str[prprev_ev] in ['TS-ON', 'TS-OFF/STOP']:
|
|
|
trial_event_labels.append('CUE-ON')
|
|
|
trialsequence[timestamp_id] = self.__set_bit(
|
|
|
trialsequence[timestamp_id],
|
|
@@ -709,8 +702,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
trialsequence[timestamp_id] = self.__set_bit(
|
|
|
trialsequence[timestamp_id],
|
|
|
self.trial_const_sequence_codes['GO-ON'])
|
|
|
- trialtypes[timestamp_id] += \
|
|
|
- self.event_labels_str[l][:2]
|
|
|
+ trialtypes[timestamp_id] += self.event_labels_str[l][:2]
|
|
|
else:
|
|
|
raise ValueError("Unknown trial event sequence.")
|
|
|
# interpretation of WS-OFF
|
|
@@ -781,11 +773,54 @@ class ReachGraspIO(BlackrockIO):
|
|
|
events.array_annotate(performance_in_trial=performance_in_trial)
|
|
|
events.array_annotate(performance_in_trial_str=performance_in_trial_str)
|
|
|
|
|
|
- def __annotate_units_with_odml(self, units):
|
|
|
+ def __create_unit_groups(self, block, view_dict=None):
|
|
|
+ unit_dict = {}
|
|
|
+ for seg in block.segments:
|
|
|
+ for st in seg.spiketrains:
|
|
|
+ chid = st.annotations['channel_id']
|
|
|
+ unit_id = st.annotations['unit_id']
|
|
|
+ if chid not in unit_dict:
|
|
|
+ unit_dict[chid] = {}
|
|
|
+ if unit_id not in unit_dict[chid]:
|
|
|
+ group = neo.Group(name='Unit {} on channel {}'.format(unit_id, chid),
|
|
|
+ description='Group for neuronal data related to unit {} on '
|
|
|
+ 'channel {}'.format(unit_id, chid),
|
|
|
+ group_type='unit',
|
|
|
+ allowed_types=[neo.SpikeTrain, SpikeTrainProxy,
|
|
|
+ neo.AnalogSignal, AnalogSignalProxy,
|
|
|
+ neo.ChannelView],
|
|
|
+ channel_id=chid,
|
|
|
+ unit_id=unit_id)
|
|
|
+ block.groups.append(group)
|
|
|
+ unit_dict[chid][unit_id] = group
|
|
|
+
|
|
|
+ unit_dict[chid][unit_id].add(st)
|
|
|
+
|
|
|
+ # if views are already created, link them to unit groups
|
|
|
+ if view_dict:
|
|
|
+ for chid, channel_dict in unit_dict.items():
|
|
|
+ for unit_id, group in channel_dict.items():
|
|
|
+ group.add(view_dict[chid])
|
|
|
+
|
|
|
+ def __create_channel_views(self, block):
|
|
|
+ view_dict = {}
|
|
|
+ for seg in block.segments:
|
|
|
+ for anasig in seg.analogsignals:
|
|
|
+ for chidx, chid in enumerate(anasig.array_annotations['channel_ids']):
|
|
|
+ if chid not in view_dict:
|
|
|
+ view = neo.ChannelView(anasig, [chidx],
|
|
|
+ name='Channel {} of {}'.format(chid,anasig.name),
|
|
|
+ channel_id=chid)
|
|
|
+ view_dict[chid] = view
|
|
|
+
|
|
|
+ return view_dict
|
|
|
+
|
|
|
+ def __annotate_units_with_odml(self, groups):
|
|
|
"""
|
|
|
Annotates units with metadata from odml file.
|
|
|
"""
|
|
|
- # Can the spike sorting info from the odML be matched with the odML?
|
|
|
+ units = [g for g in groups if
|
|
|
+ 'group_type' in g.annotations and g.annotations['group_type'] == 'unit']
|
|
|
if not self._load_spikesorting_info:
|
|
|
return
|
|
|
|
|
@@ -849,8 +884,8 @@ class ReachGraspIO(BlackrockIO):
|
|
|
|
|
|
# Annotate filter settings from odML
|
|
|
nchan = asig.shape[-1]
|
|
|
- sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][
|
|
|
- 'Filter_ns%i' % asig.array_annotations['nsx'][0]]
|
|
|
+ filter = 'Filter_ns%i' % asig.array_annotations['nsx'][0]
|
|
|
+ sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter]
|
|
|
props = sec.properties
|
|
|
hi_pass_freq = np.full((nchan), pq.Quantity(props['HighPassFreq'].values[0],
|
|
|
props['HighPassFreq'].unit))
|
|
@@ -860,8 +895,8 @@ class ReachGraspIO(BlackrockIO):
|
|
|
lo_pass_order = np.zeros_like(lo_pass_freq)
|
|
|
filter_type= np.empty((nchan), np.str)
|
|
|
for chidx in range(nchan):
|
|
|
- filter_name = 'Filter_ns%i' % asig.array_annotations['nsx'][chidx]
|
|
|
- sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter_name]
|
|
|
+ filter = 'Filter_ns%i' % asig.array_annotations['nsx'][chidx]
|
|
|
+ sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter]
|
|
|
hi_pass_freq[chidx] = pq.Quantity(
|
|
|
sec.properties['HighPassFreq'].values[0],
|
|
|
sec.properties['HighPassFreq'].unit)
|
|
@@ -880,65 +915,62 @@ class ReachGraspIO(BlackrockIO):
|
|
|
filter_type=filter_type
|
|
|
))
|
|
|
|
|
|
- def __annotate_channelindex_with_odml(self, chidx):
|
|
|
- """
|
|
|
- Annotates channelindex with metadata from odml file.
|
|
|
- """
|
|
|
+ # Get rejection bands
|
|
|
+ sec = self.odmldoc['PreProcessing']
|
|
|
+ bands = sec.properties['LFPBands'].values
|
|
|
|
|
|
- if self.odmldoc:
|
|
|
- # Get rejection bands
|
|
|
- sec = self.odmldoc['PreProcessing']
|
|
|
- bands = sec.properties['LFPBands'].values
|
|
|
+ if hasattr(bands, '__iter__'):
|
|
|
+ for band in bands:
|
|
|
+ sec = self.odmldoc['PreProcessing'][band]
|
|
|
|
|
|
- if hasattr(bands, '__iter__'):
|
|
|
- for band in bands:
|
|
|
- sec = self.odmldoc['PreProcessing'][band]
|
|
|
-
|
|
|
- if type(sec.properties['RejElectrodes'].values) is list:
|
|
|
- rej_electrodes = [int(_) for _ in sec.properties[
|
|
|
- 'RejElectrodes'].values]
|
|
|
- rej = chidx.channel_ids[0] in rej_electrodes
|
|
|
- elif sec.properties['RejElectrodes'].values == -1:
|
|
|
- rej = False
|
|
|
- elif sec.properties['RejElectrodes'].values >= 0:
|
|
|
- rej_electrodes = sec.properties[
|
|
|
- 'RejElectrodes'].values
|
|
|
- rej = (chidx.channel_ids[0] == rej_electrodes)
|
|
|
- else:
|
|
|
- raise ValueError(
|
|
|
- "Invalid entry %s in odML for rejected electrodes "
|
|
|
- "in LFP band %s." % (
|
|
|
- sec.properties['RejElectrodes'].values,
|
|
|
- band))
|
|
|
-
|
|
|
- rej_dict = {str('electrode_reject_' + band): rej}
|
|
|
-
|
|
|
- # Annotate ChannelIndex and all children for convenience
|
|
|
- chidx.annotate(**rej_dict)
|
|
|
- for asig in chidx.analogsignals:
|
|
|
- asig.annotate(**rej_dict)
|
|
|
- for unit in chidx.units:
|
|
|
- unit.annotate(**rej_dict)
|
|
|
- for st in unit.spiketrains:
|
|
|
- st.annotate(**rej_dict)
|
|
|
-
|
|
|
- # Annotate connector aligned ID to channel
|
|
|
- if chidx.channel_ids[0] in chidx.block.annotations['avail_electrode_ids']:
|
|
|
- ca_dict = {
|
|
|
- 'connector_aligned_id': chidx.block.annotations[
|
|
|
- 'avail_electrode_ids'].index(chidx.channel_ids[0])+1}
|
|
|
- chidx.coordinates = pq.Quantity(np.array([
|
|
|
- np.mod(ca_dict['connector_aligned_id']-1, 10)*.4,
|
|
|
- (ca_dict['connector_aligned_id']-1)/10*.4]),
|
|
|
- units=pq.mm)
|
|
|
-
|
|
|
- chidx.annotate(**ca_dict)
|
|
|
- for asig in chidx.analogsignals:
|
|
|
- asig.annotate(**ca_dict)
|
|
|
- for unit in chidx.units:
|
|
|
- unit.annotate(**ca_dict)
|
|
|
- for st in unit.spiketrains:
|
|
|
- st.annotate(**ca_dict)
|
|
|
+ # default: No rejection information present
|
|
|
+ rej = np.full((asig.shape[-1]), None)
|
|
|
+
|
|
|
+ if sec.properties['RejElectrodes'].values:
|
|
|
+ rej_els = np.asarray(sec.properties['RejElectrodes'].values, dtype=int)
|
|
|
+ rej = np.isin(asig.array_annotations['channel_ids'], rej_els)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "Invalid entry %s in odML for rejected electrodes "
|
|
|
+ "in LFP band %s." % (sec.properties['RejElectrodes'].values, band))
|
|
|
+
|
|
|
+ asig.array_annotations.update({str('electrode_reject_' + band): rej})
|
|
|
+
|
|
|
+ def __convert_chids_and_coordinates(self, channel_ids):
|
|
|
+ nchan = len(channel_ids)
|
|
|
+ ca_ids = np.full(nchan, fill_value=None)
|
|
|
+ # use negative infinity for invalid coordinates as None is incompatible with pq.mm
|
|
|
+ coordinates_x = np.full(nchan, fill_value=-np.inf) * pq.mm
|
|
|
+ coordinates_y = np.full(nchan, fill_value=-np.inf) * pq.mm
|
|
|
+
|
|
|
+ for i, channel_id in enumerate(channel_ids):
|
|
|
+ if channel_id not in self.connector_aligned_map:
|
|
|
+ continue
|
|
|
+ ca_ids[i] = self.connector_aligned_map[channel_id]
|
|
|
+ coordinates_x[i] = np.mod(ca_ids[i] - 1, 10) * 0.4 * pq.mm
|
|
|
+ coordinates_y[i] = ((ca_ids[i] - 1) / 10) * 0.4 * pq.mm
|
|
|
+
|
|
|
+ return ca_ids, coordinates_x, coordinates_y
|
|
|
+
|
|
|
+ def __annotate_channel_infos(self, block):
|
|
|
+ if self.odmldoc:
|
|
|
+ objs = block.groups
|
|
|
+ for seg in block.segments:
|
|
|
+ objs.extend(seg.analogsignals + seg.spiketrains)
|
|
|
+
|
|
|
+ for obj in objs:
|
|
|
+ if hasattr(obj, 'array_annotations') and ('channel_ids' in obj.array_annotations):
|
|
|
+ chids = obj.array_annotations['channel_ids']
|
|
|
+ ca_ids, *coordinates = self.__convert_chids_and_coordinates(chids)
|
|
|
+ obj.array_annotations.update(dict(connector_aligned_ids=ca_ids,
|
|
|
+ coordinates_x=coordinates[0],
|
|
|
+ coordinates_y=coordinates[1]))
|
|
|
+ elif 'channel_id' in obj.annotations:
|
|
|
+ chid = obj.annotations['channel_id']
|
|
|
+ ca_id, *coordinates = self.__convert_chids_and_coordinates([chid])
|
|
|
+ obj.annotate(connector_aligned_id=ca_id[0],
|
|
|
+ coordinate_x=coordinates[0][0],
|
|
|
+ coordinate_y=coordinates[1][0])
|
|
|
|
|
|
def __annotate_block_with_odml(self, bl):
|
|
|
"""
|
|
@@ -993,7 +1025,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
if len(np.unique(asig.array_annotations['nsx'])) > 1:
|
|
|
raise ValueError('Multiple nsx file origins (%s) in single AnalogSignal'
|
|
|
''.format(asig.array_annotations['nsx']))
|
|
|
-
|
|
|
+
|
|
|
# Get and correct for shifts
|
|
|
filter_name = 'Filter_ns%i' % asig.array_annotations['nsx'][0] # use nsx of 1st signal
|
|
|
sec = self.odmldoc['Cerebus']['NeuralSignalProcessor']['NeuralSignals'][filter_name]
|
|
@@ -1113,7 +1145,8 @@ class ReachGraspIO(BlackrockIO):
|
|
|
individually for each channel (keys), e.g. {1: 5, 2: 'all'}
|
|
|
loads unit 5 from channel 1 and all units from channel 2.
|
|
|
load_waveforms (boolean):
|
|
|
- If True, waveforms are attached to all loaded spiketrains.
|
|
|
+ Control SpikeTrains.waveforms is None or not.
|
|
|
+ Default: False
|
|
|
load_events (boolean): DEPRECATED
|
|
|
If True, all recorded events are loaded.
|
|
|
scaling (str): DEPRECATED
|
|
@@ -1495,15 +1528,13 @@ class ReachGraspIO(BlackrockIO):
|
|
|
if 'condition' in list(seg.annotations):
|
|
|
bl.annotations['conditions'].append(seg.annotations['condition'])
|
|
|
|
|
|
+ ch_dict = self.__create_channel_views(bl)
|
|
|
+ self.__create_unit_groups(bl, ch_dict)
|
|
|
+
|
|
|
if self.odmldoc:
|
|
|
self.__annotate_block_with_odml(bl)
|
|
|
- for chidx in bl.channel_indexes:
|
|
|
- self.__annotate_channelindex_with_odml(chidx)
|
|
|
- self.__annotate_units_with_odml(chidx.units)
|
|
|
-
|
|
|
- for chidx in bl.channel_indexes:
|
|
|
- if isinstance(chidx.index, int):
|
|
|
- chidx.index = [chidx.index]
|
|
|
+ self.__annotate_channel_infos(bl)
|
|
|
+ self.__annotate_units_with_odml(bl.groups)
|
|
|
|
|
|
return bl
|
|
|
|
|
@@ -1614,7 +1645,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
seg.name = name
|
|
|
if description is not None:
|
|
|
seg.description = description
|
|
|
-
|
|
|
+
|
|
|
# load data of all events and epochs
|
|
|
for ev_idx, event in enumerate(seg.events):
|
|
|
seg.events[ev_idx] = event.load()
|
|
@@ -1629,8 +1660,7 @@ class ReachGraspIO(BlackrockIO):
|
|
|
self.__correct_filter_shifts(asig)
|
|
|
|
|
|
for ev in seg.events:
|
|
|
- # Modify digital trial events to include semantic event
|
|
|
- # informations
|
|
|
+ # Modify digital trial events to include semantic event information
|
|
|
if ev.name == 'digital_input_port':
|
|
|
self.__annotate_dig_trial_events(ev)
|
|
|
self.__add_rejection_to_event(ev)
|