PQNNSwithADC.m 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. function [Performances, normCCost, iSet, iXMinDist] = PQNNSwithADC( kk, mm, nRecall, flagVec)
  2. %clear all
  3. close all
  4. clc
  5. addpath('../Functions');
  6. addpath('./Functions');
  7. addpath('./Functions/Yael Library');
  8. % kk=100; % Number of Centroids for each set
  9. % mm=16; % Number of Centroid Sets to take into account
  10. % kCent=kk^mm; % Total cardinality of search space
  11. % nRecall=1; % Number of Recalls
  12. imgSz=128; % Image Descriptor size
  13. subImgSz=imgSz/mm; % Size of sub images to analyze
  14. Nlearn=1e5; % Number of pictures learn set
  15. Ntrain=1e6; % Number of pictures train set
  16. Ntests=1e4; % Number of pictures test set
  17. PQCentCompon=flagVec(1); % flag to activate Centroids Computation
  18. Quantizationon=flagVec(2); % flag to activate Data Sub Quantization Indices
  19. ADCCompon=flagVec(3); % flag to activate ADC Computations
  20. dispPQCenton=flagVec(4); % flag to activate Centroids Displayt
  21. justGetResultson=flagVec(5);% flag to just get results from stored data
  22. if(~justGetResultson)
  23. %% Computing/Loading PQ SubCentroids
  24. ndots=10;
  25. str2print='';
  26. ll=10;
  27. while mm/ll>1
  28. ndots=ndots-1;
  29. ll=ll*10;
  30. end
  31. ll=10;
  32. while kk/ll>1
  33. ndots=ndots-1;
  34. ll=ll*10;
  35. end
  36. for ll=0:ndots
  37. str2print=strcat(str2print,'.');
  38. end
  39. if(PQCentCompon)
  40. %% Reading Binary Learning Data Set
  41. fprintf('Loading Learning Data Set ...................... ');
  42. TSetMat=fvecs_read('../Data/sift/sift_learn.fvecs');
  43. fprintf('Done.\n');
  44. %% Preparing Training Data for Computation
  45. fprintf('Computing m=%1.0f, k=%1.0f PQ subcentroids ...%s',mm,kk,str2print);
  46. CSubSetMat=zeros(subImgSz,kk,mm);
  47. for jj=1:mm
  48. CSubSetMat(:,:,jj)=yael_kmeans(...
  49. single(TSetMat((jj-1)*subImgSz+1:jj*subImgSz,:)),...
  50. kk, 'niter', 100, 'verbose', 0);
  51. %[~, CCtemp]=kmeans(TSubSetMat(:,:,jj)',kk);
  52. %CSubSetMat(:,:,jj)=CCtemp';
  53. pp=double(jj/mm*100);
  54. if(pp<10)
  55. fprintf('\b\b%2.0f%%',pp);
  56. else
  57. fprintf('\b\b\b%2.0f%%',pp);
  58. end
  59. end
  60. fprintf('\b\b Done.\n');
  61. clearvars TSetMat
  62. save(strcat('./Data/k',int2str(kk),'m',int2str(mm),'CSubSetMat.mat'),'CSubSetMat');
  63. else
  64. fprintf('Loading m=%1.0f, k=%1.0f PQ subcentroids ....%s',mm,kk,str2print);
  65. load(strcat('./Data/k',int2str(kk),'m',int2str(mm),'CSubSetMat.mat'));
  66. fprintf(' Done.\n');
  67. end
  68. clearvars TSubSetMat
  69. %% Displaying PQ SubCentroids
  70. if(dispPQCenton)
  71. fprintf('Displaying PQ subcentroids: ... \n',mm);
  72. mmax=min(3,mm);
  73. for jj=1:mmax
  74. figure('units','normalized','outerposition',[0 0 1 1])
  75. suptitle('Centroid Portions of Pictures')
  76. imshow255_texmex(0,imgSz,[255*ones(kk,(jj-1)*subImgSz) ...
  77. CSubSetMat(:,:,jj)' ...
  78. 255*ones(kk,(mm-jj)*subImgSz)]);
  79. end
  80. end
  81. %% Preparing Training Data for Computation
  82. fprintf('Loading Training Data Set ...................... ');
  83. XSetMat=fvecs_read('../Data/sift/sift_base.fvecs');
  84. fprintf('Done.\n');
  85. %% Computing/Loading PQ SubCentroids for Training Data
  86. if(Quantizationon)
  87. nSteps=100; % number of steps for computation, be sure it divides 1e6
  88. fprintf('Computing PQ subcentroids for Training Data .... ');
  89. iXQ=zeros(mm,Ntrain);
  90. for jj=1:mm
  91. for ll=1:Ntrain/nSteps
  92. [~,iXQ(jj,(ll-1)*nSteps+1:ll*nSteps)]=Quantization(...
  93. XSetMat((jj-1)*subImgSz+1:jj*subImgSz,(ll-1)*nSteps+1:ll*nSteps),...
  94. CSubSetMat(:,:,jj));
  95. end
  96. pp=double(jj/mm*100);
  97. if(pp<10)
  98. fprintf('\b\b%2.0f%%',pp);
  99. else
  100. fprintf('\b\b\b%2.0f%%',pp);
  101. end
  102. end
  103. fprintf('\b\b\b\b\b\bDone.\n');
  104. save(strcat('./Data/ADC/ADCk',int2str(kk),'m',int2str(mm),'iXQ.mat'),'iXQ');
  105. else
  106. fprintf('Loading PQ subcentroids for Training Data ...... ');
  107. load(strcat('./Data/ADC/ADCk',int2str(kk),'m',int2str(mm),'iXQ.mat'));
  108. fprintf('Done.\n');
  109. end
  110. clearvars XSetMat
  111. %% Preparing Testing Data for Computation
  112. fprintf('Loading Testing Data Set ....................... ');
  113. YSetMat=fvecs_read('../Data/sift/sift_query.fvecs');
  114. fprintf('Done.\n');
  115. %% Computing/Loading NNs by applying ADC
  116. if(ADCCompon)
  117. fprintf('Solving NNS Applying ADC ........................ ');
  118. [iXMinDist]= AsymmetricDistanceComputation(CSubSetMat, iXQ, YSetMat, nRecall);
  119. save(strcat('./Data/ADC/ADCk',int2str(kk),'m',int2str(mm),'iXMinDist.mat'),'iXMinDist');
  120. fprintf('\b\b\bDone.\n');
  121. else
  122. fprintf('Loading NNS results with ADC ...................... ');
  123. load(strcat('./Data/ADC/ADCk',int2str(kk),'m',int2str(mm),'iXMinDist.mat'));
  124. fprintf('Done.\n');
  125. end
  126. clearvars iXQ YSetMat CSubSetMat
  127. else
  128. fprintf('Loading NNS results with ADC ...................... ');
  129. load(strcat('./Data/ADC/ADCk',int2str(kk),'m',int2str(mm),'iXMinDist.mat'));
  130. fprintf('Done.\n');
  131. end
  132. %% Loading results
  133. fprintf('Loading Results Data Set ............');
  134. iSet=ivecs_read('../Data/sift/sift_groundtruth.ivecs')+1;
  135. fprintf('Done.\n');
  136. iSet=iSet(1,:);
  137. %% Testing PQ Data Set
  138. fprintf('Testing PQ Test Data Set ... \n');
  139. reqRecall=zeros(Ntests,1);
  140. for ii = 1:Ntests
  141. iXeq=find(iXMinDist(ii,:)==iSet(ii));
  142. if numel(iXeq)==1
  143. reqRecall(ii)=iXeq;
  144. else
  145. reqRecall(ii)=nRecall+1;
  146. end
  147. end
  148. reqRecall=sort(reqRecall);
  149. ll=1;
  150. recAtR=[1 2 5 10 20 50 100 200 500 1000 2000 5000 10000];
  151. recAtR=recAtR(recAtR<=nRecall);
  152. % recAtR=nRecall;
  153. Performances=zeros(1,length(recAtR));
  154. for ii=recAtR
  155. if ii <= nRecall
  156. Performances(ll) = length (find (reqRecall <= ii & reqRecall <= nRecall)) / Ntests * 100;
  157. fprintf ('Recall@%3d = %.3f\n', ii, Performances(ll));
  158. ll=ll+1;
  159. end
  160. end
  161. %% Computational Cost Evaluation
  162. maxCCost=Ntrain*Ntests*imgSz;
  163. quantCCost=Ntests*kk*imgSz;
  164. PQCCostADC=Ntrain*Ntests*mm+quantCCost;
  165. normCCost=PQCCostADC/maxCCost;
  166. fprintf('Computational Cost: %2.2f %% w.r.t Exhaustive NNS\n',normCCost*100);