123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384 |
- ### Model training and testing
- import os
- import pickle
- import sys
- import time
- import matplotlib
- import matplotlib.pyplot as plt
- import h5py
- import numpy as np
- import torch
- from torchvision import datasets, transforms
- import torchvision.models as models
- from torch import sigmoid
- from lucent.util import set_seed
- from lucent.modelzoo import inceptionv1, util
- from helper import smoothing_laplacian_loss, sparsity_loss
- from neural_model import Model
- import sys
- gen_path=sys.argv[1]
- monkey_name=sys.argv[2]
- roi_name=sys.argv[3]
- layer=sys.argv[4]
- if len(sys.argv) == 4:
- layers = []
- else:
- layers = [item.replace(',', '').replace(']', '') for item in sys.argv[6:]]
- if layer == 'False':
- grid_search = True
- train_single_layer = False
- elif layer == 'True':
- train_single_layer = True
- grid_search = True
- else:
- grid_search = False
-
- # hyperparameters
- seed = 0
- nb_epochs = 1000
- save_epochs = 100
- grid_epochs = 120
- grid_save_epochs = grid_epochs
- batch_size = 100
- lr_decay_gamma = 1/3
- lr_decay_step = 3*save_epochs
- backbone = 'inception_v1'
- sparse_weights = [1e-8]
- #learning_rates = [1e-2,1e-3,1e-4]
- #smooth_weights = [1e-3,1e-2,0.1,0.2]
- #weight_decays = [1e-4,1e-3,1e-2,1e-1]
- learning_rates = [1e-6,1e-5,1e-4]
- smooth_weights = [1e-6,1e-5,1e-4,1e-3]
- weight_decays = [1e-7,1e-6,1e-5,1e-4]
- # paths + filenames
- data_filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl'
- grid_filename = gen_path + monkey_name + '/snapshots/grid_search_array_'+ roi_name +'.pkl'
- snapshot_path = gen_path + monkey_name + '/snapshots/array'+ roi_name +'_neural_model.pt'
- current_datetime = time.strftime("%Y-%m-%d_%H_%M_%S", time.gmtime())
- loss_plot_path = f'./training_data/training_loss_classifier_{current_datetime}.png'
- # load_snapshot = True
- load_snapshot = False
- GPU = torch.cuda.is_available()
- if GPU:
- torch.cuda.set_device(0)
- snapshot_pattern = gen_path + monkey_name + '/snapshots/neural_model_{backbone}_{layer}.pt'
- # load data
- f = open(data_filename,"rb")
- cc = pickle.load(f)
- train_data = cc['train_data']
- val_data = cc['val_data']
- img_data = cc['img_data']
- val_img_data = cc['val_img_data']
- n_neurons = train_data.shape[1]
- ######
- # Grid search:
- ######
- iter1_done = False
- if grid_search:
- params = []
- val_corrs = []
- for layer in layers:
- print('======================')
- print('Backbone: ' + layer)
- for learning_rate in learning_rates:
- for smooth_weight in smooth_weights:
- for sparse_weight in sparse_weights:
- for weight_decay in weight_decays:
-
- set_seed(seed)
- if iter1_done:
- del pretrained_model
- del net
- del criterion
- del optimizer
- del scheduler
-
- iter1_done = True
- # model, wrapped in DataParallel and moved to GPU
- pretrained_model = inceptionv1(pretrained=True)
- if GPU:
- net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
- net.initialize()
- net = torch.nn.DataParallel(net)
- net = net.cuda()
- else:
- net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
- net.initialize()
- print('Initialized using Xavier')
- net = torch.nn.DataParallel(net)
- print("Training on CPU")
- # loss function
- criterion = torch.nn.MSELoss()
- # optimizer and lr scheduler
- optimizer = torch.optim.Adam(
- [net.module.w_s,net.module.w_f],
- lr=learning_rate,
- weight_decay=weight_decay)
- scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer,
- step_size=lr_decay_step,
- gamma=lr_decay_gamma)
- cum_time = 0.0
- cum_time_data = 0.0
- cum_loss = 0.0
- optimizer.step()
- for epoch in range(grid_epochs):
- # adjust learning rate
- scheduler.step()
- torch.cuda.empty_cache()
- # get the inputs & wrap them in tensor
- batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]),
- size=batch_size,replace=False).astype('int')
- if GPU:
- neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
- img_batch = img_data[batch_idx,:].cuda()
- else:
- neural_batch = torch.tensor(train_data[batch_idx,:])
- img_batch = img_data[batch_idx,:]
- # forward + backward + optimize
- tic = time.time()
- optimizer.zero_grad()
- outputs = net(img_batch).squeeze()
- loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s,
- torch.device("cuda:0" if GPU else "cpu"),
- weight=smooth_weight) \
- + sparse_weight * torch.norm(net.module.w_f,1)
- loss.backward()
- optimizer.step()
- toc = time.time()
- cum_time += toc - tic
- cum_loss += loss.data.cpu()
- # output & test
- if epoch % grid_save_epochs == grid_save_epochs - 1:
- torch.cuda.empty_cache()
- neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
- img_batch = img_data[batch_idx,:].cuda()
-
- val_outputs = net(img_batch).squeeze()
- corrs = []
- for n in range(val_outputs.shape[1]):
- corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),neural_batch[:,n].squeeze().cpu().detach().numpy())[1,0])
- val_corr = np.median(corrs)
- # print and plot time / loss
- print('======')
- print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
- no_epoch = epoch / (grid_save_epochs - 1)
- mean_loss = cum_loss / float(grid_save_epochs)
- if mean_loss.is_cuda:
- mean_loss = mean_loss.data.cpu()
- cum_time = 0.0
- cum_loss = 0.0
- params.append([layer,learning_rate,smooth_weight,sparse_weight,weight_decay])
- val_corrs.append(val_corr)
- #print('======================')
- print(f'learning rate: {learning_rate}')
- print(f'smooth weight: {smooth_weight}')
- print(f'sparse weight: {sparse_weight}')
- print(f'weight decay: {weight_decay}')
- print(f'Validation corr: {val_corr:.3f}')
- #print('======')
- # extract winning params
- val_corrs = np.array(val_corrs)
- layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0]
- learning_rate = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][1]
- smooth_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][2]
- sparse_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][3]
- weight_decay = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][4]
- # print winning params
- print('======================')
- print('Best backbone is: ' + layer)
- print('Best learning rate is: ' + str(learning_rate))
- print('Best smooth weight is: ' + str(smooth_weight))
- print('Best sparse weight is: ' + str(sparse_weight))
- print('Best weight decay is: ' + str(weight_decay))
-
- # save params
- output = {'val_corrs':val_corrs, 'params':params}
- f = open(grid_filename,"wb")
- pickle.dump(output,f)
- f.close()
- else:
- print('======================')
- print('Backbone: ' + layer)
- f = open(grid_filename,"rb")
- cc = pickle.load(f)
- val_corrs = cc['val_corrs']
- params = cc['params']
- all_layers = np.asarray(params)[:,0]
- val_corrs = np.array(val_corrs)
- max_layer = (np.max(val_corrs[all_layers == layer]))
- good_params = params[np.where(val_corrs==max_layer)[0][0].astype('int')]
- learning_rate = good_params[1]
- smooth_weight = good_params[2]
- sparse_weight = good_params[3]
- weight_decay = good_params[4]
- train_single_layer = True
-
- ######
- # Final training!!
- ######
- if train_single_layer:
- snapshot_path = gen_path + monkey_name + '/snapshots/'+ monkey_name + '_' + roi_name + '_' + layer + '_neural_model.pt'
- layer_filename = gen_path + 'val_corr_'+ monkey_name + '_' + roi_name + '_' + layer + '.pkl'
- # model, wrapped in DataParallel and moved to GPU
- set_seed(seed)
- pretrained_model = inceptionv1(pretrained=True)
- if GPU:
- net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
- if load_snapshot:
- net.load_state_dict(torch.load(
- snapshot_path,
- map_location=lambda storage, loc: storage
- ))
- print('Loaded snap ' + snapshot_path)
- else:
- net.initialize()
- print('Initialized using Xavier')
- net = torch.nn.DataParallel(net)
- net = net.cuda()
- print("Training on {} GPU's".format(torch.cuda.device_count()))
- else:
- net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
- if load_snapshot:
- net.load_state_dict(torch.load(
- snapshot_path,
- map_location=lambda storage, loc: storage
- ))
- print('Loaded snap ' + snapshot_path)
- else:
- net.initialize()
- print('Initialized using Xavier')
- net = torch.nn.DataParallel(net)
- print("Training on CPU")
- # loss function
- criterion = torch.nn.MSELoss()
- # optimizer and lr scheduler
- optimizer = torch.optim.Adam(
- [net.module.w_s,net.module.w_f],
- lr=learning_rate,
- weight_decay=weight_decay)
- scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer,
- step_size=lr_decay_step,
- gamma=lr_decay_gamma)
- # figure for loss function
- fig = plt.figure()
- axis = fig.add_subplot(111)
- axis.set_xlabel('epoch')
- axis.set_ylabel('loss')
- axis.set_yscale('log')
- plt_line, = axis.plot([], [])
- cum_time = 0.0
- cum_time_data = 0.0
- cum_loss = 0.0
- optimizer.step()
- for epoch in range(nb_epochs):
- # adjust learning rate
- scheduler.step()
- # get the inputs & wrap them in tensor
- batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]),
- size=batch_size,replace=False).astype('int')
- if GPU:
- neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
- val_neural_data = torch.tensor(val_data).cuda()
- img_batch = img_data[batch_idx,:].cuda()
- else:
- neural_batch = torch.tensor(train_data[batch_idx,:])
- val_neural_data = torch.tensor(val_data)
- img_batch = img_data[batch_idx,:]
- # forward + backward + optimize
- tic = time.time()
- optimizer.zero_grad()
- outputs = net(img_batch).squeeze()
- loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s,
- torch.device("cuda:0" if GPU else "cpu"),
- weight=smooth_weight) \
- + sparse_weight * torch.norm(net.module.w_f,1)
- loss.backward()
- optimizer.step()
- toc = time.time()
- cum_time += toc - tic
- cum_loss += loss.data.cpu()
- # output & test
- if epoch % save_epochs == save_epochs - 1:
- val_outputs = net(val_img_data).squeeze()
- val_loss = criterion(val_outputs, val_neural_data)
- corrs = []
- for n in range(val_outputs.shape[1]):
- corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0])
- val_corr = np.median(corrs)
- tic_test = time.time()
- # print and plot time / loss
- print('======================')
- print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
- no_epoch = epoch / (save_epochs - 1)
- mean_time = cum_time / float(save_epochs)
- mean_loss = cum_loss / float(save_epochs)
- if mean_loss.is_cuda:
- mean_loss = mean_loss.data.cpu()
- cum_time = 0.0
- cum_loss = 0.0
- print(f'epoch {np.int(epoch)}/{nb_epochs} mean time: {mean_time:.3f}s')
- print(f'epoch {np.int(epoch)}/{nb_epochs} mean loss: {mean_loss:.3f}')
- print(f'epoch {np.int(no_epoch)} validation loss: {val_loss:.3f}')
- print(f'epoch {np.int(no_epoch)} validation corr: {val_corr:.3f}')
- plt_line.set_xdata(np.append(plt_line.get_xdata(), no_epoch))
- plt_line.set_ydata(np.append(plt_line.get_ydata(), mean_loss))
- axis.relim()
- axis.autoscale_view()
- fig.savefig(loss_plot_path)
- print('======================')
- print('Test time: ', time.time()-tic_test)
- # save final val corr
- output = {'val_corrs':corrs}
- f = open(layer_filename,"wb")
- pickle.dump(output,f)
- f.close()
- # save the weights, we're done!
- os.makedirs(os.path.dirname(snapshot_path), exist_ok=True)
- torch.save(net.module.state_dict(), snapshot_path)
|