text_arr_sim_wasserstein.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. """Functions for calculating array similarity metrics for text"""
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4. from scipy.optimize import minimize
  5. from scold import draw
  6. from scold import arr_sim
  7. from scold import utils
  8. def text_arr_sim(a, b=None, font_a='arial.ttf', font_b='arial.ttf', b_arr=None, measure='partial_wasserstein', translate=None, fliplr=False, flipud=False, size=100, scale_val=1.0, rotate_val=0.0, translate_prop_val_x=0.0, translate_prop_val_y=0.0, plot=False, partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'ins_weight': 0.0, 'del_weight': 0.0}, **kwargs):
  9. """Calculate similarity metrics for two strings of text, translated to achieve optimal overlap.
  10. Parameters
  11. ----------
  12. a : str
  13. b : str, optional
  14. Must be defined if `b_arr` is not. Ignored if `b_arr` is defined.
  15. font_a : str, optional
  16. `.ttf` font to use to build text array from `a`
  17. font_b : str, optional
  18. `.ttf` font to use to build text array from `b`
  19. b_arr : ndarray, optional
  20. Array that the array built from `a` will be compared to. This option is included as it is faster to pre-build the array that `a` will be compared to, if applying this function in a loop. If both `b` and `b_arr` are defined, `b` will be ignored.
  21. measure : str
  22. Argument not used. Included for consistency with `text_arr_sim.text_arr_sim()`.
  23. translate : bool
  24. Argument not used. Included for consistency with `text_arr_sim.text_arr_sim()`.
  25. fliplr : bool
  26. Should the text built from `a` be mirrored horizontally?
  27. flipud : bool
  28. Should the text built from `a` be mirrored vertically?
  29. size : int
  30. The size of the text to draw.
  31. scale_val : float
  32. This is multiplied by `size` to calculate the size of the text array built from `a`, rounded to the nearest pixel. If `b_arr` is not pre-built, `b_arr` is built at size `size`, such that `scale_val` says how many times bigger `a_arr` should be than `b_arr`.
  33. rotate_val : float
  34. Degrees by which the text built from `a` should be rotated.
  35. translate_prop_val_x : float
  36. Translation in the x axis for the text built from `a`, given as a proportion of the max size of `a_arr` and `b_arr` in that dimension. If negative, translation is in a negative direction.
  37. translate_prop_val_y : float
  38. Translation in the y axis for the text built from `a`, given as a proportion of the max size of `a_arr` and `b_arr` in that dimension. If negative, translation is in a negative direction.
  39. plot : bool
  40. Should the solution be plotted?
  41. partial_wasserstein_kwargs : dict
  42. kwargs to be passed to arr_sim.partial_wasserstein().
  43. **kwargs
  44. Other arguments to pass to `draw.text_array()`.
  45. Returns
  46. -------
  47. dict
  48. Returns a dictionary with the similarity metrics calculated in `arr_sim.arr_sim()`, with an additional entry, `'shift'`, which contains the optimal translation values calculated by `arr_sim.translate_ov`, in form `(x, y)`.
  49. Examples
  50. --------
  51. >>> text_arr_sim('d', 'p')
  52. {'jaccard': 0.5915244261330195, 'shift': (1, 19)}
  53. >>> text_arr_sim('d', 'c')
  54. {'jaccard': 0.6390658174097664, 'shift': (1, 0)}
  55. >>> text_arr_sim('d', 'p', measure='partial_wasserstein', partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'translation':'crosscor', 'n_startvals':7, 'solver':'Nelder-Mead', 'search_method':'grid'})
  56. {'partial_wasserstein': 0.11199633634025599, 'shift': (1, 19)}
  57. >>> text_arr_sim('d', 'p', measure='partial_wasserstein', partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise':True, 'translation':'opt', 'n_startvals':7, 'solver':'Nelder-Mead', 'search_method':'grid'})
  58. {'partial_wasserstein': 0.07609689664769317, 'shift': (5.0, 11.0)}
  59. """
  60. if measure!='partial_wasserstein':
  61. raise ValueError('The measure argument is only included for consistency with text_arr_sim.text_arr_sim. The only acceptable value for text_arr_sim_wasserstein is partial_wasserstein.')
  62. if translate is not None:
  63. raise ValueError('The translate argument is only included for consistency with text_arr_sim.text_arr_sim.')
  64. assert translate_prop_val_x>=-1 and translate_prop_val_x<=1 and translate_prop_val_y>=-1 and translate_prop_val_y<=1
  65. a_arr = draw.text_array(a, font=font_a, rotate=rotate_val, fliplr=fliplr, flipud=flipud, size=size*scale_val, **kwargs)
  66. if np.all(b_arr == None):
  67. b_arr = draw.text_array(b, font=font_b, rotate=0, fliplr=False, flipud=False, size=size, **kwargs) # faster to pre-define and give as argument if using the same b text in a loop
  68. # calculate the partial wasserstein value
  69. # map proportion to raw shift values, and pass as arguments to the partial wasserstein fun
  70. max_shifts = [np.max([a_arr.shape[i], b_arr.shape[i]]) for i in range(len(a_arr.shape))]
  71. translate_val_x = np.round(translate_prop_val_x * max_shifts[0], 5)
  72. translate_val_y = np.round(translate_prop_val_y * max_shifts[1], 5)
  73. partial_wasserstein_kwargs['trans_manual'] = (translate_val_x, translate_val_y)
  74. sim_res = {}
  75. partial_wasserstein_kwargs['plot'] = plot
  76. if plot:
  77. # flip the indices for the plot
  78. a_arr=a_arr.T
  79. b_arr=b_arr.T
  80. sim_res[measure] = arr_sim.partial_wasserstein(a_arr, b_arr, **partial_wasserstein_kwargs)
  81. # add shift to the results
  82. sim_res['shift'] = (translate_val_x, translate_val_y)
  83. return(sim_res)
  84. def _opt_text_arr_sim_flip_manual(a='a', b=None, font_a='arial.ttf', font_b='arial.ttf', b_arr=None, measure='partial_wasserstein', translate=True, scale=True, rotate=True, fliplr=False, flipud=False, size=100, rotation_bounds=(-np.Infinity, np.Infinity), max_scale_change_factor=2, max_translation_factor=0.999, rotation_eval_n=9, scale_eval_n=9, translation_eval_n=9, solver='Nelder-Mead', search_method='grid', plot=False, partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise': True, 'ins_weight':0.0, 'del_weight':0.0}, **kwargs):
  85. """Find parameters for geometric operations of translation, scale, and rotation that minimise partial wasserstein between two arrays of drawn text.
  86. Parameters
  87. ----------
  88. a : str
  89. b : str, optional
  90. Must be defined if `b_arr` is not. Ignored if `b_arr` is defined.
  91. font_a : str, optional
  92. `.ttf` font to use to build text array from `a`
  93. font_b : str, optional
  94. `.ttf` font to use to build text array from `b`
  95. b_arr : ndarray, optional
  96. Array that the array built from `a` will be compared to. This option is included as it is faster to pre-build the array that `a` will be compared to, if applying this function in a loop. If both `b` and `b_arr` are defined, `b` will be ignored.
  97. measure : str
  98. Argument not used. Included for consistency with `text_arr_sim.text_arr_sim()`.
  99. translate : bool
  100. Should the translation operation be optimised? If `False`, will always use default positions.
  101. scale: bool
  102. Should scale be optimised?
  103. rotate : bool
  104. Should rotation be optimised?
  105. fliplr : bool
  106. Should `b` be flipped horizontally? Note that in this version of the function, `b` is not optimised. Instead, this function can be run with this set to True and False, and the best result taken.
  107. flipud : bool
  108. Should `b` be flipped vertically? Note that in this version of the function, `b` is not optimised. Instead, this function can be run with this set to True and False, and the best result taken.
  109. size : int
  110. Size for the text (the scale parameter will be multiplied by this value).
  111. rotation_bounds : tuple
  112. Limits for optimising rotation in form `(lowerbound, upperbound)`. For example, `(-90, 90)` will limit rotation to 90 degrees in either direction.
  113. max_scale_change_factor : float
  114. Maximum value for the optimised scale parameter. `max_scale_change_factor=2` will permit 100% bigger or 50% smaller, i.e., twice as large or twice as small.
  115. max_translation_factor : float
  116. Maximum value for the optimised translation parameter. `max_translation_factor=0.5` will permit translation 50% of the maximum bounds in any direction, where bounds are determined by the maximum size of the arrays being compared. Must be <1, as it is optimised on a logistic scale.
  117. rotation_eval_n : int
  118. How many starting values should be tried for optimising rotation?
  119. scale_eval_n : int
  120. How many starting values should be tried for optimising scale?
  121. translation_eval_n : int
  122. How many starting values should be tried for optimising translation (x and y)?
  123. solver : str
  124. Which solver to use? Possible values are those available to `scipy.optimize.minimize()`.
  125. search_method : str
  126. Method for setting starting values. Options are:
  127. 'grid': set in equal steps from the lower to the upper bound
  128. 'random': set randomly between the lower and upper bound
  129. plot : bool
  130. Should the optimal solution be plotted?
  131. partial_wasserstein_kwargs : dict
  132. kwargs to be passed to arr_sim.partial_wasserstein() or arr_sim.partial_wasserstein_trans()
  133. **kwargs
  134. Other arguments to pass to `text_arr_sim()`.
  135. Returns
  136. -------
  137. dict
  138. A dictionary containing the following values:
  139. 'translate': Whether translation was optimised
  140. 'scale': Whether scale was optimised
  141. 'rotate': Whether rotation was optimised
  142. 'fliplr': Placeholder for main function (always `None`)
  143. 'flipud': Placeholder for main function (always `None`)
  144. 'intersection', 'union', 'overlap', 'jaccard', 'dice', etc.: Values from `arr_sim.arr_sim()`
  145. 'translate_val_x': Optimal shift value in x dimension
  146. 'translate_val_y': Optimal shift value in y dimension
  147. 'scale_val': Optimal scale coefficient
  148. 'rotate_val': Optimal rotation coefficient
  149. 'flip_val': Whether the array was slipped horizontally
  150. """
  151. if measure!='partial_wasserstein':
  152. raise ValueError('The measure argument is only included for consistency with text_arr_sim.text_arr_sim. The only acceptable value for text_arr_sim_wasserstein is partial_wasserstein.')
  153. if np.all(b_arr == None):
  154. b_arr = draw.text_array(b, font=font_b, rotate=0, fliplr=False, flipud=False, size=size, **kwargs)
  155. # if neither scale, rotation, nor translation need to be optimised, just use the cross correlation approach to get optimal cold values...
  156. if (not scale) and (not rotate) and (not translate):
  157. sim_res = text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=1, rotate_val=0, translate_prop_val_x=0, translate_prop_val_y=0, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  158. poss_scale_vals = [0]
  159. poss_rotate_vals = [0]
  160. poss_translate_x_vals = [0]
  161. poss_translate_y_vals = [0]
  162. fun_vals = sim_res[measure]
  163. # otherwise, optimise translation, scale, and/or rotation
  164. else:
  165. # functions which will be optimised (note that scale is on a log-scale here for the optimiser - this is useful as centred on zero and will have same precision for increase and decrease, i.e. whether 2x or 0.5x) and prevent a scale of 0
  166. def sim_opt_translate_scale_rotate(x):
  167. # map logistic scale to raw scale
  168. translate_prop_val_x = utils.inv_logistic(x[0])
  169. translate_prop_val_y = utils.inv_logistic(x[1])
  170. # map log scale to raw scale
  171. scale_exp = np.exp(x[2])
  172. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=scale_exp, rotate_val=x[3], translate_prop_val_x=translate_prop_val_x, translate_prop_val_y=translate_prop_val_y, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  173. def sim_opt_translate_scale(x):
  174. # map logistic scale to raw scale
  175. translate_prop_val_x = utils.inv_logistic(x[0])
  176. translate_prop_val_y = utils.inv_logistic(x[1])
  177. # map log scale to raw scale
  178. scale_exp = np.exp(x[2])
  179. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=scale_exp, rotate_val=0, translate_prop_val_x=translate_prop_val_x, translate_prop_val_y=translate_prop_val_y, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  180. def sim_opt_translate_rotate(x):
  181. # map logistic scale to raw scale
  182. translate_prop_val_x = utils.inv_logistic(x[0])
  183. translate_prop_val_y = utils.inv_logistic(x[1])
  184. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=1, rotate_val=x[2], translate_prop_val_x=translate_prop_val_x, translate_prop_val_y=translate_prop_val_y, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  185. def sim_opt_scale_rotate(x):
  186. # map log scale to raw scale
  187. scale_exp = np.exp(x[0])
  188. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=scale_exp, rotate_val=x[1], translate_prop_val_x=0, translate_prop_val_y=0, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  189. def sim_opt_translate(x):
  190. # map logistic scale to raw scale
  191. translate_prop_val_x = utils.inv_logistic(x[0])
  192. translate_prop_val_y = utils.inv_logistic(x[1])
  193. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=1, rotate_val=0, translate_prop_val_x=translate_prop_val_x, translate_prop_val_y=translate_prop_val_y, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  194. def sim_opt_scale(x):
  195. # map log scale to raw scale
  196. scale_exp = np.exp(x[0])
  197. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=scale_exp, rotate_val=0, translate_prop_val_x=0, translate_prop_val_y=0, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  198. def sim_opt_rotate(x):
  199. return text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, scale_val=1, rotate_val=x[0], translate_prop_val_x=0, translate_prop_val_y=0, size=size, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)[measure]
  200. # bounds of translate & scale optimisation - used to calculate the starting values
  201. scale_bounds = (-np.log(max_scale_change_factor), np.log(max_scale_change_factor))
  202. # (translation units are in raw rather than link units here; they are transformed when passed to the minimisation fun)
  203. translate_bounds = (-max_translation_factor, max_translation_factor)
  204. # starting values for optimising translation, scale, and rotation
  205. if search_method=='grid':
  206. starting_points_scale = np.linspace(
  207. scale_bounds[0],
  208. scale_bounds[1],
  209. scale_eval_n, endpoint=True)
  210. starting_points_rotation = np.linspace(
  211. max((-180, min(rotation_bounds))),
  212. min((180, max(rotation_bounds))),
  213. rotation_eval_n, endpoint=True)
  214. starting_points_translation = utils.logistic(np.linspace(
  215. translate_bounds[0],
  216. translate_bounds[1],
  217. translation_eval_n, endpoint=True))
  218. elif search_method=='random':
  219. starting_points_scale = np.random.uniform(
  220. scale_bounds[0],
  221. scale_bounds[1],
  222. size=scale_eval_n)
  223. starting_points_rotation = np.random.uniform(
  224. max((-180, min(rotation_bounds))),
  225. min((180, max(rotation_bounds))),
  226. size=rotation_eval_n)
  227. starting_points_translation = utils.logistic(np.random.uniform(
  228. translate_bounds[0],
  229. translate_bounds[1],
  230. size=translation_eval_n))
  231. # list which will contain the results
  232. iter_res = []
  233. if translate:
  234. if (scale) and (rotate):
  235. for start_translate_x in starting_points_translation:
  236. for start_translate_y in starting_points_translation:
  237. for start_scale in starting_points_scale:
  238. for start_rotate in starting_points_rotation:
  239. iter_res.append(minimize(sim_opt_translate_scale_rotate, x0=[start_translate_x, start_translate_y, start_scale, start_rotate], method=solver, bounds=[utils.logistic(translate_bounds), utils.logistic(translate_bounds), scale_bounds, rotation_bounds]))
  240. elif (scale) and (not rotate):
  241. for start_translate_x in starting_points_translation:
  242. for start_translate_y in starting_points_translation:
  243. for start_scale in starting_points_scale:
  244. iter_res.append(minimize(sim_opt_translate_scale, x0=[start_translate_x, start_translate_y, start_scale], method=solver, bounds = [utils.logistic(translate_bounds), utils.logistic(translate_bounds), scale_bounds]))
  245. elif (not scale) and (rotate):
  246. for start_translate_x in starting_points_translation:
  247. for start_translate_y in starting_points_translation:
  248. for start_rotate in starting_points_rotation:
  249. iter_res.append(minimize(sim_opt_translate_rotate, x0=[start_translate_x, start_translate_y, start_rotate], method=solver, bounds = [utils.logistic(translate_bounds), utils.logistic(translate_bounds), rotation_bounds]))
  250. elif (not scale) and (not rotate):
  251. for start_translate_x in starting_points_translation:
  252. for start_translate_y in starting_points_translation:
  253. iter_res.append(minimize(sim_opt_translate, x0=[start_translate_x, start_translate_y], method=solver, bounds = [utils.logistic(translate_bounds), utils.logistic(translate_bounds)]))
  254. else:
  255. if (scale) and (rotate):
  256. for start_scale in starting_points_scale:
  257. for start_rotate in starting_points_rotation:
  258. iter_res.append(minimize(sim_opt_scale_rotate, x0=[start_scale, start_rotate], method=solver, bounds=[scale_bounds, rotation_bounds]))
  259. elif (scale) and (not rotate):
  260. for start_scale in starting_points_scale:
  261. iter_res.append(minimize(sim_opt_scale, x0=[start_scale], method=solver, bounds = [scale_bounds]))
  262. elif (not scale) and (rotate):
  263. for start_rotate in starting_points_rotation:
  264. iter_res.append(minimize(sim_opt_rotate, x0=[start_rotate], method=solver, bounds = [rotation_bounds]))
  265. fun_vals = np.array([i['fun'] for i in iter_res])
  266. # first, get indices of iterations which reached the best solution
  267. min_fun_idx = fun_vals == np.min(fun_vals)
  268. # use this to extract possible scale and rotation solutions
  269. if translate:
  270. if (scale) and (rotate):
  271. poss_translate_x_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  272. poss_translate_y_vals = np.array([i['x'][1] for i in iter_res])[min_fun_idx]
  273. poss_scale_vals = np.array([i['x'][2] for i in iter_res])[min_fun_idx]
  274. poss_rotate_vals = np.array([i['x'][3] for i in iter_res])[min_fun_idx]
  275. elif (scale) and (not rotate):
  276. poss_translate_x_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  277. poss_translate_y_vals = np.array([i['x'][1] for i in iter_res])[min_fun_idx]
  278. poss_scale_vals = np.array([i['x'][2] for i in iter_res])[min_fun_idx]
  279. poss_rotate_vals = np.zeros(poss_scale_vals.shape)
  280. elif (not scale) and (rotate):
  281. poss_translate_x_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  282. poss_translate_y_vals = np.array([i['x'][1] for i in iter_res])[min_fun_idx]
  283. poss_scale_vals = np.zeros(poss_translate_x_vals.shape)
  284. poss_rotate_vals = np.array([i['x'][2] for i in iter_res])[min_fun_idx]
  285. elif (not scale) and (not rotate):
  286. poss_translate_x_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  287. poss_translate_y_vals = np.array([i['x'][1] for i in iter_res])[min_fun_idx]
  288. poss_scale_vals = np.zeros(poss_translate_x_vals.shape)
  289. poss_rotate_vals = np.zeros(poss_translate_x_vals.shape)
  290. else:
  291. if (scale) and (rotate):
  292. poss_scale_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  293. poss_rotate_vals = np.array([i['x'][1] for i in iter_res])[min_fun_idx]
  294. poss_translate_x_vals = np.zeros(poss_scale_vals.shape)
  295. poss_translate_y_vals = np.zeros(poss_scale_vals.shape)
  296. elif (scale) and (not rotate):
  297. poss_scale_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  298. poss_rotate_vals = np.zeros(poss_scale_vals.shape)
  299. poss_translate_x_vals = np.zeros(poss_scale_vals.shape)
  300. poss_translate_y_vals = np.zeros(poss_scale_vals.shape)
  301. elif (not scale) and (rotate):
  302. poss_rotate_vals = np.array([i['x'][0] for i in iter_res])[min_fun_idx]
  303. poss_scale_vals = np.zeros(poss_rotate_vals.shape)
  304. poss_translate_x_vals = np.zeros(poss_rotate_vals.shape)
  305. poss_translate_y_vals = np.zeros(poss_rotate_vals.shape)
  306. # make sure the rotation values are all expressed within [-180, 180] instead of [0, inf]
  307. # (this is useful for minimising the angle when there are multiple identical solutions)
  308. poss_rotate_vals %= 360
  309. poss_rotate_vals_dir = np.matrix([poss_rotate_vals, poss_rotate_vals-360])
  310. poss_rotate_pw_idx = np.array(np.matrix.argmin(np.abs(poss_rotate_vals_dir), 0))[0]
  311. poss_rotate_vals = np.array([poss_rotate_vals_dir[poss_rotate_pw_idx[i], i] for i in range(poss_rotate_vals_dir.shape[1])])
  312. # next, get the solutions of these with the smallest absolute scale (i.e., closest to original log scale value of zero)
  313. min_abs_scale_idx = np.abs(poss_scale_vals) == np.min(np.abs(poss_scale_vals))
  314. poss_scale_vals = poss_scale_vals[min_abs_scale_idx]
  315. poss_rotate_vals = poss_rotate_vals[min_abs_scale_idx]
  316. # finally, get the solution, of those, with the smallest absolute rotation (i.e., closest to original rotation)
  317. min_abs_rotate_idx = np.abs(poss_rotate_vals) == np.min(np.abs(poss_rotate_vals))
  318. poss_scale_vals = poss_scale_vals[min_abs_rotate_idx]
  319. poss_rotate_vals = poss_rotate_vals[min_abs_rotate_idx]
  320. # replicate the optimal values to extract the translation values
  321. sim_res = text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=fliplr, flipud=flipud, translate_prop_val_x=utils.inv_logistic(poss_translate_x_vals[0]), translate_prop_val_y=utils.inv_logistic(poss_translate_y_vals[0]), scale_val=np.exp(poss_scale_vals[0]), rotate_val=poss_rotate_vals[0], size=size, plot=plot, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  322. res = {'a':a, 'b':b,
  323. 'font_a':font_a, 'font_b':font_b,
  324. # settings for optimisation
  325. 'translate': translate,
  326. 'scale': scale,
  327. 'rotate': rotate,
  328. 'fliplr': None,
  329. 'flipud': None,
  330. # results from optimisation
  331. measure: np.min(fun_vals),
  332. 'translate_prop_val_x': utils.inv_logistic(poss_translate_x_vals[0]),
  333. 'translate_prop_val_y': utils.inv_logistic(poss_translate_y_vals[0]),
  334. 'translate_val_x': sim_res['shift'][0],
  335. 'translate_val_y': sim_res['shift'][1],
  336. # the optimal scale and rotation values
  337. 'scale_val': np.exp(poss_scale_vals[0]),
  338. 'rotate_val': poss_rotate_vals[0],
  339. 'fliplr_val': fliplr,
  340. 'flipud_val': flipud}
  341. return(res)
  342. def opt_text_arr_sim(a='a', b=None, font_a='arial.ttf', font_b='arial.ttf', b_arr=None, measure='partial_wasserstein', translate=True, scale=True, rotate=True, fliplr=True, flipud=False, size=100, rotation_bounds=(-np.Infinity, np.Infinity), max_scale_change_factor=2, max_translation_factor=0.999, rotation_eval_n=9, scale_eval_n=9, translation_eval_n=9, solver='Nelder-Mead', search_method='grid', plot=False, partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':True, 'distance_normalise': True, 'ins_weight':0.0, 'del_weight':0.0}, **kwargs):
  343. """Find parameters for geometric operations of translation, scale, rotation, and horizontal flipping that maximise overlap between two arrays of drawn text.
  344. Parameters
  345. ----------
  346. a : str
  347. b : str, optional
  348. Must be defined if `b_arr` is not. Ignored if `b_arr` is defined.
  349. font_a : str, optional
  350. `.ttf` font to use to build text array from `a`
  351. font_b : str, optional
  352. `.ttf` font to use to build text array from `b`
  353. b_arr : ndarray, optional
  354. Array that the array built from `a` will be compared to. This option is included as it is faster to pre-build the array that `a` will be compared to, if applying this function in a loop. If both `b` and `b_arr` are defined, `b` will be ignored.
  355. measure: str
  356. Argument not used. Included for consistency with `text_arr_sim.text_arr_sim()`.
  357. translate : bool
  358. Should the translation operation be optimised? If `False`, will always use default positions.
  359. scale: bool
  360. Should scale be optimised?
  361. rotate : bool
  362. Should rotation be optimised?
  363. fliplr : bool
  364. Should horizontal flipping (mirroring) be optimised?
  365. flipud : bool
  366. Should vertical flipping (mirroring) be optimised?
  367. size : int
  368. Size for the text (the scale parameter will be multiplied by this value).
  369. rotation_bounds : tuple
  370. Limits for optimising rotation in form `(lowerbound, upperbound)`. For example, `(-90, 90)` will limit rotation to 90 degrees in either direction.
  371. max_scale_change_factor : float
  372. Maximum value for the optimised scale parameter. `max_scale_change_factor=2` will permit 100% bigger or 50% smaller, i.e., twice as large or twice as small.
  373. max_translation_factor : float
  374. Maximum value for the optimised translation parameter. `max_translation_factor=0.5` will permit translation 50% of the maximum bounds in any direction, where bounds are determined by the maximum size of the arrays being compared. Must be <1, as it is optimised on a logistic scale.
  375. rotation_eval_n : int
  376. How many starting values should be tried for optimising rotation?
  377. scale_eval_n : int
  378. How many starting values should be tried for optimising scale?
  379. translation_eval_n : int
  380. How many starting values should be tried for optimising translation (x and y)?
  381. solver : str
  382. Which solver to use? Possible values are those available to `scipy.optimize.minimize()`.
  383. search_method : str
  384. Method for setting starting values. Options are:
  385. 'grid': set in equal steps from the lower to the upper bound
  386. 'random': set randomly between the lower and upper bound
  387. plot : bool
  388. Should the optimal solution be plotted?
  389. **kwargs
  390. Other arguments to pass to `text_arr_sim()`.
  391. Returns
  392. -------
  393. dict
  394. A dictionary containing the following values:
  395. 'translate': Whether translation was optimised
  396. 'scale': Whether scale was optimised
  397. 'rotate': Whether rotation was optimised
  398. 'flip': Whether flip was optimised
  399. 'intersection', 'union', 'overlap', 'jaccard', 'dice': Values from `arr_sim.arr_sim()`
  400. 'translate_val_x': Optimal shift value in x dimension
  401. 'translate_val_y': Optimal shift value in y dimension
  402. 'scale_val': Optimal scale coefficient
  403. 'rotate_val': Optimal rotation coefficient
  404. 'flip_val': Whether the optimal solution included flipping
  405. Examples
  406. --------
  407. >>> opt_text_arr_sim('e', 'o')
  408. """
  409. non_flipped = _opt_text_arr_sim_flip_manual(a=a, b=b, font_a=font_a, font_b=font_b, b_arr=b_arr, measure=measure, translate=translate, scale=scale, rotate=rotate, fliplr=False, flipud=False, size=size, rotation_bounds=rotation_bounds, max_scale_change_factor=max_scale_change_factor, max_translation_factor=max_translation_factor, rotation_eval_n=rotation_eval_n, scale_eval_n=scale_eval_n, translation_eval_n=translation_eval_n, solver=solver, search_method=search_method, plot=False, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  410. res = non_flipped
  411. res['fliplr'] = False
  412. res['flipud'] = False
  413. if fliplr:
  414. flipped_lr = _opt_text_arr_sim_flip_manual(a=a, b=b, font_a=font_a, font_b=font_b, b_arr=b_arr, measure=measure, translate=translate, scale=scale, rotate=rotate, fliplr=True, flipud=False, size=size, rotation_bounds=rotation_bounds,max_scale_change_factor=max_scale_change_factor, max_translation_factor=max_translation_factor, rotation_eval_n=rotation_eval_n, scale_eval_n=scale_eval_n, translation_eval_n=translation_eval_n, solver=solver, search_method=search_method, plot=False, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  415. if flipped_lr[measure] > res[measure] and np.abs(flipped_lr['rotate_val']) <= np.abs(res['rotate_val']):
  416. res = flipped_lr
  417. if flipud:
  418. flipped_ud = _opt_text_arr_sim_flip_manual(a=a, b=b, font_a=font_a, font_b=font_b, b_arr=b_arr, measure=measure, translate=translate, scale=scale, rotate=rotate, fliplr=False, flipud=True, size=size, rotation_bounds=rotation_bounds, max_scale_change_factor=max_scale_change_factor, max_translation_factor=max_translation_factor, rotation_eval_n=rotation_eval_n, scale_eval_n=scale_eval_n, translation_eval_n=translation_eval_n, solver=solver, search_method=search_method, plot=False, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  419. if flipped_ud[measure] > res[measure] and np.abs(flipped_ud['rotate_val']) <= np.abs(res['rotate_val']):
  420. res = flipped_ud
  421. if fliplr and flipud:
  422. flipped_lrud = _opt_text_arr_sim_flip_manual(a=a, b=b, font_a=font_a, font_b=font_b, b_arr=b_arr, measure=measure, translate=translate, scale=scale, rotate=rotate, fliplr=True, flipud=True, size=size, rotation_bounds=rotation_bounds, max_scale_change_factor=max_scale_change_factor, max_translation_factor=max_translation_factor, rotation_eval_n=rotation_eval_n, scale_eval_n=scale_eval_n, translation_eval_n=translation_eval_n, solver=solver, search_method=search_method, plot=False, partial_wasserstein_kwargs=partial_wasserstein_kwargs, **kwargs)
  423. if flipped_ud[measure] > res[measure] and np.abs(flipped_lrud['rotate_val']) <= np.abs(res['rotate_val']):
  424. res = flipped_lrud
  425. res['fliplr'] = fliplr
  426. res['flipud'] = flipud
  427. if plot:
  428. # replicate the optimal values to plot
  429. if np.all(b_arr==None):
  430. b_arr = draw.text_array(b, font=font_b, rotate=0, fliplr=False, flipud=False, size=size, **kwargs)
  431. partial_wasserstein_kwargs_pl = partial_wasserstein_kwargs.copy()
  432. partial_wasserstein_kwargs_pl['trans_manual'] = (res['translate_val_x'], res['translate_val_y'])
  433. _ = text_arr_sim(a, b_arr=b_arr, measure=measure, font_a=font_a, translate=None, fliplr=res['fliplr_val'], flipud=res['flipud_val'], scale_val=res['scale_val'], rotate_val=res['rotate_val'], size=size, plot=plot, partial_wasserstein_kwargs=partial_wasserstein_kwargs_pl, **kwargs)
  434. return(res)