align_psth.m 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. function [offset,inc_t,x,y]=align_psth(ev, ts,varargin)
  2. % [offset, inc_t,x,y]=align_psth(ev, ts, [key, val])
  3. %
  4. % Inputs:
  5. % ------
  6. % ev A list of the time in each trial (relative to session start) of a reference event (in seconds). E.g. Time of stimulus presentation.
  7. % ts Times (in seconds relative to session start) of spikes (or any point-process of interest)
  8. %
  9. % Optional Inputs (=default value)
  10. %
  11. %
  12. % pre= 3; % Include 3 seconds before the reference event
  13. % post= 3; % Include 3 seconds after the reference event
  14. % binsz= 0.001; % 1 ms bin size
  15. % krn= 0.15; % Use a Normal smoothing kernel of 150 ms
  16. % pre_mask= -inf; % This can be a scalar or vector of times relative to the reference event.
  17. % % Data before this time gets converted to NaN
  18. % post_mask=+inf; % Like pre_mask but to mask end of trial
  19. % max_offset=1; % How far can each trial be shifted from mean PSTH.
  20. % do_plot= false; % Plot? Useful for debugging and checking whether fits work.
  21. % max_iter=50; % Maximum iterations (if var_thres is not reached).
  22. % var_thres=0.05; % If the difference in the variation of the mean PSTH is less than this fraction
  23. % % then end.
  24. % mark_this= []; % A list of times (relative to ev) to mark on the plots. (e.g. a go cue)
  25. % save_plot=''; % if you want to save the process to a eps, add a name here.
  26. %
  27. % Outputs
  28. % -------
  29. % offset A double vector (same length as ev) with the relative time offsets of each trial.
  30. % inc_t A logical vector (same length as ev) which is false
  31. % x The time axis of the PSTH (e.g. -pre:binsz:post)
  32. % y A matrix [ev rows and same columns as x] of the aligned trials.
  33. %
  34. % Example:
  35. %
  36. % ev = (1:30)*5;
  37. % ts = [];
  38. % for i=1:29
  39. % ts = [ts; i*5 + normrnd(1+rand*2,0.4,[100,1])]
  40. % end
  41. % ts=sort(ts)
  42. % [offset,inct, ax,ay] = stats.align_psth(ev,ts,'pre',0,'krn',0.1,'do_plot',true);
  43. pre= 3; % Include 3 seconds before the reference event
  44. post= 3; % Include 3 seconds after the reference event
  45. binsz= 0.001; % 1 ms bin size
  46. krn= 0.15; % Use a Normal smoothing kernel of 150 ms
  47. pre_mask= -inf; % This can be a scalar or vector of times relative to the reference event.
  48. % Data before this time gets converted to NaN
  49. post_mask=+inf; % Like pre_mask but to mask end of trial
  50. max_offset=1; % How far can each trial be shifted from mean PSTH.
  51. do_plot= false; % Plot? Useful for debugging and checking whether fits work.
  52. max_iter=50; % Maximum iterations (if var_thres is not reached).
  53. var_thres=0.05; % If the difference in the variation of the mean PSTH is less than this fraction
  54. % then end.
  55. mark_this= []; % A list of times (relative to ev) to mark on the plots. (e.g. a go cue)
  56. save_plot=''; % if you want to save the process to a eps, add a name here.
  57. utils.overridedefaults(who,varargin);
  58. if isscalar(pre_mask)
  59. pre_mask=zeros(1,numel(ev))+pre_mask;
  60. elseif numel(pre_mask)~=numel(ev)
  61. fprintf(1,'numel(pre_mask) must equal num ref events or be scalar');
  62. return;
  63. end
  64. if isscalar(post_mask)
  65. post_mask=zeros(1,numel(ev))+post_mask;
  66. elseif numel(post_mask)~=numel(ev)
  67. fprintf(1,'numel(post_mask) must equal num ref events or be scalar');
  68. return;
  69. end
  70. if isscalar(krn)
  71. % If krn is scalar then create a smoothing kernel that is Normal with S.D. krn.
  72. dx=ceil(5*krn);
  73. kx=-dx:binsz:dx;
  74. krn=normpdf(kx,0, krn);
  75. % krn(1:(find(kx==0)-1))=0;
  76. krn=krn/sum(krn);
  77. end
  78. old_var=10e10;
  79. done=0;
  80. thres=50;
  81. offset=zeros(size(ev));
  82. if do_plot
  83. clf;ax=axes('Position',[0.1 0.1 0.2 0.2]);
  84. hold on;
  85. end
  86. cnt=1;
  87. inc_t=ones(size(ev))==1;
  88. inc_t(isnan(ev))=false;
  89. %% Calculate the mean and ci of the
  90. while ~done
  91. [y,x]=stats.spike_filter(ev+offset,ts,krn,'pre',pre,'post',post,'kernel_bin_size',binsz);
  92. % xcorr doesn't handle nans well, i think. so this was commented out. -jce
  93. % [y x]=maskraster(x,y,pre_mask(ref),post_mask(ref));
  94. ymn = nanmean(y(inc_t,:));
  95. yst = stats.nanstderr(y(inc_t,:));
  96. if cnt==1
  97. maxy=2*max(ymn); % this is used to set the ylim for the plot
  98. end
  99. if do_plot
  100. cla
  101. imagesc(x,[1 maxy],y(inc_t,:),'Parent',ax);
  102. colormap('hot')
  103. hold(ax,'on')
  104. % % This code plotted the average PSTH on top of the heat map. But was a bit distracting.
  105. % plot(ax,x,ymn-yst,x,ymn+yst,'Color',[1-0.2*cnt, 0 ,0]);
  106. % plot(ax,x,ymn-yst,x,ymn+yst,'Color',[0.2 0.2 0.9],'LineWidth',2);
  107. xlim([-pre post])
  108. ylim([1 maxy])
  109. xlabel('Time (s)')
  110. % ylabel('Spike/sec')
  111. set(ax,'YTick',[]);
  112. yss=linspace(1,maxy,sum(inc_t==1));
  113. plot(ax,-offset(inc_t), yss,'c+');
  114. if ~isempty(mark_this)
  115. plot(ax,-offset(inc_t)+mark_this(inc_t), yss, 'gx');
  116. end
  117. if cnt==1
  118. cbh=colorbar('peer',ax(1),'East');
  119. set(cbh,'Position',get(cbh,'Position')-[0.2 0 0 0])
  120. end
  121. drawnow
  122. if ~isempty(save_plot)
  123. saveas(gcf,sprintf('ap_%s_%d.eps',save_plot,cnt),'epsc2');
  124. end
  125. if cnt==1
  126. set(cbh,'Visible','off')
  127. end
  128. end
  129. for tx=1:numel(ev);
  130. if inc_t(tx)
  131. [xcy,xcx]=xcorr(y(tx,:)-mean(y(tx,:)),ymn-mean(ymn));
  132. [v,peakx]=max(xcy);
  133. offset(tx)=offset(tx)+xcx(peakx)*binsz;
  134. if abs(offset(tx))>max_offset
  135. inc_t(tx)=false;
  136. end
  137. end
  138. end
  139. new_var=sum(nanvar(y));
  140. var_diff=(old_var-new_var)/old_var;
  141. if do_plot
  142. fprintf('Variance improved by %2.3g %% of total variance\n',100*var_diff);
  143. end
  144. old_var=new_var;
  145. cnt=cnt+1;
  146. if abs(var_diff)<var_thres || cnt>max_iter
  147. done=true;
  148. end
  149. end