quadfitN.m 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. function [H, d, c, e, fit] = quadfitN(xs, Fs, at_ori)
  2. % function [H, d, c, e] = quadfitN(xs, Fs, [at_ori])
  3. %
  4. % least-square fit of a quadratic function to points in N dimensions
  5. %
  6. %
  7. % INPUTS:
  8. %
  9. % xs N x M matrix, where N is the number of dimensions and M is the
  10. % number of data points
  11. %
  12. % Fs 1 x M vector, the functional values corresponding to each
  13. % column in xs
  14. %
  15. % at_ori optional parameter, 0 by default
  16. % if at_ori = 1, then [xs = 0, Fs = 0] is assumed to be the extremum
  17. % of the quadratic function
  18. %
  19. % OUTPUTS:
  20. %
  21. % H N x N matrix, the hessian matrix of second derivatives of the
  22. % fit
  23. %
  24. % d 1 x N vector, linear terms
  25. %
  26. % c a scalar, the constant term
  27. %
  28. % e 1 x M vector, the error of the computed fit
  29. if nargin<3,
  30. at_ori = 0;
  31. end;
  32. [N M] = size(xs); % N dimensions, M points
  33. if at_ori,
  34. X = zeros(M, sum(1:N)); % there are sum(1:N) quadratic terms
  35. for m = 1:M,
  36. myx = xs(:,m)'; % this data point
  37. myterms = [];
  38. for i = 1:N,
  39. for j = i:N,
  40. if i==j, scale = 2; else scale = 1; end;
  41. myterms = [myterms myx(i)*myx(j)/scale]; %#ok<AGROW>
  42. end;
  43. end;
  44. X(m,:) = myterms;
  45. end;
  46. else
  47. X = zeros(M, sum(1:N)+N+1); % there are sum(1:N) quadratic terms, N linear terms, and 1 constant term
  48. for m = 1:M,
  49. myx = xs(:,m)'; % this data point
  50. myterms = [];
  51. for i = 1:N, % the quadratic terms
  52. for j = i:N,
  53. if i==j, scale = 2; else scale = 1; end;
  54. myterms = [myterms myx(i)*myx(j)/scale]; %#ok<AGROW>
  55. end;
  56. end;
  57. myterms = [myterms myx 1]; %#ok<AGROW> % the linear and constant terms
  58. X(m,:) = myterms;
  59. end;
  60. end;
  61. % solve for A in X*A = Fs,
  62. A = pinv(X)*Fs(:);
  63. % unpack A into H, d, and c
  64. H = zeros(N,N);
  65. d = zeros(1,N);
  66. c = 0;
  67. mark = 1;
  68. for i = 1:N,
  69. for j = i:N,
  70. H(i, j) = A(mark);
  71. H(j, i) = H(i, j); % symmetry
  72. mark = mark+1;
  73. end;
  74. end;
  75. if ~at_ori,
  76. for i = 1:N,
  77. d(i) = A(mark);
  78. mark = mark+1;
  79. end;
  80. c = A(mark);
  81. end;
  82. e = Fs - (diag(xs'*H*xs*0.5)' + d*xs + c);
  83. fit = @(ns)(diag(ns'*H*ns*0.5)' + d*ns + c);