spm_ADEM.m 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. function [DEM] = spm_ADEM(DEM)
  2. % Dynamic expectation maximisation: Active inversion
  3. % FORMAT DEM = spm_ADEM(DEM)
  4. %
  5. % DEM.G - generative process
  6. % DEM.M - recognition model
  7. % DEM.C - causes
  8. % DEM.U - prior expectation of causes
  9. %__________________________________________________________________________
  10. %
  11. % This implementation of DEM is the same as spm_DEM but integrates both the
  12. % generative process and model inversion in parallel. Its functionality is
  13. % exactly the same apart from the fact that confounds are not accommodated
  14. % explicitly. The generative model is specified by DEM.G and the veridical
  15. % causes by DEM.C; these may or may not be used as priors on the causes for
  16. % the inversion model DEM.M (i.e., DEM.U = DEM.C). Clearly, DEM.G does not
  17. % require any priors or precision components; it will use the values of the
  18. % parameters specified in the prior expectation fields.
  19. %
  20. % This routine is not used for model inversion per se but to simulate the
  21. % dynamical inversion of models. Critically, it includes action
  22. % variables a - that couple the model back to the generative process
  23. % This enables active inference (c.f., action-perception) or embodied
  24. % inference.
  25. %
  26. % hierarchical models M(i)
  27. %--------------------------------------------------------------------------
  28. % M(i).g = y(t) = g(x,v,P) {inline function, string or m-file}
  29. % M(i).f = dx/dt = f(x,v,P) {inline function, string or m-file}
  30. %
  31. % M(i).pE = prior expectation of p model-parameters
  32. % M(i).pC = prior covariances of p model-parameters
  33. % M(i).hE = prior expectation of h hyper-parameters (cause noise)
  34. % M(i).hC = prior covariances of h hyper-parameters (cause noise)
  35. % M(i).gE = prior expectation of g hyper-parameters (state noise)
  36. % M(i).gC = prior covariances of g hyper-parameters (state noise)
  37. % M(i).Q = precision components (input noise)
  38. % M(i).R = precision components (state noise)
  39. % M(i).V = fixed precision (input noise)
  40. % M(i).W = fixed precision (state noise)
  41. % M(i).xP = precision (states)
  42. %
  43. % M(i).m = number of inputs v(i + 1);
  44. % M(i).n = number of states x(i)
  45. % M(i).l = number of output v(i)
  46. % M(i).k = number of action a(i)
  47. % hierarchical process G(i)
  48. %--------------------------------------------------------------------------
  49. % G(i).g = y(t) = g(x,v,a,P) {inline function, string or m-file}
  50. % G(i).f = dx/dt = f(x,v,a,P) {inline function, string or m-file}
  51. %
  52. % G(i).pE = model-parameters
  53. % G(i).U = precision (action)
  54. % G(i).V = precision (input noise)
  55. % G(i).W = precision (state noise)
  56. %
  57. % G(1).R = restriction or rate matrix for action [default: 1];
  58. % G(i).aP = precision (action) [default: exp(-2)]
  59. %
  60. % G(i).m = number of inputs v(i + 1);
  61. % G(i).n = number of states x(i)
  62. % G(i).l = number of output v(i)
  63. % G(i).k = number of action a(i)
  64. %
  65. %
  66. % Returns the following fields of DEM
  67. %--------------------------------------------------------------------------
  68. %
  69. % true model-states - u
  70. %--------------------------------------------------------------------------
  71. % pU.x = true hidden states
  72. % pU.v = true causal states v{1} = response (Y)
  73. % pU.C = prior covariance: cov(v)
  74. % pU.S = prior covariance: cov(x)
  75. %
  76. % model-parameters - p
  77. %--------------------------------------------------------------------------
  78. % pP.P = parameters for each level
  79. %
  80. % hyper-parameters (log-transformed) - h,g
  81. %--------------------------------------------------------------------------
  82. % pH.h = cause noise
  83. % pH.g = state noise
  84. %
  85. % conditional moments of model-states - q(u)
  86. %--------------------------------------------------------------------------
  87. % qU.a = Action
  88. % qU.x = Conditional expectation of hidden states
  89. % qU.v = Conditional expectation of causal states
  90. % qU.z = Conditional prediction errors (v)
  91. % qU.C = Conditional covariance: cov(v)
  92. % qU.S = Conditional covariance: cov(x)
  93. %
  94. % conditional moments of model-parameters - q(p)
  95. %--------------------------------------------------------------------------
  96. % qP.P = Conditional expectation
  97. % qP.C = Conditional covariance
  98. %
  99. % conditional moments of hyper-parameters (log-transformed) - q(h)
  100. %--------------------------------------------------------------------------
  101. % qH.h = Conditional expectation (cause noise)
  102. % qH.g = Conditional expectation (state noise)
  103. % qH.C = Conditional covariance
  104. %
  105. % F = log evidence = log marginal likelihood = negative free energy
  106. %__________________________________________________________________________
  107. %
  108. % spm_ADEM implements a variational Bayes (VB) scheme under the Laplace
  109. % approximation to the conditional densities of states (u), parameters (p)
  110. % and hyperparameters (h) of any analytic nonlinear hierarchical dynamic
  111. % model, with additive Gaussian innovations. It comprises three
  112. % variational steps (D,E and M) that update the conditional moments of u, p
  113. % and h respectively
  114. %
  115. % D: qu.u = max <L>q(p,h)
  116. % E: qp.p = max <L>q(u,h)
  117. % M: qh.h = max <L>q(u,p)
  118. %
  119. % where qu.u corresponds to the conditional expectation of hidden states x
  120. % and causal states v and so on. L is the ln p(y,u,p,h|M) under the model
  121. % M. The conditional covariances obtain analytically from the curvature of
  122. % L with respect to u, p and h.
  123. %
  124. % The D-step is embedded in the E-step because q(u) changes with each
  125. % sequential observation. The dynamical model is transformed into a static
  126. % model using temporal derivatives at each time point. Continuity of the
  127. % conditional trajectories q(u,t) is assured by a continuous ascent of F(t)
  128. % in generalised co-ordinates. This means DEM can deconvolve online and
  129. % represents an alternative to Kalman filtering or alternative Bayesian
  130. % update procedures.
  131. %
  132. %__________________________________________________________________________
  133. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  134. % Karl Friston
  135. % $Id: spm_ADEM.m 7145 2017-07-31 13:57:39Z karl $
  136. % check model, data, priors and unpack
  137. %--------------------------------------------------------------------------
  138. DEM = spm_ADEM_set(DEM);
  139. M = DEM.M;
  140. G = DEM.G;
  141. C = DEM.C;
  142. U = DEM.U;
  143. % check whether to print
  144. %--------------------------------------------------------------------------
  145. try
  146. db = DEM.db;
  147. catch
  148. db = 1;
  149. end
  150. % find or create a DEM figure
  151. %--------------------------------------------------------------------------
  152. if db
  153. Fdem = spm_figure('GetWin','DEM');
  154. end
  155. % ensure embedding dimensions are compatible
  156. %--------------------------------------------------------------------------
  157. G(1).E.n = M(1).E.n;
  158. G(1).E.d = M(1).E.n;
  159. % order parameters (d = n = 1 for static models) and checks
  160. %==========================================================================
  161. d = M(1).E.d + 1; % embedding order of q(v)
  162. n = M(1).E.n + 1; % embedding order of q(x)
  163. s = M(1).E.s; % smoothness - s.d. (bins)
  164. % number of states and parameters - generative model
  165. %--------------------------------------------------------------------------
  166. nY = size(C,2); % number of samples
  167. nl = size(M,2); % number of levels
  168. nv = sum(spm_vec(M.m)); % number of v (causal states)
  169. nx = sum(spm_vec(M.n)); % number of x (hidden states)
  170. ny = M(1).l; % number of y (inputs)
  171. nc = M(end).l; % number of c (prior causes)
  172. nu = nv*d + nx*n; % number of generalised states
  173. % number of states and parameters - generative process
  174. %--------------------------------------------------------------------------
  175. gr = sum(spm_vec(G.l)); % number of v (outputs)
  176. ga = sum(spm_vec(G.k)); % number of a (active states)
  177. gx = sum(spm_vec(G.n)); % number of x (hidden states)
  178. gy = G(1).l; % number of y (inputs)
  179. na = ga; % number of a (action)
  180. % number of iterations
  181. %--------------------------------------------------------------------------
  182. try, nE = M(1).E.nE; catch, nE = 16; end
  183. try, nM = M(1).E.nM; catch, nM = 8; end
  184. try, dt = M(1).E.dt; catch, dt = 1; end
  185. % initialise regularisation parameters
  186. %--------------------------------------------------------------------------
  187. te = 2; % log integration time for E-Step
  188. global t
  189. % precision (roughness) of generalised fluctuations
  190. %--------------------------------------------------------------------------
  191. iV = spm_DEM_R(n,s);
  192. iG = spm_DEM_R(n,s);
  193. % time-delay operators (absorb motor delays into motor gain matrix)
  194. %--------------------------------------------------------------------------
  195. try
  196. nG = norm(iG);
  197. iG = iG*spm_DEM_T(n,-M(1).Ta);
  198. iG = iG*nG/norm(iG);
  199. end
  200. try
  201. Ty = spm_DEM_T(n,-M(1).Ty);
  202. Ty = kron(Ty,speye(ny,ny));
  203. end
  204. % precision components Q{} requiring [Re]ML estimators (M-Step)
  205. %==========================================================================
  206. Q = {};
  207. for i = 1:nl
  208. q0{i,i} = sparse(M(i).l,M(i).l); %#ok<AGROW>
  209. r0{i,i} = sparse(M(i).n,M(i).n);
  210. end
  211. Q0 = kron(iV,spm_cat(q0));
  212. R0 = kron(iV,spm_cat(r0));
  213. for i = 1:nl
  214. for j = 1:length(M(i).Q)
  215. q = q0;
  216. q{i,i} = M(i).Q{j};
  217. Q{end + 1} = blkdiag(kron(iV,spm_cat(q)),R0);
  218. end
  219. for j = 1:length(M(i).R)
  220. q = r0;
  221. q{i,i} = M(i).R{j};
  222. Q{end + 1} = blkdiag(Q0,kron(iV,spm_cat(q)));
  223. end
  224. end
  225. % and fixed components P
  226. %--------------------------------------------------------------------------
  227. Q0 = kron(iV,spm_cat(spm_diag({M.V})));
  228. R0 = kron(iV,spm_cat(spm_diag({M.W})));
  229. Qp = blkdiag(Q0,R0);
  230. nh = length(Q); % number of hyperparameters
  231. iR = [zeros(1,ny),ones(1,nv),ones(1,nx)]; % for empirical priors
  232. iR = kron(speye(n,n),diag(iR));
  233. % restriction or rate matrices - in terms of precision
  234. %--------------------------------------------------------------------------
  235. q0{1} = G(1).U;
  236. Q0 = kron(iG,spm_cat(q0));
  237. R0 = kron(iG,spm_cat(r0));
  238. iG = blkdiag(Q0,R0);
  239. % restriction or rate matrices – in terms of dE/da
  240. %--------------------------------------------------------------------------
  241. try
  242. R = sparse(sum(spm_vec(G.l)),na);
  243. R(1:ny,:) = G(1).R;
  244. R = kron(spm_speye(n,1,0),R);
  245. catch
  246. R = 1;
  247. end
  248. % fixed priors on action (a)
  249. %--------------------------------------------------------------------------
  250. try
  251. aP = G(1).aP;
  252. catch
  253. aP = exp(-2);
  254. end
  255. % fixed priors on states (u)
  256. %--------------------------------------------------------------------------
  257. xP = spm_cat(spm_diag({M.xP}));
  258. Px = kron(iV(1:n,1:n),speye(nx,nx)*exp(-8) + xP);
  259. Pv = kron(iV(1:d,1:d),speye(nv,nv)*exp(-8));
  260. Pa = spm_speye(na,na)*aP;
  261. Pu = spm_cat(spm_diag({Px Pv}));
  262. % hyperpriors
  263. %--------------------------------------------------------------------------
  264. ph.h = spm_vec({M.hE M.gE}); % prior expectation of h
  265. ph.c = spm_cat(spm_diag({M.hC M.gC})); % prior covariances of h
  266. qh.h = ph.h; % conditional expectation
  267. qh.c = ph.c; % conditional covariance
  268. ph.ic = spm_inv(ph.c); % prior precision
  269. % priors on parameters (in reduced parameter space)
  270. %==========================================================================
  271. pp.c = cell(nl,nl);
  272. qp.p = cell(nl,1);
  273. for i = 1:(nl - 1)
  274. % eigenvector reduction: p <- pE + qp.u*qp.p
  275. %----------------------------------------------------------------------
  276. qp.u{i} = spm_svd(M(i).pC); % basis for parameters
  277. M(i).p = size(qp.u{i},2); % number of qp.p
  278. qp.p{i} = sparse(M(i).p,1); % initial qp.p
  279. pp.c{i,i} = qp.u{i}'*M(i).pC*qp.u{i}; % prior covariance
  280. try
  281. qp.e{i} = qp.p{i} + qp.u{i}'*(spm_vec(M(i).P) - spm_vec(M(i).pE));
  282. catch
  283. qp.e{i} = qp.p{i}; % initial qp.e
  284. end
  285. end
  286. Up = spm_cat(spm_diag(qp.u));
  287. % initialise and augment with confound parameters B; with flat priors
  288. %--------------------------------------------------------------------------
  289. np = sum(spm_vec(M.p)); % number of model parameters
  290. pp.c = spm_cat(pp.c);
  291. pp.ic = spm_inv(pp.c);
  292. % initialise conditional density q(p) (for D-Step)
  293. %--------------------------------------------------------------------------
  294. qp.e = spm_vec(qp.e);
  295. qp.c = sparse(np,np);
  296. % initialise cell arrays for D-Step; e{i + 1} = (d/dt)^i[e] = e[i]
  297. %==========================================================================
  298. qu.x = cell(n,1);
  299. qu.v = cell(n,1);
  300. qu.a = cell(1,1);
  301. qu.y = cell(n,1);
  302. qu.u = cell(n,1);
  303. pu.v = cell(n,1);
  304. pu.x = cell(n,1);
  305. pu.z = cell(n,1);
  306. pu.w = cell(n,1);
  307. [qu.x{:}] = deal(sparse(nx,1));
  308. [qu.v{:}] = deal(sparse(nv,1));
  309. [qu.a{:}] = deal(sparse(na,1));
  310. [qu.y{:}] = deal(sparse(ny,1));
  311. [qu.u{:}] = deal(sparse(nc,1));
  312. [pu.v{:}] = deal(sparse(gr,1));
  313. [pu.x{:}] = deal(sparse(gx,1));
  314. [pu.z{:}] = deal(sparse(gr,1));
  315. [pu.w{:}] = deal(sparse(gx,1));
  316. % initialise cell arrays for hierarchical structure of x[0] and v[0]
  317. %--------------------------------------------------------------------------
  318. qu.x{1} = spm_vec({M(1:end - 1).x});
  319. qu.v{1} = spm_vec({M(1 + 1:end).v});
  320. qu.a{1} = spm_vec({G.a});
  321. pu.x{1} = spm_vec({G.x});
  322. pu.v{1} = spm_vec({G.v});
  323. % derivatives for Jacobian of D-step
  324. %--------------------------------------------------------------------------
  325. Dx = kron(spm_speye(n,n,1),spm_speye(nx,nx,0));
  326. Dv = kron(spm_speye(d,d,1),spm_speye(nv,nv,0));
  327. Dc = kron(spm_speye(d,d,1),spm_speye(nc,nc,0));
  328. Da = kron(spm_speye(1,1,1),sparse(na,na));
  329. Du = spm_cat(spm_diag({Dx,Dv}));
  330. Dq = spm_cat(spm_diag({Dx,Dv,Dc,Da}));
  331. Dx = kron(spm_speye(n,n,1),spm_speye(gx,gx,0));
  332. Dv = kron(spm_speye(n,n,1),spm_speye(gr,gr,0));
  333. Dp = spm_cat(spm_diag({Dv,Dx,Dv,Dx}));
  334. dfdw = kron(speye(n,n),speye(gx,gx));
  335. dydv = kron(speye(n,n),speye(gy,gr));
  336. % and null blocks
  337. %--------------------------------------------------------------------------
  338. dVdc = sparse(d*nc,1);
  339. % gradients and curvatures for conditional uncertainty
  340. %--------------------------------------------------------------------------
  341. dWdu = sparse(nu,1);
  342. dWduu = sparse(nu,nu);
  343. % preclude unnecessary iterations
  344. %--------------------------------------------------------------------------
  345. if ~np && ~nh, nE = 1; end
  346. % create innovations (and add causes)
  347. %--------------------------------------------------------------------------
  348. [z,w] = spm_DEM_z(G,nY);
  349. z{end} = C + z{end};
  350. a = {G.a};
  351. Z = spm_cat(z(:));
  352. W = spm_cat(w(:));
  353. A = spm_cat(a(:));
  354. % Iterate DEM
  355. %==========================================================================
  356. F = -Inf;
  357. for iE = 1:nE
  358. % get time and clear persistent variables in evaluation routines
  359. %----------------------------------------------------------------------
  360. tic; clear spm_DEM_eval
  361. % E-Step: (with embedded D-Step)
  362. %======================================================================
  363. % [re-]set accumulators for E-Step
  364. %----------------------------------------------------------------------
  365. dFdp = zeros(np,1);
  366. dFdpp = zeros(np,np);
  367. EE = sparse(0);
  368. ECE = sparse(0);
  369. EiSE = sparse(0);
  370. qp.ic = sparse(0);
  371. Hqu.c = sparse(0);
  372. % [re-]set precisions using [hyper]parameter estimates
  373. %----------------------------------------------------------------------
  374. iS = Qp;
  375. for i = 1:nh
  376. iS = iS + Q{i}*exp(qh.h(i));
  377. end
  378. % precision for empirical priors
  379. %----------------------------------------------------------------------
  380. iP = iR*iS*iR;
  381. % [re-]set states & their derivatives
  382. %----------------------------------------------------------------------
  383. try
  384. qu = qU(1);
  385. pu = pU(1);
  386. end
  387. % D-Step: (nY samples)
  388. %======================================================================
  389. for iY = 1:nY
  390. % time (GLOBAL variable for non-automomous systems)
  391. %------------------------------------------------------------------
  392. t = iY/nY;
  393. % pass action to pu.a (external states)
  394. %==================================================================
  395. try, A = spm_cat({qU.a qu.a}); end
  396. % derivatives of responses and random fluctuations
  397. %------------------------------------------------------------------
  398. pu.z = spm_DEM_embed(Z,n,iY);
  399. pu.w = spm_DEM_embed(W,n,iY);
  400. pu.a = spm_DEM_embed(A,n,iY);
  401. qu.u = spm_DEM_embed(U,n,iY);
  402. % evaluate generative process
  403. %------------------------------------------------------------------
  404. [pu,dg,df] = spm_ADEM_diff(G,pu);
  405. % and pass response to qu.y
  406. %==================================================================
  407. for i = 1:n
  408. y = spm_unvec(pu.v{i},{G.v});
  409. qu.y{i} = y{1};
  410. end
  411. % sensory delays
  412. %------------------------------------------------------------------
  413. try, qu.y = spm_unvec(Ty*spm_vec(qu.y),qu.y); end
  414. % evaluate generative model
  415. %------------------------------------------------------------------
  416. [E,dE] = spm_DEM_eval(M,qu,qp);
  417. % conditional covariance [of states {u}]
  418. %------------------------------------------------------------------
  419. qu.c = spm_inv(dE.du'*iS*dE.du + Pu);
  420. pu.c = spm_inv(dE.du'*iP*dE.du + Pu);
  421. Hqu.c = Hqu.c + spm_logdet(qu.c);
  422. % save at qu(t)
  423. %------------------------------------------------------------------
  424. qE{iY} = E;
  425. qC{iY} = qu.c;
  426. pC{iY} = pu.c;
  427. qU(iY) = qu;
  428. pU(iY) = pu;
  429. % and conditional precision
  430. %------------------------------------------------------------------
  431. if nh
  432. ECEu = dE.du*qu.c*dE.du';
  433. ECEp = dE.dp*qp.c*dE.dp';
  434. end
  435. % uncertainty about parameters dWdv, ... ; W = ln(|qp.c|)
  436. %==================================================================
  437. if np
  438. for i = 1:nu
  439. CJp(:,i) = spm_vec(qp.c*dE.dpu{i}'*iS);
  440. dEdpu(:,i) = spm_vec(dE.dpu{i}');
  441. end
  442. dWdu = CJp'*spm_vec(dE.dp');
  443. dWduu = CJp'*dEdpu;
  444. end
  445. % tensor products for Jacobian (generative process)
  446. %------------------------------------------------------------------
  447. Dgda = kron(spm_speye(n,1,1),dg.da);
  448. Dgdv = kron(spm_speye(n,n,1),dg.dv);
  449. Dgdx = kron(spm_speye(n,n,1),dg.dx);
  450. dfda = kron(spm_speye(n,1,0),df.da);
  451. dfdv = kron(spm_speye(n,n,0),df.dv);
  452. dfdx = kron(spm_speye(n,n,0),df.dx);
  453. dgda = kron(spm_speye(n,1,0),dg.da);
  454. dgdx = kron(spm_speye(n,n,0),dg.dx);
  455. % change in error w.r.t. action
  456. %------------------------------------------------------------------
  457. Dfdx = 0;
  458. for i = 1:n
  459. Dfdx = Dfdx + kron(spm_speye(n,n,-i),df.dx^(i - 1));
  460. end
  461. % dE/da with restriction (R)
  462. %------------------------------------------------------------------
  463. dE.dv = dE.dy*dydv;
  464. dE.da = dE.dv*((dgda + dgdx*Dfdx*dfda).*R);
  465. % first-order derivatives
  466. %------------------------------------------------------------------
  467. dVdu = -dE.du'*iS*E - Pu*spm_vec({qu.x{1:n} qu.v{1:d}}) - dWdu/2;
  468. dVda = -dE.da'*iG*E - Pa*spm_vec( qu.a{1:1});
  469. % and second-order derivatives
  470. %------------------------------------------------------------------
  471. dVduu = -dE.du'*iS*dE.du - Pu - dWduu/2 ;
  472. dVdaa = -dE.da'*iG*dE.da - Pa;
  473. dVduv = -dE.du'*iS*dE.dv;
  474. dVduc = -dE.du'*iS*dE.dc;
  475. dVdua = -dE.du'*iS*dE.da;
  476. dVdav = -dE.da'*iG*dE.dv;
  477. dVdau = -dE.da'*iG*dE.du;
  478. dVdac = -dE.da'*iG*dE.dc;
  479. % D-step update: of causes v{i}, and hidden states x(i)
  480. %==================================================================
  481. % states and conditional modes
  482. %------------------------------------------------------------------
  483. p = {pu.v{1:n} pu.x{1:n} pu.z{1:n} pu.w{1:n}};
  484. q = {qu.x{1:n} qu.v{1:d} qu.u{1:d} qu.a{1:1}};
  485. u = [p q];
  486. % gradient
  487. %------------------------------------------------------------------
  488. dFdu = [ Dp*spm_vec(p);
  489. spm_vec({dVdu; dVdc; dVda}) + Dq*spm_vec(q)];
  490. % Jacobian (variational flow)
  491. %------------------------------------------------------------------
  492. dFduu = spm_cat(...
  493. {Dgdv Dgdx Dv [] [] [] Dgda;
  494. dfdv dfdx [] dfdw [] [] dfda;
  495. [] [] Dv [] [] [] [];
  496. [] [] [] Dx [] [] [];
  497. dVduv [] [] [] Du+dVduu dVduc dVdua;
  498. [] [] [] [] [] Dc []
  499. dVdav [] [] [] dVdau dVdac dVdaa});
  500. % update states q = {x,v,z,w} and conditional modes
  501. %==================================================================
  502. du = spm_dx(dFduu,dFdu,dt);
  503. u = spm_unvec(spm_vec(u) + du,u);
  504. % and save them
  505. %------------------------------------------------------------------
  506. pu.v(1:n) = u((1:n));
  507. pu.x(1:n) = u((1:n) + n);
  508. qu.x(1:n) = u((1:n) + n + n + n + n);
  509. qu.v(1:d) = u((1:d) + n + n + n + n + n);
  510. qu.a(1:1) = u((1:1) + n + n + n + n + n + d + d);
  511. % Gradients and curvatures for E-Step: W = tr(C*J'*iS*J)
  512. %==================================================================
  513. if np
  514. for i = 1:np
  515. CJu(:,i) = spm_vec(qu.c*dE.dup{i}'*iS);
  516. dEdup(:,i) = spm_vec(dE.dup{i}');
  517. end
  518. dWdp = CJu'*spm_vec(dE.du');
  519. dWdpp = CJu'*dEdup;
  520. % Accumulate; dF/dP = <dL/dp>, dF/dpp = ...
  521. %--------------------------------------------------------------
  522. dFdp = dFdp - dWdp/2 - dE.dp'*iS*E;
  523. dFdpp = dFdpp - dWdpp/2 - dE.dp'*iS*dE.dp;
  524. qp.ic = qp.ic + dE.dp'*iS*dE.dp;
  525. end
  526. % accumulate SSE
  527. %------------------------------------------------------------------
  528. EiSE = EiSE + E'*iS*E;
  529. % and quantities for M-Step
  530. %------------------------------------------------------------------
  531. if nh
  532. EE = E*E'+ EE;
  533. ECE = ECE + ECEu + ECEp;
  534. end
  535. if nE == 1
  536. % evaluate objective function (F)
  537. %======================================================================
  538. J(iY) = - trace(E'*iS*E)/2 ... % states (u)
  539. + spm_logdet(qu.c) ... % entropy q(u)
  540. + spm_logdet(iS)/2; % entropy - error
  541. end
  542. end % sequence (nY)
  543. % augment with priors
  544. %----------------------------------------------------------------------
  545. dFdp = dFdp - pp.ic*qp.e;
  546. dFdpp = dFdpp - pp.ic;
  547. qp.ic = qp.ic + pp.ic;
  548. qp.c = spm_inv(qp.ic);
  549. % E-step: update expectation (p)
  550. %======================================================================
  551. % update conditional expectation
  552. %----------------------------------------------------------------------
  553. dp = spm_dx(dFdpp,dFdp,{te});
  554. qp.e = qp.e + dp;
  555. qp.p = spm_unvec(qp.e,qp.p);
  556. % M-step - hyperparameters (h = exp(l))
  557. %======================================================================
  558. mh = zeros(nh,1);
  559. dFdh = zeros(nh,1);
  560. dFdhh = zeros(nh,nh);
  561. for iM = 1:nM
  562. % [re-]set precisions using [hyper]parameter estimates
  563. %------------------------------------------------------------------
  564. iS = Qp;
  565. for i = 1:nh
  566. iS = iS + Q{i}*exp(qh.h(i));
  567. end
  568. S = spm_inv(iS);
  569. dS = ECE + EE - S*nY;
  570. % 1st-order derivatives: dFdh = dF/dh
  571. %------------------------------------------------------------------
  572. for i = 1:nh
  573. dPdh{i} = Q{i}*exp(qh.h(i));
  574. dFdh(i,1) = -trace(dPdh{i}*dS)/2;
  575. end
  576. % 2nd-order derivatives: dFdhh
  577. %------------------------------------------------------------------
  578. for i = 1:nh
  579. for j = 1:nh
  580. dFdhh(i,j) = -trace(dPdh{i}*S*dPdh{j}*S*nY)/2;
  581. end
  582. end
  583. % hyperpriors
  584. %------------------------------------------------------------------
  585. qh.e = qh.h - ph.h;
  586. dFdh = dFdh - ph.ic*qh.e;
  587. dFdhh = dFdhh - ph.ic;
  588. % update ReML estimate of parameters
  589. %------------------------------------------------------------------
  590. dh = spm_dx(dFdhh,dFdh);
  591. qh.h = qh.h + dh;
  592. mh = mh + dh;
  593. % conditional covariance of hyperparameters
  594. %------------------------------------------------------------------
  595. qh.c = -spm_inv(dFdhh);
  596. % convergence (M-Step)
  597. %------------------------------------------------------------------
  598. if (dFdh'*dh < 1e-2) || (norm(dh,1) < exp(-8)), break, end
  599. end % M-Step
  600. % evaluate objective function (F)
  601. %======================================================================
  602. L = - trace(EiSE)/2 ... % states (u)
  603. - trace(qp.e'*pp.ic*qp.e)/2 ... % parameters (p)
  604. - trace(qh.e'*ph.ic*qh.e)/2 ... % hyperparameters (h)
  605. + Hqu.c/2 ... % entropy q(u)
  606. + spm_logdet(qp.c)/2 ... % entropy q(p)
  607. + spm_logdet(qh.c)/2 ... % entropy q(h)
  608. - spm_logdet(pp.c)/2 ... % entropy - prior p
  609. - spm_logdet(ph.c)/2 ... % entropy - prior h
  610. + spm_logdet(iS)*nY/2 ... % entropy - error
  611. - n*ny*nY*log(2*pi)/2;
  612. % if F is increasing, save expansion point and derivatives
  613. %----------------------------------------------------------------------
  614. if L > F(end) || iE < 3
  615. % save model-states (for each time point)
  616. %==================================================================
  617. for t = 1:length(qU)
  618. % states
  619. %--------------------------------------------------------------
  620. a = spm_unvec(qU(t).a{1},{G.a});
  621. v = spm_unvec(pU(t).v{1},{G.v});
  622. x = spm_unvec(pU(t).x{1},{G.x});
  623. z = spm_unvec(pU(t).z{1},{G.v});
  624. w = spm_unvec(pU(t).w{1},{G.x});
  625. for i = 1:nl
  626. try
  627. PU.v{i}(:,t) = spm_vec(v{i});
  628. PU.z{i}(:,t) = spm_vec(z{i});
  629. end
  630. try
  631. PU.x{i}(:,t) = spm_vec(x{i});
  632. PU.w{i}(:,t) = spm_vec(w{i});
  633. end
  634. try
  635. QU.a{i}(:,t) = spm_vec(a{i});
  636. end
  637. end
  638. % conditional modes
  639. %--------------------------------------------------------------
  640. v = spm_unvec(qU(t).v{1},{M(1 + 1:end).v});
  641. x = spm_unvec(qU(t).x{1},{M(1:end - 1).x});
  642. z = spm_unvec(qE{t}(1:(ny + nv)),{M.v});
  643. w = spm_unvec(qE{t}((1:nx) + (ny + nv)*n),{M.x});
  644. for i = 1:(nl - 1)
  645. if M(i).m, QU.v{i + 1}(:,t) = spm_vec(v{i}); end
  646. if M(i).l, QU.z{i}(:,t) = spm_vec(z{i}); end
  647. if M(i).n, QU.x{i}(:,t) = spm_vec(x{i}); end
  648. if M(i).n, QU.w{i}(:,t) = spm_vec(w{i}); end
  649. end
  650. QU.v{1}(:,t) = spm_vec(qU(t).y{1}) - spm_vec(z{1});
  651. QU.z{nl}(:,t) = spm_vec(z{nl});
  652. % and conditional covariances
  653. %--------------------------------------------------------------
  654. i = (1:nx);
  655. QU.S{t} = qC{t}(i,i);
  656. PU.S{t} = pC{t}(i,i);
  657. i = (1:nv) + nx*n;
  658. QU.C{t} = qC{t}(i,i);
  659. PU.C{t} = pC{t}(i,i);
  660. end
  661. % save conditional densities
  662. %------------------------------------------------------------------
  663. B.QU = QU;
  664. B.PU = PU;
  665. B.qp = qp;
  666. B.qh = qh;
  667. % decrease regularisation
  668. %------------------------------------------------------------------
  669. F(iE) = L;
  670. te = min(te + 1,8);
  671. else
  672. % otherwise, return to previous expansion point and break
  673. %------------------------------------------------------------------
  674. QU = B.QU;
  675. PU = B.PU;
  676. qp = B.qp;
  677. qh = B.qh;
  678. % increase regularisation
  679. %------------------------------------------------------------------
  680. F(iE) = F(end);
  681. te = min(te - 1,0);
  682. end
  683. % report and break if convergence
  684. %======================================================================
  685. if db
  686. figure(Fdem)
  687. spm_DEM_qU(QU)
  688. if np
  689. subplot(nl,4,4*nl)
  690. bar(full(Up*qp.e))
  691. xlabel({'parameters';'{minus prior}'})
  692. axis square, grid on
  693. end
  694. if length(F) > 2
  695. subplot(nl,4,4*nl - 1)
  696. plot(F - F(1))
  697. xlabel('updates')
  698. title('log-evidence')
  699. axis square, grid on
  700. end
  701. drawnow
  702. % report (EM-Steps)
  703. %------------------------------------------------------------------
  704. str{1} = sprintf('ADEM: %i (%i)',iE,iM);
  705. str{2} = sprintf('F:%.4e',full(L - F(1)));
  706. str{3} = sprintf('p:%.2e',full(dp'*dp));
  707. str{4} = sprintf('h:%.2e',full(mh'*mh));
  708. str{5} = sprintf('(%.2e sec)',full(toc));
  709. fprintf('%-16s%-16s%-14s%-14s%-16s\n',str{:})
  710. end
  711. if (norm(dp,1) < exp(-8)) && (norm(mh,1) < exp(-8)), break, end
  712. end
  713. % assemble output arguments
  714. %==========================================================================
  715. % conditional moments of model-parameters (rotated into original space)
  716. %--------------------------------------------------------------------------
  717. qP.P = spm_unvec(Up*qp.e + spm_vec(M.pE),M.pE);
  718. qP.C = Up*qp.c*Up';
  719. qP.V = spm_unvec(diag(qP.C),M.pE);
  720. % conditional moments of hyper-parameters (log-transformed)
  721. %--------------------------------------------------------------------------
  722. qH.h = spm_unvec(qh.h,{{M.hE} {M.gE}});
  723. qH.g = qH.h{2};
  724. qH.h = qH.h{1};
  725. qH.C = qh.c;
  726. qH.V = spm_unvec(diag(qH.C),{{M.hE} {M.gE}});
  727. qH.W = qH.V{2};
  728. qH.V = qH.V{1};
  729. % Fill in DEM with response and its causes
  730. %--------------------------------------------------------------------------
  731. DEM.pP.P = {G.pE}; % parameters encoding process
  732. DEM.M = M; % generative model
  733. DEM.U = U; % causes
  734. DEM.Y = PU.v{1}; % response
  735. DEM.pU = PU; % prior moments of model-states
  736. DEM.qU = QU; % conditional moments of model-states
  737. DEM.qP = qP; % conditional moments of model-parameters
  738. DEM.qH = qH; % conditional moments of hyper-parameters
  739. DEM.F = F; % [-ve] Free energy
  740. try
  741. DEM.J = J; % [-ve] Free energy (over samples)
  742. end