plot_latentVariables.m 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. clear; clc
  2. load(fullfile(ottBari2020_root, 'Data', 'Modeling', 'ModelFits', 'threeOutcomes_MLEfits.mat'));
  3. myColors = importColors_bb;
  4. VP_color = myColors.bluishGreen;
  5. % get relevant behavior models
  6. modelCriterion = 'AIC';
  7. plotFlag = false;
  8. models_of_interest_RPE = {'base','curr','mean'};
  9. models_of_interest_V = {'base','mean'};
  10. timePeriod = 'RD';
  11. bm_RD = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag, ...
  12. 'particularModel', models_of_interest_RPE);
  13. timePeriod = 'cue';
  14. bm_cue = select_RPEmods(os, timePeriod,'scoreToUse',modelCriterion,'plotModels_Flag',plotFlag, ...
  15. 'particularModels', models_of_interest_V);
  16. %%
  17. timePeriod = 'RD'; % RD, cue
  18. normalization = 'zscore'; % none, zscore, minmax
  19. mod_type = ['mod_' timePeriod];
  20. switch timePeriod
  21. case 'RD'
  22. latent_var = 'RPEs';
  23. bm = bm_RD;
  24. case 'cue'
  25. latent_var = 'V';
  26. bm = bm_cue;
  27. otherwise
  28. error('timePeriod not found')
  29. end
  30. nSim = 501; % must be odd
  31. % VP first
  32. all_latent_VP = [];
  33. norm_fr_real_VP = [];
  34. norm_fr_sim_VP = [];
  35. mean_real_VP = [];
  36. mean_pred_VP = [];
  37. corr_spike_count_VP = [];
  38. var_real_VP = [];
  39. var_pred_VP = [];
  40. trialComparison_pred_VP = [];
  41. for n = find(bm.mask_base)
  42. sign_flip = sign(os(n).(mod_type).base.bestParams(2));
  43. real_spike_count = os(n).(['spikeCount_' timePeriod])';
  44. % generate nSim predicted spike counts
  45. pred_spike_count = [];
  46. for i = 1:nSim
  47. pred_spike_count(i,:) = poissrnd(os(n).(mod_type).base.mean_predictedSpikes)';
  48. end
  49. tmp_corr = corr(real_spike_count', pred_spike_count');
  50. % get median correlation
  51. corr_spike_count_VP = [corr_spike_count_VP median(tmp_corr)];
  52. % save that median neuron
  53. median_ind = find(tmp_corr == median(tmp_corr), 1);
  54. % get latent variables for plotting
  55. tmp_latent = os(n).(mod_type).base.(latent_var)';
  56. % normalize
  57. switch normalization
  58. case 'none'
  59. case 'zscore'
  60. tmp_latent = normalize(tmp_latent);
  61. case 'minmax'
  62. norm_const = 1/max(abs(tmp_latent));
  63. tmp_latent = norm_const*tmp_latent;
  64. end
  65. all_latent_VP = [all_latent_VP tmp_latent];
  66. % normalize real and predicted spike counts for tuning curves
  67. norm_fr_real_VP = [norm_fr_real_VP sign_flip*normalize(real_spike_count)];
  68. norm_fr_sim_VP = [norm_fr_sim_VP sign_flip*normalize(pred_spike_count(median_ind, :))];
  69. % mean spike counts
  70. mean_real_VP = [mean_real_VP mean(real_spike_count)];
  71. mean_pred_VP = [mean_pred_VP mean(pred_spike_count(median_ind, :))];
  72. % STD of spike counts; must use simulated spike counts here
  73. var_real_VP = [var_real_VP var(real_spike_count)];
  74. var_pred_VP = [var_pred_VP var(pred_spike_count(median_ind, :))];
  75. % save the median simulated spike count
  76. trialComparison_pred_VP = [trialComparison_pred_VP {pred_spike_count(median_ind, :)}];
  77. end
  78. nBins = 11;
  79. latent_bins = prctile(all_latent_VP, linspace(0, 100, nBins));
  80. spike_bins_real_VP = arrayfun(@(i, j) norm_fr_real_VP(all_latent_VP >= i & all_latent_VP < j), latent_bins(1:end -1), ...
  81. latent_bins(2:end), 'UniformOutput', false);
  82. spike_bins_sim_VP = arrayfun(@(i, j) norm_fr_sim_VP(all_latent_VP >= i & all_latent_VP < j), latent_bins(1:end -1), ...
  83. latent_bins(2:end), 'UniformOutput', false);
  84. % figure
  85. x_latent_bins = latent_bins(1:end - 1) + diff(latent_bins)/2;
  86. scatterSize = 15;
  87. h_figure = figure;
  88. h_lat_VP = subplot(221); hold on
  89. h_mean = subplot(222); hold on
  90. h_corr = subplot(223); hold on
  91. h_var = subplot(224); hold on
  92. subplot(h_lat_VP)
  93. t_lat_VP(1) = plotFilled(x_latent_bins, spike_bins_real_VP, VP_color, h_lat_VP);
  94. t_lat_VP(2) = plotFilled(x_latent_bins, spike_bins_sim_VP, myColors.blue_bright, h_lat_VP);
  95. subplot(h_mean)
  96. scatter(mean_real_VP, mean_pred_VP, scatterSize, 'filled', 'MarkerFaceColor', VP_color)
  97. maxVal = max([mean_real_VP mean_pred_VP]);
  98. plot([0 maxVal],[0 maxVal],'k--')
  99. xlabel('Real spikes (mean)')
  100. ylabel('Predicted spikes (mean)')
  101. subplot(h_corr)
  102. corr_bins = linspace(-1,1,40);
  103. histogram(corr_spike_count_VP, corr_bins, 'Normalization', 'Probability', 'EdgeColor', 'none', 'FaceColor', VP_color)
  104. xlabel('Correlation')
  105. ylabel('Probability')
  106. xlim([-0.05 1.05])
  107. subplot(h_var)
  108. scatter(var_real_VP, var_pred_VP, scatterSize, 'filled', 'MarkerFaceColor', VP_color)
  109. maxVal = max([var_real_VP var_pred_VP]);
  110. plot([0 maxVal],[0 maxVal],'k--')
  111. xlabel('Real spikes (variance)')
  112. ylabel('Predicted spikes (variance)')
  113. % clean it up
  114. legend(t_lat_VP, {'VP','Predicted'}, 'location', 'best')
  115. for cP = [h_lat_VP h_mean h_corr h_var]
  116. subplot(cP)
  117. set(cP,'tickdir','out')
  118. if cP == h_lat_VP
  119. xlabel(latent_var)
  120. ylabel([timePeriod ' spikes (z-score)'])
  121. if strcmp(latent_var, 'RPEs')
  122. % xlim([-1 1])
  123. else
  124. % xlim([0 1])
  125. end
  126. elseif cP == h_mean
  127. xlim([0 20]); ylim([0 20])
  128. elseif cP == h_corr
  129. elseif cP == h_var
  130. xlim([0 60]); ylim([0 60])
  131. end
  132. end
  133. %% plot neurons with particular cross correlations; go with median-correlation simulated spike count
  134. prtile_cutoff = 80;
  135. VP_neur = find(bm.mask_base);
  136. [~, VP_neuron_ind] = min(abs(corr_spike_count_VP - prctile(corr_spike_count_VP, prtile_cutoff)));
  137. VP_neuron_corr = corr_spike_count_VP(VP_neuron_ind);
  138. os_VP = os(VP_neur);
  139. os_VP_neur = os_VP(VP_neuron_ind);
  140. h_figure = figure;
  141. h_VP = subplot(1,1,1); hold on
  142. subplot(h_VP)
  143. plot(os_VP_neur.spikeCount_RD, 'Color', VP_color, 'linewidth', 2)
  144. plot(trialComparison_pred_VP{VP_neuron_ind}, 'Color', myColors.blue_bright, 'linewidth', 2)
  145. title(sprintf('VP\nnind: %i, corr = %0.2f', VP_neuron_ind, VP_neuron_corr))
  146. for cp = h_VP
  147. subplot(cp)
  148. xlabel('Trials')
  149. ylabel('Spike count')
  150. set(cp,'tickdir','out')
  151. end