ott_RW_V_mean_cue.m 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. function [LH, probSpike, V, mean_predictedSpikes, RPE] = ott_RW_V_mean_cue(startValues, spikeCounts, rewards, timeLocked, cueInfo)
  2. V_sucCue = startValues(1);
  3. V_malCue = startValues(2);
  4. trials = length(rewards);
  5. nonpred_trials = cueInfo(:, 3) == 1; % trials without predictive cues
  6. nonpred_trials = nonpred_trials(timeLocked); % subselect the ones with responses
  7. mean_nonpredSpikes = mean(spikeCounts(nonpred_trials)); % mean firing of non-predicted trials
  8. log_mean_nonpredSpikes = log(mean_nonpredSpikes); % log transform so exp(log...) = mean_nonpredSpikes
  9. rateParam = zeros(trials, 1);
  10. for t = 1:trials
  11. if cueInfo(t, 1) == 1 % sucrose cue
  12. rateParam(t) = exp(log_mean_nonpredSpikes + V_sucCue);
  13. elseif cueInfo(t, 2) == 1 % malto cue
  14. rateParam(t) = exp(log_mean_nonpredSpikes + V_malCue);
  15. elseif cueInfo(t, 3) == 1 % nonpredictive cue
  16. rateParam(t) = exp(log_mean_nonpredSpikes);
  17. else
  18. error('cueInfo is 0 for all columns\n');
  19. end
  20. end
  21. probSpike = poisspdf(spikeCounts, rateParam(timeLocked)); % mask rateParam to exclude trials where the animal didn't lick fast enough
  22. mean_predictedSpikes = rateParam(timeLocked);
  23. if any(isinf(log(probSpike)))
  24. LH = 1e9;
  25. else
  26. LH = -1 * sum(log(probSpike));
  27. end
  28. V = NaN;
  29. RPE = NaN;