1
0

example.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import logging
  2. import sys
  3. import os.path
  4. from collections import OrderedDict
  5. import numpy as np
  6. from braindecode.datasets.bbci import BBCIDataset
  7. from braindecode.datautil.signalproc import highpass_cnt
  8. import torch.nn.functional as F
  9. import torch as th
  10. from torch import optim
  11. from braindecode.torch_ext.util import set_random_seeds
  12. from braindecode.models.deep4 import Deep4Net
  13. from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
  14. from braindecode.models.util import to_dense_prediction_model
  15. from braindecode.experiments.experiment import Experiment
  16. from braindecode.torch_ext.util import np_to_var
  17. from braindecode.datautil.iterators import CropsFromTrialsIterator
  18. from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
  19. from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
  20. from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
  21. RuntimeMonitor, CroppedTrialMisclassMonitor
  22. from braindecode.datautil.splitters import split_into_two_sets
  23. from braindecode.datautil.trial_segment import \
  24. create_signal_target_from_raw_mne
  25. from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
  26. from braindecode.datautil.signalproc import exponential_running_standardize
  27. log = logging.getLogger(__name__)
  28. log.setLevel('DEBUG')
  29. def load_bbci_data(filename, low_cut_hz, debug=False):
  30. load_sensor_names = None
  31. if debug:
  32. load_sensor_names = ['C3', 'C4', 'C2']
  33. # we loaded all sensors to always get same cleaning results independent of sensor selection
  34. # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
  35. # works for datasets in our paper
  36. loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
  37. log.info("Loading data...")
  38. cnt = loader.load()
  39. # Cleaning: First find all trials that have absolute microvolt values
  40. # larger than +- 800 inside them and remember them for removal later
  41. log.info("Cutting trials...")
  42. marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
  43. ('Rest', [3]), ('Feet', [4])])
  44. clean_ival = [0, 4000]
  45. set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
  46. clean_ival)
  47. clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
  48. log.info("Clean trials: {:3d} of {:3d} ({:5.1f}%)".format(
  49. np.sum(clean_trial_mask),
  50. len(set_for_cleaning.X),
  51. np.mean(clean_trial_mask) * 100))
  52. # now pick only sensors with C in their name
  53. # as they cover motor cortex
  54. C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
  55. 'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
  56. 'C6',
  57. 'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
  58. 'FCC5h',
  59. 'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
  60. 'CPP5h',
  61. 'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
  62. 'CCP1h',
  63. 'CCP2h', 'CPP1h', 'CPP2h']
  64. if debug:
  65. C_sensors = load_sensor_names
  66. cnt = cnt.pick_channels(C_sensors)
  67. # Further preprocessings as descibed in paper
  68. log.info("Resampling...")
  69. cnt = resample_cnt(cnt, 250.0)
  70. log.info("Highpassing...")
  71. cnt = mne_apply(
  72. lambda a: highpass_cnt(
  73. a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
  74. cnt)
  75. log.info("Standardizing...")
  76. cnt = mne_apply(
  77. lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
  78. init_block_size=1000,
  79. eps=1e-4).T,
  80. cnt)
  81. # Trial interval, start at -500 already, since improved decoding for networks
  82. ival = [-500, 4000]
  83. dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
  84. dataset.X = dataset.X[clean_trial_mask]
  85. dataset.y = dataset.y[clean_trial_mask]
  86. return dataset
  87. def load_train_valid_test(
  88. train_filename, test_filename, low_cut_hz, debug=False):
  89. log.info("Loading train...")
  90. full_train_set = load_bbci_data(
  91. train_filename, low_cut_hz=low_cut_hz, debug=debug)
  92. log.info("Loading test...")
  93. test_set = load_bbci_data(
  94. test_filename, low_cut_hz=low_cut_hz, debug=debug)
  95. valid_set_fraction = 0.8
  96. train_set, valid_set = split_into_two_sets(full_train_set,
  97. valid_set_fraction)
  98. log.info("Train set with {:4d} trials".format(len(train_set.X)))
  99. if valid_set is not None:
  100. log.info("Valid set with {:4d} trials".format(len(valid_set.X)))
  101. log.info("Test set with {:4d} trials".format(len(test_set.X)))
  102. return train_set, valid_set, test_set
  103. def run_exp_on_high_gamma_dataset(train_filename, test_filename,
  104. low_cut_hz, model_name,
  105. max_epochs, max_increase_epochs,
  106. np_th_seed,
  107. debug):
  108. input_time_length = 1000
  109. batch_size = 60
  110. lr = 1e-3
  111. weight_decay = 0
  112. train_set, valid_set, test_set = load_train_valid_test(
  113. train_filename=train_filename,
  114. test_filename=test_filename,
  115. low_cut_hz=low_cut_hz, debug=debug)
  116. if debug:
  117. max_epochs = 4
  118. set_random_seeds(np_th_seed, cuda=True)
  119. #torch.backends.cudnn.benchmark = True# sometimes crashes?
  120. n_classes = int(np.max(train_set.y) + 1)
  121. n_chans = int(train_set.X.shape[1])
  122. if model_name == 'deep':
  123. model = Deep4Net(n_chans, n_classes,
  124. input_time_length=input_time_length,
  125. final_conv_length=2).create_network()
  126. elif model_name == 'shallow':
  127. model = ShallowFBCSPNet(
  128. n_chans, n_classes, input_time_length=input_time_length,
  129. final_conv_length=30).create_network()
  130. to_dense_prediction_model(model)
  131. model.cuda()
  132. model.eval()
  133. out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
  134. n_preds_per_input = out.cpu().data.numpy().shape[2]
  135. optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay,
  136. lr=lr)
  137. iterator = CropsFromTrialsIterator(batch_size=batch_size,
  138. input_time_length=input_time_length,
  139. n_preds_per_input=n_preds_per_input,
  140. seed=np_th_seed)
  141. monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
  142. CroppedTrialMisclassMonitor(
  143. input_time_length=input_time_length), RuntimeMonitor()]
  144. model_constraint = MaxNormDefaultConstraint()
  145. loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
  146. targets)
  147. run_after_early_stop = True
  148. do_early_stop = True
  149. remember_best_column = 'valid_misclass'
  150. stop_criterion = Or([MaxEpochs(max_epochs),
  151. NoDecrease('valid_misclass', max_increase_epochs)])
  152. exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
  153. loss_function=loss_function, optimizer=optimizer,
  154. model_constraint=model_constraint,
  155. monitors=monitors,
  156. stop_criterion=stop_criterion,
  157. remember_best_column=remember_best_column,
  158. run_after_early_stop=run_after_early_stop, cuda=True,
  159. do_early_stop=do_early_stop)
  160. exp.run()
  161. return exp
  162. if __name__ == '__main__':
  163. logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
  164. level=logging.DEBUG, stream=sys.stdout)
  165. subject_id = 1
  166. # have to change the data_folder here to make it run.
  167. data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/'
  168. train_filename = os.path.join(
  169. data_folder, 'train/{:d}.mat'.format(subject_id))
  170. test_filename = os.path.join(
  171. data_folder, 'test/{:d}.mat'.format(subject_id))
  172. max_epochs = 800
  173. max_increase_epochs = 80
  174. model_name = 'deep' # or shallow
  175. low_cut_hz = 0 # or 4
  176. np_th_seed = 0 # random seed for numpy and pytorch
  177. debug = False
  178. exp = run_exp_on_high_gamma_dataset(train_filename, test_filename,
  179. low_cut_hz, model_name,
  180. max_epochs, max_increase_epochs,
  181. np_th_seed,
  182. debug)
  183. log.info("Last 10 epochs")
  184. log.info("\n" + str(exp.epochs_df.iloc[-10:]))