WNNetsOverPQ_Fine_Coarse.m 15 KB


  1. %function [NNeighPerf, NormCpxty, NormPQCpxty]= WNNetsOverPQ_Kmeans(kk,mm,nRecall,Npeaks)
  2. % clear all
  3. % close all
  4. clc
  5. addpath('./SIFTs Library');
  6. addpath('./PQ Library');
  7. addpath('./WNN Library');
  8. addpath('./PQ Library/Yael Library');
  9. set(0,'defaulttextinterpreter','latex');
  10. % ---- Quantities of interest for the input Data Set --------------
  11. nDim=128; % Dimensionality of the Data
  12. NLearn=1e5; % Number of descriptors learn set
  13. Ntrain=1e6; % Number of descriptors train set
  14. Ntests=1e4; % Number of descriptors test set
  15. % ---- Quantities of interest for Fine Product Quantization -------
  16. kkf=256; % Number of clusters kmeans
  17. mmf=16; % Number of splits for PQ (divides 128)
  18. mDimf=nDim/mmf; % Dimensionality of splitted vectors
  19. nRecall=100;
  20. recAtR=[1 2 5 10 20 50 100 200 500 1000 2000 5000 10000];
  21. recAtR=recAtR(recAtR<=nRecall);
  22. % ---- Quantities of interest for Coarse Product Quantization -----
  23. kkc=41; % Number of clusters kmeans
  24. mmc=2; % Number of splits for PQ (divides 128)
  25. mDimc=nDim/mmc; % Dimensionality of splitted vectors
  26. % ---- Quantities of interest for Willshaw Networks ---------------
  27. nn=1e2; % Number of vectors per WNN
  28. qq=Ntrain/nn; % Number of different WNNs
  29. Npeaks=round([1, qq./[1000 100 40 20 10 5 4 3 2 4/3 1]]); % Number of peaks to select
  30. %Npeaks=round([1, qq./[2000 1000 200 100 20 10 5 4 3 2 4/3 1]]);
  31. %Npeaks=round([1, qq./[2000 1000 200 100 200/3 50 20 10 5]]);
  32. binaryLab=0;
  33. Zoomed=0;
  34. if(binaryLab && Zoomed)
  35. resultsFileName=sprintf('./Binary Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d_Zoomed.mat',mmf,kkf,nn,mmc,kkc);
  36. elseif binaryLab
  37. resultsFileName=sprintf('./Binary Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc);
  38. elseif Zoomed
  39. resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d_Zoomed.mat',mmf,kkf,nn,mmc,kkc);
  40. else
  41. resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc);
  42. end
  43. if 0 %exist(resultsFileName,'file')==2
  44. %% LOADING RESULTS SECTION
  45. load(resultsFileName);
  46. else
  47. %% LOADING DATA SECTION
  48. bVSetMat=double(fvecs_read('../Data/sift/sift_learn.fvecs'));
  49. bWSetMat=double(fvecs_read('../Data/sift/sift_base.fvecs'));
  50. bZSetMat=double(fvecs_read('../Data/sift/sift_query.fvecs'));
  51. iMinDistExh=ivecs_read('../Data/sift/sift_groundtruth.ivecs')+1;
  52. %% COARSE PRODUCT QUANTIZATION SECTION
  53. if(binaryLab)
  54. coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_CZidx_k%d_m%d.mat',kkc,mmc);
  55. else
  56. coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_sZSetMat_k%d_m%d.mat',kkc,mmc);
  57. end
  58. if 0%exist(coarseFileName,'file')==2
  59. load(coarseFileName);
  60. else
  61. str=sprintf('Coarse PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmc,kkc);
  62. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  63. CSubSetMatc=zeros(mDimc,kkc,mmc);
  64. CWidxc=zeros(Ntrain,mmc);
  65. CZidxc=zeros(Ntests,mmc);
  66. sZSetMatc=zeros(kkc*mmc,Ntests);
  67. pperc=[];
  68. nsplits=10;
  69. nnsp=Ntrain/nsplits;
  70. for jj=1:mmc
  71. [CSubSetMatc(:,:,jj), ~ ]=yael_kmeans(...
  72. single(bVSetMat((jj-1)*mDimc+1:jj*mDimc,:)),kkc, 'niter', 100, 'verbose', 0);
  73. for rr=1:nsplits
  74. [~,CWidxc((rr-1)*nnsp+1:rr*nnsp,jj)]=Quantization(...
  75. bWSetMat((jj-1)*mDimc+1:jj*mDimc,(rr-1)*nnsp+1:rr*nnsp),...
  76. CSubSetMatc(:,:,jj));
  77. end
  78. if(binaryLab)
  79. [~,CZidxc(:,jj)]=Quantization(...
  80. bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),...
  81. CSubSetMatc(:,:,jj));
  82. else
  83. ZDistMatc=EuclideanDistancesMat(bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj));
  84. %sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(-sqrt(ZDistMatc)');
  85. %sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(1-bsxfun(@rdivide, abs(ZDistMatc)', min(abs(ZDistMatc)')));
  86. sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=(bsxfun(@ldivide, abs(ZDistMatc)', min(abs(ZDistMatc)')));
  87. end
  88. perc=sprintf('%2.2f%%',jj/mmc*100);
  89. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  90. fprintf(1,'%s',perc); pperc=perc;
  91. end
  92. %clearvars CSubSetMatc ZDistMatc
  93. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  94. fprintf('Done.\n');
  95. if(binaryLab)
  96. save(coarseFileName,'CWidxc','CZidxc');
  97. else
  98. %save(coarseFileName,'CWidxc','sZSetMatc');
  99. end
  100. end
  101. %% FINE PRODUCT QUANTIZATION SECTION
  102. fineFileName=sprintf('./Fine Quantization Indices and Distances/fine_CWidx_ZDistMat_k%d_m%d.mat',kkf,mmf);
  103. if exist(fineFileName,'file')==2
  104. load(fineFileName);
  105. else
  106. str=sprintf('Fine PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmf,kkf);
  107. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  108. CSubSetMatf=zeros(mDimf,kkf,mmf);
  109. CWidxf=zeros(Ntrain,mmf);
  110. ZDistMatf=zeros(Ntests,kkf,mmf);
  111. pperc=[];
  112. for jj=1:mmf
  113. [CSubSetMatf(:,:,jj), ~ ]=yael_kmeans(...
  114. single(bVSetMat((jj-1)*mDimf+1:jj*mDimf,:)),kkf, 'niter', 100, 'verbose', 0);
  115. ZDistMatf(:,:,jj)=EuclideanDistancesMat(bZSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  116. [~,CWidxf(:,jj)]=Quantization(bWSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  117. perc=sprintf('%2.2f%%',jj/mmf*100);
  118. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  119. fprintf(1,'%s',perc);pperc=perc;
  120. end
  121. clearvars CSubSetMatf bZSetMat
  122. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  123. fprintf('Done.\n');
  124. save(fineFileName,'CWidxf','ZDistMatf');
  125. end
  126. clearvars bWSetMat bVSetMat
  127. %% PREPROCESSING DATA TO STORE IN WNNets SECTION
  128. str=sprintf('Building Learning Data Sets ');
  129. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  130. [minDh,learningNetsIdx]=QuantHammingCanonicalIndices_ceil(CWidxc,qq);
  131. [minDh2,learningNetsIdx2]=QuantHammingCanonicalIndices_floor(CWidxc,qq);
  132. nStoredVec=hist(learningNetsIdx,1:qq);
  133. fprintf('Done.\n');
  134. break;
  135. eWDiffFeatsSplittedRpi(1:qq)=struct('RunPermIndices',[]);
  136. eWDiffFeatsSplittedRci(1:qq)=struct('RunCentroidsIndices',[]);
  137. multiCount=ones(1,qq);
  138. str=sprintf('Preprocessing Data for Learning ');
  139. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  140. pperc=[];
  141. for ii=1:Ntrain
  142. eWDiffFeatsSplittedRpi(learningNetsIdx(ii)).RunPermIndices(multiCount(learningNetsIdx(ii)))=ii;
  143. eWDiffFeatsSplittedRci(learningNetsIdx(ii)).RunCentroidsIndices(:,multiCount(learningNetsIdx(ii)))=CWidxc(ii,:);
  144. multiCount(learningNetsIdx(ii))=multiCount(learningNetsIdx(ii))+1;
  145. perc=sprintf('%2.2f%%',ii/Ntrain*100);
  146. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  147. fprintf(1,'%s',perc); pperc=perc;
  148. end
  149. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  150. fprintf('Done.\n');
  151. clearvars eWSetMat permRand CWidxc %multiCount nStoredVec
  152. %% BUILDING WNNets AND SCORES COMPUTATION SECTION
  153. str=sprintf('Building %d Willshaw NNets and Computing Scores ',qq);
  154. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  155. II=eye(kkc);
  156. P0=0;
  157. %Ntests=1;
  158. sX0=zeros(Ntests,qq);
  159. pperc=[];
  160. zc=zeros(kkc*mmc,Ntests);
  161. if(binaryLab)
  162. for jj=1:mmc
  163. zc((jj-1)*kkc+1:jj*kkc,:)=II(CZidxc(1:Ntests,jj),:)';
  164. end
  165. end
  166. for jj=1:qq
  167. uniqCent=unique(eWDiffFeatsSplittedRci(jj).RunCentroidsIndices','rows');
  168. WXsmatrix=BuildNetworkfromIdxs(uniqCent',II);
  169. P0=P0+mean(mean(WXsmatrix));
  170. if(binaryLab)
  171. %sX0(:,jj)=diag(zc(:,1:Ntests)'*double(WXsmatrix)*zc(:,1:Ntests));
  172. sX0(:,jj)=1./(sum(zc(:,1:Ntests).^2))'.*diag(zc(:,1:Ntests)'*double(WXsmatrix)*zc(:,1:Ntests));
  173. else
  174. %sX0(:,jj)=diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests));
  175. sX0(:,jj)=1./(sum(sZSetMatc(:,1:Ntests)).^2)'.*diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests));
  176. end
  177. perc=sprintf('%2.2f%%',jj/qq*100);
  178. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  179. fprintf(1,'%s',perc);pperc=perc;
  180. end
  181. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  182. fprintf('Done.\n');
  183. avgP0=P0/qq;
  184. clearvars sZSetMatc eWDiffFeatsSplittedRsm WXsmatrix eWDiffFeatsSplittedRci uniqCent
  185. %% ADC COMPUTATION FOR PERFORMANCES COMPUTATION SECTION
  186. str=sprintf('Computing Performances: L=[%d:%d], recalls@[%d:%d] ',Npeaks(1),Npeaks(end),recAtR(1),recAtR(end));
  187. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  188. NNeighPerf=zeros(length(recAtR),length(Npeaks));
  189. iXMinDistADC=zeros(1,nRecall);
  190. actIXMinDistADC=zeros(1,nRecall);
  191. [~,sortemp]=sort(sX0,2,'descend');
  192. NNmostLikely=sortemp(:,1:Npeaks(end))';
  193. [ssxo,isxo]=sort(sX0,2,'descend');
  194. %%
  195. close all;
  196. for itest=500
  197. lrp2=0;
  198. rpixcs2=[];
  199. for jj=1:qq
  200. lrp2(jj)=length(eWDiffFeatsSplittedRpi(jj).RunPermIndices);
  201. rpixcs2=[ rpixcs2 eWDiffFeatsSplittedRpi(jj).RunPermIndices];
  202. end
  203. for jj=1:qq
  204. exhNN(jj)=any(eWDiffFeatsSplittedRpi(jj).RunPermIndices == iMinDistExh(1,itest));
  205. end
  206. exhNNidx=find(exhNN==1);
  207. interl=randperm(qq);
  208. revint(interl)=1:qq;
  209. figure()
  210. plot((sX0(itest,interl))); hold on;
  211. [~,indL0]=sort(sX0(itest,:),'descend');
  212. stem(revint(indL0(1:10)),(sX0(itest,indL0(1:10)))); hold on;
  213. stem(revint(exhNNidx)*[1 1],[0 sX0(itest,exhNNidx)]);
  214. end
  215. xlabel('Network Index $l=1,\dots,L$','interpreter','latex');
  216. ylabel('Scores $s(\mathbf{z}_0,\mathcal{Z}_l)$','interpreter','latex');
  217. h=legend('Score of the $l$-th Network','$L_0$ Highest Scores Networks','Score of the NN$(\mathbf{x}_0)$ Network');
  218. set(h,'interpreter','latex');
  219. %%
  220. plot(nStoredVec); hold on;
  221. plot(mean(nStoredVec)*ones(size(nStoredVec)));
  222. title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ // Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  223. xlabel('Network Index $l=1,\dots,L$','interpreter','latex');
  224. ylabel('Number of vectors stored in each Neural Network $\mathcal{Z}_l$','interpreter','latex');
  225. clearvars sortemp
  226. pperc=[];
  227. for jj=1:Ntests
  228. for tt=1:length(Npeaks)
  229. if tt>1
  230. retIdxs=actIXMinDistADC;
  231. for ll=Npeaks(tt-1)+1:Npeaks(tt)
  232. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  233. end
  234. else
  235. retIdxs=[];
  236. for ll=1:Npeaks(tt)
  237. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  238. end
  239. end
  240. iXMinDistADC=AsymmetricDistanceComputationWDist(CWidxf(retIdxs,:)', ZDistMatf(jj,:,:), nRecall);
  241. numRetIdxs=length(retIdxs);
  242. if numRetIdxs >= nRecall
  243. actIXMinDistADC=retIdxs(iXMinDistADC);
  244. else
  245. actIXMinDistADC(1:numRetIdxs)=retIdxs(iXMinDistADC(1:numRetIdxs));
  246. actIXMinDistADC(numRetIdxs+1:end)=ones(1,nRecall-numRetIdxs);
  247. end
  248. NNeighPerf(:,tt) = NNeighPerf(:,tt)+ RecallTest(actIXMinDistADC,nRecall,iMinDistExh(:,jj))';
  249. perc=sprintf('%2.2f%%',((jj-1)*length(Npeaks)+tt)/(Ntests*length(Npeaks))*100);
  250. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  251. fprintf(1,'%s',perc);pperc=perc;
  252. end
  253. end
  254. NNeighPerf=NNeighPerf/Ntests;
  255. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  256. fprintf('Done.\n');
  257. clearvars SortQueryNetIdx actIXMinDistADC RetrievedIdxs bZSetMat CWidxf NNMostLikely eWDiffFeatsSplittedRpi sX0
  258. %% COMPUTATIONAL COST EVALUATION SECTION
  259. Cpxty=kkf*nDim+kkc*nDim+avgP0*(mmc*kkc)^2*qq+Npeaks*nn*mmf;
  260. NormCpxty=Cpxty/(nDim*Ntrain);
  261. NormPQCpxty=(kkf*nDim+mmf*Ntrain)/(nDim*Ntrain);
  262. NormNormCpxty=NormCpxty/NormPQCpxty;
  263. %% SAVING RESULTS SECTION
  264. save(resultsFileName,'NormCpxty','NormPQCpxty','NNeighPerf','Npeaks');
  265. end
  266. %% RESULTS VISUALIZATION SECTION
  267. figure();
  268. plot([NormCpxty(1:end)]/NormPQCpxty,NNeighPerf','x-'); hold on;
  269. for jj=1:length(Npeaks)-1
  270. %text(NormCpxty(jj)/NormPQCpxty,1.05,sprintf('L=%d',Npeaks(jj)),'interpreter','latex');
  271. end
  272. %plot(NormPQCpxty/NormPQCpxty,NNeighPerf(:,end),'*');
  273. %text(NormPQCpxty/NormPQCpxty,1.05,'PQ','interpreter','latex');
  274. text(.1,1.05,'$L_0=\frac{L}{10}$','interpreter','latex');
  275. text(.9,1.05,'$L_0=L$','interpreter','latex');
  276. %title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  277. ylabel('NN Search Performances','interpreter','latex');
  278. xlabel('Computational Cost (Normalized to Fine PQ)','interpreter','latex');
  279. h=legend('recall @1','recall @2','recall @5','recall @10','recall @20',...
  280. 'recall @50','recall @100','Location','southeast');
  281. grid on; grid minor;
  282. set(h,'interpreter','latex');
  283. ylim([0 1.12]);
  284. xlim([0 1.05]);%NormCpxty(end)+.005])
  285. break;
  286. %% Lower RESULTS VISUALIZATION SECTION
  287. MostRelPeaks=[1 10 100 250 500 1000 2000];
  288. mri=arrayfun(@(ll) find(Npeaks==MostRelPeaks(ll)),1:length(MostRelPeaks));
  289. figure();
  290. plot((NormCpxty(mri)/NormPQCpxty),NNeighPerf(:,mri)','x'); hold on;
  291. plot((NormCpxty(1:mri(end))/NormPQCpxty),NNeighPerf(:,1:mri(end))'); hold on;
  292. for jj=mri
  293. text(NormCpxty(jj)/NormPQCpxty+0.002,NNeighPerf(1,jj)-0.001,sprintf('$L_0=\\frac{L}{d}$',Npeaks(jj)),'interpreter','latex');
  294. end
  295. %plot((NormPQCpxty),NNClassPerf(:,end),'*');
  296. %text((NormPQCpxty),1.005,'PQ','interpreter','latex');
  297. %title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  298. ylabel('Nearest Neighbour Search Performances','interpreter','latex');
  299. xlabel('Computational Cost (Normalized to Fine PQ)','interpreter','latex');
  300. h=legend('Recall@1','Recall@2','Recall@5','Recall@10','Recall@20',...
  301. 'Recall@50','Recall@100','Location','southeast');
  302. grid on; grid minor;
  303. set(h,'interpreter','latex');
  304. %ylim([0.84 1.002]);
  305. %xlim([NormCpxty(1)-0.0005 NormPQCpxty(end)+.0005])
  306. ylim([0 1.04]);
  307. xlim([0 .21]);