Browse Source

example decoding

Robin Tibor Schirrmeister 5 years ago
parent
commit
ff6312dbce
2 changed files with 218 additions and 0 deletions
  1. 2 0
      README.md
  2. 216 0
      example.py

+ 2 - 0
README.md

@@ -22,6 +22,8 @@ For using the dataset for decoding, see the next section.
 
 ## Reproduction of our results
 The `example.py` code in this repository shows how to reproduce the decoding results from the paper above and can also be used as an example code for decoding.
+Please change the `data_folder` in the code to the folder where you downloaded the dataset to, see the code at the bottom of the file.
+
 
 
 ## Data format

+ 216 - 0
example.py

@@ -0,0 +1,216 @@
+import logging
+import sys
+import os.path
+from collections import OrderedDict
+import numpy as np
+
+from braindecode.datasets.bbci import  BBCIDataset
+from braindecode.datautil.signalproc import highpass_cnt
+import torch.nn.functional as F
+import torch as th
+from torch import optim
+from braindecode.torch_ext.util import set_random_seeds
+from braindecode.models.deep4 import Deep4Net
+from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
+from braindecode.models.util import to_dense_prediction_model
+from braindecode.experiments.experiment import Experiment
+from braindecode.torch_ext.util import np_to_var
+from braindecode.datautil.iterators import CropsFromTrialsIterator
+from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
+from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
+from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
+    RuntimeMonitor, CroppedTrialMisclassMonitor
+
+from braindecode.datautil.splitters import split_into_two_sets
+from braindecode.datautil.trial_segment import \
+    create_signal_target_from_raw_mne
+from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
+from braindecode.datautil.signalproc import exponential_running_standardize
+
+
+log = logging.getLogger(__name__)
+log.setLevel('DEBUG')
+
+
+def load_bbci_data(filename, low_cut_hz, debug=False):
+    load_sensor_names = None
+    if debug:
+        load_sensor_names = ['C3', 'C4', 'C2']
+    # we loaded all sensors to always get same cleaning results independent of sensor selection
+    # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
+    # works for datasets in our paper
+    loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
+
+    log.info("Loading data...")
+    cnt = loader.load()
+
+    # Cleaning: First find all trials that have absolute microvolt values
+    # larger than +- 800 inside them and remember them for removal later
+    log.info("Cutting trials...")
+
+    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
+                              ('Rest', [3]), ('Feet', [4])])
+    clean_ival = [0, 4000]
+
+    set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
+                                                  clean_ival)
+
+    clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
+
+    log.info("Clean trials: {:3d}  of {:3d} ({:5.1f}%)".format(
+        np.sum(clean_trial_mask),
+        len(set_for_cleaning.X),
+        np.mean(clean_trial_mask) * 100))
+
+    # now pick only sensors with C in their name
+    # as they cover motor cortex
+    C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
+                 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
+                 'C6',
+                 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
+                 'FCC5h',
+                 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
+                 'CPP5h',
+                 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
+                 'CCP1h',
+                 'CCP2h', 'CPP1h', 'CPP2h']
+    if debug:
+        C_sensors = load_sensor_names
+    cnt = cnt.pick_channels(C_sensors)
+
+    # Further preprocessings as descibed in paper
+    log.info("Resampling...")
+    cnt = resample_cnt(cnt, 250.0)
+    log.info("Highpassing...")
+    cnt = mne_apply(
+        lambda a: highpass_cnt(
+            a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
+        cnt)
+    log.info("Standardizing...")
+    cnt = mne_apply(
+        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
+                                                  init_block_size=1000,
+                                                  eps=1e-4).T,
+        cnt)
+
+    # Trial interval, start at -500 already, since improved decoding for networks
+    ival = [-500, 4000]
+
+    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
+    dataset.X = dataset.X[clean_trial_mask]
+    dataset.y = dataset.y[clean_trial_mask]
+    return dataset
+
+
+def load_train_valid_test(
+        train_filename, test_filename, low_cut_hz, debug=False):
+    log.info("Loading train...")
+    full_train_set = load_bbci_data(
+        train_filename, low_cut_hz=low_cut_hz, debug=debug)
+
+    log.info("Loading test...")
+    test_set = load_bbci_data(
+        test_filename, low_cut_hz=low_cut_hz, debug=debug)
+    valid_set_fraction = 0.8
+    train_set, valid_set = split_into_two_sets(full_train_set,
+                                               valid_set_fraction)
+
+    log.info("Train set with {:4d} trials".format(len(train_set.X)))
+    if valid_set is not None:
+        log.info("Valid set with {:4d} trials".format(len(valid_set.X)))
+    log.info("Test set with  {:4d} trials".format(len(test_set.X)))
+
+    return train_set, valid_set, test_set
+
+
+def run_exp_on_high_gamma_dataset(train_filename, test_filename,
+                  low_cut_hz, model_name,
+                  max_epochs, max_increase_epochs,
+                  np_th_seed,
+                  debug):
+    train_set, valid_set, test_set = load_train_valid_test(
+        train_filename=train_filename,
+        test_filename=test_filename,
+        low_cut_hz=low_cut_hz, debug=debug)
+    if debug:
+        max_epochs = 4
+
+    set_random_seeds(np_th_seed, cuda=True)
+    #torch.backends.cudnn.benchmark = True# sometimes crashes?
+    n_classes = int(np.max(train_set.y) + 1)
+    n_chans = int(train_set.X.shape[1])
+    input_time_length = 1000
+    if model_name == 'deep':
+        model = Deep4Net(n_chans, n_classes,
+                         input_time_length=input_time_length,
+                         final_conv_length=2).create_network()
+    elif model_name == 'shallow':
+        model = ShallowFBCSPNet(
+            n_chans, n_classes, input_time_length=input_time_length,
+            final_conv_length=30).create_network()
+
+    to_dense_prediction_model(model)
+    model.cuda()
+    model.eval()
+
+    out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
+
+    n_preds_per_input = out.cpu().data.numpy().shape[2]
+    optimizer = optim.Adam(model.parameters(), weight_decay=0,
+                           lr=1e-3)
+
+    iterator = CropsFromTrialsIterator(batch_size=60,
+                                       input_time_length=input_time_length,
+                                       n_preds_per_input=n_preds_per_input,
+                                       seed=np_th_seed)
+
+    monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
+                CroppedTrialMisclassMonitor(
+                    input_time_length=input_time_length), RuntimeMonitor()]
+
+    model_constraint = MaxNormDefaultConstraint()
+
+    loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
+                                                      targets)
+
+    run_after_early_stop = True
+    do_early_stop = True
+    remember_best_column = 'valid_misclass'
+    stop_criterion = Or([MaxEpochs(max_epochs),
+                         NoDecrease('valid_misclass', max_increase_epochs)])
+
+    exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
+                     loss_function=loss_function, optimizer=optimizer,
+                     model_constraint=model_constraint,
+                     monitors=monitors,
+                     stop_criterion=stop_criterion,
+                     remember_best_column=remember_best_column,
+                     run_after_early_stop=run_after_early_stop, cuda=True,
+                     do_early_stop=do_early_stop)
+    exp.run()
+    return exp
+
+
+if __name__ == '__main__':
+    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
+                     level=logging.DEBUG, stream=sys.stdout)
+    subject_id = 1
+    # have to change the data_folder here to make it run.
+    data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/'
+    train_filename =  os.path.join(
+        data_folder, 'train/{:d}.mat'.format(subject_id))
+    test_filename =  os.path.join(
+        data_folder, 'test/{:d}.mat'.format(subject_id))
+    max_epochs = 800
+    max_increase_epochs = 80
+    model_name = 'deep' # or shallow
+    low_cut_hz = 0 # or 4
+    np_th_seed = 0 # random seed for numpy and pytorch
+    debug = True
+    exp = run_exp_on_high_gamma_dataset(train_filename, test_filename,
+                                  low_cut_hz, model_name,
+                                  max_epochs, max_increase_epochs,
+                                  np_th_seed,
+                                  debug)
+    log.info("Last 10 epochs")
+    log.info("\n" + str(exp.epochs_df.iloc[-10:]))