ott_RW_V_prev.m 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_prev(startValues, spikeCounts, rewards, timeLocked)
  2. slope = startValues(1);
  3. intercept = startValues(2);
  4. rho = startValues(3); % how valuable is maltodextrin on a water -> malto -> sucrose scale
  5. alphaLearn = 1; % value just reflects previous trial outcome
  6. Vinit = 0.5;
  7. water_ind = rewards == 2;
  8. mal_ind = rewards == 0;
  9. rewards(water_ind) = 0;
  10. rewards(mal_ind) = rho; % scale mal between 0 and 1
  11. trials = length(rewards);
  12. V = zeros(trials + 1, 1);
  13. RPE = zeros(trials, 1);
  14. V(1) = Vinit;
  15. % Call learning rule
  16. for t = 1:trials
  17. RPE(t) = rewards(t) - V(t);
  18. V(t + 1) = V(t) + alphaLearn*RPE(t);
  19. end
  20. rateParam = exp(slope*V(1:trials) + intercept);
  21. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  22. mean_predictedSpikes = rateParam(timeLocked);
  23. V = V(1:trials);
  24. V = V(timeLocked);
  25. RPE = RPE(timeLocked);
  26. if any(isinf(log(probSpike)))
  27. LH = 1e9;
  28. else
  29. LH = -1 * sum(log(probSpike));
  30. end