nets_lda.m 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. %
  2. % lda_percentage = nets_lda(x,A,nmethod);
  3. % Steve Smith - 2013-2014
  4. %
  5. % apply one of a number of linear classifiers to the two-group data
  6. % returning the classifiction accuracy found using LOO testing
  7. %
  8. % x = subjects X measurements - for example a "netmats" matrix, with two "groups" of ordered subjects
  9. % A is number of subjects in first group; set to 0 for paired data
  10. %
  11. % nmethod:
  12. % 1 FLD (unstable in Octave)
  13. % 2 FLD-mean (ignore covariance)
  14. % 3 two-group T weighting
  15. % 4 two-group maximum-T weighting
  16. % 5 two-group thresholded-T weighting
  17. % 6 two-group T/stddev weighting
  18. % 7 Matlab's built-in SVM (not available in Octave)
  19. % 8 LIBSVM's SVM - need to have libsvm installed
  20. %
  21. function [lda_percentages,grot] = nets_lda(x,A,nmethod);
  22. x=x';
  23. grot=[]; % misc return variable
  24. N=size(x,2);
  25. m=std(x,0,2);
  26. x=x(find(m>0),:);
  27. if A==0
  28. ALL=1:N/2;
  29. for n=1:N/2
  30. x(:,[n n+N/2]) = nets_demean(x(:,[n n+N/2]),2);
  31. end
  32. else
  33. ALL=1:N;
  34. end
  35. for n = ALL
  36. if A==0
  37. xa=x(:,setdiff(1:N/2,n));
  38. xb=x(:,setdiff(N/2+1:N,n+N/2));
  39. else
  40. xa=x(:,setdiff(1:A,n));
  41. xb=x(:,setdiff(A+1:N,n));
  42. end
  43. E=[nets_demean(xa,2) nets_demean(xb,2)]; Estd=std(E,0,2); [u,s,v]=svd(E,'econ'); ps=pinv(s);
  44. meana=mean(xa,2); meanb=mean(xb,2); deltamean = meana-meanb; meanab=0.5*(meana+meanb);
  45. w=u*ps*ps*(u'*deltamean);
  46. xx=x(:,n) - meanab;
  47. xxn= (x(:,n) - meanab) ./ Estd;
  48. xan=(xa-repmat(meanab,1,size(xa,2)))./repmat(Estd,1,size(xa,2));
  49. xbn=(xb-repmat(meanab,1,size(xb,2)))./repmat(Estd,1,size(xb,2));
  50. t=sqrt(N-2)*deltamean./Estd;
  51. [mmm,iii]=max(abs(t)); best_t_i(n)=iii;
  52. if nmethod==1
  53. lda(1,n)=w'* xx; % FLD
  54. elseif nmethod==2
  55. lda(1,n)=deltamean'*xx; % FLD mean (ignore covariance)
  56. elseif nmethod==3
  57. lda(1,n)=t'*xx; % T
  58. elseif nmethod==4
  59. lda(1,n)=t(iii)*xx(iii); % Tmax
  60. elseif nmethod==5
  61. lda(1,n)=(t.*(abs(t)>4))'*xx; % Tthresh
  62. elseif nmethod==6
  63. lda(1,n)=(deltamean./(Estd.*Estd))'*xx; % T/std
  64. elseif nmethod==7 % matlab svm
  65. svmstruct = svmtrain([xan xbn]',[ ones(size(xa,2),1) ; zeros(size(xb,2),1) ]);
  66. %figure; plot(svmstruct.ScaleData.shift); figure; plot(svmstruct.ScaleData.scaleFactor)
  67. lda(1,n)=svmclassify(svmstruct,xxn')*2-1;
  68. elseif nmethod==8 % LIBSVM svm
  69. svmstruct = svmtrain([ ones(size(xa,2),1) ; zeros(size(xb,2),1) ], [xan xbn]' , '-q -t 0'); % change to "-t 2" for RBF nonlinear SVM
  70. [grot,~,~]=svmpredict(1,xxn',svmstruct,'-q'); lda(1,n)=grot*2-1;
  71. end
  72. end
  73. if A==0
  74. lda_percentages = 100* ( sum(lda>0,2) )' / length(ALL);
  75. else
  76. lda_percentages = 100* ( sum(lda(:,1:A)>0,2) + sum(lda(:,A+1:N)<0,2) )' / length(ALL);
  77. end