Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

npc.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. ### Functions for plotting, storing and fitting MEIs
  2. import torch
  3. import torchvision.models as models
  4. import torchvision.transforms.functional as F
  5. import torch.nn.functional as nnf
  6. from torchvision import datasets, transforms
  7. from lucent.optvis import render, param, transform, objectives
  8. from lucent.modelzoo import inceptionv1, util, inceptionv1_avgPool
  9. import numpy as np
  10. import pickle
  11. from scipy.io import savemat,loadmat
  12. from pathlib import Path
  13. from os import listdir
  14. from os.path import isfile, join
  15. import matplotlib.pyplot as plt
  16. from matplotlib.patches import Ellipse
  17. import shapely.affinity
  18. import shapely.geometry
  19. from skimage.draw import polygon
  20. from skimage.morphology import convex_hull_image
  21. from skimage.filters import gaussian as smoothing
  22. from PIL import Image
  23. import math
  24. from tqdm import tqdm
  25. from helper import gaussian, moments, fitgaussian, load, gabor_patch, occlude_pic
  26. from neural_model import Model
  27. def load_trained_model(gen_path,name,roi,layer = None):
  28. if layer is None:
  29. data_filename = gen_path + name + '/snapshots/grid_search_array'+ roi + '.pkl'
  30. f = open(data_filename,"rb")
  31. cc = pickle.load(f)
  32. val_corrs = cc['val_corrs']
  33. params = cc['params']
  34. val_corrs = np.array(val_corrs)
  35. layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0]
  36. n_neurons = 0
  37. layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
  38. if isfile(layer_filename):
  39. f = open(layer_filename,"rb")
  40. cc = pickle.load(f)
  41. val_corrs = cc['val_corrs']
  42. n_neurons = len(val_corrs)
  43. if n_neurons == 0:
  44. data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
  45. f = open(data_filename,"rb")
  46. cc = pickle.load(f)
  47. val_data = cc['val_data']
  48. n_neurons = val_data.shape[1]
  49. pretrained_model = inceptionv1(pretrained=True)
  50. trained_model = Model(pretrained_model,layer,n_neurons,device='cpu')
  51. trained_model.load_state_dict(torch.load(gen_path + name + '/snapshots/'+ name + roi + '_' + layer + '_neural_model.pt',map_location=torch.device('cpu')))
  52. return trained_model,n_neurons#,good_neurons
  53. def plot_layer_corrs(gen_path,name,layers,roi):
  54. all_corrs = []
  55. for layer in layers:
  56. layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
  57. f = open(layer_filename,"rb")
  58. cc = pickle.load(f)
  59. val_corrs = cc['val_corrs']
  60. all_corrs.append(val_corrs)
  61. all_corrs = np.array(all_corrs)
  62. fig1, ax = plt.subplots()
  63. c = '#3399ff'
  64. ax.set_title('Layers:')
  65. ax.boxplot(all_corrs.transpose(), labels=layers, notch=True,
  66. widths=.5,showcaps=False, whis=[2.5,97.5],patch_artist=True,
  67. showfliers=False,boxprops=dict(facecolor=c, color=c),
  68. capprops=dict(color=c),whiskerprops=dict(color=c),
  69. flierprops=dict(color=c, markeredgecolor=c),medianprops=dict(color='w',linewidth=2))
  70. ax.set_ylabel('Cross-validated Pearson r')
  71. ax.set_ylim([0,1])
  72. ax.spines['top'].set_visible(False)
  73. ax.spines['right'].set_visible(False)
  74. plt.show()
  75. temp = np.median(all_corrs,axis=1)
  76. layer = layers[np.where(temp==temp.max())[0][0].astype('int')]
  77. return all_corrs, layer
  78. def compute_val_corrs(gen_path,name,trained_model,roi):
  79. data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
  80. f = open(data_filename,"rb")
  81. cc = pickle.load(f)
  82. val_img_data = cc['val_img_data']
  83. val_outputs = trained_model(val_img_data).squeeze()
  84. val_data = cc['val_data']
  85. corrs = []
  86. for n in range(val_outputs.shape[1]):
  87. corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0])
  88. return corrs
  89. def good_neurons(trained_model,n_neurons,corrs,make_plots=True):
  90. z = 1
  91. all_good_rfs = []
  92. goods = []
  93. for n in range(n_neurons):
  94. trained_model_rf = np.abs(np.reshape(trained_model.w_s[n].squeeze().detach().cpu().numpy(),
  95. [np.sqrt(trained_model.w_s.shape[2]).astype('int'),np.sqrt(trained_model.w_s.shape[2]).astype('int')]))
  96. if corrs[n] > 0:
  97. z += 1
  98. goods.append(n)
  99. trained_model_rf_norm = (trained_model_rf+np.abs(trained_model_rf.min()))/(np.abs(trained_model_rf.max())+np.abs(trained_model_rf.min()))
  100. all_good_rfs.append(trained_model_rf_norm)
  101. if make_plots:
  102. print(z,corrs[n])
  103. plt.imshow(trained_model_rf, cmap='seismic')
  104. plt.colorbar()
  105. plt.show()
  106. plt.plot(trained_model.w_f[0,n].squeeze().detach().cpu().numpy())
  107. plt.show()
  108. goods = np.array(goods)
  109. all_good_rfs = np.array(all_good_rfs)
  110. return goods,all_good_rfs
  111. def gaussian_RFs(gen_path,all_good_rfs,goods,corrs,all_corrs,pixperdeg,stim_size,mask_size,trained_model,name,roi,shift_x=0,shift_y=0):
  112. centrex = []
  113. centrey = []
  114. szx = []
  115. szy = []
  116. szdeg = []
  117. masks = []
  118. sparsity = []
  119. all_feats = []
  120. pixperdeg_reduced = all_good_rfs[0].shape[0]/all_good_rfs*pixperdeg
  121. scaling_f = stim_size/all_good_rfs[0].shape[0]
  122. scaling_mask = mask_size/all_good_rfs[0].shape[0]
  123. pixperdeg_reduced_mask = all_good_rfs[0].shape[0]/mask_size*pixperdeg
  124. fovea_x = stim_size/2 - shift_x
  125. fovea_y = stim_size/2 - shift_y
  126. ax = plt.gca()
  127. ax.set_title('Spatial RFs:')
  128. for n in range(len(goods)):
  129. # fit 2d Gaussian to W_s
  130. data = all_good_rfs[n]
  131. params = fitgaussian(data)
  132. fit = gaussian(*params)
  133. plt.imshow(data, cmap=plt.cm.gist_earth_r)
  134. (height, y, x, width_y, width_x) = params # x and y are shifted in img coords
  135. circle = Ellipse((x, y), 2*1.65*width_x, 2*1.65*width_y, edgecolor='r', facecolor='None', clip_on=True)
  136. ax.add_patch(circle)
  137. centrex.append(x*scaling_f-fovea_x)
  138. centrey.append(fovea_y-y*scaling_f)
  139. szx.append(2*1.65*width_x*scaling_f)
  140. szy.append(2*1.65*width_y*scaling_f)
  141. szdeg.append(np.mean((2*1.65*width_x/pixperdeg_reduced,2*1.65*width_y/pixperdeg_reduced)))
  142. # create binary mask summing up the estimated 2d Gaussians (95% CI)
  143. mask = np.zeros(shape=(mask_size,mask_size), dtype="bool")
  144. circ = shapely.geometry.Point((x*scaling_mask,y*scaling_mask)).buffer(1)
  145. ell = shapely.affinity.scale(circ, 1.65*width_x*scaling_mask+pixperdeg_reduced_mask, 1.65*width_y*scaling_mask+pixperdeg_reduced_mask)
  146. ell_coords = np.array(list(ell.exterior.coords))
  147. cc, rr = polygon(ell_coords[:,0], ell_coords[:,1], mask.shape)
  148. mask[rr,cc] = True
  149. masks.append(mask)
  150. feats = trained_model.w_f[0,goods[n]].squeeze().detach().cpu().numpy()
  151. chn_spars = (1 - (((np.sum(np.abs(feats)) / len(feats)) ** 2) / (np.sum(np.abs(feats) ** 2) / len(feats)))) / (1 - (1 / len(feats)))
  152. sparsity.append(chn_spars)
  153. all_feats.append(feats)
  154. plt.show()
  155. masks = np.array(masks)
  156. centrex = np.array(centrex)
  157. centrey = np.array(centrey)
  158. szx = np.array(szx)
  159. szy = np.array(szy)
  160. szdeg = np.array(szdeg)
  161. sparsity = np.array(sparsity)
  162. all_feats = np.array(all_feats)
  163. # plot:
  164. labels = ['Sparsity index','Frequency (neurons)','']
  165. c = '#3399ff'
  166. xline = [np.nanmedian(sparsity)]
  167. fig1, ax = plt.subplots()
  168. ax.hist(sparsity,bins=10,color=c)
  169. if len(xline) != 0:
  170. plt.axvline(xline[0], color='k', linestyle='dashed')
  171. if len(labels) != 0:
  172. ax.set_xlabel(labels[0])
  173. ax.set_ylabel(labels[1])
  174. ax.set_title(labels[2])
  175. ax.set_xlim([0,1])
  176. ax.spines['top'].set_visible(False)
  177. ax.spines['right'].set_visible(False)
  178. ax.set_title('Feature sparsness:')
  179. plt.show()
  180. all_corrs = np.moveaxis(all_corrs,0,-1)
  181. output = {'goods':goods,
  182. 'all_good_rfs':all_good_rfs,
  183. 'cross_val_corr':corrs,
  184. 'layers_cross_val_corr':all_corrs,
  185. 'centrex':centrex,
  186. 'centrey':centrey,
  187. 'szx':szx,
  188. 'szy':szy,
  189. 'szdeg':szdeg,
  190. 'all_feats':all_feats,
  191. 'sparsity':sparsity}
  192. savemat(gen_path + name + roi + '_good_rfs' + '.mat', output)
  193. return masks
  194. def generate_MEIS(gen_path,trained_model,name,goods=None,roi='',sign=1,batch_n=1):
  195. if sign == 1:
  196. sign_label = 'MEIs/'
  197. elif sign == -1:
  198. sign_label = 'MIIs/'
  199. else:
  200. raise Exception('The variable sign must be either 1 (for MEIs) or -1 (for MIIs)')
  201. Path(gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
  202. if goods is None:
  203. goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
  204. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  205. trained_model.to(device).eval()
  206. all_transforms = [
  207. #transform.pad(2),
  208. transform.jitter(8),
  209. transform.random_scale([n/1000. for n in range(975, 1050)]),
  210. transform.random_rotate(list(range(-5,5))),
  211. transform.jitter(4)
  212. #transforms.Grayscale(3),
  213. ]
  214. batch_param_f = lambda: param.image(128, batch=batch_n, fft=True, decorrelate=True)
  215. cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
  216. for chn in range(len(goods)):
  217. obj = objectives.channel("output", int(goods[chn]))*sign
  218. _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  219. show_inline=True, thresholds=(50,))
  220. for n in range(_[0].shape[0]):
  221. filename = gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_' + str(n) + '.bmp'
  222. temp = _[0][n]#*np.repeat(np.expand_dims(masks[chn],-1),3,axis=-1)+np.repeat((1-np.expand_dims(masks[chn],-1))*.5,3,axis=-1)
  223. temp_pic = Image.fromarray(np.uint8((temp)*255))
  224. temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
  225. temp_pic.save(filename)
  226. def generate_surround_MEIS(gen_path,trained_model,name,goods=None,roi=''):
  227. Path(gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
  228. if goods is None:
  229. goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
  230. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  231. trained_model.to(device).eval()
  232. all_transforms = [
  233. transform.pad(2),
  234. transform.jitter(4),
  235. transform.random_scale([n/1000. for n in range(975, 1050)]),
  236. transform.random_rotate(list(range(-5,5))),
  237. transforms.Grayscale(3),
  238. ]
  239. batch_param_f = lambda: param.image(128, batch=1, fft=True, decorrelate=True)
  240. cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
  241. for chn in range(len(goods)):
  242. obj = objectives.channel("output", int(goods[chn]))
  243. mei = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  244. show_inline=True, thresholds=(50,))
  245. mask = trained_model.w_s[chn].squeeze()
  246. mask = torch.reshape(mask,(np.sqrt(mask.shape[0]).astype('int'),np.sqrt(mask.shape[0]).astype('int')))
  247. mask = np.abs(nnf.interpolate(mask[None,None,:],mei[0].shape[1]).squeeze().cpu().detach().numpy())
  248. mask = smoothing(mask>(np.std(mask)*3),1)
  249. mask = np.repeat(mask[:,:,np.newaxis],3,2)
  250. all_transforms = [
  251. occlude_pic(mei[0], mask,device),
  252. transform.pad(2),
  253. transform.jitter(4),
  254. transform.random_scale([n/1000. for n in range(975, 1050)]),
  255. transform.random_rotate(list(range(-5,5))),
  256. transforms.Grayscale(3),
  257. ]
  258. for i in range(2):
  259. if i:
  260. obj = objectives.channel("output", int(goods[chn]))
  261. filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_positive_Surround.bmp'
  262. else:
  263. obj = -objectives.channel("output", int(goods[chn]))
  264. filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_negative_Surround.bmp'
  265. _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  266. show_inline=True, thresholds=(50,))
  267. temp = _[0][0]
  268. temp_pic = Image.fromarray(np.uint8((temp)*255))
  269. temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
  270. temp_pic.save(filename)
  271. def ori_tuning(gen_path,stim_size,wavelengths,orientations,phases,contrasts,trained_model,name,goods,roi):
  272. # load data:
  273. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  274. data_dict = {}
  275. f = loadmat(data_path)
  276. for k, v in f.items():
  277. data_dict[k] = np.array(v)
  278. centrex = data_dict['centrex'].astype(int)[0]
  279. centrey = data_dict['centrey'].astype(int)[0]
  280. # define transforms:
  281. #transform = transforms.Compose([transforms.ToPILImage(),
  282. # transforms.Resize([224,224]),
  283. # transforms.ToTensor(),
  284. # transforms.Normalize(mean=[0.485, 0.456, 0.406],
  285. # std=[0.229, 0.224, 0.225])])
  286. transform = transforms.Compose([transforms.ToPILImage(),
  287. transforms.Resize([224,224]),
  288. transforms.ToTensor()])
  289. # RF center:
  290. px_x = np.median(centrex)
  291. px_y = np.median(centrey)
  292. # iterate gabors:
  293. stimuli = []
  294. all_params = []
  295. for w in range(len(wavelengths)):
  296. for o in range(len(orientations)):
  297. for p in range(len(phases)):
  298. for c in range(len(contrasts)):
  299. min_b = 0.5 - contrasts[c]/2
  300. max_b = 0.5 + contrasts[c]/2
  301. # create stimulus:
  302. stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
  303. radius = stim_size, wavelength = wavelengths[w],
  304. orientation = orientations[o], phase = phases[p],
  305. min_brightness = min_b, max_brightness = max_b)
  306. stimuli.append(stimulus)
  307. all_params.append([wavelengths[w],orientations[o],phases[p],contrasts[c]])
  308. all_params = np.array(all_params)
  309. # get response:
  310. stimuli = np.array(stimuli).squeeze()
  311. all_resps_tot = []
  312. for s in range(stimuli.shape[0]):
  313. temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
  314. all_resps_tot.append(trained_model(temp[None,:,:,:]).squeeze()[goods].detach().numpy())
  315. all_resps_tot = np.array(all_resps_tot)
  316. OI_sel = []
  317. BEST_params = []
  318. TUN_curves = []
  319. SFS_curves = []
  320. CNTR_curves = []
  321. for n in range(len(goods)):
  322. # reshape the overall response:
  323. all_resps = all_resps_tot[:,n].reshape(len(wavelengths),len(orientations),len(phases),len(contrasts))
  324. # compute the orientation index OI:
  325. norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
  326. max_idx = np.where(norm_resp == np.max(norm_resp))
  327. if len(max_idx)>1:
  328. max_idx = np.array(max_idx)
  329. max_idx = max_idx[:,-1]
  330. ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] + 90)
  331. if np.array(ortho_idx).size == 0:
  332. ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] - 90)
  333. chn_sel = (norm_resp.max() - norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]]) / (norm_resp.max() + norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]])
  334. OI_sel.append(np.squeeze(chn_sel))
  335. # find best parameters (= max response)
  336. #print(np.where(all_resps_tot == np.max(all_resps))[0][0])
  337. par_max = all_params[np.int(np.where(all_resps_tot == np.max(all_resps))[0][0]),:]
  338. BEST_params.append(par_max)
  339. # output all tuning curve
  340. ori_curve = np.squeeze(norm_resp[max_idx[0],:,max_idx[2],max_idx[3]])
  341. TUN_curves.append(ori_curve)
  342. #sf_curve = np.squeeze(norm_resp[:,max_idx[1],max_idx[2],max_idx[3]])
  343. sf_curve = np.squeeze(all_resps[:,max_idx[1],max_idx[2],max_idx[3]])
  344. SFS_curves.append(sf_curve)
  345. con_curve = np.mean(norm_resp,axis=(0,1,2))
  346. CNTR_curves.append(con_curve)
  347. OI_sel = np.array(OI_sel)
  348. BEST_params = np.array(BEST_params)
  349. TUN_curves = np.array(TUN_curves)
  350. CNTR_curves = np.array(CNTR_curves)
  351. # plot:
  352. labels = ['Orientation selectivity index','Frequency (neurons)','']
  353. c = '#3399ff'
  354. xline = [np.nanmedian(OI_sel)]
  355. fig1, ax = plt.subplots()
  356. ax.hist(OI_sel,bins=10,color=c)
  357. if len(xline) != 0:
  358. plt.axvline(xline[0], color='k', linestyle='dashed')
  359. if len(labels) != 0:
  360. ax.set_xlabel(labels[0])
  361. ax.set_ylabel(labels[1])
  362. ax.set_title(labels[2])
  363. ax.set_xlim([0,1])
  364. ax.spines['top'].set_visible(False)
  365. ax.spines['right'].set_visible(False)
  366. ax.set_title('Orientation tuning:')
  367. plt.show()
  368. output = {'goods':goods,
  369. 'OI_sel':OI_sel,
  370. 'SFS_curves':SFS_curves,
  371. 'TUN_curves':TUN_curves,
  372. 'CNTR_curves':CNTR_curves,
  373. 'BEST_params':BEST_params}
  374. savemat(gen_path + name + roi + '_ori_sel' + '.mat', output)
  375. def size_tuning(gen_path,stim_size,radii,trained_model,name,goods,roi):
  376. # load data:
  377. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  378. data_dict = {}
  379. f = loadmat(data_path)
  380. for k, v in f.items():
  381. data_dict[k] = np.array(v)
  382. centrex = data_dict['centrex'].astype(int)[0]
  383. centrey = data_dict['centrey'].astype(int)[0]
  384. radii_f = radii.astype(int)
  385. data_path = gen_path + name + roi + '_ori_sel' + '.mat'
  386. data_dict = {}
  387. f = loadmat(data_path)
  388. for k, v in f.items():
  389. data_dict[k] = np.array(v)
  390. BEST_params = data_dict['BEST_params']
  391. # define transforms:
  392. transform = transforms.Compose([transforms.ToPILImage(),
  393. transforms.Resize([224,224]),
  394. transforms.ToTensor(),
  395. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  396. std=[0.229, 0.224, 0.225])])
  397. SSI_sel = []
  398. BEST_size = []
  399. S_TUN_curves = []
  400. for n in tqdm(range(len(goods)), disable=False):
  401. # RF center:
  402. px_x = centrex[n]
  403. px_y = centrey[n]
  404. params = BEST_params[n,:]
  405. # iterate gabors:
  406. stimuli = []
  407. for r in range(len(radii_f)):
  408. # create stimulus:
  409. stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
  410. radius = radii_f[r], wavelength = params[0],
  411. orientation = params[1], phase = params[2])
  412. stimuli.append(stimulus)
  413. # get response:
  414. stimuli = np.array(stimuli).squeeze()
  415. all_resps = []
  416. for s in range(stimuli.shape[0]):
  417. temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
  418. all_resps.append(trained_model(temp[None,:,:,:]).squeeze()[goods[n]].detach().numpy())
  419. all_resps = np.array(all_resps)
  420. # find best size (= max response)
  421. index = np.where(all_resps == np.max(all_resps))[0]
  422. if len(index) > 1:
  423. index = index[0]
  424. rad_max = radii[np.int(index)]
  425. BEST_size.append(rad_max)
  426. # compute the surround suppression index SS:
  427. norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
  428. max_resp = norm_resp[np.int(index)]
  429. fin_resp = norm_resp[-1]
  430. if max_resp == fin_resp:
  431. chn_sel = 0
  432. else:
  433. chn_sel = (max_resp - fin_resp) / (max_resp + fin_resp)
  434. SSI_sel.append(np.squeeze(chn_sel))
  435. S_TUN_curves.append(norm_resp)
  436. SSI_sel = np.array(SSI_sel)
  437. BEST_size = np.array(BEST_size)
  438. S_TUN_curves = np.array(S_TUN_curves)
  439. # plot:
  440. labels = ['Surround suppression index','Frequency (neurons)','']
  441. c = '#3399ff'
  442. xline = [np.nanmedian(SSI_sel)]
  443. fig1, ax = plt.subplots()
  444. ax.hist(SSI_sel,bins=10,color=c)
  445. if len(xline) != 0:
  446. plt.axvline(xline[0], color='k', linestyle='dashed')
  447. if len(labels) != 0:
  448. ax.set_xlabel(labels[0])
  449. ax.set_ylabel(labels[1])
  450. ax.set_title(labels[2])
  451. ax.set_xlim([0,1])
  452. ax.spines['top'].set_visible(False)
  453. ax.spines['right'].set_visible(False)
  454. ax.set_title('Surround suppression:')
  455. plt.show()
  456. output = {'goods':goods,
  457. 'SSI_sel':SSI_sel,
  458. 'S_TUN_curves':S_TUN_curves,
  459. 'BEST_size':BEST_size}
  460. savemat(gen_path + name + roi + '_size_sel' + '.mat', output)
  461. def get_full_resps(gen_path,trained_model,goods,name,roi,folder):
  462. trained_model.to('cpu').eval()
  463. trained_model.w_s = torch.nn.Parameter(torch.ones(trained_model.w_s.shape),requires_grad=False)
  464. transform = transforms.Compose([transforms.Resize([224,224]),
  465. transforms.ToTensor(),
  466. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  467. std=[0.229, 0.224, 0.225])])
  468. dataset = datasets.ImageFolder(folder, transform=transform)
  469. dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset.imgs), shuffle=False)
  470. img_data, junk = next(iter(dataloader))
  471. resps = trained_model(img_data).squeeze().detach().numpy()[:,goods]
  472. output = {'goods':goods,
  473. 'responses':resps}
  474. savemat(gen_path + name + roi + '_responses' + '.mat', output)
  475. return resps, img_data
  476. def output(gen_path,name,roi,goods,good_neurons,resps=False):
  477. # the output must be re-arranged in the same format as the input files
  478. # but we applied two selections (one for reliability and one for val-corr)
  479. max_n = len(good_neurons)
  480. tot = np.array(np.where(good_neurons.squeeze()==1)).squeeze()
  481. goods_ok = tot[goods]
  482. data_dict = {}
  483. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  484. f = loadmat(data_path)
  485. for k, v in f.items():
  486. if k[0] != '_':
  487. temp = np.array(v).squeeze()
  488. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  489. ts = np.array(temp.shape)
  490. ts[0] = max_n
  491. temp_r = np.empty(ts)
  492. temp_r.fill(np.nan)
  493. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  494. temp_r[goods_ok] = temp
  495. else:
  496. temp_r[tot] = temp
  497. data_dict[k] = temp_r
  498. data_path = gen_path + name + roi + '_ori_sel' + '.mat'
  499. f = loadmat(data_path)
  500. for k, v in f.items():
  501. if k[0] != '_':
  502. temp = np.array(v).squeeze()
  503. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  504. ts = np.array(temp.shape)
  505. ts[0] = max_n
  506. temp_r = np.empty(ts)
  507. temp_r.fill(np.nan)
  508. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  509. temp_r[goods_ok] = temp
  510. else:
  511. temp_r[tot] = temp
  512. data_dict[k] = temp_r
  513. data_path = gen_path + name + roi + '_size_sel' + '.mat'
  514. f = loadmat(data_path)
  515. for k, v in f.items():
  516. if k[0] != '_':
  517. temp = np.array(v).squeeze()
  518. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  519. ts = np.array(temp.shape)
  520. ts[0] = max_n
  521. temp_r = np.empty(ts)
  522. temp_r.fill(np.nan)
  523. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  524. temp_r[goods_ok] = temp
  525. else:
  526. temp_r[tot] = temp
  527. data_dict[k] = temp_r
  528. data_dict['good_chns'] = good_neurons
  529. if resps:
  530. data_path = gen_path + name + roi + '_responses' + '.mat'
  531. f = loadmat(data_path)
  532. for k, v in f.items():
  533. if k[0] != '_':
  534. temp = np.transpose(np.array(v).squeeze())
  535. ts = np.array(temp.shape)
  536. ts[0] = max_n
  537. temp_r = np.empty(ts)
  538. temp_r.fill(np.nan)
  539. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  540. temp_r[goods_ok] = temp
  541. else:
  542. temp_r[tot] = temp
  543. data_dict[k] = temp_r
  544. savemat(gen_path + name + roi + '_OUTPUT' + '.mat', data_dict)