kmeans_param_opt.m 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. function [K_opt,distance_opt] = kmeans_param_opt(Data,center)
  2. %This function tries to find the sub-optimal number of K for a K-means
  3. %Clustering and the best distance criteria for the data using both
  4. %silhouette method and error
  5. % Inputs :
  6. % Data: to be clustered data, rows of the Data are the
  7. % observations and the columns are the variables (features),
  8. % center: 1 for normalizing data (suggested if feature's range are very different)
  9. % Outputs :
  10. % K: Number of Clusters
  11. % distance : a string to determine the distance method to use
  12. %% initializing
  13. if center==1
  14. Data = zscore(Data);
  15. end
  16. K_max = round(size(Data,1)/4); % maximum number of clusters
  17. Distances_to_consider = {'sqeuclidean','correlation','cityblock','cosine'};
  18. %% Finding Optimum k inside the same distance
  19. % based on error
  20. err_all = zeros([length(Distances_to_consider),K_max]);
  21. err_all_normalized = zeros([length(Distances_to_consider),K_max]);
  22. K_opt_distances = zeros(length(Distances_to_consider),1);
  23. figure();sgtitle('Error vs. Number of Clusters');
  24. for i=1:length(Distances_to_consider)
  25. distance = Distances_to_consider{i};
  26. for j=1:K_max
  27. [~,~,err] = kmeans(Data,j,'Replicates',10,'Start','plus','MaxIter',10000,'Distance',distance,'EmptyAction','drop');
  28. err_all(i,j) = sum(err);
  29. end
  30. err_all_normalized(i,:) = err_all(i,:)./max(err_all(i,:));
  31. std = stdfilt(err_all_normalized(i,:),ones(1,3));
  32. K_opt_distances(i) = find(std < 0.05,1,'first');
  33. subplot(2,ceil(length(Distances_to_consider)/2),i)
  34. plot(err_all(i,:));title(strcat('Distance: ',Distances_to_consider{i}));xline(K_opt_distances(i),'--r','Stopping Point');
  35. xlabel('Number of Clusters','FontSize',12);ylabel('Error','FontSize',12);
  36. end
  37. %% Comparing across distances
  38. figure()
  39. plot(err_all_normalized','LineWidth',2);xlabel('Number of Clusters','FontSize',12);ylabel('Normalized error','FontSize',12)
  40. legend(Distances_to_consider,'Location','northeastoutside');legend('boxoff')
  41. % based on silhouette
  42. figure()
  43. S = zeros(size(Data,1),length(Distances_to_consider));
  44. for i=1:length(Distances_to_consider)
  45. distance = Distances_to_consider{i};
  46. [idx_clus,~,~] = kmeans(Data,K_opt_distances(i),'Replicates',10,'Start','plus','MaxIter',10000,'Distance',distance,'EmptyAction','drop');
  47. subplot(2,ceil(length(Distances_to_consider)/2),i);
  48. [S(:,i),h] = silhouette(Data,idx_clus,distance);title(strcat('Distance: ',distance));xline(mean(S(:,i)),'--r','Average');
  49. end
  50. dispersity = var(S);
  51. [~,opt_idx] = min(dispersity);
  52. K_opt = K_opt_distances(opt_idx);
  53. distance_opt = Distances_to_consider{opt_idx};
  54. end