helper_gRW_V.m 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. function ms = helper_RW_V(os, varargin)
  2. % helper_RW_V Fits neural models to test RW model predictions ~ value
  3. % ms = helper_RW_V(os, varargin)
  4. % INPUTS
  5. % os: behavioral data structure
  6. % .spikeCount: number of spikes in a particular time bin
  7. % .rewards: 0 for maltodextrin, 1 for sucrose
  8. % .timeLocked: 1 if the animal licked within 2s, 0 if the animal did not (logical)
  9. % timeLocked will always be more than or equal to the number of spike trials
  10. % varargin
  11. % StartingPoints: determines how many points to optimize from
  12. % ParticularModel: cell array of strings of models to use
  13. % RNG: random number generator seed (default = 1)
  14. % OUTPUTS
  15. % ms: model structure of fits
  16. p = inputParser;
  17. p.addParameter('StartingPoints', 1)
  18. p.addParameter('ParticularModel', []);
  19. p.addParameter('RNG', []);
  20. p.parse(varargin{:});
  21. if ~isempty(p.Results.RNG)
  22. rng(p.Results.RNG)
  23. end
  24. % Initialize models
  25. if isempty(p.Results.ParticularModel)
  26. modelNames = {'base','prev','mean'};
  27. else
  28. modelNames = p.Results.ParticularModel;
  29. end
  30. % Set up optimization problem
  31. options = optimset('Algorithm', 'interior-point','ObjectiveLimit',...
  32. -1.000000000e+300,'TolFun',1e-15, 'Display','off');
  33. % set boundary conditions
  34. alpha_range = [0 1];
  35. slope_range = [0 20]; % reward sensitivity
  36. intercept_range = [-20 20]; % baseline spiking
  37. for currMod = modelNames
  38. currMod = currMod{:};
  39. % initialize output variables
  40. runs = p.Results.StartingPoints;
  41. LH = zeros(runs, 1);
  42. if strcmp(currMod, 'base')
  43. paramNames = {'alpha','slope','intercept'};
  44. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  45. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  46. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  47. numParam = size(startValues, 2);
  48. allParams = zeros(runs, numParam);
  49. A=[eye(numParam); -eye(numParam)];
  50. b=[ alpha_range(2); slope_range(2); intercept_range(2);
  51. -alpha_range(1); -slope_range(1); -intercept_range(1)];
  52. parfor r = 1:runs
  53. [allParams(r, :), LH(r, :)] = ...
  54. fmincon(@ott_RW_V_base, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  55. end
  56. [~, bestFit] = min(LH);
  57. ms.(currMod).bestParams = allParams(bestFit, :);
  58. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  59. ott_RW_V_base(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  60. elseif strcmp(currMod, 'base_asymm')
  61. paramNames = {'alphaPPE','alphaNPE','slope','intercept'};
  62. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  63. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  64. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  65. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  66. numParam = size(startValues, 2);
  67. allParams = zeros(runs, numParam);
  68. A=[eye(numParam); -eye(numParam)];
  69. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2);
  70. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1)];
  71. parfor r = 1:runs
  72. [allParams(r, :), LH(r, :)] = ...
  73. fmincon(@ott_RW_V_base_asymm, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  74. end
  75. [~, bestFit] = min(LH);
  76. ms.(currMod).bestParams = allParams(bestFit, :);
  77. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  78. ott_RW_V_base_asymm(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  79. elseif strcmp(currMod, 'prev')
  80. paramNames = {'slope','intercept'};
  81. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  82. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  83. numParam = size(startValues, 2);
  84. allParams = zeros(runs, numParam);
  85. A=[eye(numParam); -eye(numParam)];
  86. b=[ slope_range(2); intercept_range(2);
  87. -slope_range(1); -intercept_range(1)];
  88. parfor r = 1:runs
  89. [allParams(r, :), LH(r, :)] = ...
  90. fmincon(@ott_RW_V_prev, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  91. end
  92. [~, bestFit] = min(LH);
  93. ms.(currMod).bestParams = allParams(bestFit, :);
  94. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  95. ott_RW_V_prev(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  96. elseif strcmp(currMod, 'mean')
  97. paramNames = {''};
  98. numParam = 0;
  99. [LH, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  100. ott_RW_V_mean([], os.spikeCount);
  101. bestFit = 1;
  102. ms.(currMod).bestParams = [];
  103. else
  104. error('RW model: Model name not found')
  105. end
  106. ms.(currMod).paramNames = paramNames;
  107. ms.(currMod).LH = -1 * LH(bestFit, :);
  108. ms.(currMod).BIC = log(length(os.spikeCount))*numParam - 2*ms.(currMod).LH;
  109. ms.(currMod).AIC = 2*numParam - 2*ms.(currMod).LH;
  110. ms.(currMod).AICc = ms.(currMod).AIC + (2*numParam^2 + 2*numParam)/(length(os.spikeCount) - numParam - 1);
  111. end