helper_RW_RPE.m 16 KB


  1. function ms = helper_RW_RPE(os, varargin)
  2. % helper_RW_RPE Fits neural models to test RW model predictions ~ RPE
  3. % ms = helper_RW_RPE(os, varargin)
  4. % INPUTS
  5. % os: behavioral data structure
  6. % .spikeCount: number of spikes in a particular time bin
  7. % .rewards: 0 for maltodextrin, 1 for sucrose
  8. % .timeLocked: 1 if the animal licked within 2s, 0 if the animal did not (logical)
  9. % timeLocked will always be more than or equal to the number of spike trials
  10. % .cueInfo: N x 3; left column 1 for suc cue, middle column for mal, right noninformative
  11. % varargin
  12. % StartingPoints: determines how many points to optimize from
  13. % ParticularModel: cell array of strings of models to use
  14. % RNG: random number generator seed (default = 1)
  15. % OUTPUTS
  16. % ms: model structure of fits
  17. p = inputParser;
  18. p.addParameter('StartingPoints', 1)
  19. p.addParameter('ParticularModel', []);
  20. p.addParameter('RNG', []);
  21. p.parse(varargin{:});
  22. if ~isempty(p.Results.RNG)
  23. rng(p.Results.RNG)
  24. end
  25. % Initialize models
  26. if isempty(p.Results.ParticularModel)
  27. modelNames = {'base', 'base_cue', ...
  28. 'curr', 'curr_cue', ...
  29. 'mean', 'mean_cue'};
  30. else
  31. modelNames = p.Results.ParticularModel;
  32. end
  33. % Set up optimization problem
  34. options = optimset('Algorithm', 'interior-point','ObjectiveLimit',...
  35. -1.000000000e+300,'TolFun',1e-15, 'Display','off');
  36. % set boundary conditions
  37. alpha_range = [0 1];
  38. slope_range = [0 20]; % reward sensitivity
  39. intercept_range = [-20 20]; % baseline spiking
  40. vsuc_range = [-5 5]; % value of sucrose cue
  41. vmal_range = [-5 5]; % value of mal cue
  42. for currMod = modelNames
  43. currMod = currMod{:};
  44. % initialize output variables
  45. runs = p.Results.StartingPoints;
  46. LH = zeros(runs, 1);
  47. if strcmp(currMod, 'base')
  48. paramNames = {'alpha','slope','intercept'};
  49. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  50. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  51. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  52. numParam = size(startValues, 2);
  53. allParams = zeros(runs, numParam);
  54. A=[eye(numParam); -eye(numParam)];
  55. b=[ alpha_range(2); slope_range(2); intercept_range(2);
  56. -alpha_range(1); -slope_range(1); -intercept_range(1)];
  57. parfor r = 1:runs
  58. [allParams(r, :), LH(r, :)] = ...
  59. fmincon(@ott_RW_RPE_base, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  60. end
  61. [~, bestFit] = min(LH);
  62. ms.(currMod).bestParams = allParams(bestFit, :);
  63. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  64. ott_RW_RPE_base(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  65. elseif strcmp(currMod, 'base_asymm')
  66. paramNames = {'alphaPPE','alphaNPE','slope','intercept'};
  67. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  68. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  69. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  70. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  71. numParam = size(startValues, 2);
  72. allParams = zeros(runs, numParam);
  73. A=[eye(numParam); -eye(numParam)];
  74. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2);
  75. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1)];
  76. parfor r = 1:runs
  77. [allParams(r, :), LH(r, :)] = ...
  78. fmincon(@ott_RW_RPE_base_asymm, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  79. end
  80. [~, bestFit] = min(LH);
  81. ms.(currMod).bestParams = allParams(bestFit, :);
  82. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  83. ott_RW_RPE_base_asymm(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  84. elseif strcmp(currMod, 'base_cue')
  85. paramNames = {'alpha','slope','intercept','vsuc','vmal'};
  86. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  87. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  88. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  89. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  90. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  91. numParam = size(startValues, 2);
  92. allParams = zeros(runs, numParam);
  93. A=[eye(numParam); -eye(numParam)];
  94. b=[ alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2);
  95. -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  96. parfor r = 1:runs
  97. [allParams(r, :), LH(r, :)] = ...
  98. fmincon(@ott_RW_RPE_base_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  99. end
  100. [~, bestFit] = min(LH);
  101. ms.(currMod).bestParams = allParams(bestFit, :);
  102. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  103. ott_RW_RPE_base_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  104. elseif strcmp(currMod, 'base_asymm_cue')
  105. paramNames = {'alphaPPE','alphaNPE','slope','intercept','vsuc','vmal'};
  106. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  107. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  108. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  109. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  110. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  111. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  112. numParam = size(startValues, 2);
  113. allParams = zeros(runs, numParam);
  114. A=[eye(numParam); -eye(numParam)];
  115. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2);
  116. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  117. parfor r = 1:runs
  118. [allParams(r, :), LH(r, :)] = ...
  119. fmincon(@ott_RW_RPE_base_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  120. end
  121. [~, bestFit] = min(LH);
  122. ms.(currMod).bestParams = allParams(bestFit, :);
  123. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  124. ott_RW_RPE_base_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  125. elseif strcmp(currMod, 'threeValue')
  126. paramNames = {'alpha','slope','intercept'};
  127. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  128. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  129. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  130. numParam = size(startValues, 2);
  131. allParams = zeros(runs, numParam);
  132. A=[eye(numParam); -eye(numParam)];
  133. b=[ alpha_range(2); slope_range(2); intercept_range(2);
  134. -alpha_range(1); -slope_range(1); -intercept_range(1)];
  135. parfor r = 1:runs
  136. [allParams(r, :), LH(r, :)] = ...
  137. fmincon(@ott_RW_RPE_threeValue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  138. end
  139. [~, bestFit] = min(LH);
  140. ms.(currMod).bestParams = allParams(bestFit, :);
  141. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  142. ott_RW_RPE_threeValue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  143. elseif strcmp(currMod, 'threeValue_asymm')
  144. paramNames = {'alpha','slope','intercept'};
  145. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  146. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  147. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  148. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  149. numParam = size(startValues, 2);
  150. allParams = zeros(runs, numParam);
  151. A=[eye(numParam); -eye(numParam)];
  152. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2);
  153. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1)];
  154. parfor r = 1:runs
  155. [allParams(r, :), LH(r, :)] = ...
  156. fmincon(@ott_RW_RPE_threeValue_asymm, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  157. end
  158. [~, bestFit] = min(LH);
  159. ms.(currMod).bestParams = allParams(bestFit, :);
  160. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  161. ott_RW_RPE_threeValue_asymm(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  162. elseif strcmp(currMod, 'threeValue_cue')
  163. paramNames = {'alpha','slope','intercept','vsuc','vmal'};
  164. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  165. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  166. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  167. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  168. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  169. numParam = size(startValues, 2);
  170. allParams = zeros(runs, numParam);
  171. A=[eye(numParam); -eye(numParam)];
  172. b=[ alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2);
  173. -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  174. parfor r = 1:runs
  175. [allParams(r, :), LH(r, :)] = ...
  176. fmincon(@ott_RW_RPE_threeValue_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  177. end
  178. [~, bestFit] = min(LH);
  179. ms.(currMod).bestParams = allParams(bestFit, :);
  180. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  181. ott_RW_RPE_threeValue_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  182. elseif strcmp(currMod, 'threeValue_asymm_cue')
  183. paramNames = {'alphaPPE','alphaNPE','slope','intercept','vsuc','vmal'};
  184. startValues = [rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  185. rand(runs, 1)*diff(alpha_range) + alpha_range(1) ...
  186. rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  187. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  188. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  189. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  190. numParam = size(startValues, 2);
  191. allParams = zeros(runs, numParam);
  192. A=[eye(numParam); -eye(numParam)];
  193. b=[ alpha_range(2); alpha_range(2); slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2);
  194. -alpha_range(1); -alpha_range(1); -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  195. parfor r = 1:runs
  196. [allParams(r, :), LH(r, :)] = ...
  197. fmincon(@ott_RW_RPE_threeValue_asymm_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  198. end
  199. [~, bestFit] = min(LH);
  200. ms.(currMod).bestParams = allParams(bestFit, :);
  201. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  202. ott_RW_RPE_threeValue_asymm_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  203. elseif strcmp(currMod, 'curr')
  204. paramNames = {'slope','intercept'};
  205. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  206. rand(runs, 1)*diff(intercept_range) + intercept_range(1)];
  207. numParam = size(startValues, 2);
  208. allParams = zeros(runs, numParam);
  209. A=[eye(numParam); -eye(numParam)];
  210. b=[ slope_range(2); intercept_range(2);
  211. -slope_range(1); -intercept_range(1)];
  212. parfor r = 1:runs
  213. [allParams(r, :), LH(r, :)] = ...
  214. fmincon(@ott_RW_RPE_curr, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked);
  215. end
  216. [~, bestFit] = min(LH);
  217. ms.(currMod).bestParams = allParams(bestFit, :);
  218. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  219. ott_RW_RPE_curr(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked);
  220. elseif strcmp(currMod, 'curr_cue')
  221. paramNames = {'slope','intercept','vsuc','vmal'};
  222. startValues = [rand(runs, 1)*diff(slope_range) + slope_range(1) ...
  223. rand(runs, 1)*diff(intercept_range) + intercept_range(1) ...
  224. rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  225. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  226. numParam = size(startValues, 2);
  227. allParams = zeros(runs, numParam);
  228. A=[eye(numParam); -eye(numParam)];
  229. b=[ slope_range(2); intercept_range(2); vsuc_range(2); vmal_range(2);
  230. -slope_range(1); -intercept_range(1); -vsuc_range(1); -vmal_range(1)];
  231. parfor r = 1:runs
  232. [allParams(r, :), LH(r, :)] = ...
  233. fmincon(@ott_RW_RPE_curr_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  234. end
  235. [~, bestFit] = min(LH);
  236. ms.(currMod).bestParams = allParams(bestFit, :);
  237. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  238. ott_RW_RPE_curr_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  239. elseif strcmp(currMod, 'mean')
  240. paramNames = {''};
  241. numParam = 0;
  242. [LH, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  243. ott_RW_RPE_mean([], os.spikeCount);
  244. bestFit = 1;
  245. ms.(currMod).bestParams = [];
  246. elseif strcmp(currMod, 'mean_cue')
  247. paramNames = {'vsuc','vmal'};
  248. startValues = [rand(runs, 1)*diff(vsuc_range) + vsuc_range(1) ...
  249. rand(runs, 1)*diff(vmal_range) + vmal_range(1)];
  250. numParam = size(startValues, 2);
  251. allParams = zeros(runs, numParam);
  252. A=[eye(numParam); -eye(numParam)];
  253. b=[ vsuc_range(2); vmal_range(2);
  254. -vsuc_range(1); -vmal_range(1)];
  255. parfor r = 1:runs
  256. [allParams(r, :), LH(r, :)] = ...
  257. fmincon(@ott_RW_RPE_mean_cue, startValues(r, :), A, b, [], [], [], [], [], options, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  258. end
  259. [~, bestFit] = min(LH);
  260. ms.(currMod).bestParams = allParams(bestFit, :);
  261. [~, ms.(currMod).probSpike, ms.(currMod).V, ms.(currMod).mean_predictedSpikes, ms.(currMod).RPEs] = ...
  262. ott_RW_RPE_mean_cue(ms.(currMod).bestParams, os.spikeCount, os.rewards, os.timeLocked, os.cueInfo);
  263. else
  264. error('RW model: Model name not found')
  265. end
  266. ms.(currMod).paramNames = paramNames;
  267. ms.(currMod).LH = -1 * LH(bestFit, :);
  268. ms.(currMod).BIC = log(length(os.spikeCount))*numParam - 2*ms.(currMod).LH;
  269. ms.(currMod).AIC = 2*numParam - 2*ms.(currMod).LH;
  270. ms.(currMod).AICc = ms.(currMod).AIC + (2*numParam^2 + 2*numParam)/(length(os.spikeCount) - numParam - 1);
  271. % ms.(currMod).CIvals = sqrt(diag(inv(hess)))'*1.96;
  272. end