WNNetsOverPQ_Fine_Coarse.m 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. %function [NNeighPerf, NormCpxty, NormPQCpxty]= WNNetsOverPQ_Kmeans(kk,mm,nRecall,Npeaks)
  2. % clear all
  3. % close all
  4. clc
  5. addpath('./PQ Library');
  6. addpath('./WNN Library');
  7. addpath('./PQ Library/Yael Library');
  8. set(0,'defaulttextinterpreter','latex');
  9. % ---- Quantities of interest for the input Data Set --------------
  10. nDim=784; % Dimensionality of the Data
  11. NLearn=6e4; % Number of descriptors learn set
  12. Ntrain=6e4; % Number of descriptors train set
  13. Ntests=1e4; % Number of descriptors test set
  14. % ---- Quantities of interest for Fine Product Quantization -------
  15. kkf=256; % Number of clusters kmeans
  16. mmf=16; % Number of splits for PQ (divides 128)
  17. mDimf=nDim/mmf; % Dimensionality of splitted vectors
  18. nRecall=100;
  19. recAtR=[1 2 5 10 20 50 100 200 500 1000 2000 5000 10000];
  20. recAtR=recAtR(recAtR<=nRecall);
  21. % ---- Quantities of interest for Coarse Product Quantization -----
  22. kkc=41; % Number of clusters kmeans
  23. mmc=2; % Number of splits for PQ (divides 128)
  24. mDimc=nDim/mmc; % Dimensionality of splitted vectors
  25. % ---- Quantities of interest for Willshaw Networks ---------------
  26. nn=10; % Number of vectors per WNN
  27. qq=Ntrain/nn; % Number of different WNNs
  28. %Npeaks=round([1, qq./[10 5 4 3 2 4/3 1]]); % Number of peaks to select
  29. Npeaks=round([1 qq./[1000 600 300 200 100 20 10 5 4 3 2 4/3 1]]);
  30. %Npeaks=round([1, qq./[2000 1000 200 100 200/3 50 20 10 5]]);
  31. %Npeaks=round([1 2 3 4 5 6 10 qq./[ 20 10 5 4 3 2 4/3 1]]);
  32. binaryLab=0;
  33. Zoomed=0;
  34. if(binaryLab && Zoomed)
  35. resultsFileName=sprintf('./Binary Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d_Zoomed.mat',mmf,kkf,nn,mmc,kkc);
  36. elseif binaryLab
  37. resultsFileName=sprintf('./Binary Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc);
  38. elseif Zoomed
  39. resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d_Zoomed.mat',mmf,kkf,nn,mmc,kkc);
  40. else
  41. resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc);
  42. end
  43. if exist(resultsFileName,'file')==2
  44. %% LOADING RESULTS SECTION
  45. load(resultsFileName);
  46. else
  47. %% LOADING DATA SECTION
  48. fileOffset = 16; % bytes of file offset
  49. fileID = fopen('../Data/train-images.idx3-ubyte');
  50. bWSetVec = fread(fileID,nDim*Ntrain+fileOffset,'uint8');
  51. bWSetVec = bWSetVec(fileOffset+1:end);
  52. bWSetMat = reshape(bWSetVec,nDim,Ntrain);
  53. fclose(fileID);
  54. fileID = fopen('../Data/t10k-images.idx3-ubyte');
  55. bZSetVec = fread(fileID,nDim*Ntests+fileOffset,'uint8');
  56. bZSetVec = bZSetVec(fileOffset+1:end);
  57. bZSetMat = reshape(bZSetVec,nDim,Ntests);
  58. fclose(fileID);
  59. fileOffsetL = 8; % bytes of file offset
  60. fileIDL = fopen('../Data/train-labels.idx1-ubyte');
  61. bWLab = fread(fileIDL,Ntrain+fileOffsetL,'uint8');
  62. bWLab = bWLab(fileOffsetL+1:end);
  63. fclose(fileIDL);
  64. fileIDL = fopen('../Data/t10k-labels.idx1-ubyte');
  65. bZLab = fread(fileIDL,Ntests+fileOffsetL,'uint8');
  66. bZLab = bZLab(fileOffsetL+1:end);
  67. fclose(fileIDL);
  68. clear bWSetVec bZSetVec fileID fileIDL fileOffset fileOffsetL
  69. %imshow(reshape(bWSetMat(:,1),sqrt(nDim),sqrt(nDim))')
  70. %bVSetMat=double(fvecs_read('../Data/sift/sift_learn.fvecs'));
  71. %bWSetMat=double(fvecs_read('../Data/sift/sift_base.fvecs'));
  72. %bZSetMat=double(fvecs_read('../Data/sift/sift_query.fvecs'));
  73. %iMinDistExh=ivecs_read('../Data/sift/sift_groundtruth.ivecs')+1;
  74. %% COARSE PRODUCT QUANTIZATION SECTION
  75. if(binaryLab)
  76. coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_CZidx_k%d_m%d.mat',kkc,mmc);
  77. else
  78. coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_sZSetMat_k%d_m%d.mat',kkc,mmc);
  79. end
  80. if exist(coarseFileName,'file')==2
  81. load(coarseFileName);
  82. else
  83. str=sprintf('Coarse PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmc,kkc);
  84. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  85. CSubSetMatc=zeros(mDimc,kkc,mmc);
  86. CWidxc=zeros(Ntrain,mmc);
  87. CZidxc=zeros(Ntests,mmc);
  88. sZSetMatc=zeros(kkc*mmc,Ntests);
  89. pperc=[];
  90. nsplits=10;
  91. nnsp=Ntrain/nsplits;
  92. for jj=1:mmc
  93. [CSubSetMatc(:,:,jj), ~ ]=yael_kmeans(...
  94. single(bWSetMat((jj-1)*mDimc+1:jj*mDimc,:)),kkc, 'niter', 100, 'verbose', 0);
  95. for rr=1:nsplits
  96. [~,CWidxc((rr-1)*nnsp+1:rr*nnsp,jj)]=Quantization(...
  97. bWSetMat((jj-1)*mDimc+1:jj*mDimc,(rr-1)*nnsp+1:rr*nnsp),...
  98. CSubSetMatc(:,:,jj));
  99. end
  100. if(binaryLab)
  101. [~,CZidxc(:,jj)]=Quantization(...
  102. bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),...
  103. CSubSetMatc(:,:,jj));
  104. else
  105. ZDistMatc=EuclideanDistancesMat(bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj));
  106. %sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(-sqrt(ZDistMatc)');
  107. %sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(1-bsxfun(@rdivide, abs(ZDistMatc)', min(abs(ZDistMatc)')));
  108. sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=(bsxfun(@ldivide, abs(ZDistMatc)', min(abs(ZDistMatc)')));
  109. end
  110. perc=sprintf('%2.2f%%',jj/mmc*100);
  111. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  112. fprintf(1,'%s',perc); pperc=perc;
  113. end
  114. %clearvars CSubSetMatc ZDistMatc
  115. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  116. fprintf('Done.\n');
  117. if(binaryLab)
  118. save(coarseFileName,'CWidxc','CZidxc');
  119. else
  120. save(coarseFileName,'CWidxc','sZSetMatc');
  121. end
  122. end
  123. %% FINE PRODUCT QUANTIZATION SECTION
  124. fineFileName=sprintf('./Fine Quantization Indices and Distances/fine_CWidx_ZDistMat_k%d_m%d.mat',kkf,mmf);
  125. if exist(fineFileName,'file')==2
  126. load(fineFileName);
  127. else
  128. str=sprintf('Fine PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmf,kkf);
  129. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  130. CSubSetMatf=zeros(mDimf,kkf,mmf);
  131. CWidxf=zeros(Ntrain,mmf);
  132. ZDistMatf=zeros(Ntests,kkf,mmf);
  133. pperc=[];
  134. for jj=1:mmf
  135. [CSubSetMatf(:,:,jj), ~ ]=yael_kmeans(...
  136. single(bWSetMat((jj-1)*mDimf+1:jj*mDimf,:)),kkf, 'niter', 100, 'verbose', 0);
  137. ZDistMatf(:,:,jj)=EuclideanDistancesMat(bZSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  138. [~,CWidxf(:,jj)]=Quantization(bWSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj));
  139. perc=sprintf('%2.2f%%',jj/mmf*100);
  140. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  141. fprintf(1,'%s',perc);pperc=perc;
  142. end
  143. clearvars CSubSetMatf bZSetMat
  144. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  145. fprintf('Done.\n');
  146. save(fineFileName,'CWidxf','ZDistMatf');
  147. end
  148. clearvars bWSetMat bVSetMat
  149. %% PREPROCESSING DATA TO STORE IN WNNets SECTION
  150. str=sprintf('Building Learning Data Sets ');
  151. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  152. [minDh,learningNetsIdx]=QuantHammingCanonicalIndices_ceil(CWidxc,qq);
  153. nStoredVec=hist(learningNetsIdx,1:qq);
  154. fprintf('Done.\n');
  155. %break;
  156. eWDiffFeatsSplittedRpi(1:qq)=struct('RunPermIndices',[]);
  157. eWDiffFeatsSplittedRci(1:qq)=struct('RunCentroidsIndices',[]);
  158. multiCount=ones(1,qq);
  159. str=sprintf('Preprocessing Data for Learning ');
  160. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  161. pperc=[];
  162. for ii=1:Ntrain
  163. eWDiffFeatsSplittedRpi(learningNetsIdx(ii)).RunPermIndices(multiCount(learningNetsIdx(ii)))=ii;
  164. eWDiffFeatsSplittedRci(learningNetsIdx(ii)).RunCentroidsIndices(:,multiCount(learningNetsIdx(ii)))=CWidxc(ii,:);
  165. multiCount(learningNetsIdx(ii))=multiCount(learningNetsIdx(ii))+1;
  166. perc=sprintf('%2.2f%%',ii/Ntrain*100);
  167. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  168. fprintf(1,'%s',perc); pperc=perc;
  169. end
  170. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  171. fprintf('Done.\n');
  172. clearvars eWSetMat permRand CWidxc %multiCount nStoredVec
  173. %% BUILDING WNNets AND SCORES COMPUTATION SECTION
  174. str=sprintf('Building %d Willshaw NNets and Computing Scores ',qq);
  175. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  176. II=eye(kkc);
  177. P0=0;
  178. %Ntests=1;
  179. sX0=zeros(Ntests,qq);
  180. pperc=[];
  181. zc=zeros(kkc*mmc,Ntests);
  182. if(binaryLab)
  183. for jj=1:mmc
  184. zc((jj-1)*kkc+1:jj*kkc,:)=II(CZidxc(1:Ntests,jj),:)';
  185. end
  186. end
  187. for jj=1:qq
  188. uniqCent=unique(eWDiffFeatsSplittedRci(jj).RunCentroidsIndices','rows');
  189. WXsmatrix=BuildNetworkfromIdxs(uniqCent',II);
  190. P0=P0+mean(mean(WXsmatrix));
  191. if(binaryLab)
  192. %sX0(:,jj)=diag(zc(:,1:Ntests)'*double(WXsmatrix)*zc(:,1:Ntests));
  193. sX0(:,jj)=1./(sum(zc(:,1:Ntests).^2))'.*diag(zc(:,1:Ntests)'*double(WXsmatrix)*zc(:,1:Ntests));
  194. else
  195. %sX0(:,jj)=diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests));
  196. sX0(:,jj)=1./(sum(sZSetMatc(:,1:Ntests)).^2)'.*diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests));
  197. end
  198. perc=sprintf('%2.2f%%',jj/qq*100);
  199. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  200. fprintf(1,'%s',perc);pperc=perc;
  201. end
  202. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  203. fprintf('Done.\n');
  204. avgP0=P0/qq;
  205. clearvars sZSetMatc eWDiffFeatsSplittedRsm WXsmatrix eWDiffFeatsSplittedRci uniqCent
  206. %% ADC COMPUTATION FOR PERFORMANCES COMPUTATION SECTION
  207. str=sprintf('Computing Performances: L=[%d:%d], recalls@[%d:%d] ',Npeaks(1),Npeaks(end),recAtR(1),recAtR(end));
  208. fprintf('%s%s ',str,('.')*ones(1,55-length(str)));
  209. NNClassPerf=zeros(length(recAtR),length(Npeaks));
  210. iXMinDistADC=zeros(1,nRecall);
  211. actIXMinDistADC=zeros(1,nRecall);
  212. [~,sortemp]=sort(sX0,2,'descend');
  213. NNmostLikely=sortemp(:,1:Npeaks(end))';
  214. [ssxo,isxo]=sort(sX0,2,'descend');
  215. %%
  216. % close all;
  217. % for itest=500
  218. %
  219. % lrp2=0;
  220. % rpixcs2=[];
  221. % for jj=1:qq
  222. % lrp2(jj)=length(eWDiffFeatsSplittedRpi(jj).RunPermIndices);
  223. % rpixcs2=[ rpixcs2 eWDiffFeatsSplittedRpi(jj).RunPermIndices];
  224. % end
  225. %
  226. % %for jj=1:qq
  227. % % exhNN(jj)=any(eWDiffFeatsSplittedRpi(jj).RunPermIndices == iMinDistExh(1,itest));
  228. % %end
  229. % %exhNNidx=find(exhNN==1);
  230. %
  231. % interl=randperm(qq);
  232. % revint(interl)=1:qq;
  233. %
  234. % figure()
  235. % plot((sX0(itest,interl))); hold on;
  236. %
  237. % [~,indL0]=sort(sX0(itest,:),'descend');
  238. % stem(revint(indL0(1:10)),(sX0(itest,indL0(1:10)))); hold on;
  239. %
  240. % stem(revint(exhNNidx)*[1 1],[0 sX0(itest,exhNNidx)]);
  241. % end
  242. % xlabel('Network Index $l=1,\dots,L$','interpreter','latex');
  243. % ylabel('Scores $s(\mathbf{z}_0,\mathcal{Z}_l)$','interpreter','latex');
  244. % h=legend('Score of the $l$-th Network','$L_0$ Highest Scores Networks','Score of the NN$(\mathbf{x}_0)$ Network');
  245. % set(h,'interpreter','latex');
  246. %%
  247. plot(nStoredVec); hold on;
  248. plot(mean(nStoredVec)*ones(size(nStoredVec)));
  249. title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ // Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  250. xlabel('Network Index $l=1,\dots,L$','interpreter','latex');
  251. ylabel('Number of vectors stored in each Neural Network $\mathcal{Z}_l$','interpreter','latex');
  252. clearvars sortemp
  253. pperc=[];
  254. for jj=1:Ntests
  255. for tt=1:length(Npeaks)
  256. if tt>1
  257. retIdxs=actIXMinDistADC;
  258. for ll=Npeaks(tt-1)+1:Npeaks(tt)
  259. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  260. end
  261. else
  262. retIdxs=[];
  263. for ll=1:Npeaks(tt)
  264. retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices];
  265. end
  266. end
  267. iXMinDistADC=AsymmetricDistanceComputationWDist(CWidxf(retIdxs,:)', ZDistMatf(jj,:,:), nRecall);
  268. numRetIdxs=length(retIdxs);
  269. if numRetIdxs >= nRecall
  270. actIXMinDistADC=retIdxs(iXMinDistADC);
  271. else
  272. actIXMinDistADC(1:numRetIdxs)=retIdxs(iXMinDistADC(1:numRetIdxs));
  273. actIXMinDistADC(numRetIdxs+1:end)=ones(1,nRecall-numRetIdxs);
  274. end
  275. for rr=1:length(recAtR)
  276. if recAtR(rr)==1
  277. NNClassPerf(rr,tt)=NNClassPerf(rr,tt)+sum(bWLab(actIXMinDistADC(:,1))==bZLab(jj));
  278. else
  279. NNClassPerf(rr,tt)=NNClassPerf(rr,tt)+sum(max((bWLab(actIXMinDistADC(:,1:recAtR(rr)))==repmat(bZLab(jj),recAtR(rr),1))'));
  280. end
  281. end
  282. %NNClassPerf(:,tt) = NNClassPerf(:,tt)+ RecallTest(actIXMinDistADC,nRecall,iMinDistExh(:,jj))';
  283. perc=sprintf('%2.2f%%',((jj-1)*length(Npeaks)+tt)/(Ntests*length(Npeaks))*100);
  284. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  285. fprintf(1,'%s',perc);pperc=perc;
  286. end
  287. end
  288. NNClassPerf=NNClassPerf/Ntests;
  289. fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc)));
  290. fprintf('Done.\n');
  291. clearvars SortQueryNetIdx actIXMinDistADC RetrievedIdxs bZSetMat CWidxf NNMostLikely eWDiffFeatsSplittedRpi sX0
  292. %% COMPUTATIONAL COST EVALUATION SECTION
  293. Cpxty=kkf*nDim+kkc*nDim+avgP0*(mmc*kkc)^2*qq+Npeaks*nn*mmf;
  294. NormCpxty=Cpxty/(nDim*Ntrain);
  295. NormPQCpxty=(kkf*nDim+mmf*Ntrain)/(nDim*Ntrain);
  296. NormNormCpxty=NormCpxty/NormPQCpxty;
  297. %% SAVING RESULTS SECTION
  298. save(resultsFileName,'NormCpxty','NormPQCpxty','NNClassPerf','Npeaks');
  299. end
  300. %% RESULTS VISUALIZATION SECTION
  301. figure();
  302. plot((NormCpxty(1:end)/NormPQCpxty),NNClassPerf(:,1:end)','x-'); hold on;
  303. for jj=1:length(Npeaks)-1
  304. %text(NormCpxty(jj)/NormPQCpxty,1.005,sprintf('L=%d',Npeaks(jj)),'interpreter','latex');
  305. end
  306. %plot((NormPQCpxty),NNClassPerf(:,end),'*');
  307. %text((NormPQCpxty),1.005,'PQ','interpreter','latex');
  308. title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  309. ylabel('Nearest Neighbour Search Performances','interpreter','latex');
  310. xlabel('Computational Cost (Normalized to Exhaustive Search)','interpreter','latex');
  311. h=legend('recall @1','recall @2','recall @5','recall @10','recall @20',...
  312. 'recall @50','recall @100','Location','southeast');
  313. grid on; grid minor;
  314. set(h,'interpreter','latex');
  315. %ylim([0.84 1.002]);
  316. %xlim([NormCpxty(1)-0.0005 NormPQCpxty(end)+.0005])
  317. ylim([0.84 1.01]);
  318. xlim([0.21 1.1]);
  319. %% Lower RESULTS VISUALIZATION SECTION
  320. MostRelPeaks=[1 6 60 300 600];
  321. mri=arrayfun(@(ll) find(Npeaks==MostRelPeaks(ll)),1:length(MostRelPeaks));
  322. figure();
  323. plot((NormCpxty(mri)/NormPQCpxty),NNClassPerf(:,mri)','x'); hold on;
  324. plot((NormCpxty(1:mri(end))/NormPQCpxty),NNClassPerf(:,1:mri(end))'); hold on;
  325. for jj=mri
  326. text(NormCpxty(jj)/NormPQCpxty+0.001,NNClassPerf(1,jj)-0.002,sprintf('$L_0=\\frac{d}{d}$',Npeaks(jj)),'interpreter','latex');
  327. end
  328. %plot((NormPQCpxty),NNClassPerf(:,end),'*');
  329. %text((NormPQCpxty),1.005,'PQ','interpreter','latex');
  330. %title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn));
  331. ylabel('Classification Performances','interpreter','latex');
  332. xlabel('Computational Cost (Normalized to Fine PQ)','interpreter','latex');
  333. h=legend('Recall@1','Recall@2','Recall@5','Recall@10','Recall@20',...
  334. 'Recall@50','Recall@100','Location','southeast');
  335. grid on; grid minor;
  336. set(h,'interpreter','latex');
  337. %ylim([0.84 1.002]);
  338. %xlim([NormCpxty(1)-0.0005 NormPQCpxty(end)+.0005])
  339. ylim([0.864 1.006]);
  340. xlim([0.22 .31]);