|
@@ -128,6 +128,10 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
|
|
max_epochs, max_increase_epochs,
|
|
max_epochs, max_increase_epochs,
|
|
np_th_seed,
|
|
np_th_seed,
|
|
debug):
|
|
debug):
|
|
|
|
+ input_time_length = 1000
|
|
|
|
+ batch_size = 60
|
|
|
|
+ lr = 1e-3
|
|
|
|
+ weight_decay = 0
|
|
train_set, valid_set, test_set = load_train_valid_test(
|
|
train_set, valid_set, test_set = load_train_valid_test(
|
|
train_filename=train_filename,
|
|
train_filename=train_filename,
|
|
test_filename=test_filename,
|
|
test_filename=test_filename,
|
|
@@ -139,7 +143,6 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
|
|
#torch.backends.cudnn.benchmark = True# sometimes crashes?
|
|
#torch.backends.cudnn.benchmark = True# sometimes crashes?
|
|
n_classes = int(np.max(train_set.y) + 1)
|
|
n_classes = int(np.max(train_set.y) + 1)
|
|
n_chans = int(train_set.X.shape[1])
|
|
n_chans = int(train_set.X.shape[1])
|
|
- input_time_length = 1000
|
|
|
|
if model_name == 'deep':
|
|
if model_name == 'deep':
|
|
model = Deep4Net(n_chans, n_classes,
|
|
model = Deep4Net(n_chans, n_classes,
|
|
input_time_length=input_time_length,
|
|
input_time_length=input_time_length,
|
|
@@ -156,10 +159,10 @@ def run_exp_on_high_gamma_dataset(train_filename, test_filename,
|
|
out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
|
|
out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
|
|
|
|
|
|
n_preds_per_input = out.cpu().data.numpy().shape[2]
|
|
n_preds_per_input = out.cpu().data.numpy().shape[2]
|
|
- optimizer = optim.Adam(model.parameters(), weight_decay=0,
|
|
|
|
- lr=1e-3)
|
|
|
|
|
|
+ optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay,
|
|
|
|
+ lr=lr)
|
|
|
|
|
|
- iterator = CropsFromTrialsIterator(batch_size=60,
|
|
|
|
|
|
+ iterator = CropsFromTrialsIterator(batch_size=batch_size,
|
|
input_time_length=input_time_length,
|
|
input_time_length=input_time_length,
|
|
n_preds_per_input=n_preds_per_input,
|
|
n_preds_per_input=n_preds_per_input,
|
|
seed=np_th_seed)
|
|
seed=np_th_seed)
|