ott_RW_V_base.m 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_base(startValues, spikeCounts, rewards, timeLocked)
  2. % reward is 0: mal, 1: suc, 2: water
  3. alphaLearn = startValues(1);
  4. slope = startValues(2);
  5. intercept = startValues(3);
  6. rho = startValues(4); % how valuable is maltodextrin on a water -> malto -> sucrose scale
  7. Vinit = (1 + 0 + rho)/3;
  8. water_ind = rewards == 2;
  9. mal_ind = rewards == 0;
  10. rewards(water_ind) = 0;
  11. rewards(mal_ind) = rho; % scale mal between 0 and 1
  12. trials = length(rewards);
  13. V = zeros(trials + 1, 1);
  14. RPE = zeros(trials, 1);
  15. V(1) = Vinit;
  16. % Call learning rule
  17. for t = 1:trials
  18. RPE(t) = rewards(t) - V(t);
  19. V(t + 1) = V(t) + alphaLearn*RPE(t);
  20. end
  21. rateParam = exp(slope*V(1:trials) + intercept);
  22. rateParam(rateParam < 0) = 0.1; % set rate param to zero if it goes below; might consider a better rule in the future
  23. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  24. mean_predictedSpikes = rateParam(timeLocked);
  25. V = V(1:trials);
  26. V = V(timeLocked);
  27. RPE = RPE(timeLocked);
  28. if any(isinf(log(probSpike)))
  29. LH = 1e9;
  30. else
  31. LH = -1 * sum(log(probSpike));
  32. end