ott_RW_RPE_base_asymm.m 992 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_base_asymm(startValues, spikeCounts, rewards, timeLocked)
  2. % cued experiment
  3. alphaPPE = startValues(1);
  4. alphaNPE = startValues(2);
  5. slope = startValues(3);
  6. intercept = startValues(4);
  7. Vinit = alphaPPE / (alphaPPE + alphaNPE);
  8. trials = length(rewards);
  9. V = zeros(trials + 1, 1);
  10. RPE = zeros(trials, 1);
  11. V(1) = Vinit;
  12. % Call learning rule
  13. for t = 1:trials
  14. RPE(t) = rewards(t) - V(t);
  15. if RPE(t) >= 0
  16. V(t + 1) = V(t) + alphaPPE*RPE(t);
  17. else
  18. V(t + 1) = V(t) + alphaNPE*RPE(t);
  19. end
  20. end
  21. rateParam = exp(slope*RPE + intercept);
  22. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  23. mean_predictedSpikes = rateParam(timeLocked);
  24. V = V(1:trials);
  25. V = V(timeLocked);
  26. RPE = RPE(timeLocked);
  27. if any(isinf(log(probSpike)))
  28. LH = 1e9;
  29. else
  30. LH = -1 * sum(log(probSpike));
  31. end