neural_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch
  2. from torch import nn, Tensor
  3. import torch.nn.functional as F
  4. from typing import Dict, Iterable, Callable
  5. import numpy as np
  6. from helper import makeGaussian, FeatureExtractor, fix_parameters, load_sta
  7. class Model(nn.Module):
  8. """
  9. Model of neural responses
  10. """
  11. def __init__(self,pretrained_model,layer,n_neurons,device=None,debug=False):
  12. super(Model, self).__init__()
  13. self.layer = layer
  14. self.debug = debug
  15. self.ann = fix_parameters(pretrained_model)
  16. self.inc_features = FeatureExtractor(self.ann, layers=[self.layer])
  17. dummy_input = torch.ones(1, 3, 224, 224)
  18. dummy_feats = self.inc_features(dummy_input)
  19. self.mod_shape = dummy_feats[self.layer].shape
  20. if self.debug:
  21. self.w_s = torch.nn.Parameter(torch.randn(n_neurons, 1, self.mod_shape[-1]*self.mod_shape[-1], 1,
  22. device=device))
  23. else:
  24. self.w_s = torch.nn.Parameter(torch.randn(n_neurons, 1, self.mod_shape[-1]*self.mod_shape[-1], 1,
  25. device=device,requires_grad=True))
  26. self.w_f = torch.nn.Parameter(torch.randn(1, n_neurons, 1, self.mod_shape[1],
  27. device=device, requires_grad=True))
  28. self.ann_bn = torch.nn.BatchNorm2d(self.mod_shape[1],momentum=0.9,eps=1e-4,affine=False)
  29. self.output = torch.nn.Identity()
  30. def forward(self,x):
  31. x = self.inc_features(x)
  32. x = x[self.layer]
  33. x = F.relu(self.ann_bn(x))
  34. x = x.view(x.shape[0],x.shape[1],x.shape[2]*x.shape[3],1)
  35. x = x.permute(0,-1,2,1)
  36. x = F.conv2d(x,torch.abs(self.w_s))
  37. x = torch.mul(x,self.w_f)
  38. x = torch.sum(x,-1,keepdim=True)
  39. return self.output(x)
  40. def initialize(self):
  41. nn.init.xavier_normal_(self.w_f)
  42. if self.debug:
  43. temp = np.ndarray.flatten(makeGaussian(self.mod_shape[-1], fwhm = self.mod_shape[-1]/20,
  44. center=[self.mod_shape[-1]*.3,self.mod_shape[-1]*.7]))
  45. for i in range(len(self.w_s)):
  46. self.w_s[i,0,:,0] = torch.tensor(temp)
  47. else:
  48. nn.init.xavier_normal_(self.w_s)