123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- function [C,h,Ph,F,Fa,Fc] = spm_reml_A(YY,X,Q,N,hE,hC,V)
- % ReML estimation of covariance components from y*y' - factored components
- % FORMAT [C,h,Ph,F,Fa,Fc] = spm_reml_A(YY,X,Q,N,[hE,hC,V])
- %
- % YY - (m x m) sample covariance matrix Y*Y' {Y = (m x N) data matrix}
- % X - (m x p) design matrix
- % Q - {1 x q} covariance components (factors)
- % N - number of samples
- %
- % hE - hyperprior expectation [default = 0]
- % hC - hyperprior covariance [default = 256]
- % V - fixed covariance component
- %
- % C - (m x m) estimated errors: C = A*A': A = h(1)*Q{1} + h(2)*Q{2} + ...
- % h - (q x 1) ReML hyperparameters h
- % Ph - (q x q) conditional precision of h
- %
- % F - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective
- %
- % Fa - accuracy
- % Fc - complexity (F = Fa - Fc)
- %
- % Performs a Fisher-Scoring ascent on F to find MAP variance parameter
- % estimates. NB: uses weakly informative normal hyperpriors on the factors.
- %
- %__________________________________________________________________________
- %
- % SPM ReML routines:
- %
- % spm_reml: no positivity constraints on covariance parameters
- % spm_reml_sc: positivity constraints on covariance parameters
- % spm_sp_reml: for sparse patterns (c.f., ARD)
- %
- %__________________________________________________________________________
- % Copyright (C) 2010-2017 Wellcome Trust Centre for Neuroimaging
- % Karl Friston
- % $Id: spm_reml_A.m 7192 2017-10-18 14:59:01Z guillaume $
- % assume a single sample if not specified
- %--------------------------------------------------------------------------
- try, N; catch, N = 1; end
- try, V; catch, V = 1e-16; end
- % initialise h
- %--------------------------------------------------------------------------
- n = length(Q{1});
- m = length(Q);
- h = zeros(m,1) + exp(-4);
- dFdh = zeros(m,1);
- dFdhh = zeros(m,m);
- % ortho-normalise X
- %--------------------------------------------------------------------------
- if isempty(X)
- X = sparse(n,0);
- else
- X = spm_svd(X,0);
- end
- % check fixed component
- %--------------------------------------------------------------------------
- if length(V) == 1
- V = V*speye(n,n);
- end
- % initialise and specify hyperpriors
- %==========================================================================
- % hyperpriors
- %--------------------------------------------------------------------------
- try, hE = hE(:); catch, hE = 0; end
- try, hP = spm_inv(hC); catch, hP = 1/256; end
- % check sise
- %--------------------------------------------------------------------------
- if length(hE) < m, hE = hE(1)*ones(m,1); end
- if length(hP) < m, hP = hP(1)*speye(m,m); end
- % intialise h: so that sum(exp(h)) = 1
- %--------------------------------------------------------------------------
- if any(diag(hP) > exp(16))
- h = hE;
- end
- % ReML (EM/VB)
- %--------------------------------------------------------------------------
- dF = Inf;
- t = 4;
- for k = 1:32
-
- % compute current estimate of covariance
- %----------------------------------------------------------------------
- A = 0;
- for i = 1:m
- A = A + Q{i}*h(i);
- end
- C = V + A*A';
- iC = spm_inv(C);
-
- % E-step: conditional covariance cov(B|y) {Cq}
- %======================================================================
- iCX = iC*X;
- if ~isempty(X)
- Cq = inv(X'*iCX);
- else
- Cq = sparse(0);
- end
-
- % M-step: ReML estimate of hyperparameters
- %======================================================================
-
- % Gradient dF/dh (first derivatives)
- %----------------------------------------------------------------------
- P = iC - iCX*Cq*iCX';
- U = speye(n) - P*YY/N;
- PQ = cell(m,1);
- for i = 1:m
-
- % dF/dh
- %------------------------------------------------------------------
- PQ{i} = P*(A*Q{i}' + Q{i}'*A);
- dFdh(i) = -spm_trace(PQ{i},U)*N/2;
-
- end
-
- % Expected curvature E{dF/dhh} (second derivatives)
- % dF/dhh = -trace{P*Q{i}*P*Q{j}}
- %----------------------------------------------------------------------
- for i = 1:m
- for j = i:m
- dFdhh(i,j) = -spm_trace(PQ{i},PQ{j})*N/2;
- dFdhh(j,i) = dFdhh(i,j);
- end
- end
-
- % add hyperpriors
- %----------------------------------------------------------------------
- e = h - hE;
- dFdh = dFdh - hP*e;
- dFdhh = dFdhh - hP;
-
- % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh
- %----------------------------------------------------------------------
- dh = spm_dx(dFdhh,dFdh,{t});
- h = h + dh;
-
-
- % predicted change in F - increase regularisation if increasing
- %----------------------------------------------------------------------
- pF = dFdh'*dh;
- if pF > dF
- t = t - 1;
- else
- t = t + 1/8;
- end
- dF = pF;
-
- % convergence
- %----------------------------------------------------------------------
- fprintf('%s %-23d: %10s%e [%+3.2f]\n',' ReML Iteration',k,'...',full(dF),t);
- if dF < 1e-2
- break
- end
- end
- % log evidence = ln p(y|X,Q) = ReML objective = F = trace(R'*iC*R*YY)/2 ...
- %--------------------------------------------------------------------------
- Ph = -dFdhh;
- if nargout > 3
-
- % tr(hP*inv(Ph)) - nh (complexity KL cost of parameters = 0)
- %----------------------------------------------------------------------
- Ft = trace(hP/Ph) - length(Ph);
-
- % complexity - KL(Ph,hP)
- %----------------------------------------------------------------------
- Fc = Ft/2 + e'*hP*e/2 + spm_logdet(Ph/hP)/2;
-
- % Accuracy - ln p(Y|h)
- %----------------------------------------------------------------------
- Fa = Ft/2 - trace(C*P*YY*P)/2 - N*n*log(2*pi)/2 - N*spm_logdet(C)/2;
-
- % Free-energy
- %----------------------------------------------------------------------
- F = Fa - Fc;
-
- end
|