"""Functions for array similarity metrics""" import numpy as np from scipy.signal import correlate import scipy as sp from skimage.morphology import binary_dilation import ot import ot.plot import matplotlib.pyplot as plt import matplotlib as mpl import warnings from scold import utils def arr_sim(a, b, measure='all', partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'translation':'opt', 'n_startvals':9, 'solver':'Nelder-Mead', 'search_method':'grid'}): """Calculate a measure of array similarity for 2 (aligned) binary arrays of equal shape. Parameters ---------- a : ndarray b : ndarray measure : str one of the following, to select the desired measure: 'all' returns a dictionary of all values, as well as the intersection and union. Because of speed, partial Wasserstein is not included. 'overlap' returns the Overlap Coefficient 'jaccard' returns Jacard Similarity 'dice' returns the Sørensen-Dice Coefficient 'px_dist' returns the total pixel distance after accounting for the union 'partial_wasserstein' returns the (partial) Wasserstein distance, with translation optionally optimised via partial_wasserstein_trans, and with kwargs passed via partial_wasserstein_kwargs Returns ------- float or dict Returns a float for the requested similarity value, or a dictionary with all possible values if `measure == 'all'`, or a dictionary in the case of 'partial_wasserstein'. """ # early return for partial wasserstein if measure=='partial_wasserstein': return( partial_wasserstein_trans(a, b, **partial_wasserstein_kwargs) ) # intersect = np.sum(a+b==2) # area of intersection for the binary case intersect = np.sum(np.minimum(a, b)) # area of intersection of A and B area_a = np.sum(a) area_b = np.sum(b) union = (area_a + area_b - intersect) if measure == 'all': return({ 'intersection': intersect, 'union': union, # 'area_a': area_a, # 'area_b': area_b, 'overlap': intersect / min((area_a, area_b)), 'jaccard': intersect / union, 'dice': (intersect / (area_a + area_b)) * 2, 'px_dist': px_dist(a, b, translate=False)}) elif measure=='overlap': return( intersect / min((area_a, area_b)) ) elif measure=='jaccard': return( intersect / (area_a + area_b - intersect) ) elif measure=='dice': return( (intersect / (area_a + area_b)) * 2 ) # multiply by two to normalise within [0, 1] for consistency with other measures elif measure=='px_dist': return( px_dist(a, b, translate=False) ) else: return( np.nan ) def translate_ov(a, b, keep_a_constant=True, return_first_only=False, constant_values=0): """Maximise array overlap, permitting translation only, via cross-correlation. Parameters ---------- a : ndarray b : ndarray keep_a_constant : bool Refers to the method for aligning the matrices a and b - if False, a and b are both padded to achieve the overlap, if True, a remains the same size, and b is padded from one side and trimmed from the other side. return_first_only : bool Two arrays may have multiple solutions with the same maximal overlap. Should only the first solution be returned? If False, will return list of lists of solutions. constant_values : float or list or array Passed to `np.pad`. Will usually want to be the value for the background (default = 0). Returns ------- list Returns a list in form `[a, b, shift]`, where `a` and `b` are the aligned matrices, and `shift` contains translation from default location in coordinates `(x, y)`, where `(0, 0)` would be the default location. If `return_first_only` is `True`, will return a list of solutions that reach the same maximum, in form `[[a, b, shift], [a, b, shift], ...]`. """ a, b = utils.pad_for_translation(a, b, constant_values=constant_values) # cross correlate ab_corr = np.round(correlate(a, b, method='fft', mode='same'), 5) # rounded for floating point precision in fft max_corr_idx = np.where(ab_corr == np.max(ab_corr)) # translation from centre, for each identified overlap shift = [ np.negative(np.round((a.shape[0])/2 - max_corr_idx[0])).astype(int), np.negative(np.round((a.shape[1])/2 - max_corr_idx[1])).astype(int) ] # calculate Euclidean distance from centre, of each overlap euc_dist = np.array([np.sqrt(shift[0][i]**2 + shift[1][i]**2) for i in range(shift[0].shape[0])]) euc_dist_order = euc_dist.argsort() # sort the original euclidean distance values euc_dist = euc_dist[euc_dist_order] # sort overlaps ascendingly in order of Euclidean distance, so that the solutions requiring minimal translation are first shift[0] = shift[0][euc_dist_order] shift[1] = shift[1][euc_dist_order] if return_first_only: ov_range = 1 else: ov_range = shift[0].shape[0] res = [] for ov_nr in range(ov_range): sh_i = (shift[0][ov_nr], shift[1][ov_nr]) if keep_a_constant: if sh_i[0]>0: x_pad = (sh_i[0], 0) b_trim_x = b[:-sh_i[0], :] elif sh_i[0]<0: x_pad = (0, np.abs(sh_i[0])) b_trim_x = b[np.abs(sh_i[0]):, :] else: x_pad = (0, 0) b_trim_x = b if sh_i[1]>0: y_pad = (sh_i[1], 0) b_trim_xy = b_trim_x[:, :-sh_i[1]] elif sh_i[1]<0: y_pad = (0, np.abs(sh_i[1])) b_trim_xy = b_trim_x[:, np.abs(sh_i[1]):] else: y_pad = (0, 0) b_trim_xy = b_trim_x b_pad = np.pad(b_trim_xy, (x_pad, y_pad), constant_values=constant_values) if return_first_only: res = [a, b_pad, sh_i] else: res.append([a, b_pad, sh_i]) else: if sh_i[0]>0: x_pad = (sh_i[0], 0) elif sh_i[0]<0: x_pad = (0, np.abs(sh_i[0])) else: x_pad = (0, 0) if sh_i[1]>0: y_pad = (sh_i[1], 0) elif sh_i[1]<0: y_pad = (0, np.abs(sh_i[1])) else: y_pad = (0, 0) b_pad = np.pad(b, (x_pad, y_pad), constant_values=constant_values) a_pad = np.pad(a, (np.flip(x_pad), np.flip(y_pad)), constant_values=constant_values) if return_first_only: res = [a_pad, b_pad, sh_i] else: res.append([a_pad, b_pad, sh_i]) return(res) def px_dist(a_arr, b_arr, partial=False, translate=True, constant_values=0, return_res=False, plot=False): """Calculate the number of pixels needed to be added/deleted to make one array into the other, at the arrays' positions of optimal overlap from `translate_ov()`. Default behaviour will also weight costs by pixel values if not binarised. Parameters ---------- a_arr : ndarray b_arr : ndarray partial : bool Should the distance be partial pixel distance? If `True`, will calculate number of pixels from a that cannot be explained by b, but will ignore pixels in b that cannot be explained by a. translate : bool Should the distance be calculated at positions of maximum overlap (from `translate_ov()`)? If `False`, will just use default locations. constant_values : float or list or array Passed to `np.pad()` by `translate_ov()`. Will usually want to be the value for the background (default = 0). return_res : bool Should the function return a dictionary of results split by direction of addition/deletion? If True, will return list with regular result as first entry, and a dictionary of results split by direction as second entry. plot : bool Should the result be plotted? Returns ------- float or list Returns a float for the requested distance value, or, if `return_res` is `True`, a list whose first entry is the overall distance and whose second entry is a dictionary with the results split by array ('a_addition' and 'b_addition'). Examples -------- >>> a = np.zeros((4, 4)) >>> a[:, 3] = 1 >>> b = np.zeros((4, 5)) >>> b[:, 1] = 1 >>> b[2, 3] = 0 >>> b[1, 0:2] = 1 >>> px_dist(a, b, return_res=True) [1.0, {'a_addition': 0.0, 'b_addition': 1.0}] """ if translate: a_aligned, b_aligned, *_ = translate_ov(a_arr, b_arr, return_first_only=True, constant_values=constant_values) else: a_aligned = a_arr b_aligned = b_arr arrs_diff = a_aligned-b_aligned if partial: arrs_diff[arrs_diff>0] = 0 if plot: plt.imshow(utils.crop_zeros(x=arrs_diff, y=a_aligned+b_aligned), interpolation='none', cmap='RdBu', norm=mpl.colors.TwoSlopeNorm(vmin=-np.abs(arrs_diff).max(), vcenter=0, vmax=np.abs(arrs_diff).max())) plt.colorbar() plt.title('a - b') if return_res: res = {'a_addition': np.sum(np.abs(arrs_diff[arrs_diff>0])), 'b_addition': np.sum(np.abs(arrs_diff[arrs_diff<0]))} return [np.sum(np.abs(arrs_diff)), res] else: return np.sum(np.abs(arrs_diff)) def partial_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, trans_manual=(0,0), constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, plot=False): """Calculate the partial Wasserstein metric between two arrays, with optional translation applied first. Parameters ---------- a_arr : ndarray b_arr : ndarray scale_mass : bool Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported. `scale_mass_method` : str The method used to scale mass. One of the following: 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive. 'proportion': Divide masses in both arrays by their sum. mass_normalise : bool Should the Partial Wasserstein metric be normalised (divided by) total mass transported? distance_normalise : bool Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.) del_weight : float ins_weight : float trans_weight : float entropic_reg_term : float Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution. return_plan : bool Should the function return the transport plan instead of the metric? return_res : bool Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan). trans_manual : tuple A tuple of the form `(x, y)`, where `x` and `y` can be floats, for a translation that should be applied to `a_arr` prior to calculating the metric. constant_values : float or list or array Passed to `np.pad()`. Will usually want to be the value for the background (default = 0). distance_metric : str Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options. max_emd_iter : int The maximum number of iterations to allow the EMD solver used by the POT library. max_entropic_iter : int The maximum number of iterations to allow the entropic OT solver used by the POT library. nb_dummies: int The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0. plot: bool Should the solution be plotted? Returns ------- float or tuple If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return a tuple containing the metric (first element) and the transport plan (second element). Examples -------- >>> a = np.zeros((3, 3)) >>> a[0, :] = 1 >>> a[2, 2] = 1 >>> b = np.zeros((3, 3)) >>> b[0, :] = 1 >>> b[2, 0] = 1 >>> partial_wasserstein(a, b) 2.0 >>> # permits continuous translation (not just whole pixels) >>> partial_wasserstein(a, b, trans_manual=(0.01, -0.2)) 2.8007722589903596 """ # pad and assign to source and target arrays s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=constant_values) # coordinates of mass xs = np.transpose(np.array(np.where(s!=0))) xt = np.transpose(np.array(np.where(t!=0))) # get mass values that correspond to the indices of values # (first need indices in numpy-friendly format) # s_idx = tuple(xs.transpose()) # s_hist = s[s_idx] # t_idx = tuple(xt.transpose()) # t_hist = t[t_idx] s_hist = s.flatten()[s.flatten()!=0] t_hist = t.flatten()[t.flatten()!=0] if scale_mass: # scale such that there is equal total mass between in both arrays if scale_mass_method == 'proportion': s_hist /= s_hist.sum() t_hist /= t_hist.sum() elif scale_mass_method == 'upscale': if s_hist.sum() > t_hist.sum(): t_hist *= s_hist.sum()/t_hist.sum() elif s_hist.sum() < t_hist.sum(): s_hist *= t_hist.sum()/s_hist.sum() elif scale_mass_method == 'downscale': if s_hist.sum() > t_hist.sum(): s_hist *= t_hist.sum()/s_hist.sum() elif s_hist.sum() < t_hist.sum(): t_hist *= s_hist.sum()/t_hist.sum() else: ValueError('Unknown mass-scaling method!') # account for any precision errors (the tolerance may need adjusting!) imprec_diff = np.abs(s_hist.sum() - t_hist.sum()) assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it if s_hist.sum() > t_hist.sum(): t_hist[0] += imprec_diff elif s_hist.sum() < t_hist.sum(): s_hist[0] += imprec_diff # check they now have equal total mass # assert s_hist.sum() == t_hist.sum() # comment out since we don't care about precision errors too small for the method above to account for # apply the translation xs_t = np.float64(xs.copy()) xs_t[:, 0] -= trans_manual[0] # -= for consistency with behaviour of utils.pad_translate_mat() xs_t[:, 1] -= trans_manual[1] # loss matrix M = ot.dist(xs_t, xt, metric=distance_metric) if distance_normalise: M /= M.max() # calculate partial wasserstein m = np.min([s_hist.sum(), t_hist.sum()]) # mass to transport if entropic_reg_term == 0: pw_metric = ot.partial.partial_wasserstein2(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies) else: pw_metric = np.sum(M * ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter)) # normalise by mass if mass_normalise: pw_metric /= m # calculate total cost del_cost = del_weight * (s_hist.sum() - t_hist.sum()) if s_hist.sum() > t_hist.sum() else 0.0 ins_cost = ins_weight * (t_hist.sum() - s_hist.sum()) if t_hist.sum() > s_hist.sum() else 0.0 trans_cost = trans_weight * pw_metric total_cost = del_cost + ins_cost + trans_cost # plot if plot: if np.any(np.array(trans_manual) != 0): warnings.warn('The plot will only display *original* translation values, as translation is optimised continuously') if entropic_reg_term == 0: tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies) else: tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_entropic_iter) pl_rgb = np.zeros((s.shape[0], s.shape[1], 3)) pl_rgb[:, :, 0] = s/s.max() pl_rgb[:, :, 2] = t/t.max() fig = plt.figure() plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none') ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3) fig.show() # return transport plan if requested if return_res: if entropic_reg_term == 0: tp = ot.partial.partial_wasserstein(s_hist, t_hist, M=M, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies) else: tp = ot.partial.entropic_partial_wasserstein(s_hist, t_hist, M=M, reg=entropic_reg_term, m=m, numItermax=max_emd_iter) return({'total_cost': total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp}) # return result return(total_cost) def partial_wasserstein_trans(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, distance_normalise=False, translation='opt', del_weight=0.0, ins_weight=0.0, trans_weight=1.0, entropic_reg_term=0.0, return_res=False, plot=False, constant_values=0, distance_metric='Euclidean', max_emd_iter=int(1e7), max_entropic_iter=int(1e5), nb_dummies=1, n_startvals=7, solver='Nelder-Mead', search_method='grid', options=None): """Calculate the partial Wasserstein metric between two arrays, with translation permitted. Parameters ---------- a_arr : ndarray b_arr : ndarray scale_mass : bool Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported. `scale_mass_method` : str The method used to scale mass. One of the following: 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive. 'proportion': Divide masses in both arrays by their sum. mass_normalise : bool Should the Partial Wasserstein metric be normalised (divided by) total mass transported? distance_normalise : bool Should the distance matrix be normalised (divided by) the maximum distance in the distance matrix? (This is done prior to solving optimal transport.) translation : str Method for aligning the two arrays prior to distance calulation. Options are: 'opt': Use non-linear optimisation to find the overlay that minimises the partial Wasserstein distance. This method will be slow. 'crosscor': Use cross-correlation to align the matrices at the location of their maximal overlap. None: use the default positions. del_weight : float ins_weight : float trans_weight : float entropic_reg_term : float Regularisation term for entropic partial Wasserstein. If entropic_reg_term==0, will use the standard EMD solver for the exact solution. return_res : bool # Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining the usual `'trans'` and `'metric'`, but also `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, and `'tp'` (transport plan). constant_values : float or list or array Passed to `np.pad()`. Will usually want to be the value for the background (default = 0). distance_metric : str Which distance metric to use (default='Euclidean'). See ?ot.dist documentation for options. max_emd_iter : int The maximum number of iterations to allow the EMD solver used by the POT library. max_entropic_iter : int The maximum number of iterations to allow the entropic OT solver used by the POT library. nb_dummies: int The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. Ignored if entropic_reg_term!=0. plot: bool Should the solution be plotted? n_startvals : int The number of starting values to try in optimmising translation, if `translation=='opt'`. solver : str The solver to use in optimisation, if `translation='opt'`. Possible values are those available to `scipy.optimize.minimize()`. search_method : str Method for setting starting values if `translation=='opt'`. Options are: 'grid': set in equal steps from the lower to the upper bound 'random': set randomly between the lower and upper bound options : dict Options to pass to the solver. E.g., `{'maxiter': 100}`. Returns ------- dict If `return_plan` is `False`, will return a `dict` containing `'trans'` (the translation used), and `'metric'`. If `return_plan` is `True`, this `dict` will also contain the transport plan. Examples -------- >>> # compare solutions - crosscorrelation maximises overlap; optimiser minimises distance cost >>> import draw, utils >>> a_arr = draw.text_array('d', size=50) >>> b_arr = draw.text_array('p', size=50) >>> # padding is only used for the plotting; unpadded result will be identical >>> a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr) >>> cc_sol = partial_wasserstein_trans(a_pad, b_pad, translation='crosscor') >>> pw_sol = partial_wasserstein_trans(a_pad, b_pad, translation='opt') >>> import matplotlib.pyplot as plt >>> plt.imshow(utils.pad_translate_mat(a_pad, cc_sol['trans'][0], cc_sol['trans'][1]) + b_pad) >>> plt.imshow(utils.pad_translate_mat(a_pad, int(pw_sol['trans'][0]), int(pw_sol['trans'][1])) + b_pad) """ if translation == 'opt': # find the best translation using non-linear optimisation to minimise distance # (note that the translation method used in partial_wasserstein_trans() permits float changes, i.e., not just whole pixels) max_shifts = [np.max([a_arr.shape[i], b_arr.shape[i]]) for i in range(len(a_arr.shape))] bounds = [(-sh, sh) for sh in max_shifts] if search_method=='grid': start_vals_arr = np.array([np.linspace(-sh, sh, num=n_startvals, endpoint=True) for sh in max_shifts]) elif search_method=='random': start_vals_arr = np.array([np.random.uniform(-sh, sh, size=n_startvals) for sh in max_shifts]) start_vals_tuples = [(start_vals_arr[0, i], start_vals_arr[1, i]) for i in range(start_vals_arr.shape[1])] iter_res = [] def f(x0, a_arr=a_arr, b_arr=b_arr): return(partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=x0, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies)) for st in start_vals_tuples: iter_res.append(sp.optimize.minimize( f, x0=st, method=solver, bounds=bounds, options=options )) # get results of each iteration from the search fun_vals = np.array([i['fun'] for i in iter_res]) metric = np.min(fun_vals) # optimal value # get optimal translation as that which minimises the optimal transport, and, if there are multiple solutions of equal OT values, that with the minimal Euclidean shift from starting positions poss_sols = [i['x'] for i in iter_res if i['fun']==metric] poss_sols_costs = np.array([np.sqrt(x[0]**2 + x[1]**2) for x in poss_sols]) optimum_idx = np.argwhere(poss_sols_costs==np.min(poss_sols_costs)) trans_res = poss_sols[int(optimum_idx[0])] # check whether rounding to nearest pixel improves the result - if so, return the rounded values rounded_res = f(x0 = np.round(trans_res)) if rounded_res < metric: res = {'trans': np.round(trans_res), 'metric': rounded_res} else: res = {'trans': trans_res, 'metric': metric} if plot: _ = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], plot=plot, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies) if return_res: # get all the values requested and merge the dictionaries full_res = partial_wasserstein(a_arr=a_arr, b_arr=b_arr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies) res = {**res, **full_res} elif translation == 'crosscor': a_pad, b_pad, shift = translate_ov(a_arr, b_arr, constant_values=constant_values, return_first_only=True) # crop to reduce size a_cr = utils.crop_zeros(a_pad, a_pad+b_pad) b_cr = utils.crop_zeros(b_pad, a_pad+b_pad) metric = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot) res = {'trans': shift, 'metric': metric} if return_res: # get all the values requested and merge the dictionaries full_res = partial_wasserstein(a_arr=a_cr, b_arr=b_cr, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, return_res=True) res = {**res, **full_res} elif translation is None: a_pad, b_pad = utils.pad_for_translation(a_arr, b_arr, pad=False) metric = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=(0,0), constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies, plot=plot) res = {'trans': (0,0), 'metric': metric} if return_res: # get all the values requested and merge the dictionaries full_res = partial_wasserstein(a_arr=a_pad, b_arr=b_pad, scale_mass=scale_mass, scale_mass_method=scale_mass_method, mass_normalise=mass_normalise, distance_normalise=distance_normalise, del_weight=del_weight, ins_weight=ins_weight, trans_weight=trans_weight, entropic_reg_term=entropic_reg_term, trans_manual=res['trans'], return_res=True, constant_values=constant_values, distance_metric=distance_metric, max_emd_iter=max_emd_iter, max_entropic_iter=max_entropic_iter, nb_dummies=nb_dummies) res = {**res, **full_res} return(res) def partial_gromov_wasserstein(a_arr, b_arr, scale_mass=False, scale_mass_method='upscale', mass_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, return_res=False, max_emd_iter=int(1e7), nb_dummies=1, plot=False): """Calculate the partial Gromov-Wasserstein metric between two arrays. Parameters ---------- a_arr : ndarray b_arr : ndarray scale_mass : bool Should mass be scaled prior to calculating distance? If `True`, masses in `a_arr` and `b_arr` scaled prior solving the optimal transport plan, with the method provided in `scale_mass_method`. If `scale_mass` is `True`, then `del_weight` and `ins_weight` will have no effect on the calculated distance, as no mass is being added or removed, only transported. `scale_mass_method` : str The method used to scale mass. One of the following: 'upscale': Upscale the masses in the more massive array to sum to the same amount as in the less massive. 'downscale': Downscale the masses in the less massive array to sum to the same amount as in the more massive. 'proportion': Divide masses in both arrays by their sum. mass_normalise : bool Should the Partial Wasserstein metric be normalised (divided by) total mass transported? del_weight : float ins_weight : float trans_weight : float return_res : bool Should the function return the transport plan, trans_cos, ins_cost, and del_cost, as well as the metric? If `True`, will return a `dict()` cotaining `'total_cost'`, `'trans_cost'`, `'ins_cost'`, `'del_cost'`, `'tp'` (transport plan). max_emd_iter : int The maximum number of iterations to allow the EMD solver used by the POT library. nb_dummies: int The number of dummy points used by the EMD solver of the POT library. Can avoid instabilities. plot: bool Should the solution be plotted? Returns ------- float or ndarray If `return_plan` is `False`, will return a `float` with the metric. If `return_plan` is `True`, will return the transport plan. Examples -------- >>> # rotation-invariant >>> a = np.zeros((3, 3)) >>> a[0, :] = 1 >>> a[1, 2] = 1 >>> b = np.zeros((3, 3)) >>> b[0, :] = 1 >>> b[1, 0] = 1 >>> partial_gromov_wasserstein(a, b) 0.0 >>> # translation-invariant >>> c = utils.pad_translate_mat(b, -1, 0) >>> partial_gromov_wasserstein(b, c) 0.0 >>> # compare to >>> d = np.zeros((3, 3)) >>> d[0, :] = 1 >>> d[2, 1] = 1 >>> partial_gromov_wasserstein(a, d) 0.44546019305093054 >>> # and >>> e = np.zeros((3, 3)) >>> e[:, 1] = 1 >>> f = np.zeros((3, 3)) >>> f[0:2, 1] = 1 >>> f[2, 2] = 1 >>> partial_gromov_wasserstein(e, f) 0.020330872466366223 >>> # finally >>> import draw >>> a_arr = draw.text_array('d', size=50) >>> b_arr = draw.text_array('p', size=50) >>> partial_gromov_wasserstein(a_arr, b_arr, scale_mass=True) 33.843035785027006 >>> # vs. >>> rng = np.random.default_rng() >>> noise = a_arr.copy() >>> rng.shuffle(noise, axis=0) >>> rng.shuffle(noise, axis=1) >>> partial_gromov_wasserstein(a_arr, noise, scale_mass=True) 397.4592829552584 """ # coordinates of mass xs = np.transpose(np.array(np.where(a_arr!=0))) xt = np.transpose(np.array(np.where(b_arr!=0))) # distance kernels C1 = sp.spatial.distance.cdist(xs, xs) C2 = sp.spatial.distance.cdist(xt, xt) # normalise the distance kernels C1 /= C1.max() C2 /= C2.max() # get mass values that correspond to the indices of values (hist) p = a_arr.flatten()[a_arr.flatten()!=0] q = b_arr.flatten()[b_arr.flatten()!=0] if scale_mass: # scale such that there is equal total mass between in both arrays if scale_mass_method == 'proportion': p /= np.sum(p) q /= np.sum(q) elif scale_mass_method == 'upscale': if np.sum(p) > np.sum(q): q *= np.sum(p)/np.sum(q) elif np.sum(p) < np.sum(q): p *= np.sum(q)/np.sum(p) elif scale_mass_method == 'downscale': if np.sum(p) > np.sum(q): p *= np.sum(q)/np.sum(p) elif np.sum(p) < np.sum(q): q *= np.sum(p)/np.sum(q) else: ValueError('Unknown mass-scaling method!') # account for any precision errors (the tolerance may need adjusting!) imprec_diff = np.abs(np.sum(p) - np.sum(q)) assert imprec_diff < 1e-8 # check the imprecision is reasonably small before adjusting for it if np.sum(p) > np.sum(q): q[0] += imprec_diff elif np.sum(p) < np.sum(q): p[0] += imprec_diff # check they now have equal total mass # assert np.sum(p) == np.sum(q) # comment out since we don't care about precision errors too small for the method above to account for # amount of mass to be moved should be the maximum possible m = np.min([np.sum(p), np.sum(q)]) pgw_metric = ot.partial.partial_gromov_wasserstein2(C1=C1, C2=C2, p=p, q=q, m=m, numItermax=max_emd_iter, nb_dummies=nb_dummies) # normalise by mass if mass_normalise: pgw_metric /= m # calculate total cost del_cost = del_weight * (p.sum() - q.sum()) if p.sum() > q.sum() else 0.0 ins_cost = ins_weight * (q.sum() - p.sum()) if q.sum() > p.sum() else 0.0 trans_cost = trans_weight * pgw_metric total_cost = del_cost + ins_cost + trans_cost # plot if plot: tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m) # not in the same space, but plot as though they are s, t = utils.pad_for_translation(a_arr, b_arr, pad=False, constant_values=0.0) # coordinates of mass xs = np.transpose(np.array(np.where(s!=0))) xt = np.transpose(np.array(np.where(t!=0))) pl_rgb = np.zeros((s.shape[0], s.shape[1], 3)) pl_rgb[:, :, 0] = s/s.max() pl_rgb[:, :, 2] = t/t.max() # pl_rgb[pl_rgb.sum(axis=2)==0, :] = 1 fig = plt.figure() plt.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none') ot.plot.plot2D_samples_mat(xs, xt, G=tp, color='black', thr=1e-3) fig.show() if return_res: tp = ot.partial.partial_gromov_wasserstein(C1=C1, C2=C2, p=p, q=q, m=m) return({'total_cost':total_cost, 'trans_cost':trans_cost, 'ins_cost':ins_cost, 'del_cost':del_cost, 'tp':tp}) # return result return(total_cost) # import draw # import matplotlib.pyplot as plt # from string import ascii_lowercase # from tqdm import tqdm # chars = [draw.text_array(c, size=50) for c in ascii_lowercase] # all_res = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random') for c_i in tqdm(chars) for c_j in chars] # all_dists = np.array([x['metric'] for x in all_res]).reshape((len(ascii_lowercase), len(ascii_lowercase))) # plt.imshow(all_dists, interpolation='none') # plt.colorbar() # all_res_sc = [partial_wasserstein_trans(c_i, c_j, translation='opt', n_startvals=3, search_method='random', scale_mass=True) for c_i in tqdm(chars) for c_j in chars] # all_dists_sc = np.array([x['metric'] for x in all_res_sc]).reshape((len(ascii_lowercase), len(ascii_lowercase))) # plt.imshow(all_dists_sc, interpolation='none') # plt.colorbar() # all_res_cc = [partial_wasserstein_trans(c_i, c_j, translation='crosscor') for c_i in tqdm(chars) for c_j in chars] # all_dists_cc = np.array([x['metric'] for x in all_res_cc]).reshape((len(ascii_lowercase), len(ascii_lowercase))) # plt.imshow(all_dists_cc, interpolation='none') # plt.colorbar() # playing with getting sub-character neighbourhood information from transport plans # looks like you get sensible results if you scale the mass when a.sum()>b.sum(), and don't scale the mass when b.sum()>a.sum()