recover_RW_MLE_sim.m 10 KB


  1. clear os_RPE_temp os_V_temp os_RPE os_V
  2. nStart = 10;
  3. nTrials = 55; % median number of trials
  4. nNeurons = 200;
  5. Vinit = 0.5;
  6. m_RPE = {'base_asymm','base','curr','mean'};
  7. m_V = {'base_asymm','base','mean'};
  8. for ind = 1:nNeurons % for all neurons
  9. fprintf('n %i of %i\n', ind, nNeurons)
  10. V = NaN(nTrials + 1, 1);
  11. RPE = NaN(nTrials, 1);
  12. Vasymm = NaN(nTrials + 1, 1);
  13. RPEasymm = NaN(nTrials, 1);
  14. rwd = binornd(1, 0.5, nTrials, 1);
  15. rwd_prev = [0; rwd(1:end - 1)];
  16. alpha = rand; % 0 to 1
  17. alphaPPE = rand; % 0 to 1
  18. alphaNPE = rand; % 0 to 1
  19. slope = rand*3 + 1; % 1 to 4
  20. int = rand*10 - 5; % -5 to 5
  21. V(1) = Vinit;
  22. Vasymm(1) = Vinit;
  23. for t = 1:nTrials
  24. RPE(t) = rwd(t) - V(t);
  25. V(t + 1) = V(t) + alpha*RPE(t);
  26. RPEasymm(t) = rwd(t) - Vasymm(t);
  27. if RPEasymm(t) >= 0
  28. Vasymm(t + 1) = Vasymm(t) + alphaPPE*RPEasymm(t);
  29. else
  30. Vasymm(t + 1) = Vasymm(t) + alphaNPE*RPEasymm(t);
  31. end
  32. end
  33. % simulate RD neurons
  34. sp_RPEasymm = poissrnd(exp(slope*RPEasymm + int));
  35. sp_RPE = poissrnd(exp(slope*RPE + int));
  36. sp_curr = poissrnd(exp(slope*rwd + int));
  37. % simulate cue neurons
  38. sp_Vasymm = poissrnd(exp(slope*Vasymm(1:nTrials) + int));
  39. sp_V = poissrnd(exp(slope*V(1:nTrials) + int));
  40. sp_prev = poissrnd(exp(slope*rwd_prev + int));
  41. % simulate mean neurons (both for RD and cue)
  42. sp_mean = poissrnd(exp(int), nTrials, 1);
  43. % RD neurons
  44. os_RPE_temp(ind).params.alpha = alpha;
  45. os_RPE_temp(ind).params.alphaPPE = alphaPPE;
  46. os_RPE_temp(ind).params.alphaNPE = alphaNPE;
  47. os_RPE_temp(ind).params.slope = slope;
  48. os_RPE_temp(ind).params.int = int;
  49. os_RPE_temp(ind).rewards = rwd;
  50. os_RPE_temp(ind).timeLocked = true(size(rwd));
  51. os_RPE_temp(ind).spikeCount_RPEasymm = sp_RPEasymm;
  52. os_RPE_temp(ind).spikeCount_RPE = sp_RPE;
  53. os_RPE_temp(ind).spikeCount_curr = sp_curr;
  54. os_RPE_temp(ind).spikeCount_mean = sp_mean;
  55. % spikeCount is a temporary field
  56. % fit RPE
  57. os_RPE_temp(ind).spikeCount = os_RPE_temp(ind).spikeCount_RPEasymm;
  58. ms = helper_RW_RPE(os_RPE_temp(ind), 'StartingPoints', nStart, 'particularModel', m_RPE);
  59. os_RPE_temp(ind).mod_RPEasymm = ms;
  60. os_RPE_temp(ind).spikeCount = os_RPE_temp(ind).spikeCount_RPE;
  61. ms = helper_RW_RPE(os_RPE_temp(ind), 'StartingPoints', nStart, 'particularModel', m_RPE);
  62. os_RPE_temp(ind).mod_RPE = ms;
  63. os_RPE_temp(ind).spikeCount = os_RPE_temp(ind).spikeCount_curr;
  64. ms = helper_RW_RPE(os_RPE_temp(ind), 'StartingPoints', nStart, 'particularModel', m_RPE);
  65. os_RPE_temp(ind).mod_curr = ms;
  66. os_RPE_temp(ind).spikeCount = os_RPE_temp(ind).spikeCount_mean;
  67. ms = helper_RW_RPE(os_RPE_temp(ind), 'StartingPoints', nStart, 'particularModel', m_RPE);
  68. os_RPE_temp(ind).mod_mean = ms;
  69. % cue neurons
  70. os_V_temp(ind).params.alpha = alpha;
  71. os_V_temp(ind).params.alphaPPE = alphaPPE;
  72. os_V_temp(ind).params.alphaNPE = alphaNPE;
  73. os_V_temp(ind).params.slope = slope;
  74. os_V_temp(ind).params.int = int;
  75. os_V_temp(ind).rewards = rwd;
  76. os_V_temp(ind).timeLocked = true(size(rwd));
  77. os_V_temp(ind).spikeCount_Vasymm = sp_Vasymm;
  78. os_V_temp(ind).spikeCount_V = sp_V;
  79. os_V_temp(ind).spikeCount_prev = sp_prev;
  80. os_V_temp(ind).spikeCount_mean = sp_mean;
  81. % spikeCount is a temporary field
  82. % fit V
  83. os_V_temp(ind).spikeCount = os_V_temp(ind).spikeCount_Vasymm;
  84. ms = helper_RW_V(os_V_temp(ind), 'StartingPoints', nStart, 'particularModel', m_V);
  85. os_V_temp(ind).mod_Vasymm = ms;
  86. os_V_temp(ind).spikeCount = os_V_temp(ind).spikeCount_V;
  87. ms = helper_RW_V(os_V_temp(ind), 'StartingPoints', nStart, 'particularModel', m_V);
  88. os_V_temp(ind).mod_V = ms;
  89. os_V_temp(ind).spikeCount = os_V_temp(ind).spikeCount_prev;
  90. ms = helper_RW_V(os_V_temp(ind), 'StartingPoints', nStart, 'particularModel', m_V);
  91. os_V_temp(ind).mod_prev = ms;
  92. os_V_temp(ind).spikeCount = os_V_temp(ind).spikeCount_mean;
  93. ms = helper_RW_V(os_V_temp(ind), 'StartingPoints', nStart, 'particularModel', m_V);
  94. os_V_temp(ind).mod_mean = ms;
  95. % remove spikeCount to avoid future confusion
  96. os_RPE(ind) = rmfield(os_RPE_temp(ind), 'spikeCount');
  97. os_V(ind) = rmfield(os_V_temp(ind), 'spikeCount');
  98. end
  99. fprintf('Finished\n')
  100. save(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', ...
  101. 'intBlocks_MLEfits_simulated_offSim.mat'), ...
  102. 'os_RPE', 'os_V');
  103. %%
  104. load(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', ...
  105. 'intBlocks_MLEfits_simulated_offSim.mat'))
  106. scoreToUse = 'AIC';
  107. plotModels_Flag = true;
  108. m_RPE = {'base','curr','mean'};
  109. m_V = {'base','mean'};
  110. bm_RPE_aRPE = select_RPEmods(os_RPE, 'RPE', 'scoreToUse', scoreToUse, ...
  111. 'plotModels_Flag', plotModels_Flag, 'particularModels', m_RPE);
  112. bm_RPE_acurr = select_RPEmods(os_RPE, 'curr', 'scoreToUse', scoreToUse, ...
  113. 'plotModels_Flag', plotModels_Flag, 'particularModels', m_RPE);
  114. bm_RPE_aamean = select_RPEmods(os_RPE, 'mean', 'scoreToUse', scoreToUse, ...
  115. 'plotModels_Flag', plotModels_Flag, 'particularModels', m_RPE);
  116. bm_V_aV = select_RPEmods(os_V, 'V', 'particularModels', m_V, 'scoreToUse', scoreToUse, ...
  117. 'plotModels_Flag', plotModels_Flag);
  118. bm_V_amean = select_RPEmods(os_V, 'mean', 'particularModels', m_V, 'scoreToUse', scoreToUse, ...
  119. 'plotModels_Flag', plotModels_Flag);
  120. myColors = importColors_bb;
  121. %%
  122. % RPE first
  123. model = 'V';
  124. switch model
  125. case 'RPE'
  126. m1 = bm_RPE_aRPE;
  127. m2 = bm_RPE_acurr;
  128. m3 = bm_RPE_aamean;
  129. mod2 = 'curr';
  130. nNeurons = length(os_RPE);
  131. case 'V'
  132. m1 = bm_V_aV;
  133. m2 = bm_V_amean;
  134. nNeurons = length(os_V);
  135. otherwise
  136. error('model not found')
  137. end
  138. switch model
  139. case 'RPE'
  140. aBase_rBase = sum(m1.mask_base)/nNeurons;
  141. aBase_rMod2 = sum(m1.(['mask_' mod2]))/nNeurons;
  142. aBase_rMean = sum(m1.mask_mean)/nNeurons;
  143. aMod2_rBase = sum(m2.mask_base)/nNeurons;
  144. aMod2_rMod2 = sum(m2.(['mask_' mod2]))/nNeurons;
  145. aMod2_rMean = sum(m2.mask_mean)/nNeurons;
  146. aMean_rBase = sum(m3.mask_base)/nNeurons;
  147. aMean_rMod2 = sum(m3.(['mask_' mod2]))/nNeurons;
  148. aMean_rMean = sum(m3.mask_mean)/nNeurons;
  149. mat_for_hmap = [aBase_rBase aBase_rMod2 aBase_rMean;
  150. aMod2_rBase aMod2_rMod2 aMod2_rMean;
  151. aMean_rBase aMean_rMod2 aMean_rMean];
  152. h_heatmap = figure;
  153. axisLabel = {'Base',mod2,'Mean'};
  154. case 'V'
  155. aBase_rBase = sum(m1.mask_base)/nNeurons;
  156. aBase_rMean = sum(m1.mask_mean)/nNeurons;
  157. aMod2_rBase = sum(m2.mask_base)/nNeurons;
  158. aMod2_rMean = sum(m2.mask_mean)/nNeurons;
  159. aMean_rBase = sum(m3.mask_base)/nNeurons;
  160. aMean_rMean = sum(m3.mask_mean)/nNeurons;
  161. mat_for_hmap = [aBase_rBase aBase_rMean;
  162. aMean_rBase aMean_rMean];
  163. h_heatmap = figure;
  164. axisLabel = {'Base',mod2,'Mean'};
  165. end
  166. cmap_toUse = cmap_customColors(64, 'whiteBlue');
  167. [hImage, hText, hTick] = heatmap_AD(mat_for_hmap, axisLabel, axisLabel, '%0.2f', ...
  168. 'Colormap', cmap_toUse, ...
  169. 'ShowAllTicks', true, ...
  170. 'UseFigureColormap', false, ...
  171. 'Colorbar', true, ...
  172. 'FontSize', 10, ...
  173. 'MinColorValue', 0, ...
  174. 'MaxColorValue', 1, ...
  175. 'GridLines', '-');
  176. xlabel('Recovered model')
  177. ylabel('True model')
  178. set(gca,'tickdir','out')
  179. title(model)
  180. username = getenv('USERNAME');
  181. %% recover parameters
  182. param_struct = struct();
  183. % RPE neurons
  184. n_RPE = os_RPE(bm_RPE_aRPE.mask_base);
  185. param_struct.RPE.alpha.actual = [];
  186. param_struct.RPE.alpha.recovered = [];
  187. param_struct.RPE.slope.actual = [];
  188. param_struct.RPE.slope.recovered = [];
  189. param_struct.RPE.int.actual = [];
  190. param_struct.RPE.int.recovered = [];
  191. for n = 1:length(n_RPE)
  192. % alpha
  193. param_struct.RPE.alpha.actual = [param_struct.RPE.alpha.actual ...
  194. n_RPE(n).params.alpha];
  195. param_struct.RPE.alpha.recovered = [param_struct.RPE.alpha.recovered ...
  196. n_RPE(n).mod_RPE.base.bestParams(1)];
  197. % slope
  198. param_struct.RPE.slope.actual = [param_struct.RPE.slope.actual ...
  199. n_RPE(n).params.slope];
  200. param_struct.RPE.slope.recovered = [param_struct.RPE.slope.recovered ...
  201. n_RPE(n).mod_RPE.base.bestParams(2)];
  202. % slope
  203. param_struct.RPE.int.actual = [param_struct.RPE.int.actual ...
  204. n_RPE(n).params.int];
  205. param_struct.RPE.int.recovered = [param_struct.RPE.int.recovered ...
  206. n_RPE(n).mod_RPE.base.bestParams(3)];
  207. end
  208. % plot it
  209. binEdges = -1.1:0.2:1.1;
  210. h_paramRecovery = figure;
  211. h(1) = subplot(131); hold on
  212. rec_alpha = param_struct.RPE.alpha.actual - param_struct.RPE.alpha.recovered;
  213. histogram(rec_alpha, binEdges, 'EdgeColor','none','normalization','probability')
  214. xlabel('$\alpha$ (actual - recovered)','interpreter','latex')
  215. h(2) = subplot(132); hold on
  216. rec_slope = param_struct.RPE.slope.actual - param_struct.RPE.slope.recovered;
  217. histogram(rec_slope, binEdges, 'EdgeColor','none','normalization','probability')
  218. xlabel('slope (actual - recovered)','interpreter','latex')
  219. h(3) = subplot(133); hold on
  220. rec_int = param_struct.RPE.int.actual - param_struct.RPE.int.recovered;
  221. histogram(rec_int, binEdges, 'EdgeColor','none','normalization','probability')
  222. xlabel('intercept (actual - recovered)','interpreter','latex')
  223. for curr_h = h
  224. subplot(curr_h)
  225. ylim_range = get(curr_h, 'YLim');
  226. plot([0 0],ylim_range,'--','Color', myColors.gray)
  227. ylabel('Probability')
  228. set(curr_h,'tickdir','out')
  229. end
  230. saveFigureIteration_ottBari2019(h_paramRecovery, saveLoc, 'recovery_paramBias','FigureSize','max')