correlation.py 1.0 KB

12345678910111213141516171819202122232425262728293031
  1. import scipy.cluster.hierarchy as sch
  2. import numpy as np
  3. def cluster_corr(corr_array, threshold=None, inplace=False):
  4. """
  5. Rearranges the correlation matrix, corr_array, so that groups of highly
  6. correlated variables are next to eachother
  7. Parameters
  8. ----------
  9. corr_array : pandas.DataFrame or numpy.ndarray
  10. a NxN correlation matrix
  11. Returns
  12. -------
  13. corr_array : a NxN correlation matrix with the columns and rows rearranged
  14. linkage : linkage of distances
  15. labels : cluster labels
  16. idx : sorted incides for original labels
  17. """
  18. pairwise_distances = sch.distance.pdist(corr_array)
  19. linkage = sch.linkage(pairwise_distances, method='complete')
  20. cluster_distance_threshold = pairwise_distances.max()/2 if threshold is None else threshold
  21. labels = sch.fcluster(linkage, cluster_distance_threshold, criterion='distance')
  22. idx = np.argsort(labels)
  23. if not inplace:
  24. corr_array = corr_array.copy()
  25. return corr_array[idx, :][:, idx], linkage, labels, idx