spm_LAPF.m 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. function [DEM] = spm_LAPF(DEM)
  2. % Laplacian model inversion (see also spm_LAPS)
  3. % FORMAT DEM = spm_LAPF(DEM)
  4. %
  5. % DEM.M - hierarchical model
  6. % DEM.Y - response variable, output or data
  7. % DEM.U - explanatory variables, inputs or prior expectation of causes
  8. %__________________________________________________________________________
  9. %
  10. % generative model
  11. %--------------------------------------------------------------------------
  12. % M(i).g = v = g(x,v,P) {inline function, string or m-file}
  13. % M(i).f = dx/dt = f(x,v,P) {inline function, string or m-file}
  14. %
  15. % M(i).ph = pi(v) = ph(x,v,h,M) {inline function, string or m-file}
  16. % M(i).pg = pi(x) = pg(x,v,g,M) {inline function, string or m-file}
  17. %
  18. % M(i).pE = prior expectation of p model-parameters
  19. % M(i).pC = prior covariances of p model-parameters
  20. % M(i).hE = prior expectation of h log-precision (cause noise)
  21. % M(i).hC = prior covariances of h log-precision (cause noise)
  22. % M(i).gE = prior expectation of g log-precision (state noise)
  23. % M(i).gC = prior covariances of g log-precision (state noise)
  24. % M(i).xP = precision (states)
  25. % M(i).Q = precision components (input noise)
  26. % M(i).R = precision components (state noise)
  27. % M(i).V = fixed precision (input noise)
  28. % M(i).W = fixed precision (state noise)
  29. %
  30. % M(i).m = number of inputs v(i + 1);
  31. % M(i).n = number of states x(i);
  32. % M(i).l = number of output v(i);
  33. %
  34. % conditional moments of model-states - q(u)
  35. %--------------------------------------------------------------------------
  36. % qU.x = Conditional expectation of hidden states
  37. % qU.v = Conditional expectation of causal states
  38. % qU.w = Conditional prediction error (states)
  39. % qU.z = Conditional prediction error (causes)
  40. % qU.C = Conditional covariance: cov(v)
  41. % qU.S = Conditional covariance: cov(x)
  42. %
  43. % conditional moments of model-parameters - q(p)
  44. %--------------------------------------------------------------------------
  45. % qP.P = Conditional expectation
  46. % qP.C = Conditional covariance
  47. %
  48. % conditional moments of hyper-parameters (log-transformed) - q(h)
  49. %--------------------------------------------------------------------------
  50. % qH.h = Conditional expectation (cause noise)
  51. % qH.g = Conditional expectation (state noise)
  52. % qH.C = Conditional covariance
  53. %
  54. % F = log-evidence = log-marginal likelihood = negative free-energy
  55. %__________________________________________________________________________
  56. %
  57. % spm_LAPF implements a variational scheme under the Laplace
  58. % approximation to the conditional joint density q on states (u), parameters
  59. % (p) and hyperparameters (h,g) of any analytic nonlinear hierarchical dynamic
  60. % model, with additive Gaussian innovations.
  61. %
  62. % q(u,p,h,g) = max <L(t)>q
  63. %
  64. % L is the ln p(y,u,p,h,g|M) under the model M. The conditional covariances
  65. % obtain analytically from the curvature of L with respect to the unknowns.
  66. %__________________________________________________________________________
  67. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  68. % Karl Friston
  69. % $Id: spm_LAPF.m 6018 2014-05-25 09:24:14Z karl $
  70. % find or create a DEM figure
  71. %--------------------------------------------------------------------------
  72. try
  73. DEM.M(1).nograph;
  74. catch
  75. DEM.M(1).nograph = 0;
  76. end
  77. if ~DEM.M(1).nograph
  78. Fdem = spm_figure('GetWin','DEM');
  79. end
  80. % check model, data and priors
  81. %==========================================================================
  82. [M,Y,U] = spm_DEM_set(DEM);
  83. % number of iterations
  84. %--------------------------------------------------------------------------
  85. try, nD = M(1).E.nD; catch, nD = 1; end
  86. try, nN = M(1).E.nN; catch, nN = 16; end
  87. % ensure integration scheme evaluates gradients at each time-step
  88. %--------------------------------------------------------------------------
  89. M(1).E.linear = 4;
  90. % assume precisions are a function of, and only of, hyperparameters
  91. %--------------------------------------------------------------------------
  92. try
  93. method = M(1).E.method;
  94. catch
  95. method.h = 1;
  96. method.g = 1;
  97. method.x = 0;
  98. method.v = 0;
  99. end
  100. try method.h; catch, method.h = 0; end
  101. try method.g; catch, method.g = 0; end
  102. try method.x; catch, method.x = 0; end
  103. try method.v; catch, method.v = 0; end
  104. % assume precisions are a function of, and only of, hyperparameters
  105. %--------------------------------------------------------------------------
  106. try
  107. form = M(1).E.form;
  108. catch
  109. form = 'Gaussian';
  110. end
  111. % checks for Laplace models (precision functions; ph and pg)
  112. %--------------------------------------------------------------------------
  113. for i = 1:length(M)
  114. try
  115. feval(M(i).ph,M(i).x,M(i + 1).v,M(i).hE,M(i)); method.v = 1;
  116. catch
  117. M(i).ph = inline('spm_LAP_ph(x,v,h,M)','x','v','h','M');
  118. end
  119. try
  120. feval(M(i).pg,M(i).x,M(i + 1).v,M(i).gE,M(i)); method.x = 1;
  121. catch
  122. M(i).pg = inline('spm_LAP_pg(x,v,h,M)','x','v','h','M');
  123. end
  124. end
  125. M(1).E.method = method;
  126. % order parameters (d = n = 1 for static models) and checks
  127. %==========================================================================
  128. d = M(1).E.d + 1; % embedding order of q(v)
  129. n = M(1).E.n + 1; % embedding order of q(x)
  130. % number of states and parameters
  131. %--------------------------------------------------------------------------
  132. ns = size(Y,2); % number of samples
  133. nl = size(M,2); % number of levels
  134. nv = sum(spm_vec(M.m)); % number of v (casual states)
  135. nx = sum(spm_vec(M.n)); % number of x (hidden states)
  136. ny = M(1).l; % number of y (inputs)
  137. nc = M(end).l; % number of c (prior causes)
  138. nu = nv*d + nx*n; % number of generalised states
  139. ne = nv*n + nx*n + ny*n; % number of generalised errors
  140. % precision (R) of generalised errors and null matrices for concatenation
  141. %==========================================================================
  142. s = M(1).E.s;
  143. Rh = spm_DEM_R(n,s,form);
  144. Rg = spm_DEM_R(n,s,form);
  145. W = sparse(nx*n,nx*n);
  146. V = sparse((ny + nv)*n,(ny + nv)*n);
  147. % fixed priors on states (u)
  148. %--------------------------------------------------------------------------
  149. Px = kron(sparse(1,1,1,n,n),spm_cat(spm_diag({M.xP})));
  150. Pv = kron(sparse(1,1,1,d,d),sparse(nv,nv));
  151. pu.ic = spm_cat(spm_diag({Px Pv}));
  152. % hyperpriors
  153. %--------------------------------------------------------------------------
  154. ph.h = spm_vec({M.hE M.gE}); % prior expectation of h,g
  155. ph.c = spm_cat(spm_diag({M.hC M.gC})); % prior covariances of h,g
  156. ph.ic = spm_pinv(ph.c); % prior precision of h,g
  157. qh.h = {M.hE}; % conditional expectation h
  158. qh.g = {M.gE}; % conditional expectation g
  159. nh = length(spm_vec(qh.h)); % number of hyperparameters h
  160. ng = length(spm_vec(qh.g)); % number of hyperparameters g
  161. nb = nh + ng; % number of hyerparameters
  162. % priors on parameters (in reduced parameter space)
  163. %==========================================================================
  164. pp.c = cell(nl,nl);
  165. qp.p = cell(nl,1);
  166. for i = 1:(nl - 1)
  167. % eigenvector reduction: p <- pE + qp.u*qp.p
  168. %----------------------------------------------------------------------
  169. qp.u{i} = spm_svd(M(i).pC); % basis for parameters
  170. M(i).p = size(qp.u{i},2); % number of qp.p
  171. qp.p{i} = sparse(M(i).p,1); % initial deviates
  172. pp.c{i,i} = qp.u{i}'*M(i).pC*qp.u{i}; % prior covariance
  173. end
  174. Up = spm_cat(spm_diag(qp.u));
  175. % priors on parameters
  176. %--------------------------------------------------------------------------
  177. pp.p = spm_vec(M.pE);
  178. pp.c = spm_cat(pp.c);
  179. pp.ic = spm_inv(pp.c);
  180. % initialise conditional density q(p)
  181. %--------------------------------------------------------------------------
  182. for i = 1:(nl - 1)
  183. try
  184. qp.p{i} = qp.p{i} + qp.u{i}'*(spm_vec(M(i).P) - spm_vec(M(i).pE));
  185. end
  186. end
  187. np = size(Up,2);
  188. % initialise cell arrays for D-Step; e{i + 1} = (d/dt)^i[e] = e[i]
  189. %==========================================================================
  190. qu.x = cell(n,1);
  191. qu.v = cell(n,1);
  192. qu.y = cell(n,1);
  193. qu.u = cell(n,1);
  194. [qu.x{:}] = deal(sparse(nx,1));
  195. [qu.v{:}] = deal(sparse(nv,1));
  196. [qu.y{:}] = deal(sparse(ny,1));
  197. [qu.u{:}] = deal(sparse(nc,1));
  198. % initialise cell arrays for hierarchical structure of x[0] and v[0]
  199. %--------------------------------------------------------------------------
  200. x = {M(1:end - 1).x};
  201. v = {M(1 + 1:end).v};
  202. qu.x{1} = spm_vec(x);
  203. qu.v{1} = spm_vec(v);
  204. % derivatives for Jacobian of D-step
  205. %--------------------------------------------------------------------------
  206. Dx = kron(spm_speye(n,n,1),spm_speye(nx,nx));
  207. Dv = kron(spm_speye(d,d,1),spm_speye(nv,nv));
  208. Dy = kron(spm_speye(n,n,1),spm_speye(ny,ny));
  209. Dc = kron(spm_speye(d,d,1),spm_speye(nc,nc));
  210. Du = spm_cat(spm_diag({Dx,Dv}));
  211. Ip = spm_speye(np,np);
  212. Ih = spm_speye(nb,nb);
  213. qp.dp = sparse(np,1); % conditional expectation of dp/dt
  214. qh.dp = sparse(nb,1); % conditional expectation of dh/dt
  215. % precision of fluctuations on parameters of hyperparameters
  216. %--------------------------------------------------------------------------
  217. Kp = ns*Ip;
  218. Kh = ns*Ih;
  219. % gradients of generalised weighted errors
  220. %--------------------------------------------------------------------------
  221. dedh = sparse(nh,ne);
  222. dedg = sparse(ng,ne);
  223. dedv = sparse(nv,ne);
  224. dedx = sparse(nx,ne);
  225. dedhh = sparse(nh,nh);
  226. dedgg = sparse(ng,ng);
  227. % curvatures of Gibb's energy w.r.t. hyperparameters
  228. %--------------------------------------------------------------------------
  229. dHdh = sparse(nh, 1);
  230. dHdg = sparse(ng, 1);
  231. dHdp = sparse(np, 1);
  232. dHdx = sparse(nx*n,1);
  233. dHdv = sparse(nv*d,1);
  234. % preclude unnecessary iterations and set switchs
  235. %--------------------------------------------------------------------------
  236. if ~np && ~nh && ~ng, nN = 1; end
  237. mnx = nx*~~method.x;
  238. mnv = nv*~~method.v;
  239. % Iterate Lapalace scheme
  240. %==========================================================================
  241. Fa = -Inf;
  242. for iN = 1:nN
  243. % get time and clear persistent variables in evaluation routines
  244. %----------------------------------------------------------------------
  245. tic; clear spm_DEM_eval
  246. % [re-]set states & their derivatives
  247. %----------------------------------------------------------------------
  248. try, qu = Q(1).u; end
  249. % D-Step: (nD D-Steps for each sample)
  250. %======================================================================
  251. for is = 1:ns
  252. % D-Step: until convergence for static systems
  253. %==================================================================
  254. for iD = 1:nD
  255. % sampling time
  256. %--------------------------------------------------------------
  257. ts = is + (iD - 1)/nD;
  258. % derivatives of responses and inputs
  259. %--------------------------------------------------------------
  260. try
  261. qu.y(1:n) = spm_DEM_embed(Y,n,ts,1,M(1).delays);
  262. qu.u(1:d) = spm_DEM_embed(U,d,ts);
  263. catch
  264. qu.y(1:n) = spm_DEM_embed(Y,n,ts);
  265. qu.u(1:d) = spm_DEM_embed(U,d,ts);
  266. end
  267. % evaluate functions and derivatives
  268. %==============================================================
  269. % prediction errors (E) and precision vectors (p)
  270. %--------------------------------------------------------------
  271. [E,dE] = spm_DEM_eval(M,qu,qp);
  272. [p,dp] = spm_LAP_eval(M,qu,qh);
  273. % gradients of log(det(iS)) dDd...
  274. %==============================================================
  275. % get precision matrices
  276. %--------------------------------------------------------------
  277. iSh = diag(exp(p.h));
  278. iSg = diag(exp(p.g));
  279. iS = blkdiag(kron(Rh,iSh),kron(Rg,iSg));
  280. % gradients of trace(diag(p)) = sum(p); p = precision vector
  281. %--------------------------------------------------------------
  282. dpdx = n*sum(spm_cat({dp.h.dx; dp.g.dx}));
  283. dpdv = n*sum(spm_cat({dp.h.dv; dp.g.dv}));
  284. dpdh = n*sum(dp.h.dh);
  285. dpdg = n*sum(dp.g.dg);
  286. dpdx = kron(sparse(1,1,1,1,n),dpdx);
  287. dpdv = kron(sparse(1,1,1,1,d),dpdv);
  288. dDdu = [dpdx dpdv]';
  289. dDdh = [dpdh dpdg]';
  290. % gradients precision-weighted generalised error dSd..
  291. %==============================================================
  292. % gradients w.r.t. hyperparameters
  293. %--------------------------------------------------------------
  294. for i = 1:nh
  295. diS = diag(dp.h.dh(:,i).*exp(p.h));
  296. diSdh{i} = blkdiag(kron(Rh,diS),W);
  297. dedh(i,:) = E'*diSdh{i};
  298. end
  299. for i = 1:ng
  300. diS = diag(dp.g.dg(:,i).*exp(p.g));
  301. diSdg{i} = blkdiag(V,kron(Rg,diS));
  302. dedg(i,:) = E'*diSdg{i};
  303. end
  304. % gradients w.r.t. hidden states
  305. %--------------------------------------------------------------
  306. for i = 1:mnx
  307. diV = diag(dp.h.dx(:,i).*exp(p.h));
  308. diW = diag(dp.g.dx(:,i).*exp(p.g));
  309. diSdx{i} = blkdiag(kron(Rh,diV),kron(Rg,diW));
  310. dedx(i,:) = E'*diSdx{i};
  311. end
  312. % gradients w.r.t. causal states
  313. %--------------------------------------------------------------
  314. for i = 1:mnv
  315. diV = diag(dp.h.dv(:,i).*exp(p.h));
  316. diW = diag(dp.g.dv(:,i).*exp(p.g));
  317. diSdv{i} = blkdiag(kron(Rh,diV),kron(Rg,diW));
  318. dedv(i,:) = E'*diSdv{i};
  319. end
  320. dSdx = kron(sparse(1,1,1,n,1),dedx);
  321. dSdv = kron(sparse(1,1,1,d,1),dedv);
  322. dSdu = [dSdx; dSdv];
  323. dEdh = [dedh; dedg];
  324. dEdp = dE.dp'*iS;
  325. dEdu = dE.du'*iS;
  326. % curvatures w.r.t. hyperparameters
  327. %--------------------------------------------------------------
  328. for i = 1:nh
  329. for j = i:nh
  330. diS = diag(dp.h.dh(:,i).*dp.h.dh(:,j).*exp(p.h));
  331. diS = blkdiag(kron(Rh,diS),W);
  332. dedhh(i,j) = E'*diS*E;
  333. dedhh(j,i) = dedhh(i,j);
  334. end
  335. end
  336. for i = 1:ng
  337. for j = i:ng
  338. diS = diag(dp.g.dg(:,i).*dp.g.dg(:,j).*exp(p.g));
  339. diS = blkdiag(V,kron(Rg,diS));
  340. dedgg(i,j) = E'*diS*E;
  341. dedgg(j,i) = dedgg(i,j);
  342. end
  343. end
  344. % combined curvature
  345. %--------------------------------------------------------------
  346. dSdhh = spm_cat({dedhh [] ;
  347. [] dedgg});
  348. % errors (from prior expectations) (NB pp.p = 0)
  349. %--------------------------------------------------------------
  350. Eu = spm_vec(qu.x(1:n),qu.v(1:d));
  351. Ep = spm_vec(qp.p);
  352. Eh = spm_vec(qh.h,qh.g) - ph.h;
  353. % first-order derivatives of Gibb's Energy
  354. %==============================================================
  355. dLdu = dEdu*E + dSdu*E/2 - dDdu/2 + pu.ic*Eu;
  356. dLdh = dEdh*E/2 - dDdh/2 + ph.ic*Eh;
  357. dLdp = dEdp*E + pp.ic*Ep;
  358. % and second-order derivatives of Gibb's Energy
  359. %--------------------------------------------------------------
  360. % dLduu = dEdu*dE.du + dSdu*dE.du + dE.du'*dSdu' + pu.ic;
  361. % dLdup = dEdu*dE.dp + dSdu*dE.dp;
  362. dLduu = dEdu*dE.du + pu.ic;
  363. dLdpp = dEdp*dE.dp + pp.ic;
  364. dLdhh = dSdhh/2 + ph.ic;
  365. dLdup = dEdu*dE.dp;
  366. dLdhu = dEdh*dE.du;
  367. dLduy = dEdu*dE.dy;
  368. dLduc = dEdu*dE.dc;
  369. dLdpy = dEdp*dE.dy;
  370. dLdpc = dEdp*dE.dc;
  371. dLdhy = dEdh*dE.dy;
  372. dLdhc = dEdh*dE.dc;
  373. dLdhp = dEdh*dE.dp;
  374. dLdpu = dLdup';
  375. dLdph = dLdhp';
  376. % precision and covariances
  377. %--------------------------------------------------------------
  378. iC = spm_cat({dLduu dLdup;
  379. dLdpu dLdpp});
  380. C = spm_inv(iC);
  381. % first-order derivatives of Entropy term
  382. %==============================================================
  383. % log-precision
  384. %--------------------------------------------------------------
  385. for i = 1:nh
  386. Luub = dE.du'*diSdh{i}*dE.du;
  387. Lpub = dE.dp'*diSdh{i}*dE.du;
  388. Lppb = dE.dp'*diSdh{i}*dE.dp;
  389. diCdh = spm_cat({Luub Lpub';
  390. Lpub Lppb});
  391. dHdh(i) = sum(sum(diCdh.*C))/2;
  392. end
  393. for i = 1:ng
  394. Luub = dE.du'*diSdg{i}*dE.du;
  395. Lpub = dE.dp'*diSdg{i}*dE.du;
  396. Lppb = dE.dp'*diSdg{i}*dE.dp;
  397. diCdg = spm_cat({Luub Lpub';
  398. Lpub Lppb});
  399. dHdg(i) = sum(sum(diCdg.*C))/2;
  400. end
  401. % parameters
  402. %--------------------------------------------------------------
  403. for i = 1:np
  404. Luup = dE.dup{i}'*dEdu';
  405. Lpup = dEdp*dE.dup{i};
  406. Luup = Luup + Luup';
  407. diCdp = spm_cat({Luup Lpup';
  408. Lpup [] });
  409. dHdp(i) = sum(sum(diCdp.*C))/2;
  410. end
  411. % % hidden and causal states
  412. % %--------------------------------------------------------------
  413. % for i = 1:mnx
  414. % Luux = dE.du'*diSdx{i}*dE.du;
  415. % Lpux = dE.dp'*diSdx{i}*dE.du;
  416. % Lppx = dE.dp'*diSdx{i}*dE.dp;
  417. % diCdx = spm_cat({Luux Lpux';
  418. % Lpux Lppx});
  419. % dHdx(i) = sum(sum(diCdx.*C))/2;
  420. %
  421. % end
  422. % for i = 1:mnv
  423. % Luuv = dE.du'*diSdv{i}*dE.du;
  424. % Lpuv = dE.dp'*diSdv{i}*dE.du;
  425. % Lppv = dE.dp'*diSdv{i}*dE.dp;
  426. % diCdv = spm_cat({Luuv Lpuv';
  427. % Lpuv Lppv});
  428. % dHdv(i) = sum(sum(diCdv.*C))/2;
  429. % end
  430. dHdb = [dHdh; dHdg];
  431. dHdu = [dHdx; dHdv];
  432. % save conditional moments (and prediction error) at Q{t}
  433. %==============================================================
  434. if iD == 1
  435. % save means
  436. %----------------------------------------------------------
  437. Q(is).e = E;
  438. Q(is).E = iS*E;
  439. Q(is).u = qu;
  440. Q(is).p = qp;
  441. Q(is).h = qh;
  442. % and conditional covariances
  443. %----------------------------------------------------------
  444. Q(is).u.s = C((1:nx),(1:nx));
  445. Q(is).u.c = C((1:nv) + nx*n, (1:nv) + nx*n);
  446. Q(is).p.c = C((1:np) + nu, (1:np) + nu);
  447. Q(is).h.c = spm_inv(dLdhh);
  448. Cu = C(1:nu,1:nu);
  449. % Free-energy (states)
  450. %----------------------------------------------------------
  451. L(is) = ...
  452. - E'*iS*E/2 + spm_logdet(iS)/2 - n*ny*log(2*pi)/2 ...
  453. - Eu'*pu.ic*Eu/2 + spm_logdet(pu.ic)/2 + spm_logdet(Cu)/2;
  454. % Free-energy (states and parameters)
  455. %----------------------------------------------------------
  456. A(is) = - E'*iS*E/2 + spm_logdet(iS)/2 ...
  457. - Eu'*pu.ic*Eu/2 + spm_logdet(pu.ic)/2 ...
  458. - Ep'*pp.ic*Ep/2 + spm_logdet(pp.ic)/2 ...
  459. - Eh'*ph.ic*Eh/2 + spm_logdet(ph.ic)/2 ...
  460. - n*ny*log(2*pi)/2 - spm_logdet(iC)/2 - spm_logdet(dLdhh)/2;
  461. end
  462. % update conditional moments
  463. %==============================================================
  464. % uopdate curvatures of [hyper]paramters
  465. %--------------------------------------------------------------
  466. try
  467. dLdPP = dLdPP*(1 - 1/ns) + dLdpp/ns;
  468. dLdHH = dLdHH*(1 - 1/ns) + dLdhh/ns;
  469. catch
  470. dLdPP = dLdpp;
  471. dLdHH = dLdhh;
  472. end
  473. % rotate and scale gradient (and curvatures)
  474. %--------------------------------------------------------------
  475. [Vp,Sp] = spm_svd(dLdPP,0);
  476. [Vh,Sh] = spm_svd(dLdHH,0);
  477. Sp = diag(1./(diag(sqrt(Sp))));
  478. Sh = diag(1./(diag(sqrt(Sh))));
  479. dLdp = Sp*Vp'*dLdp;
  480. dHdp = Sp*Vp'*dHdp;
  481. dLdpy = Sp*Vp'*dLdpy;
  482. dLdpu = Sp*Vp'*dLdpu;
  483. dLdpc = Sp*Vp'*dLdpc;
  484. dLdph = Sp*Vp'*dLdph;
  485. dLdpp = Sp*Vp'*dLdpp*Vp;
  486. dLdhp = dLdhp*Vp;
  487. dLdh = Sh*Vh'*dLdh;
  488. dHdb = Sh*Vh'*dHdb;
  489. dLdhy = Sh*Vh'*dLdhy;
  490. dLdhu = Sh*Vh'*dLdhu;
  491. dLdhc = Sh*Vh'*dLdhc;
  492. dLdhp = Sh*Vh'*dLdhp;
  493. dLdhh = Sh*Vh'*dLdhh*Vh;
  494. dLdph = dLdph*Vh;
  495. % assemble conditional means
  496. %--------------------------------------------------------------
  497. q{1} = qu.y(1:n);
  498. q{2} = qu.x(1:n);
  499. q{3} = qu.v(1:d);
  500. q{4} = qu.u(1:d);
  501. q{5} = spm_unvec(Vp'*spm_vec(qp.p),qp.p);
  502. qb = spm_unvec(Vh'*spm_vec({qh.h qh.g}),{qh.h qh.g});
  503. q{6} = qb{1};
  504. q{7} = qb{2};
  505. q{8} = Vp'*qp.dp;
  506. q{9} = Vh'*qh.dp;
  507. % flow
  508. %--------------------------------------------------------------
  509. f{1} = Dy*spm_vec(q{1});
  510. f{2} = Du*spm_vec(q{2:3}) - dLdu - dHdu;
  511. f{3} = Dc*spm_vec(q{4});
  512. f{4} = spm_vec(q{8});
  513. f{5} = spm_vec(q{9});
  514. f{6} = -Kp*spm_vec(q{8}) - dLdp - dHdp;
  515. f{7} = -Kh*spm_vec(q{9}) - dLdh - dHdb;
  516. % and Jacobian
  517. %--------------------------------------------------------------
  518. dfdq = spm_cat({Dy [] [] [] [] [] [];
  519. -dLduy Du-dLduu -dLduc [] [] [] [];
  520. [] [] Dc [] [] [] [];
  521. [] [] [] [] [] Ip [];
  522. [] [] [] [] [] [] Ih;
  523. -dLdpy -dLdpu -dLdpc -dLdpp -dLdph -Kp [];
  524. -dLdhy -dLdhu -dLdhc -dLdhp -dLdhh [] -Kh});
  525. % update conditional modes of states
  526. %==============================================================
  527. dq = spm_dx(dfdq, spm_vec(f), 1/nD);
  528. q = spm_unvec(spm_vec(q) + dq,q);
  529. % unpack conditional means
  530. %--------------------------------------------------------------
  531. qu.x(1:n) = q{2};
  532. qu.v(1:d) = q{3};
  533. qp.p = spm_unvec(Vp*spm_vec(q{5}),qp.p);
  534. qb = spm_unvec(Vh*spm_vec(q{6:7}),{qh.h qh.g});
  535. qh.h = qb{1};
  536. qh.g = qb{2};
  537. qp.dp = Vp*q{8};
  538. qh.dp = Vh*q{9};
  539. end % D-Step
  540. end % sequence (ns)
  541. % Bayesian parameter averaging
  542. %======================================================================
  543. % Conditional moments of time-averaged parameters
  544. %----------------------------------------------------------------------
  545. Pp = 0;
  546. Ep = 0;
  547. for i = 1:ns
  548. P = spm_inv(Q(i).p.c);
  549. Ep = Ep + P*spm_vec(Q(i).p.p);
  550. Pp = Pp + P;
  551. end
  552. Cp = spm_inv(Pp);
  553. Ep = Cp*Ep;
  554. % conditional moments of hyper-parameters
  555. %----------------------------------------------------------------------
  556. Ph = 0;
  557. Eh = 0;
  558. for i = 1:ns
  559. P = spm_inv(Q(i).h.c);
  560. Ph = Ph + P;
  561. Eh = Eh + P*spm_vec({Q(i).h.h Q(i).h.g});
  562. end
  563. Ch = spm_inv(Ph);
  564. Eh = Ch*Eh - ph.h;
  565. % Free-action of states plus free-energy of parameters
  566. %======================================================================
  567. Fs = sum(A);
  568. Fi = sum(L) ...
  569. - Ep'*pp.ic*Ep/2 + spm_logdet(pp.ic)/2 - spm_logdet(Pp)/2 ...
  570. - Eh'*ph.ic*Eh/2 + spm_logdet(ph.ic)/2 - spm_logdet(Ph)/2;
  571. % if F is increasing terminate
  572. %----------------------------------------------------------------------
  573. if Fi < Fa && iN > 4
  574. break
  575. else
  576. Fa = Fi;
  577. F(iN) = Fi;
  578. S(iN) = Fs;
  579. end
  580. % otherwise save conditional moments (for each time point)
  581. %======================================================================
  582. for t = 1:length(Q)
  583. % states and predictions
  584. %------------------------------------------------------------------
  585. v = spm_unvec(Q(t).u.v{1},v);
  586. x = spm_unvec(Q(t).u.x{1},x);
  587. z = spm_unvec(Q(t).e(1:(ny + nv)),{M.v});
  588. Z = spm_unvec(Q(t).E(1:(ny + nv)),{M.v});
  589. w = spm_unvec(Q(t).e((1:nx) + (ny + nv)*n),{M.x});
  590. X = spm_unvec(Q(t).E((1:nx) + (ny + nv)*n),{M.x});
  591. for i = 1:(nl - 1)
  592. if M(i).m, qU.v{i + 1}(:,t) = spm_vec(v{i}); end
  593. if M(i).n, qU.x{i}(:,t) = spm_vec(x{i}); end
  594. if M(i).n, qU.w{i}(:,t) = spm_vec(w{i}); end
  595. if M(i).l, qU.z{i}(:,t) = spm_vec(z{i}); end
  596. if M(i).n, qU.W{i}(:,t) = spm_vec(X{i}); end
  597. if M(i).l, qU.Z{i}(:,t) = spm_vec(Z{i}); end
  598. end
  599. if M(nl).l, qU.z{nl}(:,t) = spm_vec(z{nl}); end
  600. if M(nl).l, qU.Z{nl}(:,t) = spm_vec(Z{nl}); end
  601. qU.v{1}(:,t) = spm_vec(Q(t).u.y{1}) - spm_vec(z{1});
  602. % and conditional covariances
  603. %------------------------------------------------------------------
  604. qU.S{t} = Q(t).u.s;
  605. qU.C{t} = Q(t).u.c;
  606. % parameters
  607. %------------------------------------------------------------------
  608. qP.p{t} = spm_vec(Q(t).p.p);
  609. qP.c{t} = Q(t).p.c;
  610. % hyperparameters
  611. %------------------------------------------------------------------
  612. qH.p{t} = spm_vec({Q(t).h.h Q(t).h.g});
  613. qH.c{t} = Q(t).h.c;
  614. end
  615. % graphics (states)
  616. %----------------------------------------------------------------------
  617. figure(Fdem)
  618. spm_DEM_qU(qU)
  619. % graphics (parameters and log-precisions)
  620. %----------------------------------------------------------------------
  621. if np && nb
  622. subplot(2*nl,2,4*nl - 2)
  623. plot(1:ns,spm_cat(qP.p))
  624. set(gca,'XLim',[1 ns])
  625. title('parameters (modes)','FontSize',16)
  626. subplot(2*nl,2,4*nl)
  627. plot(1:ns,spm_cat(qH.p))
  628. set(gca,'XLim',[1 ns])
  629. title('log-precision','FontSize',16)
  630. elseif nb
  631. subplot(nl,2,2*nl)
  632. plot(1:ns,spm_cat(qH.p))
  633. set(gca,'XLim',[1 ns])
  634. title('log-precision','FontSize',16)
  635. elseif np
  636. subplot(nl,2,2*nl)
  637. plot(1:ns,spm_cat(qP.p))
  638. set(gca,'XLim',[1 ns])
  639. title('parameters (modes)','FontSize',16)
  640. end
  641. drawnow
  642. % report (EM-Steps)
  643. %----------------------------------------------------------------------
  644. try
  645. dF = F(end) - F(end - 1);
  646. catch
  647. dF = 0;
  648. end
  649. str{1} = sprintf('LAP: %i (%i)', iN,iD);
  650. str{2} = sprintf('F:%.4e', full(F(iN) - F(1)));
  651. str{3} = sprintf('dF:%.2e', full(dF));
  652. str{4} = sprintf('(%.2e sec)', full(toc));
  653. fprintf('%-16s%-16s%-14s%-16s\n',str{:})
  654. end
  655. % Place Bayesian parameter averages in output arguments
  656. %==========================================================================
  657. % Conditional moments of time-averaged parameters
  658. %--------------------------------------------------------------------------
  659. Pp = 0;
  660. Ep = 0;
  661. for i = 1:ns
  662. % weight in proportion to precisions
  663. %----------------------------------------------------------------------
  664. P = spm_inv(qP.c{i});
  665. Ep = Ep + P*qP.p{i};
  666. Pp = Pp + P;
  667. end
  668. Cp = spm_inv(Pp);
  669. Ep = Cp*Ep;
  670. P = {M.pE};
  671. qP.P = spm_unvec(Up*Ep + pp.p,P);
  672. qP.C = Up*Cp*Up';
  673. qP.V = spm_unvec(diag(qP.C),P);
  674. qP.U = Up;
  675. % conditional moments of hyper-parameters
  676. %--------------------------------------------------------------------------
  677. Ph = 0;
  678. Eh = 0;
  679. for i = 1:ns
  680. % weight in proportion to precisions
  681. %----------------------------------------------------------------------
  682. P = spm_inv(qH.c{i});
  683. Ph = Ph + P;
  684. Eh = Eh + P*qH.p{i};
  685. end
  686. Ch = spm_inv(Ph);
  687. Eh = Ch*Eh;
  688. P = {qh.h qh.g};
  689. P = spm_unvec(Eh,P);
  690. qH.h = P{1};
  691. qH.g = P{2};
  692. qH.C = Ch;
  693. P = spm_unvec(diag(qH.C),P);
  694. qH.V = P{1};
  695. qH.W = P{2};
  696. % assign output variables
  697. %--------------------------------------------------------------------------
  698. DEM.M = M; % model
  699. DEM.U = U; % causes
  700. DEM.qU = qU; % conditional moments of model-states
  701. DEM.qP = qP; % conditional moments of model-parameters
  702. DEM.qH = qH; % conditional moments of hyper-parameters
  703. DEM.F = F; % [-ve] Free energy
  704. DEM.S = S; % [-ve] Free action
  705. return
  706. % Notes (check on curvature)
  707. %==========================================================================
  708. % analytic form
  709. %----------------------------------------------------------
  710. iC = spm_cat({dLduu dLdup dLduh;
  711. dLdpu dLdpp dLdph;
  712. dLdhu dLdhp dLdhh});
  713. % numerical approximations
  714. %----------------------------------------------------------
  715. qq.x = qu.x(1:n);
  716. qq.v = qu.v(1:d);
  717. qq.p = qp.p;
  718. qq.h = qh.h;
  719. qq.g = qh.g;
  720. dLdqq = spm_diff('spm_LAP_F',qq,qu,qp,qh,pu,pp,ph,M,[1 1]);
  721. dLdqq = spm_cat(dLdqq');
  722. subplot(2,2,1);imagesc(dLdqq); axis square
  723. subplot(2,2,2);imagesc(iC); axis square
  724. subplot(2,2,3);imagesc(dLdqq - iC);axis square
  725. subplot(2,2,4);plot(iC,':k');hold on;
  726. plot(dLdqq - iC,'r');hold off; axis square
  727. drawnow
  728. % Notes (descent on parameters
  729. %==========================================================================
  730. I = eye(length(dLdpp));
  731. k = kp;
  732. Luu = dLdpp;
  733. J = spm_cat({[] I;
  734. -Luu -k*I});
  735. [u s] = eig(full(J));
  736. max(diag(s))
  737. [uj sj] = eig(full(dLdpp));
  738. Luu = min(diag(sj));
  739. % Luu = max(diag(sj));
  740. k = kp;
  741. ss(1) = -(k + sqrt(k^2 - 4*Luu))/2;
  742. ss(2) = -(k - sqrt(k^2 - 4*Luu))/2;
  743. max(ss)
  744. k = (1:128);
  745. s = -(k - sqrt(k.^2 - 4*Luu))/2;
  746. plot(k,-1./real(s))