123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906 |
- function [DEM] = spm_LAPF(DEM)
- % Laplacian model inversion (see also spm_LAPS)
- % FORMAT DEM = spm_LAPF(DEM)
- %
- % DEM.M - hierarchical model
- % DEM.Y - response variable, output or data
- % DEM.U - explanatory variables, inputs or prior expectation of causes
- %__________________________________________________________________________
- %
- % generative model
- %--------------------------------------------------------------------------
- % M(i).g = v = g(x,v,P) {inline function, string or m-file}
- % M(i).f = dx/dt = f(x,v,P) {inline function, string or m-file}
- %
- % M(i).ph = pi(v) = ph(x,v,h,M) {inline function, string or m-file}
- % M(i).pg = pi(x) = pg(x,v,g,M) {inline function, string or m-file}
- %
- % M(i).pE = prior expectation of p model-parameters
- % M(i).pC = prior covariances of p model-parameters
- % M(i).hE = prior expectation of h log-precision (cause noise)
- % M(i).hC = prior covariances of h log-precision (cause noise)
- % M(i).gE = prior expectation of g log-precision (state noise)
- % M(i).gC = prior covariances of g log-precision (state noise)
- % M(i).xP = precision (states)
- % M(i).Q = precision components (input noise)
- % M(i).R = precision components (state noise)
- % M(i).V = fixed precision (input noise)
- % M(i).W = fixed precision (state noise)
- %
- % M(i).m = number of inputs v(i + 1);
- % M(i).n = number of states x(i);
- % M(i).l = number of output v(i);
- %
- % conditional moments of model-states - q(u)
- %--------------------------------------------------------------------------
- % qU.x = Conditional expectation of hidden states
- % qU.v = Conditional expectation of causal states
- % qU.w = Conditional prediction error (states)
- % qU.z = Conditional prediction error (causes)
- % qU.C = Conditional covariance: cov(v)
- % qU.S = Conditional covariance: cov(x)
- %
- % conditional moments of model-parameters - q(p)
- %--------------------------------------------------------------------------
- % qP.P = Conditional expectation
- % qP.C = Conditional covariance
- %
- % conditional moments of hyper-parameters (log-transformed) - q(h)
- %--------------------------------------------------------------------------
- % qH.h = Conditional expectation (cause noise)
- % qH.g = Conditional expectation (state noise)
- % qH.C = Conditional covariance
- %
- % F = log-evidence = log-marginal likelihood = negative free-energy
- %__________________________________________________________________________
- %
- % spm_LAPF implements a variational scheme under the Laplace
- % approximation to the conditional joint density q on states (u), parameters
- % (p) and hyperparameters (h,g) of any analytic nonlinear hierarchical dynamic
- % model, with additive Gaussian innovations.
- %
- % q(u,p,h,g) = max <L(t)>q
- %
- % L is the ln p(y,u,p,h,g|M) under the model M. The conditional covariances
- % obtain analytically from the curvature of L with respect to the unknowns.
- %__________________________________________________________________________
- % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
-
- % Karl Friston
- % $Id: spm_LAPF.m 6018 2014-05-25 09:24:14Z karl $
- % find or create a DEM figure
- %--------------------------------------------------------------------------
- try
- DEM.M(1).nograph;
- catch
- DEM.M(1).nograph = 0;
- end
- if ~DEM.M(1).nograph
- Fdem = spm_figure('GetWin','DEM');
- end
- % check model, data and priors
- %==========================================================================
- [M,Y,U] = spm_DEM_set(DEM);
- % number of iterations
- %--------------------------------------------------------------------------
- try, nD = M(1).E.nD; catch, nD = 1; end
- try, nN = M(1).E.nN; catch, nN = 16; end
- % ensure integration scheme evaluates gradients at each time-step
- %--------------------------------------------------------------------------
- M(1).E.linear = 4;
- % assume precisions are a function of, and only of, hyperparameters
- %--------------------------------------------------------------------------
- try
- method = M(1).E.method;
- catch
- method.h = 1;
- method.g = 1;
- method.x = 0;
- method.v = 0;
- end
- try method.h; catch, method.h = 0; end
- try method.g; catch, method.g = 0; end
- try method.x; catch, method.x = 0; end
- try method.v; catch, method.v = 0; end
- % assume precisions are a function of, and only of, hyperparameters
- %--------------------------------------------------------------------------
- try
- form = M(1).E.form;
- catch
- form = 'Gaussian';
- end
- % checks for Laplace models (precision functions; ph and pg)
- %--------------------------------------------------------------------------
- for i = 1:length(M)
- try
- feval(M(i).ph,M(i).x,M(i + 1).v,M(i).hE,M(i)); method.v = 1;
- catch
- M(i).ph = inline('spm_LAP_ph(x,v,h,M)','x','v','h','M');
- end
- try
- feval(M(i).pg,M(i).x,M(i + 1).v,M(i).gE,M(i)); method.x = 1;
- catch
- M(i).pg = inline('spm_LAP_pg(x,v,h,M)','x','v','h','M');
- end
- end
- M(1).E.method = method;
-
- % order parameters (d = n = 1 for static models) and checks
- %==========================================================================
- d = M(1).E.d + 1; % embedding order of q(v)
- n = M(1).E.n + 1; % embedding order of q(x)
-
- % number of states and parameters
- %--------------------------------------------------------------------------
- ns = size(Y,2); % number of samples
- nl = size(M,2); % number of levels
- nv = sum(spm_vec(M.m)); % number of v (casual states)
- nx = sum(spm_vec(M.n)); % number of x (hidden states)
- ny = M(1).l; % number of y (inputs)
- nc = M(end).l; % number of c (prior causes)
- nu = nv*d + nx*n; % number of generalised states
- ne = nv*n + nx*n + ny*n; % number of generalised errors
-
- % precision (R) of generalised errors and null matrices for concatenation
- %==========================================================================
- s = M(1).E.s;
- Rh = spm_DEM_R(n,s,form);
- Rg = spm_DEM_R(n,s,form);
- W = sparse(nx*n,nx*n);
- V = sparse((ny + nv)*n,(ny + nv)*n);
- % fixed priors on states (u)
- %--------------------------------------------------------------------------
- Px = kron(sparse(1,1,1,n,n),spm_cat(spm_diag({M.xP})));
- Pv = kron(sparse(1,1,1,d,d),sparse(nv,nv));
- pu.ic = spm_cat(spm_diag({Px Pv}));
-
- % hyperpriors
- %--------------------------------------------------------------------------
- ph.h = spm_vec({M.hE M.gE}); % prior expectation of h,g
- ph.c = spm_cat(spm_diag({M.hC M.gC})); % prior covariances of h,g
- ph.ic = spm_pinv(ph.c); % prior precision of h,g
-
- qh.h = {M.hE}; % conditional expectation h
- qh.g = {M.gE}; % conditional expectation g
- nh = length(spm_vec(qh.h)); % number of hyperparameters h
- ng = length(spm_vec(qh.g)); % number of hyperparameters g
- nb = nh + ng; % number of hyerparameters
- % priors on parameters (in reduced parameter space)
- %==========================================================================
- pp.c = cell(nl,nl);
- qp.p = cell(nl,1);
- for i = 1:(nl - 1)
-
- % eigenvector reduction: p <- pE + qp.u*qp.p
- %----------------------------------------------------------------------
- qp.u{i} = spm_svd(M(i).pC); % basis for parameters
- M(i).p = size(qp.u{i},2); % number of qp.p
- qp.p{i} = sparse(M(i).p,1); % initial deviates
- pp.c{i,i} = qp.u{i}'*M(i).pC*qp.u{i}; % prior covariance
-
- end
- Up = spm_cat(spm_diag(qp.u));
-
- % priors on parameters
- %--------------------------------------------------------------------------
- pp.p = spm_vec(M.pE);
- pp.c = spm_cat(pp.c);
- pp.ic = spm_inv(pp.c);
-
- % initialise conditional density q(p)
- %--------------------------------------------------------------------------
- for i = 1:(nl - 1)
- try
- qp.p{i} = qp.p{i} + qp.u{i}'*(spm_vec(M(i).P) - spm_vec(M(i).pE));
- end
- end
- np = size(Up,2);
- % initialise cell arrays for D-Step; e{i + 1} = (d/dt)^i[e] = e[i]
- %==========================================================================
- qu.x = cell(n,1);
- qu.v = cell(n,1);
- qu.y = cell(n,1);
- qu.u = cell(n,1);
- [qu.x{:}] = deal(sparse(nx,1));
- [qu.v{:}] = deal(sparse(nv,1));
- [qu.y{:}] = deal(sparse(ny,1));
- [qu.u{:}] = deal(sparse(nc,1));
-
- % initialise cell arrays for hierarchical structure of x[0] and v[0]
- %--------------------------------------------------------------------------
- x = {M(1:end - 1).x};
- v = {M(1 + 1:end).v};
- qu.x{1} = spm_vec(x);
- qu.v{1} = spm_vec(v);
-
- % derivatives for Jacobian of D-step
- %--------------------------------------------------------------------------
- Dx = kron(spm_speye(n,n,1),spm_speye(nx,nx));
- Dv = kron(spm_speye(d,d,1),spm_speye(nv,nv));
- Dy = kron(spm_speye(n,n,1),spm_speye(ny,ny));
- Dc = kron(spm_speye(d,d,1),spm_speye(nc,nc));
- Du = spm_cat(spm_diag({Dx,Dv}));
- Ip = spm_speye(np,np);
- Ih = spm_speye(nb,nb);
- qp.dp = sparse(np,1); % conditional expectation of dp/dt
- qh.dp = sparse(nb,1); % conditional expectation of dh/dt
- % precision of fluctuations on parameters of hyperparameters
- %--------------------------------------------------------------------------
- Kp = ns*Ip;
- Kh = ns*Ih;
- % gradients of generalised weighted errors
- %--------------------------------------------------------------------------
- dedh = sparse(nh,ne);
- dedg = sparse(ng,ne);
- dedv = sparse(nv,ne);
- dedx = sparse(nx,ne);
- dedhh = sparse(nh,nh);
- dedgg = sparse(ng,ng);
-
- % curvatures of Gibb's energy w.r.t. hyperparameters
- %--------------------------------------------------------------------------
- dHdh = sparse(nh, 1);
- dHdg = sparse(ng, 1);
- dHdp = sparse(np, 1);
- dHdx = sparse(nx*n,1);
- dHdv = sparse(nv*d,1);
- % preclude unnecessary iterations and set switchs
- %--------------------------------------------------------------------------
- if ~np && ~nh && ~ng, nN = 1; end
- mnx = nx*~~method.x;
- mnv = nv*~~method.v;
- % Iterate Lapalace scheme
- %==========================================================================
- Fa = -Inf;
- for iN = 1:nN
-
- % get time and clear persistent variables in evaluation routines
- %----------------------------------------------------------------------
- tic; clear spm_DEM_eval
-
- % [re-]set states & their derivatives
- %----------------------------------------------------------------------
- try, qu = Q(1).u; end
-
-
- % D-Step: (nD D-Steps for each sample)
- %======================================================================
- for is = 1:ns
-
- % D-Step: until convergence for static systems
- %==================================================================
- for iD = 1:nD
-
- % sampling time
- %--------------------------------------------------------------
- ts = is + (iD - 1)/nD;
-
- % derivatives of responses and inputs
- %--------------------------------------------------------------
- try
- qu.y(1:n) = spm_DEM_embed(Y,n,ts,1,M(1).delays);
- qu.u(1:d) = spm_DEM_embed(U,d,ts);
- catch
- qu.y(1:n) = spm_DEM_embed(Y,n,ts);
- qu.u(1:d) = spm_DEM_embed(U,d,ts);
- end
-
-
- % evaluate functions and derivatives
- %==============================================================
-
- % prediction errors (E) and precision vectors (p)
- %--------------------------------------------------------------
- [E,dE] = spm_DEM_eval(M,qu,qp);
- [p,dp] = spm_LAP_eval(M,qu,qh);
-
-
- % gradients of log(det(iS)) dDd...
- %==============================================================
-
- % get precision matrices
- %--------------------------------------------------------------
- iSh = diag(exp(p.h));
- iSg = diag(exp(p.g));
- iS = blkdiag(kron(Rh,iSh),kron(Rg,iSg));
-
-
- % gradients of trace(diag(p)) = sum(p); p = precision vector
- %--------------------------------------------------------------
- dpdx = n*sum(spm_cat({dp.h.dx; dp.g.dx}));
- dpdv = n*sum(spm_cat({dp.h.dv; dp.g.dv}));
- dpdh = n*sum(dp.h.dh);
- dpdg = n*sum(dp.g.dg);
- dpdx = kron(sparse(1,1,1,1,n),dpdx);
- dpdv = kron(sparse(1,1,1,1,d),dpdv);
- dDdu = [dpdx dpdv]';
- dDdh = [dpdh dpdg]';
-
-
- % gradients precision-weighted generalised error dSd..
- %==============================================================
- % gradients w.r.t. hyperparameters
- %--------------------------------------------------------------
- for i = 1:nh
- diS = diag(dp.h.dh(:,i).*exp(p.h));
- diSdh{i} = blkdiag(kron(Rh,diS),W);
- dedh(i,:) = E'*diSdh{i};
- end
- for i = 1:ng
- diS = diag(dp.g.dg(:,i).*exp(p.g));
- diSdg{i} = blkdiag(V,kron(Rg,diS));
- dedg(i,:) = E'*diSdg{i};
- end
- % gradients w.r.t. hidden states
- %--------------------------------------------------------------
- for i = 1:mnx
- diV = diag(dp.h.dx(:,i).*exp(p.h));
- diW = diag(dp.g.dx(:,i).*exp(p.g));
- diSdx{i} = blkdiag(kron(Rh,diV),kron(Rg,diW));
- dedx(i,:) = E'*diSdx{i};
- end
-
- % gradients w.r.t. causal states
- %--------------------------------------------------------------
- for i = 1:mnv
- diV = diag(dp.h.dv(:,i).*exp(p.h));
- diW = diag(dp.g.dv(:,i).*exp(p.g));
- diSdv{i} = blkdiag(kron(Rh,diV),kron(Rg,diW));
- dedv(i,:) = E'*diSdv{i};
- end
- dSdx = kron(sparse(1,1,1,n,1),dedx);
- dSdv = kron(sparse(1,1,1,d,1),dedv);
- dSdu = [dSdx; dSdv];
- dEdh = [dedh; dedg];
- dEdp = dE.dp'*iS;
- dEdu = dE.du'*iS;
-
- % curvatures w.r.t. hyperparameters
- %--------------------------------------------------------------
- for i = 1:nh
- for j = i:nh
- diS = diag(dp.h.dh(:,i).*dp.h.dh(:,j).*exp(p.h));
- diS = blkdiag(kron(Rh,diS),W);
- dedhh(i,j) = E'*diS*E;
- dedhh(j,i) = dedhh(i,j);
- end
- end
- for i = 1:ng
- for j = i:ng
- diS = diag(dp.g.dg(:,i).*dp.g.dg(:,j).*exp(p.g));
- diS = blkdiag(V,kron(Rg,diS));
- dedgg(i,j) = E'*diS*E;
- dedgg(j,i) = dedgg(i,j);
- end
- end
-
- % combined curvature
- %--------------------------------------------------------------
- dSdhh = spm_cat({dedhh [] ;
- [] dedgg});
-
-
- % errors (from prior expectations) (NB pp.p = 0)
- %--------------------------------------------------------------
- Eu = spm_vec(qu.x(1:n),qu.v(1:d));
- Ep = spm_vec(qp.p);
- Eh = spm_vec(qh.h,qh.g) - ph.h;
-
- % first-order derivatives of Gibb's Energy
- %==============================================================
- dLdu = dEdu*E + dSdu*E/2 - dDdu/2 + pu.ic*Eu;
- dLdh = dEdh*E/2 - dDdh/2 + ph.ic*Eh;
- dLdp = dEdp*E + pp.ic*Ep;
-
-
- % and second-order derivatives of Gibb's Energy
- %--------------------------------------------------------------
- % dLduu = dEdu*dE.du + dSdu*dE.du + dE.du'*dSdu' + pu.ic;
- % dLdup = dEdu*dE.dp + dSdu*dE.dp;
- dLduu = dEdu*dE.du + pu.ic;
- dLdpp = dEdp*dE.dp + pp.ic;
- dLdhh = dSdhh/2 + ph.ic;
- dLdup = dEdu*dE.dp;
- dLdhu = dEdh*dE.du;
- dLduy = dEdu*dE.dy;
- dLduc = dEdu*dE.dc;
- dLdpy = dEdp*dE.dy;
- dLdpc = dEdp*dE.dc;
- dLdhy = dEdh*dE.dy;
- dLdhc = dEdh*dE.dc;
- dLdhp = dEdh*dE.dp;
- dLdpu = dLdup';
- dLdph = dLdhp';
-
- % precision and covariances
- %--------------------------------------------------------------
- iC = spm_cat({dLduu dLdup;
- dLdpu dLdpp});
-
- C = spm_inv(iC);
-
- % first-order derivatives of Entropy term
- %==============================================================
-
- % log-precision
- %--------------------------------------------------------------
- for i = 1:nh
- Luub = dE.du'*diSdh{i}*dE.du;
- Lpub = dE.dp'*diSdh{i}*dE.du;
- Lppb = dE.dp'*diSdh{i}*dE.dp;
- diCdh = spm_cat({Luub Lpub';
- Lpub Lppb});
- dHdh(i) = sum(sum(diCdh.*C))/2;
- end
- for i = 1:ng
- Luub = dE.du'*diSdg{i}*dE.du;
- Lpub = dE.dp'*diSdg{i}*dE.du;
- Lppb = dE.dp'*diSdg{i}*dE.dp;
- diCdg = spm_cat({Luub Lpub';
- Lpub Lppb});
- dHdg(i) = sum(sum(diCdg.*C))/2;
- end
-
- % parameters
- %--------------------------------------------------------------
- for i = 1:np
- Luup = dE.dup{i}'*dEdu';
- Lpup = dEdp*dE.dup{i};
- Luup = Luup + Luup';
- diCdp = spm_cat({Luup Lpup';
- Lpup [] });
- dHdp(i) = sum(sum(diCdp.*C))/2;
- end
- % % hidden and causal states
- % %--------------------------------------------------------------
- % for i = 1:mnx
- % Luux = dE.du'*diSdx{i}*dE.du;
- % Lpux = dE.dp'*diSdx{i}*dE.du;
- % Lppx = dE.dp'*diSdx{i}*dE.dp;
- % diCdx = spm_cat({Luux Lpux';
- % Lpux Lppx});
- % dHdx(i) = sum(sum(diCdx.*C))/2;
- %
- % end
- % for i = 1:mnv
- % Luuv = dE.du'*diSdv{i}*dE.du;
- % Lpuv = dE.dp'*diSdv{i}*dE.du;
- % Lppv = dE.dp'*diSdv{i}*dE.dp;
- % diCdv = spm_cat({Luuv Lpuv';
- % Lpuv Lppv});
- % dHdv(i) = sum(sum(diCdv.*C))/2;
- % end
- dHdb = [dHdh; dHdg];
- dHdu = [dHdx; dHdv];
- % save conditional moments (and prediction error) at Q{t}
- %==============================================================
- if iD == 1
-
- % save means
- %----------------------------------------------------------
- Q(is).e = E;
- Q(is).E = iS*E;
- Q(is).u = qu;
- Q(is).p = qp;
- Q(is).h = qh;
-
- % and conditional covariances
- %----------------------------------------------------------
- Q(is).u.s = C((1:nx),(1:nx));
- Q(is).u.c = C((1:nv) + nx*n, (1:nv) + nx*n);
- Q(is).p.c = C((1:np) + nu, (1:np) + nu);
- Q(is).h.c = spm_inv(dLdhh);
- Cu = C(1:nu,1:nu);
- % Free-energy (states)
- %----------------------------------------------------------
- L(is) = ...
- - E'*iS*E/2 + spm_logdet(iS)/2 - n*ny*log(2*pi)/2 ...
- - Eu'*pu.ic*Eu/2 + spm_logdet(pu.ic)/2 + spm_logdet(Cu)/2;
-
- % Free-energy (states and parameters)
- %----------------------------------------------------------
- A(is) = - E'*iS*E/2 + spm_logdet(iS)/2 ...
- - Eu'*pu.ic*Eu/2 + spm_logdet(pu.ic)/2 ...
- - Ep'*pp.ic*Ep/2 + spm_logdet(pp.ic)/2 ...
- - Eh'*ph.ic*Eh/2 + spm_logdet(ph.ic)/2 ...
- - n*ny*log(2*pi)/2 - spm_logdet(iC)/2 - spm_logdet(dLdhh)/2;
- end
-
- % update conditional moments
- %==============================================================
-
- % uopdate curvatures of [hyper]paramters
- %--------------------------------------------------------------
- try
- dLdPP = dLdPP*(1 - 1/ns) + dLdpp/ns;
- dLdHH = dLdHH*(1 - 1/ns) + dLdhh/ns;
- catch
- dLdPP = dLdpp;
- dLdHH = dLdhh;
- end
- % rotate and scale gradient (and curvatures)
- %--------------------------------------------------------------
- [Vp,Sp] = spm_svd(dLdPP,0);
- [Vh,Sh] = spm_svd(dLdHH,0);
- Sp = diag(1./(diag(sqrt(Sp))));
- Sh = diag(1./(diag(sqrt(Sh))));
-
- dLdp = Sp*Vp'*dLdp;
- dHdp = Sp*Vp'*dHdp;
- dLdpy = Sp*Vp'*dLdpy;
- dLdpu = Sp*Vp'*dLdpu;
- dLdpc = Sp*Vp'*dLdpc;
- dLdph = Sp*Vp'*dLdph;
- dLdpp = Sp*Vp'*dLdpp*Vp;
- dLdhp = dLdhp*Vp;
-
- dLdh = Sh*Vh'*dLdh;
- dHdb = Sh*Vh'*dHdb;
- dLdhy = Sh*Vh'*dLdhy;
- dLdhu = Sh*Vh'*dLdhu;
- dLdhc = Sh*Vh'*dLdhc;
- dLdhp = Sh*Vh'*dLdhp;
- dLdhh = Sh*Vh'*dLdhh*Vh;
- dLdph = dLdph*Vh;
-
- % assemble conditional means
- %--------------------------------------------------------------
- q{1} = qu.y(1:n);
- q{2} = qu.x(1:n);
- q{3} = qu.v(1:d);
- q{4} = qu.u(1:d);
- q{5} = spm_unvec(Vp'*spm_vec(qp.p),qp.p);
- qb = spm_unvec(Vh'*spm_vec({qh.h qh.g}),{qh.h qh.g});
- q{6} = qb{1};
- q{7} = qb{2};
- q{8} = Vp'*qp.dp;
- q{9} = Vh'*qh.dp;
-
- % flow
- %--------------------------------------------------------------
- f{1} = Dy*spm_vec(q{1});
- f{2} = Du*spm_vec(q{2:3}) - dLdu - dHdu;
- f{3} = Dc*spm_vec(q{4});
- f{4} = spm_vec(q{8});
- f{5} = spm_vec(q{9});
- f{6} = -Kp*spm_vec(q{8}) - dLdp - dHdp;
- f{7} = -Kh*spm_vec(q{9}) - dLdh - dHdb;
-
-
- % and Jacobian
- %--------------------------------------------------------------
- dfdq = spm_cat({Dy [] [] [] [] [] [];
- -dLduy Du-dLduu -dLduc [] [] [] [];
- [] [] Dc [] [] [] [];
- [] [] [] [] [] Ip [];
- [] [] [] [] [] [] Ih;
- -dLdpy -dLdpu -dLdpc -dLdpp -dLdph -Kp [];
- -dLdhy -dLdhu -dLdhc -dLdhp -dLdhh [] -Kh});
-
-
- % update conditional modes of states
- %==============================================================
- dq = spm_dx(dfdq, spm_vec(f), 1/nD);
- q = spm_unvec(spm_vec(q) + dq,q);
-
- % unpack conditional means
- %--------------------------------------------------------------
- qu.x(1:n) = q{2};
- qu.v(1:d) = q{3};
- qp.p = spm_unvec(Vp*spm_vec(q{5}),qp.p);
- qb = spm_unvec(Vh*spm_vec(q{6:7}),{qh.h qh.g});
- qh.h = qb{1};
- qh.g = qb{2};
- qp.dp = Vp*q{8};
- qh.dp = Vh*q{9};
-
- end % D-Step
-
- end % sequence (ns)
-
-
- % Bayesian parameter averaging
- %======================================================================
- % Conditional moments of time-averaged parameters
- %----------------------------------------------------------------------
- Pp = 0;
- Ep = 0;
- for i = 1:ns
- P = spm_inv(Q(i).p.c);
- Ep = Ep + P*spm_vec(Q(i).p.p);
- Pp = Pp + P;
- end
- Cp = spm_inv(Pp);
- Ep = Cp*Ep;
- % conditional moments of hyper-parameters
- %----------------------------------------------------------------------
- Ph = 0;
- Eh = 0;
- for i = 1:ns
- P = spm_inv(Q(i).h.c);
- Ph = Ph + P;
- Eh = Eh + P*spm_vec({Q(i).h.h Q(i).h.g});
- end
- Ch = spm_inv(Ph);
- Eh = Ch*Eh - ph.h;
- % Free-action of states plus free-energy of parameters
- %======================================================================
- Fs = sum(A);
- Fi = sum(L) ...
- - Ep'*pp.ic*Ep/2 + spm_logdet(pp.ic)/2 - spm_logdet(Pp)/2 ...
- - Eh'*ph.ic*Eh/2 + spm_logdet(ph.ic)/2 - spm_logdet(Ph)/2;
- % if F is increasing terminate
- %----------------------------------------------------------------------
- if Fi < Fa && iN > 4
- break
- else
- Fa = Fi;
- F(iN) = Fi;
- S(iN) = Fs;
- end
-
- % otherwise save conditional moments (for each time point)
- %======================================================================
- for t = 1:length(Q)
-
-
- % states and predictions
- %------------------------------------------------------------------
- v = spm_unvec(Q(t).u.v{1},v);
- x = spm_unvec(Q(t).u.x{1},x);
- z = spm_unvec(Q(t).e(1:(ny + nv)),{M.v});
- Z = spm_unvec(Q(t).E(1:(ny + nv)),{M.v});
- w = spm_unvec(Q(t).e((1:nx) + (ny + nv)*n),{M.x});
- X = spm_unvec(Q(t).E((1:nx) + (ny + nv)*n),{M.x});
- for i = 1:(nl - 1)
- if M(i).m, qU.v{i + 1}(:,t) = spm_vec(v{i}); end
- if M(i).n, qU.x{i}(:,t) = spm_vec(x{i}); end
- if M(i).n, qU.w{i}(:,t) = spm_vec(w{i}); end
- if M(i).l, qU.z{i}(:,t) = spm_vec(z{i}); end
- if M(i).n, qU.W{i}(:,t) = spm_vec(X{i}); end
- if M(i).l, qU.Z{i}(:,t) = spm_vec(Z{i}); end
- end
- if M(nl).l, qU.z{nl}(:,t) = spm_vec(z{nl}); end
- if M(nl).l, qU.Z{nl}(:,t) = spm_vec(Z{nl}); end
- qU.v{1}(:,t) = spm_vec(Q(t).u.y{1}) - spm_vec(z{1});
-
- % and conditional covariances
- %------------------------------------------------------------------
- qU.S{t} = Q(t).u.s;
- qU.C{t} = Q(t).u.c;
-
- % parameters
- %------------------------------------------------------------------
- qP.p{t} = spm_vec(Q(t).p.p);
- qP.c{t} = Q(t).p.c;
-
- % hyperparameters
- %------------------------------------------------------------------
- qH.p{t} = spm_vec({Q(t).h.h Q(t).h.g});
- qH.c{t} = Q(t).h.c;
-
- end
-
- % graphics (states)
- %----------------------------------------------------------------------
- figure(Fdem)
- spm_DEM_qU(qU)
-
- % graphics (parameters and log-precisions)
- %----------------------------------------------------------------------
- if np && nb
- subplot(2*nl,2,4*nl - 2)
- plot(1:ns,spm_cat(qP.p))
- set(gca,'XLim',[1 ns])
- title('parameters (modes)','FontSize',16)
-
- subplot(2*nl,2,4*nl)
- plot(1:ns,spm_cat(qH.p))
- set(gca,'XLim',[1 ns])
- title('log-precision','FontSize',16)
-
- elseif nb
- subplot(nl,2,2*nl)
- plot(1:ns,spm_cat(qH.p))
- set(gca,'XLim',[1 ns])
- title('log-precision','FontSize',16)
-
- elseif np
- subplot(nl,2,2*nl)
- plot(1:ns,spm_cat(qP.p))
- set(gca,'XLim',[1 ns])
- title('parameters (modes)','FontSize',16)
- end
- drawnow
-
- % report (EM-Steps)
- %----------------------------------------------------------------------
- try
- dF = F(end) - F(end - 1);
- catch
- dF = 0;
- end
- str{1} = sprintf('LAP: %i (%i)', iN,iD);
- str{2} = sprintf('F:%.4e', full(F(iN) - F(1)));
- str{3} = sprintf('dF:%.2e', full(dF));
- str{4} = sprintf('(%.2e sec)', full(toc));
- fprintf('%-16s%-16s%-14s%-16s\n',str{:})
-
- end
-
-
- % Place Bayesian parameter averages in output arguments
- %==========================================================================
-
- % Conditional moments of time-averaged parameters
- %--------------------------------------------------------------------------
- Pp = 0;
- Ep = 0;
- for i = 1:ns
-
- % weight in proportion to precisions
- %----------------------------------------------------------------------
- P = spm_inv(qP.c{i});
- Ep = Ep + P*qP.p{i};
- Pp = Pp + P;
-
- end
- Cp = spm_inv(Pp);
- Ep = Cp*Ep;
- P = {M.pE};
- qP.P = spm_unvec(Up*Ep + pp.p,P);
- qP.C = Up*Cp*Up';
- qP.V = spm_unvec(diag(qP.C),P);
- qP.U = Up;
-
- % conditional moments of hyper-parameters
- %--------------------------------------------------------------------------
- Ph = 0;
- Eh = 0;
- for i = 1:ns
-
- % weight in proportion to precisions
- %----------------------------------------------------------------------
- P = spm_inv(qH.c{i});
- Ph = Ph + P;
- Eh = Eh + P*qH.p{i};
-
- end
- Ch = spm_inv(Ph);
- Eh = Ch*Eh;
- P = {qh.h qh.g};
- P = spm_unvec(Eh,P);
- qH.h = P{1};
- qH.g = P{2};
- qH.C = Ch;
- P = spm_unvec(diag(qH.C),P);
- qH.V = P{1};
- qH.W = P{2};
-
-
- % assign output variables
- %--------------------------------------------------------------------------
- DEM.M = M; % model
- DEM.U = U; % causes
-
- DEM.qU = qU; % conditional moments of model-states
- DEM.qP = qP; % conditional moments of model-parameters
- DEM.qH = qH; % conditional moments of hyper-parameters
-
- DEM.F = F; % [-ve] Free energy
- DEM.S = S; % [-ve] Free action
- return
- % Notes (check on curvature)
- %==========================================================================
- % analytic form
- %----------------------------------------------------------
- iC = spm_cat({dLduu dLdup dLduh;
- dLdpu dLdpp dLdph;
- dLdhu dLdhp dLdhh});
-
- % numerical approximations
- %----------------------------------------------------------
- qq.x = qu.x(1:n);
- qq.v = qu.v(1:d);
- qq.p = qp.p;
- qq.h = qh.h;
- qq.g = qh.g;
- dLdqq = spm_diff('spm_LAP_F',qq,qu,qp,qh,pu,pp,ph,M,[1 1]);
- dLdqq = spm_cat(dLdqq');
-
- subplot(2,2,1);imagesc(dLdqq); axis square
- subplot(2,2,2);imagesc(iC); axis square
- subplot(2,2,3);imagesc(dLdqq - iC);axis square
- subplot(2,2,4);plot(iC,':k');hold on;
- plot(dLdqq - iC,'r');hold off; axis square
- drawnow
-
- % Notes (descent on parameters
- %==========================================================================
- I = eye(length(dLdpp));
- k = kp;
- Luu = dLdpp;
- J = spm_cat({[] I;
- -Luu -k*I});
- [u s] = eig(full(J));
- max(diag(s))
- [uj sj] = eig(full(dLdpp));
- Luu = min(diag(sj));
- % Luu = max(diag(sj));
- k = kp;
- ss(1) = -(k + sqrt(k^2 - 4*Luu))/2;
- ss(2) = -(k - sqrt(k^2 - 4*Luu))/2;
- max(ss)
- k = (1:128);
- s = -(k - sqrt(k.^2 - 4*Luu))/2;
- plot(k,-1./real(s))
|