KernelRegressionA.m 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. classdef KernelRegressionA
  2. properties
  3. baseline_per_trial = false
  4. core_kernel % this is the matrix for a single kernel. Will use it as a convolution kernel
  5. core_kernel_matrix % This is the core_kernel convolued with the event times.
  6. event_times
  7. kernel_bin_size = 0.01 % seconds
  8. kernel_dof = 50
  9. kernel_duration = 2 % seconds
  10. kernel_smoothing = 3
  11. kernel_smoothing_style = 'normal'
  12. kernel_weights % The result of the kernel estimation step.
  13. spiketimes
  14. trial_types
  15. trial_weights % The result of the trial weight estimation step
  16. weighted_kernel_matrix % core_kernel_matrix multiplied by trial_weights
  17. end
  18. properties (Dependent)
  19. number_of_events
  20. kernel_bins
  21. kernels % Combine the kernel_weights with core_kernel to get the kernels
  22. total_time_steps
  23. end
  24. methods
  25. function obj = KernelRegressionA(e, s, t)
  26. obj.event_times = e;
  27. obj.spiketimes = s;
  28. if nargin < 3
  29. obj.trial_types = col(1:size(e,1));
  30. else
  31. obj.trial_types = t;
  32. % This allows you to fit fewer than # of trial trial_weights. Eg. if you want to assume the weights on
  33. % the same trial type is the same.
  34. end
  35. obj.trial_weights = ones(numel(unique(obj.trial_types)),1);
  36. end
  37. function obj = run(obj)
  38. generateCoreKernel(obj);
  39. generateKernelMatrix(obj);
  40. generateCoreWeights(obj);
  41. generateWeightMatrix(obj);
  42. fit(obj);
  43. end
  44. function obj = generateCoreKernel(obj) % tested, OK
  45. % To do the kernel regression we need a regression matrix to specify
  46. % where each element of the kernel influences the firing rate.
  47. % We will start with the
  48. % assumption that all kernels have the same length: kernel_duration.
  49. % Initialize the matrix to be the right size.
  50. obj.core_kernel = zeros(obj.kernel_bins, obj.kernel_dof);
  51. obj.kernel_weights = ones(size(obj.event_times,2), obj.kernel_dof);
  52. % Put ones every where they should be.
  53. bins_per_dof = obj.kernel_bins/ obj.kernel_dof;
  54. tmpA = repmat(1:obj.kernel_bins:numel(obj.core_kernel),bins_per_dof,1) + (0:(bins_per_dof-1))';
  55. idx = tmpA + (0:bins_per_dof:(obj.kernel_bins-1));
  56. obj.core_kernel(idx(:)) = 1;
  57. % Apply smoothing
  58. if obj.kernel_smoothing > 0
  59. smooth_factor = obj.kernel_smoothing * bins_per_dof;
  60. switch obj.kernel_smoothing_style
  61. case 'normal'
  62. smooth_krn = normpdf(-(5*smooth_factor):(5*smooth_factor), 0, smooth_factor)';
  63. case 'box'
  64. smooth_krn = ones(smooth_factor,1);
  65. otherwise
  66. error('Do not know how to smooth using %s');
  67. end
  68. obj.core_kernel = conv2(obj.core_kernel, smooth_krn, 'same');
  69. obj.core_kernel = obj.core_kernel ./ sum(obj.core_kernel,2);
  70. end
  71. end
  72. function obj = generateKernelMatrix(obj)
  73. % Initialize a matrix that is [session_duration / bin_size x # of
  74. % kernels * kernel_bins + 1] (the one is for baseline). o
  75. if obj.baseline_per_trial
  76. % Should baseline be allowed to vary for different trials of the same trial_type? I guess yes.
  77. obj.core_kernel_matrix = zeros(obj.total_time_steps, obj.number_of_events*obj.kernel_bins + size(obj.event_times,1));
  78. else
  79. obj.core_kernel_matrix = zeros(obj.total_time_steps, obj.number_of_events*obj.kernel_bins + 1);
  80. end
  81. kernel_matrix = zeros(obj.total_time_steps, obj.number_of_events*obj.kernel_bins);
  82. % just for the kernels. Deal with the baseline later
  83. % We have a big matrix of zeros. We want to put the core_kernel
  84. % everywhere there is an event. Our plan for doing this is to put a 1
  85. % whereever we want the kernel and then convolve this with our
  86. % core_kernel. We can then use this to estimate the kernel_weights.
  87. row_offset = floor(obj.kernel_bins/2);
  88. col_offset = floor(obj.kernel_dof/2);
  89. krn_offset = obj.kernel_dof;
  90. event_index = floor((obj.event_times - min(obj.event_times(:))) /obj.kernel_bin_size); % Converts event_times to indices
  91. row_idx = event_index(:) + row_offset;
  92. col_idx = col(repmat((0:(obj.number_of_events-1))*krn_offset,obj.) +
  93. idx = sub2ind(size(kernel_matrix), row_idx, col_idx);
  94. kernel_matrix(idx) = 1;
  95. kernel_matrix = conv2(kernel_matrix, obj.core_kernel, 'same');
  96. end
  97. function number_of_events = get.number_of_events(obj)
  98. number_of_events = size(obj.event_times, 2);
  99. end
  100. function total_time_steps = get.total_time_steps(obj)
  101. total_time_steps = (max(obj.event_times(:)) - min(obj.event_times(:))) ./ obj.kernel_bin_size + obj.kernel_bins;
  102. end
  103. function kernel_bins = get.kernel_bins(obj)
  104. kernel_bins = obj.kernel_duration/obj.kernel_bin_size;
  105. assert(rem(kernel_bins,1)==0, 'The kernel_duration should be an integer multiple of kernel_bin_size');
  106. end
  107. function kernels = get.kernels(obj)
  108. % tested, OK
  109. kernels = obj.kernel_weights * obj.core_kernel';
  110. end
  111. end
  112. end
  113. function y = col(x)
  114. y = x(:);
  115. end