123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import matplotlib.pyplot as plt
- from matplotlib import rc
- from matplotlib.gridspec import GridSpec
- import numpy as np
- import os.path as op
- from string import ascii_lowercase
- chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
- fig_font_size = 8
- fig_font_size_small = 8
- plt.rcParams.update({
- "text.usetex": True,
- "font.family": "Helvetica",
- 'font.size': fig_font_size
- })
- rc('text.latex', preamble='\n'.join([
- r'\usepackage{tgheros}', # helvetica font
- r'\renewcommand\familydefault{\sfdefault} ',
- r'\usepackage[T1]{fontenc}'
- ]))
- m_j = np.load(op.join('..', 'stim_sim', 'preregistered', 'jacc.npy'))
- m_ot = np.load(op.join('..', 'stim_sim', 'preregistered', 'ot.npy'))
- m_j[np.diag_indices(m_j.shape[0])] = np.nan
- m_ot[np.diag_indices(m_ot.shape[0])] = np.nan
- j_lwr = m_j[np.tril_indices(n=m_j.shape[0], k=-1)]
- ot_lwr = m_ot[np.tril_indices(n=m_ot.shape[0], k=-1)]
- # check the number of unique values (need to all be unique for the argsort ranking method to work)
- assert(len(np.unique(j_lwr)) == len(j_lwr))
- assert(len(np.unique(ot_lwr)) == len(ot_lwr))
- j_lwr_rank = j_lwr.argsort().argsort() + 1
- ot_lwr_rank = ot_lwr.argsort().argsort() + 1
- j_rank = np.zeros(m_j.shape)
- ot_rank = np.zeros(m_ot.shape)
- j_rank[np.tril_indices(n=m_j.shape[0], k=-1)] = j_lwr_rank
- ot_rank[np.tril_indices(n=m_ot.shape[0], k=-1)] = ot_lwr_rank
- # j_rank[np.triu_indices(n=j.shape[0], k=0)] = np.nan
- # ot_rank[np.triu_indices(n=ot.shape[0], k=0)] = np.nan
- j_rank[np.triu_indices(n=m_j.shape[0], k=0)] = j_rank.T[np.triu_indices(n=m_j.shape[0], k=0)]
- ot_rank[np.triu_indices(n=m_ot.shape[0], k=0)] = ot_rank.T[np.triu_indices(n=m_ot.shape[0], k=0)]
- j_rank[np.diag_indices(n=j_rank.shape[0])] = np.nan
- ot_rank[np.diag_indices(n=ot_rank.shape[0])] = np.nan
- # %%
- # plot
- example_pairs = [['b', 'p'], ['k', 't'], ['ß', 'o'], ['j', 'w'], ['f', 'b'], ['h', 'k']]
- example_text_locs_raw = [(0.125, 22), (0.35, 38), (0.3, 52), (0.65, 70), (0.525, 50), (0.22, 32)]
- example_text_locs_rank = [(15, 250), (80, 310), (360, 50), (420, 100), (100, 400), (230, 12)]
- fig = plt.figure(layout = 'constrained', figsize = (4.25, 3))
- gs = GridSpec(3, 4, figure=fig)
- axr0 = fig.add_subplot(gs[0, 0])
- plt.imshow(m_j, interpolation='none')
- # cb = plt.colorbar()
- # cb.set_ticks([0.2, 0.4, 0.6, 0.8])
- # cb.set_ticklabels(cb.get_ticks())
- axr0.set_xticks([0, 14, 29])
- axr0.set_yticks([0, 14, 29])
- axr0.set_xticklabels(['a', '...', 'ß'])
- axr0.set_yticklabels(['a', '...', 'ß'])
- xticks = axr0.xaxis.get_major_ticks()
- xticks[1].get_children()[0].set_visible(False)
- yticks = axr0.yaxis.get_major_ticks()
- yticks[1].get_children()[0].set_visible(False)
- yticks[1].get_children()[3].set_rotation(90)
- axr0.text(-0.725, 1.2, r'\textbf{a}', transform=axr0.transAxes, size=fig_font_size)
- axr0.set_title('Jaccard', size=fig_font_size)
- axr1 = fig.add_subplot(gs[0, 1])
- plt.imshow(m_ot, interpolation='none')
- # cb = plt.colorbar()
- # cb.set_ticks(np.arange(0, 20, 5))
- # cb.set_ticklabels(cb.get_ticks())
- axr1.set_xticks([0, 14, 29])
- axr1.set_yticks([0, 14, 29])
- axr1.set_xticklabels(['a', '...', 'ß'])
- axr1.set_yticklabels(['a', '...', 'ß'])
- xticks = axr1.xaxis.get_major_ticks()
- xticks[1].get_children()[0].set_visible(False)
- yticks = axr1.yaxis.get_major_ticks()
- yticks[1].get_children()[0].set_visible(False)
- yticks[1].get_children()[3].set_rotation(90)
- # axr1.text(-0.625, 1.2, r'\textbf{a2}', transform=axr1.transAxes, size=12)
- axr1.set_title('Wasserstein', size=fig_font_size)
- axr2 = fig.add_subplot(gs[1:, :2])
- axr2.scatter(j_lwr, ot_lwr, color='k', s=3)
- # axr2.set_aspect('equal')
- axr2.set_xlabel('Jaccard Distance')
- axr2.set_ylabel('Wasserstein Distance')
- axr2.spines[['right', 'top']].set_visible(False)
- axr2.set_xticks([0.2, 0.4, 0.6, 0.8])
- # axr2.set_yticks([20, 40, 60, 80, 100])
- axr2.set_xticklabels(axr2.get_xticks())
- axr2.set_yticklabels(axr2.get_yticks())
- for i, (a, b) in enumerate(example_pairs):
- a_i = chars.index(a)
- b_i = chars.index(b)
- axr2.scatter(m_j[a_i, b_i], m_ot[a_i, b_i], color='r', s=3.1)
- if example_text_locs_raw[i] is None:
- text_locs = (m_j[a_i, b_i], m_ot[a_i, b_i])
- else:
- text_locs = example_text_locs_raw[i]
- axr2.plot((m_j[a_i, b_i], text_locs[0]), (m_ot[a_i, b_i], text_locs[1]), 'r-', lw=0.5)
- xytext_offset = (-0.7, -0.75) if text_locs[1] < m_ot[a_i, b_i] else (-0.7, 0.5)
- axr2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small)
- ax0 = fig.add_subplot(gs[0, 2])
- plt.imshow(j_rank, interpolation='none')
- ax0.text(-0.85, 1.2, r'\textbf{b}', transform=ax0.transAxes, size=fig_font_size)
- # cb = plt.colorbar()
- # cb.set_ticks([0, 434])
- # cb.set_ticklabels(['0', '434'])
- ax0.set_title('Jaccard', size=fig_font_size)
- ax0.set_xticks([0, 14, 29])
- ax0.set_yticks([0, 14, 29])
- ax0.set_xticklabels(['a', '...', 'ß'])
- ax0.set_yticklabels(['a', '...', 'ß'])
- xticks = ax0.xaxis.get_major_ticks()
- xticks[1].get_children()[0].set_visible(False)
- yticks = ax0.yaxis.get_major_ticks()
- yticks[1].get_children()[0].set_visible(False)
- yticks[1].get_children()[3].set_rotation(90)
- ax1 = fig.add_subplot(gs[0, 3])
- plt.imshow(ot_rank, interpolation='none')
- # ax1.text(-0.625, 1.2, r'\textbf{b2}', transform=ax1.transAxes, size=12)
- # cb = plt.colorbar()
- # # cb.set_label('Rank', rotation=270, labelpad=-7.5)
- # cb.set_ticks([0, 434])
- # cb.set_ticklabels(['0', '434'])
- ax1.set_title('Wasserstein', size=fig_font_size)
- ax1.set_xticks([0, 14, 29])
- ax1.set_yticks([0, 14, 29])
- ax1.set_xticklabels(['a', '...', 'ß'])
- ax1.set_yticklabels(['a', '...', 'ß'])
- xticks = ax1.xaxis.get_major_ticks()
- xticks[1].get_children()[0].set_visible(False)
- yticks = ax1.yaxis.get_major_ticks()
- yticks[1].get_children()[0].set_visible(False)
- yticks[1].get_children()[3].set_rotation(90)
- ax2_ticks = np.arange(0, 434, 100)
- ax2 = fig.add_subplot(gs[1:, 2:])
- ax2.scatter(j_lwr_rank, ot_lwr_rank, color='k', s=3)
- # ax2.set_aspect('equal')
- ax2.set_xlabel('Rank Jaccard Distance')
- ax2.set_ylabel('\nRank Wasserstein Distance')
- # ax2.text(-0.225, 1.05, r'\textbf{b3}', transform=ax2.transAxes, size=12)
- # ax2.set_title('Rank Correlation', size=10)
- ax2.spines[['right', 'top']].set_visible(False)
- ax2.set_xticks(ax2_ticks)
- ax2.set_yticks(ax2_ticks)
- ax2.set_xticklabels(ax2_ticks)
- ax2.set_yticklabels(ax2_ticks)
- for i, (a, b) in enumerate(example_pairs):
- a_i = chars.index(a)
- b_i = chars.index(b)
- ax2.scatter(j_rank[a_i, b_i], ot_rank[a_i, b_i], color='r', s=3)
- if example_text_locs_rank[i] is None:
- text_locs = (j_rank[a_i, b_i], ot_rank[a_i, b_i])
- else:
- text_locs = example_text_locs_rank[i]
- ax2.plot((j_rank[a_i, b_i], text_locs[0]), (ot_rank[a_i, b_i], text_locs[1]), 'r-', lw=1)
- xytext_offset = (-0.7, -0.75) if text_locs[1] < ot_rank[a_i, b_i] else (-0.7, 0.5)
- ax2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small)
- fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.pdf'))
- fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.png'))
- fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.svg'))
|