maxDistanceBasedMetric.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from scipy.spatial import ConvexHull
  2. from scipy.spatial.distance import cdist
  3. from .swcFuncs import readSWC_numpy
  4. import numpy as np
  5. from pyemd import emd
  6. def cdist_1d_centripetal(list1, list2, center):
  7. """
  8. Return a matrix of pairwise displacements of entries of list1 from corresponding
  9. entries in list2, with displacements
  10. towards center having positive value and those away having negative values
  11. :param list1: iterable of floats or ints
  12. :param list2: iterable of floats or ints
  13. :param center: float or int
  14. :return: numpy.ndarray of shape (<len of list1>, <len of list2>)
  15. """
  16. assert all([isinstance(x, (float, int)) for x in list1])
  17. assert all([isinstance(x, (float, int)) for x in list2])
  18. assert isinstance(center, (float, int))
  19. mesh2 = np.array([list2] * len(list1))
  20. mesh1 = np.array([list1] * len(list2)).T
  21. mesh1CenteredAbs = np.abs(mesh1 - center)
  22. mesh2CenteredAbs = np.abs(mesh2 - center)
  23. return mesh2CenteredAbs - mesh1CenteredAbs
  24. def calcMaxDistances(swcList):
  25. """
  26. Compute the convex hull of the union of points from all swcs in swcList. For each vertex of this
  27. convex hull, compute the distance of the farthest point among the vertices of the hull and return
  28. them.
  29. :param swcList: list of valid SWC files on the file system, list of strings.
  30. :return: list of maximum distances
  31. """
  32. swcPointSets = []
  33. for swc in swcList:
  34. headr, swcData = readSWC_numpy(swc)
  35. swcPointSets.append(swcData[:, 2:5])
  36. unionWithDuplicates = np.concatenate(swcPointSets, axis=0)
  37. if any(np.abs(unionWithDuplicates).max(axis=0) == 0):
  38. raise(ValueError("The list of SWCs all lie on a plane or on a line and hence do not "
  39. "for a 3D point cloud. Such SWCs are not supported."))
  40. hull = ConvexHull(unionWithDuplicates)
  41. vertices = unionWithDuplicates[hull.vertices, :]
  42. distMatrix = cdist(unionWithDuplicates, vertices)
  43. maxDistances = distMatrix.max(axis=1).tolist()
  44. return maxDistances
  45. def maxDistEMD(swcList):
  46. """
  47. Calculate the maxDistance based metric. It is the size normalized Earth mover distance
  48. between the distribution of maxDistances (see calcMaxDistances above)
  49. of the pooled collection of points of all swcs in swcList and the distribution of
  50. pooled diagonal maxDistances of individual swcs in swcList
  51. :param swcList: list of valid swc files on the system, list of strings
  52. :return: float
  53. """
  54. individualMaxDistances = [calcMaxDistances([swc]) for swc in swcList]
  55. pooledIndividualMaxDistances = np.concatenate(individualMaxDistances, axis=0)
  56. meanPIMD = pooledIndividualMaxDistances.mean()
  57. PIMDNorm = (pooledIndividualMaxDistances - meanPIMD) / meanPIMD
  58. maxDistancesAllPts = np.array(calcMaxDistances(swcList))
  59. MDAPNorm = (maxDistancesAllPts - meanPIMD) / meanPIMD
  60. binWidth = 1 / meanPIMD
  61. bins = np.arange(MDAPNorm.min() - 0.5 * binWidth,
  62. MDAPNorm.max() + 0.5 * binWidth,
  63. binWidth)
  64. hist1, bins1 = np.histogram(MDAPNorm, bins)
  65. hist2, bins2 = np.histogram(PIMDNorm, bins1)
  66. hist1Normed = hist1 / float(hist1.sum())
  67. hist2Normed = hist2 / float(hist2.sum())
  68. dist_metric = cdist_1d_centripetal(bins1, bins2, center=0)
  69. emd_val = emd(np.asarray(hist2Normed, dtype=np.float64), np.asarray(hist1Normed, dtype=np.float64),
  70. np.asarray(dist_metric, dtype=np.float64))
  71. return emd_val