ott_RW_RPE_threeValue_asymm.m 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_threeValue_asymm(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment
  3. alphaPPE = startValues(1);
  4. alphaNPE = startValues(2);
  5. slope = startValues(3);
  6. intercept = startValues(4);
  7. Vinit = alphaPPE / (alphaPPE + alphaNPE);
  8. Vsuc = 1;
  9. Vmal = 0;
  10. trials = length(rewards);
  11. V = zeros(trials + 1, 1);
  12. RPE = zeros(trials, 1);
  13. RPEforObs = zeros(trials, 1);
  14. V(1) = Vinit;
  15. % Call learning rule
  16. for t = 1:trials
  17. if cueInfo(t, 1) == 1 % sucrose cue
  18. RPEforObs(t) = rewards(t) - Vsuc;
  19. V(t + 1) = V(t); % carry forward the value function for the nonpredictive cue
  20. elseif cueInfo(t, 2) == 1 % malto cue
  21. RPEforObs(t) = rewards(t) - Vmal;
  22. V(t + 1) = V(t); % carry forward the value function
  23. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  24. RPE(t) = rewards(t) - V(t);
  25. if RPE(t) >= 0
  26. V(t + 1) = V(t) + alphaPPE*RPE(t);
  27. else
  28. V(t + 1) = V(t) + alphaNPE*RPE(t);
  29. end
  30. RPEforObs(t) = RPE(t);
  31. end
  32. end
  33. rateParam = exp(slope*RPEforObs + intercept);
  34. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  35. mean_predictedSpikes = rateParam(timeLocked);
  36. V = V(1:trials);
  37. V = V(timeLocked);
  38. RPE = RPE(timeLocked);
  39. if any(isinf(log(probSpike)))
  40. LH = 1e9;
  41. else
  42. LH = -1 * sum(log(probSpike));
  43. end