arr_sim.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. """Functions for array similarity metrics"""
  2. import numpy as np
  3. from scipy.signal import correlate
  4. import scipy as sp
  5. from skimage.morphology import binary_dilation
  6. import ot
  7. import ot.plot
  8. import matplotlib.pyplot as plt
  9. import matplotlib as mpl
  10. import warnings
  11. from scold import utils
  12. def arr_sim(a, b, measure='all', partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'translation':'opt', 'n_startvals':9, 'solver':'Nelder-Mead', 'search_method':'grid'}):
  13. """Calculate a measure of array similarity for 2 (aligned) binary arrays of equal shape.
  14. Parameters
  15. ----------
  16. a : ndarray
  17. b : ndarray
  18. measure : str
  19. one of the following, to select the desired measure:
  20. 'all' returns a dictionary of all values, as well as the intersection and union. Because of speed, partial Wasserstein is not included.
  21. 'overlap' returns the Overlap Coefficient
  22. 'jaccard' returns Jacard Similarity
  23. 'dice' returns the Sørensen-Dice Coefficient
  24. 'px_dist' returns the total pixel distance after accounting for the union
  25. 'partial_wasserstein' returns the (partial) Wasserstein distance, with translation optionally optimised via partial_wasserstein_trans, and with kwargs passed via partial_wasserstein_kwargs
  26. Returns
  27. -------
  28. float or dict
  29. Returns a float for the requested similarity value, or a dictionary with all possible values if `measure == 'all'`, or a dictionary in the case of 'partial_wasserstein'.
  30. """
  31. # early return for partial wasserstein
  32. if measure=='partial_wasserstein':
  33. return( partial_wasserstein_trans(a, b, **partial_wasserstein_kwargs) )
  34. # intersect = np.sum(a+b==2) # area of intersection for the binary case
  35. intersect = np.sum(np.minimum(a, b)) # area of intersection of A and B
  36. area_a = np.sum(a)
  37. area_b = np.sum(b)
  38. union = (area_a + area_b - intersect)
  39. if measure == 'all':
  40. return({
  41. 'intersection': intersect,
  42. 'union': union,
  43. # 'area_a': area_a,
  44. # 'area_b': area_b,
  45. 'overlap': intersect / min((area_a, area_b)),
  46. 'jaccard': intersect / union,
  47. 'dice': (intersect / (area_a + area_b)) * 2,
  48. 'px_dist': px_dist(a, b, translate=False)})
  49. elif measure=='overlap':
  50. return( intersect / min((area_a, area_b)) )
  51. elif measure=='jaccard':
  52. return( intersect / (area_a + area_b - intersect) )
  53. elif measure=='dice':
  54. return( (intersect / (area_a + area_b)) * 2 ) # multiply by two to normalise within [0, 1] for consistency with other measures
  55. elif measure=='px_dist':
  56. return( px_dist(a, b, translate=False) )
  57. else:
  58. return( np.nan )
  59. def translate_ov(a, b, keep_a_constant=True, return_first_only=False, constant_values=0):
  60. """Maximise array overlap, permitting translation only, via cross-correlation.
  61. Parameters
  62. ----------
  63. a : ndarray
  64. b : ndarray
  65. keep_a_constant : bool
  66. Refers to the method for aligning the matrices a and b - if False, a and b are both padded to achieve the overlap, if True, a remains the same size, and b is padded from one side and trimmed from the other side.
  67. return_first_only : bool
  68. Two arrays may have multiple solutions with the same maximal overlap. Should only the first solution be returned? If False, will return list of lists of solutions.
  69. constant_values : float or list or array
  70. Passed to `np.pad`. Will usually want to be the value for the background (default = 0).
  71. Returns
  72. -------
  73. list
  74. Returns a list in form `[a, b, shift]`, where `a` and `b` are the aligned matrices, and `shift` contains translation from default location in coordinates `(x, y)`, where `(0, 0)` would be the default location. If `return_first_only` is `True`, will return a list of solutions that reach the same maximum, in form `[[a, b, shift], [a, b, shift], ...]`.
  75. """
  76. a, b = utils.pad_for_translation(a, b, constant_values=constant_values)
  77. # cross correlate
  78. ab_corr = np.round(correlate(a, b, method='fft', mode='same'), 5) # rounded for floating point precision in fft
  79. max_corr_idx = np.where(ab_corr == np.max(ab_corr))
  80. # translation from centre, for each identified overlap
  81. shift = [
  82. np.negative(np.round((a.shape[0])/2 - max_corr_idx[0])).astype(int),
  83. np.negative(np.round((a.shape[1])/2 - max_corr_idx[1])).astype(int)
  84. ]
  85. # calculate Euclidean distance from centre, of each overlap
  86. euc_dist = np.array([np.sqrt(shift[0][i]**2 + shift[1][i]**2) for i in range(shift[0].shape[0])])
  87. euc_dist_order = euc_dist.argsort()
  88. # sort the original euclidean distance values
  89. euc_dist = euc_dist[euc_dist_order]
  90. # sort overlaps ascendingly in order of Euclidean distance, so that the solutions requiring minimal translation are first
  91. shift[0] = shift[0][euc_dist_order]
  92. shift[1] = shift[1][euc_dist_order]
  93. if return_first_only:
  94. ov_range = 1
  95. else:
  96. ov_range = shift[0].shape[0]
  97. res = []
  98. for ov_nr in range(ov_range):
  99. sh_i = (shift[0][ov_nr], shift[1][ov_nr])
  100. if keep_a_constant:
  101. if sh_i[0]>0:
  102. x_pad = (sh_i[0], 0)
  103. b_trim_x = b[:-sh_i[0], :]
  104. elif sh_i[0]<0:
  105. x_pad = (0, np.abs(sh_i[0]))
  106. b_trim_x = b[np.abs(sh_i[0]):, :]
  107. else:
  108. x_pad = (0, 0)
  109. b_trim_x = b
  110. if sh_i[1]>0:
  111. y_pad = (sh_i[1], 0)
  112. b_trim_xy = b_trim_x[:, :-sh_i[1]]
  113. elif sh_i[1]<0:
  114. y_pad = (0, np.abs(sh_i[1]))
  115. b_trim_xy = b_trim_x[:, np.abs(sh_i[1]):]
  116. else:
  117. y_pad = (0, 0)
  118. b_trim_xy = b_trim_x
  119. b_pad = np.pad(b_trim_xy, (x_pad, y_pad), constant_values=constant_values)
  120. if return_first_only:
  121. res = [a, b_pad, sh_i]
  122. else:
  123. res.append([a, b_pad, sh_i])
  124. else:
  125. if sh_i[0]>0:
  126. x_pad = (sh_i[0], 0)
  127. elif sh_i[0]<0:
  128. x_pad = (0, np.abs(sh_i[0]))
  129. else:
  130. x_pad = (0, 0)
  131. if sh_i[1]>0:
  132. y_pad = (sh_i[1], 0)
  133. elif sh_i[1]<0:
  134. y_pad = (0, np.abs(sh_i[1]))
  135. else:
  136. y_pad = (0, 0)
  137. b_pad = np.pad(b, (x_pad, y_pad), constant_values=constant_values)
  138. a_pad = np.pad(a, (np.flip(x_pad), np.flip(y_pad)), constant_values=constant_values)
  139. if return_first_only:
  140. res = [a_pad, b_pad, sh_i]
  141. else:
  142. res.append([a_pad, b_pad, sh_i])
  143. return(res)
  144. def px_dist(a_arr, b_arr, partial=False, translate=True, constant_values=0, return_res=False, plot=False):
  145. """Calculate the number of pixels needed to be added/deleted to make one array into the other, at the arrays' positions of optimal overlap from `translate_ov()`. Default behaviour will also weight costs by pixel values if not binarised.
  146. Parameters
  147. ----------
  148. a_arr : ndarray
  149. b_arr : ndarray
  150. partial : bool
  151. Should the distance be partial pixel distance? If `True`, will calculate number of pixels from a that cannot be explained by b, but will ignore pixels in b that cannot be explained by a.
  152. translate : bool
  153. Should the distance be calculated at positions of maximum overlap (from `translate_ov()`)? If `False`, will just use default locations.
  154. constant_values : float or list or array
  155. Passed to `np.pad()` by `translate_ov()`. Will usually want to be the value for the background (default = 0).
  156. return_res : bool
  157. Should the function return a dictionary of results split by direction of addition/deletion? If True, will return list with regular result as first entry, and a dictionary of results split by direction as second entry.
  158. plot : bool
  159. Should the result be plotted?
  160. Returns
  161. -------
  162. float or list
  163. Returns a float for the requested distance value, or, if `return_res` is `True`, a list whose first entry is the overall distance and whose second entry is a dictionary with the results split by array ('a_addition' and 'b_addition').
  164. Examples
  165. --------
  166. >>> a = np.zeros((4, 4))
  167. >>> a[:, 3] = 1
  168. >>> b = np.zeros((4, 5))
  169. >>> b[:, 1] = 1
  170. >>> b[2, 3] = 0
  171. >>> b[1, 0:2] = 1
  172. >>> px_dist(a, b, return_res=True)
  173. [1.0, {'a_addition': 0.0, 'b_addition': 1.0}]
  174. """
  175. if translate:
  176. a_aligned, b_aligned, *_ = translate_ov(a_arr, b_arr, return_first_only=True, constant_values=constant_values)
  177. else:
  178. a_aligned = a_arr
  179. b_aligned = b_arr
  180. arrs_diff = a_aligned-b_aligned
  181. if partial:
  182. arrs_diff[arrs_diff>0] = 0
  183. if plot:
  184. plt.imshow(utils.crop_zeros(x=arrs_diff, y=a_aligned+b_aligned), interpolation='none', cmap='RdBu', norm=mpl.colors.TwoSlopeNorm(vmin=-np.abs(arrs_diff).max(), vcenter=0, vmax=np.abs(arrs_diff).max()))
  185. plt.colorbar()
  186. plt.title('a - b')
  187. if return_res:
  188. res = {'a_addition': np.sum(np.abs(arrs_diff[arrs_diff>0])),
  189. 'b_addition': np.sum(np.abs(arrs_diff[arrs_diff<0]))}
  190. return [np.sum(np.abs(arrs_diff)), res]
  191. else:
  192. return np.sum(np.abs(arrs_diff))
  193. def partial_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, trans_manual=(0,0), constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, plot=False):
  194. """Calculate the partial Wasserstein metric between two arrays, with optional translation applied first.
  195. Parameters
  196. ----------
  197. a_arr : ndarray
  198. b_arr : ndarray
  199. scale_mass : bool
  200. Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
  201. `scale_mass_method` : str
  202. The method used to scale mass. One of the following:
  203. 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
  204. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
  205. 'proportion': Divide masses in both arrays by their sum.
  206. mass_normalise : bool
  207. Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
  208. distance_normalise : bool
  209. Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.)
  210. del_weight : float
  211. ins_weight : float
  212. trans_weight : float
  213. entropic_reg_term : float
  214. Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution.
  215. return_plan : bool
  216. Should the function return the transport plan instead of the metric?
  217. return_res : bool
  218. Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan).
  219. trans_manual : tuple
  220. A tuple of the form `(x, y)`, where `x` and `y` can be floats, for a translation that should be applied to `a_arr` prior to calculating the metric.
  221. constant_values : float or list or array
  222. Passed to `np.pad()`. Will usually want to be the value for the background (default = 0).
  223. distance_metric : str
  224. Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options.
  225. max_emd_iter : int
  226. The maximum number of iterations to allow the EMD solver used by the POT library.
  227. max_entropic_iter : int
  228. The maximum number of iterations to allow the entropic OT solver used by the POT library.
  229. nb_dummies: int
  230. The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0.
  231. plot: bool
  232. Should the solution be plotted?
  233. Returns
  234. -------
  235. float or tuple
  236. If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return a tuple containing the metric (first element) and the transport plan (second element).
  237. Examples
  238. --------
  239. >>> a = np.zeros((3, 3))
  240. >>> a[0, :] = 1
  241. >>> a[2, 2] = 1
  242. >>> b = np.zeros((3, 3))
  243. >>> b[0, :] = 1
  244. >>> b[2, 0] = 1
  245. >>> partial_wasserstein(a, b)
  246. 2.0
  247. >>> # permits continuous translation (not just whole pixels)
  248. >>> partial_wasserstein(a, b, trans_manual=(0.01, -0.2))
  249. 2.8007722589903596
  250. """
  251. # pad and assign to source and target arrays
  252. s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=constant_values)
  253. # coordinates of mass
  254. xs = np.transpose(np.array(np.where(s!=0)))
  255. xt = np.transpose(np.array(np.where(t!=0)))
  256. # get mass values that correspond to the indices of values
  257. # (first need indices in numpy-friendly format)
  258. # s_idx = tuple(xs.transpose())
  259. # s_hist = s[s_idx]
  260. # t_idx = tuple(xt.transpose())
  261. # t_hist = t[t_idx]
  262. s_hist = s.flatten()[s.flatten()!=0]
  263. t_hist = t.flatten()[t.flatten()!=0]
  264. if scale_mass:
  265. # scale such that there is equal total mass between in both arrays
  266. if scale_mass_method == 'proportion':
  267. s_hist /= s_hist.sum()
  268. t_hist /= t_hist.sum()
  269. elif scale_mass_method == 'upscale':
  270. if s_hist.sum() > t_hist.sum():
  271. t_hist *= s_hist.sum()/t_hist.sum()
  272. elif s_hist.sum() < t_hist.sum():
  273. s_hist *= t_hist.sum()/s_hist.sum()
  274. elif scale_mass_method == 'downscale':
  275. if s_hist.sum() > t_hist.sum():
  276. s_hist *= t_hist.sum()/s_hist.sum()
  277. elif s_hist.sum() < t_hist.sum():
  278. t_hist *= s_hist.sum()/t_hist.sum()
  279. else:
  280. ValueError('Unknown mass-scaling method!')
  281. # account for any precision errors (the tolerance may need adjusting!)
  282. imprec_diff = np.abs(s_hist.sum() - t_hist.sum())
  283. assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it
  284. if s_hist.sum() > t_hist.sum():
  285. t_hist[0] += imprec_diff
  286. elif s_hist.sum() < t_hist.sum():
  287. s_hist[0] += imprec_diff
  288. # check they now have equal total mass
  289. # assert s_hist.sum() == t_hist.sum() # comment out since we don't care about precision errors too small for the method above to account for
  290. # apply the translation
  291. xs_t = np.float64(xs.copy())
  292. xs_t[:, 0] -= trans_manual[0] # -= for consistency with behaviour of utils.pad_translate_mat()
  293. xs_t[:, 1] -= trans_manual[1]
  294. # loss matrix
  295. M = ot.dist(xs_t, xt, metric=distance_metric)
  296. if distance_normalise:
  297. M /= M.max()
  298. # calculate partial wasserstein
  299. m = np.min([s_hist.sum(), t_hist.sum()]) # mass to transport
  300. if entropic_reg_term == 0:
  301. pw_metric = ot.partial.partial_wasserstein2(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
  302. else:
  303. pw_metric = np.sum(M * ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter))
  304. # normalise by mass
  305. if mass_normalise:
  306. pw_metric /= m
  307. # calculate total cost
  308. del_cost = del_weight * (s_hist.sum() - t_hist.sum()) if s_hist.sum() > t_hist.sum() else 0.0
  309. ins_cost = ins_weight * (t_hist.sum() - s_hist.sum()) if t_hist.sum() > s_hist.sum() else 0.0
  310. trans_cost = trans_weight * pw_metric
  311. total_cost = del_cost + ins_cost + trans_cost
  312. # plot
  313. if plot:
  314. if np.any(np.array(trans_manual) != 0):
  315. warnings.warn('The plot will only display *original* translation values, as translation is optimised continuously')
  316. if entropic_reg_term == 0:
  317. tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
  318. else:
  319. tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter)
  320. pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
  321. pl_rgb[:, :, 0] = s/s.max()
  322. pl_rgb[:, :, 2] = t/t.max()
  323. fig = plt.figure()
  324. plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
  325. ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3)
  326. fig.show()
  327. # return transport plan if requested
  328. if return_res:
  329. if entropic_reg_term == 0:
  330. tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
  331. else:
  332. tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_emd_iter)
  333. return({'total_cost': total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp})
  334. # return result
  335. return(total_cost)
  336. def partial_wasserstein_trans(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, translation='opt', del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, plot=False, constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, n_startvals=7, solver='Nelder-Mead', search_method='grid', options=None):
  337. """Calculate the partial Wasserstein metric between two arrays, with translation permitted.
  338. Parameters
  339. ----------
  340. a_arr : ndarray
  341. b_arr : ndarray
  342. scale_mass : bool
  343. Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
  344. `scale_mass_method` : str
  345. The method used to scale mass. One of the following:
  346. 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
  347. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
  348. 'proportion': Divide masses in both arrays by their sum.
  349. mass_normalise : bool
  350. Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
  351. distance_normalise : bool
  352. Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.)
  353. translation : str
  354. Method for aligning the two arrays prior to distance calulation. Options are:
  355. 'opt': Use non-linear optimisation to find the overlay that minimises the partial Wasserstein distance. This method will be slow.
  356. 'crosscor': Use cross-correlation to align the matrices at the location of their maximal overlap.
  357. None: use the default positions.
  358. del_weight : float
  359. ins_weight : float
  360. trans_weight : float
  361. entropic_reg_term : float
  362. Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution.
  363. return_res : bool
  364. # Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining the usual `'trans'` and `'metric'`, but also `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, and `'tp'` (transport plan).
  365. constant_values : float or list or array
  366. Passed to `np.pad()`. Will usually want to be the value for the background (default = 0).
  367. distance_metric : str
  368. Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options.
  369. max_emd_iter : int
  370. The maximum number of iterations to allow the EMD solver used by the POT library.
  371. max_entropic_iter : int
  372. The maximum number of iterations to allow the entropic OT solver used by the POT library.
  373. nb_dummies: int
  374. The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0.
  375. plot: bool
  376. Should the solution be plotted?
  377. n_startvals : int
  378. The number of starting values to try in optimmising translation, if `translation=='opt'`.
  379. solver : str
  380. The solver to use in optimisation, if `translation='opt'`. Possible values are those available to `scipy.optimize.minimize()`.
  381. search_method : str
  382. Method for setting starting values if `translation=='opt'`. Options are:
  383. 'grid': set in equal steps from the lower to the upper bound
  384. 'random': set randomly between the lower and upper bound
  385. options : dict
  386. Options to pass to the solver. E.g., `{'maxiter': 100}`.
  387. Returns
  388. -------
  389. dict
  390. If `return_plan` is `False`, will return a `dict` containing `'trans'` (the translation used), and `'metric'`. If `return_plan` is `True`, this `dict` will also contain the transport plan.
  391. Examples
  392. --------
  393. >>> # compare solutions - crosscorrelation maximises overlap; optimiser minimises distance cost
  394. >>> import draw, utils
  395. >>> a_arr = draw.text_array('d', size=50)
  396. >>> b_arr = draw.text_array('p', size=50)
  397. >>> # padding is only used for the plotting; unpadded result will be identical
  398. >>> a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr)
  399. >>> cc_sol = partial_wasserstein_trans(a_pad, b_pad, translation='crosscor')
  400. >>> pw_sol = partial_wasserstein_trans(a_pad, b_pad, translation='opt')
  401. >>> import matplotlib.pyplot as plt
  402. >>> plt.imshow(utils.pad_translate_mat(a_pad, cc_sol['trans'][0], cc_sol['trans'][1]) + b_pad)
  403. >>> plt.imshow(utils.pad_translate_mat(a_pad, int(pw_sol['trans'][0]), int(pw_sol['trans'][1])) + b_pad)
  404. """
  405. if translation == 'opt':
  406. # find the best translation using non-linear optimisation to minimise distance
  407. # (note that the translation method used in partial_wasserstein_trans() permits float changes, i.e., not just whole pixels)
  408. max_shifts = [np.max([a_arr.shape[i], b_arr.shape[i]]) for i in range(len(a_arr.shape))]
  409. bounds = [(-sh, sh) for sh in max_shifts]
  410. if search_method=='grid':
  411. start_vals_arr = np.array([np.linspace(-sh, sh, num=n_startvals, endpoint=True) for sh in max_shifts])
  412. elif search_method=='random':
  413. start_vals_arr = np.array([np.random.uniform(-sh, sh, size=n_startvals) for sh in max_shifts])
  414. start_vals_tuples = [(start_vals_arr[0, i], start_vals_arr[1, i]) for i in range(start_vals_arr.shape[1])]
  415. iter_res = []
  416. def f(x0, a_arr=a_arr, b_arr=b_arr):
  417. return(partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=x0, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies))
  418. for st in start_vals_tuples:
  419. iter_res.append(sp.optimize.minimize(
  420. f, x0=st, method=solver, bounds=bounds,
  421. options=options
  422. ))
  423. # get results of each iteration from the search
  424. fun_vals = np.array([i['fun'] for i in iter_res])
  425. metric = np.min(fun_vals) # optimal value
  426. # get optimal translation as that which minimises the optimal transport, and, if there are multiple solutions of equal OT values, that with the minimal Euclidean shift from starting positions
  427. poss_sols = [i['x'] for i in iter_res if i['fun']==metric]
  428. poss_sols_costs = np.array([np.sqrt(x[0]**2 + x[1]**2) for x in poss_sols])
  429. optimum_idx = np.argwhere(poss_sols_costs==np.min(poss_sols_costs))
  430. trans_res = poss_sols[int(optimum_idx[0])]
  431. # check whether rounding to nearest pixel improves the result - if so, return the rounded values
  432. rounded_res = f(x0 = np.round(trans_res))
  433. if rounded_res < metric:
  434. res = {'trans': np.round(trans_res), 'metric': rounded_res}
  435. else:
  436. res = {'trans': trans_res, 'metric': metric}
  437. if plot:
  438. _ = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], plot=plot, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
  439. if return_res:
  440. # get all the values requested and merge the dictionaries
  441. full_res = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
  442. res = {**res, **full_res}
  443. elif translation == 'crosscor':
  444. a_pad, b_pad, shift = translate_ov(a_arr, b_arr, constant_values=constant_values, return_first_only=True)
  445. # crop to reduce size
  446. a_cr = utils.crop_zeros(a_pad, a_pad+b_pad)
  447. b_cr = utils.crop_zeros(b_pad, a_pad+b_pad)
  448. metric = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot)
  449. res = {'trans': shift, 'metric': metric}
  450. if return_res:
  451. # get all the values requested and merge the dictionaries
  452. full_res = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, return_res=True)
  453. res = {**res, **full_res}
  454. elif translation is None:
  455. a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr, pad=False)
  456. metric = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot)
  457. res = {'trans': (0,0), 'metric': metric}
  458. if return_res:
  459. # get all the values requested and merge the dictionaries
  460. full_res = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
  461. res = {**res, **full_res}
  462. return(res)
  463. def partial_gromov_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, return_res=False, max_emd_iter=int(1e7), nb_dummies=1, plot=False):
  464. """Calculate the partial Gromov-Wasserstein metric between two arrays.
  465. Parameters
  466. ----------
  467. a_arr : ndarray
  468. b_arr : ndarray
  469. scale_mass : bool
  470. Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
  471. `scale_mass_method` : str
  472. The method used to scale mass. One of the following:
  473. 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
  474. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
  475. 'proportion': Divide masses in both arrays by their sum.
  476. mass_normalise : bool
  477. Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
  478. del_weight : float
  479. ins_weight : float
  480. trans_weight : float
  481. return_res : bool
  482. Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan).
  483. max_emd_iter : int
  484. The maximum number of iterations to allow the EMD solver used by the POT library.
  485. nb_dummies: int
  486. The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities.
  487. plot: bool
  488. Should the solution be plotted?
  489. Returns
  490. -------
  491. float or ndarray
  492. If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return the transport plan.
  493. Examples
  494. --------
  495. >>> # rotation-invariant
  496. >>> a = np.zeros((3, 3))
  497. >>> a[0, :] = 1
  498. >>> a[1, 2] = 1
  499. >>> b = np.zeros((3, 3))
  500. >>> b[0, :] = 1
  501. >>> b[1, 0] = 1
  502. >>> partial_gromov_wasserstein(a, b)
  503. 0.0
  504. >>> # translation-invariant
  505. >>> c = utils.pad_translate_mat(b, -1, 0)
  506. >>> partial_gromov_wasserstein(b, c)
  507. 0.0
  508. >>> # compare to
  509. >>> d = np.zeros((3, 3))
  510. >>> d[0, :] = 1
  511. >>> d[2, 1] = 1
  512. >>> partial_gromov_wasserstein(a, d)
  513. 0.44546019305093054
  514. >>> # and
  515. >>> e = np.zeros((3, 3))
  516. >>> e[:, 1] = 1
  517. >>> f = np.zeros((3, 3))
  518. >>> f[0:2, 1] = 1
  519. >>> f[2, 2] = 1
  520. >>> partial_gromov_wasserstein(e, f)
  521. 0.020330872466366223
  522. >>> # finally
  523. >>> import draw
  524. >>> a_arr = draw.text_array('d', size=50)
  525. >>> b_arr = draw.text_array('p', size=50)
  526. >>> partial_gromov_wasserstein(a_arr, b_arr, scale_mass=True)
  527. 33.843035785027006
  528. >>> # vs.
  529. >>> rng = np.random.default_rng()
  530. >>> noise = a_arr.copy()
  531. >>> rng.shuffle(noise, axis=0)
  532. >>> rng.shuffle(noise, axis=1)
  533. >>> partial_gromov_wasserstein(a_arr, noise, scale_mass=True)
  534. 397.4592829552584
  535. """
  536. # coordinates of mass
  537. xs = np.transpose(np.array(np.where(a_arr!=0)))
  538. xt = np.transpose(np.array(np.where(b_arr!=0)))
  539. # distance kernels
  540. C1 = sp.spatial.distance.cdist(xs, xs)
  541. C2 = sp.spatial.distance.cdist(xt, xt)
  542. # normalise the distance kernels
  543. C1 /= C1.max()
  544. C2 /= C2.max()
  545. # get mass values that correspond to the indices of values (hist)
  546. p = a_arr.flatten()[a_arr.flatten()!=0]
  547. q = b_arr.flatten()[b_arr.flatten()!=0]
  548. if scale_mass:
  549. # scale such that there is equal total mass between in both arrays
  550. if scale_mass_method == 'proportion':
  551. p /= np.sum(p)
  552. q /= np.sum(q)
  553. elif scale_mass_method == 'upscale':
  554. if np.sum(p) > np.sum(q):
  555. q *= np.sum(p)/np.sum(q)
  556. elif np.sum(p) < np.sum(q):
  557. p *= np.sum(q)/np.sum(p)
  558. elif scale_mass_method == 'downscale':
  559. if np.sum(p) > np.sum(q):
  560. p *= np.sum(q)/np.sum(p)
  561. elif np.sum(p) < np.sum(q):
  562. q *= np.sum(p)/np.sum(q)
  563. else:
  564. ValueError('Unknown mass-scaling method!')
  565. # account for any precision errors (the tolerance may need adjusting!)
  566. imprec_diff = np.abs(np.sum(p) - np.sum(q))
  567. assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it
  568. if np.sum(p) > np.sum(q):
  569. q[0] += imprec_diff
  570. elif np.sum(p) < np.sum(q):
  571. p[0] += imprec_diff
  572. # check they now have equal total mass
  573. # assert np.sum(p) == np.sum(q) # comment out since we don't care about precision errors too small for the method above to account for
  574. # amount of mass to be moved should be the maximum possible
  575. m = np.min([np.sum(p), np.sum(q)])
  576. pgw_metric = ot.partial.partial_gromov_wasserstein2(C1=C1, C2=C2, p=p, q=q, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
  577. # normalise by mass
  578. if mass_normalise:
  579. pgw_metric /= m
  580. # calculate total cost
  581. del_cost = del_weight * (p.sum() - q.sum()) if p.sum() > q.sum() else 0.0
  582. ins_cost = ins_weight * (q.sum() - p.sum()) if q.sum() > p.sum() else 0.0
  583. trans_cost = trans_weight * pgw_metric
  584. total_cost = del_cost + ins_cost + trans_cost
  585. # plot
  586. if plot:
  587. tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m)
  588. # not in the same space, but plot as though they are
  589. s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=0.0)
  590. # coordinates of mass
  591. xs = np.transpose(np.array(np.where(s!=0)))
  592. xt = np.transpose(np.array(np.where(t!=0)))
  593. pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
  594. pl_rgb[:, :, 0] = s/s.max()
  595. pl_rgb[:, :, 2] = t/t.max()
  596. # pl_rgb[pl_rgb.sum(axis=2)==0, :] = 1
  597. fig = plt.figure()
  598. plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
  599. ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3)
  600. fig.show()
  601. if return_res:
  602. tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m)
  603. return({'total_cost':total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp})
  604. # return result
  605. return(total_cost)
  606. # import draw
  607. # import matplotlib.pyplot as plt
  608. # from string import ascii_lowercase
  609. # from tqdm import tqdm
  610. # chars = [draw.text_array(c, size=50) for c in ascii_lowercase]
  611. # all_res = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random') for c_i in tqdm(chars) for c_j in chars]
  612. # all_dists = np.array([x['metric'] for x in all_res]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
  613. # plt.imshow(all_dists, interpolation='none')
  614. # plt.colorbar()
  615. # all_res_sc = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random', scale_mass=True) for c_i in tqdm(chars) for c_j in chars]
  616. # all_dists_sc = np.array([x['metric'] for x in all_res_sc]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
  617. # plt.imshow(all_dists_sc, interpolation='none')
  618. # plt.colorbar()
  619. # all_res_cc = [partial_wasserstein_trans(c_i, c_j, translation='crosscor') for c_i in tqdm(chars) for c_j in chars]
  620. # all_dists_cc = np.array([x['metric'] for x in all_res_cc]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
  621. # plt.imshow(all_dists_cc, interpolation='none')
  622. # plt.colorbar()
  623. # playing with getting sub-character neighbourhood information from transport plans
  624. # looks like you get sensible results if you scale the mass when a.sum()>b.sum(), and don't scale the mass when b.sum()>a.sum()