99_jaccard_distance_figure.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import sys
  2. import os.path as op
  3. sys.path.append(op.join(op.dirname(__file__), '..'))
  4. from scold import arr_sim, draw, utils
  5. import matplotlib.pyplot as plt
  6. from matplotlib.gridspec import GridSpec
  7. from mpl_toolkits.axes_grid1.inset_locator import inset_axes
  8. from matplotlib import rc
  9. import numpy as np
  10. import ot
  11. import os.path as op
  12. fig_font_size = 8
  13. plt.rcParams.update({
  14. "text.usetex": True,
  15. "font.family": "Helvetica",
  16. 'font.size': fig_font_size
  17. })
  18. rc('text.latex', preamble='\n'.join([
  19. r'\usepackage{tgheros}', # helvetica font
  20. r'\renewcommand\familydefault{\sfdefault} ',
  21. r'\usepackage[T1]{fontenc}'
  22. ]))
  23. # %%
  24. # letters
  25. font = 'arial.ttf'
  26. font_size = 20
  27. scale_mass = True
  28. k = draw.text_array('k', font=font, size=font_size, method=None)
  29. h = draw.text_array('h', font=font, size=font_size, method=None)
  30. # pad and assign to source and target arrays
  31. k_pad, h_pad = utils.pad_for_translation(h.T, k.T, pad=False, constant_values=0.0)
  32. # %%
  33. # plot
  34. fig = plt.figure(layout = 'constrained', figsize = (1.5, 0.8))
  35. gs = GridSpec(1, 3, figure=fig)
  36. axa = fig.add_subplot(gs[0])
  37. pl_rgb = np.zeros((k_pad.shape[0], k_pad.shape[1], 3))
  38. pl_rgb[:, :, 0] = k_pad/k_pad.max()
  39. pl_rgb[:, :, 2] = h_pad/h_pad.max()
  40. axa.imshow(utils.rotate_rgb_hue(1-pl_rgb.transpose((1, 0, 2)), 0.5), interpolation='none')
  41. axa.set_xticks([0,4,8])
  42. axa.set_yticks([0,4,8,12])
  43. axa.set_xticklabels(axa.get_xticks())
  44. axa.set_yticklabels(axa.get_yticks())
  45. axa.spines[['right', 'top']].set_visible(False)
  46. axb = fig.add_subplot(gs[1:])
  47. axb.axis('off')
  48. axb.text(0.1, 1, r'\[\frac{\Sigma\ \ }{\Sigma\ \ }\]', horizontalalignment='left', verticalalignment='top', fontsize=24)
  49. ia_upper = inset_axes(axb, height=0.25, width=1, loc=2, borderpad=0, bbox_to_anchor=(-0.14, 0.027, 1, 1), bbox_transform=axb.transAxes)
  50. ia_upper.axis('off')
  51. intersection = np.transpose(np.minimum(k_pad, h_pad) / k_pad.max())
  52. intersection_pad = np.pad(intersection, pad_width=((1,1), (1,1)), mode='constant', constant_values=0.0)
  53. union = np.transpose(np.maximum(k_pad, h_pad))
  54. union_pad = np.pad(union, pad_width=((1,1), (1,1)), mode='constant', constant_values=0.0)
  55. ia_int_im = ia_upper.imshow(intersection_pad, cmap='gist_gray', interpolation='none', vmin=0.0, vmax=1.0)
  56. # cb = plt.colorbar(ia_int_im)
  57. ia_lower = inset_axes(axb, height=0.25, width=1, loc=3, borderpad=0, bbox_to_anchor=(-0.14, -0.39, 1, 1), bbox_transform=axb.transAxes)
  58. ia_lower.axis('off')
  59. ia_union_im = ia_lower.imshow(union_pad, cmap='gist_gray', interpolation='none', vmin=0.0, vmax=1.0)
  60. fig.savefig(op.join('..', 'fig', 'intro_jaccard_examples.pdf'))
  61. fig.savefig(op.join('..', 'fig', 'intro_jaccard_examples.png'))
  62. fig.savefig(op.join('..', 'fig', 'intro_jaccard_examples.svg'))
  63. print(f'Jaccard = {np.sum(intersection) / np.sum(union)}')