99_estimate_spaces_illustrate.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # %% import modules ======================
  2. import os
  3. n_threads = str(14)
  4. os.environ["OMP_NUM_THREADS"] = n_threads
  5. os.environ["OPENBLAS_NUM_THREADS"] = n_threads
  6. os.environ["MKL_NUM_THREADS"] = n_threads
  7. os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads
  8. os.environ["NUMEXPR_NUM_THREADS"] = n_threads
  9. import os.path as op
  10. import pandas as pd
  11. import numpy as np
  12. from tqdm import tqdm
  13. # from string import ascii_lowercase
  14. import sys
  15. import os.path as op
  16. sys.path.append(op.join(op.dirname(__file__), '..'))
  17. import scold
  18. from scold import draw, text_arr_sim, arr_sim, text_arr_sim_wasserstein, utils
  19. import ot
  20. import matplotlib.pyplot as plt
  21. from matplotlib import rc
  22. from mpl_toolkits.axes_grid1 import make_axes_locatable
  23. from tqdm import tqdm
  24. fig_font_size = 8
  25. fig_font_size_small = 8
  26. plt.rcParams.update({
  27. "text.usetex": False,
  28. "font.family": "Helvetica",
  29. 'font.size': fig_font_size
  30. })
  31. # rc('text.latex', preamble='\n'.join([
  32. # r'\usepackage{tgheros}', # helvetica font
  33. # r'\renewcommand\familydefault{\sfdefault} ',
  34. # r'\usepackage[T1]{fontenc}'
  35. # ]))
  36. #%matplotlib qt
  37. # %% setup
  38. font = 'Arial-Lgt.ttf'
  39. font_size_plot = 25 # base font size for plotting
  40. font_size = 75
  41. round_method = None
  42. char1 = 'a'
  43. char2 = 'c'
  44. char2_arr = draw.text_array(char2, font=font, size=font_size, method=round_method)
  45. char2_arr_pl = draw.text_array(char2, font=font, size=font_size_plot, method=round_method)
  46. n_scale = 250
  47. n_rotation = 250
  48. logscale = np.linspace(np.log(0.5), np.log(2), num=n_scale, endpoint=True)
  49. rotation = np.linspace(-180, 180, num=n_rotation, endpoint=True)
  50. # %%
  51. # get the optima
  52. # opt_w = text_arr_sim_wasserstein.opt_text_arr_sim(
  53. # char1, b_arr=char2_arr,
  54. # font_b=font, size=font_size, method=round_method,
  55. # measure='partial_wasserstein',
  56. # translate=True, translation_eval_n=5, max_translation_factor=0.99,
  57. # scale=True, scale_eval_n=5, max_scale_change_factor=2.0,
  58. # fliplr=False, flipud=False,
  59. # rotate=True, rotation_eval_n=5, rotation_bounds=(-np.inf, np.inf),
  60. # partial_wasserstein_kwargs={'scale_mass':True, 'scale_mass_method':'proportion', 'mass_normalise':False, 'distance_normalise':False, 'ins_weight':0.0, 'del_weight':0.0}
  61. # )
  62. # opt_j = text_arr_sim.opt_text_arr_sim(
  63. # char1, b_arr=char2_arr,
  64. # font_a=font, font_b=font, size=font_size, method=round_method,
  65. # measure='jaccard',
  66. # translate=True,
  67. # scale=True, scale_eval_n=9, max_scale_change_factor=2.0,
  68. # fliplr=False, flipud=False,
  69. # rotate=True, rotation_eval_n=9, rotation_bounds=(-np.inf, np.inf)
  70. # )
  71. # %% import the mapped spaces
  72. space_j = np.load(op.join('fig_code', '99_estimate_spaces_jaccard.npy'))
  73. space_w = np.load(op.join('fig_code', '99_estimate_spaces_wasserstein.npy'))
  74. # %%
  75. # plot
  76. fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))
  77. im_interpolation = 'none'
  78. # Jaccard Distance
  79. im = axs[0].imshow(space_w, interpolation=im_interpolation, aspect='auto')
  80. xticks = np.array([-180, -90, 0, 90, 180])
  81. xtick_locs = 0.5 * (1 + xticks/rotation.max()) * (n_rotation-1)
  82. axs[0].set_xticks(ticks=xtick_locs, labels=xticks)
  83. axs[0].set_xlabel('Rotation (°)')
  84. yticks = np.linspace(np.log(0.5), np.log(2), num=5, endpoint=True)
  85. ytick_locs = 0.5 * (1 + yticks/logscale.max()) * (n_scale-1)
  86. axs[0].set_yticks(ticks=ytick_locs, labels=np.round(yticks, 2))
  87. axs[0].invert_yaxis()
  88. axs[0].set_ylabel('Log Scale')
  89. ax_y2 = axs[0].secondary_yaxis('right')
  90. ax_y2.set_yticks(ytick_locs)
  91. ax_y2.set_yticklabels(np.round(np.exp(yticks), 2))
  92. ax_y2.set_ylabel('Scale', rotation=270, va='bottom')
  93. divider = make_axes_locatable(axs[0])
  94. cax = divider.append_axes('top', size='7.5%', pad=0.075)
  95. plt.colorbar(im, cax=cax, location='top', orientation='horizontal', label='Wasserstein Distance', ticks = [4, 6, 8, 10, 12, 14])
  96. # Wasserstein Distance
  97. im = axs[1].imshow(space_j, interpolation=im_interpolation, aspect='auto')
  98. axs[1].set_xticks(ticks=xtick_locs, labels=xticks)
  99. axs[1].set_xlabel('Rotation (°)')
  100. axs[1].set_yticks(ticks=ytick_locs, labels=np.round(yticks, 2))
  101. axs[1].invert_yaxis()
  102. axs[1].set_ylabel('Log Scale')
  103. ax_y2 = axs[1].secondary_yaxis('right')
  104. ax_y2.set_yticks(ytick_locs)
  105. ax_y2.set_yticklabels(np.round(np.exp(yticks), 2))
  106. ax_y2.set_ylabel('Scale', rotation=270, va='bottom')
  107. divider = make_axes_locatable(axs[1])
  108. cax = divider.append_axes('top', size='7.5%', pad=0.075)
  109. plt.colorbar(im, cax=cax, location='top', orientation='horizontal', label='Jaccard Distance', ticks=[0.5, 0.6, 0.7, 0.8])
  110. # example locations to plot in the space
  111. # listed as rotation, scale (translation is optimised for each)
  112. eg_locations = [
  113. # [-168.9920930404184, 0.8333333333333334],
  114. # [-171.32530120481928, 0.98617779803696],
  115. [-177.35775607466348, 1.0029500584943818],
  116. [-80, 0.75],
  117. [0, 1],
  118. [90, 1.41],
  119. [175, 1.6]
  120. ]
  121. for eg_loc in eg_locations:
  122. x_prop = (eg_loc[0] - rotation.min()) / (rotation.max() - rotation.min())
  123. x_pos = x_prop * (len(rotation)-1)
  124. y_prop = (np.log(eg_loc[1]) - logscale.min()) / (logscale.max() - logscale.min())
  125. y_pos = y_prop * (len(logscale)-1)
  126. axs[0].scatter(x_pos, y_pos, marker='+', color='r')
  127. axs[1].scatter(x_pos, y_pos, marker='+', color='r')
  128. # finalise
  129. fig.tight_layout()
  130. fig.savefig(op.join('fig', 'optimisation_illustration', 'spaces.svg'))
  131. # %%
  132. # plot examples
  133. eg_fig_size = (0.65, 0.65)
  134. for eg_nr, eg_loc in enumerate(tqdm(eg_locations, desc='plotting examples')):
  135. # estimate the translation parameters at this resolution
  136. jacc_trans_opt = text_arr_sim.text_arr_sim(a=char1, b_arr=char2_arr_pl, font_a=font, size=font_size_plot, method=round_method, measure='jaccard', translate=True, scale_val=eg_loc[1], fliplr=False, flipud=False, rotate_val=eg_loc[0])
  137. wass_trans_opt = text_arr_sim.text_arr_sim(a=char1, b=None, font_a=font, font_b=font, b_arr=char2_arr_pl, measure='partial_wasserstein', translate=True, fliplr=False, flipud=False, size=font_size_plot, scale_val=eg_loc[1], rotate_val=eg_loc[0], plot=False, partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':False, 'scale_mass_method':'proportion', 'distance_normalise':False, 'translation':'opt', 'n_startvals':5, 'solver':'Nelder-Mead', 'search_method':'grid'})
  138. # plot Wasserstein positions
  139. char1_arr_i = draw.text_array(char1, font=font, size=font_size_plot*eg_loc[1], method=round_method, rotate=eg_loc[0])
  140. # shift order of x and y to match the function
  141. trans_manual = [wass_trans_opt['shift'][1], wass_trans_opt['shift'][0]]
  142. char1_pad, char2_pad = utils.pad_for_translation(char1_arr_i.T, char2_arr_pl.T)
  143. # note: rounded to nearest integer for image, but remember that translation is optimised continuously for Wasserstein distance
  144. char1_pad_trans = utils.pad_translate_mat(char1_pad, int(round(trans_manual[0])), int(round(trans_manual[1])))
  145. s = utils.crop_zeros(char1_pad_trans, char1_pad_trans+char2_pad)
  146. t = utils.crop_zeros(char2_pad, char1_pad_trans+char2_pad)
  147. w_manual = arr_sim.partial_wasserstein(s, t, scale_mass=True, scale_mass_method='proportion', mass_normalise=False, distance_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, return_res=True, trans_manual=(0,0))
  148. tp = w_manual['tp']
  149. # plot Wasserstein
  150. xs = np.transpose(np.array(np.where(s!=0)))
  151. xt = np.transpose(np.array(np.where(t!=0)))
  152. M = ot.dist(xs, xt, metric='Euclidean')
  153. pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
  154. pl_rgb[:, :, 0] = s/s.max()
  155. pl_rgb[:, :, 2] = t/t.max()
  156. fig, ax = plt.subplots(figsize=eg_fig_size)
  157. ax.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
  158. for i in range(xs.shape[0]):
  159. for j in range(xt.shape[0]):
  160. if M[i, j] > 0:
  161. ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], alpha=tp[i, j] / tp.max(), color='k', linewidth=0.33)
  162. ax.axis('off')
  163. fig.savefig(op.join('fig', 'optimisation_illustration', f'w_example_{eg_nr}.svg'))
  164. # plot Jaccard
  165. # note that here the x and y are never swapped
  166. trans_manual = [jacc_trans_opt['shift'][0], jacc_trans_opt['shift'][1]]
  167. char1_pad_trans = utils.pad_translate_mat(char1_pad, int(round(trans_manual[0])), int(round(trans_manual[1])))
  168. s = utils.crop_zeros(char1_pad_trans, char1_pad_trans+char2_pad)
  169. t = utils.crop_zeros(char2_pad, char1_pad_trans+char2_pad)
  170. arr_sim.arr_sim(s, t)
  171. pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
  172. pl_rgb[:, :, 0] = s/s.max()
  173. pl_rgb[:, :, 2] = t/t.max()
  174. fig, ax = plt.subplots(figsize=eg_fig_size)
  175. ax.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
  176. ax.axis('off')
  177. fig.savefig(op.join('fig', 'optimisation_illustration', f'j_example_{eg_nr}.svg'))
  178. # %%
  179. # plot the optimised matrices
  180. arr_fig_size = (2.75, 0.75)
  181. def rankify_vec(vec):
  182. # rank values in a vector
  183. # (ties have their ranks averaged, as in R)
  184. rank_vec = vec.argsort().argsort() + 1
  185. rank_vec = rank_vec.astype(float)
  186. for uv in np.unique(vec):
  187. uv_idx = vec==uv
  188. if np.sum(uv_idx)>1:
  189. rank_vec[uv_idx] = np.mean(rank_vec[uv_idx])
  190. return rank_vec
  191. def rankify_mat(mat):
  192. lwr = mat[np.tril_indices(n=mat.shape[0], k=-1)]
  193. lwr_rank = rankify_vec(lwr)
  194. mat_rank = np.zeros(mat.shape)
  195. mat_rank[np.tril_indices(n=mat.shape[0], k=-1)] = lwr_rank
  196. mat_rank[np.triu_indices(n=mat.shape[0], k=0)] = mat_rank.T[np.triu_indices(n=mat.shape[0], k=0)]
  197. mat_rank[np.diag_indices(n=mat_rank.shape[0])] = np.nan
  198. return mat_rank
  199. geom_combs = [
  200. # R, S, T
  201. [0, 0, 0],
  202. [1, 0, 0],
  203. [0, 1, 0],
  204. [0, 0, 1],
  205. [1, 1, 0],
  206. [1, 0, 1],
  207. [0, 1, 1],
  208. [1, 1, 1]
  209. ]
  210. for measure in ['ot', 'jacc']:
  211. max_dist = 0 # used to set the colour bounds
  212. geom_loc = op.join('stim_sim', f'{measure}_geom')
  213. n_types = 9
  214. fig, axs = plt.subplots(2, n_types, figsize=arr_fig_size)
  215. x_sp = 0
  216. raw_ims = []
  217. for comb_i in geom_combs:
  218. r, s, t = comb_i
  219. t_lab = 'T' if t==1 else ''
  220. s_lab = 'S' if s==1 else ''
  221. r_lab = 'R' if r==1 else ''
  222. arr = np.load(op.join(geom_loc, f'{measure}_T{t}_S{s}_R{r}_Flr0_Fud0.npy'))
  223. if arr[~np.isnan(arr)].max() > max_dist:
  224. max_dist = arr[~np.isnan(arr)].max()
  225. pl_title = '-' if (r==0) and (s==0) and (t==0) else f'{r_lab}{s_lab}{t_lab}'
  226. # raw
  227. raw_ims.append(axs[0, x_sp].imshow(arr, interpolation='none'))
  228. axs[0, x_sp].axis('off')
  229. axs[0, x_sp].set_title(pl_title, size=fig_font_size)
  230. # rank
  231. axs[1, x_sp].imshow(rankify_mat(arr), interpolation='none')
  232. axs[1, x_sp].axis('off')
  233. x_sp += 1
  234. if measure=='jacc':
  235. max_dist = 1
  236. for raw_im_i in raw_ims:
  237. raw_im_i.set_clim([0, max_dist])
  238. if measure == 'ot':
  239. arr = np.load(op.join(geom_loc, f'{measure}_pgw.npy'))
  240. # raw
  241. axs[0, x_sp].imshow(arr, interpolation='none')
  242. axs[0, x_sp].axis('off')
  243. axs[0, x_sp].set_title('G-W', size=fig_font_size)
  244. # rank
  245. axs[1, x_sp].imshow(rankify_mat(arr), interpolation='none')
  246. axs[1, x_sp].axis('off')
  247. else:
  248. axs[0, x_sp].axis('off')
  249. axs[1, x_sp].axis('off')
  250. fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.1, hspace=0.1)
  251. fig.savefig(op.join('fig', 'optimisation_illustration', f'mats_{measure}.svg'))