99_rdms_corr.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import matplotlib.pyplot as plt
  2. from matplotlib import rc
  3. from matplotlib.gridspec import GridSpec
  4. import numpy as np
  5. import os.path as op
  6. from string import ascii_lowercase
  7. chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
  8. fig_font_size = 8
  9. fig_font_size_small = 8
  10. plt.rcParams.update({
  11. "text.usetex": True,
  12. "font.family": "Helvetica",
  13. 'font.size': fig_font_size
  14. })
  15. rc('text.latex', preamble='\n'.join([
  16. r'\usepackage{tgheros}', # helvetica font
  17. r'\renewcommand\familydefault{\sfdefault} ',
  18. r'\usepackage[T1]{fontenc}'
  19. ]))
  20. m_j = np.load(op.join('..', 'stim_sim', 'preregistered', 'jacc.npy'))
  21. m_ot = np.load(op.join('..', 'stim_sim', 'preregistered', 'ot.npy'))
  22. m_j[np.diag_indices(m_j.shape[0])] = np.nan
  23. m_ot[np.diag_indices(m_ot.shape[0])] = np.nan
  24. j_lwr = m_j[np.tril_indices(n=m_j.shape[0], k=-1)]
  25. ot_lwr = m_ot[np.tril_indices(n=m_ot.shape[0], k=-1)]
  26. # check the number of unique values (need to all be unique for the argsort ranking method to work)
  27. assert(len(np.unique(j_lwr)) == len(j_lwr))
  28. assert(len(np.unique(ot_lwr)) == len(ot_lwr))
  29. j_lwr_rank = j_lwr.argsort().argsort() + 1
  30. ot_lwr_rank = ot_lwr.argsort().argsort() + 1
  31. j_rank = np.zeros(m_j.shape)
  32. ot_rank = np.zeros(m_ot.shape)
  33. j_rank[np.tril_indices(n=m_j.shape[0], k=-1)] = j_lwr_rank
  34. ot_rank[np.tril_indices(n=m_ot.shape[0], k=-1)] = ot_lwr_rank
  35. # j_rank[np.triu_indices(n=j.shape[0], k=0)] = np.nan
  36. # ot_rank[np.triu_indices(n=ot.shape[0], k=0)] = np.nan
  37. j_rank[np.triu_indices(n=m_j.shape[0], k=0)] = j_rank.T[np.triu_indices(n=m_j.shape[0], k=0)]
  38. ot_rank[np.triu_indices(n=m_ot.shape[0], k=0)] = ot_rank.T[np.triu_indices(n=m_ot.shape[0], k=0)]
  39. j_rank[np.diag_indices(n=j_rank.shape[0])] = np.nan
  40. ot_rank[np.diag_indices(n=ot_rank.shape[0])] = np.nan
  41. # %%
  42. # plot
  43. example_pairs = [['b', 'p'], ['k', 't'], ['ß', 'o'], ['j', 'w'], ['f', 'b'], ['h', 'k']]
  44. example_text_locs_raw = [(0.125, 22), (0.35, 38), (0.3, 52), (0.65, 70), (0.525, 50), (0.22, 32)]
  45. example_text_locs_rank = [(15, 250), (80, 310), (360, 50), (420, 100), (100, 400), (230, 12)]
  46. fig = plt.figure(layout = 'constrained', figsize = (4.25, 3))
  47. gs = GridSpec(3, 4, figure=fig)
  48. axr0 = fig.add_subplot(gs[0, 0])
  49. plt.imshow(m_j, interpolation='none')
  50. # cb = plt.colorbar()
  51. # cb.set_ticks([0.2, 0.4, 0.6, 0.8])
  52. # cb.set_ticklabels(cb.get_ticks())
  53. axr0.set_xticks([0, 14, 29])
  54. axr0.set_yticks([0, 14, 29])
  55. axr0.set_xticklabels(['a', '...', 'ß'])
  56. axr0.set_yticklabels(['a', '...', 'ß'])
  57. xticks = axr0.xaxis.get_major_ticks()
  58. xticks[1].get_children()[0].set_visible(False)
  59. yticks = axr0.yaxis.get_major_ticks()
  60. yticks[1].get_children()[0].set_visible(False)
  61. yticks[1].get_children()[3].set_rotation(90)
  62. axr0.text(-0.725, 1.2, r'\textbf{a}', transform=axr0.transAxes, size=fig_font_size)
  63. axr0.set_title('Jaccard', size=fig_font_size)
  64. axr1 = fig.add_subplot(gs[0, 1])
  65. plt.imshow(m_ot, interpolation='none')
  66. # cb = plt.colorbar()
  67. # cb.set_ticks(np.arange(0, 20, 5))
  68. # cb.set_ticklabels(cb.get_ticks())
  69. axr1.set_xticks([0, 14, 29])
  70. axr1.set_yticks([0, 14, 29])
  71. axr1.set_xticklabels(['a', '...', 'ß'])
  72. axr1.set_yticklabels(['a', '...', 'ß'])
  73. xticks = axr1.xaxis.get_major_ticks()
  74. xticks[1].get_children()[0].set_visible(False)
  75. yticks = axr1.yaxis.get_major_ticks()
  76. yticks[1].get_children()[0].set_visible(False)
  77. yticks[1].get_children()[3].set_rotation(90)
  78. # axr1.text(-0.625, 1.2, r'\textbf{a2}', transform=axr1.transAxes, size=12)
  79. axr1.set_title('Wasserstein', size=fig_font_size)
  80. axr2 = fig.add_subplot(gs[1:, :2])
  81. axr2.scatter(j_lwr, ot_lwr, color='k', s=3)
  82. # axr2.set_aspect('equal')
  83. axr2.set_xlabel('Jaccard Distance')
  84. axr2.set_ylabel('Wasserstein Distance')
  85. axr2.spines[['right', 'top']].set_visible(False)
  86. axr2.set_xticks([0.2, 0.4, 0.6, 0.8])
  87. # axr2.set_yticks([20, 40, 60, 80, 100])
  88. axr2.set_xticklabels(axr2.get_xticks())
  89. axr2.set_yticklabels(axr2.get_yticks())
  90. for i, (a, b) in enumerate(example_pairs):
  91. a_i = chars.index(a)
  92. b_i = chars.index(b)
  93. axr2.scatter(m_j[a_i, b_i], m_ot[a_i, b_i], color='r', s=3.1)
  94. if example_text_locs_raw[i] is None:
  95. text_locs = (m_j[a_i, b_i], m_ot[a_i, b_i])
  96. else:
  97. text_locs = example_text_locs_raw[i]
  98. axr2.plot((m_j[a_i, b_i], text_locs[0]), (m_ot[a_i, b_i], text_locs[1]), 'r-', lw=0.5)
  99. xytext_offset = (-0.7, -0.75) if text_locs[1] < m_ot[a_i, b_i] else (-0.7, 0.5)
  100. axr2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small)
  101. ax0 = fig.add_subplot(gs[0, 2])
  102. plt.imshow(j_rank, interpolation='none')
  103. ax0.text(-0.85, 1.2, r'\textbf{b}', transform=ax0.transAxes, size=fig_font_size)
  104. # cb = plt.colorbar()
  105. # cb.set_ticks([0, 434])
  106. # cb.set_ticklabels(['0', '434'])
  107. ax0.set_title('Jaccard', size=fig_font_size)
  108. ax0.set_xticks([0, 14, 29])
  109. ax0.set_yticks([0, 14, 29])
  110. ax0.set_xticklabels(['a', '...', 'ß'])
  111. ax0.set_yticklabels(['a', '...', 'ß'])
  112. xticks = ax0.xaxis.get_major_ticks()
  113. xticks[1].get_children()[0].set_visible(False)
  114. yticks = ax0.yaxis.get_major_ticks()
  115. yticks[1].get_children()[0].set_visible(False)
  116. yticks[1].get_children()[3].set_rotation(90)
  117. ax1 = fig.add_subplot(gs[0, 3])
  118. plt.imshow(ot_rank, interpolation='none')
  119. # ax1.text(-0.625, 1.2, r'\textbf{b2}', transform=ax1.transAxes, size=12)
  120. # cb = plt.colorbar()
  121. # # cb.set_label('Rank', rotation=270, labelpad=-7.5)
  122. # cb.set_ticks([0, 434])
  123. # cb.set_ticklabels(['0', '434'])
  124. ax1.set_title('Wasserstein', size=fig_font_size)
  125. ax1.set_xticks([0, 14, 29])
  126. ax1.set_yticks([0, 14, 29])
  127. ax1.set_xticklabels(['a', '...', 'ß'])
  128. ax1.set_yticklabels(['a', '...', 'ß'])
  129. xticks = ax1.xaxis.get_major_ticks()
  130. xticks[1].get_children()[0].set_visible(False)
  131. yticks = ax1.yaxis.get_major_ticks()
  132. yticks[1].get_children()[0].set_visible(False)
  133. yticks[1].get_children()[3].set_rotation(90)
  134. ax2_ticks = np.arange(0, 434, 100)
  135. ax2 = fig.add_subplot(gs[1:, 2:])
  136. ax2.scatter(j_lwr_rank, ot_lwr_rank, color='k', s=3)
  137. # ax2.set_aspect('equal')
  138. ax2.set_xlabel('Rank Jaccard Distance')
  139. ax2.set_ylabel('\nRank Wasserstein Distance')
  140. # ax2.text(-0.225, 1.05, r'\textbf{b3}', transform=ax2.transAxes, size=12)
  141. # ax2.set_title('Rank Correlation', size=10)
  142. ax2.spines[['right', 'top']].set_visible(False)
  143. ax2.set_xticks(ax2_ticks)
  144. ax2.set_yticks(ax2_ticks)
  145. ax2.set_xticklabels(ax2_ticks)
  146. ax2.set_yticklabels(ax2_ticks)
  147. for i, (a, b) in enumerate(example_pairs):
  148. a_i = chars.index(a)
  149. b_i = chars.index(b)
  150. ax2.scatter(j_rank[a_i, b_i], ot_rank[a_i, b_i], color='r', s=3)
  151. if example_text_locs_rank[i] is None:
  152. text_locs = (j_rank[a_i, b_i], ot_rank[a_i, b_i])
  153. else:
  154. text_locs = example_text_locs_rank[i]
  155. ax2.plot((j_rank[a_i, b_i], text_locs[0]), (ot_rank[a_i, b_i], text_locs[1]), 'r-', lw=1)
  156. xytext_offset = (-0.7, -0.75) if text_locs[1] < ot_rank[a_i, b_i] else (-0.7, 0.5)
  157. ax2.annotate(f'{a}-{b}', text_locs, xytext=xytext_offset, textcoords='offset fontsize', color='r', size=fig_font_size_small)
  158. fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.pdf'))
  159. fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.png'))
  160. fig.savefig(op.join('..', 'fig', 'intro_stim_rdms.svg'))