ott_RW_RPE_habit.m 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE, slope_vec] = ott_RW_RPE_habit(startValues, spikeCounts, rewards, timeLocked)
  2. alphaLearn = startValues(1);
  3. slope_max = startValues(2); % slope positive
  4. slope_min = startValues(3);
  5. int = startValues(4); % intercept
  6. Vinit = 0.5;
  7. trials = length(rewards);
  8. V = zeros(trials + 1, 1);
  9. RPE = zeros(trials, 1);
  10. V(1) = Vinit;
  11. % Call learning rule
  12. for t = 1:trials
  13. RPE(t) = rewards(t) - V(t);
  14. V(t + 1) = V(t) + alphaLearn*RPE(t);
  15. end
  16. slope_vec = (slope_min - slope_max)*V(1:trials) + slope_min; % slope modulated by value
  17. % right side (log_mean_spikes - slope_vec*0.5) constrains firing rate to mean firing rate at r = 0.5
  18. rateParam = exp(slope_vec.*rewards + int - slope_vec*0.5); % firing rate modulated by rewards
  19. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  20. mean_predictedSpikes = rateParam(timeLocked);
  21. V = V(1:trials);
  22. V = V(timeLocked);
  23. RPE = RPE(timeLocked);
  24. slope_vec = slope_vec(timeLocked);
  25. if any(isinf(log(probSpike)))
  26. LH = 1e9;
  27. else
  28. LH = -1 * sum(log(probSpike));
  29. end