123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- # %% import modules ======================
- import os
- n_threads = str(14)
- os.environ["OMP_NUM_THREADS"] = n_threads
- os.environ["OPENBLAS_NUM_THREADS"] = n_threads
- os.environ["MKL_NUM_THREADS"] = n_threads
- os.environ["VECLIB_MAXIMUM_THREADS"] = n_threads
- os.environ["NUMEXPR_NUM_THREADS"] = n_threads
- import os.path as op
- import pandas as pd
- import numpy as np
- from tqdm import tqdm
- # from string import ascii_lowercase
- import sys
- import os.path as op
- sys.path.append(op.join(op.dirname(__file__), '..'))
- import scold
- from scold import draw, text_arr_sim, arr_sim, text_arr_sim_wasserstein, utils
- import ot
- import matplotlib.pyplot as plt
- from matplotlib import rc
- from mpl_toolkits.axes_grid1 import make_axes_locatable
- from tqdm import tqdm
- fig_font_size = 8
- fig_font_size_small = 8
- plt.rcParams.update({
- "text.usetex": False,
- "font.family": "Helvetica",
- 'font.size': fig_font_size
- })
- # rc('text.latex', preamble='\n'.join([
- # r'\usepackage{tgheros}', # helvetica font
- # r'\renewcommand\familydefault{\sfdefault} ',
- # r'\usepackage[T1]{fontenc}'
- # ]))
- #%matplotlib qt
- # %% setup
- font = 'Arial-Lgt.ttf'
- font_size_plot = 25 # base font size for plotting
- font_size = 75
- round_method = None
- char1 = 'a'
- char2 = 'c'
- char2_arr = draw.text_array(char2, font=font, size=font_size, method=round_method)
- char2_arr_pl = draw.text_array(char2, font=font, size=font_size_plot, method=round_method)
- n_scale = 250
- n_rotation = 250
- logscale = np.linspace(np.log(0.5), np.log(2), num=n_scale, endpoint=True)
- rotation = np.linspace(-180, 180, num=n_rotation, endpoint=True)
- # %%
- # get the optima
- # opt_w = text_arr_sim_wasserstein.opt_text_arr_sim(
- # char1, b_arr=char2_arr,
- # font_b=font, size=font_size, method=round_method,
- # measure='partial_wasserstein',
- # translate=True, translation_eval_n=5, max_translation_factor=0.99,
- # scale=True, scale_eval_n=5, max_scale_change_factor=2.0,
- # fliplr=False, flipud=False,
- # rotate=True, rotation_eval_n=5, rotation_bounds=(-np.inf, np.inf),
- # partial_wasserstein_kwargs={'scale_mass':True, 'scale_mass_method':'proportion', 'mass_normalise':False, 'distance_normalise':False, 'ins_weight':0.0, 'del_weight':0.0}
- # )
- # opt_j = text_arr_sim.opt_text_arr_sim(
- # char1, b_arr=char2_arr,
- # font_a=font, font_b=font, size=font_size, method=round_method,
- # measure='jaccard',
- # translate=True,
- # scale=True, scale_eval_n=9, max_scale_change_factor=2.0,
- # fliplr=False, flipud=False,
- # rotate=True, rotation_eval_n=9, rotation_bounds=(-np.inf, np.inf)
- # )
- # %% import the mapped spaces
- space_j = np.load(op.join('fig_code', '99_estimate_spaces_jaccard.npy'))
- space_w = np.load(op.join('fig_code', '99_estimate_spaces_wasserstein.npy'))
- # %%
- # plot
- fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))
- im_interpolation = 'none'
- # Jaccard Distance
- im = axs[0].imshow(space_w, interpolation=im_interpolation, aspect='auto')
- xticks = np.array([-180, -90, 0, 90, 180])
- xtick_locs = 0.5 * (1 + xticks/rotation.max()) * (n_rotation-1)
- axs[0].set_xticks(ticks=xtick_locs, labels=xticks)
- axs[0].set_xlabel('Rotation (°)')
- yticks = np.linspace(np.log(0.5), np.log(2), num=5, endpoint=True)
- ytick_locs = 0.5 * (1 + yticks/logscale.max()) * (n_scale-1)
- axs[0].set_yticks(ticks=ytick_locs, labels=np.round(yticks, 2))
- axs[0].invert_yaxis()
- axs[0].set_ylabel('Log Scale')
- ax_y2 = axs[0].secondary_yaxis('right')
- ax_y2.set_yticks(ytick_locs)
- ax_y2.set_yticklabels(np.round(np.exp(yticks), 2))
- ax_y2.set_ylabel('Scale', rotation=270, va='bottom')
- divider = make_axes_locatable(axs[0])
- cax = divider.append_axes('top', size='7.5%', pad=0.075)
- plt.colorbar(im, cax=cax, location='top', orientation='horizontal', label='Wasserstein Distance', ticks = [4, 6, 8, 10, 12, 14])
- # Wasserstein Distance
- im = axs[1].imshow(space_j, interpolation=im_interpolation, aspect='auto')
- axs[1].set_xticks(ticks=xtick_locs, labels=xticks)
- axs[1].set_xlabel('Rotation (°)')
- axs[1].set_yticks(ticks=ytick_locs, labels=np.round(yticks, 2))
- axs[1].invert_yaxis()
- axs[1].set_ylabel('Log Scale')
- ax_y2 = axs[1].secondary_yaxis('right')
- ax_y2.set_yticks(ytick_locs)
- ax_y2.set_yticklabels(np.round(np.exp(yticks), 2))
- ax_y2.set_ylabel('Scale', rotation=270, va='bottom')
- divider = make_axes_locatable(axs[1])
- cax = divider.append_axes('top', size='7.5%', pad=0.075)
- plt.colorbar(im, cax=cax, location='top', orientation='horizontal', label='Jaccard Distance', ticks=[0.5, 0.6, 0.7, 0.8])
- # example locations to plot in the space
- # listed as rotation, scale (translation is optimised for each)
- eg_locations = [
- # [-168.9920930404184, 0.8333333333333334],
- # [-171.32530120481928, 0.98617779803696],
- [-177.35775607466348, 1.0029500584943818],
- [-80, 0.75],
- [0, 1],
- [90, 1.41],
- [175, 1.6]
- ]
- for eg_loc in eg_locations:
- x_prop = (eg_loc[0] - rotation.min()) / (rotation.max() - rotation.min())
- x_pos = x_prop * (len(rotation)-1)
- y_prop = (np.log(eg_loc[1]) - logscale.min()) / (logscale.max() - logscale.min())
- y_pos = y_prop * (len(logscale)-1)
- axs[0].scatter(x_pos, y_pos, marker='+', color='r')
- axs[1].scatter(x_pos, y_pos, marker='+', color='r')
- # finalise
- fig.tight_layout()
- fig.savefig(op.join('fig', 'optimisation_illustration', 'spaces.svg'))
- # %%
- # plot examples
- eg_fig_size = (0.65, 0.65)
- for eg_nr, eg_loc in enumerate(tqdm(eg_locations, desc='plotting examples')):
- # estimate the translation parameters at this resolution
- jacc_trans_opt = text_arr_sim.text_arr_sim(a=char1, b_arr=char2_arr_pl, font_a=font, size=font_size_plot, method=round_method, measure='jaccard', translate=True, scale_val=eg_loc[1], fliplr=False, flipud=False, rotate_val=eg_loc[0])
- wass_trans_opt = text_arr_sim.text_arr_sim(a=char1, b=None, font_a=font, font_b=font, b_arr=char2_arr_pl, measure='partial_wasserstein', translate=True, fliplr=False, flipud=False, size=font_size_plot, scale_val=eg_loc[1], rotate_val=eg_loc[0], plot=False, partial_wasserstein_kwargs={'scale_mass':True, 'mass_normalise':False, 'scale_mass_method':'proportion', 'distance_normalise':False, 'translation':'opt', 'n_startvals':5, 'solver':'Nelder-Mead', 'search_method':'grid'})
- # plot Wasserstein positions
- char1_arr_i = draw.text_array(char1, font=font, size=font_size_plot*eg_loc[1], method=round_method, rotate=eg_loc[0])
- # shift order of x and y to match the function
- trans_manual = [wass_trans_opt['shift'][1], wass_trans_opt['shift'][0]]
- char1_pad, char2_pad = utils.pad_for_translation(char1_arr_i.T, char2_arr_pl.T)
- # note: rounded to nearest integer for image, but remember that translation is optimised continuously for Wasserstein distance
- char1_pad_trans = utils.pad_translate_mat(char1_pad, int(round(trans_manual[0])), int(round(trans_manual[1])))
- s = utils.crop_zeros(char1_pad_trans, char1_pad_trans+char2_pad)
- t = utils.crop_zeros(char2_pad, char1_pad_trans+char2_pad)
- w_manual = arr_sim.partial_wasserstein(s, t, scale_mass=True, scale_mass_method='proportion', mass_normalise=False, distance_normalise=False, del_weight=0.0, ins_weight=0.0, trans_weight=1.0, return_res=True, trans_manual=(0,0))
- tp = w_manual['tp']
- # plot Wasserstein
- xs = np.transpose(np.array(np.where(s!=0)))
- xt = np.transpose(np.array(np.where(t!=0)))
- M = ot.dist(xs, xt, metric='Euclidean')
- pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
- pl_rgb[:, :, 0] = s/s.max()
- pl_rgb[:, :, 2] = t/t.max()
- fig, ax = plt.subplots(figsize=eg_fig_size)
- ax.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
- for i in range(xs.shape[0]):
- for j in range(xt.shape[0]):
- if M[i, j] > 0:
- ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], alpha=tp[i, j] / tp.max(), color='k', linewidth=0.33)
- ax.axis('off')
- fig.savefig(op.join('fig', 'optimisation_illustration', f'w_example_{eg_nr}.svg'))
- # plot Jaccard
- # note that here the x and y are never swapped
- trans_manual = [jacc_trans_opt['shift'][0], jacc_trans_opt['shift'][1]]
- char1_pad_trans = utils.pad_translate_mat(char1_pad, int(round(trans_manual[0])), int(round(trans_manual[1])))
- s = utils.crop_zeros(char1_pad_trans, char1_pad_trans+char2_pad)
- t = utils.crop_zeros(char2_pad, char1_pad_trans+char2_pad)
- arr_sim.arr_sim(s, t)
- pl_rgb = np.zeros((s.shape[0], s.shape[1], 3))
- pl_rgb[:, :, 0] = s/s.max()
- pl_rgb[:, :, 2] = t/t.max()
- fig, ax = plt.subplots(figsize=eg_fig_size)
- ax.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
- ax.axis('off')
- fig.savefig(op.join('fig', 'optimisation_illustration', f'j_example_{eg_nr}.svg'))
- # %%
- # plot the optimised matrices
- arr_fig_size = (2.75, 0.75)
- def rankify_vec(vec):
- # rank values in a vector
- # (ties have their ranks averaged, as in R)
- rank_vec = vec.argsort().argsort() + 1
- rank_vec = rank_vec.astype(float)
- for uv in np.unique(vec):
- uv_idx = vec==uv
- if np.sum(uv_idx)>1:
- rank_vec[uv_idx] = np.mean(rank_vec[uv_idx])
- return rank_vec
- def rankify_mat(mat):
- lwr = mat[np.tril_indices(n=mat.shape[0], k=-1)]
- lwr_rank = rankify_vec(lwr)
- mat_rank = np.zeros(mat.shape)
- mat_rank[np.tril_indices(n=mat.shape[0], k=-1)] = lwr_rank
- mat_rank[np.triu_indices(n=mat.shape[0], k=0)] = mat_rank.T[np.triu_indices(n=mat.shape[0], k=0)]
- mat_rank[np.diag_indices(n=mat_rank.shape[0])] = np.nan
- return mat_rank
- geom_combs = [
- # R, S, T
- [0, 0, 0],
- [1, 0, 0],
- [0, 1, 0],
- [0, 0, 1],
- [1, 1, 0],
- [1, 0, 1],
- [0, 1, 1],
- [1, 1, 1]
- ]
- for measure in ['ot', 'jacc']:
- max_dist = 0 # used to set the colour bounds
- geom_loc = op.join('stim_sim', f'{measure}_geom')
- n_types = 9
- fig, axs = plt.subplots(2, n_types, figsize=arr_fig_size)
- x_sp = 0
- raw_ims = []
- for comb_i in geom_combs:
- r, s, t = comb_i
- t_lab = 'T' if t==1 else ''
- s_lab = 'S' if s==1 else ''
- r_lab = 'R' if r==1 else ''
- arr = np.load(op.join(geom_loc, f'{measure}_T{t}_S{s}_R{r}_Flr0_Fud0.npy'))
- if arr[~np.isnan(arr)].max() > max_dist:
- max_dist = arr[~np.isnan(arr)].max()
- pl_title = '-' if (r==0) and (s==0) and (t==0) else f'{r_lab}{s_lab}{t_lab}'
- # raw
- raw_ims.append(axs[0, x_sp].imshow(arr, interpolation='none'))
- axs[0, x_sp].axis('off')
- axs[0, x_sp].set_title(pl_title, size=fig_font_size)
- # rank
- axs[1, x_sp].imshow(rankify_mat(arr), interpolation='none')
- axs[1, x_sp].axis('off')
- x_sp += 1
-
- if measure=='jacc':
- max_dist = 1
-
- for raw_im_i in raw_ims:
- raw_im_i.set_clim([0, max_dist])
- if measure == 'ot':
- arr = np.load(op.join(geom_loc, f'{measure}_pgw.npy'))
- # raw
- axs[0, x_sp].imshow(arr, interpolation='none')
- axs[0, x_sp].axis('off')
- axs[0, x_sp].set_title('G-W', size=fig_font_size)
- # rank
- axs[1, x_sp].imshow(rankify_mat(arr), interpolation='none')
- axs[1, x_sp].axis('off')
- else:
- axs[0, x_sp].axis('off')
- axs[1, x_sp].axis('off')
-
- fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.1, hspace=0.1)
- fig.savefig(op.join('fig', 'optimisation_illustration', f'mats_{measure}.svg'))
|