mexlf.c 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /*=================================================================
  2. * mexlocfit.c
  3. *
  4. * starting a locfit interface.
  5. *
  6. /* $Revision: 1.5 $ */
  7. #include "mex.h"
  8. #include "lfev.h"
  9. design des;
  10. lfit lf;
  11. int lf_error;
  12. extern void lfmxdata(), lfmxsp(), lfmxevs();
  13. void
  14. mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
  15. {
  16. int d, i, nvc[5], nvm, vc, nv, nc, mvc, lw[7];
  17. double *y;
  18. mxArray *iwkc, *temparray;
  19. const char * fpcnames[] = {"evaluation_points","fitted_values","evaluation_vectors","fit_limits","family_link","kappa"};
  20. const char * evsnames[] = {"cell","splitvar","lo","hi"};
  21. mut_redirect(mexPrintf);
  22. if (nrhs != 3) mexErrMsgTxt("mexlf requires 3 inputs.");
  23. lf_error = 0;
  24. lfit_alloc(&lf);
  25. lfit_init(&lf);
  26. lfmxdata(&lf.lfd,prhs[0]);
  27. d = lf.lfd.d;
  28. lfmxsp(&lf.sp,prhs[2],d);
  29. lfmxevs(&lf,prhs[1]);
  30. guessnv(&lf.evs,&lf.sp,&lf.mdl,lf.lfd.n,d,lw,nvc);
  31. nvm = lf.fp.nvm = nvc[0];
  32. vc = nvc[2];
  33. if (ev(&lf.evs) != EPRES) lf.fp.xev = mxCalloc(nvm*d,sizeof(double));
  34. lf.fp.lev = nvm*d;
  35. lf.fp.wk = lf.fp.coef = mxCalloc(lw[1],sizeof(double));
  36. lf.fp.lwk = lw[1];
  37. lf.evs.iwk = mxCalloc(lw[2],sizeof(int));
  38. lf.evs.liw = lw[2];
  39. plhs[1] = mxCreateDoubleMatrix(1,lw[3],mxREAL);
  40. lf.pc.wk = mxGetPr(plhs[1]);
  41. lf.pc.lwk = lw[3];
  42. lf.fp.kap = mxCalloc(lw[5],sizeof(double));
  43. /* should also allocate design here */
  44. if (!lf_error) startmodule(&lf,&des,NULL,NULL);
  45. /* now, store the results:
  46. plhs[0] stores informtion about fit points and evaluation structure.
  47. it is now a matlab structure, not a cell
  48. fit_points.evaluation_points - matrix of evaluation points.
  49. fit_points.fitted_values - matrix fitted values etc.
  50. fit_points.evaluation_vectors - structure of 'integer' vectors {cell,splitvar,lo,hi}
  51. fit_points.fit_limits - fit limit (matrix with 2 cols).
  52. fit_points.family_link - [familt link] numeric vector.
  53. fit_points.kappa - kap vector.
  54. */
  55. plhs[0] = mxCreateStructMatrix(1,1,6,fpcnames);
  56. if ( plhs[0] == NULL ) {
  57. printf("Problem with CreateStructMatrix for plhs[0]\n");fflush(stdout);
  58. }
  59. mxSetField(plhs[0],0,"evaluation_points",mxCreateDoubleMatrix(d,lf.fp.nv,mxREAL));
  60. memcpy(mxGetPr(mxGetField(plhs[0],0,"evaluation_points")), lf.fp.xev, d*lf.fp.nv*sizeof(double));
  61. mxSetField(plhs[0],0,"fitted_values",mxCreateDoubleMatrix(lf.fp.nv,lf.mdl.keepv,mxREAL));
  62. for (i=0; i<lf.mdl.keepv; i++)
  63. memcpy(&mxGetPr(mxGetField(plhs[0],0,"fitted_values"))[i*lf.fp.nv], &lf.fp.coef[i*nvm], lf.fp.nv*sizeof(double));
  64. /* another bit to save here? -- split values, kdtree */
  65. temparray = mxCreateStructMatrix(1,1,4,evsnames);
  66. if ( temparray == NULL ) {
  67. printf("Problem with CreateStructMatrix for temparray\n");fflush(stdout);
  68. }
  69. mxSetField(plhs[0],0,"evaluation_vectors",temparray);
  70. iwkc = mxGetField(plhs[0],0,"evaluation_vectors");
  71. nv = lf.fp.nv;
  72. nc = lf.evs.nce;
  73. mvc = (nv>nc) ? nv : nc;
  74. mxSetField(iwkc,0,"cell",mxCreateDoubleMatrix(vc,nc,mxREAL)); /* ce */
  75. mxSetField(iwkc,0,"splitvar",mxCreateDoubleMatrix(1,mvc,mxREAL)); /* s */
  76. mxSetField(iwkc,0,"lo",mxCreateDoubleMatrix(1,mvc,mxREAL)); /* lo */
  77. mxSetField(iwkc,0,"hi",mxCreateDoubleMatrix(1,mvc,mxREAL)); /* hi */
  78. y = mxGetPr(mxGetField(iwkc,0,"cell"));
  79. for (i=0; i<vc*nc; i++) y[i] = lf.evs.ce[i];
  80. y = mxGetPr(mxGetField(iwkc,0,"splitvar"));
  81. for (i=0; i<mvc; i++) y[i] = lf.evs.s[i];
  82. y = mxGetPr(mxGetField(iwkc,0,"lo"));
  83. for (i=0; i<mvc; i++) y[i] = lf.evs.lo[i];
  84. y = mxGetPr(mxGetField(iwkc,0,"hi"));
  85. for (i=0; i<mvc; i++) y[i] = lf.evs.hi[i];
  86. mxSetField(plhs[0],0,"fit_limits",mxCreateDoubleMatrix(d,2,mxREAL));
  87. memcpy(mxGetPr(mxGetField(plhs[0],0,"fit_limits")), lf.evs.fl, 2*d*sizeof(double));
  88. mxSetField(plhs[0],0,"family_link",mxCreateDoubleMatrix(1,2,mxREAL));
  89. y = mxGetPr(mxGetField(plhs[0],0,"family_link"));
  90. y[0] = fam(&lf.sp);
  91. y[1] = link(&lf.sp);
  92. mxSetField(plhs[0],0,"kappa",mxCreateDoubleMatrix(1,lf.mdl.keepc,mxREAL));
  93. memcpy(mxGetPr(mxGetField(plhs[0],0,"kappa")),lf.fp.kap,lf.mdl.keepc*sizeof(double));
  94. }