spm_dcm_bma.m 16 KB


  1. function bma = spm_dcm_bma(post,post_indx,subj,Nsamp,oddsr)
  2. % Model-independent samples from DCM posterior
  3. % FORMAT BMA = spm_dcm_bma(DCM)
  4. % FORMAT bma = spm_dcm_bma(post,post_indx,subj,Nsamp,oddsr)
  5. %
  6. % DCM - {subjects x models} cell array of DCMs over which to average
  7. % ---------------------------------------------------------------------
  8. % DCM{i,j}.Ep - posterior expectation
  9. % DCM{i,j}.Cp - posterior covariances
  10. % DCM{i,j}.F - free energy
  11. %
  12. % BMA - Baysian model average structure
  13. % ---------------------------------------------------------------------
  14. % BMA.Ep - BMA posterior mean
  15. % BMA.Cp - BMA posterior VARIANCE
  16. % BMA.F - Accumulated free energy over subjects;
  17. % BMA.P - Posterior model probability over subjects;
  18. %
  19. % BMA.SUB.Ep - subject specific BMA posterior mean
  20. % BMA.SUB.Sp - subject specific BMA posterior variance
  21. % BMA.nsamp - Number of samples
  22. % BMA.Nocc - number of models in Occam's window
  23. % BMA.Mocc - index of models in Occam's window
  24. %
  25. % If DCM is an array, Bayesian model averaging will be applied over
  26. % subjects (i.e., over columns) using FFX Baysian parameter averaging
  27. %
  28. %--------------------------------------------------------------------------
  29. % OR
  30. %--------------------------------------------------------------------------
  31. %
  32. % post [Ni x M] vector of posterior model probabilities
  33. % If Ni > 1 then inference is based on subject-specific RFX posterior
  34. % post_indx models to use in BMA (position of models in subj structure)
  35. % subj subj(n).sess(s).model(m).fname: DCM filename
  36. % Nsamp Number of samples (default = 1e3)
  37. % oddsr posterior odds ratio for defining Occam's window (default=0, ie
  38. % all models used in average)
  39. %
  40. % bma Returned data structure contains
  41. %
  42. % .nsamp Number of samples
  43. % .oddsr odds ratio
  44. % .Nocc number of models in Occam's window
  45. % .Mocc index of models in Occam's window
  46. % .indx subject specific indices of models in Occam's window
  47. %
  48. % For `Subject Parameter Averaging (SPA)':
  49. %
  50. % .mEp posterior mean
  51. % .sEp posterior SD
  52. % .mEps subject specific posterior mean
  53. % .sEps subject specific posterior SD
  54. %
  55. % use the above values in t-tests, ANOVAs to look for significant
  56. % effects in the group
  57. %
  58. % For `Group Parameter Averaging (GPA)':
  59. %
  60. % The following structures contain samples of the DCM A,B,C and D
  61. % matrices from the group posterior density. See pages 6 and 7 of [1]
  62. %
  63. % .a [dima x Nsamp]
  64. % .b [dima x Nsamp]
  65. % .c [dima x Nsamp]
  66. % .d [dima x Nsamp]
  67. %
  68. % Use these to make inferences using the group posterior density approach.
  69. % Essentially, for each parameter, GPA gets a sample which is the average
  70. % over subjects. The collection of samples then constitutes a distribution of
  71. % the group mean from which inferences can be made directly. This is to
  72. % be contrasted with SPA where, for each subject, we average over
  73. % samples to get a mean for that subject. Group level inferences
  74. % are then made using classical inference. SPA is the standard
  75. % approach.
  76. %
  77. %
  78. % For RFX BMA, different subject can have different models in
  79. % Occam's window (and different numbers of models in Occam's
  80. % window)
  81. %
  82. % This routine implements Bayesian averaging over models and subjects
  83. %
  84. % See [1] W Penny, K Stephan, J. Daunizeau, M. Rosa, K. Friston, T. Schofield
  85. % and A Leff. Comparing Families of Dynamic Causal Models.
  86. % PLoS Computational Biology, Mar 2010, 6(3), e1000709.
  87. %__________________________________________________________________________
  88. % Copyright (C) 2009 Wellcome Trust Centre for Neuroimaging
  89. % Will Penny
  90. % $Id: spm_dcm_bma.m 7081 2017-05-27 19:36:09Z karl $
  91. % defaults
  92. %--------------------------------------------------------------------------
  93. if nargin < 4 || isempty(Nsamp)
  94. Nsamp = 1e3;
  95. end
  96. if nargin < 5 || isempty(oddsr)
  97. oddsr = 0;
  98. end
  99. % inputs are DCMs – assemble input arguments
  100. %--------------------------------------------------------------------------
  101. if nargin == 1
  102. if ~iscell(post), post = {post}; end
  103. DCM = post;
  104. [n,m] = size(DCM);
  105. for i = 1:n
  106. for j = 1:m
  107. if ~isfield(DCM{i,j}, 'Ep')
  108. error(['Could not average: subject %d model %d ' ...
  109. 'not estimated'], i, j);
  110. end
  111. subj(i).sess(1).model(j).Ep = DCM{i,j}.Ep;
  112. subj(i).sess(1).model(j).Cp = DCM{i,j}.Cp;
  113. F(i,j) = DCM{i,j}.F;
  114. end
  115. end
  116. % (FFX) posterior over models
  117. %----------------------------------------------------------------------
  118. F = sum(F,1);
  119. F = F - max(F);
  120. P = exp(F);
  121. post = P/sum(P);
  122. indx = 1:m;
  123. % BMA (and BPA)
  124. %----------------------------------------------------------------------
  125. bma = spm_dcm_bma(post,indx,subj,Nsamp);
  126. BMA.Ep = bma.mEp;
  127. BMA.Cp = spm_unvec(spm_vec(bma.sEp).^2,bma.sEp);
  128. BMA.nsamp = bma.nsamp;
  129. BMA.Nocc = bma.Nocc;
  130. BMA.Mocc = bma.Mocc;
  131. BMA.F = F;
  132. BMA.P = P;
  133. for i = 1:n
  134. BMA.SUB(i).Ep = bma.mEps{i};
  135. BMA.SUB(i).Cp = spm_unvec(spm_vec(bma.sEps{i}).^2,bma.sEps{i});
  136. end
  137. bma = BMA;
  138. return
  139. end
  140. Nsub = length(subj);
  141. Nses = length(subj(1).sess);
  142. % Number of regions
  143. %--------------------------------------------------------------------------
  144. try
  145. load(subj(1).sess(1).model(1).fname);
  146. if isfield(DCM,'a')
  147. dcm_fmri = 1;
  148. nreg = DCM.n;
  149. min = DCM.M.m;
  150. dimD = 0;
  151. else
  152. dcm_fmri = 0;
  153. end
  154. catch
  155. dcm_fmri = 0;
  156. end
  157. firstsub = 1;
  158. firstmod = 1;
  159. Ep = [];
  160. [Ni,M] = size(post);
  161. if Ni > 1
  162. rfx = 1;
  163. else
  164. rfx = 0;
  165. end
  166. if rfx
  167. for i = 1:Ni,
  168. mp = max(post(i,:));
  169. post_ind{i} = find(post(i,:)>mp*oddsr);
  170. Nocc(i) = length(post_ind{i});
  171. disp(' ');
  172. disp(sprintf('Subject %d has %d models in Occams window',i,Nocc(i)));
  173. if Nocc(i) == 0,
  174. return;
  175. end
  176. for occ = 1:Nocc(i),
  177. m = post_ind{i}(occ);
  178. disp(sprintf('Model %d, <p(m|Y>=%1.2f',m,post(i,m)));
  179. end
  180. % Renormalise post prob to Occam group
  181. %------------------------------------------------------------------
  182. renorm(i).post = post(i,post_ind{i});
  183. sp = sum(renorm(i).post,2);
  184. renorm(i).post = renorm(i).post./(sp*ones(1,Nocc(i)));
  185. % Load DCM posteriors for models in Occam's window
  186. %------------------------------------------------------------------
  187. for kk = 1:Nocc(i),
  188. sel = post_indx(post_ind{i}(kk));
  189. params(i).model(kk).Ep = subj(i).sess(1).model(sel).Ep;
  190. params(i).model(kk).vEp = spm_vec(params(i).model(kk).Ep);
  191. params(i).model(kk).Cp = full(subj(i).sess(1).model(sel).Cp);
  192. if dcm_fmri
  193. dimDtmp = size(params(i).model(kk).Ep.D,3);
  194. if dimDtmp ~= 0, dimD = dimDtmp; firstsub = i; firstmod = kk;end
  195. end
  196. % Average sessions
  197. %--------------------------------------------------------------
  198. if Nses > 1
  199. clear miCp mEp
  200. disp('Averaging sessions...')
  201. for ss = 1:Nses
  202. % Only parameters with non-zero prior variance
  203. %------------------------------------------------------
  204. sess_model.Cp = full(subj(i).sess(ss).model(sel).Cp);
  205. pCdiag = diag(full(sess_model.Cp));
  206. wsel = find(pCdiag);
  207. if ss == 1
  208. wsel_first = wsel;
  209. else
  210. if ~(length(wsel) == length(wsel_first))
  211. disp('Error: DCMs must have same structure');
  212. return
  213. end
  214. if ~(wsel == wsel_first)
  215. disp('Error: DCMs must have same structure');
  216. return
  217. end
  218. end
  219. % Get posterior precision matrix and mean
  220. %------------------------------------------------------
  221. Cp = sess_model.Cp;
  222. Ep = spm_vec(subj(i).sess(ss).model(sel).Ep);
  223. miCp(:,:,ss) = inv(full(Cp(wsel,wsel)));
  224. mEp(:,ss) = full(Ep(wsel));
  225. end
  226. % Average models using Bayesian fixed-effects analysis
  227. %==========================================================
  228. Cp(wsel,wsel) = inv(sum(miCp,3));
  229. pE = subj(i).sess(ss).model(sel).Ep;
  230. weighted_Ep = 0;
  231. for s = 1:Nses
  232. weighted_Ep = weighted_Ep + miCp(:,:,s)*mEp(:,s);
  233. end
  234. Ep(wsel) = Cp(wsel,wsel)*weighted_Ep;
  235. vEp = Ep;
  236. Ep = spm_unvec(Ep,pE);
  237. params(i).model(kk).Ep = Ep;
  238. params(i).model(kk).vEp = vEp;
  239. params(i).model(kk).Cp = Cp;
  240. end
  241. [evec, eval] = eig(params(i).model(kk).Cp);
  242. deig = diag(eval);
  243. params(i).model(kk).dCp = deig;
  244. params(i).model(kk).vCp = evec;
  245. end
  246. end
  247. else % Use an FFX
  248. % Find models in Occam's window
  249. mp = max(post);
  250. post_ind = find(post>mp*oddsr);
  251. Nocc = length(post_ind);
  252. disp(' ');
  253. fprintf('%d models in Occams window:\n',Nocc);
  254. if Nocc == 0, return; end
  255. for occ = 1:Nocc,
  256. m = post_ind(occ);
  257. fprintf('\tModel %d, p(m|Y)=%1.2f\n',m,post(m));
  258. end
  259. % Renormalise post prob to Occam group
  260. %----------------------------------------------------------------------
  261. post=post(post_ind);
  262. post=post/sum(post);
  263. % Load DCM posteriors for models in Occam's window
  264. %----------------------------------------------------------------------
  265. for n=1:Nsub,
  266. for kk=1:Nocc,
  267. sel = post_indx(post_ind(kk));
  268. params(n).model(kk).Ep = subj(n).sess(1).model(sel).Ep;
  269. params(n).model(kk).vEp = spm_vec(params(n).model(kk).Ep);
  270. params(n).model(kk).Cp = full(subj(n).sess(1).model(sel).Cp);
  271. if dcm_fmri
  272. dimDtmp = size(params(n).model(kk).Ep.D,3);
  273. if dimDtmp ~= 0, dimD = dimDtmp; firstsub = n; firstmod = kk; end
  274. end
  275. if Nses > 1
  276. clear miCp mEp
  277. disp('Averaging sessions...')
  278. % Average sessions
  279. %----------------------------------------------------------
  280. for ss = 1:Nses
  281. % Only parameters with non-zero prior variance
  282. %------------------------------------------------------
  283. sess_model.Cp = full(subj(n).sess(ss).model(sel).Cp);
  284. pCdiag = diag(full(sess_model.Cp));
  285. wsel = find(pCdiag);
  286. if ss == 1
  287. wsel_first = wsel;
  288. else
  289. if ~(length(wsel) == length(wsel_first))
  290. disp('Error: DCMs must have same structure');
  291. return
  292. end
  293. if ~(wsel == wsel_first)
  294. disp('Error: DCMs must have same structure');
  295. return
  296. end
  297. end
  298. % Get posterior precision matrix and mean
  299. %------------------------------------------------------
  300. Cp = sess_model.Cp;
  301. Ep = spm_vec(subj(n).sess(ss).model(sel).Ep);
  302. miCp(:,:,ss) = inv(full(Cp(wsel,wsel)));
  303. mEp(:,ss) = full(Ep(wsel));
  304. end
  305. % Average models using Bayesian fixed-effects analysis
  306. %==========================================================
  307. Cp(wsel,wsel) = inv(sum(miCp,3));
  308. pE = subj(n).sess(ss).model(sel).Ep;
  309. weighted_Ep = 0;
  310. for s = 1:Nses
  311. weighted_Ep = weighted_Ep + miCp(:,:,s)*mEp(:,s);
  312. end
  313. Ep(wsel) = Cp(wsel,wsel)*weighted_Ep;
  314. vEp = Ep;
  315. Ep = spm_unvec(Ep,pE);
  316. params(n).model(kk).Ep = Ep;
  317. params(n).model(kk).vEp = vEp;
  318. params(n).model(kk).Cp = Cp;
  319. end
  320. [evec, eval] = eig(params(n).model(kk).Cp);
  321. deig = diag(eval);
  322. params(n).model(kk).dCp = deig;
  323. params(n).model(kk).vCp = evec;
  324. end
  325. end
  326. end
  327. % Pre-allocate sample arrays
  328. %--------------------------------------------------------------------------
  329. Np = length(params(firstsub).model(firstmod).vEp);
  330. % get dimensions of a b c d parameters
  331. %--------------------------------------------------------------------------
  332. if dcm_fmri
  333. Nr = nreg*nreg;
  334. nmods = size(DCM.Ep.B,3);
  335. Etmp.A = zeros(nreg,nreg,Nsamp);
  336. Etmp.B = zeros(nreg,nreg,nmods,Nsamp);
  337. Etmp.C = zeros(nreg,min,Nsamp);
  338. Etmp.D = zeros(nreg,nreg,dimD,Nsamp);
  339. dima = Nr;
  340. dimb = Nr+Nr*nmods;
  341. dimc = Nr+Nr*nmods+nreg*min;
  342. end
  343. clear Ep
  344. disp('')
  345. disp('Averaging models in Occams window...')
  346. Ep_all = zeros(Np,Nsub);
  347. Ep_sbj = zeros(Np,Nsub,Nsamp);
  348. Ep = zeros(Np,Nsamp);
  349. for i=1:Nsamp
  350. % Pick a model
  351. %----------------------------------------------------------------------
  352. if ~rfx
  353. m = spm_multrnd(post,1);
  354. end
  355. % Pick parameters from model for each subject
  356. %----------------------------------------------------------------------
  357. for n=1:Nsub
  358. clear mu dsig vsig
  359. if rfx
  360. m = spm_multrnd(renorm(n).post,1);
  361. end
  362. mu = params(n).model(m).vEp;
  363. nmu = length(mu);
  364. dsig = params(n).model(m).dCp(1:nmu,1);
  365. vsig(:,:) = params(n).model(m).vCp(1:nmu,1:nmu);
  366. tmp = spm_normrnd(mu,{dsig,vsig},1);
  367. Ep_all(1:nmu,n) = tmp(:);
  368. Ep_sbj(1:nmu,n,i) = Ep_all(1:nmu,n);
  369. end
  370. % Average over subjects
  371. %----------------------------------------------------------------------
  372. Ep(:,i) = mean(Ep_all,2);
  373. end
  374. % save mean parameters
  375. %--------------------------------------------------------------------------
  376. Ep_avg = mean(Ep,2);
  377. Ep_std = std(Ep,0,2);
  378. Ep_avg = spm_unvec(Ep_avg,params(1).model(1).Ep);
  379. Ep_std = spm_unvec(Ep_std,params(1).model(1).Ep);
  380. bma.mEp = Ep_avg;
  381. bma.sEp = Ep_std;
  382. Ep_avgsbj = mean(Ep_sbj,3);
  383. Ep_stdsbj = std(Ep_sbj,0,3);
  384. for is=1:Nsub
  385. bma.mEps{is}=spm_unvec(Ep_avgsbj(:,is),params(1).model(1).Ep);
  386. bma.sEps{is}=spm_unvec(Ep_stdsbj(:,is),params(1).model(1).Ep);
  387. end
  388. if dcm_fmri
  389. bma.a = spm_unvec(Ep(1:dima,:),Etmp.A);
  390. bma.b = spm_unvec(Ep(dima+1:dimb,:),Etmp.B);
  391. bma.c = spm_unvec(Ep(dimb+1:dimc,:),Etmp.C);
  392. if dimD ~=0
  393. bma.d = spm_unvec(Ep(dimc+1:dimc+Nr*dimD,:),Etmp.D);
  394. else
  395. bma.d = Etmp.D;
  396. end
  397. end
  398. % storing parameters
  399. % -------------------------------------------------------------------------
  400. bma.nsamp = Nsamp;
  401. bma.oddsr = oddsr;
  402. bma.Nocc = Nocc;
  403. bma.Mocc = post_ind;