123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627 |
- ### Functions for plotting, storing and fitting MEIs
- import torch
- import torchvision.models as models
- import torchvision.transforms.functional as F
- import torch.nn.functional as nnf
- from torchvision import datasets, transforms
- from lucent.optvis import render, param, transform, objectives
- from lucent.modelzoo import inceptionv1, util, inceptionv1_avgPool
- import numpy as np
- import pickle
- from scipy.io import savemat,loadmat
- from pathlib import Path
- from os import listdir
- from os.path import isfile, join
- import matplotlib.pyplot as plt
- from matplotlib.patches import Ellipse
- import shapely.affinity
- import shapely.geometry
- from skimage.draw import polygon
- from skimage.morphology import convex_hull_image
- from skimage.filters import gaussian as smoothing
- from PIL import Image
- import math
- from tqdm import tqdm
- from helper import gaussian, moments, fitgaussian, load, gabor_patch, occlude_pic
- from neural_model import Model
- def load_trained_model(gen_path,name,roi,layer = None):
- if layer is None:
- data_filename = gen_path + name + '/snapshots/grid_search_array'+ roi + '.pkl'
- f = open(data_filename,"rb")
- cc = pickle.load(f)
- val_corrs = cc['val_corrs']
- params = cc['params']
- val_corrs = np.array(val_corrs)
- layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0]
-
- n_neurons = 0
- layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
- if isfile(layer_filename):
- f = open(layer_filename,"rb")
- cc = pickle.load(f)
- val_corrs = cc['val_corrs']
- n_neurons = len(val_corrs)
-
- if n_neurons == 0:
- data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
- f = open(data_filename,"rb")
- cc = pickle.load(f)
- val_data = cc['val_data']
- n_neurons = val_data.shape[1]
- pretrained_model = inceptionv1(pretrained=True)
- trained_model = Model(pretrained_model,layer,n_neurons,device='cpu')
- trained_model.load_state_dict(torch.load(gen_path + name + '/snapshots/'+ name + roi + '_' + layer + '_neural_model.pt',map_location=torch.device('cpu')))
- return trained_model,n_neurons#,good_neurons
- def plot_layer_corrs(gen_path,name,layers,roi):
- all_corrs = []
- for layer in layers:
- layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl'
- f = open(layer_filename,"rb")
- cc = pickle.load(f)
- val_corrs = cc['val_corrs']
- all_corrs.append(val_corrs)
-
- all_corrs = np.array(all_corrs)
- fig1, ax = plt.subplots()
- c = '#3399ff'
- ax.set_title('Layers:')
- ax.boxplot(all_corrs.transpose(), labels=layers, notch=True,
- widths=.5,showcaps=False, whis=[2.5,97.5],patch_artist=True,
- showfliers=False,boxprops=dict(facecolor=c, color=c),
- capprops=dict(color=c),whiskerprops=dict(color=c),
- flierprops=dict(color=c, markeredgecolor=c),medianprops=dict(color='w',linewidth=2))
- ax.set_ylabel('Cross-validated Pearson r')
- ax.set_ylim([0,1])
- ax.spines['top'].set_visible(False)
- ax.spines['right'].set_visible(False)
- plt.show()
- temp = np.median(all_corrs,axis=1)
- layer = layers[np.where(temp==temp.max())[0][0].astype('int')]
-
- return all_corrs, layer
- def compute_val_corrs(gen_path,name,trained_model,roi):
- data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl'
- f = open(data_filename,"rb")
- cc = pickle.load(f)
- val_img_data = cc['val_img_data']
- val_outputs = trained_model(val_img_data).squeeze()
- val_data = cc['val_data']
- corrs = []
- for n in range(val_outputs.shape[1]):
- corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0])
- return corrs
- def good_neurons(trained_model,n_neurons,corrs,make_plots=True):
- z = 1
- all_good_rfs = []
- goods = []
- for n in range(n_neurons):
- trained_model_rf = np.abs(np.reshape(trained_model.w_s[n].squeeze().detach().cpu().numpy(),
- [np.sqrt(trained_model.w_s.shape[2]).astype('int'),np.sqrt(trained_model.w_s.shape[2]).astype('int')]))
- if corrs[n] > 0:
- z += 1
- goods.append(n)
- trained_model_rf_norm = (trained_model_rf+np.abs(trained_model_rf.min()))/(np.abs(trained_model_rf.max())+np.abs(trained_model_rf.min()))
- all_good_rfs.append(trained_model_rf_norm)
- if make_plots:
- print(z,corrs[n])
- plt.imshow(trained_model_rf, cmap='seismic')
- plt.colorbar()
- plt.show()
- plt.plot(trained_model.w_f[0,n].squeeze().detach().cpu().numpy())
- plt.show()
- goods = np.array(goods)
- all_good_rfs = np.array(all_good_rfs)
- return goods,all_good_rfs
- def gaussian_RFs(gen_path,all_good_rfs,goods,corrs,all_corrs,pixperdeg,stim_size,mask_size,trained_model,name,roi,shift_x=0,shift_y=0):
- centrex = []
- centrey = []
- szx = []
- szy = []
- szdeg = []
- masks = []
- sparsity = []
- all_feats = []
- pixperdeg_reduced = all_good_rfs[0].shape[0]/all_good_rfs*pixperdeg
- scaling_f = stim_size/all_good_rfs[0].shape[0]
- scaling_mask = mask_size/all_good_rfs[0].shape[0]
- pixperdeg_reduced_mask = all_good_rfs[0].shape[0]/mask_size*pixperdeg
- fovea_x = stim_size/2 - shift_x
- fovea_y = stim_size/2 - shift_y
- ax = plt.gca()
- ax.set_title('Spatial RFs:')
- for n in range(len(goods)):
- # fit 2d Gaussian to W_s
- data = all_good_rfs[n]
- params = fitgaussian(data)
- fit = gaussian(*params)
- plt.imshow(data, cmap=plt.cm.gist_earth_r)
- (height, y, x, width_y, width_x) = params # x and y are shifted in img coords
- circle = Ellipse((x, y), 2*1.65*width_x, 2*1.65*width_y, edgecolor='r', facecolor='None', clip_on=True)
- ax.add_patch(circle)
- centrex.append(x*scaling_f-fovea_x)
- centrey.append(fovea_y-y*scaling_f)
- szx.append(2*1.65*width_x*scaling_f)
- szy.append(2*1.65*width_y*scaling_f)
- szdeg.append(np.mean((2*1.65*width_x/pixperdeg_reduced,2*1.65*width_y/pixperdeg_reduced)))
- # create binary mask summing up the estimated 2d Gaussians (95% CI)
- mask = np.zeros(shape=(mask_size,mask_size), dtype="bool")
- circ = shapely.geometry.Point((x*scaling_mask,y*scaling_mask)).buffer(1)
- ell = shapely.affinity.scale(circ, 1.65*width_x*scaling_mask+pixperdeg_reduced_mask, 1.65*width_y*scaling_mask+pixperdeg_reduced_mask)
- ell_coords = np.array(list(ell.exterior.coords))
- cc, rr = polygon(ell_coords[:,0], ell_coords[:,1], mask.shape)
- mask[rr,cc] = True
- masks.append(mask)
-
- feats = trained_model.w_f[0,goods[n]].squeeze().detach().cpu().numpy()
- chn_spars = (1 - (((np.sum(np.abs(feats)) / len(feats)) ** 2) / (np.sum(np.abs(feats) ** 2) / len(feats)))) / (1 - (1 / len(feats)))
- sparsity.append(chn_spars)
- all_feats.append(feats)
- plt.show()
- masks = np.array(masks)
- centrex = np.array(centrex)
- centrey = np.array(centrey)
- szx = np.array(szx)
- szy = np.array(szy)
- szdeg = np.array(szdeg)
- sparsity = np.array(sparsity)
- all_feats = np.array(all_feats)
- # plot:
- labels = ['Sparsity index','Frequency (neurons)','']
- c = '#3399ff'
- xline = [np.nanmedian(sparsity)]
- fig1, ax = plt.subplots()
- ax.hist(sparsity,bins=10,color=c)
- if len(xline) != 0:
- plt.axvline(xline[0], color='k', linestyle='dashed')
- if len(labels) != 0:
- ax.set_xlabel(labels[0])
- ax.set_ylabel(labels[1])
- ax.set_title(labels[2])
- ax.set_xlim([0,1])
- ax.spines['top'].set_visible(False)
- ax.spines['right'].set_visible(False)
- ax.set_title('Feature sparsness:')
- plt.show()
-
- all_corrs = np.moveaxis(all_corrs,0,-1)
-
- output = {'goods':goods,
- 'all_good_rfs':all_good_rfs,
- 'cross_val_corr':corrs,
- 'layers_cross_val_corr':all_corrs,
- 'centrex':centrex,
- 'centrey':centrey,
- 'szx':szx,
- 'szy':szy,
- 'szdeg':szdeg,
- 'all_feats':all_feats,
- 'sparsity':sparsity}
- savemat(gen_path + name + roi + '_good_rfs' + '.mat', output)
- return masks
- def generate_MEIS(gen_path,trained_model,name,goods=None,roi='',sign=1,batch_n=1):
- if sign == 1:
- sign_label = 'MEIs/'
- elif sign == -1:
- sign_label = 'MIIs/'
- else:
- raise Exception('The variable sign must be either 1 (for MEIs) or -1 (for MIIs)')
-
- Path(gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
- if goods is None:
- goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- trained_model.to(device).eval()
- all_transforms = [
- #transform.pad(2),
- transform.jitter(8),
- transform.random_scale([n/1000. for n in range(975, 1050)]),
- transform.random_rotate(list(range(-5,5))),
- transform.jitter(4)
- #transforms.Grayscale(3),
- ]
- batch_param_f = lambda: param.image(128, batch=batch_n, fft=True, decorrelate=True)
- cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
- for chn in range(len(goods)):
- obj = objectives.channel("output", int(goods[chn]))*sign
- _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
- show_inline=True, thresholds=(50,))
- for n in range(_[0].shape[0]):
- filename = gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_' + str(n) + '.bmp'
- temp = _[0][n]#*np.repeat(np.expand_dims(masks[chn],-1),3,axis=-1)+np.repeat((1-np.expand_dims(masks[chn],-1))*.5,3,axis=-1)
- temp_pic = Image.fromarray(np.uint8((temp)*255))
- temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
- temp_pic.save(filename)
- def generate_surround_MEIS(gen_path,trained_model,name,goods=None,roi=''):
- Path(gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True)
- if goods is None:
- goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int)
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- trained_model.to(device).eval()
- all_transforms = [
- transform.pad(2),
- transform.jitter(4),
- transform.random_scale([n/1000. for n in range(975, 1050)]),
- transform.random_rotate(list(range(-5,5))),
- transforms.Grayscale(3),
- ]
- batch_param_f = lambda: param.image(128, batch=1, fft=True, decorrelate=True)
- cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3)
- for chn in range(len(goods)):
- obj = objectives.channel("output", int(goods[chn]))
- mei = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
- show_inline=True, thresholds=(50,))
- mask = trained_model.w_s[chn].squeeze()
- mask = torch.reshape(mask,(np.sqrt(mask.shape[0]).astype('int'),np.sqrt(mask.shape[0]).astype('int')))
- mask = np.abs(nnf.interpolate(mask[None,None,:],mei[0].shape[1]).squeeze().cpu().detach().numpy())
- mask = smoothing(mask>(np.std(mask)*3),1)
- mask = np.repeat(mask[:,:,np.newaxis],3,2)
-
- all_transforms = [
- occlude_pic(mei[0], mask,device),
- transform.pad(2),
- transform.jitter(4),
- transform.random_scale([n/1000. for n in range(975, 1050)]),
- transform.random_rotate(list(range(-5,5))),
- transforms.Grayscale(3),
- ]
-
- for i in range(2):
- if i:
- obj = objectives.channel("output", int(goods[chn]))
- filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_positive_Surround.bmp'
- else:
- obj = -objectives.channel("output", int(goods[chn]))
- filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_negative_Surround.bmp'
- _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms,
- show_inline=True, thresholds=(50,))
- temp = _[0][0]
- temp_pic = Image.fromarray(np.uint8((temp)*255))
- temp_pic = temp_pic.resize((500,500),Image.BILINEAR)
- temp_pic.save(filename)
-
- def ori_tuning(gen_path,stim_size,wavelengths,orientations,phases,contrasts,trained_model,name,goods,roi):
- # load data:
- data_path = gen_path + name + roi + '_good_rfs' + '.mat'
- data_dict = {}
- f = loadmat(data_path)
- for k, v in f.items():
- data_dict[k] = np.array(v)
- centrex = data_dict['centrex'].astype(int)[0]
- centrey = data_dict['centrey'].astype(int)[0]
- # define transforms:
- #transform = transforms.Compose([transforms.ToPILImage(),
- # transforms.Resize([224,224]),
- # transforms.ToTensor(),
- # transforms.Normalize(mean=[0.485, 0.456, 0.406],
- # std=[0.229, 0.224, 0.225])])
-
- transform = transforms.Compose([transforms.ToPILImage(),
- transforms.Resize([224,224]),
- transforms.ToTensor()])
- # RF center:
- px_x = np.median(centrex)
- px_y = np.median(centrey)
- # iterate gabors:
- stimuli = []
- all_params = []
- for w in range(len(wavelengths)):
- for o in range(len(orientations)):
- for p in range(len(phases)):
- for c in range(len(contrasts)):
-
- min_b = 0.5 - contrasts[c]/2
- max_b = 0.5 + contrasts[c]/2
- # create stimulus:
- stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
- radius = stim_size, wavelength = wavelengths[w],
- orientation = orientations[o], phase = phases[p],
- min_brightness = min_b, max_brightness = max_b)
- stimuli.append(stimulus)
- all_params.append([wavelengths[w],orientations[o],phases[p],contrasts[c]])
- all_params = np.array(all_params)
- # get response:
- stimuli = np.array(stimuli).squeeze()
- all_resps_tot = []
- for s in range(stimuli.shape[0]):
- temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
- all_resps_tot.append(trained_model(temp[None,:,:,:]).squeeze()[goods].detach().numpy())
- all_resps_tot = np.array(all_resps_tot)
- OI_sel = []
- BEST_params = []
- TUN_curves = []
- SFS_curves = []
- CNTR_curves = []
- for n in range(len(goods)):
- # reshape the overall response:
- all_resps = all_resps_tot[:,n].reshape(len(wavelengths),len(orientations),len(phases),len(contrasts))
- # compute the orientation index OI:
- norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
- max_idx = np.where(norm_resp == np.max(norm_resp))
- if len(max_idx)>1:
- max_idx = np.array(max_idx)
- max_idx = max_idx[:,-1]
- ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] + 90)
- if np.array(ortho_idx).size == 0:
- ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] - 90)
- chn_sel = (norm_resp.max() - norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]]) / (norm_resp.max() + norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]])
- OI_sel.append(np.squeeze(chn_sel))
-
- # find best parameters (= max response)
- #print(np.where(all_resps_tot == np.max(all_resps))[0][0])
- par_max = all_params[np.int(np.where(all_resps_tot == np.max(all_resps))[0][0]),:]
- BEST_params.append(par_max)
-
- # output all tuning curve
- ori_curve = np.squeeze(norm_resp[max_idx[0],:,max_idx[2],max_idx[3]])
- TUN_curves.append(ori_curve)
-
- #sf_curve = np.squeeze(norm_resp[:,max_idx[1],max_idx[2],max_idx[3]])
- sf_curve = np.squeeze(all_resps[:,max_idx[1],max_idx[2],max_idx[3]])
- SFS_curves.append(sf_curve)
-
- con_curve = np.mean(norm_resp,axis=(0,1,2))
- CNTR_curves.append(con_curve)
- OI_sel = np.array(OI_sel)
- BEST_params = np.array(BEST_params)
- TUN_curves = np.array(TUN_curves)
- CNTR_curves = np.array(CNTR_curves)
- # plot:
- labels = ['Orientation selectivity index','Frequency (neurons)','']
- c = '#3399ff'
- xline = [np.nanmedian(OI_sel)]
- fig1, ax = plt.subplots()
- ax.hist(OI_sel,bins=10,color=c)
- if len(xline) != 0:
- plt.axvline(xline[0], color='k', linestyle='dashed')
- if len(labels) != 0:
- ax.set_xlabel(labels[0])
- ax.set_ylabel(labels[1])
- ax.set_title(labels[2])
- ax.set_xlim([0,1])
- ax.spines['top'].set_visible(False)
- ax.spines['right'].set_visible(False)
- ax.set_title('Orientation tuning:')
- plt.show()
-
- output = {'goods':goods,
- 'OI_sel':OI_sel,
- 'SFS_curves':SFS_curves,
- 'TUN_curves':TUN_curves,
- 'CNTR_curves':CNTR_curves,
- 'BEST_params':BEST_params}
- savemat(gen_path + name + roi + '_ori_sel' + '.mat', output)
-
- def size_tuning(gen_path,stim_size,radii,trained_model,name,goods,roi):
- # load data:
- data_path = gen_path + name + roi + '_good_rfs' + '.mat'
- data_dict = {}
- f = loadmat(data_path)
- for k, v in f.items():
- data_dict[k] = np.array(v)
- centrex = data_dict['centrex'].astype(int)[0]
- centrey = data_dict['centrey'].astype(int)[0]
- radii_f = radii.astype(int)
- data_path = gen_path + name + roi + '_ori_sel' + '.mat'
- data_dict = {}
- f = loadmat(data_path)
- for k, v in f.items():
- data_dict[k] = np.array(v)
- BEST_params = data_dict['BEST_params']
- # define transforms:
- transform = transforms.Compose([transforms.ToPILImage(),
- transforms.Resize([224,224]),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
-
- SSI_sel = []
- BEST_size = []
- S_TUN_curves = []
- for n in tqdm(range(len(goods)), disable=False):
- # RF center:
- px_x = centrex[n]
- px_y = centrey[n]
- params = BEST_params[n,:]
- # iterate gabors:
- stimuli = []
- for r in range(len(radii_f)):
- # create stimulus:
- stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x],
- radius = radii_f[r], wavelength = params[0],
- orientation = params[1], phase = params[2])
- stimuli.append(stimulus)
- # get response:
- stimuli = np.array(stimuli).squeeze()
- all_resps = []
- for s in range(stimuli.shape[0]):
- temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu')
- all_resps.append(trained_model(temp[None,:,:,:]).squeeze()[goods[n]].detach().numpy())
- all_resps = np.array(all_resps)
- # find best size (= max response)
- index = np.where(all_resps == np.max(all_resps))[0]
- if len(index) > 1:
- index = index[0]
- rad_max = radii[np.int(index)]
- BEST_size.append(rad_max)
-
- # compute the surround suppression index SS:
- norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps))
- max_resp = norm_resp[np.int(index)]
- fin_resp = norm_resp[-1]
- if max_resp == fin_resp:
- chn_sel = 0
- else:
- chn_sel = (max_resp - fin_resp) / (max_resp + fin_resp)
- SSI_sel.append(np.squeeze(chn_sel))
- S_TUN_curves.append(norm_resp)
- SSI_sel = np.array(SSI_sel)
- BEST_size = np.array(BEST_size)
- S_TUN_curves = np.array(S_TUN_curves)
- # plot:
- labels = ['Surround suppression index','Frequency (neurons)','']
- c = '#3399ff'
- xline = [np.nanmedian(SSI_sel)]
- fig1, ax = plt.subplots()
- ax.hist(SSI_sel,bins=10,color=c)
- if len(xline) != 0:
- plt.axvline(xline[0], color='k', linestyle='dashed')
- if len(labels) != 0:
- ax.set_xlabel(labels[0])
- ax.set_ylabel(labels[1])
- ax.set_title(labels[2])
- ax.set_xlim([0,1])
- ax.spines['top'].set_visible(False)
- ax.spines['right'].set_visible(False)
- ax.set_title('Surround suppression:')
- plt.show()
-
- output = {'goods':goods,
- 'SSI_sel':SSI_sel,
- 'S_TUN_curves':S_TUN_curves,
- 'BEST_size':BEST_size}
- savemat(gen_path + name + roi + '_size_sel' + '.mat', output)
- def get_full_resps(gen_path,trained_model,goods,name,roi,folder):
- trained_model.to('cpu').eval()
- trained_model.w_s = torch.nn.Parameter(torch.ones(trained_model.w_s.shape),requires_grad=False)
- transform = transforms.Compose([transforms.Resize([224,224]),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- dataset = datasets.ImageFolder(folder, transform=transform)
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset.imgs), shuffle=False)
- img_data, junk = next(iter(dataloader))
- resps = trained_model(img_data).squeeze().detach().numpy()[:,goods]
-
- output = {'goods':goods,
- 'responses':resps}
- savemat(gen_path + name + roi + '_responses' + '.mat', output)
- return resps, img_data
- def output(gen_path,name,roi,goods,good_neurons,resps=False):
- # the output must be re-arranged in the same format as the input files
- # but we applied two selections (one for reliability and one for val-corr)
- max_n = len(good_neurons)
- tot = np.array(np.where(good_neurons.squeeze()==1)).squeeze()
- goods_ok = tot[goods]
-
- data_dict = {}
-
- data_path = gen_path + name + roi + '_good_rfs' + '.mat'
- f = loadmat(data_path)
- for k, v in f.items():
- if k[0] != '_':
- temp = np.array(v).squeeze()
- #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
- ts = np.array(temp.shape)
- ts[0] = max_n
- temp_r = np.empty(ts)
- temp_r.fill(np.nan)
- if np.array(temp.shape[0]) == np.array(goods.shape[0]):
- temp_r[goods_ok] = temp
- else:
- temp_r[tot] = temp
- data_dict[k] = temp_r
-
- data_path = gen_path + name + roi + '_ori_sel' + '.mat'
- f = loadmat(data_path)
- for k, v in f.items():
- if k[0] != '_':
- temp = np.array(v).squeeze()
- #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
- ts = np.array(temp.shape)
- ts[0] = max_n
- temp_r = np.empty(ts)
- temp_r.fill(np.nan)
- if np.array(temp.shape[0]) == np.array(goods.shape[0]):
- temp_r[goods_ok] = temp
- else:
- temp_r[tot] = temp
- data_dict[k] = temp_r
-
- data_path = gen_path + name + roi + '_size_sel' + '.mat'
- f = loadmat(data_path)
- for k, v in f.items():
- if k[0] != '_':
- temp = np.array(v).squeeze()
- #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int)
- ts = np.array(temp.shape)
- ts[0] = max_n
- temp_r = np.empty(ts)
- temp_r.fill(np.nan)
- if np.array(temp.shape[0]) == np.array(goods.shape[0]):
- temp_r[goods_ok] = temp
- else:
- temp_r[tot] = temp
- data_dict[k] = temp_r
-
- data_dict['good_chns'] = good_neurons
- if resps:
- data_path = gen_path + name + roi + '_responses' + '.mat'
- f = loadmat(data_path)
- for k, v in f.items():
- if k[0] != '_':
- temp = np.transpose(np.array(v).squeeze())
- ts = np.array(temp.shape)
- ts[0] = max_n
- temp_r = np.empty(ts)
- temp_r.fill(np.nan)
- if np.array(temp.shape[0]) == np.array(goods.shape[0]):
- temp_r[goods_ok] = temp
- else:
- temp_r[tot] = temp
- data_dict[k] = temp_r
-
- savemat(gen_path + name + roi + '_OUTPUT' + '.mat', data_dict)
-
|