spm_reml_sc.m 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. function [C,h,Ph,F,Fa,Fc,Eh,Ch,hE,hC,Q] = spm_reml_sc(YY,X,Q,N,hE,hC,V)
  2. % ReML estimation of covariance components from y*y' - proper components
  3. % FORMAT [C,h,Ph,F,Fa,Fc,Eh,Ch,hE,hC,Q] = spm_reml_sc(YY,X,Q,N,[hE,hC,V])
  4. %
  5. % YY - (m x m) sample covariance matrix Y*Y' {Y = (m x N) data matrix}
  6. % X - (m x p) design matrix
  7. % Q - {1 x q} covariance components
  8. % N - number of samples
  9. %
  10. % hE - hyperprior expectation in log-space [default = -32]
  11. % hC - hyperprior covariance in log-space [default = 256]
  12. % V - fixed covariance component
  13. %
  14. % C - (m x m) estimated errors = h(1)*Q{1} + h(2)*Q{2} + ...
  15. % h - (q x 1) ReML hyperparameters h
  16. % Ph - (q x q) conditional precision of log(h)
  17. %
  18. % hE - prior expectation of log scale parameters
  19. % hC - prior covariances of log scale parameters
  20. % Eh - posterior expectation of log scale parameters
  21. % Ch - posterior covariances of log scale parameters
  22. %
  23. % Q - scaled covariance components
  24. %
  25. % F - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective
  26. %
  27. % Fa - accuracy
  28. % Fc - complexity (F = Fa - Fc)
  29. %
  30. % Performs a Fisher-Scoring ascent on F to find MAP variance parameter
  31. % estimates. NB: uses weakly informative log-normal hyperpriors.
  32. % See also spm_reml for an unconstrained version that allows for negative
  33. % hyperparameters.
  34. %
  35. %__________________________________________________________________________
  36. %
  37. % SPM ReML routines:
  38. %
  39. % spm_reml: no positivity constraints on covariance parameters
  40. % spm_reml_sc: positivity constraints on covariance parameters
  41. % spm_sp_reml: for sparse patterns (c.f., ARD)
  42. %
  43. %__________________________________________________________________________
  44. % Copyright (C) 2007-2017 Wellcome Trust Centre for Neuroimaging
  45. % Karl Friston
  46. % $Id: spm_reml_sc.m 7305 2018-05-07 13:35:06Z karl $
  47. % assume a single sample if not specified
  48. %--------------------------------------------------------------------------
  49. try, N; catch, N = 1; end
  50. try, V; catch, V = 0; end
  51. % initialise h
  52. %--------------------------------------------------------------------------
  53. n = length(Q{1});
  54. m = length(Q);
  55. h = zeros(m,1);
  56. dFdh = zeros(m,1);
  57. dFdhh = zeros(m,m);
  58. Inn = speye(n,n);
  59. [PQ{1:m}] = deal(zeros(n,n));
  60. % ortho-normalise X
  61. %--------------------------------------------------------------------------
  62. if isempty(X)
  63. X = sparse(n,0);
  64. R = Inn;
  65. else
  66. X = spm_svd(X,0);
  67. R = Inn - X*X';
  68. end
  69. % check fixed component
  70. %--------------------------------------------------------------------------
  71. if length(V) == 1
  72. V = V*Inn;
  73. end
  74. % initialise and specify hyperpriors
  75. %==========================================================================
  76. % scale Q and YY
  77. %--------------------------------------------------------------------------
  78. sY = spm_trace(R,YY)/(N*n);
  79. YY = YY/sY;
  80. V = V/sY;
  81. for i = 1:m
  82. sh(i,1) = spm_trace(R,Q{i})/n;
  83. Q{i} = Q{i}/sh(i);
  84. end
  85. % hyperpriors
  86. %--------------------------------------------------------------------------
  87. try, hE = hE(:); catch, hE = -32; end
  88. try, hP = spm_inv(hC); catch, hP = 1/256; end
  89. % check sise
  90. %--------------------------------------------------------------------------
  91. if length(hE) < m, hE = hE(1)*ones(m,1); end
  92. if length(hP) < m, hP = hP(1)*speye(m,m); end
  93. % intialise h: so that sum(exp(h)) = 1
  94. %--------------------------------------------------------------------------
  95. if any(diag(hP) > exp(16))
  96. h = hE;
  97. end
  98. % ReML (EM/VB)
  99. %--------------------------------------------------------------------------
  100. dF = Inf;
  101. as = 1:m;
  102. t = 4;
  103. for k = 1:32
  104. % compute current estimate of covariance
  105. %----------------------------------------------------------------------
  106. C = V;
  107. for i = as
  108. C = C + Q{i}*exp(h(i));
  109. end
  110. iC = spm_inv(C);
  111. % E-step: conditional covariance cov(B|y) {Cq}
  112. %======================================================================
  113. iCX = iC*X;
  114. if ~isempty(X)
  115. Cq = inv(X'*iCX);
  116. else
  117. Cq = sparse(0);
  118. end
  119. % M-step: ReML estimate of hyperparameters
  120. %======================================================================
  121. % Gradient dF/dh (first derivatives)
  122. %----------------------------------------------------------------------
  123. P = iC - iCX*Cq*iCX';
  124. U = Inn - P*YY/N;
  125. for i = as
  126. % dF/dh = -trace(dF/diC*iC*Q{i}*iC)
  127. %------------------------------------------------------------------
  128. PQ{i} = P*Q{i};
  129. dFdh(i) = -spm_trace(PQ{i},U)*N/2;
  130. end
  131. % Expected curvature E{dF/dhh} (second derivatives)
  132. %----------------------------------------------------------------------
  133. for i = as
  134. for j = as
  135. % dF/dhh = -trace{P*Q{i}*P*Q{j}}
  136. %--------------------------------------------------------------
  137. dFdhh(i,j) = -spm_trace(PQ{i},PQ{j})*N/2;
  138. dFdhh(j,i) = dFdhh(i,j);
  139. end
  140. end
  141. % modulate
  142. %----------------------------------------------------------------------
  143. dFdh = dFdh.*exp(h);
  144. dFdhh = dFdhh.*(exp(h)*exp(h)');
  145. % add hyperpriors
  146. %----------------------------------------------------------------------
  147. e = h - hE;
  148. dFdh = dFdh - hP*e;
  149. dFdhh = dFdhh - hP;
  150. % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh
  151. %----------------------------------------------------------------------
  152. dh = spm_dx(dFdhh(as,as),dFdh(as),{t});
  153. h(as) = h(as) + dh;
  154. % predicted change in F - increase regularisation if increasing
  155. %----------------------------------------------------------------------
  156. pF = dFdh(as)'*dh;
  157. if pF > dF
  158. t = t - 1;
  159. else
  160. t = t + 1/8;
  161. end
  162. dF = pF;
  163. % convergence
  164. %----------------------------------------------------------------------
  165. fprintf('%s %-23d: %10s%e [%+3.2f]\n',' ReML Iteration',k,'...',full(dF),t);
  166. if dF < 1e-2
  167. break
  168. else
  169. % eliminate redundant components (automatic selection)
  170. %------------------------------------------------------------------
  171. as = find(h > hE);
  172. as = as(:)';
  173. end
  174. end
  175. % log evidence = ln p(y|X,Q) = ReML objective = F = trace(R'*iC*R*YY)/2 ...
  176. %--------------------------------------------------------------------------
  177. Ph = -dFdhh;
  178. if nargout > 3
  179. % tr(hP*inv(Ph)) - nh (complexity KL cost of parameters = 0)
  180. %----------------------------------------------------------------------
  181. Ft = trace(hP/Ph) - length(Ph);
  182. % complexity - KL(Ph,hP)
  183. %----------------------------------------------------------------------
  184. Fc = Ft/2 + e'*hP*e/2 + spm_logdet(Ph/hP)/2;
  185. % Accuracy - ln p(Y|h)
  186. %----------------------------------------------------------------------
  187. Fa = Ft/2 - spm_trace(C*P,YY*P)/2 - N*n*log(2*pi)/2 - N*spm_logdet(C)/2;
  188. % Free-energy
  189. %----------------------------------------------------------------------
  190. F = Fa - Fc - N*n*log(sY)/2;
  191. end
  192. % priors and posteriors of log parameters (with scaling)
  193. %--------------------------------------------------------------------------
  194. if nargout > 7
  195. hE = hE + log(sY) - log(sh);
  196. hC = spm_inv(hP);
  197. Eh = h + log(sY) - log(sh);
  198. Ch = spm_inv(Ph);
  199. end
  200. % return exp(h) hyperpriors and rescale
  201. %--------------------------------------------------------------------------
  202. h = sY*exp(h)./sh;
  203. C = sY*C;