### Model training and testing import os import pickle import sys import time import matplotlib import matplotlib.pyplot as plt import h5py import numpy as np import torch from torchvision import datasets, transforms import torchvision.models as models from torch import sigmoid from lucent.util import set_seed from lucent.modelzoo import inceptionv1, util from helper import smoothing_laplacian_loss, sparsity_loss from neural_model import Model import sys gen_path=sys.argv[1] monkey_name=sys.argv[2] roi_name=sys.argv[3] layer=sys.argv[4] if len(sys.argv) == 4: layers = [] else: layers = [item.replace(',', '').replace(']', '') for item in sys.argv[6:]] if layer == 'False': grid_search = True train_single_layer = False elif layer == 'True': train_single_layer = True grid_search = True else: grid_search = False # hyperparameters seed = 0 nb_epochs = 1000 save_epochs = 100 grid_epochs = 120 grid_save_epochs = grid_epochs batch_size = 100 lr_decay_gamma = 1/3 lr_decay_step = 3*save_epochs backbone = 'inception_v1' sparse_weights = [1e-8] #learning_rates = [1e-2,1e-3,1e-4] #smooth_weights = [1e-3,1e-2,0.1,0.2] #weight_decays = [1e-4,1e-3,1e-2,1e-1] learning_rates = [1e-6,1e-5,1e-4] smooth_weights = [1e-6,1e-5,1e-4,1e-3] weight_decays = [1e-7,1e-6,1e-5,1e-4] # paths + filenames data_filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl' grid_filename = gen_path + monkey_name + '/snapshots/grid_search_array_'+ roi_name +'.pkl' snapshot_path = gen_path + monkey_name + '/snapshots/array'+ roi_name +'_neural_model.pt' current_datetime = time.strftime("%Y-%m-%d_%H_%M_%S", time.gmtime()) loss_plot_path = f'./training_data/training_loss_classifier_{current_datetime}.png' # load_snapshot = True load_snapshot = False GPU = torch.cuda.is_available() if GPU: torch.cuda.set_device(0) snapshot_pattern = gen_path + monkey_name + '/snapshots/neural_model_{backbone}_{layer}.pt' # load data f = open(data_filename,"rb") cc = pickle.load(f) train_data = cc['train_data'] val_data = cc['val_data'] img_data = cc['img_data'] val_img_data = cc['val_img_data'] n_neurons = train_data.shape[1] ###### # Grid search: ###### iter1_done = False if grid_search: params = [] val_corrs = [] for layer in layers: print('======================') print('Backbone: ' + layer) for learning_rate in learning_rates: for smooth_weight in smooth_weights: for sparse_weight in sparse_weights: for weight_decay in weight_decays: set_seed(seed) if iter1_done: del pretrained_model del net del criterion del optimizer del scheduler iter1_done = True # model, wrapped in DataParallel and moved to GPU pretrained_model = inceptionv1(pretrained=True) if GPU: net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu")) net.initialize() net = torch.nn.DataParallel(net) net = net.cuda() else: net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu")) net.initialize() print('Initialized using Xavier') net = torch.nn.DataParallel(net) print("Training on CPU") # loss function criterion = torch.nn.MSELoss() # optimizer and lr scheduler optimizer = torch.optim.Adam( [net.module.w_s,net.module.w_f], lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=lr_decay_step, gamma=lr_decay_gamma) cum_time = 0.0 cum_time_data = 0.0 cum_loss = 0.0 optimizer.step() for epoch in range(grid_epochs): # adjust learning rate scheduler.step() torch.cuda.empty_cache() # get the inputs & wrap them in tensor batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]), size=batch_size,replace=False).astype('int') if GPU: neural_batch = torch.tensor(train_data[batch_idx,:]).cuda() img_batch = img_data[batch_idx,:].cuda() else: neural_batch = torch.tensor(train_data[batch_idx,:]) img_batch = img_data[batch_idx,:] # forward + backward + optimize tic = time.time() optimizer.zero_grad() outputs = net(img_batch).squeeze() loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s, torch.device("cuda:0" if GPU else "cpu"), weight=smooth_weight) \ + sparse_weight * torch.norm(net.module.w_f,1) loss.backward() optimizer.step() toc = time.time() cum_time += toc - tic cum_loss += loss.data.cpu() # output & test if epoch % grid_save_epochs == grid_save_epochs - 1: torch.cuda.empty_cache() neural_batch = torch.tensor(train_data[batch_idx,:]).cuda() img_batch = img_data[batch_idx,:].cuda() val_outputs = net(img_batch).squeeze() corrs = [] for n in range(val_outputs.shape[1]): corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),neural_batch[:,n].squeeze().cpu().detach().numpy())[1,0]) val_corr = np.median(corrs) # print and plot time / loss print('======') print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) no_epoch = epoch / (grid_save_epochs - 1) mean_loss = cum_loss / float(grid_save_epochs) if mean_loss.is_cuda: mean_loss = mean_loss.data.cpu() cum_time = 0.0 cum_loss = 0.0 params.append([layer,learning_rate,smooth_weight,sparse_weight,weight_decay]) val_corrs.append(val_corr) #print('======================') print(f'learning rate: {learning_rate}') print(f'smooth weight: {smooth_weight}') print(f'sparse weight: {sparse_weight}') print(f'weight decay: {weight_decay}') print(f'Validation corr: {val_corr:.3f}') #print('======') # extract winning params val_corrs = np.array(val_corrs) layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0] learning_rate = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][1] smooth_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][2] sparse_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][3] weight_decay = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][4] # print winning params print('======================') print('Best backbone is: ' + layer) print('Best learning rate is: ' + str(learning_rate)) print('Best smooth weight is: ' + str(smooth_weight)) print('Best sparse weight is: ' + str(sparse_weight)) print('Best weight decay is: ' + str(weight_decay)) # save params output = {'val_corrs':val_corrs, 'params':params} f = open(grid_filename,"wb") pickle.dump(output,f) f.close() else: print('======================') print('Backbone: ' + layer) f = open(grid_filename,"rb") cc = pickle.load(f) val_corrs = cc['val_corrs'] params = cc['params'] all_layers = np.asarray(params)[:,0] val_corrs = np.array(val_corrs) max_layer = (np.max(val_corrs[all_layers == layer])) good_params = params[np.where(val_corrs==max_layer)[0][0].astype('int')] learning_rate = good_params[1] smooth_weight = good_params[2] sparse_weight = good_params[3] weight_decay = good_params[4] train_single_layer = True ###### # Final training!! ###### if train_single_layer: snapshot_path = gen_path + monkey_name + '/snapshots/'+ monkey_name + '_' + roi_name + '_' + layer + '_neural_model.pt' layer_filename = gen_path + 'val_corr_'+ monkey_name + '_' + roi_name + '_' + layer + '.pkl' # model, wrapped in DataParallel and moved to GPU set_seed(seed) pretrained_model = inceptionv1(pretrained=True) if GPU: net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu")) if load_snapshot: net.load_state_dict(torch.load( snapshot_path, map_location=lambda storage, loc: storage )) print('Loaded snap ' + snapshot_path) else: net.initialize() print('Initialized using Xavier') net = torch.nn.DataParallel(net) net = net.cuda() print("Training on {} GPU's".format(torch.cuda.device_count())) else: net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu")) if load_snapshot: net.load_state_dict(torch.load( snapshot_path, map_location=lambda storage, loc: storage )) print('Loaded snap ' + snapshot_path) else: net.initialize() print('Initialized using Xavier') net = torch.nn.DataParallel(net) print("Training on CPU") # loss function criterion = torch.nn.MSELoss() # optimizer and lr scheduler optimizer = torch.optim.Adam( [net.module.w_s,net.module.w_f], lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=lr_decay_step, gamma=lr_decay_gamma) # figure for loss function fig = plt.figure() axis = fig.add_subplot(111) axis.set_xlabel('epoch') axis.set_ylabel('loss') axis.set_yscale('log') plt_line, = axis.plot([], []) cum_time = 0.0 cum_time_data = 0.0 cum_loss = 0.0 optimizer.step() for epoch in range(nb_epochs): # adjust learning rate scheduler.step() # get the inputs & wrap them in tensor batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]), size=batch_size,replace=False).astype('int') if GPU: neural_batch = torch.tensor(train_data[batch_idx,:]).cuda() val_neural_data = torch.tensor(val_data).cuda() img_batch = img_data[batch_idx,:].cuda() else: neural_batch = torch.tensor(train_data[batch_idx,:]) val_neural_data = torch.tensor(val_data) img_batch = img_data[batch_idx,:] # forward + backward + optimize tic = time.time() optimizer.zero_grad() outputs = net(img_batch).squeeze() loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s, torch.device("cuda:0" if GPU else "cpu"), weight=smooth_weight) \ + sparse_weight * torch.norm(net.module.w_f,1) loss.backward() optimizer.step() toc = time.time() cum_time += toc - tic cum_loss += loss.data.cpu() # output & test if epoch % save_epochs == save_epochs - 1: val_outputs = net(val_img_data).squeeze() val_loss = criterion(val_outputs, val_neural_data) corrs = [] for n in range(val_outputs.shape[1]): corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0]) val_corr = np.median(corrs) tic_test = time.time() # print and plot time / loss print('======================') print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) no_epoch = epoch / (save_epochs - 1) mean_time = cum_time / float(save_epochs) mean_loss = cum_loss / float(save_epochs) if mean_loss.is_cuda: mean_loss = mean_loss.data.cpu() cum_time = 0.0 cum_loss = 0.0 print(f'epoch {np.int(epoch)}/{nb_epochs} mean time: {mean_time:.3f}s') print(f'epoch {np.int(epoch)}/{nb_epochs} mean loss: {mean_loss:.3f}') print(f'epoch {np.int(no_epoch)} validation loss: {val_loss:.3f}') print(f'epoch {np.int(no_epoch)} validation corr: {val_corr:.3f}') plt_line.set_xdata(np.append(plt_line.get_xdata(), no_epoch)) plt_line.set_ydata(np.append(plt_line.get_ydata(), mean_loss)) axis.relim() axis.autoscale_view() fig.savefig(loss_plot_path) print('======================') print('Test time: ', time.time()-tic_test) # save final val corr output = {'val_corrs':corrs} f = open(layer_filename,"wb") pickle.dump(output,f) f.close() # save the weights, we're done! os.makedirs(os.path.dirname(snapshot_path), exist_ok=True) torch.save(net.module.state_dict(), snapshot_path)