%function [NNeighPerf, NormCpxty, NormPQCpxty]= WNNetsOverPQ_Kmeans(kk,mm,nRecall,Npeaks) clear all close all clc addpath('./SIFTs Library'); addpath('./PQ Library'); addpath('./WNN Library'); addpath('./PQ Library/Yael Library'); set(0,'defaulttextinterpreter','latex'); % ---- Quantities of interest for the input Data Set -------------- nDim=128; % Dimensionality of the Data NLearn=1e5; % Number of descriptors learn set Ntrain=1e6; % Number of descriptors train set Ntests=1e4; % Number of descriptors test set % ---- Quantities of interest for Fine Product Quantization ------- kkf=256; % Number of clusters kmeans mmf=16; % Number of splits for PQ (divides 128) mDimf=nDim/mmf; % Dimensionality of splitted vectors nRecall=100; recAtR=[1 2 5 10 20 50 100 200 500 1000 2000 5000 10000]; recAtR=recAtR(recAtR<=nRecall); % ---- Quantities of interest for Coarse Product Quantization ----- kkc=10; % Number of clusters kmeans mmc=2; % Number of splits for PQ (divides 128) mDimc=nDim/mmc; % Dimensionality of splitted vectors % ---- Quantities of interest for Willshaw Networks --------------- nn=1e4; % Number of vectors per WNN qq=Ntrain/nn; % Number of different WNNs Npeaks=round([1, qq./[10 5 4 3 2 4/3 1]]); % Number of peaks to select resultsFileName=sprintf('./Result Sets/finePQ_m%d_k%d_coarseWNN_n%d_m%d_k%d.mat',mmf,kkf,nn,mmc,kkc); if 0%exist(resultsFileName,'file')==2 %% LOADING RESULTS SECTION load(resultsFileName); else %% LOADING DATA SECTION bVSetMat=double(fvecs_read('../Data/sift/sift_learn.fvecs')); bWSetMat=double(fvecs_read('../Data/sift/sift_base.fvecs')); bZSetMat=double(fvecs_read('../Data/sift/sift_query.fvecs')); iMinDistExh=ivecs_read('../Data/sift/sift_groundtruth.ivecs')+1; %% COARSE PRODUCT QUANTIZATION SECTION coarseFileName=sprintf('./Coarse Quantization Indices and Dissimilarities/coarse_CWidx_sZSetMat_k%d_m%d.mat',kkc,mmc); if exist(coarseFileName,'file')==2 load(coarseFileName); else str=sprintf('Coarse PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmc,kkc); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); CSubSetMatc=zeros(mDimc,kkc,mmc); CWidxc=zeros(Ntrain,mmc); sZSetMatc=zeros(kkc*mmc,Ntests); pperc=[]; for jj=1:mmc [CSubSetMatc(:,:,jj), ~ ]=yael_kmeans(... single(bVSetMat((jj-1)*mDimc+1:jj*mDimc,:)),kkc, 'niter', 100, 'verbose', 0); ZDistMatc=EuclideanDistancesMat(bZSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj)); [~,CWidxc(:,jj)]=Quantization(bWSetMat((jj-1)*mDimc+1:jj*mDimc,:),CSubSetMatc(:,:,jj)); sZSetMatc((jj-1)*kkc+1:jj*kkc,:)=exp(-sqrt(ZDistMatc)'); perc=sprintf('%2.2f%%',jj/mmc*100); fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf(1,'%s',perc); pperc=perc; end clearvars CSubSetMatc ZDistMatc fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf('Done.\n'); save(coarseFileName,'CWidxc','sZSetMatc'); end %% FINE PRODUCT QUANTIZATION SECTION fineFileName=sprintf('./Fine Quantization Indices and Distances/fine_CWidx_ZDistMat_k%d_m%d.mat',kkf,mmf); if exist(fineFileName,'file')==2 load(fineFileName); else str=sprintf('Fine PQ: computing m=%1.0f, k=%1.0f subcentroids ',mmf,kkf); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); CSubSetMatf=zeros(mDimf,kkf,mmf); CWidxf=zeros(Ntrain,mmf); ZDistMatf=zeros(Ntests,kkf,mmf); pperc=[]; for jj=1:mmf [CSubSetMatf(:,:,jj), ~ ]=yael_kmeans(... single(bVSetMat((jj-1)*mDimf+1:jj*mDimf,:)),kkf, 'niter', 100, 'verbose', 0); ZDistMatf(:,:,jj)=EuclideanDistancesMat(bZSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj)); [~,CWidxf(:,jj)]=Quantization(bWSetMat((jj-1)*mDimf+1:jj*mDimf,:),CSubSetMatf(:,:,jj)); perc=sprintf('%2.2f%%',jj/mmf*100); fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf(1,'%s',perc);pperc=perc; end clearvars CSubSetMatf bZSetMat fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf('Done.\n'); save(fineFileName,'CWidxf','ZDistMatf'); end clearvars bWSetMat bVSetMat %% PREPROCESSING DATA TO STORE IN WNNets SECTION str=sprintf('Building Learning Data Sets '); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); %permRand=randperm(Ntrain,qq); %% JUST FOR kkc=10 mmc=4 %AllCentSet=[arrayfun(@(ll)ceil(ll/1000),1:10000); % repmat(arrayfun(@(ll)ceil(ll/100),1:1000),1,10); % repmat(arrayfun(@(ll)ceil(ll/10),1:100),1,100); repmat([1:10],1,1000)]'; %% JUST FOR kkc=10 mmc=2 AllCentSet=[arrayfun(@(ll)ceil(ll/10),1:100); repmat([1:10],1,10)]'; learningNetsIdx=QuantHammingAllOfTheCanonicalIndices(CWidxc,AllCentSet); %learningNetsIdx=QuantHammingCanonicalIndices(CWidxc,permRand); LearningDistribution=hist(learningNetsIdx,1:qq); fprintf('Done.\n'); eWDiffFeatsSplittedRpi(1:100)=struct('RunPermIndices',[]); eWDiffFeatsSplittedRci(1:100)=struct('RunCentroidsIndices',[]); multiCount=ones(1,qq); str=sprintf('Preprocessing Data for Learning '); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); pperc=[]; for ii=1:Ntrain eWDiffFeatsSplittedRpi(learningNetsIdx(ii)).RunPermIndices(multiCount(learningNetsIdx(ii)))=ii; eWDiffFeatsSplittedRci(learningNetsIdx(ii)).RunCentroidsIndices(:,multiCount(learningNetsIdx(ii)))=CWidxc(ii,:); multiCount(learningNetsIdx(ii))=multiCount(learningNetsIdx(ii))+1; perc=sprintf('%2.2f%%',ii/Ntrain*100); fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf(1,'%s',perc); pperc=perc; end fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf('Done.\n'); clearvars eWSetMat permRand CWidxc %multiCount nStoredVec %% BUILDING WNNets AND SCORES COMPUTATION SECTION str=sprintf('Building %d Willshaw NNets and Computing Scores ',qq); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); II=eye(kkc); P0=0; %Ntests=1; sX0=zeros(Ntests,qq); pperc=[]; for jj=1:qq uniqCent=unique(eWDiffFeatsSplittedRci(jj).RunCentroidsIndices','rows'); if ~isempty(uniqCent) WXsmatrix=BuildNetworkfromIdxs(uniqCent',II); P0=P0+mean(mean(WXsmatrix)); sX0(:,jj)=diag(sZSetMatc(:,1:Ntests)'*double(WXsmatrix)*sZSetMatc(:,1:Ntests)); end perc=sprintf('%2.2f%%',jj/qq*100); fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf(1,'%s',perc);pperc=perc; end fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf('Done.\n'); avgP0=P0/qq; clearvars sZSetMatc eWDiffFeatsSplittedRsm WXsmatrix eWDiffFeatsSplittedRci uniqCent %% ADC COMPUTATION FOR PERFORMANCES COMPUTATION SECTION str=sprintf('Computing Performances: L=[%d:%d], recalls@[%d:%d] ',Npeaks(1),Npeaks(end),recAtR(1),recAtR(end)); fprintf('%s%s ',str,('.')*ones(1,55-length(str))); NNeighPerf=zeros(length(recAtR),length(Npeaks)); iXMinDistADC=zeros(1,nRecall); actIXMinDistADC=zeros(1,nRecall); [~,sortemp]=sort(sX0,2,'descend'); NNmostLikely=sortemp(:,1:Npeaks(end))'; clearvars sortemp pperc=[]; for jj=1:Ntests for tt=1:length(Npeaks) if tt>1 retIdxs=actIXMinDistADC; for ll=Npeaks(tt-1)+1:Npeaks(tt) retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices]; end usedVec(jj,tt)=length(retIdxs)-length(actIXMinDistADC); else retIdxs=[]; for ll=1:Npeaks(tt) retIdxs=[retIdxs, eWDiffFeatsSplittedRpi(NNmostLikely(ll,jj)).RunPermIndices]; end usedVec(jj,1)=length(retIdxs); end iXMinDistADC=AsymmetricDistanceComputationWDist(CWidxf(retIdxs,:)', ZDistMatf(jj,:,:), nRecall); numRetIdxs=length(retIdxs); if numRetIdxs >= nRecall actIXMinDistADC=retIdxs(iXMinDistADC); else actIXMinDistADC(1:numRetIdxs)=retIdxs(iXMinDistADC(1:numRetIdxs)); actIXMinDistADC(numRetIdxs+1:end)=ones(1,nRecall-numRetIdxs); end NNeighPerf(:,tt) = NNeighPerf(:,tt)+ RecallTest(actIXMinDistADC,nRecall,iMinDistExh(:,jj))'; perc=sprintf('%2.2f%%',((jj-1)*length(Npeaks)+tt)/(Ntests*length(Npeaks))*100); fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf(1,'%s',perc);pperc=perc; end end NNeighPerf=NNeighPerf/Ntests; fprintf(1,'%s',sprintf('\b')*ones(1,length(pperc))); fprintf('Done.\n'); clearvars SortQueryNetIdx actIXMinDistADC RetrievedIdxs bZSetMat CWidxf NNMostLikely eWDiffFeatsSplittedRpi sX0 %% COMPUTATIONAL COST EVALUATION SECTION actNpeaks=mean(cumsum(usedVec')'./nn); Cpxty=kkf*nDim+kkc*nDim+avgP0*(mmc*kkc)^2*qq+actNpeaks*nn*mmf; NormCpxty=Cpxty/(nDim*Ntrain); NormPQCpxty=(kkf*nDim+mmf*Ntrain)/(nDim*Ntrain); NormNormCpxty=NormCpxty/NormPQCpxty; %% SAVING RESULTS SECTION save(resultsFileName,'NormCpxty','NormPQCpxty','NNeighPerf','Npeaks','actNpeaks','avgP0'); end %% RESULTS VISUALIZATION SECTION figure(); plot([NormCpxty(1:end-1) NormPQCpxty],NNeighPerf','x-'); hold on; for jj=1:length(actNpeaks)-1 text(NormCpxty(jj),1.05,sprintf('L=%2.0f',actNpeaks(jj)*nn),'interpreter','latex'); end plot(NormPQCpxty,NNeighPerf(:,end),'*'); text(NormPQCpxty,1.05,'PQ','interpreter','latex'); title(sprintf('Fine PQ $(k_f=%d, m_f=%d)$ + Coarse WNNets $(k_c=%d,m_c=%d,q=%d,n=%d)$',kkf,mmf,kkc,mmc,qq,nn)); ylabel('Nearest Neighbour Search Performances','interpreter','latex'); xlabel('Computational Cost (Normalized to Exhaustive Search)','interpreter','latex'); h=legend('recall @1','recall @2','recall @5','recall @10','recall @20',... 'recall @50','recall @100','Location','southeast'); grid on; grid minor; set(h,'interpreter','latex'); ylim([0 1.1]); xlim([0 NormCpxty(end)+.005])