classRF_predict.m 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. %**************************************************************
  2. %* mex interface to Andy Liaw et al.'s C code (used in R package randomForest)
  3. %* Added by Abhishek Jaiantilal ( abhishek.jaiantilal@colorado.edu )
  4. %* License: GPLv2
  5. %* Version: 0.02
  6. %
  7. % Calls Classification Random Forest
  8. % A wrapper matlab file that calls the mex file
  9. % This does prediction given the data and the model file
  10. % Options depicted in predict function in http://cran.r-project.org/web/packages/randomForest/randomForest.pdf
  11. %**************************************************************
  12. %function [Y_hat votes] = classRF_predict(X,model, extra_options)
  13. % requires 2 arguments
  14. % X: data matrix
  15. % model: generated via classRF_train function
  16. % extra_options.predict_all = predict_all if set will send all the prediction.
  17. %
  18. %
  19. % Returns
  20. % Y_hat - prediction for the data
  21. % votes - unnormalized weights for the model
  22. % prediction_per_tree - per tree prediction. the returned object .
  23. % If predict.all=TRUE, then the individual component of the returned object is a character
  24. % matrix where each column contains the predicted class by a tree in the forest.
  25. %
  26. %
  27. % Not yet implemented
  28. % proximity
  29. function [Y_new, votes, prediction_per_tree] = classRF_predict(X,model, extra_options)
  30. if nargin<2
  31. error('need atleast 2 parameters,X matrix and model');
  32. end
  33. if exist('extra_options','var')
  34. if isfield(extra_options,'predict_all')
  35. predict_all = extra_options.predict_all;
  36. end
  37. end
  38. if ~exist('predict_all','var'); predict_all=0;end
  39. [Y_hat,prediction_per_tree,votes] = mexClassRF_predict(X',model.nrnodes,model.ntree,model.xbestsplit,model.classwt,model.cutoff,model.treemap,model.nodestatus,model.nodeclass,model.bestvar,model.ndbigtree,model.nclass, predict_all);
  40. %keyboard
  41. votes = votes';
  42. clear mexClassRF_predict
  43. Y_new = double(Y_hat);
  44. new_labels = model.new_labels;
  45. orig_labels = model.orig_labels;
  46. for i=1:length(orig_labels)
  47. Y_new(find(Y_hat==new_labels(i)))=Inf;
  48. Y_new(isinf(Y_new))=orig_labels(i);
  49. end
  50. 1;