fit_RW_MLE.m 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. clear; clc
  2. task = 'cue';
  3. dir_MLEfit_ott(task)
  4. ott_data = loadData_ott(task);
  5. nStart = 10;
  6. RNG_val = 1;
  7. % just the included neurons
  8. my_n = ott_data.IncludedNeurons;
  9. ott_data.Ninfo = {ott_data.Ninfo{my_n,1}; ott_data.Ninfo{my_n,2}}';
  10. ott_data.RDHz = ott_data.RDHz(my_n);
  11. ott_data.PEHz = ott_data.PEHz(my_n);
  12. ott_data.CueHz = ott_data.CueHz(my_n);
  13. ott_data.Predictors = ott_data.Predictors(my_n);
  14. ott_data.AllTrials = ott_data.AllTrials(my_n);
  15. ott_data.PredCue = ott_data.PredCue(my_n);
  16. ott_data.PredCueAllTrials = ott_data.PredCueAllTrials(my_n);
  17. ott_data.Day = ott_data.Day(my_n);
  18. ott_data.Rat = ott_data.Rat(my_n);
  19. ott_data.IncludedNeurons = ott_data.IncludedNeurons(my_n);
  20. %%
  21. clear os_temp os
  22. models_of_interest_RPE = {'base','base_asymm','base_cue','base_asymm_cue',...
  23. 'threeValue','threeValue_asymm','threeValue_cue','threeValue_asymm_cue',...
  24. 'curr','curr_cue','mean','mean_cue'};
  25. models_of_interest_V = {'base','base_asymm','base_cue','base_asymm_cue','mean','mean_cue'};
  26. % models_of_interest_RPE = {'base_asymm','base_asymm_cue',...
  27. % 'threeValue_asymm','threeValue_asymm_cue'};
  28. % models_of_interest_V = {'base_asymm','base_asymm_cue'};
  29. all_fits = struct(); % initialize an empty structure
  30. tic
  31. for ind = 1:length(ott_data.AllTrials) % for all neurons
  32. fprintf('n %i of %i. Elapsed time is %0.2f min\n', ind, length(ott_data.AllTrials), toc/60)
  33. os_temp(ind).include = ott_data.IncludedNeurons(ind);
  34. os_temp(ind).Ninfo = ott_data.Ninfo(ind, :);
  35. os_temp(ind).day = ott_data.Day(ind);
  36. os_temp(ind).rat = ott_data.Rat(ind);
  37. os_temp(ind).spikeCount_RD = round(ott_data.RDHz{ind}*1.2); % RD period; 1.2s long
  38. os_temp(ind).spikeCount_cue = round(ott_data.CueHz{ind}*0.75); % cue period; 0.75s long
  39. os_temp(ind).spikeCount_PE = ott_data.PEHz{ind};
  40. os_temp(ind).rewards = ott_data.AllTrials{ind}(:, 1); % 0 mal, 1 suc
  41. os_temp(ind).timeLocked = logical(ott_data.AllTrials{ind}(:, 2)); % trials fast enough to have time-locked responses
  42. os_temp(ind).cueInfo = ott_data.PredCueAllTrials{ind}; % predictive cues logical mask [suc mal none]
  43. % spikeCount is a temporary field
  44. % fit RD
  45. os_temp(ind).spikeCount = os_temp(ind).spikeCount_RD;
  46. ms = helper_RW_RPE(os_temp(ind), 'StartingPoints', nStart, 'RNG', RNG_val, 'ParticularModel', models_of_interest_RPE);
  47. os_temp(ind).mod_RD = ms;
  48. % fit cue
  49. os_temp(ind).spikeCount = os_temp(ind).spikeCount_cue;
  50. ms = helper_RW_V(os_temp(ind), 'StartingPoints', nStart, 'RNG', RNG_val, 'ParticularModel', models_of_interest_V);
  51. os_temp(ind).mod_cue = ms;
  52. % remove spikeCount to avoid future confusion
  53. os(ind) = rmfield(os_temp(ind), 'spikeCount');
  54. end
  55. fprintf('Finished\n')
  56. save_MLEfit_ott(task, os);
  57. %%
  58. for n = 1:length(os)
  59. os(n).mod_RD.base_asymm = os_tmp(n).mod_RD.base_asymm;
  60. os(n).mod_RD.base_asymm_cue = os_tmp(n).mod_RD.base_asymm_cue;
  61. os(n).mod_RD.threeValue_asymm = os_tmp(n).mod_RD.threeValue_asymm;
  62. os(n).mod_RD.threeValue_asymm_cue = os_tmp(n).mod_RD.threeValue_asymm_cue;
  63. os(n).mod_cue.base_asymm = os_tmp(n).mod_cue.base_asymm;
  64. os(n).mod_cue.base_asymm_cue = os_tmp(n).mod_cue.base_asymm_cue;
  65. end