import torch from torch import nn, Tensor import torch.nn.functional as F from typing import Dict, Iterable, Callable import numpy as np from helper import makeGaussian, FeatureExtractor, fix_parameters, load_sta class Model(nn.Module): """ Model of neural responses """ def __init__(self,pretrained_model,layer,n_neurons,device=None,debug=False): super(Model, self).__init__() self.layer = layer self.debug = debug self.ann = fix_parameters(pretrained_model) self.inc_features = FeatureExtractor(self.ann, layers=[self.layer]) dummy_input = torch.ones(1, 3, 224, 224) dummy_feats = self.inc_features(dummy_input) self.mod_shape = dummy_feats[self.layer].shape if self.debug: self.w_s = torch.nn.Parameter(torch.randn(n_neurons, 1, self.mod_shape[-1]*self.mod_shape[-1], 1, device=device)) else: self.w_s = torch.nn.Parameter(torch.randn(n_neurons, 1, self.mod_shape[-1]*self.mod_shape[-1], 1, device=device,requires_grad=True)) self.w_f = torch.nn.Parameter(torch.randn(1, n_neurons, 1, self.mod_shape[1], device=device, requires_grad=True)) self.ann_bn = torch.nn.BatchNorm2d(self.mod_shape[1],momentum=0.9,eps=1e-4,affine=False) self.output = torch.nn.Identity() def forward(self,x): x = self.inc_features(x) x = x[self.layer] x = F.relu(self.ann_bn(x)) x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],1) x = x.permute(0,-1,2,1) x = F.conv2d(x,torch.abs(self.w_s)) x = torch.mul(x,self.w_f) x = torch.sum(x,-1,keepdim=True) return self.output(x) def initialize(self): nn.init.xavier_normal_(self.w_f) if self.debug: temp = np.ndarray.flatten(makeGaussian(self.mod_shape[-1], fwhm = self.mod_shape[-1]/20, center=[self.mod_shape[-1]*.3,self.mod_shape[-1]*.7])) for i in range(len(self.w_s)): self.w_s[i,0,:,0] = torch.tensor(temp) else: nn.init.xavier_normal_(self.w_s)