spm_nlsi_LS.m 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. function [Ep,qC,qh,F] = spm_nlsi_LS(M,U,Y)
  2. % Bayesian inversion of a nonlinear model using (Laplacian) sampling
  3. % FORMAT [Ep,Cp,Eh,F] = spm_nlsi_LS(M,U,Y)
  4. %
  5. % Dynamical MIMO models
  6. %__________________________________________________________________________
  7. %
  8. % M.IS - function name f(P,M,U) - generative model
  9. % This function specifies the nonlinear model:
  10. % y = Y.y = IS(P,M,U) + X0*P0 + e
  11. % were e ~ N(0,C). For dynamic systems this would be an integration
  12. % scheme (e.g. spm_int). spm_int expects the following:
  13. %
  14. % M.f - f(x,u,P,M)
  15. % M.g - g(x,u,P,M)
  16. % x - state variables
  17. % u - inputs or causes
  18. % P - free parameters
  19. % M - fixed functional forms and parameters in M
  20. %
  21. % M.FS - function name f(y,M) - feature selection
  22. % This [optional] function performs feature selection assuming the
  23. % generalized model y = FS(y,M) = FS(IS(P,M,U),M) + X0*P0 + e
  24. %
  25. % M.P - starting estimates for model parameters [optional]
  26. %
  27. % M.pE - prior expectation - E{P} of model parameters
  28. % M.pC - prior covariance - Cov{P} of model parameters
  29. %
  30. % M.hE - prior expectation - E{h} of log-precision parameters
  31. % M.hC - prior covariance - Cov{h} of log-precision parameters
  32. %
  33. % U.u - inputs
  34. % U.dt - sampling interval
  35. %
  36. % Y.y - outputs
  37. % Y.dt - sampling interval for outputs
  38. % Y.X0 - Confounds or null space (over size(y,1) bins or all vec(y))
  39. % Y.Q - q error precision components (over size(y,1) bins or all vec(y))
  40. %
  41. %
  42. % Parameter estimates
  43. %--------------------------------------------------------------------------
  44. % Ep - (p x 1) conditional expectation E{P|y}
  45. % Cp - (p x p) conditional covariance Cov{P|y}
  46. % Eh - (q x 1) conditional log-precisions E{h|y}
  47. %
  48. % log evidence
  49. %--------------------------------------------------------------------------
  50. % F - [-ve] free energy F = log evidence = p(y|f,g,pE,pC) = p(y|m)
  51. %
  52. %__________________________________________________________________________
  53. % Returns the moments of the posterior p.d.f. of the parameters of a
  54. % nonlinear model specified by IS(P,M,U) under Gaussian assumptions.
  55. % Usually, IS is an integrator of a dynamic MIMO input-state-output model
  56. %
  57. % dx/dt = f(x,u,P)
  58. % y = g(x,u,P) + X0*P0 + e
  59. %
  60. % A static nonlinear observation model with fixed input or causes u
  61. % obtains when x = []. i.e.
  62. %
  63. % y = g([],u,P) + X0*P0e + e
  64. %
  65. % but static nonlinear models are specified more simply using
  66. %
  67. % y = IS(P,M,U) + X0*P0 + e
  68. %
  69. % Priors on the free parameters P are specified in terms of expectation pE
  70. % and covariance pC.
  71. %
  72. % For generic aspects of the scheme see:
  73. %
  74. % Friston K, Mattout J, Trujillo-Barreto N, Ashburner J, Penny W.
  75. % Variational free energy and the Laplace approximation.
  76. % NeuroImage. 2007 Jan 1;34(1):220-34.
  77. %
  78. % This scheme handels complex data along the lines originally described in:
  79. %
  80. % Sehpard RJ, Lordan BP, and Grant EH.
  81. % Least squares analysis of complex data with applications to permittivity
  82. % measurements.
  83. % J. Phys. D. Appl. Phys 1970 3:1759-1764.
  84. %
  85. %__________________________________________________________________________
  86. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  87. % Karl Friston
  88. % $Id: spm_nlsi_LS.m 5219 2013-01-29 17:07:07Z spm $
  89. % figure (unless disabled)
  90. %--------------------------------------------------------------------------
  91. try
  92. M.nograph;
  93. catch
  94. M.nograph = 0;
  95. end
  96. if ~M.nograph
  97. Fsi = spm_figure('GetWin','SI');
  98. end
  99. % check integrator
  100. %--------------------------------------------------------------------------
  101. try
  102. M.IS;
  103. catch
  104. M.IS = 'spm_int';
  105. end
  106. % composition of feature selection and prediction (usually an integrator)
  107. %--------------------------------------------------------------------------
  108. if isfield(M,'FS')
  109. % FS(y,M)
  110. %----------------------------------------------------------------------
  111. try
  112. y = feval(M.FS,Y.y,M);
  113. IS = inline([M.FS '(' M.IS '(P,M,U),M)'],'P','M','U');
  114. % FS(y,M)
  115. %----------------------------------------------------------------------
  116. catch
  117. y = feval(M.FS,Y.y);
  118. IS = inline([M.FS '(' M.IS '(P,M,U))'],'P','M','U');
  119. end
  120. else
  121. % FS(y) = y
  122. %----------------------------------------------------------------------
  123. y = Y.y;
  124. IS = inline([M.IS '(P,M,U)'],'P','M','U');
  125. end
  126. % size of data (usually samples x channels)
  127. %--------------------------------------------------------------------------
  128. if iscell(y)
  129. ns = size(y{1},1);
  130. else
  131. ns = size(y,1);
  132. end
  133. nr = length(spm_vec(y))/ns; % number of samples and responses
  134. M.ns = ns; % store in M.ns for integrator
  135. % initial states
  136. %--------------------------------------------------------------------------
  137. try
  138. M.x;
  139. catch
  140. if ~isfield(M,'n'), M.n = 0; end
  141. M.x = sparse(M.n,1);
  142. end
  143. % input
  144. %--------------------------------------------------------------------------
  145. try
  146. U;
  147. catch
  148. U = [];
  149. end
  150. % initial parameters
  151. %--------------------------------------------------------------------------
  152. try
  153. spm_vec(M.P) - spm_vec(M.pE);
  154. fprintf('\nParameter initialisation successful\n')
  155. catch
  156. M.P = M.pE;
  157. end
  158. % time-step
  159. %--------------------------------------------------------------------------
  160. try
  161. Y.dt;
  162. catch
  163. Y.dt = 1;
  164. end
  165. % precision components Q
  166. %--------------------------------------------------------------------------
  167. try
  168. Q = Y.Q;
  169. if isnumeric(Q), Q = {Q}; end
  170. catch
  171. Q = spm_Ce(ns*ones(1,nr));
  172. end
  173. nh = length(Q); % number of precision components
  174. nt = length(Q{1}); % number of time bins
  175. nq = nr*ns/nt; % for compact Kronecker form of M-step
  176. h = zeros(nh,1); % initialise hyperparameters
  177. % confounds (if specified)
  178. %--------------------------------------------------------------------------
  179. try
  180. nb = size(Y.X0,1); % number of bins
  181. nx = nr*ns/nb; % number of blocks
  182. dfdu = kron(speye(nx,nx),Y.X0);
  183. catch
  184. dfdu = sparse(ns*nr,0);
  185. end
  186. % hyperpriors - expectation
  187. %--------------------------------------------------------------------------
  188. try
  189. hE = M.hE;
  190. if length(hE) ~= nh
  191. hE = hE*sparse(nh,1);
  192. end
  193. catch
  194. hE = sparse(nh,1);
  195. end
  196. % prior moments
  197. %--------------------------------------------------------------------------
  198. pE = M.pE;
  199. pC = M.pC;
  200. nu = size(dfdu,2); % number of parameters (confounds)
  201. np = size(pC,2); % number of parameters (effective)
  202. % second-order moments (in reduced space)
  203. %--------------------------------------------------------------------------
  204. ipC = spm_inv(pC);
  205. % initialize conditional density
  206. %--------------------------------------------------------------------------
  207. Eu = spm_pinv(dfdu)*spm_vec(y);
  208. Ep = pE;
  209. % precision and conditional covariance
  210. %------------------------------------------------------------------
  211. iS = sparse(0);
  212. for i = 1:nh
  213. iS = iS + Q{i}*(exp(-16) + exp(hE(i)));
  214. end
  215. S = spm_inv(iS);
  216. iS = kron(speye(nq),iS);
  217. qS = spm_sqrtm(pC/32);
  218. qE = spm_vec(pE);
  219. pE = spm_vec(pE);
  220. y = spm_vec(y);
  221. np = length(qE);
  222. % Sampling
  223. %==========================================================================
  224. Gmax = -Inf;
  225. for k = 1:64
  226. % time
  227. %----------------------------------------------------------------------
  228. tic;
  229. % Gibb's sampling
  230. %======================================================================
  231. for i = 1:128
  232. % prediction
  233. %------------------------------------------------------------------
  234. P(:,i) = qE + qS*randn(np,1);
  235. R(:,i) = spm_vec(feval(IS,spm_unvec(P(:,i),M.pE),M,U));
  236. % prediction error
  237. %------------------------------------------------------------------
  238. ey = R(:,i) - y;
  239. ep = P(:,i) - pE;
  240. % Gibb's energy
  241. %------------------------------------------------------------------
  242. qh = real(ey')*iS*real(ey) + imag(ey)'*iS*imag(ey);
  243. G(i,1) = - ns*log(qh)/2 - ep'*ipC*ep/2;
  244. % conditional mode
  245. %----------------------------------------------------------------------
  246. [maxG,j] = max(G);
  247. if maxG > Gmax
  248. qE = P(:,j);
  249. f = R(:,j);
  250. Gmax = maxG;
  251. end
  252. pE = qE;
  253. disp(i)
  254. end
  255. % conditional dispersion
  256. %----------------------------------------------------------------------
  257. q = exp((G - maxG));
  258. q = q/sum(q);
  259. for i = 1:np
  260. for j = 1:np
  261. qC(i,j) = ((P(i,:) - qE(i)).*(P(j,:) - qE(j)))*q;
  262. end
  263. end
  264. qS = spm_sqrtm(qC);
  265. % objective function:
  266. %======================================================================
  267. F = Gmax + spm_logdet(ipC*qC)/2;
  268. F = Gmax;
  269. % graphics
  270. %----------------------------------------------------------------------
  271. if exist('Fsi', 'var')
  272. spm_figure('Select', Fsi)
  273. % reshape prediction if necessary
  274. %------------------------------------------------------------------
  275. f = reshape(f,ns,nr);
  276. d = reshape(y,ns,nr);
  277. % subplot prediction
  278. %------------------------------------------------------------------
  279. x = (1:ns)*Y.dt;
  280. xLab = 'time (seconds)';
  281. try
  282. if length(M.Hz) == ns
  283. x = Y.Hz;
  284. xLab = 'Frequency (Hz)';
  285. end
  286. end
  287. if isreal(f)
  288. subplot(2,1,1)
  289. plot(x,f,x,d,':')
  290. xlabel(xLab)
  291. title(sprintf('%s: %i','prediction and response: E-Step',k))
  292. grid on
  293. else
  294. subplot(2,2,1)
  295. plot(x,real(f),x,real(d),':')
  296. xlabel(xLab)
  297. ylabel('real')
  298. title(sprintf('%s: %i','prediction and response: E-Step',k))
  299. grid on
  300. subplot(2,2,2)
  301. plot(x,imag(f),x,imag(d),':')
  302. xlabel(xLab)
  303. ylabel('imaginary')
  304. title(sprintf('%s: %i','prediction and response: E-Step',k))
  305. grid on
  306. end
  307. % subplot Gibb's smapling
  308. %------------------------------------------------------------------
  309. subplot(2,2,3)
  310. plot(G)
  311. xlabel('smaple')
  312. title('Gibbs energy')
  313. % subplot parameters
  314. %------------------------------------------------------------------
  315. subplot(2,2,4)
  316. bar(full(qE - spm_vec(M.pE)))
  317. xlabel('parameter')
  318. title('conditional expectation')
  319. grid on
  320. drawnow
  321. end
  322. % convergence
  323. %----------------------------------------------------------------------
  324. try, dF = F - Fk; catch, dF = 0; end
  325. Fk = F;
  326. fprintf('%-6s: %i %6s %-6.3e %6s %.3e (%.2f)\n','LS',k,'F:',full(F),'dF:',full(dF),toc)
  327. if k > 4 && dF < 1e-4
  328. break
  329. end
  330. end
  331. if exist('Fsi', 'var')
  332. spm_figure('Focus', Fsi)
  333. end