Browse Source

moved hyperparameters into one place

Robin Tibor Schirrmeister 3 years ago
parent
commit
56c7b3a60b
1 changed files with 7 additions and 4 deletions
  1. 7 4
      example.py

+ 7 - 4
example.py

@@ -128,6 +128,10 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
                   max_epochs, max_increase_epochs,
                   np_th_seed,
                   debug):
+    input_time_length = 1000
+    batch_size = 60
+    lr = 1e-3
+    weight_decay = 0
     train_set, valid_set, test_set = load_train_valid_test(
         train_filename=train_filename,
         test_filename=test_filename,
@@ -139,7 +143,6 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
     #torch.backends.cudnn.benchmark = True# sometimes crashes?
     n_classes = int(np.max(train_set.y) + 1)
     n_chans = int(train_set.X.shape[1])
-    input_time_length = 1000
     if model_name == 'deep':
         model = Deep4Net(n_chans, n_classes,
                          input_time_length=input_time_length,
@@ -156,10 +159,10 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
     out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
 
     n_preds_per_input = out.cpu().data.numpy().shape[2]
-    optimizer = optim.Adam(model.parameters(), weight_decay=0,
-                           lr=1e-3)
+    optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay,
+                           lr=lr)
 
-    iterator = CropsFromTrialsIterator(batch_size=60,
+    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                        input_time_length=input_time_length,
                                        n_preds_per_input=n_preds_per_input,
                                        seed=np_th_seed)