### Functions for plotting, storing and fitting MEIs import torch import torchvision.models as models import torchvision.transforms.functional as F import torch.nn.functional as nnf from torchvision import datasets, transforms from lucent.optvis import render, param, transform, objectives from lucent.modelzoo import inceptionv1, util, inceptionv1_avgPool import numpy as np import pickle from scipy.io import savemat,loadmat from pathlib import Path from os import listdir from os.path import isfile, join import matplotlib.pyplot as plt from matplotlib.patches import Ellipse import shapely.affinity import shapely.geometry from skimage.draw import polygon from skimage.morphology import convex_hull_image from skimage.filters import gaussian as smoothing from PIL import Image import math from tqdm import tqdm from helper import gaussian, moments, fitgaussian, load, gabor_patch, occlude_pic from neural_model import Model def load_trained_model(gen_path,name,roi,layer = None): if layer is None: data_filename = gen_path + name + '/snapshots/grid_search_array'+ roi + '.pkl' f = open(data_filename,"rb") cc = pickle.load(f) val_corrs = cc['val_corrs'] params = cc['params'] val_corrs = np.array(val_corrs) layer = params[np.where(val_corrs==val_corrs.max())[0][0].astype('int')][0] n_neurons = 0 layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl' if isfile(layer_filename): f = open(layer_filename,"rb") cc = pickle.load(f) val_corrs = cc['val_corrs'] n_neurons = len(val_corrs) if n_neurons == 0: data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl' f = open(data_filename,"rb") cc = pickle.load(f) val_data = cc['val_data'] n_neurons = val_data.shape[1] pretrained_model = inceptionv1(pretrained=True) trained_model = Model(pretrained_model,layer,n_neurons,device='cpu') trained_model.load_state_dict(torch.load(gen_path + name + '/snapshots/'+ name + roi + '_' + layer + '_neural_model.pt',map_location=torch.device('cpu'))) return trained_model,n_neurons#,good_neurons def plot_layer_corrs(gen_path,name,layers,roi): all_corrs = [] for layer in layers: layer_filename = gen_path + 'val_corr_'+ name + roi + '_' + layer + '.pkl' f = open(layer_filename,"rb") cc = pickle.load(f) val_corrs = cc['val_corrs'] all_corrs.append(val_corrs) all_corrs = np.array(all_corrs) fig1, ax = plt.subplots() c = '#3399ff' ax.set_title('Layers:') ax.boxplot(all_corrs.transpose(), labels=layers, notch=True, widths=.5,showcaps=False, whis=[2.5,97.5],patch_artist=True, showfliers=False,boxprops=dict(facecolor=c, color=c), capprops=dict(color=c),whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c),medianprops=dict(color='w',linewidth=2)) ax.set_ylabel('Cross-validated Pearson r') ax.set_ylim([0,1]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) plt.show() temp = np.median(all_corrs,axis=1) layer = layers[np.where(temp==temp.max())[0][0].astype('int')] return all_corrs, layer def compute_val_corrs(gen_path,name,trained_model,roi): data_filename = gen_path + name + '/data_THINGS_array'+ roi +'_v1.pkl' f = open(data_filename,"rb") cc = pickle.load(f) val_img_data = cc['val_img_data'] val_outputs = trained_model(val_img_data).squeeze() val_data = cc['val_data'] corrs = [] for n in range(val_outputs.shape[1]): corrs.append(np.corrcoef(val_outputs[:,n].cpu().detach().numpy(),val_data[:,n])[1,0]) return corrs def good_neurons(trained_model,n_neurons,corrs,make_plots=True): z = 1 all_good_rfs = [] goods = [] for n in range(n_neurons): trained_model_rf = np.abs(np.reshape(trained_model.w_s[n].squeeze().detach().cpu().numpy(), [np.sqrt(trained_model.w_s.shape[2]).astype('int'),np.sqrt(trained_model.w_s.shape[2]).astype('int')])) if corrs[n] > 0: z += 1 goods.append(n) trained_model_rf_norm = (trained_model_rf+np.abs(trained_model_rf.min()))/(np.abs(trained_model_rf.max())+np.abs(trained_model_rf.min())) all_good_rfs.append(trained_model_rf_norm) if make_plots: print(z,corrs[n]) plt.imshow(trained_model_rf, cmap='seismic') plt.colorbar() plt.show() plt.plot(trained_model.w_f[0,n].squeeze().detach().cpu().numpy()) plt.show() goods = np.array(goods) all_good_rfs = np.array(all_good_rfs) return goods,all_good_rfs def gaussian_RFs(gen_path,all_good_rfs,goods,corrs,all_corrs,pixperdeg,stim_size,mask_size,trained_model,name,roi,shift_x=0,shift_y=0): centrex = [] centrey = [] szx = [] szy = [] szdeg = [] masks = [] sparsity = [] all_feats = [] pixperdeg_reduced = all_good_rfs[0].shape[0]/all_good_rfs*pixperdeg scaling_f = stim_size/all_good_rfs[0].shape[0] scaling_mask = mask_size/all_good_rfs[0].shape[0] pixperdeg_reduced_mask = all_good_rfs[0].shape[0]/mask_size*pixperdeg fovea_x = stim_size/2 - shift_x fovea_y = stim_size/2 - shift_y ax = plt.gca() ax.set_title('Spatial RFs:') for n in range(len(goods)): # fit 2d Gaussian to W_s data = all_good_rfs[n] params = fitgaussian(data) fit = gaussian(*params) plt.imshow(data, cmap=plt.cm.gist_earth_r) (height, y, x, width_y, width_x) = params # x and y are shifted in img coords circle = Ellipse((x, y), 2*1.65*width_x, 2*1.65*width_y, edgecolor='r', facecolor='None', clip_on=True) ax.add_patch(circle) centrex.append(x*scaling_f-fovea_x) centrey.append(fovea_y-y*scaling_f) szx.append(2*1.65*width_x*scaling_f) szy.append(2*1.65*width_y*scaling_f) szdeg.append(np.mean((2*1.65*width_x/pixperdeg_reduced,2*1.65*width_y/pixperdeg_reduced))) # create binary mask summing up the estimated 2d Gaussians (95% CI) mask = np.zeros(shape=(mask_size,mask_size), dtype="bool") circ = shapely.geometry.Point((x*scaling_mask,y*scaling_mask)).buffer(1) ell = shapely.affinity.scale(circ, 1.65*width_x*scaling_mask+pixperdeg_reduced_mask, 1.65*width_y*scaling_mask+pixperdeg_reduced_mask) ell_coords = np.array(list(ell.exterior.coords)) cc, rr = polygon(ell_coords[:,0], ell_coords[:,1], mask.shape) mask[rr,cc] = True masks.append(mask) feats = trained_model.w_f[0,goods[n]].squeeze().detach().cpu().numpy() chn_spars = (1 - (((np.sum(np.abs(feats)) / len(feats)) ** 2) / (np.sum(np.abs(feats) ** 2) / len(feats)))) / (1 - (1 / len(feats))) sparsity.append(chn_spars) all_feats.append(feats) plt.show() masks = np.array(masks) centrex = np.array(centrex) centrey = np.array(centrey) szx = np.array(szx) szy = np.array(szy) szdeg = np.array(szdeg) sparsity = np.array(sparsity) all_feats = np.array(all_feats) # plot: labels = ['Sparsity index','Frequency (neurons)',''] c = '#3399ff' xline = [np.nanmedian(sparsity)] fig1, ax = plt.subplots() ax.hist(sparsity,bins=10,color=c) if len(xline) != 0: plt.axvline(xline[0], color='k', linestyle='dashed') if len(labels) != 0: ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) ax.set_title(labels[2]) ax.set_xlim([0,1]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_title('Feature sparsness:') plt.show() all_corrs = np.moveaxis(all_corrs,0,-1) output = {'goods':goods, 'all_good_rfs':all_good_rfs, 'cross_val_corr':corrs, 'layers_cross_val_corr':all_corrs, 'centrex':centrex, 'centrey':centrey, 'szx':szx, 'szy':szy, 'szdeg':szdeg, 'all_feats':all_feats, 'sparsity':sparsity} savemat(gen_path + name + roi + '_good_rfs' + '.mat', output) return masks def generate_MEIS(gen_path,trained_model,name,goods=None,roi='',sign=1,batch_n=1): if sign == 1: sign_label = 'MEIs/' elif sign == -1: sign_label = 'MIIs/' else: raise Exception('The variable sign must be either 1 (for MEIs) or -1 (for MIIs)') Path(gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True) if goods is None: goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") trained_model.to(device).eval() all_transforms = [ #transform.pad(2), transform.jitter(8), transform.random_scale([n/1000. for n in range(975, 1050)]), transform.random_rotate(list(range(-5,5))), transform.jitter(4) #transforms.Grayscale(3), ] batch_param_f = lambda: param.image(128, batch=batch_n, fft=True, decorrelate=True) cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3) for chn in range(len(goods)): obj = objectives.channel("output", int(goods[chn]))*sign _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms, show_inline=True, thresholds=(50,)) for n in range(_[0].shape[0]): filename = gen_path + sign_label + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_' + str(n) + '.bmp' temp = _[0][n]#*np.repeat(np.expand_dims(masks[chn],-1),3,axis=-1)+np.repeat((1-np.expand_dims(masks[chn],-1))*.5,3,axis=-1) temp_pic = Image.fromarray(np.uint8((temp)*255)) temp_pic = temp_pic.resize((500,500),Image.BILINEAR) temp_pic.save(filename) def generate_surround_MEIS(gen_path,trained_model,name,goods=None,roi=''): Path(gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/').mkdir(parents=True, exist_ok=True) if goods is None: goods = np.linspace(0, trained_model.n_neurons-1, trained_model.n_neurons).astype(int) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") trained_model.to(device).eval() all_transforms = [ transform.pad(2), transform.jitter(4), transform.random_scale([n/1000. for n in range(975, 1050)]), transform.random_rotate(list(range(-5,5))), transforms.Grayscale(3), ] batch_param_f = lambda: param.image(128, batch=1, fft=True, decorrelate=True) cppn_opt = lambda params: torch.optim.Adam(params, 1e-2, weight_decay=1e-3) for chn in range(len(goods)): obj = objectives.channel("output", int(goods[chn])) mei = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms, show_inline=True, thresholds=(50,)) mask = trained_model.w_s[chn].squeeze() mask = torch.reshape(mask,(np.sqrt(mask.shape[0]).astype('int'),np.sqrt(mask.shape[0]).astype('int'))) mask = np.abs(nnf.interpolate(mask[None,None,:],mei[0].shape[1]).squeeze().cpu().detach().numpy()) mask = smoothing(mask>(np.std(mask)*3),1) mask = np.repeat(mask[:,:,np.newaxis],3,2) all_transforms = [ occlude_pic(mei[0], mask,device), transform.pad(2), transform.jitter(4), transform.random_scale([n/1000. for n in range(975, 1050)]), transform.random_rotate(list(range(-5,5))), transforms.Grayscale(3), ] for i in range(2): if i: obj = objectives.channel("output", int(goods[chn])) filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_positive_Surround.bmp' else: obj = -objectives.channel("output", int(goods[chn])) filename = gen_path + 'surrMEIs/' + name + roi + '_fft_decorr_lr_1e-2_l2_1e-3/' + str(goods[chn]) + '_negative_Surround.bmp' _ = render.render_vis(trained_model, obj, batch_param_f, cppn_opt, transforms=all_transforms, show_inline=True, thresholds=(50,)) temp = _[0][0] temp_pic = Image.fromarray(np.uint8((temp)*255)) temp_pic = temp_pic.resize((500,500),Image.BILINEAR) temp_pic.save(filename) def ori_tuning(gen_path,stim_size,wavelengths,orientations,phases,contrasts,trained_model,name,goods,roi): # load data: data_path = gen_path + name + roi + '_good_rfs' + '.mat' data_dict = {} f = loadmat(data_path) for k, v in f.items(): data_dict[k] = np.array(v) centrex = data_dict['centrex'].astype(int)[0] centrey = data_dict['centrey'].astype(int)[0] # define transforms: #transform = transforms.Compose([transforms.ToPILImage(), # transforms.Resize([224,224]), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225])]) transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize([224,224]), transforms.ToTensor()]) # RF center: px_x = np.median(centrex) px_y = np.median(centrey) # iterate gabors: stimuli = [] all_params = [] for w in range(len(wavelengths)): for o in range(len(orientations)): for p in range(len(phases)): for c in range(len(contrasts)): min_b = 0.5 - contrasts[c]/2 max_b = 0.5 + contrasts[c]/2 # create stimulus: stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x], radius = stim_size, wavelength = wavelengths[w], orientation = orientations[o], phase = phases[p], min_brightness = min_b, max_brightness = max_b) stimuli.append(stimulus) all_params.append([wavelengths[w],orientations[o],phases[p],contrasts[c]]) all_params = np.array(all_params) # get response: stimuli = np.array(stimuli).squeeze() all_resps_tot = [] for s in range(stimuli.shape[0]): temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu') all_resps_tot.append(trained_model(temp[None,:,:,:]).squeeze()[goods].detach().numpy()) all_resps_tot = np.array(all_resps_tot) OI_sel = [] BEST_params = [] TUN_curves = [] SFS_curves = [] CNTR_curves = [] for n in range(len(goods)): # reshape the overall response: all_resps = all_resps_tot[:,n].reshape(len(wavelengths),len(orientations),len(phases),len(contrasts)) # compute the orientation index OI: norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps)) max_idx = np.where(norm_resp == np.max(norm_resp)) if len(max_idx)>1: max_idx = np.array(max_idx) max_idx = max_idx[:,-1] ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] + 90) if np.array(ortho_idx).size == 0: ortho_idx = np.where((180 * orientations / math.pi) == (180 * orientations / math.pi)[max_idx[1]] - 90) chn_sel = (norm_resp.max() - norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]]) / (norm_resp.max() + norm_resp[max_idx[0],ortho_idx,max_idx[2],max_idx[3]]) OI_sel.append(np.squeeze(chn_sel)) # find best parameters (= max response) #print(np.where(all_resps_tot == np.max(all_resps))[0][0]) par_max = all_params[np.int(np.where(all_resps_tot == np.max(all_resps))[0][0]),:] BEST_params.append(par_max) # output all tuning curve ori_curve = np.squeeze(norm_resp[max_idx[0],:,max_idx[2],max_idx[3]]) TUN_curves.append(ori_curve) #sf_curve = np.squeeze(norm_resp[:,max_idx[1],max_idx[2],max_idx[3]]) sf_curve = np.squeeze(all_resps[:,max_idx[1],max_idx[2],max_idx[3]]) SFS_curves.append(sf_curve) con_curve = np.mean(norm_resp,axis=(0,1,2)) CNTR_curves.append(con_curve) OI_sel = np.array(OI_sel) BEST_params = np.array(BEST_params) TUN_curves = np.array(TUN_curves) CNTR_curves = np.array(CNTR_curves) # plot: labels = ['Orientation selectivity index','Frequency (neurons)',''] c = '#3399ff' xline = [np.nanmedian(OI_sel)] fig1, ax = plt.subplots() ax.hist(OI_sel,bins=10,color=c) if len(xline) != 0: plt.axvline(xline[0], color='k', linestyle='dashed') if len(labels) != 0: ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) ax.set_title(labels[2]) ax.set_xlim([0,1]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_title('Orientation tuning:') plt.show() output = {'goods':goods, 'OI_sel':OI_sel, 'SFS_curves':SFS_curves, 'TUN_curves':TUN_curves, 'CNTR_curves':CNTR_curves, 'BEST_params':BEST_params} savemat(gen_path + name + roi + '_ori_sel' + '.mat', output) def size_tuning(gen_path,stim_size,radii,trained_model,name,goods,roi): # load data: data_path = gen_path + name + roi + '_good_rfs' + '.mat' data_dict = {} f = loadmat(data_path) for k, v in f.items(): data_dict[k] = np.array(v) centrex = data_dict['centrex'].astype(int)[0] centrey = data_dict['centrey'].astype(int)[0] radii_f = radii.astype(int) data_path = gen_path + name + roi + '_ori_sel' + '.mat' data_dict = {} f = loadmat(data_path) for k, v in f.items(): data_dict[k] = np.array(v) BEST_params = data_dict['BEST_params'] # define transforms: transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize([224,224]), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) SSI_sel = [] BEST_size = [] S_TUN_curves = [] for n in tqdm(range(len(goods)), disable=False): # RF center: px_x = centrex[n] px_y = centrey[n] params = BEST_params[n,:] # iterate gabors: stimuli = [] for r in range(len(radii_f)): # create stimulus: stimulus = gabor_patch([stim_size,stim_size], pos_yx = [px_y,px_x], radius = radii_f[r], wavelength = params[0], orientation = params[1], phase = params[2]) stimuli.append(stimulus) # get response: stimuli = np.array(stimuli).squeeze() all_resps = [] for s in range(stimuli.shape[0]): temp = transform(np.moveaxis(stimuli[s]*255,0,-1).squeeze().astype(np.uint8))#.to('cpu') all_resps.append(trained_model(temp[None,:,:,:]).squeeze()[goods[n]].detach().numpy()) all_resps = np.array(all_resps) # find best size (= max response) index = np.where(all_resps == np.max(all_resps))[0] if len(index) > 1: index = index[0] rad_max = radii[np.int(index)] BEST_size.append(rad_max) # compute the surround suppression index SS: norm_resp = (all_resps - np.min(all_resps)) / (np.max(all_resps) - np.min(all_resps)) max_resp = norm_resp[np.int(index)] fin_resp = norm_resp[-1] if max_resp == fin_resp: chn_sel = 0 else: chn_sel = (max_resp - fin_resp) / (max_resp + fin_resp) SSI_sel.append(np.squeeze(chn_sel)) S_TUN_curves.append(norm_resp) SSI_sel = np.array(SSI_sel) BEST_size = np.array(BEST_size) S_TUN_curves = np.array(S_TUN_curves) # plot: labels = ['Surround suppression index','Frequency (neurons)',''] c = '#3399ff' xline = [np.nanmedian(SSI_sel)] fig1, ax = plt.subplots() ax.hist(SSI_sel,bins=10,color=c) if len(xline) != 0: plt.axvline(xline[0], color='k', linestyle='dashed') if len(labels) != 0: ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) ax.set_title(labels[2]) ax.set_xlim([0,1]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_title('Surround suppression:') plt.show() output = {'goods':goods, 'SSI_sel':SSI_sel, 'S_TUN_curves':S_TUN_curves, 'BEST_size':BEST_size} savemat(gen_path + name + roi + '_size_sel' + '.mat', output) def get_full_resps(gen_path,trained_model,goods,name,roi,folder): trained_model.to('cpu').eval() trained_model.w_s = torch.nn.Parameter(torch.ones(trained_model.w_s.shape),requires_grad=False) transform = transforms.Compose([transforms.Resize([224,224]), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) dataset = datasets.ImageFolder(folder, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset.imgs), shuffle=False) img_data, junk = next(iter(dataloader)) resps = trained_model(img_data).squeeze().detach().numpy()[:,goods] output = {'goods':goods, 'responses':resps} savemat(gen_path + name + roi + '_responses' + '.mat', output) return resps, img_data def output(gen_path,name,roi,goods,good_neurons,resps=False): # the output must be re-arranged in the same format as the input files # but we applied two selections (one for reliability and one for val-corr) max_n = len(good_neurons) tot = np.array(np.where(good_neurons.squeeze()==1)).squeeze() goods_ok = tot[goods] data_dict = {} data_path = gen_path + name + roi + '_good_rfs' + '.mat' f = loadmat(data_path) for k, v in f.items(): if k[0] != '_': temp = np.array(v).squeeze() #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int) ts = np.array(temp.shape) ts[0] = max_n temp_r = np.empty(ts) temp_r.fill(np.nan) if np.array(temp.shape[0]) == np.array(goods.shape[0]): temp_r[goods_ok] = temp else: temp_r[tot] = temp data_dict[k] = temp_r data_path = gen_path + name + roi + '_ori_sel' + '.mat' f = loadmat(data_path) for k, v in f.items(): if k[0] != '_': temp = np.array(v).squeeze() #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int) ts = np.array(temp.shape) ts[0] = max_n temp_r = np.empty(ts) temp_r.fill(np.nan) if np.array(temp.shape[0]) == np.array(goods.shape[0]): temp_r[goods_ok] = temp else: temp_r[tot] = temp data_dict[k] = temp_r data_path = gen_path + name + roi + '_size_sel' + '.mat' f = loadmat(data_path) for k, v in f.items(): if k[0] != '_': temp = np.array(v).squeeze() #idx = np.array(np.where(np.array(temp.shape) == np.array(goods.shape[0]))).squeeze().astype(int) ts = np.array(temp.shape) ts[0] = max_n temp_r = np.empty(ts) temp_r.fill(np.nan) if np.array(temp.shape[0]) == np.array(goods.shape[0]): temp_r[goods_ok] = temp else: temp_r[tot] = temp data_dict[k] = temp_r data_dict['good_chns'] = good_neurons if resps: data_path = gen_path + name + roi + '_responses' + '.mat' f = loadmat(data_path) for k, v in f.items(): if k[0] != '_': temp = np.transpose(np.array(v).squeeze()) ts = np.array(temp.shape) ts[0] = max_n temp_r = np.empty(ts) temp_r.fill(np.nan) if np.array(temp.shape[0]) == np.array(goods.shape[0]): temp_r[goods_ok] = temp else: temp_r[tot] = temp data_dict[k] = temp_r savemat(gen_path + name + roi + '_OUTPUT' + '.mat', data_dict)