helper_RW_V.m 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. function ms = helper_RW_V(os, varargin)
  2. % helper_RW_V Fits neural models to test RW model predictions ~ value
  3. % ms = helper_RW_V(os, varargin)
  4. % INPUTS
  5. % os: behavioral data structure
  6. % .spikeCount: number of spikes in a particular time bin
  7. % .rewards: 0 for maltodextrin, 1 for sucrose
  8. % .timeLocked: 1 if the animal licked within 2s, 0 if the animal did not (logical)
  9. % timeLocked will always be more than or equal to the number of spike trials
  10. % varargin
  11. % StartingPoints: determines how many points to optimize from
  12. % ParticularModel: cell array of strings of models to use
  13. % RNG: random number generator seed (default = 1)
  14. % OUTPUTS
  15. % ms: model structure of fits
  16. p = inputParser;
  17. p.addParameter('StartingPoints', 1)
  18. p.addParameter('ParticularModel', []);
  19. p.addParameter('RNG', []);
  20. p.parse(varargin{:});
  21. if ~isempty(p.Results.RNG)
  22. rng(p.Results.RNG)
  23. end
  24. % Initialize models
  25. if isempty(p.Results.ParticularModel)
  26. modelNames = {'base','prev','mean'};
  27. else
  28. modelNames = p.Results.ParticularModel;
  29. end
  30. % Set up optimization problem
  31. options = optimset('Algorithm', 'interior-point','ObjectiveLimit',...
  32. -1.000000000e+300,'TolFun',1e-15, 'Display','off');
  33. % set boundary conditions
  34. alpha_range = [0 1];
  35. slope_range = [0 20]; % reward sensitivity
  36. intercept_range = [-20 20]; % baseline spiking
  37. rho_range = [0 1];
  38. for currMod = modelNames
  39. currMod = currMod{:};
  40. % initialize output variables
  41. runs = p.Results.StartingPoints;
  42. LH = zeros(runs, 1);
  43. if strcmp(currMod, 'base')
  44. paramNames = {'alpha','slope','intercept','rho'};
  45. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  46. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  47. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  48. rand(runs, 1)*diff(rho_range) + rho_range(1)];
  49. numParam = size(startValues, 2);
  50. allParams = zeros(runs, numParam);
  51. A=[eye(numParam); -eye(numParam)];
  52. b=[ alpha_range(2); slope_range(2); intercept_range(2); rho_range(2);
  53. -alpha_range(1); -slope_range(1); -intercept_range(1); -rho_range(1)];
  54. parfor r = 1:runs
  55. [allParams(r, :), LH(r, :)] = ...
  56. fmincon(@ott_RW_V_base, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  57. end
  58. [~, bestFit] = min(LH);
  59. ms.(currMod).bestParams = allParams(bestFit, :);
  60. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  61. ott_RW_V_base(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  62. elseif strcmp(currMod, 'base_asymm')
  63. paramNames = {'alphaPPE','alphaNPE','slope','intercept','rho'};
  64. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  65. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  66. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  67. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  68. rand(runs, 1)*diff(rho_range) + rho_range(1)];
  69. numParam = size(startValues, 2);
  70. allParams = zeros(runs, numParam);
  71. A=[eye(numParam); -eye(numParam)];
  72. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2); rho_range(2);
  73. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1); -rho_range(1)];
  74. parfor r = 1:runs
  75. [allParams(r, :), LH(r, :)] = ...
  76. fmincon(@ott_RW_V_base_asymm, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  77. end
  78. [~, bestFit] = min(LH);
  79. ms.(currMod).bestParams = allParams(bestFit, :);
  80. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  81. ott_RW_V_base_asymm(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  82. elseif strcmp(currMod, 'prev')
  83. paramNames = {'slope','intercept','rho'};
  84. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  85. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  86. rand(runs, 1)*diff(rho_range) + rho_range(1)];
  87. numParam = size(startValues, 2);
  88. allParams = zeros(runs, numParam);
  89. A=[eye(numParam); -eye(numParam)];
  90. b=[ slope_range(2); intercept_range(2); rho_range(2);
  91. -slope_range(1); -intercept_range(1); -rho_range(1)];
  92. parfor r = 1:runs
  93. [allParams(r, :), LH(r, :)] = ...
  94. fmincon(@ott_RW_V_prev, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  95. end
  96. [~, bestFit] = min(LH);
  97. ms.(currMod).bestParams = allParams(bestFit, :);
  98. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  99. ott_RW_V_prev(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  100. elseif strcmp(currMod, 'mean')
  101. paramNames = {''};
  102. numParam = 0;
  103. [LH, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  104. ott_RW_V_mean([], os.spikeCount);
  105. bestFit = 1;
  106. ms.(currMod).bestParams = [];
  107. else
  108. error('RW model: Model name not found')
  109. end
  110. ms.(currMod).paramNames = paramNames;
  111. ms.(currMod).LH = -1 * LH(bestFit, :);
  112. ms.(currMod).BIC = log(length(os.spikeCount))*numParam - 2*ms.(currMod).LH;
  113. ms.(currMod).AIC = 2*numParam - 2*ms.(currMod).LH;
  114. ms.(currMod).AICc = ms.(currMod).AIC + (2*numParam^2 + 2*numParam)/(length(os.spikeCount) - numParam - 1);
  115. end