ott_RW_V_base.m 829 B

1234567891011121314151617181920212223242526272829303132
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_base(startValues, spikeCounts, rewards, timeLocked)
  2. alphaLearn = startValues(1);
  3. slope = startValues(2);
  4. intercept = startValues(3);
  5. Vinit = 0.5;
  6. trials = length(rewards);
  7. V = zeros(trials + 1, 1);
  8. RPE = zeros(trials, 1);
  9. V(1) = Vinit;
  10. % Call learning rule
  11. for t = 1:trials
  12. RPE(t) = rewards(t) - V(t);
  13. V(t + 1) = V(t) + alphaLearn*RPE(t);
  14. end
  15. rateParam = exp(slope*V(1:trials) + intercept);
  16. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  17. mean_predictedSpikes = rateParam(timeLocked);
  18. V = V(1:trials);
  19. V = V(timeLocked);
  20. RPE = RPE(timeLocked);
  21. if any(isinf(log(probSpike)))
  22. LH = 1e9;
  23. else
  24. LH = -1 * sum(log(probSpike));
  25. end