plot_linearRegressions.m 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. clear; clc
  2. load(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', 'threeOutcomes_MLEfits.mat'));
  3. myColors = importColors_bb;
  4. VP_color = myColors.bluishGreen;
  5. % get relevant behavior models
  6. modelCriterion = 'AIC';
  7. plotFlag = false;
  8. models_of_interest_RPE = {'base','curr','mean'};
  9. models_of_interest_V = {'base','mean'};
  10. timePeriod = 'RD';
  11. bm_RD = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag, ...
  12. 'particularModel', models_of_interest_RPE);
  13. timePeriod = 'cue';
  14. bm_cue = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag, ...
  15. 'particularModels', models_of_interest_V);
  16. %%
  17. timePeriod = 'RD';
  18. nanPad = 50;
  19. mod_type = ['mod_' timePeriod];
  20. switch timePeriod
  21. case 'RD'
  22. latent_var = 'RPEs';
  23. bm = bm_RD;
  24. mod2 = 'curr';
  25. case 'cue'
  26. latent_var = 'V';
  27. bm = bm_cue;
  28. mod2 = 'prev';
  29. otherwise
  30. error('timePeriod not found')
  31. end
  32. % VP first
  33. rwd_base_VP = [];
  34. fr_base_VP = [];
  35. fr_base_pred_VP = [];
  36. rwd_mod2_VP = [];
  37. fr_mod2_VP = [];
  38. fr_mod2_pred_VP = [];
  39. rwd_mean_VP = [];
  40. fr_mean_VP = [];
  41. fr_mean_pred_VP = [];
  42. for n = 1:length(os)
  43. spike_count = os(n).(['spikeCount_' timePeriod]);
  44. rewards = os(n).rewards(os(n).timeLocked);
  45. if bm.mask_base(n) == 1 % base neuron
  46. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  47. pred_spike_count = poissrnd(os(n).(mod_type).base.mean_predictedSpikes);
  48. rho = os(n).(mod_type).base.bestParams(4);
  49. water_ind = rewards == 2;
  50. mal_ind = rewards == 0;
  51. rewards(water_ind) = 0;
  52. rewards(mal_ind) = rho; % scale mal between 0 and 1
  53. rwd_base_VP = [rwd_base_VP; NaN(nanPad, 1); rewards];
  54. fr_base_VP = [fr_base_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  55. fr_base_pred_VP = [fr_base_pred_VP; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  56. elseif bm.(['mask_' mod2])(n) == 1 % curr/prev neuron
  57. sign_flip = sign(os(n).(mod_type).(mod2).bestParams(1));
  58. pred_spike_count = poissrnd(os(n).(mod_type).(mod2).mean_predictedSpikes);
  59. rho = os(n).(mod_type).(mod2).bestParams(3);
  60. water_ind = rewards == 2;
  61. mal_ind = rewards == 0;
  62. rewards(water_ind) = 0;
  63. rewards(mal_ind) = rho; % scale mal between 0 and 1
  64. rwd_mod2_VP = [rwd_mod2_VP; NaN(nanPad, 1); rewards];
  65. fr_mod2_VP = [fr_mod2_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  66. fr_mod2_pred_VP = [fr_mod2_pred_VP; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  67. elseif bm.mask_mean(n) == 1 % mean neuron
  68. pred_spike_count = poissrnd(mean(spike_count), length(spike_count), 1);
  69. rho = 0.8;
  70. water_ind = rewards == 2;
  71. mal_ind = rewards == 0;
  72. rewards(water_ind) = 0;
  73. rewards(mal_ind) = rho; % scale mal between 0 and 1
  74. rwd_mean_VP = [rwd_mean_VP; NaN(nanPad, 1); rewards];
  75. fr_mean_VP = [fr_mean_VP; NaN(nanPad, 1); normalize(spike_count)];
  76. fr_mean_pred_VP = [fr_mean_pred_VP; NaN(nanPad, 1); normalize(pred_spike_count)];
  77. end
  78. end
  79. rwdHx_base_VP = [rwd_base_VP generateHistoryMatrix(rwd_base_VP, 10)];
  80. rwdHx_mod2_VP = [rwd_mod2_VP generateHistoryMatrix(rwd_mod2_VP, 10)];
  81. rwdHx_mean_VP = [rwd_mean_VP generateHistoryMatrix(rwd_mean_VP, 10)];
  82. base_mod_VP = fitlm(rwdHx_base_VP, fr_base_VP);
  83. mod2_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_VP);
  84. mean_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_VP);
  85. base_pred_mod_VP = fitlm(rwdHx_base_VP, fr_base_pred_VP);
  86. mod2_pred_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_pred_VP);
  87. mean_pred_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_pred_VP);
  88. % figure
  89. h_bg = figure;
  90. h_VP = subplot(121); hold on
  91. h_VP_pred = subplot(122); hold on
  92. subplot(h_VP)
  93. title('VP')
  94. t_VP(1) = plotRegressionWithCI(base_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', VP_color);
  95. t_VP(2) = plotRegressionWithCI(mod2_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.darkGray);
  96. t_VP(3) = plotRegressionWithCI(mean_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.lightGray);
  97. subplot(h_VP_pred)
  98. title('VP - simulated neurons')
  99. t_VP_sim(1) = plotRegressionWithCI(base_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', VP_color);
  100. t_VP_sim(2) = plotRegressionWithCI(mod2_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.darkGray);
  101. t_VP_sim(3) = plotRegressionWithCI(mean_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.lightGray);
  102. for cP = [h_VP h_VP_pred]
  103. subplot(cP)
  104. set(cP,'tickdir','out')
  105. plot([-0.5 10.5],[0 0],'k--')
  106. xlim([-0.5 10.5])
  107. ylim([-0.6 2.0])
  108. xlabel('Reward n trials back')
  109. ylabel('Coefficient ($\pm 95\%$ CI)', 'Interpreter', 'latex')
  110. end
  111. legend(t_VP, {'Base', mod2, 'Mean'})
  112. legend(t_VP_sim, {'Base', mod2, 'Mean'})
  113. %% generate R2-maps
  114. timePeriod = 'RD';
  115. nanPad = 50;
  116. mod_type = ['mod_' timePeriod];
  117. switch timePeriod
  118. case 'RD'
  119. latent_var = 'RPEs';
  120. bm = bm_RD;
  121. mod2 = 'curr';
  122. case 'cue'
  123. latent_var = 'V';
  124. bm = bm_cue;
  125. mod2 = 'prev';
  126. case 'PE'
  127. latent_var = 'V';
  128. bm = bm_PE;
  129. mod2 = 'prev';
  130. otherwise
  131. error('timePeriod not found')
  132. end
  133. % VP first
  134. rwd_base_VP = [];
  135. fr_base_VP = [];
  136. rwd_mod2_VP = [];
  137. fr_mod2_VP = [];
  138. for n = 1:length(os)
  139. spike_count = os(n).(['spikeCount_' timePeriod]);
  140. rewards = os(n).rewards(os(n).timeLocked);
  141. if bm.mask_base(n) == 1 % base neuron
  142. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  143. pred_spike_count = poissrnd(os(n).(mod_type).base.mean_predictedSpikes);
  144. rwd_base_VP = [rwd_base_VP; NaN(nanPad, 1); rewards];
  145. fr_base_VP = [fr_base_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  146. elseif bm.(['mask_' mod2])(n) == 1 % curr/prev neuron
  147. sign_flip = sign(os(n).(mod_type).(mod2).bestParams(1));
  148. pred_spike_count = poissrnd(os(n).(mod_type).(mod2).mean_predictedSpikes);
  149. rwd_mod2_VP = [rwd_mod2_VP; NaN(nanPad, 1); rewards];
  150. fr_mod2_VP = [fr_mod2_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  151. fr_mod2_pred_VP = [fr_mod2_pred_VP; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  152. end
  153. end
  154. all_R2_base = [];
  155. all_R2_mod2 = [];
  156. all_rho = linspace(0,1,100);
  157. for rho = all_rho
  158. water_ind = rwd_base_VP == 2;
  159. mal_ind = rwd_base_VP == 0;
  160. rwd_base_VP_new = rwd_base_VP;
  161. rwd_base_VP_new(water_ind) = 0;
  162. rwd_base_VP_new(mal_ind) = rho; % scale mal between 0 and 1
  163. water_ind = rwd_mod2_VP == 2;
  164. mal_ind = rwd_mod2_VP == 0;
  165. rwd_mod2_VP_new = rwd_mod2_VP;
  166. rwd_mod2_VP_new(water_ind) = 0;
  167. rwd_mod2_VP_new(mal_ind) = rho;
  168. rwd_base_VP_new = [rwd_base_VP_new generateHistoryMatrix(rwd_base_VP_new, 10)];
  169. rwd_mod2_VP_new = [rwd_mod2_VP_new generateHistoryMatrix(rwd_mod2_VP_new, 10)];
  170. base_mod_VP = fitlm(rwd_base_VP_new, fr_base_VP);
  171. mod2_mod_VP = fitlm(rwd_mod2_VP_new, fr_mod2_VP);
  172. all_R2_base = [all_R2_base base_mod_VP.Rsquared.Adjusted];
  173. all_R2_mod2 = [all_R2_mod2 mod2_mod_VP.Rsquared.Adjusted];
  174. end
  175. h_rho = figure; hold on
  176. [~, imax_base] = max(all_R2_base);
  177. [~, imax_mod2] = max(all_R2_mod2);
  178. textOffset_y = -max(all_R2_base)/15;
  179. t(1) = plot(all_rho, all_R2_base, 'linewidth', 2, 'Color', VP_color);
  180. plot(all_rho(imax_base), all_R2_base(imax_base), 'o', 'Color', VP_color);
  181. text(all_rho(imax_base), all_R2_base(imax_base) + textOffset_y, sprintf('%0.2f', all_rho(imax_base)), 'Color', VP_color)
  182. t(2) = plot(all_rho, all_R2_mod2, 'linewidth', 2, 'Color', myColors.darkGray);
  183. plot(all_rho(imax_mod2), all_R2_mod2(imax_mod2), 'o', 'Color', myColors.darkGray);
  184. text(all_rho(imax_mod2), all_R2_mod2(imax_mod2) + textOffset_y, sprintf('%0.2f', all_rho(imax_mod2)), 'Color', myColors.darkGray)
  185. xlabel('$\rho$', 'interpreter', 'latex')
  186. ylabel('$R^2$', 'interpreter', 'latex')
  187. set(gca, 'tickdir', 'out')
  188. legend(t, {'Base',mod2}, 'location', 'best')