ott_RW_V_base_cue.m 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_base_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment
  3. alphaLearn = startValues(1);
  4. slope = startValues(2);
  5. intercept = startValues(3);
  6. V_sucCue = startValues(4);
  7. V_malCue = startValues(5);
  8. Vinit = 0.5;
  9. trials = length(rewards);
  10. V = zeros(trials + 1, 1);
  11. RPE = zeros(trials, 1);
  12. rateParam = zeros(trials, 1);
  13. V(1) = Vinit;
  14. % Call learning rule
  15. for t = 1:trials
  16. RPE(t) = rewards(t) - V(t);
  17. V(t + 1) = V(t) + alphaLearn*RPE(t);
  18. if cueInfo(t, 1) == 1 % sucrose cue
  19. rateParam(t) = exp(slope*V(t) + intercept + V_sucCue);
  20. elseif cueInfo(t, 2) == 1 % malto cue
  21. rateParam(t) = exp(slope*V(t) + intercept + V_malCue);
  22. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  23. rateParam(t) = exp(slope*V(t) + intercept);
  24. else
  25. error('cueInfo is 0 for all columns\n');
  26. end
  27. end
  28. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  29. mean_predictedSpikes = rateParam(timeLocked);
  30. V = V(1:trials);
  31. V = V(timeLocked);
  32. RPE = RPE(timeLocked);
  33. if any(isinf(log(probSpike)))
  34. LH = 1e9;
  35. else
  36. LH = -1 * sum(log(probSpike));
  37. end