123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728 |
- """Functions for array similarity metrics"""
- import numpy as np
- from scipy.signal import correlate
- import scipy as sp
- from skimage.morphology import binary_dilation
- import ot
- import ot.plot
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- import warnings
- from scold import utils
- def arr_sim(a, b, measure='all', partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'translation':'opt', 'n_startvals':9, 'solver':'Nelder-Mead', 'search_method':'grid'}):
- """Calculate a measure of array similarity for 2 (aligned) binary arrays of equal shape.
-
- Parameters
- ----------
- a : ndarray
- b : ndarray
- measure : str
- one of the following, to select the desired measure:
- 'all' returns a dictionary of all values, as well as the intersection and union. Because of speed, partial Wasserstein is not included.
- 'overlap' returns the Overlap Coefficient
- 'jaccard' returns Jacard Similarity
- 'dice' returns the Sørensen-Dice Coefficient
- 'px_dist' returns the total pixel distance after accounting for the union
- 'partial_wasserstein' returns the (partial) Wasserstein distance, with translation optionally optimised via partial_wasserstein_trans, and with kwargs passed via partial_wasserstein_kwargs
-
- Returns
- -------
- float or dict
- Returns a float for the requested similarity value, or a dictionary with all possible values if `measure == 'all'`, or a dictionary in the case of 'partial_wasserstein'.
- """
- # early return for partial wasserstein
- if measure=='partial_wasserstein':
- return( partial_wasserstein_trans(a, b, **partial_wasserstein_kwargs) )
- # intersect = np.sum(a+b==2) # area of intersection for the binary case
- intersect = np.sum(np.minimum(a, b)) # area of intersection of A and B
- area_a = np.sum(a)
- area_b = np.sum(b)
- union = (area_a + area_b - intersect)
-
- if measure == 'all':
- return({
- 'intersection': intersect,
- 'union': union,
- # 'area_a': area_a,
- # 'area_b': area_b,
- 'overlap': intersect / min((area_a, area_b)),
- 'jaccard': intersect / union,
- 'dice': (intersect / (area_a + area_b)) * 2,
- 'px_dist': px_dist(a, b, translate=False)})
- elif measure=='overlap':
- return( intersect / min((area_a, area_b)) )
- elif measure=='jaccard':
- return( intersect / (area_a + area_b - intersect) )
- elif measure=='dice':
- return( (intersect / (area_a + area_b)) * 2 ) # multiply by two to normalise within [0, 1] for consistency with other measures
- elif measure=='px_dist':
- return( px_dist(a, b, translate=False) )
- else:
- return( np.nan )
- def translate_ov(a, b, keep_a_constant=True, return_first_only=False, constant_values=0):
- """Maximise array overlap, permitting translation only, via cross-correlation.
-
- Parameters
- ----------
- a : ndarray
- b : ndarray
- keep_a_constant : bool
- Refers to the method for aligning the matrices a and b - if False, a and b are both padded to achieve the overlap, if True, a remains the same size, and b is padded from one side and trimmed from the other side.
- return_first_only : bool
- Two arrays may have multiple solutions with the same maximal overlap. Should only the first solution be returned? If False, will return list of lists of solutions.
- constant_values : float or list or array
- Passed to `np.pad`. Will usually want to be the value for the background (default = 0).
-
- Returns
- -------
- list
- Returns a list in form `[a, b, shift]`, where `a` and `b` are the aligned matrices, and `shift` contains translation from default location in coordinates `(x, y)`, where `(0, 0)` would be the default location. If `return_first_only` is `True`, will return a list of solutions that reach the same maximum, in form `[[a, b, shift], [a, b, shift], ...]`.
- """
- a, b = utils.pad_for_translation(a, b, constant_values=constant_values)
-
- # cross correlate
- ab_corr = np.round(correlate(a, b, method='fft', mode='same'), 5) # rounded for floating point precision in fft
- max_corr_idx = np.where(ab_corr == np.max(ab_corr))
-
- # translation from centre, for each identified overlap
- shift = [
- np.negative(np.round((a.shape[0])/2 - max_corr_idx[0])).astype(int),
- np.negative(np.round((a.shape[1])/2 - max_corr_idx[1])).astype(int)
- ]
-
- # calculate Euclidean distance from centre, of each overlap
- euc_dist = np.array([np.sqrt(shift[0][i]**2 + shift[1][i]**2) for i in range(shift[0].shape[0])])
- euc_dist_order = euc_dist.argsort()
-
- # sort the original euclidean distance values
- euc_dist = euc_dist[euc_dist_order]
-
- # sort overlaps ascendingly in order of Euclidean distance, so that the solutions requiring minimal translation are first
- shift[0] = shift[0][euc_dist_order]
- shift[1] = shift[1][euc_dist_order]
-
- if return_first_only:
- ov_range = 1
- else:
- ov_range = shift[0].shape[0]
-
- res = []
-
- for ov_nr in range(ov_range):
- sh_i = (shift[0][ov_nr], shift[1][ov_nr])
-
- if keep_a_constant:
- if sh_i[0]>0:
- x_pad = (sh_i[0], 0)
- b_trim_x = b[:-sh_i[0], :]
- elif sh_i[0]<0:
- x_pad = (0, np.abs(sh_i[0]))
- b_trim_x = b[np.abs(sh_i[0]):, :]
- else:
- x_pad = (0, 0)
- b_trim_x = b
- if sh_i[1]>0:
- y_pad = (sh_i[1], 0)
- b_trim_xy = b_trim_x[:, :-sh_i[1]]
- elif sh_i[1]<0:
- y_pad = (0, np.abs(sh_i[1]))
- b_trim_xy = b_trim_x[:, np.abs(sh_i[1]):]
- else:
- y_pad = (0, 0)
- b_trim_xy = b_trim_x
-
- b_pad = np.pad(b_trim_xy, (x_pad, y_pad), constant_values=constant_values)
-
- if return_first_only:
- res = [a, b_pad, sh_i]
- else:
- res.append([a, b_pad, sh_i])
-
- else:
- if sh_i[0]>0:
- x_pad = (sh_i[0], 0)
- elif sh_i[0]<0:
- x_pad = (0, np.abs(sh_i[0]))
- else:
- x_pad = (0, 0)
-
- if sh_i[1]>0:
- y_pad = (sh_i[1], 0)
- elif sh_i[1]<0:
- y_pad = (0, np.abs(sh_i[1]))
- else:
- y_pad = (0, 0)
-
- b_pad = np.pad(b, (x_pad, y_pad), constant_values=constant_values)
- a_pad = np.pad(a, (np.flip(x_pad), np.flip(y_pad)), constant_values=constant_values)
- if return_first_only:
- res = [a_pad, b_pad, sh_i]
- else:
- res.append([a_pad, b_pad, sh_i])
-
- return(res)
- def px_dist(a_arr, b_arr, partial=False, translate=True, constant_values=0, return_res=False, plot=False):
- """Calculate the number of pixels needed to be added/deleted to make one array into the other, at the arrays' positions of optimal overlap from `translate_ov()`. Default behaviour will also weight costs by pixel values if not binarised.
-
- Parameters
- ----------
- a_arr : ndarray
- b_arr : ndarray
- partial : bool
- Should the distance be partial pixel distance? If `True`, will calculate number of pixels from a that cannot be explained by b, but will ignore pixels in b that cannot be explained by a.
- translate : bool
- Should the distance be calculated at positions of maximum overlap (from `translate_ov()`)? If `False`, will just use default locations.
- constant_values : float or list or array
- Passed to `np.pad()` by `translate_ov()`. Will usually want to be the value for the background (default = 0).
- return_res : bool
- Should the function return a dictionary of results split by direction of addition/deletion? If True, will return list with regular result as first entry, and a dictionary of results split by direction as second entry.
- plot : bool
- Should the result be plotted?
-
- Returns
- -------
- float or list
- Returns a float for the requested distance value, or, if `return_res` is `True`, a list whose first entry is the overall distance and whose second entry is a dictionary with the results split by array ('a_addition' and 'b_addition').
-
- Examples
- --------
- >>> a = np.zeros((4, 4))
- >>> a[:, 3] = 1
- >>> b = np.zeros((4, 5))
- >>> b[:, 1] = 1
- >>> b[2, 3] = 0
- >>> b[1, 0:2] = 1
- >>> px_dist(a, b, return_res=True)
- [1.0, {'a_addition': 0.0, 'b_addition': 1.0}]
- """
- if translate:
- a_aligned, b_aligned, *_ = translate_ov(a_arr, b_arr, return_first_only=True, constant_values=constant_values)
- else:
- a_aligned = a_arr
- b_aligned = b_arr
- arrs_diff = a_aligned-b_aligned
- if partial:
- arrs_diff[arrs_diff>0] = 0
- if plot:
- plt.imshow(utils.crop_zeros(x=arrs_diff, y=a_aligned+b_aligned), interpolation='none', cmap='RdBu', norm=mpl.colors.TwoSlopeNorm(vmin=-np.abs(arrs_diff).max(), vcenter=0, vmax=np.abs(arrs_diff).max()))
- plt.colorbar()
- plt.title('a - b')
-
- if return_res:
- res = {'a_addition': np.sum(np.abs(arrs_diff[arrs_diff>0])),
- 'b_addition': np.sum(np.abs(arrs_diff[arrs_diff<0]))}
- return [np.sum(np.abs(arrs_diff)), res]
- else:
- return np.sum(np.abs(arrs_diff))
- def partial_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, trans_manual=(0,0), constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, plot=False):
- """Calculate the partial Wasserstein metric between two arrays, with optional translation applied first.
-
- Parameters
- ----------
- a_arr : ndarray
- b_arr : ndarray
- scale_mass : bool
- Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
- `scale_mass_method` : str
- The method used to scale mass. One of the following:
- 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
- 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
- 'proportion': Divide masses in both arrays by their sum.
- mass_normalise : bool
- Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
- distance_normalise : bool
- Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.)
- del_weight : float
- ins_weight : float
- trans_weight : float
- entropic_reg_term : float
- Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution.
- return_plan : bool
- Should the function return the transport plan instead of the metric?
- return_res : bool
- Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan).
- trans_manual : tuple
- A tuple of the form `(x, y)`, where `x` and `y` can be floats, for a translation that should be applied to `a_arr` prior to calculating the metric.
- constant_values : float or list or array
- Passed to `np.pad()`. Will usually want to be the value for the background (default = 0).
- distance_metric : str
- Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options.
- max_emd_iter : int
- The maximum number of iterations to allow the EMD solver used by the POT library.
- max_entropic_iter : int
- The maximum number of iterations to allow the entropic OT solver used by the POT library.
- nb_dummies: int
- The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0.
- plot: bool
- Should the solution be plotted?
-
- Returns
- -------
- float or tuple
- If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return a tuple containing the metric (first element) and the transport plan (second element).
- Examples
- --------
- >>> a = np.zeros((3, 3))
- >>> a[0, :] = 1
- >>> a[2, 2] = 1
- >>> b = np.zeros((3, 3))
- >>> b[0, :] = 1
- >>> b[2, 0] = 1
- >>> partial_wasserstein(a, b)
- 2.0
- >>> # permits continuous translation (not just whole pixels)
- >>> partial_wasserstein(a, b, trans_manual=(0.01, -0.2))
- 2.8007722589903596
- """
- # pad and assign to source and target arrays
- s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=constant_values)
- # coordinates of mass
- xs = np.transpose(np.array(np.where(s!=0)))
- xt = np.transpose(np.array(np.where(t!=0)))
- # get mass values that correspond to the indices of values
- # (first need indices in numpy-friendly format)
- # s_idx = tuple(xs.transpose())
- # s_hist = s[s_idx]
- # t_idx = tuple(xt.transpose())
- # t_hist = t[t_idx]
- s_hist = s.flatten()[s.flatten()!=0]
- t_hist = t.flatten()[t.flatten()!=0]
- if scale_mass:
- # scale such that there is equal total mass between in both arrays
- if scale_mass_method == 'proportion':
- s_hist /= s_hist.sum()
- t_hist /= t_hist.sum()
- elif scale_mass_method == 'upscale':
- if s_hist.sum() > t_hist.sum():
- t_hist *= s_hist.sum()/t_hist.sum()
- elif s_hist.sum() < t_hist.sum():
- s_hist *= t_hist.sum()/s_hist.sum()
- elif scale_mass_method == 'downscale':
- if s_hist.sum() > t_hist.sum():
- s_hist *= t_hist.sum()/s_hist.sum()
- elif s_hist.sum() < t_hist.sum():
- t_hist *= s_hist.sum()/t_hist.sum()
- else:
- ValueError('Unknown mass-scaling method!')
-
- # account for any precision errors (the tolerance may need adjusting!)
- imprec_diff = np.abs(s_hist.sum() - t_hist.sum())
- assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it
- if s_hist.sum() > t_hist.sum():
- t_hist[0] += imprec_diff
- elif s_hist.sum() < t_hist.sum():
- s_hist[0] += imprec_diff
-
- # check they now have equal total mass
- # assert s_hist.sum() == t_hist.sum() # comment out since we don't care about precision errors too small for the method above to account for
- # apply the translation
- xs_t = np.float64(xs.copy())
- xs_t[:, 0] -= trans_manual[0] # -= for consistency with behaviour of utils.pad_translate_mat()
- xs_t[:, 1] -= trans_manual[1]
- # loss matrix
- M = ot.dist(xs_t, xt, metric=distance_metric)
- if distance_normalise:
- M /= M.max()
- # calculate partial wasserstein
- m = np.min([s_hist.sum(), t_hist.sum()]) # mass to transport
- if entropic_reg_term == 0:
- pw_metric = ot.partial.partial_wasserstein2(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
- else:
- pw_metric = np.sum(M * ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter))
- # normalise by mass
- if mass_normalise:
- pw_metric /= m
- # calculate total cost
- del_cost = del_weight * (s_hist.sum() - t_hist.sum()) if s_hist.sum() > t_hist.sum() else 0.0
- ins_cost = ins_weight * (t_hist.sum() - s_hist.sum()) if t_hist.sum() > s_hist.sum() else 0.0
- trans_cost = trans_weight * pw_metric
- total_cost = del_cost + ins_cost + trans_cost
- # plot
- if plot:
- if np.any(np.array(trans_manual) != 0):
- warnings.warn('The plot will only display *original* translation values, as translation is optimised continuously')
- if entropic_reg_term == 0:
- tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
- else:
- tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter)
- pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
- pl_rgb[:, :, 0] = s/s.max()
- pl_rgb[:, :, 2] = t/t.max()
- fig = plt.figure()
- plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
- ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3)
- fig.show()
- # return transport plan if requested
- if return_res:
- if entropic_reg_term == 0:
- tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
- else:
- tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_emd_iter)
- return({'total_cost': total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp})
- # return result
- return(total_cost)
- def partial_wasserstein_trans(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, translation='opt', del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, plot=False, constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, n_startvals=7, solver='Nelder-Mead', search_method='grid', options=None):
- """Calculate the partial Wasserstein metric between two arrays, with translation permitted.
-
- Parameters
- ----------
- a_arr : ndarray
- b_arr : ndarray
- scale_mass : bool
- Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
- `scale_mass_method` : str
- The method used to scale mass. One of the following:
- 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
- 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
- 'proportion': Divide masses in both arrays by their sum.
- mass_normalise : bool
- Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
- distance_normalise : bool
- Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.)
- translation : str
- Method for aligning the two arrays prior to distance calulation. Options are:
- 'opt': Use non-linear optimisation to find the overlay that minimises the partial Wasserstein distance. This method will be slow.
- 'crosscor': Use cross-correlation to align the matrices at the location of their maximal overlap.
- None: use the default positions.
- del_weight : float
- ins_weight : float
- trans_weight : float
- entropic_reg_term : float
- Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution.
- return_res : bool
- # Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining the usual `'trans'` and `'metric'`, but also `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, and `'tp'` (transport plan).
- constant_values : float or list or array
- Passed to `np.pad()`. Will usually want to be the value for the background (default = 0).
- distance_metric : str
- Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options.
- max_emd_iter : int
- The maximum number of iterations to allow the EMD solver used by the POT library.
- max_entropic_iter : int
- The maximum number of iterations to allow the entropic OT solver used by the POT library.
- nb_dummies: int
- The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0.
- plot: bool
- Should the solution be plotted?
- n_startvals : int
- The number of starting values to try in optimmising translation, if `translation=='opt'`.
- solver : str
- The solver to use in optimisation, if `translation='opt'`. Possible values are those available to `scipy.optimize.minimize()`.
- search_method : str
- Method for setting starting values if `translation=='opt'`. Options are:
- 'grid': set in equal steps from the lower to the upper bound
- 'random': set randomly between the lower and upper bound
- options : dict
- Options to pass to the solver. E.g., `{'maxiter': 100}`.
-
- Returns
- -------
- dict
- If `return_plan` is `False`, will return a `dict` containing `'trans'` (the translation used), and `'metric'`. If `return_plan` is `True`, this `dict` will also contain the transport plan.
- Examples
- --------
- >>> # compare solutions - crosscorrelation maximises overlap; optimiser minimises distance cost
- >>> import draw, utils
- >>> a_arr = draw.text_array('d', size=50)
- >>> b_arr = draw.text_array('p', size=50)
- >>> # padding is only used for the plotting; unpadded result will be identical
- >>> a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr)
- >>> cc_sol = partial_wasserstein_trans(a_pad, b_pad, translation='crosscor')
- >>> pw_sol = partial_wasserstein_trans(a_pad, b_pad, translation='opt')
- >>> import matplotlib.pyplot as plt
- >>> plt.imshow(utils.pad_translate_mat(a_pad, cc_sol['trans'][0], cc_sol['trans'][1]) + b_pad)
- >>> plt.imshow(utils.pad_translate_mat(a_pad, int(pw_sol['trans'][0]), int(pw_sol['trans'][1])) + b_pad)
- """
- if translation == 'opt':
- # find the best translation using non-linear optimisation to minimise distance
- # (note that the translation method used in partial_wasserstein_trans() permits float changes, i.e., not just whole pixels)
- max_shifts = [np.max([a_arr.shape[i], b_arr.shape[i]]) for i in range(len(a_arr.shape))]
- bounds = [(-sh, sh) for sh in max_shifts]
- if search_method=='grid':
- start_vals_arr = np.array([np.linspace(-sh, sh, num=n_startvals, endpoint=True) for sh in max_shifts])
- elif search_method=='random':
- start_vals_arr = np.array([np.random.uniform(-sh, sh, size=n_startvals) for sh in max_shifts])
- start_vals_tuples = [(start_vals_arr[0, i], start_vals_arr[1, i]) for i in range(start_vals_arr.shape[1])]
- iter_res = []
- def f(x0, a_arr=a_arr, b_arr=b_arr):
- return(partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=x0, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies))
-
- for st in start_vals_tuples:
- iter_res.append(sp.optimize.minimize(
- f, x0=st, method=solver, bounds=bounds,
- options=options
- ))
-
- # get results of each iteration from the search
- fun_vals = np.array([i['fun'] for i in iter_res])
- metric = np.min(fun_vals) # optimal value
- # get optimal translation as that which minimises the optimal transport, and, if there are multiple solutions of equal OT values, that with the minimal Euclidean shift from starting positions
- poss_sols = [i['x'] for i in iter_res if i['fun']==metric]
- poss_sols_costs = np.array([np.sqrt(x[0]**2 + x[1]**2) for x in poss_sols])
- optimum_idx = np.argwhere(poss_sols_costs==np.min(poss_sols_costs))
- trans_res = poss_sols[int(optimum_idx[0])]
- # check whether rounding to nearest pixel improves the result - if so, return the rounded values
- rounded_res = f(x0 = np.round(trans_res))
-
- if rounded_res < metric:
- res = {'trans': np.round(trans_res), 'metric': rounded_res}
- else:
- res = {'trans': trans_res, 'metric': metric}
- if plot:
- _ = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], plot=plot, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
- if return_res:
- # get all the values requested and merge the dictionaries
- full_res = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
- res = {**res, **full_res}
- elif translation == 'crosscor':
- a_pad, b_pad, shift = translate_ov(a_arr, b_arr, constant_values=constant_values, return_first_only=True)
- # crop to reduce size
- a_cr = utils.crop_zeros(a_pad, a_pad+b_pad)
- b_cr = utils.crop_zeros(b_pad, a_pad+b_pad)
- metric = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot)
- res = {'trans': shift, 'metric': metric}
- if return_res:
- # get all the values requested and merge the dictionaries
- full_res = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, return_res=True)
- res = {**res, **full_res}
- elif translation is None:
- a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr, pad=False)
- metric = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot)
- res = {'trans': (0,0), 'metric': metric}
- if return_res:
- # get all the values requested and merge the dictionaries
- full_res = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)
- res = {**res, **full_res}
- return(res)
- def partial_gromov_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, return_res=False, max_emd_iter=int(1e7), nb_dummies=1, plot=False):
- """Calculate the partial Gromov-Wasserstein metric between two arrays.
-
- Parameters
- ----------
- a_arr : ndarray
- b_arr : ndarray
- scale_mass : bool
- Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported.
- `scale_mass_method` : str
- The method used to scale mass. One of the following:
- 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive.
- 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive.
- 'proportion': Divide masses in both arrays by their sum.
- mass_normalise : bool
- Should the Partial Wasserstein metric be normalised (divided by) total mass transported?
- del_weight : float
- ins_weight : float
- trans_weight : float
- return_res : bool
- Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan).
- max_emd_iter : int
- The maximum number of iterations to allow the EMD solver used by the POT library.
- nb_dummies: int
- The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities.
- plot: bool
- Should the solution be plotted?
-
- Returns
- -------
- float or ndarray
- If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return the transport plan.
- Examples
- --------
- >>> # rotation-invariant
- >>> a = np.zeros((3, 3))
- >>> a[0, :] = 1
- >>> a[1, 2] = 1
- >>> b = np.zeros((3, 3))
- >>> b[0, :] = 1
- >>> b[1, 0] = 1
- >>> partial_gromov_wasserstein(a, b)
- 0.0
- >>> # translation-invariant
- >>> c = utils.pad_translate_mat(b, -1, 0)
- >>> partial_gromov_wasserstein(b, c)
- 0.0
- >>> # compare to
- >>> d = np.zeros((3, 3))
- >>> d[0, :] = 1
- >>> d[2, 1] = 1
- >>> partial_gromov_wasserstein(a, d)
- 0.44546019305093054
- >>> # and
- >>> e = np.zeros((3, 3))
- >>> e[:, 1] = 1
- >>> f = np.zeros((3, 3))
- >>> f[0:2, 1] = 1
- >>> f[2, 2] = 1
- >>> partial_gromov_wasserstein(e, f)
- 0.020330872466366223
- >>> # finally
- >>> import draw
- >>> a_arr = draw.text_array('d', size=50)
- >>> b_arr = draw.text_array('p', size=50)
- >>> partial_gromov_wasserstein(a_arr, b_arr, scale_mass=True)
- 33.843035785027006
- >>> # vs.
- >>> rng = np.random.default_rng()
- >>> noise = a_arr.copy()
- >>> rng.shuffle(noise, axis=0)
- >>> rng.shuffle(noise, axis=1)
- >>> partial_gromov_wasserstein(a_arr, noise, scale_mass=True)
- 397.4592829552584
- """
- # coordinates of mass
- xs = np.transpose(np.array(np.where(a_arr!=0)))
- xt = np.transpose(np.array(np.where(b_arr!=0)))
- # distance kernels
- C1 = sp.spatial.distance.cdist(xs, xs)
- C2 = sp.spatial.distance.cdist(xt, xt)
- # normalise the distance kernels
- C1 /= C1.max()
- C2 /= C2.max()
- # get mass values that correspond to the indices of values (hist)
- p = a_arr.flatten()[a_arr.flatten()!=0]
- q = b_arr.flatten()[b_arr.flatten()!=0]
- if scale_mass:
- # scale such that there is equal total mass between in both arrays
- if scale_mass_method == 'proportion':
- p /= np.sum(p)
- q /= np.sum(q)
- elif scale_mass_method == 'upscale':
- if np.sum(p) > np.sum(q):
- q *= np.sum(p)/np.sum(q)
- elif np.sum(p) < np.sum(q):
- p *= np.sum(q)/np.sum(p)
- elif scale_mass_method == 'downscale':
- if np.sum(p) > np.sum(q):
- p *= np.sum(q)/np.sum(p)
- elif np.sum(p) < np.sum(q):
- q *= np.sum(p)/np.sum(q)
- else:
- ValueError('Unknown mass-scaling method!')
-
- # account for any precision errors (the tolerance may need adjusting!)
- imprec_diff = np.abs(np.sum(p) - np.sum(q))
- assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it
- if np.sum(p) > np.sum(q):
- q[0] += imprec_diff
- elif np.sum(p) < np.sum(q):
- p[0] += imprec_diff
-
- # check they now have equal total mass
- # assert np.sum(p) == np.sum(q) # comment out since we don't care about precision errors too small for the method above to account for
- # amount of mass to be moved should be the maximum possible
- m = np.min([np.sum(p), np.sum(q)])
- pgw_metric = ot.partial.partial_gromov_wasserstein2(C1=C1, C2=C2, p=p, q=q, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies)
- # normalise by mass
- if mass_normalise:
- pgw_metric /= m
- # calculate total cost
- del_cost = del_weight * (p.sum() - q.sum()) if p.sum() > q.sum() else 0.0
- ins_cost = ins_weight * (q.sum() - p.sum()) if q.sum() > p.sum() else 0.0
- trans_cost = trans_weight * pgw_metric
- total_cost = del_cost + ins_cost + trans_cost
-
- # plot
- if plot:
- tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m)
- # not in the same space, but plot as though they are
- s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=0.0)
- # coordinates of mass
- xs = np.transpose(np.array(np.where(s!=0)))
- xt = np.transpose(np.array(np.where(t!=0)))
- pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
- pl_rgb[:, :, 0] = s/s.max()
- pl_rgb[:, :, 2] = t/t.max()
- # pl_rgb[pl_rgb.sum(axis=2)==0, :] = 1
- fig = plt.figure()
- plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
- ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3)
- fig.show()
- if return_res:
- tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m)
- return({'total_cost':total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp})
- # return result
- return(total_cost)
- # import draw
- # import matplotlib.pyplot as plt
- # from string import ascii_lowercase
- # from tqdm import tqdm
- # chars = [draw.text_array(c, size=50) for c in ascii_lowercase]
- # all_res = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random') for c_i in tqdm(chars) for c_j in chars]
- # all_dists = np.array([x['metric'] for x in all_res]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
- # plt.imshow(all_dists, interpolation='none')
- # plt.colorbar()
- # all_res_sc = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random', scale_mass=True) for c_i in tqdm(chars) for c_j in chars]
- # all_dists_sc = np.array([x['metric'] for x in all_res_sc]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
- # plt.imshow(all_dists_sc, interpolation='none')
- # plt.colorbar()
- # all_res_cc = [partial_wasserstein_trans(c_i, c_j, translation='crosscor') for c_i in tqdm(chars) for c_j in chars]
- # all_dists_cc = np.array([x['metric'] for x in all_res_cc]).reshape((len(ascii_lowercase), len(ascii_lowercase)))
- # plt.imshow(all_dists_cc, interpolation='none')
- # plt.colorbar()
- # playing with getting sub-character neighbourhood information from transport plans
- # looks like you get sensible results if you scale the mass when a.sum()>b.sum(), and don't scale the mass when b.sum()>a.sum()
|