PositionDecoding.m 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. function [pos_err, pos_real] = PositionDecoding(pos_map_trial, tcs, deconv, behavior, varargin);
  2. % input::
  3. % pos_map_trial:
  4. % occupancy-normalized position activity map in the format of nTrial x nBin x nUnit
  5. % output::
  6. % pos_err:
  7. % absolute difference between decoded and actual position in cm
  8. % pos_real:
  9. % actual position in cm
  10. order_idx = 1:size(tcs.ratio, 2);
  11. unit_list = 1:size(tcs.ratio, 2);
  12. nTrial = max(behavior.trial);
  13. spd_thrsd = 3; % speed threshold: 3 cm/s
  14. tau = 0.5; % time window
  15. tracklen = 90; % in cm
  16. ifplot = false;
  17. % parse input parameters
  18. % Parse parameter list
  19. for i = 1:2:length(varargin),
  20. if ~ischar(varargin{i}),
  21. error(['Parameter ' num2str(i+2) ' is not a property']);
  22. end
  23. switch(lower(varargin{i}))
  24. case 'tracklen'
  25. tracklen = varargin{i+1};
  26. case 'tau'
  27. tau = varargin{i+1};
  28. if ~isvector(tau),
  29. error('Incorrect value for property ''unit_list''');
  30. end
  31. case 'unit_list'
  32. unit_list = varargin{i+1};
  33. if ~isvector(unit_list),
  34. error('Incorrect value for property ''unit_list''');
  35. end
  36. case 'order_idx'
  37. order_idx = varargin{i+1};
  38. if ~isvector(order_idx),
  39. error('Incorrect value for property ''order_idx''');
  40. end
  41. case 'ifplot'
  42. ifplot = varargin{i+1};
  43. otherwise,
  44. error(['Unknown property ''' num2str(varargin{i}) '''']);
  45. end
  46. end
  47. ts_tcs = tcs.tt;
  48. ts_beh = behavior.ts;
  49. pos = behavior.pos_norm * tracklen; % scale to track length
  50. trial = behavior.trial;
  51. spd = behavior.speed;
  52. nX = round(tau ./ tcs.tt(2));
  53. nNeuron = length(unit_list);
  54. deconv = deconv(:, unit_list);
  55. deconv_sm_size = 10;
  56. deconv_sm = Smooth(double(deconv), [deconv_sm_size, 0]);
  57. deconv_ordered = deconv_sm(:, order_idx);
  58. % deconv_ordered(deconv_ordered < 0.3) = 0;
  59. pos_tuning_curve = squeeze(mean(pos_map_trial(1:2:end,:,order_idx),1)); % use only odd trials for training
  60. mltp_factor = exp(-tau * sum(pos_tuning_curve, 2));
  61. nBin = length(mltp_factor);
  62. pos_decoded = [];
  63. ts_decoded = [];
  64. trial_decoded = [];
  65. for iTrial = 1:1:nTrial
  66. ok = trial == iTrial;
  67. ts_this = ts_beh(ok);
  68. ok2 = ts_tcs > ts_this(1) & ts_tcs < ts_this(end);
  69. len_this = sum(ok2);
  70. if len_this > 0
  71. dcv_this = deconv_ordered(ok2, :);
  72. ts_this = ts_tcs(ok2);
  73. for iX = 1:ceil(len_this/nX)
  74. st = (iX - 1) * nX + 1;
  75. ed = iX * nX;
  76. if ed > len_this
  77. ed = len_this;
  78. end
  79. ts_decoded = cat(1, ts_decoded, mean(ts_this(st:ed)));
  80. dcv_tmp = mean(dcv_this(st:ed, :), 1);
  81. dcv_mat_tmp = repmat(dcv_tmp, nBin, 1);
  82. prob_pos = prod(pos_tuning_curve.^dcv_mat_tmp, 2) .* mltp_factor;
  83. [~, idx] = max(prob_pos);
  84. pos_decoded = cat(1, pos_decoded, idx); % need to be normalized to track length
  85. trial_decoded = cat(1, trial_decoded, iTrial * ones(length(idx),1));
  86. end
  87. end
  88. end
  89. spd_decoded = interp1(ts_beh, spd, ts_decoded, 'linear');
  90. ok3 = spd_decoded > spd_thrsd;
  91. st_run = find(diff([0; ok3]) == 1); % start of a running period
  92. ed_run = find(diff([0; ok3]) == -1);
  93. if length(st_run) > length(ed_run)
  94. st_run = st_run(1:end-1);
  95. end
  96. if ifplot
  97. hold on
  98. for iTrial = 1:nTrial
  99. ok = trial==iTrial;
  100. plot(ts_beh(ok), pos(ok)/tracklen*nNeuron, 'b');
  101. end
  102. for iRun = 1:length(st_run)
  103. idx_this = st_run(iRun):ed_run(iRun);
  104. ts_decoded_this = ts_decoded(idx_this);
  105. pos_decoded_this = pos_decoded(idx_this)/nBin*nNeuron;
  106. pos_break = find(abs(diff(pos_decoded_this))>round(nNeuron*0.8));
  107. if isempty(pos_break)
  108. plot(ts_decoded_this, pos_decoded_this, 'r')
  109. else
  110. pos_break = [0; pos_break; length(pos_decoded_this)];
  111. for i = 1:length(pos_break)-1
  112. idx_now = pos_break(i)+1:pos_break(i+1);
  113. plot(ts_decoded_this(idx_now), pos_decoded_this(idx_now), 'r')
  114. end
  115. end
  116. end
  117. hold off
  118. end
  119. ts_decoded = ts_decoded(ok3);
  120. pos_decoded = pos_decoded(ok3) / nBin * tracklen;
  121. trial_decoded = trial_decoded(ok3);
  122. pos_true = interp1(ts_beh, pos, ts_decoded, 'linear');
  123. ok4 = rem(trial_decoded, 2) == 0; % calculate errors for only the even trials
  124. pos_err = abs(pos_decoded(ok4) - pos_true(ok4));
  125. pos_real = pos_true(ok4);
  126. trial_decoded = trial_decoded(ok4);