fitnonlinearmodel_helper.m 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. function results = fitnonlinearmodel_helper(opt,stimulus,tmatrix,smatrix,trainfun,testfun)
  2. % This is a helper function for fitnonlinearmodel.m. Not for external use!
  3. %
  4. % Notes:
  5. % - opt.data is always a cell vector and contains only one voxel
  6. % - in the nonlinear case, the seed to use has been hacked into model{1}{1} and may have multiple rows
  7. % calc
  8. islinear = isa(opt.model,'function_handle');
  9. if ~islinear
  10. ismultipleseeds = size(opt.model{1}{1},1) > 1;
  11. ismultiplemodels = length(opt.model) > 1;
  12. end
  13. % calc
  14. wantmodelfit = ~(ismember('modelfit',opt.dontsave) && ~ismember('modelfit',opt.dosave));
  15. if islinear
  16. numparams = size(stimulus{1},2);
  17. else
  18. numparams = size(opt.model{end}{2},2);
  19. end
  20. % init
  21. results = struct;
  22. results.params = zeros(length(trainfun),numparams);
  23. results.testdata = cell(1,length(trainfun)); % but converted to a matrix at the end
  24. results.modelpred = cell(1,length(trainfun)); % but converted to a matrix at the end
  25. results.modelfit = cell(1,length(trainfun)); % but converted to a matrix at the end
  26. results.trainperformance = zeros(1,length(trainfun));
  27. results.testperformance = zeros(1,length(trainfun));
  28. results.aggregatedtestperformance = [];
  29. if islinear
  30. results.numiters = [];
  31. results.resnorms = [];
  32. else
  33. results.numiters = zeros(length(trainfun),size(opt.model{1}{1},1),length(opt.model));
  34. results.resnorms = zeros(length(trainfun),size(opt.model{1}{1},1));
  35. end
  36. % loop over resampling cases
  37. for rr=1:length(trainfun)
  38. fprintf(' starting resampling case %d of %d.\n',rr,length(trainfun));
  39. % deal with resampling
  40. trainstim = feval(trainfun{rr},stimulus);
  41. traindata = feval(trainfun{rr},opt.data); % result is a column vector
  42. trainT = projectionmatrix(feval(trainfun{rr},tmatrix)); % NOTE: potentially slow step. make sparse? [or CACHE]
  43. trainS = projectionmatrix(feval(trainfun{rr},smatrix)); % NOTE: potentially slow step. make sparse? [or CACHE]
  44. teststim = feval(testfun{rr},stimulus);
  45. testdata = feval(testfun{rr},opt.data); % result is a column vector
  46. testT = projectionmatrix(feval(testfun{rr},tmatrix)); % NOTE: potentially slow step. make sparse? [or CACHE]
  47. testS = projectionmatrix(feval(testfun{rr},smatrix)); % NOTE: potentially slow step. make sparse? [or CACHE]
  48. if wantmodelfit % save on memory if user doesn't even want 'modelfit'
  49. allstim = catcell(1,stimulus);
  50. end
  51. % precompute
  52. traindataT = trainT*traindata; % remove regressors from data (fitting)
  53. % deal with last-minute data division
  54. if ~islinear
  55. datastd = std(traindataT);
  56. if datastd == 0
  57. datastd = 1;
  58. end
  59. traindataT = traindataT / datastd;
  60. end
  61. % deal with options
  62. if ~islinear
  63. options = opt.optimoptions;
  64. if ~isempty(opt.outputfcn)
  65. if nargin(opt.outputfcn) == 4
  66. options.OutputFcn = @(a,b,c) feval(opt.outputfcn,a,b,c,traindataT);
  67. else
  68. options.OutputFcn = opt.outputfcn;
  69. end
  70. end
  71. end
  72. % ok, deal with linear case
  73. if islinear
  74. % do the fitting. note that we take the mean across the third dimension
  75. % to deal with the case where the stimulus consists of multiple frames.
  76. finalparams = feval(opt.model,trainT*mean(trainstim,3),traindataT);
  77. % report
  78. fprintf(' the estimated parameters are ['); ...
  79. fprintf('%.3f ',finalparams); fprintf('].\n');
  80. % ok, deal with nonlinear case
  81. else
  82. % loop over seeds
  83. params = [];
  84. for ss=1:size(opt.model{1}{1},1)
  85. if ismultipleseeds
  86. fprintf(' trying seed %d of %d.\n',ss,size(opt.model{1}{1},1));
  87. end
  88. % loop through models
  89. for mm=1:length(opt.model)
  90. % which parameters are we actually fitting?
  91. ix = ~isnan(opt.model{mm}{2}(1,:));
  92. % calculate seed, model, and transform
  93. if mm==1
  94. seed = opt.model{mm}{1}(ss,:);
  95. model = opt.model{mm}{3};
  96. transform = opt.model{mm}{4};
  97. else
  98. seed = feval(opt.model{mm}{1},params0);
  99. model = feval(opt.model{mm}{3},params0);
  100. transform = feval(opt.model{mm}{4},params0);
  101. end
  102. % in the special case that the stimulus consists of multiple frames,
  103. % then we have to modify model so that it averages across the
  104. % predicted response associated with each frame. this is magical voodoo here.
  105. if size(trainstim,3) > 1
  106. nums = repmat(size(trainstim,3),[1 size(trainstim,1)]);
  107. model = @(pp,dd) chunkfun(feval(model,pp,squish(permute(dd,[3 1 2]),2)),nums,@(x) mean(x,1)).';
  108. end
  109. % figure out bounds to use
  110. if isequal(options.Algorithm,'levenberg-marquardt')
  111. lb = [];
  112. ub = [];
  113. else
  114. lb = opt.model{mm}{2}(1,ix);
  115. ub = opt.model{mm}{2}(2,ix);
  116. end
  117. % precompute
  118. trainstimTRANSFORM = feval(transform,trainstim);
  119. % define the final model function
  120. fun = @(pp) trainT*feval(model,copymatrix(seed,ix,pp),trainstimTRANSFORM);
  121. % report
  122. if ismultiplemodels
  123. fprintf(' for model %d of %d, the seed is [', ...
  124. mm,length(opt.model)); fprintf('%.3f ',seed); fprintf('].\n');
  125. else
  126. fprintf(' the seed is ['); fprintf('%.3f ',seed); fprintf('].\n');
  127. end
  128. % perform the fit (NOTICE THE DIVISION BY DATASTD, THE NAN PROTECTION, THE CONVERSION TO DOUBLE)
  129. if ~any(ix)
  130. params0 = seed; % if no parameters are to be optimized, just return the seed
  131. resnorm = NaN;
  132. output = [];
  133. output.iterations = NaN;
  134. else
  135. [params0,resnorm,residual,exitflag,output] = ...
  136. lsqcurvefit(@(x,y) double(nanreplace(feval(fun,x) / datastd,0,2)),seed(ix),[],double(traindataT),lb,ub,options);
  137. params0 = copymatrix(seed,ix,params0);
  138. end
  139. % report
  140. fprintf(' the estimated parameters are ['); ...
  141. fprintf('%.3f ',params0); fprintf('].\n');
  142. % record
  143. results.numiters(rr,ss,mm) = output.iterations;
  144. end
  145. % record
  146. results.resnorms(rr,ss) = resnorm; % final resnorm
  147. params(ss,:) = params0; % final parameters
  148. end
  149. % which seed produced the best results?
  150. [d,mnix] = min(results.resnorms(rr,:));
  151. finalparams = params(mnix,:);
  152. end
  153. % record the results
  154. results.params(rr,:) = finalparams;
  155. % report
  156. if ~islinear && ismultipleseeds
  157. fprintf(' seed %d was best. final estimated parameters are [',mnix); ...
  158. fprintf('%.3f ',finalparams); fprintf('].\n');
  159. end
  160. % prepare data and model fits
  161. % [NOTE: in the nonlinear case, this inherits model, transform, and trainstimTRANSFORM from above!!]
  162. traindatatemp = trainS*traindata;
  163. if islinear
  164. modelfittemp = trainS*(trainstim*finalparams');
  165. else
  166. modelfittemp = nanreplace(trainS*feval(model,finalparams,trainstimTRANSFORM),0,2);
  167. end
  168. if isempty(testdata) % handle this case explicitly, just to avoid problems
  169. results.testdata{rr} = [];
  170. results.modelpred{rr} = [];
  171. else
  172. results.testdata{rr} = testS*testdata;
  173. if islinear
  174. results.modelpred{rr} = testS*(teststim*finalparams');
  175. else
  176. results.modelpred{rr} = nanreplace(testS*feval(model,finalparams,feval(transform,teststim)),0,2);
  177. end
  178. end
  179. % prepare modelfit
  180. if wantmodelfit
  181. if islinear
  182. results.modelfit{rr} = (allstim*finalparams')';
  183. else
  184. results.modelfit{rr} = nanreplace(feval(model,finalparams,feval(transform,allstim)),0,2)';
  185. end
  186. else
  187. results.modelfit{rr} = []; % if not wanted by user, don't bother computing
  188. end
  189. % compute metrics
  190. results.trainperformance(rr) = feval(opt.metric,modelfittemp,traindatatemp);
  191. if isempty(results.testdata{rr}) % handle this case explicitly, just to avoid problems
  192. results.testperformance(rr) = NaN;
  193. else
  194. results.testperformance(rr) = feval(opt.metric,results.modelpred{rr},results.testdata{rr});
  195. end
  196. % report
  197. fprintf(' trainperformance is %.2f. testperformance is %.2f.\n', ...
  198. results.trainperformance(rr),results.testperformance(rr));
  199. end
  200. % compute aggregated metrics
  201. results.testdata = catcell(1,results.testdata);
  202. results.modelpred = catcell(1,results.modelpred);
  203. results.modelfit = catcell(1,results.modelfit);
  204. if isempty(results.testdata)
  205. results.aggregatedtestperformance = NaN;
  206. else
  207. results.aggregatedtestperformance = feval(opt.metric,results.modelpred,results.testdata);
  208. end
  209. % report
  210. fprintf(' aggregatedtestperformance is %.2f.\n',results.aggregatedtestperformance);