fit_RW_MLE.m 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. clear; clc
  2. task = 'intBlocks';
  3. dir_MLEfit_ott(task)
  4. ott_data = loadData_ott(task);
  5. nStart = 10;
  6. RNG_val = 1;
  7. %%
  8. clear os_temp os
  9. models_of_interest_RPE = {'base','base_flipped','base_asymm','base_asymm_flipped','base_rect', ...
  10. 'adapt','habit','habit_asymm','curr','curr_flipped','mean'};
  11. models_of_interest_V = {'base','base_asymm','mean'};
  12. all_fits = struct(); % initialize an empty structure
  13. for ind = 1:length(ott_data.RDHz) % for all neurons
  14. fprintf('n %i of %i\n', ind, length(ott_data.RDHz))
  15. os_temp(ind).Blocks = ott_data.Blocks(ind);
  16. os_temp(ind).Blocks12 = ott_data.Blocks12(ind);
  17. os_temp(ind).Region = ott_data.Region{ind};
  18. os_temp(ind).spikeCount_RD = round(ott_data.RDHz{ind}*1.2); % RD period; 1.2s long
  19. os_temp(ind).spikeCount_cue = round(ott_data.CueHz{ind}*0.75); % cue period; 0.75s long
  20. os_temp(ind).spikeCount_PE = ott_data.PEHz{ind}; % PE period; 1.0s long
  21. os_temp(ind).rewards = ott_data.AllTrials{ind}(:, 1); % 0 mal, 1 suc
  22. os_temp(ind).timeLocked = logical(ott_data.AllTrials{ind}(:, 2)); % trials fast enough to have time-locked responses
  23. % spikeCount is a temporary field
  24. % fit RD
  25. os_temp(ind).spikeCount = os_temp(ind).spikeCount_RD;
  26. ms = helper_RW_RPE(os_temp(ind), 'StartingPoints', nStart, 'RNG', RNG_val, 'ParticularModel', models_of_interest_RPE);
  27. os_temp(ind).mod_RD = ms;
  28. % fit cue
  29. os_temp(ind).spikeCount = os_temp(ind).spikeCount_cue;
  30. ms = helper_RW_V(os_temp(ind), 'StartingPoints', nStart, 'RNG', RNG_val, 'ParticularModel', models_of_interest_V);
  31. os_temp(ind).mod_cue = ms;
  32. % remove spikeCount to avoid future confusion
  33. os(ind) = rmfield(os_temp(ind), 'spikeCount');
  34. end
  35. fprintf('Finished\n')
  36. save_MLEfit_ott(task, os);