spm_nlsi_Newton.m 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. function [Ep,Cp,F] = spm_nlsi_Newton(M,U,Y)
  2. % Variational Lapalce for nonlinear models - Newton's method
  3. % FORMAT [Ep,Cp,F] = spm_nlsi_Newton(M,U,Y)
  4. %
  5. % Eplicit log-likihood model
  6. %__________________________________________________________________________
  7. %
  8. % M.L - log likelihood function @(P,M,U,Y)
  9. % P - free parameters
  10. % M - model
  11. %
  12. % M.P - starting estimates for model parameters [optional]
  13. % M.pE - prior expectation - E{P} of model parameters
  14. % M.pC - prior covariance - Cov{P} of model parameters
  15. %
  16. % U - inputs or causes
  17. % Y - output or response
  18. %
  19. % Parameter estimates
  20. %--------------------------------------------------------------------------
  21. % Ep - (p x 1) conditional expectation E{P|y}
  22. % Cp - (p x p) conditional covariance Cov{P|y}
  23. %
  24. % log evidence
  25. %--------------------------------------------------------------------------
  26. % F - [-ve] free energy F = log evidence = p(Y|pE,pC) = p(y|m)
  27. %
  28. %__________________________________________________________________________
  29. % Returns the moments of the posterior p.d.f. of the parameters of a
  30. % nonlinear model with a log likelihood function L(P,M,U,Y).
  31. %
  32. % Priors on the free parameters P are specified in terms of expectation pE
  33. % and covariance pC. This Variational Laplace scheme uses an explicit
  34. % (numerical) curvature to implement a gradient ascent on variational free
  35. % energy using Newton's method. An example of its application is provided at
  36. % the end of this routine using a simple general linear model. This example
  37. % eschews the mean field approximation aassociated with standard
  38. % inversions.
  39. %
  40. % For generic aspects of the scheme see:
  41. %
  42. % Friston K, Mattout J, Trujillo-Barreto N, Ashburner J, Penny W.
  43. % Variational free energy and the Laplace approximation.
  44. % NeuroImage. 2007 Jan 1;34(1):220-34.
  45. %__________________________________________________________________________
  46. % Copyright (C) 2001-2015 Wellcome Trust Centre for Neuroimaging
  47. % Karl Friston
  48. % $Id: spm_nlsi_Newton.m 6587 2015-11-02 10:29:49Z karl $
  49. % options
  50. %--------------------------------------------------------------------------
  51. try, M.nograph; catch, M.nograph = 0; end
  52. try, M.noprint; catch, M.noprint = 0; end
  53. try, M.Nmax; catch, M.Nmax = 128; end
  54. % converted to function handle
  55. %--------------------------------------------------------------------------
  56. L = spm_funcheck(M.L);
  57. % initial parameters
  58. %--------------------------------------------------------------------------
  59. try
  60. M.P; fprintf('\nParameter initialisation successful\n')
  61. catch
  62. M.P = M.pE;
  63. end
  64. % prior moments (assume uninformative priors if not specifed)
  65. %--------------------------------------------------------------------------
  66. pE = M.pE;
  67. try
  68. pC = M.pC;
  69. catch
  70. np = spm_length(M.pE);
  71. pC = speye(np,np)*exp(16);
  72. end
  73. % unpack covariance
  74. %--------------------------------------------------------------------------
  75. if isstruct(pC);
  76. pC = spm_diag(spm_vec(pC));
  77. end
  78. % dimension reduction of parameter space
  79. %--------------------------------------------------------------------------
  80. V = spm_svd(pC,0);
  81. % second-order moments (in reduced space)
  82. %--------------------------------------------------------------------------
  83. pC = V'*pC*V;
  84. ipC = inv(pC);
  85. % initialize conditional density
  86. %--------------------------------------------------------------------------
  87. p = V'*(spm_vec(M.P) - spm_vec(M.pE));
  88. Ep = spm_unvec(spm_vec(pE) + V*p,pE);
  89. % figure (unless disabled)
  90. %--------------------------------------------------------------------------
  91. if ~M.nograph, Fsi = spm_figure('GetWin','SI'); clf, end
  92. % Wariational Laplace
  93. %==========================================================================
  94. criterion = [0 0 0 0];
  95. C.F = -Inf; % free energy
  96. v = -4; % log ascent rate
  97. for k = 1:M.Nmax
  98. % time
  99. %----------------------------------------------------------------------
  100. tStart = tic;
  101. % Log-likelihood f, gradients; dfdp and curvature dfdpp
  102. %======================================================================
  103. [dfdpp,dfdp,f] = spm_diff(L,Ep,M,U,Y,[1 1],{V});
  104. dfdp = dfdp';
  105. dfdpp = full(spm_cat(dfdpp'));
  106. % enure prior bounds on curvature
  107. %----------------------------------------------------------------------
  108. [E,D] = eig(dfdpp);
  109. D = diag(D);
  110. dfdpp = E*diag(D.*(D < 0))*E';
  111. % condiitonal covariance
  112. %----------------------------------------------------------------------
  113. Cp = inv(ipC - dfdpp);
  114. % Fre energy: F(p) = log evidence - divergence
  115. %======================================================================
  116. F = f - p'*ipC*p/2 + spm_logdet(ipC*Cp)/2;
  117. G(k) = F;
  118. % record increases and reference log-evidence for reporting
  119. %----------------------------------------------------------------------
  120. if k > 1
  121. if ~M.noprint
  122. fprintf(' actual: %.3e (%.2f sec)\n',full(F - C.F),toc(tStart))
  123. end
  124. else
  125. F0 = F;
  126. end
  127. % if F has increased, update gradients and curvatures for E-Step
  128. %----------------------------------------------------------------------
  129. if F > C.F || k < 8
  130. % accept current estimates
  131. %------------------------------------------------------------------
  132. C.p = p;
  133. C.F = F;
  134. C.Cp = Cp;
  135. % E-Step: Conditional update of gradients and curvature
  136. %------------------------------------------------------------------
  137. dFdp = dfdp - ipC*p;
  138. dFdpp = dfdpp - ipC;
  139. % decrease regularization
  140. %------------------------------------------------------------------
  141. v = min(v + 1/2,4);
  142. str = 'EM:(+)';
  143. else
  144. % reset expansion point
  145. %------------------------------------------------------------------
  146. p = C.p;
  147. % and increase regularization
  148. %------------------------------------------------------------------
  149. v = min(v - 2,-4);
  150. str = 'EM:(-)';
  151. end
  152. % E-Step: update
  153. %======================================================================
  154. dp = spm_dx(dFdpp,dFdp,{v});
  155. p = p + dp;
  156. Ep = spm_unvec(spm_vec(pE) + V*p,pE);
  157. % Graphics
  158. %======================================================================
  159. if exist('Fsi', 'var')
  160. spm_figure('Select', Fsi)
  161. % trajectory in parameter space
  162. %------------------------------------------------------------------
  163. subplot(2,2,1)
  164. plot(0,0,'r.','MarkerSize',32), hold on
  165. col = [exp(-k/4) exp(-k) 1];
  166. try
  167. plot(V(1,:)*p,V(2,:)*p,'.','MarkerSize',32,'Color',col), hold on
  168. xlabel('1st parameter')
  169. ylabel('2nd parameter')
  170. catch
  171. plot(k,V(1,:)*p,'.','MarkerSize',32,'Color',col), hold on
  172. xlabel('Iteration')
  173. ylabel('1st parameter')
  174. end
  175. title('Trajectory','FontSize',16)
  176. grid on, axis square
  177. % trajectory in parameter space
  178. %------------------------------------------------------------------
  179. subplot(2,2,2)
  180. bar(full(G - F0),'c')
  181. xlabel('Iteration')
  182. ylabel('Log-evidence')
  183. title('Free energy','FontSize',16)
  184. grid on, axis square
  185. % subplot parameters
  186. %--------------------------------------------------------------
  187. subplot(2,2,3)
  188. bar(full(spm_vec(pE) + V*p))
  189. xlabel('Parameter')
  190. tstr = 'Conditional expectation';
  191. title(tstr,'FontSize',16)
  192. grid on, axis square
  193. % subplot parameters (eigenmodes)
  194. %------------------------------------------------------------------
  195. subplot(2,2,4)
  196. spm_plot_ci(p,Cp)
  197. xlabel('Parameter (eigenmodes)')
  198. title('Posterior deviations','FontSize',16)
  199. grid on, axis square
  200. drawnow
  201. end
  202. % convergence
  203. %----------------------------------------------------------------------
  204. dF = dFdp'*dp;
  205. if ~M.noprint
  206. fprintf('%-6s: %i %6s %-6.3e %6s %.3e ',str,k,'F:',full(C.F - F0),'dF predicted:',full(dF))
  207. end
  208. criterion = [(dF < 1e-1) criterion(1:end - 1)];
  209. if all(criterion)
  210. if ~M.noprint
  211. fprintf(' convergence\n')
  212. end
  213. break
  214. end
  215. end
  216. if exist('Fsi', 'var')
  217. spm_figure('Focus', Fsi)
  218. end
  219. % outputs
  220. %--------------------------------------------------------------------------
  221. Ep = spm_unvec(spm_vec(pE) + V*C.p,pE);
  222. Cp = V*C.Cp*V';
  223. F = C.F;
  224. return
  225. % NB: notes - illustrative application (a simple linear model)
  226. %==========================================================================
  227. % parameters P and design matrix U
  228. %--------------------------------------------------------------------------
  229. U = randn(32,2); % design matrix
  230. P.beta = [4;2]; % parameters of GLM
  231. P.pi = 2; % log precision
  232. % generate data
  233. %--------------------------------------------------------------------------
  234. Y = U*P.beta + exp(-P.pi/2)*randn(32,1);
  235. % model specification with log-likelihood function M.L
  236. %--------------------------------------------------------------------------
  237. M.L = @(P,M,U,Y) sum(log( spm_Npdf(Y, U*P.beta, exp(-P.pi)) ));
  238. M.pE = spm_zeros(P); % prior means (parameters)
  239. M.pC = eye(spm_length(P)); % prior variance (parameters)
  240. % Variational Laplace
  241. %--------------------------------------------------------------------------
  242. [Ep,Cp,F] = spm_nlsi_Newton(M,U,Y);
  243. % overlay true values on confidence intervals
  244. %--------------------------------------------------------------------------
  245. subplot(2,2,4),hold on
  246. bar(spm_vec(P),1/4)