spm_LAP_eval.m 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. function [p,dp] = spm_LAP_eval(M,qu,qh)
  2. % Evaluate precisions for a LAP model
  3. % FORMAT [p dp] = spm_LAP_eval(M,qu,qh)
  4. %
  5. % p.h - vector of precisions for causal states (v)
  6. % p.g - vector of precisions for hidden states (x)
  7. %
  8. % dp.h.dx - dp.h/dx
  9. % dp.h.dv - dp.h/dv
  10. % dp.h.dh - dp.h/dh
  11. %
  12. % dp.g.dx - dp.g/dx
  13. % dp.g.dv - dp.g/dv
  14. % dp.g.dg - dp.g/dg
  15. %__________________________________________________________________________
  16. % Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging
  17. % Karl Friston
  18. % $Id: spm_LAP_eval.m 6290 2014-12-20 22:11:50Z karl $
  19. % Get states {qu.v{1},qu.x{1}} in hierarchical form (v{i},x{i})
  20. %--------------------------------------------------------------------------
  21. N = length(M);
  22. v = cell(N,1);
  23. x = cell(N,1);
  24. v(1:N - 1) = spm_unvec(qu.v{1},{M(1 + 1:N).v});
  25. x(1:N - 1) = spm_unvec(qu.x{1},{M(1:N - 1).x});
  26. % precisions
  27. %==========================================================================
  28. for i = 1:N
  29. % precision of causal and hidden states
  30. %----------------------------------------------------------------------
  31. try
  32. h{i,1} = spm_vec(feval(M(i).ph,x{i},v{i},qh.h{i},M(i)));
  33. catch
  34. h{i,1} = sparse(M(i).l,1);
  35. end
  36. try
  37. g{i,1} = spm_vec(feval(M(i).pg,x{i},v{i},qh.g{i},M(i)));
  38. catch
  39. g{i,1} = sparse(M(i).n,1);
  40. end
  41. end
  42. % Concatenate over hierarchical levels
  43. %--------------------------------------------------------------------------
  44. p.h = spm_cat(h);
  45. p.g = spm_cat(g);
  46. if nargout < 2, return, end
  47. % gradients
  48. %==========================================================================
  49. % assume precisions can be functions of hyper-parameters and states
  50. %--------------------------------------------------------------------------
  51. try method.h = M(1).E.method.h; catch, method.h = 1; end
  52. try method.g = M(1).E.method.g; catch, method.g = 1; end
  53. try method.x = M(1).E.method.x; catch, method.x = 1; end
  54. try method.v = M(1).E.method.v; catch, method.v = 1; end
  55. % number of variables
  56. %--------------------------------------------------------------------------
  57. nx = numel(spm_vec(x));
  58. nv = numel(spm_vec(v));
  59. hn = numel(spm_vec(qh.h));
  60. gn = numel(spm_vec(qh.g));
  61. nh = size(p.h,1);
  62. ng = size(p.g,1);
  63. dp.h.dh = sparse(nh,hn);
  64. dp.g.dg = sparse(ng,gn);
  65. dp.h.dx = sparse(nh,nx);
  66. dp.h.dv = sparse(nh,nv);
  67. dp.g.dx = sparse(ng,nx);
  68. dp.g.dv = sparse(ng,nv);
  69. % gradients w.r.t. h only (no state-dependent noise)
  70. %----------------------------------------------------------------------
  71. if method.h || method.g
  72. for i = 1:N
  73. % precision of causal and hidden states
  74. %--------------------------------------------------------------
  75. dhdh{i,i} = spm_diff(M(i).ph,x{i},v{i},qh.h{i},M(i),3);
  76. dgdg{i,i} = spm_diff(M(i).pg,x{i},v{i},qh.g{i},M(i),3);
  77. end
  78. % Concatenate over hierarchical levels
  79. %------------------------------------------------------------------
  80. dp.h.dh = spm_cat(dhdh);
  81. dp.g.dg = spm_cat(dgdg);
  82. end
  83. % gradients w.r.t. causal states
  84. %----------------------------------------------------------------------
  85. if method.v
  86. for i = 1:N
  87. % precision of causal states
  88. %--------------------------------------------------------------
  89. dhdv{i,i} = spm_diff(M(i).ph,x{i},v{i},qh.h{i},M(i),2);
  90. % precision of hidden states
  91. %--------------------------------------------------------------
  92. dgdv{i,i} = spm_diff(M(i).pg,x{i},v{i},qh.g{i},M(i),2);
  93. end
  94. % Concatenate over hierarchical levels
  95. %------------------------------------------------------------------
  96. dp.h.dv = spm_cat(dhdv);
  97. dp.g.dv = spm_cat(dgdv);
  98. end
  99. % gradients w.r.t. hidden states
  100. %----------------------------------------------------------------------
  101. if method.x
  102. for i = 1:N
  103. % precision of causal states
  104. %--------------------------------------------------------------
  105. dhdx{i,i} = spm_diff(M(i).ph,x{i},v{i},qh.h{i},M(i),1);
  106. % precision of hidden states
  107. %--------------------------------------------------------------
  108. dgdx{i,i} = spm_diff(M(i).pg,x{i},v{i},qh.g{i},M(i),1);
  109. end
  110. % Concatenate over hierarchical levels
  111. %------------------------------------------------------------------
  112. dp.h.dx = spm_cat(dhdx);
  113. dp.g.dx = spm_cat(dgdx);
  114. end