spm_reml_A.m 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. function [C,h,Ph,F,Fa,Fc] = spm_reml_A(YY,X,Q,N,hE,hC,V)
  2. % ReML estimation of covariance components from y*y' - factored components
  3. % FORMAT [C,h,Ph,F,Fa,Fc] = spm_reml_A(YY,X,Q,N,[hE,hC,V])
  4. %
  5. % YY - (m x m) sample covariance matrix Y*Y' {Y = (m x N) data matrix}
  6. % X - (m x p) design matrix
  7. % Q - {1 x q} covariance components (factors)
  8. % N - number of samples
  9. %
  10. % hE - hyperprior expectation [default = 0]
  11. % hC - hyperprior covariance [default = 256]
  12. % V - fixed covariance component
  13. %
  14. % C - (m x m) estimated errors: C = A*A': A = h(1)*Q{1} + h(2)*Q{2} + ...
  15. % h - (q x 1) ReML hyperparameters h
  16. % Ph - (q x q) conditional precision of h
  17. %
  18. % F - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective
  19. %
  20. % Fa - accuracy
  21. % Fc - complexity (F = Fa - Fc)
  22. %
  23. % Performs a Fisher-Scoring ascent on F to find MAP variance parameter
  24. % estimates. NB: uses weakly informative normal hyperpriors on the factors.
  25. %
  26. %__________________________________________________________________________
  27. %
  28. % SPM ReML routines:
  29. %
  30. % spm_reml: no positivity constraints on covariance parameters
  31. % spm_reml_sc: positivity constraints on covariance parameters
  32. % spm_sp_reml: for sparse patterns (c.f., ARD)
  33. %
  34. %__________________________________________________________________________
  35. % Copyright (C) 2010-2017 Wellcome Trust Centre for Neuroimaging
  36. % Karl Friston
  37. % $Id: spm_reml_A.m 7192 2017-10-18 14:59:01Z guillaume $
  38. % assume a single sample if not specified
  39. %--------------------------------------------------------------------------
  40. try, N; catch, N = 1; end
  41. try, V; catch, V = 1e-16; end
  42. % initialise h
  43. %--------------------------------------------------------------------------
  44. n = length(Q{1});
  45. m = length(Q);
  46. h = zeros(m,1) + exp(-4);
  47. dFdh = zeros(m,1);
  48. dFdhh = zeros(m,m);
  49. % ortho-normalise X
  50. %--------------------------------------------------------------------------
  51. if isempty(X)
  52. X = sparse(n,0);
  53. else
  54. X = spm_svd(X,0);
  55. end
  56. % check fixed component
  57. %--------------------------------------------------------------------------
  58. if length(V) == 1
  59. V = V*speye(n,n);
  60. end
  61. % initialise and specify hyperpriors
  62. %==========================================================================
  63. % hyperpriors
  64. %--------------------------------------------------------------------------
  65. try, hE = hE(:); catch, hE = 0; end
  66. try, hP = spm_inv(hC); catch, hP = 1/256; end
  67. % check sise
  68. %--------------------------------------------------------------------------
  69. if length(hE) < m, hE = hE(1)*ones(m,1); end
  70. if length(hP) < m, hP = hP(1)*speye(m,m); end
  71. % intialise h: so that sum(exp(h)) = 1
  72. %--------------------------------------------------------------------------
  73. if any(diag(hP) > exp(16))
  74. h = hE;
  75. end
  76. % ReML (EM/VB)
  77. %--------------------------------------------------------------------------
  78. dF = Inf;
  79. t = 4;
  80. for k = 1:32
  81. % compute current estimate of covariance
  82. %----------------------------------------------------------------------
  83. A = 0;
  84. for i = 1:m
  85. A = A + Q{i}*h(i);
  86. end
  87. C = V + A*A';
  88. iC = spm_inv(C);
  89. % E-step: conditional covariance cov(B|y) {Cq}
  90. %======================================================================
  91. iCX = iC*X;
  92. if ~isempty(X)
  93. Cq = inv(X'*iCX);
  94. else
  95. Cq = sparse(0);
  96. end
  97. % M-step: ReML estimate of hyperparameters
  98. %======================================================================
  99. % Gradient dF/dh (first derivatives)
  100. %----------------------------------------------------------------------
  101. P = iC - iCX*Cq*iCX';
  102. U = speye(n) - P*YY/N;
  103. PQ = cell(m,1);
  104. for i = 1:m
  105. % dF/dh
  106. %------------------------------------------------------------------
  107. PQ{i} = P*(A*Q{i}' + Q{i}'*A);
  108. dFdh(i) = -spm_trace(PQ{i},U)*N/2;
  109. end
  110. % Expected curvature E{dF/dhh} (second derivatives)
  111. % dF/dhh = -trace{P*Q{i}*P*Q{j}}
  112. %----------------------------------------------------------------------
  113. for i = 1:m
  114. for j = i:m
  115. dFdhh(i,j) = -spm_trace(PQ{i},PQ{j})*N/2;
  116. dFdhh(j,i) = dFdhh(i,j);
  117. end
  118. end
  119. % add hyperpriors
  120. %----------------------------------------------------------------------
  121. e = h - hE;
  122. dFdh = dFdh - hP*e;
  123. dFdhh = dFdhh - hP;
  124. % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh
  125. %----------------------------------------------------------------------
  126. dh = spm_dx(dFdhh,dFdh,{t});
  127. h = h + dh;
  128. % predicted change in F - increase regularisation if increasing
  129. %----------------------------------------------------------------------
  130. pF = dFdh'*dh;
  131. if pF > dF
  132. t = t - 1;
  133. else
  134. t = t + 1/8;
  135. end
  136. dF = pF;
  137. % convergence
  138. %----------------------------------------------------------------------
  139. fprintf('%s %-23d: %10s%e [%+3.2f]\n',' ReML Iteration',k,'...',full(dF),t);
  140. if dF < 1e-2
  141. break
  142. end
  143. end
  144. % log evidence = ln p(y|X,Q) = ReML objective = F = trace(R'*iC*R*YY)/2 ...
  145. %--------------------------------------------------------------------------
  146. Ph = -dFdhh;
  147. if nargout > 3
  148. % tr(hP*inv(Ph)) - nh (complexity KL cost of parameters = 0)
  149. %----------------------------------------------------------------------
  150. Ft = trace(hP/Ph) - length(Ph);
  151. % complexity - KL(Ph,hP)
  152. %----------------------------------------------------------------------
  153. Fc = Ft/2 + e'*hP*e/2 + spm_logdet(Ph/hP)/2;
  154. % Accuracy - ln p(Y|h)
  155. %----------------------------------------------------------------------
  156. Fa = Ft/2 - trace(C*P*YY*P)/2 - N*n*log(2*pi)/2 - N*spm_logdet(C)/2;
  157. % Free-energy
  158. %----------------------------------------------------------------------
  159. F = Fa - Fc;
  160. end