12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- function [K_opt,distance_opt] = kmeans_param_opt(Data,center)
- %This function tries to find the sub-optimal number of K for a K-means
- %Clustering and the best distance criteria for the data using both
- %silhouette method and error
- % Inputs :
- % Data: to be clustered data, rows of the Data are the
- % observations and the columns are the variables (features),
- % center: 1 for normalizing data (suggested if feature's range are very different)
- % Outputs :
- % K: Number of Clusters
- % distance : a string to determine the distance method to use
- %% initializing
- if center==1
- Data = zscore(Data);
- end
- K_max = round(size(Data,1)/4); % maximum number of clusters
- Distances_to_consider = {'sqeuclidean','correlation','cityblock','cosine'};
- %% Finding Optimum k inside the same distance
- % based on error
- err_all = zeros([length(Distances_to_consider),K_max]);
- err_all_normalized = zeros([length(Distances_to_consider),K_max]);
- K_opt_distances = zeros(length(Distances_to_consider),1);
- figure();sgtitle('Error vs. Number of Clusters');
- for i=1:length(Distances_to_consider)
- distance = Distances_to_consider{i};
- for j=1:K_max
- [~,~,err] = kmeans(Data,j,'Replicates',10,'Start','plus','MaxIter',10000,'Distance',distance,'EmptyAction','drop');
- err_all(i,j) = sum(err);
- end
- err_all_normalized(i,:) = err_all(i,:)./max(err_all(i,:));
- std = stdfilt(err_all_normalized(i,:),ones(1,3));
- K_opt_distances(i) = find(std < 0.05,1,'first');
- subplot(2,ceil(length(Distances_to_consider)/2),i)
- plot(err_all(i,:));title(strcat('Distance: ',Distances_to_consider{i}));xline(K_opt_distances(i),'--r','Stopping Point');
- xlabel('Number of Clusters','FontSize',12);ylabel('Error','FontSize',12);
- end
-
- %% Comparing across distances
- figure()
- plot(err_all_normalized','LineWidth',2);xlabel('Number of Clusters','FontSize',12);ylabel('Normalized error','FontSize',12)
- legend(Distances_to_consider,'Location','northeastoutside');legend('boxoff')
- % based on silhouette
- figure()
- S = zeros(size(Data,1),length(Distances_to_consider));
- for i=1:length(Distances_to_consider)
- distance = Distances_to_consider{i};
- [idx_clus,~,~] = kmeans(Data,K_opt_distances(i),'Replicates',10,'Start','plus','MaxIter',10000,'Distance',distance,'EmptyAction','drop');
- subplot(2,ceil(length(Distances_to_consider)/2),i);
- [S(:,i),h] = silhouette(Data,idx_clus,distance);title(strcat('Distance: ',distance));xline(mean(S(:,i)),'--r','Average');
- end
- dispersity = var(S);
- [~,opt_idx] = min(dispersity);
- K_opt = K_opt_distances(opt_idx);
- distance_opt = Distances_to_consider{opt_idx};
-
- end
|