ott_RW_RPE_curr_cue.m 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_curr_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment; weight determines how much V_sucrose and V_mal influence spike counts
  3. % cueInfo: N x 3 where the left - suc, mid - mal, right - none
  4. slope = startValues(1);
  5. intercept = startValues(2);
  6. V_sucCue = startValues(3);
  7. V_malCue = startValues(4);
  8. alphaLearn = 0; % only current reward can be encoded
  9. Vinit = 0.5;
  10. trials = length(rewards);
  11. V = zeros(trials + 1, 1);
  12. RPE = zeros(trials, 1); % RPE for trial-by-trial learning
  13. rateParam = zeros(trials, 1);
  14. V(1) = Vinit;
  15. % Call learning rule
  16. for t = 1:trials
  17. RPE(t) = rewards(t) - V(t);
  18. V(t + 1) = V(t) + alphaLearn*RPE(t);
  19. if cueInfo(t, 1) == 1 % sucrose cue
  20. rateParam(t) = exp(slope*RPE(t) + intercept + V_sucCue);
  21. elseif cueInfo(t, 2) == 1 % malto cue
  22. rateParam(t) = exp(slope*RPE(t) + intercept + V_malCue);
  23. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  24. rateParam(t) = exp(slope*RPE(t) + intercept);
  25. else
  26. error('cueInfo is 0 for all columns\n');
  27. end
  28. end
  29. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  30. mean_predictedSpikes = rateParam(timeLocked);
  31. V = V(1:trials);
  32. V = V(timeLocked);
  33. RPE = RPE(timeLocked);
  34. if any(isinf(log(probSpike)))
  35. LH = 1e9;
  36. else
  37. LH = -1 * sum(log(probSpike));
  38. end