ott_RW_V_base_asymm_cue.m 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_base_asymm_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment
  3. alphaPPE = startValues(1);
  4. alphaNPE = startValues(2);
  5. slope = startValues(3);
  6. intercept = startValues(4);
  7. V_sucCue = startValues(5);
  8. V_malCue = startValues(6);
  9. Vinit = alphaPPE / (alphaPPE + alphaNPE);
  10. trials = length(rewards);
  11. V = zeros(trials + 1, 1);
  12. RPE = zeros(trials, 1);
  13. rateParam = zeros(trials, 1);
  14. V(1) = Vinit;
  15. % Call learning rule
  16. for t = 1:trials
  17. RPE(t) = rewards(t) - V(t);
  18. if RPE(t) >= 0
  19. V(t + 1) = V(t) + alphaPPE*RPE(t);
  20. else
  21. V(t + 1) = V(t) + alphaNPE*RPE(t);
  22. end
  23. if cueInfo(t, 1) == 1 % sucrose cue
  24. rateParam(t) = exp(slope*V(t) + intercept + V_sucCue);
  25. elseif cueInfo(t, 2) == 1 % malto cue
  26. rateParam(t) = exp(slope*V(t) + intercept + V_malCue);
  27. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  28. rateParam(t) = exp(slope*V(t) + intercept);
  29. else
  30. error('cueInfo is 0 for all columns\n');
  31. end
  32. end
  33. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  34. mean_predictedSpikes = rateParam(timeLocked);
  35. V = V(1:trials);
  36. V = V(timeLocked);
  37. RPE = RPE(timeLocked);
  38. if any(isinf(log(probSpike)))
  39. LH = 1e9;
  40. else
  41. LH = -1 * sum(log(probSpike));
  42. end