import matplotlib.pyplot as plt from matplotlib import rc from matplotlib.gridspec import GridSpec import numpy as np import os.path as op from string import ascii_lowercase chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß'] fig_font_size = 8 fig_font_size_small = 8 plt.rcParams.update({ "text.usetex": True, "font.family": "Helvetica", 'font.size': fig_font_size }) rc('text.latex', preamble='\n'.join([ r'\usepackage{tgheros}', # helvetica font r'\renewcommand\familydefault{\sfdefault} ', r'\usepackage[T1]{fontenc}' ])) m_j = np.load(op.join('..', 'stim_sim', 'preregistered', 'jacc.npy')) m_ot = np.load(op.join('..', 'stim_sim', 'preregistered', 'ot.npy')) m_j[np.diag_indices(m_j.shape[0])] = np.nan m_ot[np.diag_indices(m_ot.shape[0])] = np.nan j_lwr = m_j[np.tril_indices(n=m_j.shape[0], k=-1)] ot_lwr = m_ot[np.tril_indices(n=m_ot.shape[0], k=-1)] # check the number of unique values (need to all be unique for the argsort ranking method to work) assert(len(np.unique(j_lwr)) == len(j_lwr)) assert(len(np.unique(ot_lwr)) == len(ot_lwr)) j_lwr_rank = j_lwr.argsort().argsort() + 1 ot_lwr_rank = ot_lwr.argsort().argsort() + 1 j_rank = np.zeros(m_j.shape) ot_rank = np.zeros(m_ot.shape) j_rank[np.tril_indices(n=m_j.shape[0], k=-1)] = j_lwr_rank ot_rank[np.tril_indices(n=m_ot.shape[0], k=-1)] = ot_lwr_rank # j_rank[np.triu_indices(n=j.shape[0], k=0)] = np.nan # ot_rank[np.triu_indices(n=ot.shape[0], k=0)] = np.nan j_rank[np.triu_indices(n=m_j.shape[0], k=0)] = j_rank.T[np.triu_indices(n=m_j.shape[0], k=0)] ot_rank[np.triu_indices(n=m_ot.shape[0], k=0)] = ot_rank.T[np.triu_indices(n=m_ot.shape[0], k=0)] j_rank[np.diag_indices(n=j_rank.shape[0])] = np.nan ot_rank[np.diag_indices(n=ot_rank.shape[0])] = np.nan # %% # plot example_pairs = [['b', 'p'], ['k', 't'], ['ß', 'o'], ['j', 'w'], ['f', 'b'], ['h', 'k']] example_text_locs_raw = [(0.125, 22), (0.35, 38), (0.3, 52), (0.65, 70), (0.525, 50), (0.22, 32)] example_text_locs_rank = [(15, 250), (80, 310), (360, 50), (420, 100), (100, 400), (230, 12)] fig = plt.figure(layout = 'constrained', figsize = (4.25, 3)) gs = GridSpec(3, 4, figure=fig) axr0 = fig.add_subplot(gs[0, 0]) plt.imshow(m_j, interpolation='none') # cb = plt.colorbar() # cb.set_ticks([0.2, 0.4, 0.6, 0.8]) # cb.set_ticklabels(cb.get_ticks()) axr0.set_xticks([0, 14, 29]) axr0.set_yticks([0, 14, 29]) axr0.set_xticklabels(['a', '...', 'ß']) axr0.set_yticklabels(['a', '...', 'ß']) xticks = axr0.xaxis.get_major_ticks() xticks[1].get_children()[0].set_visible(False) yticks = axr0.yaxis.get_major_ticks() yticks[1].get_children()[0].set_visible(False) yticks[1].get_children()[3].set_rotation(90) axr0.text(-0.725, 1.2, r'\textbf{a}', transform=axr0.transAxes, size=fig_font_size) axr0.set_title('Jaccard', size=fig_font_size) axr1 = fig.add_subplot(gs[0, 1]) plt.imshow(m_ot, interpolation='none') # cb = plt.colorbar() # cb.set_ticks(np.arange(0, 20, 5)) # cb.set_ticklabels(cb.get_ticks()) axr1.set_xticks([0, 14, 29]) axr1.set_yticks([0, 14, 29]) axr1.set_xticklabels(['a', '...', 'ß']) axr1.set_yticklabels(['a', '...', 'ß']) xticks = axr1.xaxis.get_major_ticks() xticks[1].get_children()[0].set_visible(False) yticks = axr1.yaxis.get_major_ticks() yticks[1].get_children()[0].set_visible(False) yticks[1].get_children()[3].set_rotation(90) # axr1.text(-0.625, 1.2, r'\textbf{a2}', transform=axr1.transAxes, size=12) axr1.set_title('Wasserstein', size=fig_font_size) axr2 = fig.add_subplot(gs[1:, :2]) axr2.scatter(j_lwr, ot_lwr, color='k', s=3) # axr2.set_aspect('equal') axr2.set_xlabel('Jaccard Distance') axr2.set_ylabel('Wasserstein Distance') axr2.spines[['right', 'top']].set_visible(False) axr2.set_xticks([0.2, 0.4, 0.6, 0.8]) # axr2.set_yticks([20, 40, 60, 80, 100]) axr2.set_xticklabels(axr2.get_xticks()) axr2.set_yticklabels(axr2.get_yticks()) for i, (a, b) in enumerate(example_pairs): a_i = chars.index(a) b_i = chars.index(b) axr2.scatter(m_j[a_i, b_i], m_ot[a_i, b_i], color='r', s=3.1) if example_text_locs_raw[i] is None: text_locs = (m_j[a_i, b_i], m_ot[a_i, b_i]) else: text_locs = example_text_locs_raw[i] axr2.plot((m_j[a_i, b_i], text_locs[0]), (m_ot[a_i, b_i], text_locs[1]), 'r-', lw=0.5) xytext_offset = (-0.7, -0.75) if text_locs[1] < m_ot[a_i, b_i] else (-0.7, 0.5) axr2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small) ax0 = fig.add_subplot(gs[0, 2]) plt.imshow(j_rank, interpolation='none') ax0.text(-0.85, 1.2, r'\textbf{b}', transform=ax0.transAxes, size=fig_font_size) # cb = plt.colorbar() # cb.set_ticks([0, 434]) # cb.set_ticklabels(['0', '434']) ax0.set_title('Jaccard', size=fig_font_size) ax0.set_xticks([0, 14, 29]) ax0.set_yticks([0, 14, 29]) ax0.set_xticklabels(['a', '...', 'ß']) ax0.set_yticklabels(['a', '...', 'ß']) xticks = ax0.xaxis.get_major_ticks() xticks[1].get_children()[0].set_visible(False) yticks = ax0.yaxis.get_major_ticks() yticks[1].get_children()[0].set_visible(False) yticks[1].get_children()[3].set_rotation(90) ax1 = fig.add_subplot(gs[0, 3]) plt.imshow(ot_rank, interpolation='none') # ax1.text(-0.625, 1.2, r'\textbf{b2}', transform=ax1.transAxes, size=12) # cb = plt.colorbar() # # cb.set_label('Rank', rotation=270, labelpad=-7.5) # cb.set_ticks([0, 434]) # cb.set_ticklabels(['0', '434']) ax1.set_title('Wasserstein', size=fig_font_size) ax1.set_xticks([0, 14, 29]) ax1.set_yticks([0, 14, 29]) ax1.set_xticklabels(['a', '...', 'ß']) ax1.set_yticklabels(['a', '...', 'ß']) xticks = ax1.xaxis.get_major_ticks() xticks[1].get_children()[0].set_visible(False) yticks = ax1.yaxis.get_major_ticks() yticks[1].get_children()[0].set_visible(False) yticks[1].get_children()[3].set_rotation(90) ax2_ticks = np.arange(0, 434, 100) ax2 = fig.add_subplot(gs[1:, 2:]) ax2.scatter(j_lwr_rank, ot_lwr_rank, color='k', s=3) # ax2.set_aspect('equal') ax2.set_xlabel('Rank Jaccard Distance') ax2.set_ylabel('\nRank Wasserstein Distance') # ax2.text(-0.225, 1.05, r'\textbf{b3}', transform=ax2.transAxes, size=12) # ax2.set_title('Rank Correlation', size=10) ax2.spines[['right', 'top']].set_visible(False) ax2.set_xticks(ax2_ticks) ax2.set_yticks(ax2_ticks) ax2.set_xticklabels(ax2_ticks) ax2.set_yticklabels(ax2_ticks) for i, (a, b) in enumerate(example_pairs): a_i = chars.index(a) b_i = chars.index(b) ax2.scatter(j_rank[a_i, b_i], ot_rank[a_i, b_i], color='r', s=3) if example_text_locs_rank[i] is None: text_locs = (j_rank[a_i, b_i], ot_rank[a_i, b_i]) else: text_locs = example_text_locs_rank[i] ax2.plot((j_rank[a_i, b_i], text_locs[0]), (ot_rank[a_i, b_i], text_locs[1]), 'r-', lw=1) xytext_offset = (-0.7, -0.75) if text_locs[1] < ot_rank[a_i, b_i] else (-0.7, 0.5) ax2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small) fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.pdf')) fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.png')) fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.svg'))