ott_RW_RPE_threeValue_cue.m 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_threeValue_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment
  3. alphaLearn = startValues(1);
  4. slope = startValues(2);
  5. intercept = startValues(3);
  6. V_sucCue = startValues(4);
  7. V_malCue = startValues(5);
  8. Vinit = 0.5;
  9. Vsuc = 1;
  10. Vmal = 0;
  11. trials = length(rewards);
  12. V = zeros(trials + 1, 1);
  13. RPE = zeros(trials, 1);
  14. RPEforObs = zeros(trials, 1);
  15. rateParam = zeros(trials, 1);
  16. V(1) = Vinit;
  17. % Call learning rule
  18. for t = 1:trials
  19. if cueInfo(t, 1) == 1 % sucrose cue
  20. RPEforObs(t) = rewards(t) - Vsuc;
  21. V(t + 1) = V(t); % carry forward the value function for the nonpredictive cue
  22. rateParam(t) = exp(slope*RPEforObs(t) + intercept + V_sucCue);
  23. elseif cueInfo(t, 2) == 1 % malto cue
  24. RPEforObs(t) = rewards(t) - Vmal;
  25. V(t + 1) = V(t); % carry forward the value function
  26. rateParam(t) = exp(slope*RPEforObs(t) + intercept + V_malCue);
  27. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  28. RPE(t) = rewards(t) - V(t);
  29. V(t + 1) = V(t) + alphaLearn*RPE(t);
  30. RPEforObs(t) = RPE(t);
  31. rateParam(t) = exp(slope*RPEforObs(t) + intercept);
  32. else
  33. error('cueInfo is 0 for all columns\n');
  34. end
  35. end
  36. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  37. mean_predictedSpikes = rateParam(timeLocked);
  38. V = V(1:trials);
  39. V = V(timeLocked);
  40. RPE = RPE(timeLocked);
  41. if any(isinf(log(probSpike)))
  42. LH = 1e9;
  43. else
  44. LH = -1 * sum(log(probSpike));
  45. end