classRF_train.m 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. %**************************************************************
  2. %* mex interface to Andy Liaw et al.'s C code (used in R package randomForest)
  3. %* Added by Abhishek Jaiantilal ( abhishek.jaiantilal@colorado.edu )
  4. %* License: GPLv2
  5. %* Version: 0.02
  6. %
  7. % Calls Classification Random Forest
  8. % A wrapper matlab file that calls the mex file
  9. % This does training given the data and labels
  10. % Documentation copied from R-packages pdf
  11. % http://cran.r-project.org/web/packages/randomForest/randomForest.pdf
  12. % Tutorial on getting this working in tutorial_ClassRF.m
  13. %**************************************************************
  14. % function model = classRF_train(X,Y,ntree,mtry, extra_options)
  15. %
  16. %___Options
  17. % requires 2 arguments and the rest 3 are optional
  18. % X: data matrix
  19. % Y: target values
  20. % ntree (optional): number of trees (default is 500). also if set to 0
  21. % will default to 500
  22. % mtry (default is floor(sqrt(size(X,2))) D=number of features in X). also if set to 0
  23. % will default to 500
  24. %
  25. %
  26. % Note: TRUE = 1 and FALSE = 0 below
  27. % extra_options represent a structure containing various misc. options to
  28. % control the RF
  29. % extra_options.replace = 0 or 1 (default is 1) sampling with or without
  30. % replacement
  31. % extra_options.classwt = priors of classes. Here the function first gets
  32. % the labels in ascending order and assumes the
  33. % priors are given in the same order. So if the class
  34. % labels are [-1 1 2] and classwt is [0.1 2 3] then
  35. % there is a 1-1 correspondence. (ascending order of
  36. % class labels). Once this is set the freq of labels in
  37. % train data also affects.
  38. % extra_options.cutoff (Classification only) = A vector of length equal to number of classes. The ?winning?
  39. % class for an observation is the one with the maximum ratio of proportion
  40. % of votes to cutoff. Default is 1/k where k is the number of classes (i.e., majority
  41. % vote wins).
  42. % extra_options.strata = (not yet stable in code) variable that is used for stratified
  43. % sampling. I don't yet know how this works. Disabled
  44. % by default
  45. % extra_options.sampsize = Size(s) of sample to draw. For classification,
  46. % if sampsize is a vector of the length the number of strata, then sampling is stratified by strata,
  47. % and the elements of sampsize indicate the numbers to be
  48. % drawn from the strata.
  49. % extra_options.nodesize = Minimum size of terminal nodes. Setting this number larger causes smaller trees
  50. % to be grown (and thus take less time). Note that the default values are different
  51. % for classification (1) and regression (5).
  52. % extra_options.importance = Should importance of predictors be assessed?
  53. % extra_options.localImp = Should casewise importance measure be computed? (Setting this to TRUE will
  54. % override importance.)
  55. % extra_options.proximity = Should proximity measure among the rows be calculated?
  56. % extra_options.oob_prox = Should proximity be calculated only on 'out-of-bag' data?
  57. % extra_options.do_trace = If set to TRUE, give a more verbose output as randomForest is run. If set to
  58. % some integer, then running output is printed for every
  59. % do_trace trees.
  60. % extra_options.keep_inbag Should an n by ntree matrix be returned that keeps track of which samples are
  61. % 'in-bag' in which trees (but not how many times, if sampling with replacement)
  62. %
  63. % Options eliminated
  64. % corr_bias which happens only for regression ommitted
  65. % norm_votes - always set to return total votes for each class.
  66. %
  67. %___Returns model which has
  68. % importance = a matrix with nclass + 2 (for classification) or two (for regression) columns.
  69. % For classification, the first nclass columns are the class-specific measures
  70. % computed as mean decrease in accuracy. The nclass + 1st column is the
  71. % mean decrease in accuracy over all classes. The last column is the mean decrease
  72. % in Gini index. For Regression, the first column is the mean decrease in
  73. % accuracy and the second the mean decrease in MSE. If importance=FALSE,
  74. % the last measure is still returned as a vector.
  75. % importanceSD = The ?standard errors? of the permutation-based importance measure. For classification,
  76. % a p by nclass + 1 matrix corresponding to the first nclass + 1
  77. % columns of the importance matrix. For regression, a length p vector.
  78. % localImp = a p by n matrix containing the casewise importance measures, the [i,j] element
  79. % of which is the importance of i-th variable on the j-th case. NULL if
  80. % localImp=FALSE.
  81. % ntree = number of trees grown.
  82. % mtry = number of predictors sampled for spliting at each node.
  83. % votes (classification only) a matrix with one row for each input data point and one
  84. % column for each class, giving the fraction or number of ?votes? from the random
  85. % forest.
  86. % oob_times number of times cases are 'out-of-bag' (and thus used in computing OOB error
  87. % estimate)
  88. % proximity if proximity=TRUE when randomForest is called, a matrix of proximity
  89. % measures among the input (based on the frequency that pairs of data points are
  90. % in the same terminal nodes).
  91. % errtr = first column is OOB Err rate, second is for class 1 and so on
  92. function model=classRF_train(X,Y,ntree,mtry, extra_options)
  93. DEFAULTS_ON =0;
  94. %DEBUG_ON=0;
  95. TRUE=1;
  96. FALSE=0;
  97. orig_labels = sort(unique(Y));
  98. Y_new = Y;
  99. new_labels = 1:length(orig_labels);
  100. for i=1:length(orig_labels)
  101. Y_new(find(Y==orig_labels(i)))=Inf;
  102. Y_new(isinf(Y_new))=new_labels(i);
  103. end
  104. Y = Y_new;
  105. if exist('extra_options','var')
  106. if isfield(extra_options,'DEBUG_ON'); DEBUG_ON = extra_options.DEBUG_ON; end
  107. if isfield(extra_options,'replace'); replace = extra_options.replace; end
  108. if isfield(extra_options,'classwt'); classwt = extra_options.classwt; end
  109. if isfield(extra_options,'cutoff'); cutoff = extra_options.cutoff; end
  110. if isfield(extra_options,'strata'); strata = extra_options.strata; end
  111. if isfield(extra_options,'sampsize'); sampsize = extra_options.sampsize; end
  112. if isfield(extra_options,'nodesize'); nodesize = extra_options.nodesize; end
  113. if isfield(extra_options,'importance'); importance = extra_options.importance; end
  114. if isfield(extra_options,'localImp'); localImp = extra_options.localImp; end
  115. if isfield(extra_options,'nPerm'); nPerm = extra_options.nPerm; end
  116. if isfield(extra_options,'proximity'); proximity = extra_options.proximity; end
  117. if isfield(extra_options,'oob_prox'); oob_prox = extra_options.oob_prox; end
  118. %if isfield(extra_options,'norm_votes'); norm_votes = extra_options.norm_votes; end
  119. if isfield(extra_options,'do_trace'); do_trace = extra_options.do_trace; end
  120. %if isfield(extra_options,'corr_bias'); corr_bias = extra_options.corr_bias; end
  121. if isfield(extra_options,'keep_inbag'); keep_inbag = extra_options.keep_inbag; end
  122. end
  123. keep_forest=1; %always save the trees :)
  124. %set defaults if not already set
  125. if ~exist('DEBUG_ON','var') DEBUG_ON=FALSE; end
  126. if ~exist('replace','var'); replace = TRUE; end
  127. %if ~exist('classwt','var'); classwt = []; end %will handle these three later
  128. %if ~exist('cutoff','var'); cutoff = 1; end
  129. %if ~exist('strata','var'); strata = 1; end
  130. if ~exist('sampsize','var');
  131. if (replace)
  132. sampsize = size(X,1);
  133. else
  134. sampsize = ceil(0.632*size(X,1));
  135. end;
  136. end
  137. if ~exist('nodesize','var'); nodesize = 1; end %classification=1, regression=5
  138. if ~exist('importance','var'); importance = FALSE; end
  139. if ~exist('localImp','var'); localImp = FALSE; end
  140. if ~exist('nPerm','var'); nPerm = 1; end
  141. %if ~exist('proximity','var'); proximity = 1; end %will handle these two later
  142. %if ~exist('oob_prox','var'); oob_prox = 1; end
  143. %if ~exist('norm_votes','var'); norm_votes = TRUE; end
  144. if ~exist('do_trace','var'); do_trace = FALSE; end
  145. %if ~exist('corr_bias','var'); corr_bias = FALSE; end
  146. if ~exist('keep_inbag','var'); keep_inbag = FALSE; end
  147. if ~exist('ntree','var') | ntree<=0
  148. ntree=500;
  149. DEFAULTS_ON=1;
  150. end
  151. if ~exist('mtry','var') | mtry<=0 | mtry>size(X,2)
  152. mtry =floor(sqrt(size(X,2)));
  153. end
  154. addclass =isempty(Y);
  155. if (~addclass && length(unique(Y))<2)
  156. error('need atleast two classes for classification');
  157. end
  158. [N D] = size(X);
  159. if N==0; error(' data (X) has 0 rows');end
  160. if (mtry <1 || mtry > D)
  161. DEFAULTS_ON=1;
  162. end
  163. mtry = max(1,min(D,round(mtry)));
  164. if DEFAULTS_ON
  165. fprintf('\tSetting to defaults %d trees and mtry=%d\n',ntree,mtry);
  166. end
  167. if ~isempty(Y)
  168. if length(Y)~=N,
  169. error('Y size is not the same as X size');
  170. end
  171. addclass = FALSE;
  172. else
  173. if ~addclass,
  174. addclass=TRUE;
  175. end
  176. error('have to fill stuff here')
  177. end
  178. if ~isempty(find(isnan(X))); error('NaNs in X'); end
  179. if ~isempty(find(isnan(Y))); error('NaNs in Y'); end
  180. %now handle categories. Problem is that categories in R are more
  181. %enhanced. In this i ask the user to specify the column/features to
  182. %consider as categories, 1 if all the values are real values else
  183. %specify the number of categories here
  184. if exist ('extra_options','var') && isfield(extra_options,'categories')
  185. ncat = extra_options.categories;
  186. else
  187. ncat = ones(1,D);
  188. end
  189. maxcat = max(ncat);
  190. if maxcat>32
  191. error('Can not handle categorical predictors with more than 32 categories');
  192. end
  193. %classRF - line 88 in randomForest.default.R
  194. nclass = length(unique(Y));
  195. if ~exist('cutoff','var')
  196. cutoff = ones(1,nclass)* (1/nclass);
  197. else
  198. if sum(cutoff)>1 || sum(cutoff)<0 || length(find(cutoff<=0))>0 || length(cutoff)~=nclass
  199. error('Incorrect cutoff specified');
  200. end
  201. end
  202. if ~exist('classwt','var')
  203. classwt = ones(1,nclass);
  204. ipi=0;
  205. else
  206. if length(classwt)~=nclass
  207. error('Length of classwt not equal to the number of classes')
  208. end
  209. if ~isempty(find(classwt<=0))
  210. error('classwt must be positive');
  211. end
  212. ipi=1;
  213. end
  214. if ~exist('proximity','var')
  215. proximity = addclass;
  216. oob_prox = proximity;
  217. end
  218. if ~exist('oob_prox','var')
  219. oob_prox = proximity;
  220. end
  221. %i handle the below in the mex file
  222. % if proximity
  223. % prox = zeros(N,N);
  224. % proxts = 1;
  225. % else
  226. % prox = 1;
  227. % proxts = 1;
  228. % end
  229. %i handle the below in the mex file
  230. if localImp
  231. importance = TRUE;
  232. % impmat = zeors(D,N);
  233. else
  234. % impmat = 1;
  235. end
  236. if importance
  237. if (nPerm<1)
  238. nPerm = int32(1);
  239. else
  240. nPerm = int32(nPerm);
  241. end
  242. %classRF
  243. % impout = zeros(D,nclass+2);
  244. % impSD = zeros(D,nclass+1);
  245. else
  246. % impout = zeros(D,1);
  247. % impSD = 1;
  248. end
  249. %i handle the below in the mex file
  250. %somewhere near line 157 in randomForest.default.R
  251. if addclass
  252. % nsample = 2*n;
  253. else
  254. % nsample = n;
  255. end
  256. Stratify = (length(sampsize)>1);
  257. if (~Stratify && sampsize>N)
  258. error('Sampsize too large')
  259. end
  260. if Stratify
  261. if ~exist('strata','var')
  262. strata = Y;
  263. end
  264. nsum = sum(sampsize);
  265. if ( ~isempty(find(sampsize<=0)) || nsum==0)
  266. error('Bad sampsize specification');
  267. end
  268. else
  269. nsum = sampsize;
  270. end
  271. %i handle the below in the mex file
  272. %nrnodes = 2*floor(nsum/nodesize)+1;
  273. %xtest = 1;
  274. %ytest = 1;
  275. %ntest = 1;
  276. %labelts = FALSE;
  277. %nt = ntree;
  278. %[ldau,rdau,nodestatus,nrnodes,upper,avnode,mbest,ndtree]=
  279. %keyboard
  280. if Stratify
  281. strata = int32(strata);
  282. else
  283. strata = int32(1);
  284. end
  285. Options = int32([addclass, importance, localImp, proximity, oob_prox, do_trace, keep_forest, replace, Stratify, keep_inbag]);
  286. if DEBUG_ON
  287. %print the parameters that i am sending in
  288. fprintf('size(x) %d\n',size(X));
  289. fprintf('size(y) %d\n',size(Y));
  290. fprintf('nclass %d\n',nclass);
  291. fprintf('size(ncat) %d\n',size(ncat));
  292. fprintf('maxcat %d\n',maxcat);
  293. fprintf('size(sampsize) %d\n',size(sampsize));
  294. fprintf('sampsize[0] %d\n',sampsize(1));
  295. fprintf('Stratify %d\n',Stratify);
  296. fprintf('Proximity %d\n',proximity);
  297. fprintf('oob_prox %d\n',oob_prox);
  298. fprintf('strata %d\n',strata);
  299. fprintf('ntree %d\n',ntree);
  300. fprintf('mtry %d\n',mtry);
  301. fprintf('ipi %d\n',ipi);
  302. fprintf('classwt %f\n',classwt);
  303. fprintf('cutoff %f\n',cutoff);
  304. fprintf('nodesize %f\n',nodesize);
  305. end
  306. [nrnodes,ntree,xbestsplit,classwt,cutoff,treemap,nodestatus,nodeclass,bestvar,ndbigtree,mtry ...
  307. outcl, counttr, prox, impmat, impout, impSD, errtr, inbag] ...
  308. = mexClassRF_train(X',int32(Y_new),length(unique(Y)),ntree,mtry,int32(ncat), ...
  309. int32(maxcat), int32(sampsize), strata, Options, int32(ipi), ...
  310. classwt, cutoff, int32(nodesize),int32(nsum));
  311. model.nrnodes=nrnodes;
  312. model.ntree=ntree;
  313. model.xbestsplit=xbestsplit;
  314. model.classwt=classwt;
  315. model.cutoff=cutoff;
  316. model.treemap=treemap;
  317. model.nodestatus=nodestatus;
  318. model.nodeclass=nodeclass;
  319. model.bestvar = bestvar;
  320. model.ndbigtree = ndbigtree;
  321. model.mtry = mtry;
  322. model.orig_labels=orig_labels;
  323. model.new_labels=new_labels;
  324. model.nclass = length(unique(Y));
  325. model.outcl = outcl;
  326. model.counttr = counttr;
  327. if proximity
  328. model.proximity = prox;
  329. else
  330. model.proximity = [];
  331. end
  332. model.localImp = impmat;
  333. model.importance = impout;
  334. model.importanceSD = impSD;
  335. model.errtr = errtr';
  336. model.inbag = inbag;
  337. model.votes = counttr';
  338. model.oob_times = sum(counttr)';
  339. clear mexClassRF_train
  340. %keyboard
  341. 1;