Scheduled service maintenance on November 22

On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience. 8.5 KB

  1. import logging
  2. import sys
  3. import os.path
  4. from collections import OrderedDict
  5. import numpy as np
  6. from braindecode.datasets.bbci import BBCIDataset
  7. from braindecode.datautil.signalproc import highpass_cnt
  8. import torch.nn.functional as F
  9. import torch as th
  10. from torch import optim
  11. from braindecode.torch_ext.util import set_random_seeds
  12. from braindecode.models.deep4 import Deep4Net
  13. from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
  14. from braindecode.models.util import to_dense_prediction_model
  15. from braindecode.experiments.experiment import Experiment
  16. from braindecode.torch_ext.util import np_to_var
  17. from braindecode.datautil.iterators import CropsFromTrialsIterator
  18. from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
  19. from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
  20. from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
  21. RuntimeMonitor, CroppedTrialMisclassMonitor
  22. from braindecode.datautil.splitters import split_into_two_sets
  23. from braindecode.datautil.trial_segment import \
  24. create_signal_target_from_raw_mne
  25. from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
  26. from braindecode.datautil.signalproc import exponential_running_standardize
  27. log = logging.getLogger(__name__)
  28. log.setLevel('DEBUG')
  29. def load_bbci_data(filename, low_cut_hz, debug=False):
  30. load_sensor_names = None
  31. if debug:
  32. load_sensor_names = ['C3', 'C4', 'C2']
  33. # we loaded all sensors to always get same cleaning results independent of sensor selection
  34. # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
  35. # works for datasets in our paper
  36. loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
  37."Loading data...")
  38. cnt = loader.load()
  39. # Cleaning: First find all trials that have absolute microvolt values
  40. # larger than +- 800 inside them and remember them for removal later
  41."Cutting trials...")
  42. marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
  43. ('Rest', [3]), ('Feet', [4])])
  44. clean_ival = [0, 4000]
  45. set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
  46. clean_ival)
  47. clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
  48."Clean trials: {:3d} of {:3d} ({:5.1f}%)".format(
  49. np.sum(clean_trial_mask),
  50. len(set_for_cleaning.X),
  51. np.mean(clean_trial_mask) * 100))
  52. # now pick only sensors with C in their name
  53. # as they cover motor cortex
  54. C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
  55. 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
  56. 'C6',
  57. 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
  58. 'FCC5h',
  59. 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
  60. 'CPP5h',
  61. 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
  62. 'CCP1h',
  63. 'CCP2h', 'CPP1h', 'CPP2h']
  64. if debug:
  65. C_sensors = load_sensor_names
  66. cnt = cnt.pick_channels(C_sensors)
  67. # Further preprocessings as descibed in paper
  69. cnt = resample_cnt(cnt, 250.0)
  71. cnt = mne_apply(
  72. lambda a: highpass_cnt(
  73. a, low_cut_hz,['sfreq'], filt_order=3, axis=1),
  74. cnt)
  76. cnt = mne_apply(
  77. lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
  78. init_block_size=1000,
  79. eps=1e-4).T,
  80. cnt)
  81. # Trial interval, start at -500 already, since improved decoding for networks
  82. ival = [-500, 4000]
  83. dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
  84. dataset.X = dataset.X[clean_trial_mask]
  85. dataset.y = dataset.y[clean_trial_mask]
  86. return dataset
  87. def load_train_valid_test(
  88. train_filename, test_filename, low_cut_hz, debug=False):
  89."Loading train...")
  90. full_train_set = load_bbci_data(
  91. train_filename, low_cut_hz=low_cut_hz, debug=debug)
  92."Loading test...")
  93. test_set = load_bbci_data(
  94. test_filename, low_cut_hz=low_cut_hz, debug=debug)
  95. valid_set_fraction = 0.8
  96. train_set, valid_set = split_into_two_sets(full_train_set,
  97. valid_set_fraction)
  98."Train set with {:4d} trials".format(len(train_set.X)))
  99. if valid_set is not None:
  100."Valid set with {:4d} trials".format(len(valid_set.X)))
  101."Test set with {:4d} trials".format(len(test_set.X)))
  102. return train_set, valid_set, test_set
  103. def run_exp_on_high_gamma_dataset(train_filename, test_filename,
  104. low_cut_hz, model_name,
  105. max_epochs, max_increase_epochs,
  106. np_th_seed,
  107. debug):
  108. input_time_length = 1000
  109. batch_size = 60
  110. lr = 1e-3
  111. weight_decay = 0
  112. train_set, valid_set, test_set = load_train_valid_test(
  113. train_filename=train_filename,
  114. test_filename=test_filename,
  115. low_cut_hz=low_cut_hz, debug=debug)
  116. if debug:
  117. max_epochs = 4
  118. set_random_seeds(np_th_seed, cuda=True)
  119. #torch.backends.cudnn.benchmark = True# sometimes crashes?
  120. n_classes = int(np.max(train_set.y) + 1)
  121. n_chans = int(train_set.X.shape[1])
  122. if model_name == 'deep':
  123. model = Deep4Net(n_chans, n_classes,
  124. input_time_length=input_time_length,
  125. final_conv_length=2).create_network()
  126. elif model_name == 'shallow':
  127. model = ShallowFBCSPNet(
  128. n_chans, n_classes, input_time_length=input_time_length,
  129. final_conv_length=30).create_network()
  130. to_dense_prediction_model(model)
  131. model.cuda()
  132. model.eval()
  133. out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
  134. n_preds_per_input = out.cpu().data.numpy().shape[2]
  135. optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay,
  136. lr=lr)
  137. iterator = CropsFromTrialsIterator(batch_size=batch_size,
  138. input_time_length=input_time_length,
  139. n_preds_per_input=n_preds_per_input,
  140. seed=np_th_seed)
  141. monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
  142. CroppedTrialMisclassMonitor(
  143. input_time_length=input_time_length), RuntimeMonitor()]
  144. model_constraint = MaxNormDefaultConstraint()
  145. loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
  146. targets)
  147. run_after_early_stop = True
  148. do_early_stop = True
  149. remember_best_column = 'valid_misclass'
  150. stop_criterion = Or([MaxEpochs(max_epochs),
  151. NoDecrease('valid_misclass', max_increase_epochs)])
  152. exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
  153. loss_function=loss_function, optimizer=optimizer,
  154. model_constraint=model_constraint,
  155. monitors=monitors,
  156. stop_criterion=stop_criterion,
  157. remember_best_column=remember_best_column,
  158. run_after_early_stop=run_after_early_stop, cuda=True,
  159. do_early_stop=do_early_stop)
  161. return exp
  162. if __name__ == '__main__':
  163. logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
  164. level=logging.DEBUG, stream=sys.stdout)
  165. subject_id = 1
  166. # have to change the data_folder here to make it run.
  167. data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/'
  168. train_filename = os.path.join(
  169. data_folder, 'train/{:d}.mat'.format(subject_id))
  170. test_filename = os.path.join(
  171. data_folder, 'test/{:d}.mat'.format(subject_id))
  172. max_epochs = 800
  173. max_increase_epochs = 80
  174. model_name = 'deep' # or shallow
  175. low_cut_hz = 0 # or 4
  176. np_th_seed = 0 # random seed for numpy and pytorch
  177. debug = False
  178. exp = run_exp_on_high_gamma_dataset(train_filename, test_filename,
  179. low_cut_hz, model_name,
  180. max_epochs, max_increase_epochs,
  181. np_th_seed,
  182. debug)
  183."Last 10 epochs")
  184."\n" + str(exp.epochs_df.iloc[-10:]))