ott_RW_RPE_base_asymm.m 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_base_asymm(startValues, spikeCounts, rewards, timeLocked)
  2. % reward is 0: mal, 1: suc, 2: water
  3. alphaPPE = startValues(1);
  4. alphaNPE = startValues(2);
  5. slope = startValues(3);
  6. intercept = startValues(4);
  7. rho = startValues(5); % how valuable is maltodextrin on a water -> malto -> sucrose scale
  8. Vinit = 2/3*(alphaPPE / (alphaPPE + alphaNPE)) + 1/3*rho;
  9. water_ind = rewards == 2;
  10. mal_ind = rewards == 0;
  11. rewards(water_ind) = 0;
  12. rewards(mal_ind) = rho; % scale mal between 0 and 1
  13. trials = length(rewards);
  14. V = zeros(trials + 1, 1);
  15. RPE = zeros(trials, 1);
  16. V(1) = Vinit;
  17. % Call learning rule
  18. for t = 1:trials
  19. RPE(t) = rewards(t) - V(t);
  20. if RPE(t) >= 0
  21. V(t + 1) = V(t) + alphaPPE*RPE(t);
  22. else
  23. V(t + 1) = V(t) + alphaNPE*RPE(t);
  24. end
  25. end
  26. rateParam = exp(slope*RPE + intercept);
  27. rateParam(rateParam < 0) = 0.1; % set rate param to zero if it goes below; might consider a better rule in the future
  28. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  29. mean_predictedSpikes = rateParam(timeLocked);
  30. V = V(1:trials);
  31. V = V(timeLocked);
  32. RPE = RPE(timeLocked);
  33. if any(isinf(log(probSpike)))
  34. LH = 1e9;
  35. else
  36. LH = -1 * sum(log(probSpike));
  37. end