helper_RW_V.m 11 KB


  1. function ms = helper_RW_V(os, varargin)
  2. % helper_RW_V Fits neural models to test RW model predictions ~ value
  3. % ms = helper_RW_V(os, varargin)
  4. % INPUTS
  5. % os: behavioral data structure
  6. % .spikeCount: number of spikes in a particular time bin
  7. % .rewards: 0 for maltodextrin, 1 for sucrose
  8. % .timeLocked: 1 if the animal licked within 2s, 0 if the animal did not (logical)
  9. % timeLocked will always be more than or equal to the number of spike trials
  10. % .cueInfo: N x 3; left column 1 for suc cue, middle column for mal, right noninformative
  11. % varargin
  12. % StartingPoints: determines how many points to optimize from
  13. % ParticularModel: cell array of strings of models to use
  14. % RNG: random number generator seed (default = 1)
  15. % OUTPUTS
  16. % ms: model structure of fits
  17. p = inputParser;
  18. p.addParameter('StartingPoints', 1)
  19. p.addParameter('ParticularModel', []);
  20. p.addParameter('RNG', []);
  21. p.parse(varargin{:});
  22. if ~isempty(p.Results.RNG)
  23. rng(p.Results.RNG)
  24. end
  25. % Initialize models
  26. if isempty(p.Results.ParticularModel)
  27. modelNames = {'base', 'base_cue', ...
  28. 'prev', 'prev_cue', ...
  29. 'mean', 'mean_cue'};
  30. else
  31. modelNames = p.Results.ParticularModel;
  32. end
  33. % Set up optimization problem
  34. options = optimset('Algorithm', 'interior-point','ObjectiveLimit',...
  35. -1.000000000e+300,'TolFun',1e-15, 'Display','off');
  36. % set boundary conditions
  37. alpha_range = [0 1];
  38. slope_range = [0 20]; % reward sensitivity
  39. intercept_range = [-20 20]; % baseline spiking
  40. vsuc_range = [-5 5]; % value of sucrose cue
  41. vmal_range = [-5 5]; % value of mal cue
  42. for currMod = modelNames
  43. currMod = currMod{:};
  44. % initialize output variables
  45. runs = p.Results.StartingPoints;
  46. LH = zeros(runs, 1);
  47. if strcmp(currMod, 'base')
  48. paramNames = {'alpha','slope','intercept'};
  49. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  50. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  51. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  52. numParam = size(startValues, 2);
  53. allParams = zeros(runs, numParam);
  54. A=[eye(numParam); -eye(numParam)];
  55. b=[ alpha_range(2); slope_range(2); intercept_range(2);
  56. -alpha_range(1); -slope_range(1); -intercept_range(1)];
  57. parfor r = 1:runs
  58. [allParams(r, :), LH(r, :)] = ...
  59. fmincon(@ott_RW_V_base, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  60. end
  61. [~, bestFit] = min(LH);
  62. ms.(currMod).bestParams = allParams(bestFit, :);
  63. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  64. ott_RW_V_base(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  65. elseif strcmp(currMod, 'base_asymm')
  66. paramNames = {'alphaPPE','alphaNPE','slope','intercept'};
  67. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  68. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  69. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  70. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  71. numParam = size(startValues, 2);
  72. allParams = zeros(runs, numParam);
  73. A=[eye(numParam); -eye(numParam)];
  74. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2);
  75. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1)];
  76. parfor r = 1:runs
  77. [allParams(r, :), LH(r, :)] = ...
  78. fmincon(@ott_RW_V_base_asymm, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  79. end
  80. [~, bestFit] = min(LH);
  81. ms.(currMod).bestParams = allParams(bestFit, :);
  82. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  83. ott_RW_V_base_asymm(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  84. elseif strcmp(currMod, 'base_cue')
  85. paramNames = {'alpha','slope','intercept','vsuc','vmal'};
  86. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  87. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  88. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  89. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  90. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  91. numParam = size(startValues, 2);
  92. allParams = zeros(runs, numParam);
  93. A=[eye(numParam); -eye(numParam)];
  94. b=[ alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2); ...
  95. -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  96. parfor r = 1:runs
  97. [allParams(r, :), LH(r, :)] = ...
  98. fmincon(@ott_RW_V_base_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  99. end
  100. [~, bestFit] = min(LH);
  101. ms.(currMod).bestParams = allParams(bestFit, :);
  102. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  103. ott_RW_V_base_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  104. elseif strcmp(currMod, 'base_asymm_cue')
  105. paramNames = {'alphaPPE','alphaNPE','slope','intercept','vsuc','vmal'};
  106. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  107. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  108. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  109. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  110. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  111. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  112. numParam = size(startValues, 2);
  113. allParams = zeros(runs, numParam);
  114. A=[eye(numParam); -eye(numParam)];
  115. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2); ...
  116. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  117. parfor r = 1:runs
  118. [allParams(r, :), LH(r, :)] = ...
  119. fmincon(@ott_RW_V_base_asymm_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  120. end
  121. [~, bestFit] = min(LH);
  122. ms.(currMod).bestParams = allParams(bestFit, :);
  123. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  124. ott_RW_V_base_asymm_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  125. elseif strcmp(currMod, 'prev')
  126. paramNames = {'slope','intercept'};
  127. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  128. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  129. numParam = size(startValues, 2);
  130. allParams = zeros(runs, numParam);
  131. A=[eye(numParam); -eye(numParam)];
  132. b=[ slope_range(2); intercept_range(2);
  133. -slope_range(1); -intercept_range(1)];
  134. parfor r = 1:runs
  135. [allParams(r, :), LH(r, :)] = ...
  136. fmincon(@ott_RW_V_prev, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  137. end
  138. [~, bestFit] = min(LH);
  139. ms.(currMod).bestParams = allParams(bestFit, :);
  140. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  141. ott_RW_V_prev(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  142. elseif strcmp(currMod, 'prev_cue')
  143. paramNames = {'slope','intercept','vsuc','vmal'};
  144. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  145. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  146. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  147. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  148. numParam = size(startValues, 2);
  149. allParams = zeros(runs, numParam);
  150. A=[eye(numParam); -eye(numParam)];
  151. b=[ slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2); ...
  152. -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  153. parfor r = 1:runs
  154. [allParams(r, :), LH(r, :)] = ...
  155. fmincon(@ott_RW_V_prev_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  156. end
  157. [~, bestFit] = min(LH);
  158. ms.(currMod).bestParams = allParams(bestFit, :);
  159. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  160. ott_RW_V_prev_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  161. elseif strcmp(currMod, 'mean')
  162. paramNames = {''};
  163. numParam = 0;
  164. [LH, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  165. ott_RW_V_mean([], os.spikeCount);
  166. bestFit = 1;
  167. ms.(currMod).bestParams = [];
  168. hess = [];
  169. elseif strcmp(currMod, 'mean_cue')
  170. paramNames = {'vsuc','vmal'};
  171. startValues = [rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  172. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  173. numParam = size(startValues, 2);
  174. allParams = zeros(runs, numParam);
  175. A=[eye(numParam); -eye(numParam)];
  176. b=[ vsuc_range(2); vmal_range(2); ...
  177. -vsuc_range(1); -vmal_range(1)];
  178. parfor r = 1:runs
  179. [allParams(r, :), LH(r, :)] = ...
  180. fmincon(@ott_RW_V_mean_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  181. end
  182. [~, bestFit] = min(LH);
  183. ms.(currMod).bestParams = allParams(bestFit, :);
  184. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  185. ott_RW_V_mean_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  186. else
  187. error('RW model: Model name not found')
  188. end
  189. ms.(currMod).paramNames = paramNames;
  190. ms.(currMod).LH = -1 * LH(bestFit, :);
  191. ms.(currMod).BIC = log(length(os.spikeCount))*numParam - 2*ms.(currMod).LH;
  192. ms.(currMod).AIC = 2*numParam - 2*ms.(currMod).LH;
  193. ms.(currMod).AICc = ms.(currMod).AIC + (2*numParam^2 + 2*numParam)/(length(os.spikeCount) - numParam - 1);
  194. % ms.(currMod).CIvals = sqrt(diag(inv(hess)))'*1.96;
  195. end