kernel_regress_spikes.m 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. function [kernel, event_weights, S] = kernel_regress_spikes(event_ts, spktimes, varargin)
  2. % [kernel, event_weights, S] = kernel_regress_spikes(event_ts, spktimes, [krn_pre, krn_post, krn_bin_size])
  3. %
  4. % Inputs:
  5. % event_ts An N x M matrix, where N is the number of trials, and M is
  6. % the number of distinct events in that trial. Each element is the time of
  7. % the event relative to the beginning of the session.
  8. % spktimes The times (in seconds) of the spikes of a cell relative to
  9. % the start of the session.
  10. % krn_pre [0] Either a scalar or a vector of length M. This is the
  11. % beginning of the kernel relative to the event time.
  12. % krn_post [2] Either a scalar or a vector of length M. This is the
  13. % end of the kernel relative to the event time.
  14. % krn_bin_size [0.2] This is the resolution of the kernel. The size of the
  15. % kernel is krn_pre:krn_bin_size:krn_post
  16. if nargin==0
  17. fprintf(1,'Running test code.\n');
  18. krn_bin_size = 0.1;
  19. kernel = {gampdf(0:krn_bin_size:2,3,0.1), gampdf(0:0.1:6,5,0.4)};
  20. kernel{1} = [zeros(size(kernel{1})), kernel{1}./max(kernel{1})*7];
  21. kernel{2} = [zeros(size(kernel{2})),kernel{2}./max(kernel{2})*2];
  22. n_trials = 100;
  23. trial_starts = linspace(0,500, n_trials )';
  24. event_ts = [trial_starts trial_starts + rand(size(trial_starts))+0.1];
  25. event_weights = randi([-1 5], n_trials, numel(kernel))*3;
  26. [spktimes] = stats.simulate_spikes(event_ts, kernel, event_weights,'krn_bin_size',krn_bin_size,'baseline',20);
  27. [est_krn, est_weights]= kernel_regress_spikes(event_ts, spktimes);
  28. figure(1); clf;
  29. ax(1) = draw.jaxes([0.15 0.2 0.3 0.3]);
  30. ax(2) = draw.jaxes([0.55 0.2 0.3 0.3]);
  31. plot(ax(1), event_weights(:), est_weights(:),'o');
  32. draw.unity(ax(1));
  33. xlabel(ax(1), 'True Weights');
  34. ylabel(ax(1), 'Estimated Weights');
  35. for kx = 1:numel(kernel)
  36. ktime = numel(kernel{kx})/krn_bin_size;
  37. kax = linspace(-ktime/2, ktime/2, numel(kernel{kx}));
  38. h = plot(ax(2),kax,kernel{kx},'LineWidth',2);
  39. ktime = numel(est_krn{kx})/krn_bin_size;
  40. kax = linspace(-ktime/2, ktime/2, numel(est_krn{kx}));
  41. plot(ax(2),kax,est_krn{kx},'--','Color',h.Color,'LineWidth',2);
  42. end
  43. end % End Demo
  44. inpd = @utils.inputordefault;
  45. [krn_pre, args]=inpd('krn_pre',0,varargin);
  46. [krn_post, args]=inpd('krn_post',2,args);
  47. [krn_bin_size, args]=inpd('krn_bin_size',0.2,args);
  48. if ~isempty(args)
  49. warning('Unused argmuments in kernel_regress_spike.');
  50. disp(args);
  51. end
  52. % First, set all the co
  53. end