spm_DEM_eval.m 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. function [E,dE,f,g] = spm_DEM_eval(M,qu,qp)
  2. % evaluates state equations and derivatives for DEM schemes
  3. % FORMAT [E dE f g] = spm_DEM_eval(M,qu,qp)
  4. %
  5. % M - model structure
  6. % qu - conditional mode of states
  7. % qu.v{i} - casual states
  8. % qu.x(i) - hidden states
  9. % qu.y(i) - response
  10. % qu.u(i) - input
  11. % qp - conditional density of parameters
  12. % qp.p{i} - parameter deviates for i-th level
  13. % qp.u(i) - basis set
  14. % qp.x(i) - expansion point ( = prior expectation)
  15. %
  16. % E - generalised errors (i.e.., y - g(x,v,P); x[1] - f(x,v,P))
  17. %
  18. % dE:
  19. % dE.du - de[1:n]/du
  20. % dE.dy - de[1:n]/dy[1:n]
  21. % dE.dc - de[1:n]/dc[1:d]
  22. % dE.dp - de[1:n]/dp
  23. % dE.dup - d/dp[de[1:n]/du
  24. % dE.dpu - d/du[de[1:n]/dp
  25. %
  26. % where u = x{1:d]; v[1:d]
  27. %
  28. % To accelerate computations one can specify the nature of the model using
  29. % the field:
  30. %
  31. % M(1).E.linear = 0: full - evaluates 1st and 2nd derivatives
  32. % M(1).E.linear = 1: linear - equations are linear in x and v
  33. % M(1).E.linear = 2: bilinear - equations are linear in x, v & x*v
  34. % M(1).E.linear = 3: nonlinear - equations are linear in x, v, x*v, & x*x
  35. % M(1).E.linear = 4: full linear - evaluates 1st derivatives (for generalised
  36. % filtering, where parameters change)
  37. %__________________________________________________________________________
  38. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  39. % Karl Friston
  40. % $Id: spm_DEM_eval.m 6270 2014-11-29 12:04:48Z karl $
  41. % get dimensions
  42. %==========================================================================
  43. nl = size(M,2); % number of levels
  44. ne = sum(spm_vec(M.l)); % number of e (errors)
  45. nv = sum(spm_vec(M.m)); % number of x (causal states)
  46. nx = sum(spm_vec(M.n)); % number of x (hidden states)
  47. np = sum(spm_vec(M.p)); % number of p (parameters)
  48. % evaluate functions at each hierarchical level
  49. %==========================================================================
  50. % Get states {qu.v{1},qu.x{1}} in hierarchical form (v{i},x{i})
  51. %--------------------------------------------------------------------------
  52. v = spm_unvec(qu.v{1},{M(1 + 1:end).v});
  53. x = spm_unvec(qu.x{1},{M(1:end - 1).x});
  54. for i = 1:(nl - 1)
  55. p = spm_unvec(spm_vec(M(i).pE) + qp.u{i}*qp.p{i},M(i).pE);
  56. f{i,1} = feval(M(i).f,x{i},v{i},p);
  57. g{i,1} = feval(M(i).g,x{i},v{i},p);
  58. end
  59. % Get Derivatives
  60. %==========================================================================
  61. persistent D
  62. try
  63. method = M(1).E.linear;
  64. catch
  65. method = 0;
  66. end
  67. switch method
  68. % get derivatives at each iteration of D-step - full evaluation
  69. %----------------------------------------------------------------------
  70. case{0}
  71. D = spm_DEM_eval_diff(x,v,qp,M);
  72. % gradients w.r.t. states
  73. %------------------------------------------------------------------
  74. dedy = D.dedy;
  75. dedc = D.dedc;
  76. dfdy = D.dfdy;
  77. dfdc = D.dfdc;
  78. dgdx = D.dgdx;
  79. dgdv = D.dgdv;
  80. dfdv = D.dfdv;
  81. dfdx = D.dfdx;
  82. dgdxp = D.dgdxp;
  83. dfdxp = D.dfdxp;
  84. dgdvp = D.dgdvp;
  85. dfdvp = D.dfdvp;
  86. % gradients w.r.t. parameters
  87. %------------------------------------------------------------------
  88. dgdp = D.dgdp;
  89. dfdp = D.dfdp;
  90. % linear: assume equations are linear in x and v
  91. %----------------------------------------------------------------------
  92. case{1}
  93. % get derivatives and store expansion point (states)
  94. %------------------------------------------------------------------
  95. if isempty(D)
  96. D = spm_DEM_eval_diff(x,v,qp,M);
  97. D.x = x;
  98. D.v = v;
  99. % gradients w.r.t. states
  100. %--------------------------------------------------------------
  101. dedy = D.dedy;
  102. dedc = D.dedc;
  103. dfdy = D.dfdy;
  104. dfdc = D.dfdc;
  105. dgdx = D.dgdx;
  106. dgdv = D.dgdv;
  107. dfdv = D.dfdv;
  108. dfdx = D.dfdx;
  109. % gradients w.r.t. parameters (state-dependent)
  110. %--------------------------------------------------------------
  111. dgdxp = D.dgdxp;
  112. dfdxp = D.dfdxp;
  113. dgdvp = D.dgdvp;
  114. dfdvp = D.dfdvp;
  115. % gradients w.r.t. parameters
  116. %--------------------------------------------------------------
  117. dgdp = D.dgdp;
  118. dfdp = D.dfdp;
  119. % linear expansion for derivatives w.r.t. parameters
  120. %------------------------------------------------------------------
  121. else
  122. % gradients w.r.t. states
  123. %--------------------------------------------------------------
  124. dedy = D.dedy;
  125. dedc = D.dedc;
  126. dfdy = D.dfdy;
  127. dfdc = D.dfdc;
  128. dgdx = D.dgdx;
  129. dgdv = D.dgdv;
  130. dfdv = D.dfdv;
  131. dfdx = D.dfdx;
  132. dgdxp = D.dgdxp;
  133. dfdxp = D.dfdxp;
  134. dgdvp = D.dgdvp;
  135. dfdvp = D.dfdvp;
  136. % gradients w.r.t. parameters
  137. %--------------------------------------------------------------
  138. dx = spm_vec(qu.x{1}) - spm_vec(D.x);
  139. dv = spm_vec(qu.v{1}) - spm_vec(D.v);
  140. dgdp = D.dgdp;
  141. dfdp = D.dfdp;
  142. for p = 1:np
  143. dgdp(:,p) = D.dgdp(:,p) + D.dgdxp{p}*dx + D.dgdvp{p}*dv;
  144. if nx
  145. dfdp(:,p) = D.dfdp(:,p) + D.dfdxp{p}*dx + D.dfdvp{p}*dv;
  146. end
  147. end
  148. end
  149. % bilinear: assume equations are linear in x and v and x*v
  150. %----------------------------------------------------------------------
  151. case{2}
  152. % get derivatives and store expansion point (states)
  153. %------------------------------------------------------------------
  154. if isempty(D)
  155. % get high-order derivatives
  156. %--------------------------------------------------------------
  157. [Dv D] = spm_diff('spm_DEM_eval_diff',x,v,qp,M,2);
  158. for i = 1:nv, Dv{i} = spm_unvec(Dv{i},D); end
  159. D.x = x;
  160. D.v = v;
  161. D.Dv = Dv;
  162. % gradients w.r.t. states
  163. %--------------------------------------------------------------
  164. dedy = D.dedy;
  165. dedc = D.dedc;
  166. dfdy = D.dfdy;
  167. dfdc = D.dfdc;
  168. dgdx = D.dgdx;
  169. dgdv = D.dgdv;
  170. dfdv = D.dfdv;
  171. dfdx = D.dfdx;
  172. dgdxp = D.dgdxp;
  173. dfdxp = D.dfdxp;
  174. dgdvp = D.dgdvp;
  175. dfdvp = D.dfdvp;
  176. % gradients w.r.t. parameters
  177. %--------------------------------------------------------------
  178. dgdp = D.dgdp;
  179. dfdp = D.dfdp;
  180. % linear expansion for derivatives w.r.t. parameters
  181. %------------------------------------------------------------------
  182. else
  183. % gradients w.r.t. causes and data
  184. %--------------------------------------------------------------
  185. dedy = D.dedy;
  186. dedc = D.dedc;
  187. dfdy = D.dfdy;
  188. dfdc = D.dfdc;
  189. % states (relative to expansion point)
  190. %--------------------------------------------------------------
  191. dv = spm_vec(qu.v{1}) - spm_vec(D.v);
  192. % gradients w.r.t. states
  193. %--------------------------------------------------------------
  194. dgdx = D.dgdx;
  195. dgdv = D.dgdv;
  196. dfdx = D.dfdx;
  197. dfdv = D.dfdv;
  198. for i = 1:nv; dgdx = dgdx + D.Dv{i}.dgdx*dv(i); end
  199. for i = 1:nv; dgdv = dgdv + D.Dv{i}.dgdv*dv(i); end
  200. for i = 1:nv; dfdx = dfdx + D.Dv{i}.dfdx*dv(i); end
  201. for i = 1:nv; dfdv = dfdv + D.Dv{i}.dfdv*dv(i); end
  202. % second-order derivatives
  203. %--------------------------------------------------------------
  204. dgdxp = D.dgdxp;
  205. dgdvp = D.dgdvp;
  206. dfdxp = D.dfdxp;
  207. dfdvp = D.dfdvp;
  208. for p = 1:np
  209. for i = 1:nv; dgdxp{p} = dgdxp{p} + D.Dv{i}.dgdxp{p}*dv(i); end
  210. for i = 1:nv; dgdvp{p} = dgdvp{p} + D.Dv{i}.dgdvp{p}*dv(i); end
  211. for i = 1:nv; dfdxp{p} = dfdxp{p} + D.Dv{i}.dfdxp{p}*dv(i); end
  212. for i = 1:nv; dfdvp{p} = dfdvp{p} + D.Dv{i}.dfdvp{p}*dv(i); end
  213. end
  214. % gradients w.r.t. parameters
  215. %--------------------------------------------------------------
  216. dgdp = D.dgdp;
  217. dfdp = D.dfdp;
  218. for p = 1:np
  219. Dgdxp = (D.dgdxp{p} + dgdxp{p})/2;
  220. Dgdvp = (D.dgdvp{p} + dgdvp{p})/2;
  221. Dfdxp = (D.dfdxp{p} + dfdxp{p})/2;
  222. Dfdvp = (D.dfdvp{p} + dfdvp{p})/2;
  223. dgdp(:,p) = dgdp(:,p) + Dgdvp*dv;
  224. dfdp(:,p) = dfdp(:,p) + Dfdvp*dv;
  225. end
  226. end
  227. % nonlinear: assume equations are bilinear in x and v
  228. %----------------------------------------------------------------------
  229. case{3}
  230. % get derivatives and store expansion point (states)
  231. %------------------------------------------------------------------
  232. if isempty(D)
  233. % get high-order derivatives
  234. %--------------------------------------------------------------
  235. [Dx D] = spm_diff('spm_DEM_eval_diff',x,v,qp,M,1,'q');
  236. [Dv D] = spm_diff('spm_DEM_eval_diff',x,v,qp,M,2,'q');
  237. for i = 1:nx, Dx{i} = spm_unvec(Dx{i},D); end
  238. for i = 1:nv, Dv{i} = spm_unvec(Dv{i},D); end
  239. D.x = x;
  240. D.v = v;
  241. D.Dx = Dx;
  242. D.Dv = Dv;
  243. % gradients w.r.t. states
  244. %--------------------------------------------------------------
  245. dedy = D.dedy;
  246. dedc = D.dedc;
  247. dfdy = D.dfdy;
  248. dfdc = D.dfdc;
  249. dgdx = D.dgdx;
  250. dgdv = D.dgdv;
  251. dfdv = D.dfdv;
  252. dfdx = D.dfdx;
  253. dgdxp = D.dgdxp;
  254. dfdxp = D.dfdxp;
  255. dgdvp = D.dgdvp;
  256. dfdvp = D.dfdvp;
  257. % gradients w.r.t. parameters
  258. %--------------------------------------------------------------
  259. dgdp = D.dgdp;
  260. dfdp = D.dfdp;
  261. % linear expansion for derivatives w.r.t. parameters
  262. %------------------------------------------------------------------
  263. else
  264. % gradients w.r.t. causes and data
  265. %--------------------------------------------------------------
  266. dedy = D.dedy;
  267. dedc = D.dedc;
  268. dfdy = D.dfdy;
  269. dfdc = D.dfdc;
  270. % states (relative to expansion point)
  271. %--------------------------------------------------------------
  272. dx = spm_vec(qu.x{1}) - spm_vec(D.x);
  273. dv = spm_vec(qu.v{1}) - spm_vec(D.v);
  274. % gradients w.r.t. states
  275. %--------------------------------------------------------------
  276. dgdx = D.dgdx;
  277. dgdv = D.dgdv;
  278. dfdx = D.dfdx;
  279. dfdv = D.dfdv;
  280. for i = 1:nx; dgdx = dgdx + D.Dx{i}.dgdx*dx(i); end
  281. for i = 1:nv; dgdx = dgdx + D.Dv{i}.dgdx*dv(i); end
  282. for i = 1:nx; dgdv = dgdv + D.Dx{i}.dgdv*dx(i); end
  283. for i = 1:nv; dgdv = dgdv + D.Dv{i}.dgdv*dv(i); end
  284. for i = 1:nx; dfdx = dfdx + D.Dx{i}.dfdx*dx(i); end
  285. for i = 1:nv; dfdx = dfdx + D.Dv{i}.dfdx*dv(i); end
  286. for i = 1:nx; dfdv = dfdv + D.Dx{i}.dfdv*dx(i); end
  287. for i = 1:nv; dfdv = dfdv + D.Dv{i}.dfdv*dv(i); end
  288. % second-order derivatives
  289. %--------------------------------------------------------------
  290. dgdxp = D.dgdxp;
  291. dgdvp = D.dgdvp;
  292. dfdxp = D.dfdxp;
  293. dfdvp = D.dfdvp;
  294. for p = 1:np
  295. for i = 1:nx; dgdxp{p} = dgdxp{p} + D.Dx{i}.dgdxp{p}*dx(i); end
  296. for i = 1:nv; dgdxp{p} = dgdxp{p} + D.Dv{i}.dgdxp{p}*dv(i); end
  297. for i = 1:nx; dgdvp{p} = dgdvp{p} + D.Dx{i}.dgdvp{p}*dx(i); end
  298. for i = 1:nv; dgdvp{p} = dgdvp{p} + D.Dv{i}.dgdvp{p}*dv(i); end
  299. for i = 1:nx; dfdxp{p} = dfdxp{p} + D.Dx{i}.dfdxp{p}*dx(i); end
  300. for i = 1:nv; dfdxp{p} = dfdxp{p} + D.Dv{i}.dfdxp{p}*dv(i); end
  301. for i = 1:nx; dfdvp{p} = dfdvp{p} + D.Dx{i}.dfdvp{p}*dx(i); end
  302. for i = 1:nv; dfdvp{p} = dfdvp{p} + D.Dv{i}.dfdvp{p}*dv(i); end
  303. end
  304. % gradients w.r.t. parameters
  305. %--------------------------------------------------------------
  306. dgdp = D.dgdp;
  307. dfdp = D.dfdp;
  308. for p = 1:np
  309. Dgdxp = (D.dgdxp{p} + dgdxp{p})/2;
  310. Dgdvp = (D.dgdvp{p} + dgdvp{p})/2;
  311. Dfdxp = (D.dfdxp{p} + dfdxp{p})/2;
  312. Dfdvp = (D.dfdvp{p} + dfdvp{p})/2;
  313. dgdp(:,p) = dgdp(:,p) + Dgdxp*dx + Dgdvp*dv;
  314. dfdp(:,p) = dfdp(:,p) + Dfdxp*dx + Dfdvp*dv;
  315. end
  316. end
  317. % repeated evaluation of first order derivatives (for Laplace scheme)
  318. %----------------------------------------------------------------------
  319. case{4}
  320. % get derivatives and store expansion point (states)
  321. %------------------------------------------------------------------
  322. if isempty(D)
  323. D = spm_DEM_eval_diff(x,v,qp,M);
  324. D.x = x;
  325. D.v = v;
  326. % gradients w.r.t. states
  327. %--------------------------------------------------------------
  328. dedy = D.dedy;
  329. dedc = D.dedc;
  330. dfdy = D.dfdy;
  331. dfdc = D.dfdc;
  332. dgdx = D.dgdx;
  333. dgdv = D.dgdv;
  334. dfdv = D.dfdv;
  335. dfdx = D.dfdx;
  336. % gradients w.r.t. parameters (state-dependent)
  337. %--------------------------------------------------------------
  338. dgdxp = D.dgdxp;
  339. dfdxp = D.dfdxp;
  340. dgdvp = D.dgdvp;
  341. dfdvp = D.dfdvp;
  342. % gradients w.r.t. parameters
  343. %--------------------------------------------------------------
  344. dgdp = D.dgdp;
  345. dfdp = D.dfdp;
  346. % re-evaluate first-order derivatives
  347. %------------------------------------------------------------------
  348. else
  349. % retain second-order gradients
  350. %--------------------------------------------------------------
  351. dgdxp = D.dgdxp;
  352. dfdxp = D.dfdxp;
  353. dgdvp = D.dgdvp;
  354. dfdvp = D.dfdvp;
  355. % re-evaluate first-order gradients
  356. %--------------------------------------------------------------
  357. D = spm_DEM_eval_diff(x,v,qp,M,0);
  358. dedy = D.dedy;
  359. dedc = D.dedc;
  360. dfdy = D.dfdy;
  361. dfdc = D.dfdc;
  362. dgdx = D.dgdx;
  363. dgdv = D.dgdv;
  364. dfdv = D.dfdv;
  365. dfdx = D.dfdx;
  366. % replace second-order gradients
  367. %--------------------------------------------------------------
  368. D.dgdxp = dgdxp;
  369. D.dfdxp = dfdxp;
  370. D.dgdvp = dgdvp;
  371. D.dfdvp = dfdvp;
  372. % gradients w.r.t. parameters
  373. %--------------------------------------------------------------
  374. dx = spm_vec(qu.x{1}) - spm_vec(x);
  375. dv = spm_vec(qu.v{1}) - spm_vec(v);
  376. dgdp = D.dgdp;
  377. dfdp = D.dfdp;
  378. for p = 1:np
  379. dgdp(:,p) = D.dgdp(:,p) + D.dgdxp{p}*dx + D.dgdvp{p}*dv;
  380. if nx
  381. dfdp(:,p) = D.dfdp(:,p) + D.dfdxp{p}*dx + D.dfdvp{p}*dv;
  382. end
  383. end
  384. end
  385. otherwise
  386. disp('Unknown method')
  387. end
  388. % order parameters (d = n = 1 for static models)
  389. %--------------------------------------------------------------------------
  390. d = M(1).E.d + 1; % generalisation order of q(v)
  391. n = M(1).E.n + 1; % embedding order (n >= d)
  392. % Generalised prediction errors and derivatives
  393. %==========================================================================
  394. Ex = cell(n,1);
  395. Ev = cell(n,1);
  396. [Ex{:}] = deal(sparse(nx,1));
  397. [Ev{:}] = deal(sparse(ne,1));
  398. % prediction error (E) - causes
  399. %--------------------------------------------------------------------------
  400. for i = 1:n
  401. qu.y{i} = spm_vec(qu.y{i});
  402. end
  403. Ev{1} = [qu.y{1}; qu.v{1}] - [spm_vec(g); qu.u{1}];
  404. for i = 2:n
  405. Ev{i} = dedy*qu.y{i} + dedc*qu.u{i} ... % generalised response
  406. - dgdx*qu.x{i} - dgdv*qu.v{i}; % and prediction
  407. end
  408. % prediction error (E) - states
  409. %--------------------------------------------------------------------------
  410. try
  411. Ex{1} = qu.x{2} - spm_vec(f);
  412. end
  413. for i = 2:n - 1
  414. Ex{i} = qu.x{i + 1} ... % generalised motion
  415. - dfdx*qu.x{i} - dfdv*qu.v{i}; % and prediction
  416. end
  417. % error
  418. %--------------------------------------------------------------------------
  419. E = spm_vec({Ev,Ex});
  420. % Kronecker forms of derivatives for generalised motion
  421. %==========================================================================
  422. if nargout < 2, return, end
  423. % dE.dp (parameters)
  424. %--------------------------------------------------------------------------
  425. dgdp = {dgdp};
  426. dfdp = {dfdp};
  427. for i = 2:n
  428. dgdp{i,1} = dgdp{1};
  429. dfdp{i,1} = dfdp{1};
  430. for p = 1:np
  431. dgdp{i,1}(:,p) = dgdxp{p}*qu.x{i} + dgdvp{p}*qu.v{i};
  432. dfdp{i,1}(:,p) = dfdxp{p}*qu.x{i} + dfdvp{p}*qu.v{i};
  433. end
  434. end
  435. % generalised temporal derivatives: dE.du (states)
  436. %--------------------------------------------------------------------------
  437. dedy = kron(spm_speye(n,n),dedy);
  438. dedc = kron(spm_speye(n,d),dedc);
  439. dfdy = kron(spm_speye(n,n),dfdy);
  440. dfdc = kron(spm_speye(n,d),dfdc);
  441. dgdx = kron(spm_speye(n,n),dgdx);
  442. dgdv = kron(spm_speye(n,d),dgdv);
  443. dfdv = kron(spm_speye(n,d),dfdv);
  444. dfdx = kron(spm_speye(n,n),dfdx) - kron(spm_speye(n,n,1),speye(nx,nx));
  445. % 1st error derivatives (states)
  446. %--------------------------------------------------------------------------
  447. dE.dy = spm_cat({dedy; dfdy});
  448. dE.dc = spm_cat({dedc; dfdc});
  449. dE.dp = -spm_cat({dgdp; dfdp});
  450. dE.du = -spm_cat({dgdx, dgdv ;
  451. dfdx, dfdv});
  452. % bilinear derivatives
  453. %--------------------------------------------------------------------------
  454. for i = 1:np
  455. dgdxp{i} = kron(spm_speye(n,n),dgdxp{i});
  456. dfdxp{i} = kron(spm_speye(n,n),dfdxp{i});
  457. dgdvp{i} = kron(spm_speye(n,d),dgdvp{i});
  458. dfdvp{i} = kron(spm_speye(n,d),dfdvp{i});
  459. dE.dup{i} = -spm_cat({dgdxp{i}, dgdvp{i};
  460. dfdxp{i}, dfdvp{i}});
  461. end
  462. if np
  463. dE.dpu = spm_cell_swap(dE.dup);
  464. else
  465. dE.dpu = {};
  466. end