|
@@ -0,0 +1,216 @@
|
|
|
+import logging
|
|
|
+import sys
|
|
|
+import os.path
|
|
|
+from collections import OrderedDict
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from braindecode.datasets.bbci import BBCIDataset
|
|
|
+from braindecode.datautil.signalproc import highpass_cnt
|
|
|
+import torch.nn.functional as F
|
|
|
+import torch as th
|
|
|
+from torch import optim
|
|
|
+from braindecode.torch_ext.util import set_random_seeds
|
|
|
+from braindecode.models.deep4 import Deep4Net
|
|
|
+from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
|
|
|
+from braindecode.models.util import to_dense_prediction_model
|
|
|
+from braindecode.experiments.experiment import Experiment
|
|
|
+from braindecode.torch_ext.util import np_to_var
|
|
|
+from braindecode.datautil.iterators import CropsFromTrialsIterator
|
|
|
+from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
|
|
|
+from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
|
|
|
+from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
|
|
|
+ RuntimeMonitor, CroppedTrialMisclassMonitor
|
|
|
+
|
|
|
+from braindecode.datautil.splitters import split_into_two_sets
|
|
|
+from braindecode.datautil.trial_segment import \
|
|
|
+ create_signal_target_from_raw_mne
|
|
|
+from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
|
|
|
+from braindecode.datautil.signalproc import exponential_running_standardize
|
|
|
+
|
|
|
+
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+log.setLevel('DEBUG')
|
|
|
+
|
|
|
+
|
|
|
+def load_bbci_data(filename, low_cut_hz, debug=False):
|
|
|
+ load_sensor_names = None
|
|
|
+ if debug:
|
|
|
+ load_sensor_names = ['C3', 'C4', 'C2']
|
|
|
+ # we loaded all sensors to always get same cleaning results independent of sensor selection
|
|
|
+ # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
|
|
|
+ # works for datasets in our paper
|
|
|
+ loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
|
|
|
+
|
|
|
+ log.info("Loading data...")
|
|
|
+ cnt = loader.load()
|
|
|
+
|
|
|
+ # Cleaning: First find all trials that have absolute microvolt values
|
|
|
+ # larger than +- 800 inside them and remember them for removal later
|
|
|
+ log.info("Cutting trials...")
|
|
|
+
|
|
|
+ marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
|
|
|
+ ('Rest', [3]), ('Feet', [4])])
|
|
|
+ clean_ival = [0, 4000]
|
|
|
+
|
|
|
+ set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
|
|
|
+ clean_ival)
|
|
|
+
|
|
|
+ clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
|
|
|
+
|
|
|
+ log.info("Clean trials: {:3d} of {:3d} ({:5.1f}%)".format(
|
|
|
+ np.sum(clean_trial_mask),
|
|
|
+ len(set_for_cleaning.X),
|
|
|
+ np.mean(clean_trial_mask) * 100))
|
|
|
+
|
|
|
+ # now pick only sensors with C in their name
|
|
|
+ # as they cover motor cortex
|
|
|
+ C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
|
|
|
+ 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
|
|
|
+ 'C6',
|
|
|
+ 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
|
|
|
+ 'FCC5h',
|
|
|
+ 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
|
|
|
+ 'CPP5h',
|
|
|
+ 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
|
|
|
+ 'CCP1h',
|
|
|
+ 'CCP2h', 'CPP1h', 'CPP2h']
|
|
|
+ if debug:
|
|
|
+ C_sensors = load_sensor_names
|
|
|
+ cnt = cnt.pick_channels(C_sensors)
|
|
|
+
|
|
|
+ # Further preprocessings as descibed in paper
|
|
|
+ log.info("Resampling...")
|
|
|
+ cnt = resample_cnt(cnt, 250.0)
|
|
|
+ log.info("Highpassing...")
|
|
|
+ cnt = mne_apply(
|
|
|
+ lambda a: highpass_cnt(
|
|
|
+ a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
|
|
|
+ cnt)
|
|
|
+ log.info("Standardizing...")
|
|
|
+ cnt = mne_apply(
|
|
|
+ lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
|
|
|
+ init_block_size=1000,
|
|
|
+ eps=1e-4).T,
|
|
|
+ cnt)
|
|
|
+
|
|
|
+ # Trial interval, start at -500 already, since improved decoding for networks
|
|
|
+ ival = [-500, 4000]
|
|
|
+
|
|
|
+ dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
|
|
|
+ dataset.X = dataset.X[clean_trial_mask]
|
|
|
+ dataset.y = dataset.y[clean_trial_mask]
|
|
|
+ return dataset
|
|
|
+
|
|
|
+
|
|
|
+def load_train_valid_test(
|
|
|
+ train_filename, test_filename, low_cut_hz, debug=False):
|
|
|
+ log.info("Loading train...")
|
|
|
+ full_train_set = load_bbci_data(
|
|
|
+ train_filename, low_cut_hz=low_cut_hz, debug=debug)
|
|
|
+
|
|
|
+ log.info("Loading test...")
|
|
|
+ test_set = load_bbci_data(
|
|
|
+ test_filename, low_cut_hz=low_cut_hz, debug=debug)
|
|
|
+ valid_set_fraction = 0.8
|
|
|
+ train_set, valid_set = split_into_two_sets(full_train_set,
|
|
|
+ valid_set_fraction)
|
|
|
+
|
|
|
+ log.info("Train set with {:4d} trials".format(len(train_set.X)))
|
|
|
+ if valid_set is not None:
|
|
|
+ log.info("Valid set with {:4d} trials".format(len(valid_set.X)))
|
|
|
+ log.info("Test set with {:4d} trials".format(len(test_set.X)))
|
|
|
+
|
|
|
+ return train_set, valid_set, test_set
|
|
|
+
|
|
|
+
|
|
|
+def run_exp_on_high_gamma_dataset(train_filename, test_filename,
|
|
|
+ low_cut_hz, model_name,
|
|
|
+ max_epochs, max_increase_epochs,
|
|
|
+ np_th_seed,
|
|
|
+ debug):
|
|
|
+ train_set, valid_set, test_set = load_train_valid_test(
|
|
|
+ train_filename=train_filename,
|
|
|
+ test_filename=test_filename,
|
|
|
+ low_cut_hz=low_cut_hz, debug=debug)
|
|
|
+ if debug:
|
|
|
+ max_epochs = 4
|
|
|
+
|
|
|
+ set_random_seeds(np_th_seed, cuda=True)
|
|
|
+ #torch.backends.cudnn.benchmark = True# sometimes crashes?
|
|
|
+ n_classes = int(np.max(train_set.y) + 1)
|
|
|
+ n_chans = int(train_set.X.shape[1])
|
|
|
+ input_time_length = 1000
|
|
|
+ if model_name == 'deep':
|
|
|
+ model = Deep4Net(n_chans, n_classes,
|
|
|
+ input_time_length=input_time_length,
|
|
|
+ final_conv_length=2).create_network()
|
|
|
+ elif model_name == 'shallow':
|
|
|
+ model = ShallowFBCSPNet(
|
|
|
+ n_chans, n_classes, input_time_length=input_time_length,
|
|
|
+ final_conv_length=30).create_network()
|
|
|
+
|
|
|
+ to_dense_prediction_model(model)
|
|
|
+ model.cuda()
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
|
|
|
+
|
|
|
+ n_preds_per_input = out.cpu().data.numpy().shape[2]
|
|
|
+ optimizer = optim.Adam(model.parameters(), weight_decay=0,
|
|
|
+ lr=1e-3)
|
|
|
+
|
|
|
+ iterator = CropsFromTrialsIterator(batch_size=60,
|
|
|
+ input_time_length=input_time_length,
|
|
|
+ n_preds_per_input=n_preds_per_input,
|
|
|
+ seed=np_th_seed)
|
|
|
+
|
|
|
+ monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
|
|
|
+ CroppedTrialMisclassMonitor(
|
|
|
+ input_time_length=input_time_length), RuntimeMonitor()]
|
|
|
+
|
|
|
+ model_constraint = MaxNormDefaultConstraint()
|
|
|
+
|
|
|
+ loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
|
|
|
+ targets)
|
|
|
+
|
|
|
+ run_after_early_stop = True
|
|
|
+ do_early_stop = True
|
|
|
+ remember_best_column = 'valid_misclass'
|
|
|
+ stop_criterion = Or([MaxEpochs(max_epochs),
|
|
|
+ NoDecrease('valid_misclass', max_increase_epochs)])
|
|
|
+
|
|
|
+ exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
|
|
|
+ loss_function=loss_function, optimizer=optimizer,
|
|
|
+ model_constraint=model_constraint,
|
|
|
+ monitors=monitors,
|
|
|
+ stop_criterion=stop_criterion,
|
|
|
+ remember_best_column=remember_best_column,
|
|
|
+ run_after_early_stop=run_after_early_stop, cuda=True,
|
|
|
+ do_early_stop=do_early_stop)
|
|
|
+ exp.run()
|
|
|
+ return exp
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
|
|
|
+ level=logging.DEBUG, stream=sys.stdout)
|
|
|
+ subject_id = 1
|
|
|
+ # have to change the data_folder here to make it run.
|
|
|
+ data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/'
|
|
|
+ train_filename = os.path.join(
|
|
|
+ data_folder, 'train/{:d}.mat'.format(subject_id))
|
|
|
+ test_filename = os.path.join(
|
|
|
+ data_folder, 'test/{:d}.mat'.format(subject_id))
|
|
|
+ max_epochs = 800
|
|
|
+ max_increase_epochs = 80
|
|
|
+ model_name = 'deep' # or shallow
|
|
|
+ low_cut_hz = 0 # or 4
|
|
|
+ np_th_seed = 0 # random seed for numpy and pytorch
|
|
|
+ debug = True
|
|
|
+ exp = run_exp_on_high_gamma_dataset(train_filename, test_filename,
|
|
|
+ low_cut_hz, model_name,
|
|
|
+ max_epochs, max_increase_epochs,
|
|
|
+ np_th_seed,
|
|
|
+ debug)
|
|
|
+ log.info("Last 10 epochs")
|
|
|
+ log.info("\n" + str(exp.epochs_df.iloc[-10:]))
|