helper.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import torch
  2. from torch import nn, Tensor, FloatTensor
  3. import torch.nn.functional as F
  4. from typing import Dict, Iterable, Callable
  5. from PIL import Image
  6. import numpy as np
  7. from math import exp
  8. from scipy import optimize
  9. from lucent.optvis import param, transform
  10. from lucent.optvis.objectives import wrap_objective
  11. def makeGaussian(size, fwhm = None, center=None):
  12. x = np.arange(0, size, 1, float)
  13. y = x[:,np.newaxis]
  14. if center is None:
  15. x0 = y0 = size // 2
  16. else:
  17. x0 = center[0]
  18. y0 = center[1]
  19. return np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / fwhm**2)
  20. class FeatureExtractor(nn.Module):
  21. def __init__(self, model: nn.Module, layers: Iterable[str]):
  22. super().__init__()
  23. self.model = model
  24. self.layers = layers
  25. self._features = {layer: torch.empty(0) for layer in layers}
  26. for layer_id in layers:
  27. layer = dict([*self.model.named_modules()])[layer_id]
  28. layer.register_forward_hook(self.save_outputs_hook(layer_id))
  29. def save_outputs_hook(self, layer_id: str) -> Callable:
  30. def fn(_, __, output):
  31. self._features[layer_id] = output
  32. return fn
  33. def forward(self, x: Tensor):
  34. _ = self.model(x)
  35. return self._features
  36. def fix_parameters(module, value=None):
  37. """
  38. Set requires_grad = False for all parameters.
  39. If a value is passed all parameters are fixed to the value.
  40. """
  41. for param in module.parameters():
  42. if value:
  43. param.data = FloatTensor(param.data.size()).fill_(value)
  44. param.requires_grad = False
  45. return module
  46. def smoothing_laplacian_loss(data, device, weight=1e-3, L=None):
  47. if L is None:
  48. L = torch.tensor([[0,-1,0],[-1,-4,-1],[0,-1,0]],device=device)
  49. temp = torch.reshape(data.squeeze(), [data.squeeze().shape[0],
  50. np.sqrt(data.squeeze().shape[1]).astype('int'),
  51. np.sqrt(data.squeeze().shape[1]).astype('int')])
  52. temp = torch.square(F.conv2d(temp.unsqueeze(1),L.unsqueeze(0).unsqueeze(0).float(),
  53. padding=5))
  54. return weight * torch.sqrt(torch.sum(temp))
  55. def sparsity_loss(data_1, data_2, weight=1e-3):
  56. return weight * torch.sum(torch.sum(torch.abs(data_1))) * torch.norm(torch.sum(torch.abs(data_2)))
  57. def smoothing_laplacian_loss_v2(data, device, weight=1e-3, L=None):
  58. if L is None:
  59. L = torch.tensor([[0,-1,0],[-1,-4,-1],[0,-1,0]],device=device)
  60. L = L.unsqueeze(0).unsqueeze(0)
  61. temp = F.conv2d(data.permute([3,0,1,2]),L.repeat_interleave(data.shape[0],1).float())
  62. return weight * torch.mean(torch.sum(torch.square(temp),[1,2,3]))
  63. def sparsity_loss_v2(data, weight=1e-3):
  64. return weight * torch.mean(torch.sum(torch.sqrt(torch.sum(torch.square(data),[0,1])),1))
  65. def l1_loss(data, weight=1e-3):
  66. return weight * torch.mean(torch.sum(torch.abs(data),[0,1,2]))
  67. def sta(neural_data,img_data,size=20):
  68. sta = []
  69. for c in range(neural_data.shape[1]):
  70. res = []
  71. for n in range(neural_data.shape[0]):
  72. temp = (neural_data[n,c])
  73. res.append((temp*torch.mean(F.interpolate(img_data[n].unsqueeze(0),
  74. size=[size,size]).squeeze(),0)).detach().numpy())
  75. res = np.asarray(res)
  76. res = np.sum(res,0)
  77. res = (res + np.abs(res.min()))
  78. sta.append(res/res.max())
  79. return np.asarray(sta)
  80. def load_sta(sta,mod_shape,device):
  81. out = []
  82. for i in range(len(sta)):
  83. temp = torch.tensor(sta[i]).to(device)
  84. temp = F.interpolate(temp.unsqueeze(0).unsqueeze(0), size=[mod_shape[2],mod_shape[3]]).squeeze()
  85. temp = temp.unsqueeze(0)
  86. out.append(temp.repeat_interleave(mod_shape[1],0))
  87. out = torch.stack(out)
  88. return out.permute([1,2,3,0])
  89. def load(path,sz):
  90. return np.array(Image.open(path).resize((sz,sz))) / 255
  91. def load_crop(path,sz):
  92. im = Image.open(path)
  93. width, height = im.size # Get dimensions
  94. min_size = np.min([width,height])
  95. left = (width - min_size)/2
  96. top = (height - min_size)/2
  97. right = (width + min_size)/2
  98. bottom = (height + min_size)/2
  99. # Crop the center of the image
  100. im = im.crop((left, top, right, bottom))
  101. return np.array(im.resize((sz,sz))) / 255
  102. def mean_L1(a, b):
  103. return torch.abs(a-b).mean()
  104. @wrap_objective()
  105. def activation_difference(layer_names, activation_loss_f=mean_L1, transform_f=None):
  106. def inner(T):
  107. # first we collect the (constant) activations of image we're computing the difference to
  108. image_activations = [T(layer_name)[1] for layer_name in layer_names]
  109. if transform_f is not None:
  110. image_activations = [transform_f(act) for act in image_activations]
  111. # we also set get the activations of the optimized image which will change during optimization
  112. optimization_activations = [T(layer)[0] for layer in layer_names]
  113. if transform_f is not None:
  114. optimization_activations = [transform_f(act) for act in optimization_activations]
  115. # we use the supplied loss function to compute the actual losses
  116. losses = [activation_loss_f(a, b) for a, b in zip(image_activations, optimization_activations)]
  117. return sum(losses)
  118. return inner
  119. def gram_matrix(features, normalize=True):
  120. C, H, W = features.shape
  121. features = features.view(C, -1)
  122. gram = torch.matmul(features, torch.transpose(features, 0, 1))
  123. if normalize:
  124. gram = gram / (H * W)
  125. return gram
  126. def gaussian(height, center_x, center_y, width_x, width_y):
  127. """Returns a gaussian function with the given parameters"""
  128. width_x = float(width_x)
  129. width_y = float(width_y)
  130. return lambda x,y: height*np.exp(
  131. -(((center_x-x)/width_x)**2+((center_y-y)/width_y)**2)/2)
  132. def moments(data):
  133. """Returns (height, x, y, width_x, width_y)
  134. the gaussian parameters of a 2D distribution by calculating its
  135. moments """
  136. total = data.sum()
  137. X, Y = np.indices(data.shape)
  138. x = (X*data).sum()/total
  139. y = (Y*data).sum()/total
  140. col = data[:, int(y)]
  141. width_x = np.sqrt(np.abs((np.arange(col.size)-x)**2*col).sum()/col.sum())
  142. row = data[int(x), :]
  143. width_y = np.sqrt(np.abs((np.arange(row.size)-y)**2*row).sum()/row.sum())
  144. height = data.max()
  145. return height, x, y, width_x, width_y
  146. def fitgaussian(data):
  147. """Returns (height, x, y, width_x, width_y)
  148. the gaussian parameters of a 2D distribution found by a fit"""
  149. params = moments(data)
  150. errorfunction = lambda p: np.ravel(gaussian(*p)(*np.indices(data.shape)) -
  151. data)
  152. p, success = optimize.leastsq(errorfunction, params)
  153. return p
  154. def gabor_patch(shape, pos_yx = None, radius = None, wavelength = None, orientation = None,
  155. phase = None, min_brightness = 0, max_brightness = 1):
  156. if pos_yx is None:
  157. pos_yx = (shape[0] / 2, shape[1] / 2)
  158. assert len(shape) >= 2
  159. H = shape[len(shape) - 2]
  160. W = shape[len(shape) - 1]
  161. assert len(pos_yx) == 2
  162. # "Bounding box"
  163. xmax = W
  164. ymax = H
  165. xmin = 0
  166. ymin = 0
  167. (x, y) = np.meshgrid(np.arange(xmin, xmax), np.arange(ymin, ymax))
  168. # Rotation (around center of image)
  169. # The value along the x of the grating is constant
  170. y_theta = (x * np.sin(orientation) +
  171. y * np.cos(orientation))
  172. d = np.sqrt((y - pos_yx[0]) ** 2 + (x - pos_yx[1]) ** 2) - radius
  173. d[d < - 0.5] = - 0.5
  174. d[d > 0.5] = 0.5
  175. envelope = 0.5 - d
  176. # initially make gratings vary from 0 to 1
  177. gratings = (np.cos(2 * np.pi / wavelength * y_theta + phase) + 1) / 2
  178. # make gratings between min_brightness and max_brightness
  179. gratings = gratings * (max_brightness - min_brightness) + min_brightness
  180. gb = np.multiply(
  181. envelope,
  182. gratings)
  183. # set the background color
  184. gb += (1-envelope) * ((max_brightness - min_brightness) / 2 + min_brightness)
  185. gb = np.reshape(gb, [1, 1, H, W])
  186. gb = np.repeat(gb, 3, 1)
  187. return gb
  188. class occlude_pic(object):
  189. def __init__(self, occluder, mask, device = 'cpu'):
  190. self.occluder_t = torch.tensor(np.transpose(occluder[0], [2, 0, 1])).float().to(device)
  191. self.mask_t = torch.tensor(np.transpose(mask, [2, 0, 1])).float().to(device)
  192. def __call__(self,img):
  193. occluded_input = img[0] * (1-self.mask_t) + self.occluder_t * self.mask_t
  194. return occluded_input[None,:,:,:]