plot_linearRegressions.m 8.4 KB


  1. load(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', 'intBlocks_MLEfits.mat'));
  2. int_task = [os.Blocks] == 0;
  3. os = os(int_task);
  4. VP_mask = contains({os.Region}, 'VP');
  5. myColors = importColors_bb;
  6. VP_color = myColors.bluishGreen;
  7. NAc_color = myColors.vermillion;
  8. % get relevant behavior models
  9. modelCriterion = 'AIC';
  10. plotFlag = false;
  11. models_of_interest_RPE = {'base','curr','mean'};
  12. models_of_interest_V = {'base','mean'};
  13. timePeriod = 'RD';
  14. bm_RD = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,...
  15. 'plotModels_Flag',plotFlag,...
  16. 'particularModel',models_of_interest_RPE);
  17. timePeriod = 'cue';
  18. bm_cue = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,...
  19. 'plotModels_Flag',plotFlag,...
  20. 'particularModel',models_of_interest_V);
  21. %%
  22. timePeriod = 'RD'; % RD or cue
  23. nanPad = 50;
  24. mod_type = ['mod_' timePeriod];
  25. switch timePeriod
  26. case 'RD'
  27. latent_var = 'RPEs';
  28. bm = bm_RD;
  29. mod2 = 'curr';
  30. case 'cue'
  31. latent_var = 'V';
  32. bm = bm_cue;
  33. otherwise
  34. error('timePeriod not found')
  35. end
  36. % VP first
  37. rwd_base_VP = [];
  38. fr_base_VP = [];
  39. fr_base_pred_VP = [];
  40. if strcmp(timePeriod,'RD')
  41. rwd_mod2_VP = [];
  42. fr_mod2_VP = [];
  43. fr_mod2_pred_VP = [];
  44. end
  45. rwd_mean_VP = [];
  46. fr_mean_VP = [];
  47. fr_mean_pred_VP = [];
  48. for n = 1:length(os)
  49. if VP_mask(n) == 1 % VP neuron
  50. spike_count = os(n).(['spikeCount_' timePeriod]);
  51. rewards = os(n).rewards(os(n).timeLocked);
  52. if bm.mask_base(n) == 1 % base neuron
  53. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  54. pred_spike_count = poissrnd(os(n).(mod_type).base.mean_predictedSpikes);
  55. rwd_base_VP = [rwd_base_VP; NaN(nanPad, 1); rewards];
  56. fr_base_VP = [fr_base_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  57. fr_base_pred_VP = [fr_base_pred_VP; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  58. elseif strcmp(timePeriod,'RD') && bm.(['mask_' mod2])(n) == 1 % curr/prev neuron
  59. sign_flip = sign(os(n).(mod_type).(mod2).bestParams(1));
  60. pred_spike_count = poissrnd(os(n).(mod_type).(mod2).mean_predictedSpikes);
  61. rwd_mod2_VP = [rwd_mod2_VP; NaN(nanPad, 1); rewards];
  62. fr_mod2_VP = [fr_mod2_VP; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  63. fr_mod2_pred_VP = [fr_mod2_pred_VP; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  64. elseif bm.mask_mean(n) == 1 % mean neuron
  65. pred_spike_count = poissrnd(mean(spike_count), length(spike_count), 1);
  66. rwd_mean_VP = [rwd_mean_VP; NaN(nanPad, 1); rewards];
  67. fr_mean_VP = [fr_mean_VP; NaN(nanPad, 1); normalize(spike_count)];
  68. fr_mean_pred_VP = [fr_mean_pred_VP; NaN(nanPad, 1); normalize(pred_spike_count)];
  69. end
  70. end
  71. end
  72. rwdHx_base_VP = [rwd_base_VP generateHistoryMatrix(rwd_base_VP, 10)];
  73. rwdHx_mean_VP = [rwd_mean_VP generateHistoryMatrix(rwd_mean_VP, 10)];
  74. base_mod_VP = fitlm(rwdHx_base_VP, fr_base_VP);
  75. mean_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_VP);
  76. base_pred_mod_VP = fitlm(rwdHx_base_VP, fr_base_pred_VP);
  77. mean_pred_mod_VP = fitlm(rwdHx_mean_VP, fr_mean_pred_VP);
  78. if strcmp(timePeriod,'RD')
  79. rwdHx_mod2_VP = [rwd_mod2_VP generateHistoryMatrix(rwd_mod2_VP, 10)];
  80. mod2_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_VP);
  81. mod2_pred_mod_VP = fitlm(rwdHx_mod2_VP, fr_mod2_pred_VP);
  82. end
  83. % NAc second
  84. rwd_base_NAc = [];
  85. fr_base_NAc = [];
  86. fr_base_pred_NAc = [];
  87. if strcmp(timePeriod,'RD')
  88. rwd_mod2_NAc = [];
  89. fr_mod2_NAc = [];
  90. fr_mod2_pred_NAc = [];
  91. end
  92. rwd_mean_NAc = [];
  93. fr_mean_NAc = [];
  94. fr_mean_pred_NAc = [];
  95. for n = 1:length(os)
  96. if VP_mask(n) == 0 % NAc neuron
  97. spike_count = os(n).(['spikeCount_' timePeriod]);
  98. rewards = os(n).rewards(os(n).timeLocked);
  99. if bm.mask_base(n) == 1 % base neuron
  100. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  101. pred_spike_count = poissrnd(os(n).(mod_type).base.mean_predictedSpikes);
  102. rwd_base_NAc = [rwd_base_NAc; NaN(nanPad, 1); rewards];
  103. fr_base_NAc = [fr_base_NAc; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  104. fr_base_pred_NAc = [fr_base_pred_NAc; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  105. elseif strcmp(timePeriod,'RD') && bm.(['mask_' mod2])(n) == 1 % curr/prev neuron
  106. sign_flip = sign(os(n).(mod_type).(mod2).bestParams(1));
  107. pred_spike_count = poissrnd(os(n).(mod_type).(mod2).mean_predictedSpikes);
  108. rwd_mod2_NAc = [rwd_mod2_NAc; NaN(nanPad, 1); rewards];
  109. fr_mod2_NAc = [fr_mod2_NAc; NaN(nanPad, 1); sign_flip*normalize(spike_count)];
  110. fr_mod2_pred_NAc = [fr_mod2_pred_NAc; NaN(nanPad, 1); sign_flip*normalize(pred_spike_count)];
  111. elseif bm.mask_mean(n) == 1 % mean neuron
  112. pred_spike_count = poissrnd(mean(spike_count), length(spike_count), 1);
  113. rwd_mean_NAc = [rwd_mean_NAc; NaN(nanPad, 1); rewards];
  114. fr_mean_NAc = [fr_mean_NAc; NaN(nanPad, 1); normalize(spike_count)];
  115. fr_mean_pred_NAc = [fr_mean_pred_NAc; NaN(nanPad, 1); normalize(pred_spike_count)];
  116. end
  117. end
  118. end
  119. rwdHx_base_NAc = [rwd_base_NAc generateHistoryMatrix(rwd_base_NAc, 10)];
  120. rwdHx_mean_NAc = [rwd_mean_NAc generateHistoryMatrix(rwd_mean_NAc, 10)];
  121. base_mod_NAc = fitlm(rwdHx_base_NAc, fr_base_NAc);
  122. mean_mod_NAc = fitlm(rwdHx_mean_NAc, fr_mean_NAc);
  123. base_pred_mod_NAc = fitlm(rwdHx_base_NAc, fr_base_pred_NAc);
  124. mean_pred_mod_NAc = fitlm(rwdHx_mean_NAc, fr_mean_pred_NAc);
  125. if strcmp(timePeriod,'RD')
  126. rwdHx_mod2_NAc = [rwd_mod2_NAc generateHistoryMatrix(rwd_mod2_NAc, 10)];
  127. mod2_mod_NAc = fitlm(rwdHx_mod2_NAc, fr_mod2_NAc);
  128. mod2_pred_mod_NAc = fitlm(rwdHx_mod2_NAc, fr_mod2_pred_NAc);
  129. end
  130. % figure
  131. h_bg = figure;
  132. h_VP = subplot(221); hold on
  133. h_VP_pred = subplot(222); hold on
  134. h_NAc = subplot(223); hold on
  135. h_NAc_pred = subplot(224); hold on
  136. subplot(h_VP)
  137. title('VP')
  138. t_VP(1) = plotRegressionWithCI(base_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', VP_color);
  139. if strcmp(timePeriod,'RD')
  140. t_VP(2) = plotRegressionWithCI(mod2_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.darkGray);
  141. end
  142. t_VP(3) = plotRegressionWithCI(mean_mod_VP, 2:12, h_VP, 'XOffset', -1, 'Color', myColors.lightGray);
  143. subplot(h_NAc)
  144. title('NAc')
  145. t_NAc(1) = plotRegressionWithCI(base_mod_NAc, 2:12, h_NAc, 'XOffset', -1, 'Color', NAc_color);
  146. if strcmp(timePeriod,'RD')
  147. t_NAc(2) = plotRegressionWithCI(mod2_mod_NAc, 2:12, h_NAc, 'XOffset', -1, 'Color', myColors.darkGray);
  148. end
  149. t_NAc(3) = plotRegressionWithCI(mean_mod_NAc, 2:12, h_NAc, 'XOffset', -1, 'Color', myColors.lightGray);
  150. subplot(h_VP_pred)
  151. title('VP - simulated neurons')
  152. t_VP_sim(1) = plotRegressionWithCI(base_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', VP_color);
  153. if strcmp(timePeriod,'RD')
  154. t_VP_sim(2) = plotRegressionWithCI(mod2_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.darkGray);
  155. end
  156. t_VP_sim(3) = plotRegressionWithCI(mean_pred_mod_VP, 2:12, h_VP_pred, 'XOffset', -1, 'Color', myColors.lightGray);
  157. subplot(h_NAc_pred)
  158. title('NAc - simulated neurons')
  159. t_NAc_sim(1) = plotRegressionWithCI(base_pred_mod_NAc, 2:12, h_NAc_pred, 'XOffset', -1, 'Color', NAc_color);
  160. if strcmp(timePeriod,'RD')
  161. t_NAc_sim(2) = plotRegressionWithCI(mod2_pred_mod_NAc, 2:12, h_NAc_pred, 'XOffset', -1, 'Color', myColors.darkGray);
  162. end
  163. t_NAc_sim(3) = plotRegressionWithCI(mean_pred_mod_NAc, 2:12, h_NAc_pred, 'XOffset', -1, 'Color', myColors.lightGray);
  164. for cP = [h_VP h_NAc h_VP_pred h_NAc_pred]
  165. subplot(cP)
  166. set(cP,'tickdir','out')
  167. plot([-0.5 10.5],[0 0],'k--')
  168. xlim([-0.5 10.5])
  169. ylim([-0.6 1.1])
  170. xlabel('Reward n trials back')
  171. ylabel('Coefficient ($\pm 95\%$ CI)', 'Interpreter', 'latex')
  172. end
  173. if strcmp(timePeriod,'RD')
  174. legend(t_VP, {'Base', mod2, 'Mean'})
  175. legend(t_VP_sim, {'Base', mod2, 'Mean'})
  176. legend(t_NAc, {'Base', mod2, 'Mean'})
  177. legend(t_NAc_sim, {'Base', mod2, 'Mean'})
  178. else
  179. t_VP(2) = [];
  180. t_VP_sim(2) = [];
  181. t_NAc(2) = [];
  182. t_NAc_sim(2) = [];
  183. legend(t_VP, {'Base', 'Mean'})
  184. legend(t_VP_sim, {'Base', 'Mean'})
  185. legend(t_NAc, {'Base', 'Mean'})
  186. legend(t_NAc_sim, {'Base', 'Mean'})
  187. end