99_illustrate_control_rdms.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # %%
  2. # imports
  3. import os.path as op
  4. import pandas as pd
  5. import numpy as np
  6. from string import ascii_lowercase
  7. chars = [*ascii_lowercase, 'ä', 'ö', 'ü', 'ß']
  8. import matplotlib.pyplot as plt
  9. from mpl_toolkits.axes_grid1 import make_axes_locatable
  10. fig_font_size = 8
  11. fig_font_size_small = 8
  12. plt.rcParams.update({
  13. "text.usetex": False,
  14. "font.family": "Helvetica",
  15. 'font.size': fig_font_size
  16. })
  17. def rankify_vec(vec):
  18. # rank values in a vector
  19. # (ties have their ranks averaged, as in R)
  20. rank_vec = vec.argsort().argsort() + 1
  21. rank_vec = rank_vec.astype(float)
  22. for uv in np.unique(vec):
  23. uv_idx = vec==uv
  24. if np.sum(uv_idx)>1:
  25. rank_vec[uv_idx] = np.mean(rank_vec[uv_idx])
  26. return rank_vec
  27. def rankify_sq_mat(mat):
  28. lwr = mat[np.tril_indices(n=mat.shape[0], k=-1)]
  29. lwr_rank = rankify_vec(lwr)
  30. mat_rank = np.zeros(mat.shape)
  31. mat_rank[np.tril_indices(n=mat.shape[0], k=-1)] = lwr_rank
  32. mat_rank[np.triu_indices(n=mat.shape[0], k=0)] = mat_rank.T[np.triu_indices(n=mat.shape[0], k=0)]
  33. mat_rank[np.diag_indices(n=mat_rank.shape[0])] = np.nan
  34. return mat_rank
  35. #%matplotlib qt
  36. # %%
  37. # import the features
  38. sz_feat_df = pd.read_csv( op.join('..', 'stim_sim', 'complexity', 'complexity_features.csv') )
  39. f_feat_df = pd.read_csv( op.join('..', 'stim_sim', 'frequency', 'frequency_features.csv') )
  40. ph_feat_df = pd.read_csv( op.join('..', 'stim_sim', 'phonology', 'dominant_phonemes_features.csv') )
  41. phn_feat_df = pd.read_csv( op.join('..', 'stim_sim', 'phonology', 'letter_names_features.csv') )
  42. # %%
  43. # import the RDMs
  44. sz_df = pd.read_csv( op.join('..', 'stim_sim', 'complexity', 'complexity.csv') )
  45. sz_mat = np.load( op.join('..', 'stim_sim', 'complexity', 'complexity.npy') )
  46. f_df = pd.read_csv( op.join('..', 'stim_sim', 'frequency', 'frequency.csv') )
  47. f_mat = np.load( op.join('..', 'stim_sim', 'frequency', 'frequency.npy') )
  48. ph_df = pd.read_csv( op.join('..', 'stim_sim', 'phonology', 'dominant_phonemes.csv') )
  49. ph_mat = np.load( op.join('..', 'stim_sim', 'phonology', 'dominant_phonemes.npy') )
  50. phn_df = pd.read_csv( op.join('..', 'stim_sim', 'phonology', 'letter_names.csv') )
  51. phn_mat = np.load( op.join('..', 'stim_sim', 'phonology', 'letter_names.npy') )
  52. # %%
  53. # build the plot
  54. fig, axs = plt.subplots(4, 4, figsize=(6, 4.8))
  55. fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.75, hspace=0.55)
  56. # titles
  57. # axs[0, 0].set_ylabel('Letter Size\n(Pixel Sum)', rotation='horizontal', va='center')
  58. # axs[1, 0].set_ylabel('Letter Frequency\n(SUBTLEX-DE Count)')
  59. # axs[2, 0].set_ylabel('Dominant Phonemes\n(PanPhon Features)')
  60. # axs[3, 0].set_ylabel('Letter Name Phonemes\n(PanPhon Features)')
  61. titles = ['Letter Size\n(Pixel Sum)',
  62. 'Letter\nFrequency\n(SUBTLEX-DE\nCount)',
  63. 'Dominant\nPhonemes\n(PanPhon\nFeatures)',
  64. 'Letter Name\nPhonemes\n(PanPhon\nFeatures)']
  65. for i, t in enumerate(titles):
  66. axs[i, 0].text(0.75, 0.5, t,
  67. transform=axs[i, 0].transAxes,
  68. size=fig_font_size,
  69. verticalalignment='center',
  70. horizontalalignment='center')
  71. axs[i, 0].axis('off')
  72. axs[0, 1].set_title('Features', fontsize=fig_font_size)
  73. axs[0, 2].set_title('Raw RDMs', fontsize=fig_font_size)
  74. axs[0, 3].set_title('Rank RDMs', fontsize=fig_font_size)
  75. # size
  76. axs[0, 1].bar(sz_feat_df.char, sz_feat_df.pixel_sum, width=1.0, color='black')
  77. axs[0, 1].set_yticks([0, 2500, 5000])
  78. sz_rdm_im = axs[0, 2].imshow(sz_mat, interpolation='none')
  79. divider = make_axes_locatable(axs[0, 2])
  80. cax = divider.append_axes('right', size='5%', pad=0.05)
  81. fig.colorbar(sz_rdm_im, cax=cax, orientation='vertical', ticks=[0, 2000, 4000])
  82. sz_rank_im = axs[0, 3].imshow(rankify_sq_mat(sz_mat), interpolation='none')
  83. divider = make_axes_locatable(axs[0, 3])
  84. cax = divider.append_axes('right', size='5%', pad=0.05)
  85. fig.colorbar(sz_rank_im, cax=cax, orientation='vertical', ticks=[1, 435])
  86. # frequency
  87. axs[1, 1].bar(f_feat_df.char, f_feat_df.freq, width=1.0, color='black')
  88. axs[1, 1].set_yticks(np.array([0, 0.5, 1, 1.5])*1e7)
  89. f_rdm_im = axs[1, 2].imshow(f_mat, interpolation='none')
  90. divider = make_axes_locatable(axs[1, 2])
  91. cax = divider.append_axes('right', size='5%', pad=0.05)
  92. fig.colorbar(f_rdm_im, cax=cax, orientation='vertical', ticks=np.array([0, 0.4, 0.8, 1.2])*1e7)
  93. f_rank_im = axs[1, 3].imshow(rankify_sq_mat(f_mat), interpolation='none')
  94. divider = make_axes_locatable(axs[1, 3])
  95. cax = divider.append_axes('right', size='5%', pad=0.05)
  96. fig.colorbar(f_rank_im, cax=cax, orientation='vertical', ticks=[1, 435])
  97. # dominant phonemes
  98. ph_feat_mat = ph_feat_df.loc[:, ph_feat_df.columns!='char'].to_numpy().T
  99. ph_im = axs[2, 1].imshow(ph_feat_mat, interpolation='none', aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
  100. axs[2, 1].set_yticks([0, 5, 11, 17, 23])
  101. axs[2, 1].set_yticklabels([1, 6, 12, 18, 24])
  102. divider = make_axes_locatable(axs[2, 1])
  103. cax = divider.append_axes('right', size='5%', pad=0.05)
  104. fig.colorbar(ph_im, cax=cax, orientation='vertical')
  105. ph_rdm_im = axs[2, 2].imshow(ph_mat, interpolation='none')
  106. divider = make_axes_locatable(axs[2, 2])
  107. cax = divider.append_axes('right', size='5%', pad=0.05)
  108. fig.colorbar(ph_rdm_im, cax=cax, orientation='vertical', ticks=[0, 0.4, 0.8, 1.2])
  109. ph_rank_im = axs[2, 3].imshow(rankify_sq_mat(ph_mat), interpolation='none', vmin=1, vmax=435)
  110. divider = make_axes_locatable(axs[2, 3])
  111. cax = divider.append_axes('right', size='5%', pad=0.05)
  112. fig.colorbar(ph_rank_im, cax=cax, orientation='vertical', ticks=[1, 435])
  113. # letter names
  114. phn_feat_mat = phn_feat_df.loc[:, phn_feat_df.columns!='char'].to_numpy().T
  115. phn_im = axs[3, 1].imshow(phn_feat_mat, interpolation='none', aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
  116. axs[3, 1].set_yticks([0, 5, 11, 17, 23])
  117. axs[3, 1].set_yticklabels([1, 6, 12, 18, 24])
  118. divider = make_axes_locatable(axs[3, 1])
  119. cax = divider.append_axes('right', size='5%', pad=0.05)
  120. fig.colorbar(phn_im, cax=cax, orientation='vertical')
  121. phn_rdm_im = axs[3, 2].imshow(phn_mat, interpolation='none')
  122. divider = make_axes_locatable(axs[3, 2])
  123. cax = divider.append_axes('right', size='5%', pad=0.05)
  124. fig.colorbar(phn_rdm_im, cax=cax, orientation='vertical', ticks=[0, 0.5, 1])
  125. phn_rank_im = axs[3, 3].imshow(rankify_sq_mat(phn_mat), interpolation='none')
  126. divider = make_axes_locatable(axs[3, 3])
  127. cax = divider.append_axes('right', size='5%', pad=0.05)
  128. fig.colorbar(phn_rank_im, cax=cax, orientation='vertical', ticks=[1, 435])
  129. # axes labels for characters
  130. for i in range(4):
  131. axs[i, 1].set_xticks([0, 14, 29])
  132. axs[i, 1].set_xticklabels(['a', '...', 'ß'])
  133. xticks = axs[i, 1].xaxis.get_major_ticks()
  134. xticks[1].get_children()[0].set_visible(False)
  135. for i in range(4):
  136. for j in range(2, 4):
  137. axs[i, j].set_xticks([0, 14, 29])
  138. axs[i, j].set_yticks([0, 14, 29])
  139. axs[i, j].set_xticklabels(['a', '...', 'ß'])
  140. axs[i, j].set_yticklabels(['a', '...', 'ß'])
  141. xticks = axs[i, j].xaxis.get_major_ticks()
  142. xticks[1].get_children()[0].set_visible(False)
  143. yticks = axs[i, j].yaxis.get_major_ticks()
  144. yticks[1].get_children()[0].set_visible(False)
  145. yticks[1].get_children()[3].set_rotation(90)
  146. fig.savefig(op.join('..', 'fig', 'illustrate_controls.pdf'))
  147. fig.savefig(op.join('..', 'fig', 'illustrate_controls.png'))