plot_linearRegression.m 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. clear; clc
  2. load(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', 'cue_MLEfits.mat'));
  3. myColors = importColors_bb;
  4. VP_color = myColors.bluishGreen;
  5. % get relevant behavior models
  6. modelCriterion = 'AIC';
  7. plotFlag = false;
  8. m_RPE = {'base','base_cue','curr','curr_cue','mean','mean_cue'};
  9. m_V = {'base','base_cue','mean','mean_cue'};
  10. timePeriod = 'RD';
  11. bm_RD = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag,...
  12. 'particularModel',m_RPE);
  13. timePeriod = 'cue';
  14. bm_cue = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag,...
  15. 'particularModel',m_V);
  16. %%
  17. timePeriod = 'RD';
  18. nanPad = 50;
  19. mod_type = ['mod_' timePeriod];
  20. switch timePeriod
  21. case 'RD'
  22. bm = bm_RD;
  23. mod2 = 'curr';
  24. case 'cue'
  25. bm = bm_cue;
  26. mod2 = 'prev';
  27. otherwise
  28. error('timePeriod not found')
  29. end
  30. % VP first
  31. rwd_base_VP = [];
  32. fr_base_VP = [];
  33. fr_base_pred_VP = [];
  34. rwd_mod2_VP = [];
  35. fr_mod2_VP = [];
  36. fr_mod2_pred_VP = [];
  37. rwd_mean_VP = [];
  38. fr_mean_VP = [];
  39. fr_mean_pred_VP = [];
  40. for n = 1:length(os)
  41. spike_count = os(n).(['spikeCount_' timePeriod]);
  42. rewards = os(n).rewards(os(n).timeLocked);
  43. if bm.mask_base(n) == 1 % base neuron
  44. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  45. pred_spike_count = poissrnd(os(n).(mod_type).base.mean_predictedSpikes);
  46. rwd_base_VP = [rwd_base_VP; NaN(nanPad, 1); rewards];
  47. fr_base_VP = [fr_base_VP; NaN(nanPad, 1); normalize(spike_count)];
  48. fr_base_pred_VP = [fr_base_pred_VP; NaN(nanPad, 1); normalize(pred_spike_count)];
  49. elseif bm.(['mask_' mod2])(n) == 1 || bm.(['mask_' mod2 '_cue'])(n) == 1 % curr/prev neuron
  50. pred_spike_count = poissrnd(os(n).(mod_type).(mod2).mean_predictedSpikes);
  51. rwd_mod2_VP = [rwd_mod2_VP; NaN(nanPad, 1); rewards];
  52. fr_mod2_VP = [fr_mod2_VP; NaN(nanPad, 1); normalize(spike_count)];
  53. fr_mod2_pred_VP = [fr_mod2_pred_VP; NaN(nanPad, 1); normalize(pred_spike_count)];
  54. elseif bm.mask_mean(n) == 1 || bm.mask_mean_cue(n) == 1 % mean neuron
  55. pred_spike_count = poissrnd(mean(spike_count), length(spike_count), 1);
  56. rwd_mean_VP = [rwd_mean_VP; NaN(nanPad, 1); rewards];
  57. fr_mean_VP = [fr_mean_VP; NaN(nanPad, 1); normalize(spike_count)];
  58. fr_mean_pred_VP = [fr_mean_pred_VP; NaN(nanPad, 1); normalize(pred_spike_count)];
  59. end
  60. end
  61. rwdHx_base_VP = [rwd_base_VP generateHistoryMatrix(rwd_base_VP, 10)];
  62. rwdHx_mod2_VP = [rwd_mod2_VP generateHistoryMatrix(rwd_mod2_VP, 10)];
  63. rwdHx_mean_VP = [rwd_mean_VP generateHistoryMatrix(rwd_mean_VP, 10)];
  64. base_mod_VP = fitlm(rwdHx_base_VP, fr_base_VP);
  65. mod2_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_VP);
  66. mean_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_VP);
  67. base_pred_mod_VP = fitlm(rwdHx_base_VP, fr_base_pred_VP);
  68. mod2_pred_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_pred_VP);
  69. mean_pred_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_pred_VP);
  70. % figure
  71. h_bg = figure;
  72. h_VP = subplot(121); hold on
  73. h_VP_pred = subplot(122); hold on
  74. subplot(h_VP)
  75. title('VP')
  76. t_VP(1) = plotRegressionWithCI(base_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', VP_color);
  77. t_VP(2) = plotRegressionWithCI(mod2_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.darkGray);
  78. t_VP(3) = plotRegressionWithCI(mean_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.lightGray);
  79. subplot(h_VP_pred)
  80. title('VP - simulated neurons')
  81. t_VP_sim(1) = plotRegressionWithCI(base_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', VP_color);
  82. t_VP_sim(2) = plotRegressionWithCI(mod2_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.darkGray);
  83. t_VP_sim(3) = plotRegressionWithCI(mean_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.lightGray);
  84. for cP = [h_VP h_VP_pred]
  85. subplot(cP)
  86. set(cP,'tickdir','out')
  87. plot([-0.5 10.5],[0 0],'k--')
  88. xlim([-0.5 10.5])
  89. ylim([-0.6 2.0])
  90. xlabel('Reward n trials back')
  91. ylabel('Coefficient ($\pm 95\%$ CI)', 'Interpreter', 'latex')
  92. end
  93. legend(t_VP, {'Base', mod2, 'Mean'})
  94. legend(t_VP_sim, {'Base', mod2, 'Mean'})