ott_RW_V_prev.m 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_prev(startValues, spikeCounts, rewards, timeLocked)
  2. % cued experiment; weight determines how much V_sucrose and V_mal influence spike counts
  3. % cueInfo: N x 3 where the left - suc, mid - mal, right - none
  4. slope = startValues(1);
  5. intercept = startValues(2);
  6. alphaLearn = 1; % only current reward can be encoded
  7. Vinit = 0.5;
  8. trials = length(rewards);
  9. V = zeros(trials + 1, 1);
  10. RPE = zeros(trials, 1); % RPE for trial-by-trial learning
  11. rateParam = zeros(trials, 1);
  12. V(1) = Vinit;
  13. % Call learning rule
  14. for t = 1:trials
  15. RPE(t) = rewards(t) - V(t);
  16. V(t + 1) = V(t) + alphaLearn*RPE(t);
  17. end
  18. rateParam = exp(slope*V(1:trials) + intercept);
  19. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  20. mean_predictedSpikes = rateParam(timeLocked);
  21. V = V(1:trials);
  22. V = V(timeLocked);
  23. RPE = RPE(timeLocked);
  24. if any(isinf(log(probSpike)))
  25. LH = 1e9;
  26. else
  27. LH = -1 * sum(log(probSpike));
  28. end