npc.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. ### Functions for plotting, storing and fitting MEIs
  2. import torch
  3. import torchvision.models as models
  4. import torchvision.transforms.functional as F
  5. import torch.nn.functional as nnf
  6. from torchvision import datasets, transforms
  7. from lucent.optvis import render, param, transform, objectives
  8. from lucent.modelzoo import inceptionv1, util, inceptionv1_avgPool
  9. import numpy as np
  10. import pickle
  11. from scipy.io import savemat,loadmat
  12. from pathlib import Path
  13. from os import listdir
  14. from os.path import isfile, join
  15. import matplotlib.pyplot as plt
  16. from matplotlib.patches import Ellipse
  17. import shapely.affinity
  18. import shapely.geometry
  19. from skimage.draw import polygon
  20. from skimage.morphology import convex_hull_image
  21. from skimage.filters import gaussian as smoothing
  22. from PIL import Image
  23. import math
  24. from tqdm import tqdm
  25. from helper import gaussian, moments, fitgaussian, load, gabor_patch, occlude_pic
  26. from neural_model import Model
  27. def load_trained_model(gen_path,name,roi,layer = None):
  28. if layer is None:
  29. data_filename = gen_path + name + '/snapshots/grid_search_array'+ roi + '.pkl'
  30. f = open(data_filename,"rb")
  31. cc = pickle.load(f)
  32. val_corrs = cc['val_corrs']
  33. params = cc['params']
  34. val_corrs = np.array(val_corrs)
  35. layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0]
  36. n_neurons = 0
  37. layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
  38. if isfile(layer_filename):
  39. f = open(layer_filename,"rb")
  40. cc = pickle.load(f)
  41. val_corrs = cc['val_corrs']
  42. n_neurons = len(val_corrs)
  43. if n_neurons == 0:
  44. data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
  45. f = open(data_filename,"rb")
  46. cc = pickle.load(f)
  47. val_data = cc['val_data']
  48. n_neurons = val_data.shape[1]
  49. pretrained_model = inceptionv1(pretrained=True)
  50. trained_model = Model(pretrained_model,layer,n_neurons,device='cpu')
  51. trained_model.load_state_dict(torch.load(gen_path + name + '/snapshots/'+ name + roi + '_' + layer + '_neural_model.pt',map_location=torch.device('cpu')))
  52. return trained_model,n_neurons#,good_neurons
  53. def plot_layer_corrs(gen_path,name,layers,roi):
  54. all_corrs = []
  55. for layer in layers:
  56. layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
  57. f = open(layer_filename,"rb")
  58. cc = pickle.load(f)
  59. val_corrs = cc['val_corrs']
  60. all_corrs.append(val_corrs)
  61. all_corrs = np.array(all_corrs)
  62. fig1, ax = plt.subplots()
  63. c = '#3399ff'
  64. ax.set_title('Layers:')
  65. ax.boxplot(all_corrs.transpose(), labels=layers, notch=True,
  66. widths=.5,showcaps=False, whis=[2.5,97.5],patch_artist=True,
  67. showfliers=False,boxprops=dict(facecolor=c, color=c),
  68. capprops=dict(color=c),whiskerprops=dict(color=c),
  69. flierprops=dict(color=c, markeredgecolor=c),medianprops=dict(color='w',linewidth=2))
  70. ax.set_ylabel('Cross-validated Pearson r')
  71. ax.set_ylim([0,1])
  72. ax.spines['top'].set_visible(False)
  73. ax.spines['right'].set_visible(False)
  74. plt.show()
  75. temp = np.median(all_corrs,axis=1)
  76. layer = layers[np.where(temp==temp.max())[0][0].astype('int')]
  77. return all_corrs, layer
  78. def compute_val_corrs(gen_path,name,trained_model,roi):
  79. data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
  80. f = open(data_filename,"rb")
  81. cc = pickle.load(f)
  82. val_img_data = cc['val_img_data']
  83. val_outputs = trained_model(val_img_data).squeeze()
  84. val_data = cc['val_data']
  85. corrs = []
  86. for n in range(val_outputs.shape[1]):
  87. corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0])
  88. return corrs
  89. def good_neurons(trained_model,n_neurons,corrs,make_plots=True):
  90. z = 1
  91. all_good_rfs = []
  92. goods = []
  93. for n in range(n_neurons):
  94. trained_model_rf = np.abs(np.reshape(trained_model.w_s[n].squeeze().detach().cpu().numpy(),
  95. [np.sqrt(trained_model.w_s.shape[2]).astype('int'),np.sqrt(trained_model.w_s.shape[2]).astype('int')]))
  96. if corrs[n] > 0:
  97. z += 1
  98. goods.append(n)
  99. trained_model_rf_norm = (trained_model_rf+np.abs(trained_model_rf.min()))/(np.abs(trained_model_rf.max())+np.abs(trained_model_rf.min()))
  100. all_good_rfs.append(trained_model_rf_norm)
  101. if make_plots:
  102. print(z,corrs[n])
  103. plt.imshow(trained_model_rf, cmap='seismic')
  104. plt.colorbar()
  105. plt.show()
  106. plt.plot(trained_model.w_f[0,n].squeeze().detach().cpu().numpy())
  107. plt.show()
  108. goods = np.array(goods)
  109. all_good_rfs = np.array(all_good_rfs)
  110. return goods,all_good_rfs
  111. def gaussian_RFs(gen_path,all_good_rfs,goods,corrs,all_corrs,pixperdeg,stim_size,mask_size,trained_model,name,roi,shift_x=0,shift_y=0):
  112. centrex = []
  113. centrey = []
  114. szx = []
  115. szy = []
  116. szdeg = []
  117. masks = []
  118. sparsity = []
  119. all_feats = []
  120. pixperdeg_reduced = all_good_rfs[0].shape[0]/all_good_rfs*pixperdeg
  121. scaling_f = stim_size/all_good_rfs[0].shape[0]
  122. scaling_mask = mask_size/all_good_rfs[0].shape[0]
  123. pixperdeg_reduced_mask = all_good_rfs[0].shape[0]/mask_size*pixperdeg
  124. fovea_x = stim_size/2 - shift_x
  125. fovea_y = stim_size/2 - shift_y
  126. ax = plt.gca()
  127. ax.set_title('Spatial RFs:')
  128. for n in range(len(goods)):
  129. # fit 2d Gaussian to W_s
  130. data = all_good_rfs[n]
  131. params = fitgaussian(data)
  132. fit = gaussian(*params)
  133. plt.imshow(data, cmap=plt.cm.gist_earth_r)
  134. (height, y, x, width_y, width_x) = params # x and y are shifted in img coords
  135. circle = Ellipse((x, y), 2*1.65*width_x, 2*1.65*width_y, edgecolor='r', facecolor='None', clip_on=True)
  136. ax.add_patch(circle)
  137. centrex.append(x*scaling_f-fovea_x)
  138. centrey.append(fovea_y-y*scaling_f)
  139. szx.append(2*1.65*width_x*scaling_f)
  140. szy.append(2*1.65*width_y*scaling_f)
  141. szdeg.append(np.mean((2*1.65*width_x/pixperdeg_reduced,2*1.65*width_y/pixperdeg_reduced)))
  142. # create binary mask summing up the estimated 2d Gaussians (95% CI)
  143. mask = np.zeros(shape=(mask_size,mask_size), dtype="bool")
  144. circ = shapely.geometry.Point((x*scaling_mask,y*scaling_mask)).buffer(1)
  145. ell = shapely.affinity.scale(circ, 1.65*width_x*scaling_mask+pixperdeg_reduced_mask, 1.65*width_y*scaling_mask+pixperdeg_reduced_mask)
  146. ell_coords = np.array(list(ell.exterior.coords))
  147. cc, rr = polygon(ell_coords[:,0], ell_coords[:,1], mask.shape)
  148. mask[rr,cc] = True
  149. masks.append(mask)
  150. feats = trained_model.w_f[0,goods[n]].squeeze().detach().cpu().numpy()
  151. chn_spars = (1 - (((np.sum(np.abs(feats)) / len(feats)) ** 2) / (np.sum(np.abs(feats) ** 2) / len(feats)))) / (1 - (1 / len(feats)))
  152. sparsity.append(chn_spars)
  153. all_feats.append(feats)
  154. plt.show()
  155. masks = np.array(masks)
  156. centrex = np.array(centrex)
  157. centrey = np.array(centrey)
  158. szx = np.array(szx)
  159. szy = np.array(szy)
  160. szdeg = np.array(szdeg)
  161. sparsity = np.array(sparsity)
  162. all_feats = np.array(all_feats)
  163. # plot:
  164. labels = ['Sparsity index','Frequency (neurons)','']
  165. c = '#3399ff'
  166. xline = [np.nanmedian(sparsity)]
  167. fig1, ax = plt.subplots()
  168. ax.hist(sparsity,bins=10,color=c)
  169. if len(xline) != 0:
  170. plt.axvline(xline[0], color='k', linestyle='dashed')
  171. if len(labels) != 0:
  172. ax.set_xlabel(labels[0])
  173. ax.set_ylabel(labels[1])
  174. ax.set_title(labels[2])
  175. ax.set_xlim([0,1])
  176. ax.spines['top'].set_visible(False)
  177. ax.spines['right'].set_visible(False)
  178. ax.set_title('Feature sparsness:')
  179. plt.show()
  180. all_corrs = np.moveaxis(all_corrs,0,-1)
  181. output = {'goods':goods,
  182. 'all_good_rfs':all_good_rfs,
  183. 'cross_val_corr':corrs,
  184. 'layers_cross_val_corr':all_corrs,
  185. 'centrex':centrex,
  186. 'centrey':centrey,
  187. 'szx':szx,
  188. 'szy':szy,
  189. 'szdeg':szdeg,
  190. 'all_feats':all_feats,
  191. 'sparsity':sparsity}
  192. savemat(gen_path + name + roi + '_good_rfs' + '.mat', output)
  193. return masks
  194. def generate_MEIS(gen_path,trained_model,name,goods=None,roi='',sign=1,batch_n=1):
  195. if sign == 1:
  196. sign_label = 'MEIs/'
  197. elif sign == -1:
  198. sign_label = 'MIIs/'
  199. else:
  200. raise Exception('The variable sign must be either 1 (for MEIs) or -1 (for MIIs)')
  201. Path(gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
  202. if goods is None:
  203. goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
  204. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  205. trained_model.to(device).eval()
  206. all_transforms = [
  207. #transform.pad(2),
  208. transform.jitter(8),
  209. transform.random_scale([n/1000. for n in range(975, 1050)]),
  210. transform.random_rotate(list(range(-5,5))),
  211. transform.jitter(4)
  212. #transforms.Grayscale(3),
  213. ]
  214. batch_param_f = lambda: param.image(128, batch=batch_n, fft=True, decorrelate=True)
  215. cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
  216. for chn in range(len(goods)):
  217. obj = objectives.channel("output", int(goods[chn]))*sign
  218. _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  219. show_inline=True, thresholds=(50,))
  220. for n in range(_[0].shape[0]):
  221. filename = gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_' + str(n) + '.bmp'
  222. temp = _[0][n]#*np.repeat(np.expand_dims(masks[chn],-1),3,axis=-1)+np.repeat((1-np.expand_dims(masks[chn],-1))*.5,3,axis=-1)
  223. temp_pic = Image.fromarray(np.uint8((temp)*255))
  224. temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
  225. temp_pic.save(filename)
  226. def generate_surround_MEIS(gen_path,trained_model,name,goods=None,roi=''):
  227. Path(gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
  228. if goods is None:
  229. goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
  230. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  231. trained_model.to(device).eval()
  232. all_transforms = [
  233. transform.pad(2),
  234. transform.jitter(4),
  235. transform.random_scale([n/1000. for n in range(975, 1050)]),
  236. transform.random_rotate(list(range(-5,5))),
  237. transforms.Grayscale(3),
  238. ]
  239. batch_param_f = lambda: param.image(128, batch=1, fft=True, decorrelate=True)
  240. cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
  241. for chn in range(len(goods)):
  242. obj = objectives.channel("output", int(goods[chn]))
  243. mei = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  244. show_inline=True, thresholds=(50,))
  245. mask = trained_model.w_s[chn].squeeze()
  246. mask = torch.reshape(mask,(np.sqrt(mask.shape[0]).astype('int'),np.sqrt(mask.shape[0]).astype('int')))
  247. mask = np.abs(nnf.interpolate(mask[None,None,:],mei[0].shape[1]).squeeze().cpu().detach().numpy())
  248. mask = smoothing(mask>(np.std(mask)*3),1)
  249. mask = np.repeat(mask[:,:,np.newaxis],3,2)
  250. all_transforms = [
  251. occlude_pic(mei[0], mask,device),
  252. transform.pad(2),
  253. transform.jitter(4),
  254. transform.random_scale([n/1000. for n in range(975, 1050)]),
  255. transform.random_rotate(list(range(-5,5))),
  256. transforms.Grayscale(3),
  257. ]
  258. for i in range(2):
  259. if i:
  260. obj = objectives.channel("output", int(goods[chn]))
  261. filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_positive_Surround.bmp'
  262. else:
  263. obj = -objectives.channel("output", int(goods[chn]))
  264. filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_negative_Surround.bmp'
  265. _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
  266. show_inline=True, thresholds=(50,))
  267. temp = _[0][0]
  268. temp_pic = Image.fromarray(np.uint8((temp)*255))
  269. temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
  270. temp_pic.save(filename)
  271. def ori_tuning(gen_path,stim_size,wavelengths,orientations,phases,contrasts,trained_model,name,goods,roi):
  272. # load data:
  273. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  274. data_dict = {}
  275. f = loadmat(data_path)
  276. for k, v in f.items():
  277. data_dict[k] = np.array(v)
  278. centrex = data_dict['centrex'].astype(int)[0]
  279. centrey = data_dict['centrey'].astype(int)[0]
  280. # define transforms:
  281. #transform = transforms.Compose([transforms.ToPILImage(),
  282. # transforms.Resize([224,224]),
  283. # transforms.ToTensor(),
  284. # transforms.Normalize(mean=[0.485, 0.456, 0.406],
  285. # std=[0.229, 0.224, 0.225])])
  286. transform = transforms.Compose([transforms.ToPILImage(),
  287. transforms.Resize([224,224]),
  288. transforms.ToTensor()])
  289. # RF center:
  290. px_x = np.median(centrex)
  291. px_y = np.median(centrey)
  292. # iterate gabors:
  293. stimuli = []
  294. all_params = []
  295. for w in range(len(wavelengths)):
  296. for o in range(len(orientations)):
  297. for p in range(len(phases)):
  298. for c in range(len(contrasts)):
  299. min_b = 0.5 - contrasts[c]/2
  300. max_b = 0.5 + contrasts[c]/2
  301. # create stimulus:
  302. stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
  303. radius = stim_size, wavelength = wavelengths[w],
  304. orientation = orientations[o], phase = phases[p],
  305. min_brightness = min_b, max_brightness = max_b)
  306. stimuli.append(stimulus)
  307. all_params.append([wavelengths[w],orientations[o],phases[p],contrasts[c]])
  308. all_params = np.array(all_params)
  309. # get response:
  310. stimuli = np.array(stimuli).squeeze()
  311. all_resps_tot = []
  312. for s in range(stimuli.shape[0]):
  313. temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
  314. all_resps_tot.append(trained_model(temp[None,:,:,:]).squeeze()[goods].detach().numpy())
  315. all_resps_tot = np.array(all_resps_tot)
  316. OI_sel = []
  317. BEST_params = []
  318. TUN_curves = []
  319. SFS_curves = []
  320. CNTR_curves = []
  321. for n in range(len(goods)):
  322. # reshape the overall response:
  323. all_resps = all_resps_tot[:,n].reshape(len(wavelengths),len(orientations),len(phases),len(contrasts))
  324. # compute the orientation index OI:
  325. norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
  326. max_idx = np.where(norm_resp == np.max(norm_resp))
  327. if len(max_idx)>1:
  328. max_idx = np.array(max_idx)
  329. max_idx = max_idx[:,-1]
  330. ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] + 90)
  331. if np.array(ortho_idx).size == 0:
  332. ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] - 90)
  333. chn_sel = (norm_resp.max() - norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]]) / (norm_resp.max() + norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]])
  334. OI_sel.append(np.squeeze(chn_sel))
  335. # find best parameters (= max response)
  336. #print(np.where(all_resps_tot == np.max(all_resps))[0][0])
  337. par_max = all_params[np.int(np.where(all_resps_tot == np.max(all_resps))[0][0]),:]
  338. BEST_params.append(par_max)
  339. # output all tuning curve
  340. ori_curve = np.squeeze(norm_resp[max_idx[0],:,max_idx[2],max_idx[3]])
  341. TUN_curves.append(ori_curve)
  342. #sf_curve = np.squeeze(norm_resp[:,max_idx[1],max_idx[2],max_idx[3]])
  343. sf_curve = np.squeeze(all_resps[:,max_idx[1],max_idx[2],max_idx[3]])
  344. SFS_curves.append(sf_curve)
  345. con_curve = np.mean(norm_resp,axis=(0,1,2))
  346. CNTR_curves.append(con_curve)
  347. OI_sel = np.array(OI_sel)
  348. BEST_params = np.array(BEST_params)
  349. TUN_curves = np.array(TUN_curves)
  350. CNTR_curves = np.array(CNTR_curves)
  351. # plot:
  352. labels = ['Orientation selectivity index','Frequency (neurons)','']
  353. c = '#3399ff'
  354. xline = [np.nanmedian(OI_sel)]
  355. fig1, ax = plt.subplots()
  356. ax.hist(OI_sel,bins=10,color=c)
  357. if len(xline) != 0:
  358. plt.axvline(xline[0], color='k', linestyle='dashed')
  359. if len(labels) != 0:
  360. ax.set_xlabel(labels[0])
  361. ax.set_ylabel(labels[1])
  362. ax.set_title(labels[2])
  363. ax.set_xlim([0,1])
  364. ax.spines['top'].set_visible(False)
  365. ax.spines['right'].set_visible(False)
  366. ax.set_title('Orientation tuning:')
  367. plt.show()
  368. output = {'goods':goods,
  369. 'OI_sel':OI_sel,
  370. 'SFS_curves':SFS_curves,
  371. 'TUN_curves':TUN_curves,
  372. 'CNTR_curves':CNTR_curves,
  373. 'BEST_params':BEST_params}
  374. savemat(gen_path + name + roi + '_ori_sel' + '.mat', output)
  375. def size_tuning(gen_path,stim_size,radii,trained_model,name,goods,roi):
  376. # load data:
  377. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  378. data_dict = {}
  379. f = loadmat(data_path)
  380. for k, v in f.items():
  381. data_dict[k] = np.array(v)
  382. centrex = data_dict['centrex'].astype(int)[0]
  383. centrey = data_dict['centrey'].astype(int)[0]
  384. radii_f = radii.astype(int)
  385. data_path = gen_path + name + roi + '_ori_sel' + '.mat'
  386. data_dict = {}
  387. f = loadmat(data_path)
  388. for k, v in f.items():
  389. data_dict[k] = np.array(v)
  390. BEST_params = data_dict['BEST_params']
  391. # define transforms:
  392. transform = transforms.Compose([transforms.ToPILImage(),
  393. transforms.Resize([224,224]),
  394. transforms.ToTensor(),
  395. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  396. std=[0.229, 0.224, 0.225])])
  397. SSI_sel = []
  398. BEST_size = []
  399. S_TUN_curves = []
  400. for n in tqdm(range(len(goods)), disable=False):
  401. # RF center:
  402. px_x = centrex[n]
  403. px_y = centrey[n]
  404. params = BEST_params[n,:]
  405. # iterate gabors:
  406. stimuli = []
  407. for r in range(len(radii_f)):
  408. # create stimulus:
  409. stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
  410. radius = radii_f[r], wavelength = params[0],
  411. orientation = params[1], phase = params[2])
  412. stimuli.append(stimulus)
  413. # get response:
  414. stimuli = np.array(stimuli).squeeze()
  415. all_resps = []
  416. for s in range(stimuli.shape[0]):
  417. temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
  418. all_resps.append(trained_model(temp[None,:,:,:]).squeeze()[goods[n]].detach().numpy())
  419. all_resps = np.array(all_resps)
  420. # find best size (= max response)
  421. index = np.where(all_resps == np.max(all_resps))[0]
  422. if len(index) > 1:
  423. index = index[0]
  424. rad_max = radii[np.int(index)]
  425. BEST_size.append(rad_max)
  426. # compute the surround suppression index SS:
  427. norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
  428. max_resp = norm_resp[np.int(index)]
  429. fin_resp = norm_resp[-1]
  430. if max_resp == fin_resp:
  431. chn_sel = 0
  432. else:
  433. chn_sel = (max_resp - fin_resp) / (max_resp + fin_resp)
  434. SSI_sel.append(np.squeeze(chn_sel))
  435. S_TUN_curves.append(norm_resp)
  436. SSI_sel = np.array(SSI_sel)
  437. BEST_size = np.array(BEST_size)
  438. S_TUN_curves = np.array(S_TUN_curves)
  439. # plot:
  440. labels = ['Surround suppression index','Frequency (neurons)','']
  441. c = '#3399ff'
  442. xline = [np.nanmedian(SSI_sel)]
  443. fig1, ax = plt.subplots()
  444. ax.hist(SSI_sel,bins=10,color=c)
  445. if len(xline) != 0:
  446. plt.axvline(xline[0], color='k', linestyle='dashed')
  447. if len(labels) != 0:
  448. ax.set_xlabel(labels[0])
  449. ax.set_ylabel(labels[1])
  450. ax.set_title(labels[2])
  451. ax.set_xlim([0,1])
  452. ax.spines['top'].set_visible(False)
  453. ax.spines['right'].set_visible(False)
  454. ax.set_title('Surround suppression:')
  455. plt.show()
  456. output = {'goods':goods,
  457. 'SSI_sel':SSI_sel,
  458. 'S_TUN_curves':S_TUN_curves,
  459. 'BEST_size':BEST_size}
  460. savemat(gen_path + name + roi + '_size_sel' + '.mat', output)
  461. def get_full_resps(gen_path,trained_model,goods,name,roi,folder):
  462. trained_model.to('cpu').eval()
  463. trained_model.w_s = torch.nn.Parameter(torch.ones(trained_model.w_s.shape),requires_grad=False)
  464. transform = transforms.Compose([transforms.Resize([224,224]),
  465. transforms.ToTensor(),
  466. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  467. std=[0.229, 0.224, 0.225])])
  468. dataset = datasets.ImageFolder(folder, transform=transform)
  469. dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset.imgs), shuffle=False)
  470. img_data, junk = next(iter(dataloader))
  471. resps = trained_model(img_data).squeeze().detach().numpy()[:,goods]
  472. output = {'goods':goods,
  473. 'responses':resps}
  474. savemat(gen_path + name + roi + '_responses' + '.mat', output)
  475. return resps, img_data
  476. def output(gen_path,name,roi,goods,good_neurons,resps=False):
  477. # the output must be re-arranged in the same format as the input files
  478. # but we applied two selections (one for reliability and one for val-corr)
  479. max_n = len(good_neurons)
  480. tot = np.array(np.where(good_neurons.squeeze()==1)).squeeze()
  481. goods_ok = tot[goods]
  482. data_dict = {}
  483. data_path = gen_path + name + roi + '_good_rfs' + '.mat'
  484. f = loadmat(data_path)
  485. for k, v in f.items():
  486. if k[0] != '_':
  487. temp = np.array(v).squeeze()
  488. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  489. ts = np.array(temp.shape)
  490. ts[0] = max_n
  491. temp_r = np.empty(ts)
  492. temp_r.fill(np.nan)
  493. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  494. temp_r[goods_ok] = temp
  495. else:
  496. temp_r[tot] = temp
  497. data_dict[k] = temp_r
  498. data_path = gen_path + name + roi + '_ori_sel' + '.mat'
  499. f = loadmat(data_path)
  500. for k, v in f.items():
  501. if k[0] != '_':
  502. temp = np.array(v).squeeze()
  503. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  504. ts = np.array(temp.shape)
  505. ts[0] = max_n
  506. temp_r = np.empty(ts)
  507. temp_r.fill(np.nan)
  508. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  509. temp_r[goods_ok] = temp
  510. else:
  511. temp_r[tot] = temp
  512. data_dict[k] = temp_r
  513. data_path = gen_path + name + roi + '_size_sel' + '.mat'
  514. f = loadmat(data_path)
  515. for k, v in f.items():
  516. if k[0] != '_':
  517. temp = np.array(v).squeeze()
  518. #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
  519. ts = np.array(temp.shape)
  520. ts[0] = max_n
  521. temp_r = np.empty(ts)
  522. temp_r.fill(np.nan)
  523. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  524. temp_r[goods_ok] = temp
  525. else:
  526. temp_r[tot] = temp
  527. data_dict[k] = temp_r
  528. data_dict['good_chns'] = good_neurons
  529. if resps:
  530. data_path = gen_path + name + roi + '_responses' + '.mat'
  531. f = loadmat(data_path)
  532. for k, v in f.items():
  533. if k[0] != '_':
  534. temp = np.transpose(np.array(v).squeeze())
  535. ts = np.array(temp.shape)
  536. ts[0] = max_n
  537. temp_r = np.empty(ts)
  538. temp_r.fill(np.nan)
  539. if np.array(temp.shape[0]) == np.array(goods.shape[0]):
  540. temp_r[goods_ok] = temp
  541. else:
  542. temp_r[tot] = temp
  543. data_dict[k] = temp_r
  544. savemat(gen_path + name + roi + '_OUTPUT' + '.mat', data_dict)