Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

WNNetsOverPQ_Fine_Coarse_All_Centroids_Assets.m 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. %function [NNeighPerf, NormCpxty, NormPQCpxty]= WNNetsOverPQ_Kmeans(kk,mm,nRecall,Npeaks)
  2. clear all
  3. close all
  4. clc
  5. addpath('./SIFTs Library');
  6. addpath('./PQ Library');
  7. addpath('./WNN Library');
  8. addpath('./PQ Library/Yael Library');
  9. set(0,'defaulttextinterpreter','latex');
  10. % ---- Quantities of interest for the input Data Set --------------
  11. nDim=128; % Dimensionality of the Data
  12. NLearn=1e5; % Number of descriptors learn set
  13. Ntrain=1e6; % Number of descriptors train set
  14. Ntests=1e4; % Number of descriptors test set
  15. % ---- Quantities of interest for Fine Product Quantization -------
  16. kkf=256; % Number of clusters kmeans
  17. mmf=16; % Number of splits for PQ (divides 128)
  18. mDimf=nDim/mmf; % Dimensionality of splitted vectors
  19. nRecall=100;
  20. recAtR=[1 2 5 10 20 50 100 200 500 1000 2000 5000 10000];
  21. recAtR=recAtR(recAtR<=nRecall);
  22. % ---- Quantities of interest for Coarse Product Quantization -----
  23. kkc=10; % Number of clusters kmeans
  24. mmc=2; % Number of splits for PQ (divides 128)
  25. mDimc=nDim/mmc; % Dimensionality of splitted vectors
  26. % ---- Quantities of interest for Willshaw Networks ---------------
  27. nn=1e4; % Number of vectors per WNN
  28. qq=Ntrain/nn; % Number of different WNNs
  29. Npeaks=round([1, qq./[10 5 4 3 2 4/3 1]]); % Number of peaks to select
  30. resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc);
  31. if 0%exist(resultsFileName,'file')==2
  32. %% LOADING RESULTS SECTION
  33. load(resultsFileName);
  34. else
  35. %% LOADING DATA SECTION
  36. bVSetMat=double(fvecs_read('../Data/sift/sift_learn.fvecs'));
  37. bWSetMat=double(fvecs_read('../Data/sift/sift_base.fvecs'));
  38. bZSetMat=double(fvecs_read('../Data/sift/sift_query.fvecs'));
  39. iMinDistExh=ivecs_read('../Data/sift/sift_groundtruth.ivecs')+1;
  40. %% COARSE PRODUCT QUANTIZATION SECTION
  41. coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_sZSetMat_k%d_m%d.mat',kkc,mmc);
  42. if exist(coarseFileName,'file')==2
  43. load(coarseFileName);
  44. else
  45. str=sprintf('Coarse PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmc,kkc);
  46. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  47. CSubSetMatc=zeros(mDimc,kkc,mmc);
  48. CWidxc=zeros(Ntrain,mmc);
  49. sZSetMatc=zeros(kkc*mmc,Ntests);
  50. pperc=[];
  51. for jj=1:mmc
  52. [CSubSetMatc(:,:,jj), ~ ]=yael_kmeans(...
  53. single(bVSetMat((jj-1)*mDimc+1:jj*mDimc,:)),kkc, 'niter', 100, 'verbose', 0);
  54. ZDistMatc=EuclideanDistancesMat(bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj));
  55. [~,CWidxc(:,jj)]=Quantization(bWSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj));
  56. sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(-sqrt(ZDistMatc)');
  57. perc=sprintf('%2.2f%%',jj/mmc*100);
  58. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  59. fprintf(1,'%s',perc); pperc=perc;
  60. end
  61. clearvars CSubSetMatc ZDistMatc
  62. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  63. fprintf('Done.\n');
  64. save(coarseFileName,'CWidxc','sZSetMatc');
  65. end
  66. %% FINE PRODUCT QUANTIZATION SECTION
  67. fineFileName=sprintf('./Fine Quantization Indices and Distances/fine_CWidx_ZDistMat_k%d_m%d.mat',kkf,mmf);
  68. if exist(fineFileName,'file')==2
  69. load(fineFileName);
  70. else
  71. str=sprintf('Fine PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmf,kkf);
  72. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  73. CSubSetMatf=zeros(mDimf,kkf,mmf);
  74. CWidxf=zeros(Ntrain,mmf);
  75. ZDistMatf=zeros(Ntests,kkf,mmf);
  76. pperc=[];
  77. for jj=1:mmf
  78. [CSubSetMatf(:,:,jj), ~ ]=yael_kmeans(...
  79. single(bVSetMat((jj-1)*mDimf+1:jj*mDimf,:)),kkf, 'niter', 100, 'verbose', 0);
  80. ZDistMatf(:,:,jj)=EuclideanDistancesMat(bZSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  81. [~,CWidxf(:,jj)]=Quantization(bWSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  82. perc=sprintf('%2.2f%%',jj/mmf*100);
  83. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  84. fprintf(1,'%s',perc);pperc=perc;
  85. end
  86. clearvars CSubSetMatf bZSetMat
  87. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  88. fprintf('Done.\n');
  89. save(fineFileName,'CWidxf','ZDistMatf');
  90. end
  91. clearvars bWSetMat bVSetMat
  92. %% PREPROCESSING DATA TO STORE IN WNNets SECTION
  93. str=sprintf('Building Learning Data Sets ');
  94. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  95. %permRand=randperm(Ntrain,qq);
  96. %% JUST FOR kkc=10 mmc=4
  97. %AllCentSet=[arrayfun(@(ll)ceil(ll/1000),1:10000);
  98. % repmat(arrayfun(@(ll)ceil(ll/100),1:1000),1,10);
  99. % repmat(arrayfun(@(ll)ceil(ll/10),1:100),1,100); repmat([1:10],1,1000)]';
  100. %% JUST FOR kkc=10 mmc=2
  101. AllCentSet=[arrayfun(@(ll)ceil(ll/10),1:100); repmat([1:10],1,10)]';
  102. learningNetsIdx=QuantHammingAllOfTheCanonicalIndices(CWidxc,AllCentSet);
  103. %learningNetsIdx=QuantHammingCanonicalIndices(CWidxc,permRand);
  104. LearningDistribution=hist(learningNetsIdx,1:qq);
  105. fprintf('Done.\n');
  106. eWDiffFeatsSplittedRpi(1:100)=struct('RunPermIndices',[]);
  107. eWDiffFeatsSplittedRci(1:100)=struct('RunCentroidsIndices',[]);
  108. multiCount=ones(1,qq);
  109. str=sprintf('Preprocessing Data for Learning ');
  110. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  111. pperc=[];
  112. for ii=1:Ntrain
  113. eWDiffFeatsSplittedRpi(learningNetsIdx(ii)).RunPermIndices(multiCount(learningNetsIdx(ii)))=ii;
  114. eWDiffFeatsSplittedRci(learningNetsIdx(ii)).RunCentroidsIndices(:,multiCount(learningNetsIdx(ii)))=CWidxc(ii,:);
  115. multiCount(learningNetsIdx(ii))=multiCount(learningNetsIdx(ii))+1;
  116. perc=sprintf('%2.2f%%',ii/Ntrain*100);
  117. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  118. fprintf(1,'%s',perc); pperc=perc;
  119. end
  120. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  121. fprintf('Done.\n');
  122. clearvars eWSetMat permRand CWidxc %multiCount nStoredVec
  123. %% BUILDING WNNets AND SCORES COMPUTATION SECTION
  124. str=sprintf('Building %d Willshaw NNets and Computing Scores ',qq);
  125. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  126. II=eye(kkc);
  127. P0=0;
  128. %Ntests=1;
  129. sX0=zeros(Ntests,qq);
  130. pperc=[];
  131. for jj=1:qq
  132. uniqCent=unique(eWDiffFeatsSplittedRci(jj).RunCentroidsIndices','rows');
  133. if ~isempty(uniqCent)
  134. WXsmatrix=BuildNetworkfromIdxs(uniqCent',II);
  135. P0=P0+mean(mean(WXsmatrix));
  136. sX0(:,jj)=diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests));
  137. end
  138. perc=sprintf('%2.2f%%',jj/qq*100);
  139. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  140. fprintf(1,'%s',perc);pperc=perc;
  141. end
  142. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  143. fprintf('Done.\n');
  144. avgP0=P0/qq;
  145. clearvars sZSetMatc eWDiffFeatsSplittedRsm WXsmatrix eWDiffFeatsSplittedRci uniqCent
  146. %% ADC COMPUTATION FOR PERFORMANCES COMPUTATION SECTION
  147. str=sprintf('Computing Performances: L=[%d:%d], recalls@[%d:%d] ',Npeaks(1),Npeaks(end),recAtR(1),recAtR(end));
  148. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  149. NNeighPerf=zeros(length(recAtR),length(Npeaks));
  150. iXMinDistADC=zeros(1,nRecall);
  151. actIXMinDistADC=zeros(1,nRecall);
  152. [~,sortemp]=sort(sX0,2,'descend');
  153. NNmostLikely=sortemp(:,1:Npeaks(end))';
  154. clearvars sortemp
  155. pperc=[];
  156. for jj=1:Ntests
  157. for tt=1:length(Npeaks)
  158. if tt>1
  159. retIdxs=actIXMinDistADC;
  160. for ll=Npeaks(tt-1)+1:Npeaks(tt)
  161. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  162. end
  163. usedVec(jj,tt)=length(retIdxs)-length(actIXMinDistADC);
  164. else
  165. retIdxs=[];
  166. for ll=1:Npeaks(tt)
  167. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  168. end
  169. usedVec(jj,1)=length(retIdxs);
  170. end
  171. iXMinDistADC=AsymmetricDistanceComputationWDist(CWidxf(retIdxs,:)', ZDistMatf(jj,:,:), nRecall);
  172. numRetIdxs=length(retIdxs);
  173. if numRetIdxs >= nRecall
  174. actIXMinDistADC=retIdxs(iXMinDistADC);
  175. else
  176. actIXMinDistADC(1:numRetIdxs)=retIdxs(iXMinDistADC(1:numRetIdxs));
  177. actIXMinDistADC(numRetIdxs+1:end)=ones(1,nRecall-numRetIdxs);
  178. end
  179. NNeighPerf(:,tt) = NNeighPerf(:,tt)+ RecallTest(actIXMinDistADC,nRecall,iMinDistExh(:,jj))';
  180. perc=sprintf('%2.2f%%',((jj-1)*length(Npeaks)+tt)/(Ntests*length(Npeaks))*100);
  181. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  182. fprintf(1,'%s',perc);pperc=perc;
  183. end
  184. end
  185. NNeighPerf=NNeighPerf/Ntests;
  186. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  187. fprintf('Done.\n');
  188. clearvars SortQueryNetIdx actIXMinDistADC RetrievedIdxs bZSetMat CWidxf NNMostLikely eWDiffFeatsSplittedRpi sX0
  189. %% COMPUTATIONAL COST EVALUATION SECTION
  190. actNpeaks=mean(cumsum(usedVec')'./nn);
  191. Cpxty=kkf*nDim+kkc*nDim+avgP0*(mmc*kkc)^2*qq+actNpeaks*nn*mmf;
  192. NormCpxty=Cpxty/(nDim*Ntrain);
  193. NormPQCpxty=(kkf*nDim+mmf*Ntrain)/(nDim*Ntrain);
  194. NormNormCpxty=NormCpxty/NormPQCpxty;
  195. %% SAVING RESULTS SECTION
  196. save(resultsFileName,'NormCpxty','NormPQCpxty','NNeighPerf','Npeaks','actNpeaks','avgP0');
  197. end
  198. %% RESULTS VISUALIZATION SECTION
  199. figure();
  200. plot([NormCpxty(1:end-1) NormPQCpxty],NNeighPerf','x-'); hold on;
  201. for jj=1:length(actNpeaks)-1
  202. text(NormCpxty(jj),1.05,sprintf('L=%2.0f',actNpeaks(jj)*nn),'interpreter','latex');
  203. end
  204. plot(NormPQCpxty,NNeighPerf(:,end),'*');
  205. text(NormPQCpxty,1.05,'PQ','interpreter','latex');
  206. title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  207. ylabel('Nearest Neighbour Search Performances','interpreter','latex');
  208. xlabel('Computational Cost (Normalized to Exhaustive Search)','interpreter','latex');
  209. h=legend('recall @1','recall @2','recall @5','recall @10','recall @20',...
  210. 'recall @50','recall @100','Location','southeast');
  211. grid on; grid minor;
  212. set(h,'interpreter','latex');
  213. ylim([0 1.1]);
  214. xlim([0 NormCpxty(end)+.005])