spm_PEB.m 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. function [C,P,F] = spm_PEB(y,P,OPT)
  2. % parametric empirical Bayes (PEB) for hierarchical linear models
  3. % FORMAT [C,P,F] = spm_PEB(y,P,OPT)
  4. %
  5. % y - (n x 1) response variable
  6. %
  7. % MODEL SPECIFICATION
  8. %
  9. % P{i}.X - (n x m) ith level design matrix i.e: constraints on <Eb{i - 1}>
  10. % P{i}.C - {q}(n x n) ith level contraints on Cov{e{i}} = Cov{b{i - 1}}
  11. %
  12. % OPT - enforces positively constraints on the covariance hyperparameters
  13. % by adopting a log-normal [flat] hyperprior. default = 0
  14. %
  15. % POSTERIOR OR CONDITIONAL ESTIMATES
  16. %
  17. % C{i}.E - (n x 1) conditional expectation E{b{i - 1}|y}
  18. % C{i}.C - (n x n) conditional covariance Cov{b{i - 1}|y} = Cov{e{i}|y}
  19. % C{i}.M - (n x n) ML estimate of Cov{b{i - 1}} = Cov{e{i}}
  20. % C{i}.h - (q x 1) ith level ReML hyperparameters for covariance:
  21. % Cov{e{i}} = P{i}.h(1)*P{i}.C{1} + ...
  22. %
  23. % LOG EVIDENCE
  24. %
  25. % F - [-ve] free energy F = log evidence = p(y|X,C)
  26. %
  27. % If P{i}.C is not a cell the covariance at that level is assumed to be kown
  28. % and Cov{e{i}} = P{i}.C (i.e. the hyperparameter is fixed at 1)
  29. %
  30. % If P{n}.C is not a cell this is taken to indicate that a full Bayesian
  31. % estimate is required where P{n}.X is the prior expectation and P{n}.C is
  32. % the known prior covariance. For consistency, with PEB, this is implemented
  33. % by setting b{n} = 1 through appropriate constraints at level {n + 1}.
  34. %
  35. % To implement non-hierarchical Bayes with priors on the parameters use
  36. % a two level model setting the second level design matrix to zeros.
  37. %__________________________________________________________________________
  38. %
  39. % Returns the moments of the posterior p.d.f. of the parameters of a
  40. % hierarchical linear observation model under Gaussian assumptions
  41. %
  42. % y = X{1}*b{1} + e{1}
  43. % b{1} = X{2}*b{2} + e{2}
  44. % ...
  45. %
  46. % b{n - 1} = X{n}*b{n} + e{n}
  47. %
  48. % e{n} ~ N{0,Ce{n}}
  49. %
  50. % using Parametic Emprical Bayes (PEB)
  51. %
  52. % Ref: Dempster A.P., Rubin D.B. and Tsutakawa R.K. (1981) Estimation in
  53. % covariance component models. J. Am. Stat. Assoc. 76;341-353
  54. %__________________________________________________________________________
  55. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  56. % Karl Friston
  57. % $Id: spm_PEB.m 7305 2018-05-07 13:35:06Z karl $
  58. % set default
  59. %--------------------------------------------------------------------------
  60. try
  61. OPT;
  62. catch
  63. OPT = 0;
  64. end
  65. % number of levels (p)
  66. %--------------------------------------------------------------------------
  67. M = 32; % maximum number of iterations
  68. p = length(P);
  69. % check covariance constraints - assume i.i.d. errors conforming to X{i}
  70. %--------------------------------------------------------------------------
  71. for i = 1:p
  72. if ~isfield(P{i},'C')
  73. [n,m] = size(P{i}.X);
  74. if i == 1
  75. P{i}.C = {speye(n,n)};
  76. else
  77. for j = 1:m
  78. k = find(P{i}.X(:,j));
  79. P{i}.C{j} = sparse(k,k,1,n,n);
  80. end
  81. end
  82. end
  83. end
  84. % Construct augmented non-hierarchical model
  85. %==========================================================================
  86. % design matrix and indices
  87. %--------------------------------------------------------------------------
  88. I = {0};
  89. J = {0};
  90. K = {0};
  91. XX = [];
  92. X = 1;
  93. for i = 1:p
  94. % design matrix
  95. %----------------------------------------------------------------------
  96. X = X*P{i}.X;
  97. XX = [XX X];
  98. % indices for ith level parameters
  99. %----------------------------------------------------------------------
  100. [n,m] = size(P{i}.X);
  101. I{i} = (1:n) + I{end}(end);
  102. J{i} = (1:m) + J{end}(end);
  103. end
  104. % augment design matrix and data
  105. %--------------------------------------------------------------------------
  106. n = size(XX,2);
  107. XX = [XX; speye(n,n)];
  108. y = [y; sparse(n,1)];
  109. % last level constraints
  110. %--------------------------------------------------------------------------
  111. n = size(P{p}.X,2);
  112. I{p + 1} = (1:n) + I{end}(end);
  113. q = I{end}(end);
  114. Cb = sparse(q,q);
  115. if ~iscell(P{end}.C)
  116. % Full Bayes: (i.e. Cov(b) = 0, <b> = 1)
  117. %----------------------------------------------------------------------
  118. y(I{end}) = sparse(1:n,1,1);
  119. else
  120. % Empirical Bayes: uniform priors (i.e. Cov(b) = Inf, <b> = 0)
  121. %----------------------------------------------------------------------
  122. Cb(I{end},I{end}) = sparse(1:n,1:n,exp(32));
  123. end
  124. % assemble augmented constraints Q: Cov{e} = Cb + h(i)*Q{i} + ...
  125. %==========================================================================
  126. if ~isfield(P{1},'Q')
  127. % covariance contraints Q on Cov{e{i}} = Cov{b{i - 1}}
  128. %----------------------------------------------------------------------
  129. h = [];
  130. Q = {};
  131. for i = 1:p
  132. % collect constraints on prior covariances - Cov{e{i}}
  133. %------------------------------------------------------------------
  134. if iscell(P{i}.C)
  135. m = length(P{i}.C);
  136. for j = 1:m
  137. [u,v,s] = find(P{i}.C{j});
  138. u = u + I{i}(1) - 1;
  139. v = v + I{i}(1) - 1;
  140. Q{end + 1} = sparse(u,v,s,q,q);
  141. end
  142. % indices for ith-level hyperparameters
  143. %--------------------------------------------------------------
  144. try
  145. K{i} = (1:m) + K{end}(end);
  146. catch
  147. K{i} = (1:m);
  148. end
  149. else
  150. % unless they are known - augment Cb
  151. %--------------------------------------------------------------
  152. [u,v,s] = find(P{i}.C + speye(length(P{i}.C))*1e-6);
  153. u = u + I{i}(1) - 1;
  154. v = v + I{i}(1) - 1;
  155. Cb = Cb + sparse(u,v,s,q,q);
  156. % indices for ith-level hyperparameters
  157. %--------------------------------------------------------------
  158. K{i} = [];
  159. end
  160. end
  161. % note overlapping bases - requiring 2nd order M-Step derivatives
  162. %----------------------------------------------------------------------
  163. m = length(Q);
  164. d = sparse(m,m);
  165. for i = 1:m
  166. XQX{i} = XX'*Q{i}*XX;
  167. end
  168. for i = 1:m
  169. for j = i:m
  170. o = nnz(XQX{i}*XQX{j});
  171. d(i,j) = o;
  172. d(j,i) = o;
  173. end
  174. end
  175. % log-transform and save
  176. %----------------------------------------------------------------------
  177. h = zeros(m,1);
  178. if OPT
  179. hP = speye(m,m)/16;
  180. else
  181. hP = speye(m,m)/exp(16);
  182. for i = 1:m
  183. h(i) = any(diag(Q{i}));
  184. end
  185. end
  186. P{1}.hP = hP;
  187. P{1}.Cb = Cb;
  188. P{1}.Q = Q;
  189. P{1}.h = h;
  190. P{1}.K = K;
  191. P{1}.d = d;
  192. end
  193. hP = P{1}.hP;
  194. Cb = P{1}.Cb;
  195. Q = P{1}.Q;
  196. h = P{1}.h;
  197. K = P{1}.K;
  198. d = P{1}.d;
  199. % Iterative EM
  200. %--------------------------------------------------------------------------
  201. m = length(Q);
  202. dFdh = zeros(m,1);
  203. dFdhh = zeros(m,m);
  204. for k = 1:M
  205. % inv(Cov(e)) - iC(h)
  206. %----------------------------------------------------------------------
  207. Ce = Cb;
  208. for i = 1:m
  209. if OPT
  210. Ce = Ce + Q{i}*exp(h(i));
  211. else
  212. Ce = Ce + Q{i}*h(i);
  213. end
  214. end
  215. iC = spm_inv(Ce,exp(-16));
  216. % E-step: conditional mean E{B|y} and covariance cov(B|y)
  217. %======================================================================
  218. iCX = iC*XX;
  219. Cby = spm_inv(XX'*iCX);
  220. B = Cby*(iCX'*y);
  221. % M-step: ReML estimate of hyperparameters (if m > 0)
  222. %======================================================================
  223. if m == 0, break, end
  224. % Gradient dF/dh (first derivatives)
  225. %----------------------------------------------------------------------
  226. Py = iC*(y - XX*B);
  227. iCXC = iCX*Cby;
  228. for i = 1:m
  229. % dF/dh = -trace(dF/diC*iC*Q{i}*iC)
  230. %------------------------------------------------------------------
  231. PQ{i} = iC*Q{i} - iCXC*(iCX'*Q{i});
  232. if OPT
  233. PQ{i} = PQ{i}*exp(h(i));
  234. end
  235. dFdh(i) = -trace(PQ{i})/2 + y'*PQ{i}*Py/2;
  236. end
  237. % Expected curvature E{ddF/dhh} (second derivatives)
  238. %----------------------------------------------------------------------
  239. for i = 1:m
  240. for j = i:m
  241. if d(i,j)
  242. % ddF/dhh = -trace{P*Q{i}*P*Q{j}}
  243. %----------------------------------------------------------
  244. dFdhh(i,j) = -spm_trace(PQ{i},PQ{j})/2;
  245. dFdhh(j,i) = dFdhh(i,j);
  246. end
  247. end
  248. end
  249. % add hyperpriors
  250. %----------------------------------------------------------------------
  251. dFdhh = dFdhh - hP;
  252. % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh
  253. %----------------------------------------------------------------------
  254. dh = -pinv(dFdhh)*dFdh;
  255. h = h + dh;
  256. % Convergence
  257. %======================================================================
  258. w = norm(dh,1);
  259. fprintf('%-30s: %i %30s%e\n',' PEB Iteration',k,'...',full(dFdh'*dh));
  260. % if dF < 0.01
  261. %----------------------------------------------------------------------
  262. if dFdh'*dh < 1e-2, break, end
  263. % if dh^2 < 1e-8
  264. %----------------------------------------------------------------------
  265. if w < 1e-4, break, end
  266. % if log-normal hyperpriors and h < exp(-16)
  267. %----------------------------------------------------------------------
  268. if OPT && all(h < -16), break, end
  269. end
  270. % place hyperparameters in P{1} and output structure for {n + 1}
  271. %--------------------------------------------------------------------------
  272. P{1}.h = h + exp(-32);
  273. C{p + 1}.E = B(J{p});
  274. C{p + 1}.M = Cb(I{end},I{end});
  275. % recursive computation of conditional means E{b|y}
  276. %--------------------------------------------------------------------------
  277. for i = p:-1:2
  278. C{i}.E = B(J{i - 1}) + P{i}.X*C{i + 1}.E;
  279. end
  280. % hyperpriors - precision
  281. %--------------------------------------------------------------------------
  282. if OPT
  283. h = exp(h);
  284. end
  285. % conditional covariances Cov{b|y} and ReML esimtates of Ce{i) = Cb{i - 1}
  286. %--------------------------------------------------------------------------
  287. for i = 1:p
  288. C{i + 1}.C = Cby(J{i},J{i});
  289. C{i}.M = Ce(I{i},I{i});
  290. C{i}.h = h(K{i});
  291. end
  292. % log evidence = ln p(y|X,C) = F = [-ve] free energy
  293. %--------------------------------------------------------------------------
  294. if nargout > 2
  295. % condotional covariance of h
  296. %----------------------------------------------------------------------
  297. Ph = -dFdhh;
  298. % log evidence = F
  299. %----------------------------------------------------------------------
  300. F = - Py'*Ce*Py/2 ...
  301. - length(I{1})*log(2*pi)/2 ...
  302. - spm_logdet(Ce)/2 ...
  303. - spm_logdet(Ph)/2 ...
  304. + spm_logdet(hP)/2 ...
  305. + spm_logdet(Cby)/2;
  306. end
  307. % warning
  308. %--------------------------------------------------------------------------
  309. if k == M, warning('maximum number of iterations exceeded'), end