maxDistanceBasedMetric.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from scipy.spatial import ConvexHull
  2. from scipy.spatial.distance import cdist
  3. from .swcFuncs import readSWC_numpy
  4. import numpy as np
  5. from pyemd import emd
  6. def cdist_1d_centripetal(list1, list2, center):
  7. """
  8. Return a matrix of pairwise displacements of entries of list1 from corresponding
  9. entries in list2, with displacements
  10. towards center having positive value and those away having negative values
  11. :param list1: iterable of floats or ints
  12. :param list2: iterable of floats or ints
  13. :param center: float or int
  14. :return: numpy.ndarray of shape (<len of list1>, <len of list2>)
  15. """
  16. assert all([isinstance(x, (float, int)) for x in list1])
  17. assert all([isinstance(x, (float, int)) for x in list2])
  18. assert isinstance(center, (float, int))
  19. mesh2 = np.array([list2] * len(list1))
  20. mesh1 = np.array([list1] * len(list2)).T
  21. mesh1CenteredAbs = np.abs(mesh1 - center)
  22. mesh2CenteredAbs = np.abs(mesh2 - center)
  23. return mesh2CenteredAbs - mesh1CenteredAbs
  24. def calcMaxDistances(swcList):
  25. """
  26. Compute the convex hull of the union of points from all swcs in swcList. For each vertex of this
  27. convex hull, compute the distance of the farthest point among the vertices of the hull and return
  28. them.
  29. :param swcList: list of valid SWC files on the file system, list of strings.
  30. :return: list of maximum distances
  31. """
  32. swcPointSets = []
  33. for swc in swcList:
  34. headr, swcData = readSWC_numpy(swc)
  35. swcPointSets.append(swcData[:, 2:5])
  36. unionWithDuplicates = np.concatenate(swcPointSets, axis=0)
  37. if any(np.abs(unionWithDuplicates).max(axis=0) == 0):
  38. raise ValueError(
  39. "The list of SWCs all lie on a plane or on a line and hence do not "
  40. "for a 3D point cloud. Such SWCs are not supported.")
  41. hull = ConvexHull(unionWithDuplicates)
  42. vertices = unionWithDuplicates[hull.vertices, :]
  43. distMatrix = cdist(unionWithDuplicates, vertices)
  44. maxDistances = distMatrix.max(axis=1).tolist()
  45. return maxDistances
  46. def maxDistEMD(swcList):
  47. """
  48. Calculate the maxDistance based metric. It is the size normalized Earth mover distance
  49. between the distribution of maxDistances (see calcMaxDistances above)
  50. of the pooled collection of points of all swcs in swcList and the distribution of
  51. pooled diagonal maxDistances of individual swcs in swcList
  52. :param swcList: list of valid swc files on the system, list of strings
  53. :return: float
  54. """
  55. individualMaxDistances = [calcMaxDistances([swc]) for swc in swcList]
  56. pooledIndividualMaxDistances = np.concatenate(individualMaxDistances, axis=0)
  57. meanPIMD = pooledIndividualMaxDistances.mean()
  58. PIMDNorm = (pooledIndividualMaxDistances - meanPIMD) / meanPIMD
  59. maxDistancesAllPts = np.array(calcMaxDistances(swcList))
  60. MDAPNorm = (maxDistancesAllPts - meanPIMD) / meanPIMD
  61. binWidth = 1 / meanPIMD
  62. bins = np.arange(MDAPNorm.min() - 0.5 * binWidth,
  63. MDAPNorm.max() + 0.5 * binWidth,
  64. binWidth)
  65. hist1, bins1 = np.histogram(MDAPNorm, bins)
  66. hist2, bins2 = np.histogram(PIMDNorm, bins1)
  67. hist1Normed = hist1 / float(hist1.sum())
  68. hist2Normed = hist2 / float(hist2.sum())
  69. dist_metric = cdist_1d_centripetal(bins1, bins2, center=0)
  70. emd_val = emd(np.asarray(hist2Normed, dtype=np.float64), np.asarray(hist1Normed, dtype=np.float64),
  71. np.asarray(dist_metric, dtype=np.float64))
  72. return emd_val