ott_RW_RPE_threeValue_asymm_cue.m 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_RPE_threeValue_asymm_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. % cued experiment
  3. alphaPPE = startValues(1);
  4. alphaNPE = startValues(2);
  5. slope = startValues(3);
  6. intercept = startValues(4);
  7. V_sucCue = startValues(5);
  8. V_malCue = startValues(6);
  9. Vinit = alphaPPE / (alphaPPE + alphaNPE);
  10. Vsuc = 1;
  11. Vmal = 0;
  12. trials = length(rewards);
  13. V = zeros(trials + 1, 1);
  14. RPE = zeros(trials, 1);
  15. RPEforObs = zeros(trials, 1);
  16. rateParam = zeros(trials, 1);
  17. V(1) = Vinit;
  18. % Call learning rule
  19. for t = 1:trials
  20. if cueInfo(t, 1) == 1 % sucrose cue
  21. RPEforObs(t) = rewards(t) - Vsuc;
  22. V(t + 1) = V(t); % carry forward the value function for the nonpredictive cue
  23. rateParam(t) = exp(slope*RPEforObs(t) + intercept + V_sucCue);
  24. elseif cueInfo(t, 2) == 1 % malto cue
  25. RPEforObs(t) = rewards(t) - Vmal;
  26. V(t + 1) = V(t); % carry forward the value function
  27. rateParam(t) = exp(slope*RPEforObs(t) + intercept + V_malCue);
  28. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  29. RPE(t) = rewards(t) - V(t);
  30. if RPE(t) >= 0
  31. V(t + 1) = V(t) + alphaPPE*RPE(t);
  32. else
  33. V(t + 1) = V(t) + alphaNPE*RPE(t);
  34. end
  35. RPEforObs(t) = RPE(t);
  36. rateParam(t) = exp(slope*RPEforObs(t) + intercept);
  37. else
  38. error('cueInfo is 0 for all columns\n');
  39. end
  40. end
  41. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  42. mean_predictedSpikes = rateParam(timeLocked);
  43. V = V(1:trials);
  44. V = V(timeLocked);
  45. RPE = RPE(timeLocked);
  46. if any(isinf(log(probSpike)))
  47. LH = 1e9;
  48. else
  49. LH = -1 * sum(log(probSpike));
  50. end