Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

train_neural_model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. ### Model training and testing
  2. import os
  3. import pickle
  4. import sys
  5. import time
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. import h5py
  9. import numpy as np
  10. import torch
  11. from torchvision import datasets, transforms
  12. import torchvision.models as models
  13. from torch import sigmoid
  14. from lucent.util import set_seed
  15. from lucent.modelzoo import inceptionv1, util
  16. from helper import smoothing_laplacian_loss, sparsity_loss
  17. from neural_model import Model
  18. import sys
  19. gen_path=sys.argv[1]
  20. monkey_name=sys.argv[2]
  21. roi_name=sys.argv[3]
  22. layer=sys.argv[4]
  23. if len(sys.argv) == 4:
  24. layers = []
  25. else:
  26. layers = [item.replace(',', '').replace(']', '') for item in sys.argv[6:]]
  27. if layer == 'False':
  28. grid_search = True
  29. train_single_layer = False
  30. elif layer == 'True':
  31. train_single_layer = True
  32. grid_search = True
  33. else:
  34. grid_search = False
  35. # hyperparameters
  36. seed = 0
  37. nb_epochs = 1000
  38. save_epochs = 100
  39. grid_epochs = 120
  40. grid_save_epochs = grid_epochs
  41. batch_size = 100
  42. lr_decay_gamma = 1/3
  43. lr_decay_step = 3*save_epochs
  44. backbone = 'inception_v1'
  45. sparse_weights = [1e-8]
  46. #learning_rates = [1e-2,1e-3,1e-4]
  47. #smooth_weights = [1e-3,1e-2,0.1,0.2]
  48. #weight_decays = [1e-4,1e-3,1e-2,1e-1]
  49. learning_rates = [1e-6,1e-5,1e-4]
  50. smooth_weights = [1e-6,1e-5,1e-4,1e-3]
  51. weight_decays = [1e-7,1e-6,1e-5,1e-4]
  52. # paths + filenames
  53. data_filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl'
  54. grid_filename = gen_path + monkey_name + '/snapshots/grid_search_array_'+ roi_name +'.pkl'
  55. snapshot_path = gen_path + monkey_name + '/snapshots/array'+ roi_name +'_neural_model.pt'
  56. current_datetime = time.strftime("%Y-%m-%d_%H_%M_%S", time.gmtime())
  57. loss_plot_path = f'./training_data/training_loss_classifier_{current_datetime}.png'
  58. # load_snapshot = True
  59. load_snapshot = False
  60. GPU = torch.cuda.is_available()
  61. if GPU:
  62. torch.cuda.set_device(0)
  63. snapshot_pattern = gen_path + monkey_name + '/snapshots/neural_model_{backbone}_{layer}.pt'
  64. # load data
  65. f = open(data_filename,"rb")
  66. cc = pickle.load(f)
  67. train_data = cc['train_data']
  68. val_data = cc['val_data']
  69. img_data = cc['img_data']
  70. val_img_data = cc['val_img_data']
  71. n_neurons = train_data.shape[1]
  72. ######
  73. # Grid search:
  74. ######
  75. iter1_done = False
  76. if grid_search:
  77. params = []
  78. val_corrs = []
  79. for layer in layers:
  80. print('======================')
  81. print('Backbone: ' + layer)
  82. for learning_rate in learning_rates:
  83. for smooth_weight in smooth_weights:
  84. for sparse_weight in sparse_weights:
  85. for weight_decay in weight_decays:
  86. set_seed(seed)
  87. if iter1_done:
  88. del pretrained_model
  89. del net
  90. del criterion
  91. del optimizer
  92. del scheduler
  93. iter1_done = True
  94. # model, wrapped in DataParallel and moved to GPU
  95. pretrained_model = inceptionv1(pretrained=True)
  96. if GPU:
  97. net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
  98. net.initialize()
  99. net = torch.nn.DataParallel(net)
  100. net = net.cuda()
  101. else:
  102. net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
  103. net.initialize()
  104. print('Initialized using Xavier')
  105. net = torch.nn.DataParallel(net)
  106. print("Training on CPU")
  107. # loss function
  108. criterion = torch.nn.MSELoss()
  109. # optimizer and lr scheduler
  110. optimizer = torch.optim.Adam(
  111. [net.module.w_s,net.module.w_f],
  112. lr=learning_rate,
  113. weight_decay=weight_decay)
  114. scheduler = torch.optim.lr_scheduler.StepLR(
  115. optimizer,
  116. step_size=lr_decay_step,
  117. gamma=lr_decay_gamma)
  118. cum_time = 0.0
  119. cum_time_data = 0.0
  120. cum_loss = 0.0
  121. optimizer.step()
  122. for epoch in range(grid_epochs):
  123. # adjust learning rate
  124. scheduler.step()
  125. torch.cuda.empty_cache()
  126. # get the inputs & wrap them in tensor
  127. batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]),
  128. size=batch_size,replace=False).astype('int')
  129. if GPU:
  130. neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
  131. img_batch = img_data[batch_idx,:].cuda()
  132. else:
  133. neural_batch = torch.tensor(train_data[batch_idx,:])
  134. img_batch = img_data[batch_idx,:]
  135. # forward + backward + optimize
  136. tic = time.time()
  137. optimizer.zero_grad()
  138. outputs = net(img_batch).squeeze()
  139. loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s,
  140. torch.device("cuda:0" if GPU else "cpu"),
  141. weight=smooth_weight) \
  142. + sparse_weight * torch.norm(net.module.w_f,1)
  143. loss.backward()
  144. optimizer.step()
  145. toc = time.time()
  146. cum_time += toc - tic
  147. cum_loss += loss.data.cpu()
  148. # output & test
  149. if epoch % grid_save_epochs == grid_save_epochs - 1:
  150. torch.cuda.empty_cache()
  151. neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
  152. img_batch = img_data[batch_idx,:].cuda()
  153. val_outputs = net(img_batch).squeeze()
  154. corrs = []
  155. for n in range(val_outputs.shape[1]):
  156. corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),neural_batch[:,n].squeeze().cpu().detach().numpy())[1,0])
  157. val_corr = np.median(corrs)
  158. # print and plot time / loss
  159. print('======')
  160. print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
  161. no_epoch = epoch / (grid_save_epochs - 1)
  162. mean_loss = cum_loss / float(grid_save_epochs)
  163. if mean_loss.is_cuda:
  164. mean_loss = mean_loss.data.cpu()
  165. cum_time = 0.0
  166. cum_loss = 0.0
  167. params.append([layer,learning_rate,smooth_weight,sparse_weight,weight_decay])
  168. val_corrs.append(val_corr)
  169. #print('======================')
  170. print(f'learning rate: {learning_rate}')
  171. print(f'smooth weight: {smooth_weight}')
  172. print(f'sparse weight: {sparse_weight}')
  173. print(f'weight decay: {weight_decay}')
  174. print(f'Validation corr: {val_corr:.3f}')
  175. #print('======')
  176. # extract winning params
  177. val_corrs = np.array(val_corrs)
  178. layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0]
  179. learning_rate = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][1]
  180. smooth_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][2]
  181. sparse_weight = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][3]
  182. weight_decay = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][4]
  183. # print winning params
  184. print('======================')
  185. print('Best backbone is: ' + layer)
  186. print('Best learning rate is: ' + str(learning_rate))
  187. print('Best smooth weight is: ' + str(smooth_weight))
  188. print('Best sparse weight is: ' + str(sparse_weight))
  189. print('Best weight decay is: ' + str(weight_decay))
  190. # save params
  191. output = {'val_corrs':val_corrs, 'params':params}
  192. f = open(grid_filename,"wb")
  193. pickle.dump(output,f)
  194. f.close()
  195. else:
  196. print('======================')
  197. print('Backbone: ' + layer)
  198. f = open(grid_filename,"rb")
  199. cc = pickle.load(f)
  200. val_corrs = cc['val_corrs']
  201. params = cc['params']
  202. all_layers = np.asarray(params)[:,0]
  203. val_corrs = np.array(val_corrs)
  204. max_layer = (np.max(val_corrs[all_layers == layer]))
  205. good_params = params[np.where(val_corrs==max_layer)[0][0].astype('int')]
  206. learning_rate = good_params[1]
  207. smooth_weight = good_params[2]
  208. sparse_weight = good_params[3]
  209. weight_decay = good_params[4]
  210. train_single_layer = True
  211. ######
  212. # Final training!!
  213. ######
  214. if train_single_layer:
  215. snapshot_path = gen_path + monkey_name + '/snapshots/'+ monkey_name + '_' + roi_name + '_' + layer + '_neural_model.pt'
  216. layer_filename = gen_path + 'val_corr_'+ monkey_name + '_' + roi_name + '_' + layer + '.pkl'
  217. # model, wrapped in DataParallel and moved to GPU
  218. set_seed(seed)
  219. pretrained_model = inceptionv1(pretrained=True)
  220. if GPU:
  221. net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
  222. if load_snapshot:
  223. net.load_state_dict(torch.load(
  224. snapshot_path,
  225. map_location=lambda storage, loc: storage
  226. ))
  227. print('Loaded snap ' + snapshot_path)
  228. else:
  229. net.initialize()
  230. print('Initialized using Xavier')
  231. net = torch.nn.DataParallel(net)
  232. net = net.cuda()
  233. print("Training on {} GPU's".format(torch.cuda.device_count()))
  234. else:
  235. net = Model(pretrained_model,layer,n_neurons,torch.device("cuda:0" if GPU else "cpu"))
  236. if load_snapshot:
  237. net.load_state_dict(torch.load(
  238. snapshot_path,
  239. map_location=lambda storage, loc: storage
  240. ))
  241. print('Loaded snap ' + snapshot_path)
  242. else:
  243. net.initialize()
  244. print('Initialized using Xavier')
  245. net = torch.nn.DataParallel(net)
  246. print("Training on CPU")
  247. # loss function
  248. criterion = torch.nn.MSELoss()
  249. # optimizer and lr scheduler
  250. optimizer = torch.optim.Adam(
  251. [net.module.w_s,net.module.w_f],
  252. lr=learning_rate,
  253. weight_decay=weight_decay)
  254. scheduler = torch.optim.lr_scheduler.StepLR(
  255. optimizer,
  256. step_size=lr_decay_step,
  257. gamma=lr_decay_gamma)
  258. # figure for loss function
  259. fig = plt.figure()
  260. axis = fig.add_subplot(111)
  261. axis.set_xlabel('epoch')
  262. axis.set_ylabel('loss')
  263. axis.set_yscale('log')
  264. plt_line, = axis.plot([], [])
  265. cum_time = 0.0
  266. cum_time_data = 0.0
  267. cum_loss = 0.0
  268. optimizer.step()
  269. for epoch in range(nb_epochs):
  270. # adjust learning rate
  271. scheduler.step()
  272. # get the inputs & wrap them in tensor
  273. batch_idx = np.random.choice(np.linspace(0,train_data.shape[0]-1,train_data.shape[0]),
  274. size=batch_size,replace=False).astype('int')
  275. if GPU:
  276. neural_batch = torch.tensor(train_data[batch_idx,:]).cuda()
  277. val_neural_data = torch.tensor(val_data).cuda()
  278. img_batch = img_data[batch_idx,:].cuda()
  279. else:
  280. neural_batch = torch.tensor(train_data[batch_idx,:])
  281. val_neural_data = torch.tensor(val_data)
  282. img_batch = img_data[batch_idx,:]
  283. # forward + backward + optimize
  284. tic = time.time()
  285. optimizer.zero_grad()
  286. outputs = net(img_batch).squeeze()
  287. loss = criterion(outputs, neural_batch.float()) + smoothing_laplacian_loss(net.module.w_s,
  288. torch.device("cuda:0" if GPU else "cpu"),
  289. weight=smooth_weight) \
  290. + sparse_weight * torch.norm(net.module.w_f,1)
  291. loss.backward()
  292. optimizer.step()
  293. toc = time.time()
  294. cum_time += toc - tic
  295. cum_loss += loss.data.cpu()
  296. # output & test
  297. if epoch % save_epochs == save_epochs - 1:
  298. val_outputs = net(val_img_data).squeeze()
  299. val_loss = criterion(val_outputs, val_neural_data)
  300. corrs = []
  301. for n in range(val_outputs.shape[1]):
  302. corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0])
  303. val_corr = np.median(corrs)
  304. tic_test = time.time()
  305. # print and plot time / loss
  306. print('======================')
  307. print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
  308. no_epoch = epoch / (save_epochs - 1)
  309. mean_time = cum_time / float(save_epochs)
  310. mean_loss = cum_loss / float(save_epochs)
  311. if mean_loss.is_cuda:
  312. mean_loss = mean_loss.data.cpu()
  313. cum_time = 0.0
  314. cum_loss = 0.0
  315. print(f'epoch {np.int(epoch)}/{nb_epochs} mean time: {mean_time:.3f}s')
  316. print(f'epoch {np.int(epoch)}/{nb_epochs} mean loss: {mean_loss:.3f}')
  317. print(f'epoch {np.int(no_epoch)} validation loss: {val_loss:.3f}')
  318. print(f'epoch {np.int(no_epoch)} validation corr: {val_corr:.3f}')
  319. plt_line.set_xdata(np.append(plt_line.get_xdata(), no_epoch))
  320. plt_line.set_ydata(np.append(plt_line.get_ydata(), mean_loss))
  321. axis.relim()
  322. axis.autoscale_view()
  323. fig.savefig(loss_plot_path)
  324. print('======================')
  325. print('Test time: ', time.time()-tic_test)
  326. # save final val corr
  327. output = {'val_corrs':corrs}
  328. f = open(layer_filename,"wb")
  329. pickle.dump(output,f)
  330. f.close()
  331. # save the weights, we're done!
  332. os.makedirs(os.path.dirname(snapshot_path), exist_ok=True)
  333. torch.save(net.module.state_dict(), snapshot_path)