import logging import sys import os.path from collections import OrderedDict import numpy as np from braindecode.datasets.bbci import BBCIDataset from braindecode.datautil.signalproc import highpass_cnt import torch.nn.functional as F import torch as th from torch import optim from braindecode.torch_ext.util import set_random_seeds from braindecode.models.deep4 import Deep4Net from braindecode.models.shallow_fbcsp import ShallowFBCSPNet from braindecode.models.util import to_dense_prediction_model from braindecode.experiments.experiment import Experiment from braindecode.torch_ext.util import np_to_var from braindecode.datautil.iterators import CropsFromTrialsIterator from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or from braindecode.torch_ext.constraints import MaxNormDefaultConstraint from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \ RuntimeMonitor, CroppedTrialMisclassMonitor from braindecode.datautil.splitters import split_into_two_sets from braindecode.datautil.trial_segment import \ create_signal_target_from_raw_mne from braindecode.mne_ext.signalproc import mne_apply, resample_cnt from braindecode.datautil.signalproc import exponential_running_standardize log = logging.getLogger(__name__) log.setLevel('DEBUG') def load_bbci_data(filename, low_cut_hz, debug=False): load_sensor_names = None if debug: load_sensor_names = ['C3', 'C4', 'C2'] # we loaded all sensors to always get same cleaning results independent of sensor selection # There is an inbuilt heuristic that tries to use only EEG channels and that definitely # works for datasets in our paper loader = BBCIDataset(filename, load_sensor_names=load_sensor_names) log.info("Loading data...") cnt = loader.load() # Cleaning: First find all trials that have absolute microvolt values # larger than +- 800 inside them and remember them for removal later log.info("Cutting trials...") marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],), ('Rest', [3]), ('Feet', [4])]) clean_ival = [0, 4000] set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def, clean_ival) clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800 log.info("Clean trials: {:3d} of {:3d} ({:5.1f}%)".format( np.sum(clean_trial_mask), len(set_for_cleaning.X), np.mean(clean_trial_mask) * 100)) # now pick only sensors with C in their name # as they cover motor cortex C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5', 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h', 'CCP2h', 'CPP1h', 'CPP2h'] if debug: C_sensors = load_sensor_names cnt = cnt.pick_channels(C_sensors) # Further preprocessings as descibed in paper log.info("Resampling...") cnt = resample_cnt(cnt, 250.0) log.info("Highpassing...") cnt = mne_apply( lambda a: highpass_cnt( a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1), cnt) log.info("Standardizing...") cnt = mne_apply( lambda a: exponential_running_standardize(a.T, factor_new=1e-3, init_block_size=1000, eps=1e-4).T, cnt) # Trial interval, start at -500 already, since improved decoding for networks ival = [-500, 4000] dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival) dataset.X = dataset.X[clean_trial_mask] dataset.y = dataset.y[clean_trial_mask] return dataset def load_train_valid_test( train_filename, test_filename, low_cut_hz, debug=False): log.info("Loading train...") full_train_set = load_bbci_data( train_filename, low_cut_hz=low_cut_hz, debug=debug) log.info("Loading test...") test_set = load_bbci_data( test_filename, low_cut_hz=low_cut_hz, debug=debug) valid_set_fraction = 0.8 train_set, valid_set = split_into_two_sets(full_train_set, valid_set_fraction) log.info("Train set with {:4d} trials".format(len(train_set.X))) if valid_set is not None: log.info("Valid set with {:4d} trials".format(len(valid_set.X))) log.info("Test set with {:4d} trials".format(len(test_set.X))) return train_set, valid_set, test_set def run_exp_on_high_gamma_dataset(train_filename, test_filename, low_cut_hz, model_name, max_epochs, max_increase_epochs, np_th_seed, debug): input_time_length = 1000 batch_size = 60 lr = 1e-3 weight_decay = 0 train_set, valid_set, test_set = load_train_valid_test( train_filename=train_filename, test_filename=test_filename, low_cut_hz=low_cut_hz, debug=debug) if debug: max_epochs = 4 set_random_seeds(np_th_seed, cuda=True) #torch.backends.cudnn.benchmark = True# sometimes crashes? n_classes = int(np.max(train_set.y) + 1) n_chans = int(train_set.X.shape[1]) if model_name == 'deep': model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length, final_conv_length=2).create_network() elif model_name == 'shallow': model = ShallowFBCSPNet( n_chans, n_classes, input_time_length=input_time_length, final_conv_length=30).create_network() to_dense_prediction_model(model) model.cuda() model.eval() out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda()) n_preds_per_input = out.cpu().data.numpy().shape[2] optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr) iterator = CropsFromTrialsIterator(batch_size=batch_size, input_time_length=input_time_length, n_preds_per_input=n_preds_per_input, seed=np_th_seed) monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'), CroppedTrialMisclassMonitor( input_time_length=input_time_length), RuntimeMonitor()] model_constraint = MaxNormDefaultConstraint() loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2), targets) run_after_early_stop = True do_early_stop = True remember_best_column = 'valid_misclass' stop_criterion = Or([MaxEpochs(max_epochs), NoDecrease('valid_misclass', max_increase_epochs)]) exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator, loss_function=loss_function, optimizer=optimizer, model_constraint=model_constraint, monitors=monitors, stop_criterion=stop_criterion, remember_best_column=remember_best_column, run_after_early_stop=run_after_early_stop, cuda=True, do_early_stop=do_early_stop) exp.run() return exp if __name__ == '__main__': logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', level=logging.DEBUG, stream=sys.stdout) subject_id = 1 # have to change the data_folder here to make it run. data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/' train_filename = os.path.join( data_folder, 'train/{:d}.mat'.format(subject_id)) test_filename = os.path.join( data_folder, 'test/{:d}.mat'.format(subject_id)) max_epochs = 800 max_increase_epochs = 80 model_name = 'deep' # or shallow low_cut_hz = 0 # or 4 np_th_seed = 0 # random seed for numpy and pytorch debug = False exp = run_exp_on_high_gamma_dataset(train_filename, test_filename, low_cut_hz, model_name, max_epochs, max_increase_epochs, np_th_seed, debug) log.info("Last 10 epochs") log.info("\n" + str(exp.epochs_df.iloc[-10:]))