Compute_accuracyNN_matrix.m 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. function [Acc,ConfMat] = Compute_accuracyNN_matrix(AmpResp,AllSamples,LOO)
  2. %% This function calculates the similarity matrix which estimates the correlation between pairs
  3. %% of sounds after correction for trial to trial noise
  4. %% INPUT
  5. % AmpResp = the responses to be correlations : ROI * stim * trials
  6. % AllSamples = 1 (default) calculates all possible pairs of trials, 0
  7. % instead does this for a random 50 combiations which is a lot quicker
  8. % LOO = 1 train using leave one out cross validation
  9. % LOO = 0 train using two halves of trials for cross validation
  10. %% OUTPUT
  11. % Acc : accuracy for identifying each sound using nearest neighbour
  12. % ConfMat : confusion matrix
  13. %% intialize
  14. if not(exist('AllSamples'))
  15. AllSamples = 1;
  16. end
  17. if not(exist('LOO'))
  18. LOO = 0;
  19. end
  20. if LOO ==1
  21. AllSamples = 1;
  22. end
  23. % clean up NaN cells
  24. NaN_cells = find(isnan(sum(squeeze(nanmean(AmpResp,3)'))));
  25. AmpResp(NaN_cells,:,:) = [];
  26. disp([num2str(length(NaN_cells)) ' ROIs out of ' num2str(size(AmpResp,1)) 'excluded because they had NaNs'])
  27. % Not all data sets have all of the trials
  28. maxgoodtrials = find(squeeze(sum(sum(isnan(AmpResp),1),2))==0,1,'last');
  29. disp([num2str(maxgoodtrials) ' trials without NaNs'])
  30. maxgoodtrials = floor(maxgoodtrials/2)*2;
  31. for neur = 1:size(AmpResp,1)
  32. AmpResp(neur,:,:) = AmpResp(neur,:,:) - nanmean(AmpResp(neur,:));
  33. end
  34. if LOO
  35. % Use leave one out
  36. AllTrialPairs = combnk(1:maxgoodtrials,maxgoodtrials-1);
  37. TrialsTrain = AllTrialPairs;
  38. TrialsTest = [maxgoodtrials:-1:1];
  39. TrialsTest = reshape(TrialsTest,maxgoodtrials,1);
  40. else
  41. % Get the trial pairs
  42. AllTrialPairs = combnk(1:maxgoodtrials,floor(maxgoodtrials/2));
  43. TrialsTrain = AllTrialPairs;
  44. TrialsTest = fliplr(TrialsTrain')';
  45. end
  46. if AllSamples
  47. else
  48. % a subset of trial pairs
  49. RandTrials = randperm(size(TrialsTrain,1),10);
  50. TrialsTrain = TrialsTrain(RandTrials,:);
  51. TrialsTest = TrialsTest(RandTrials,:);
  52. end
  53. %% Get accuracy
  54. for perm = 1:size(TrialsTest,1)
  55. TrainVect = squeeze(nanmean(AmpResp(:,:,TrialsTrain(perm,:)),3));
  56. if length(TrialsTest(perm,:))>1
  57. TestVect = squeeze((AmpResp(:,:,TrialsTest(perm,:))));
  58. else
  59. TestVect = ((AmpResp(:,:,TrialsTest(perm,:))));
  60. end
  61. for k = 1:size(TestVect,3)
  62. C = corr(TrainVect,squeeze(TestVect(:,:,k)));
  63. [val,ind(k,:)] = max(C);
  64. end
  65. Acc(perm,:) = nanmean((ind-repmat([1:size(TestVect,2)],[size(TestVect,3),1]))==0,1);
  66. for sd = 1:size(ind,2)
  67. for sd2 = 1:size(ind,2)
  68. ConfMat(perm,sd,sd2) = sum(ind(:,sd) ==sd2);
  69. end
  70. end
  71. for sd = 1:size(ind,2)
  72. for sd2 = 1:size(ind,2)
  73. ConfMat(perm,sd,sd2) = mean(ind(:,sd) ==sd2);
  74. end
  75. end
  76. end