select_RPEmods.m 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. function bestModStruct = select_RPEmods(os, timePeriod, varargin)
  2. % timePeriod should be 'RD', 'cue', or 'PE'
  3. % scoreToUse should be 'AIC', 'BIC', or 'LR'
  4. % pValue is used for LR tests
  5. p = inputParser();
  6. p.addParameter('scoreToUse', 'AIC')
  7. p.addParameter('plotModels_Flag', true)
  8. p.addParameter('particularModels', '')
  9. p.addParameter('normalizeHistogram_Flag', false)
  10. p.addParameter('pValue',0.05)
  11. p.parse(varargin{:});
  12. bestModStruct = struct();
  13. bestModStruct.all_scores = [];
  14. modToUse = ['mod_' timePeriod];
  15. allMods = fields(os(1).(modToUse))';
  16. if ~isempty(p.Results.particularModels)
  17. allMods = intersect(allMods, p.Results.particularModels, 'stable');
  18. end
  19. switch p.Results.scoreToUse
  20. % for AIC or BIC, grab the pre-computed score
  21. case {'AIC','BIC'}
  22. for i = 1:length(os)
  23. tmp_scores = [];
  24. for curr_mod = allMods
  25. curr_mod = curr_mod{:};
  26. tmp_scores = [tmp_scores; os(i).(modToUse).(curr_mod).(p.Results.scoreToUse)];
  27. end
  28. bestModStruct.all_scores = [bestModStruct.all_scores tmp_scores];
  29. end
  30. [~, bestMod] = min(bestModStruct.all_scores);
  31. case 'LR'
  32. warning('off','econ:lratiotest:RLLExceedsULL') % if LH values are very close, you'll get a ton of warnings; turn off
  33. % for LR test, a touch more involved; start with unrestricted model and go down stepwise
  34. % first find the dof for each model and sort in descending order
  35. LR_dof = [];
  36. for curr_mod = allMods
  37. curr_mod = curr_mod{:};
  38. LR_dof = [LR_dof length(os(1).(modToUse).(curr_mod).bestParams)];
  39. end
  40. [LR_dof, LR_order] = sort(LR_dof,'descend');
  41. allMods = allMods(LR_order); % reorder the models if necessary
  42. bestMod = [];
  43. for i = 1:length(os)
  44. modFound_flag = false;
  45. modInd = 1;
  46. while modFound_flag == false
  47. h = lratiotest(os(i).(modToUse).(allMods{modInd}).LH, ...
  48. os(i).(modToUse).(allMods{modInd + 1}).LH, ...
  49. LR_dof(modInd) - LR_dof(modInd + 1), ...
  50. p.Results.pValue);
  51. if h == 1 % if the improvement is significant, keep it and move on
  52. modFound_flag = true;
  53. bestMod = [bestMod modInd];
  54. else % if the improvement is insigificant, mov down
  55. modInd = modInd + 1; % iterate to the next model
  56. if modInd == length(allMods) % if this is the last model, then it's best
  57. modFound_flag = true;
  58. bestMod = [bestMod modInd];
  59. end
  60. end
  61. end
  62. end
  63. warning('on','econ:lratiotest:RLLExceedsULL') % turn warnings back on
  64. end
  65. bestModStruct.bestMod = bestMod;
  66. bestModStruct.bestMod_name = allMods(bestMod);
  67. for curr_mod = allMods
  68. curr_mod = curr_mod{:};
  69. bestModStruct.(['mask_' curr_mod]) = strcmp(bestModStruct.bestMod_name, curr_mod);
  70. end
  71. if p.Results.plotModels_Flag == true
  72. figure
  73. bins = 0.5:length(allMods) + 0.5;
  74. if p.Results.normalizeHistogram_Flag == true
  75. histogram(bestMod, bins, 'Normalization', 'probability')
  76. else
  77. histogram(bestMod, bins)
  78. end
  79. xlim([min(bins)-0.5 max(bins)+0.5])
  80. set(gca,'tickdir','out', 'xtick', 1:length(allMods) + 0.5,...
  81. 'xticklabel', strrep(allMods, '_', '-'), ...
  82. 'xticklabelrotation',60)
  83. ylabel('Number of neurons')
  84. title(['Lowest (best) ' p.Results.scoreToUse])
  85. end