plot_model_estimates.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #!/user/bin/env python
  2. # coding=utf-8
  3. """
  4. @author: yannansu
  5. @created at: 03.08.21 14:07
  6. Generate publication-quality figures of modeling estimates.
  7. """
  8. import matplotlib as mpl
  9. import matplotlib.pyplot as plt
  10. import seaborn as sns
  11. from matplotlib.ticker import MultipleLocator, FormatStrFormatter, MaxNLocator
  12. import pandas as pd
  13. import numpy as np
  14. from data_analysis.color4plot import color4plot
  15. import pickle
  16. from data_analysis.estimate_likelihood import rect_sin
  17. # Figure configuration
  18. plt.style.use('data_analysis/figures/plot_style.txt')
  19. sns.set_context('paper')
  20. # mpl.rcParams['pdf.fonttype'] = 42
  21. # mpl.rcParams['ps.fonttype'] = 42
  22. # mpl.rcParams['font.family'] = 'Arial'
  23. alpha = .6
  24. capsize = 4
  25. fist_x = 22.5
  26. # x_ticks = np.linspace(0 + fist_x, 360 + fist_x, 8, endpoint=False)
  27. x_major_ticks = np.array([0, 90, 180, 270, 360])
  28. x_minor_ticks = np.linspace(0, 360, 8, endpoint=False)
  29. x_grid = np.load("data_analysis/model_estimates/x_grid.npy")
  30. def rep_end(dat):
  31. """
  32. Duplicate the first and the last stimuli for visualization.
  33. """
  34. first = dat[dat['Hue Angle'] == dat['Hue Angle'].min()]
  35. first_copy = first.copy()
  36. first_copy['Hue Angle'] = first_copy['Hue Angle'].apply(lambda x: x + 360)
  37. last = dat[dat['Hue Angle'] == dat['Hue Angle'].max()]
  38. last_copy = last.copy()
  39. last_copy['Hue Angle'] = last_copy['Hue Angle'].apply(lambda x: x - 360)
  40. rep_dat = pd.concat([dat, first_copy, last_copy],
  41. ignore_index=False).sort_values('Hue Angle')
  42. return rep_dat
  43. def plot_jnd_fits(estm, lh_jnd_fits, save_pdf=None):
  44. d_l = rep_end(estm.query('condition == "LL"'))
  45. d_h = rep_end(estm.query('condition == "HH"'))
  46. hue_angles = d_l['Hue Angle'].values
  47. jnd_l = d_l['JND'].values
  48. jnd_h = d_h['JND'].values
  49. pickle_l, pickle_h = lh_jnd_fits
  50. with open(pickle_l, 'rb') as f:
  51. fit_jnd_l = pickle.load(f)
  52. with open(pickle_h, 'rb') as f:
  53. fit_jnd_h = pickle.load(f)
  54. fig, ax = plt.subplots(1, 2, figsize=[3.45 * 2, 3.45])
  55. ylim = 16
  56. # LL
  57. ax[0].set_title('L vs. L')
  58. ax[0].plot(hue_angles, jnd_l, 'o', markersize=8, color='gray')
  59. ax[0].errorbar(x=hue_angles, y=jnd_l, yerr=d_l['JND_err'], ls='none', capsize=capsize, color='gray')
  60. ax[0].plot(x_grid, rect_sin(x_grid, *fit_jnd_l['params']), color='gray')
  61. ax[0].fill_between(hue_angles, fit_jnd_l['bounds'][0], fit_jnd_l['bounds'][1], alpha=0.2, color='gray')
  62. ax[0].set_ylim([0, ylim])
  63. # ax[0].set_ylim([0, 15])
  64. # HH
  65. ax[1].set_title('H vs. H')
  66. ax[1].plot(hue_angles, jnd_h, 'o', markersize=8, color='gray')
  67. ax[1].errorbar(x=hue_angles, y=jnd_h, yerr=d_h['JND_err'], ls='none', capsize=capsize, color='gray')
  68. ax[1].plot(x_grid, rect_sin(x_grid, *fit_jnd_h['params']), color='gray')
  69. ax[1].fill_between(hue_angles, fit_jnd_h['bounds'][0], fit_jnd_h['bounds'][1], alpha=0.2, color='gray')
  70. # ax[1].set_ylim([0, ylim])
  71. ax[1].set_ylim([0, 25])
  72. # Set axis info
  73. for i in range(2):
  74. ax[i].set_xlim([0, 360])
  75. ax[i].xaxis.set_minor_locator(plt.FixedLocator(x_minor_ticks))
  76. ax[i].xaxis.set_major_locator(plt.FixedLocator(x_major_ticks))
  77. ax[i].set_xlabel('Hue Angle (deg)')
  78. ax[i].set_ylabel('JND (deg)')
  79. plt.tight_layout()
  80. if save_pdf is not None:
  81. plt.savefig('data_analysis/figures/modeling/' + save_pdf + '.pdf')
  82. plt.show()
  83. def plot_jnd_fits_both(estm, lh_jnd_fits, save_pdf=None):
  84. d_l = rep_end(estm.query('condition == "LL"'))
  85. d_h = rep_end(estm.query('condition == "HH"'))
  86. hue_angles = d_l['Hue Angle'].values
  87. jnd_l = d_l['JND'].values
  88. jnd_h = d_h['JND'].values
  89. pickle_l, pickle_h = lh_jnd_fits
  90. with open(pickle_l, 'rb') as f:
  91. fit_jnd_l = pickle.load(f)
  92. with open(pickle_h, 'rb') as f:
  93. fit_jnd_h = pickle.load(f)
  94. fig, ax = plt.subplots(1, 1, figsize=[3.45, 3.45])
  95. ylim = 25
  96. # LL
  97. ax.set_title('JND fitting')
  98. ax.scatter(hue_angles, jnd_l, marker='o', s=15, facecolors='none', edgecolors='gray', label='L vs. L')
  99. ax.errorbar(x=hue_angles, y=jnd_l, yerr=d_l['JND_err'], ls='none', capsize=capsize, color='gray')
  100. ax.plot(x_grid, rect_sin(x_grid, *fit_jnd_l['params']), color='gray')
  101. ax.fill_between(hue_angles, fit_jnd_l['bounds'][0], fit_jnd_l['bounds'][1], alpha=0.2, color='gray')
  102. # HH
  103. ax.scatter(hue_angles, jnd_h, marker='s', s=15, facecolors='none', edgecolors='gray', label='H vs. H')
  104. ax.errorbar(x=hue_angles, y=jnd_h, yerr=d_h['JND_err'], ls='none', capsize=capsize, color='gray')
  105. ax.plot(x_grid, rect_sin(x_grid, *fit_jnd_h['params']), color='gray')
  106. ax.fill_between(hue_angles, fit_jnd_h['bounds'][0], fit_jnd_h['bounds'][1], alpha=0.2, color='gray')
  107. ax.set_ylim([0, ylim])
  108. # ax.set_yscale('log')
  109. # ax.set_ylim([10**(-0.15), 10**1.5])
  110. ax.set_xlim([0, 360])
  111. ax.xaxis.set_minor_locator(plt.FixedLocator(x_minor_ticks))
  112. ax.xaxis.set_major_locator(plt.FixedLocator(x_major_ticks))
  113. ax.set_xlabel('Hue Angle (deg)')
  114. ax.set_ylabel('JND (deg)')
  115. leg = ax.legend(loc=(0.65, 0.85))
  116. leg.get_frame().set_linewidth(0.0)
  117. leg.get_frame().set_facecolor('none')
  118. plt.tight_layout()
  119. if save_pdf is not None:
  120. plt.savefig('data_analysis/figures/modeling/' + save_pdf + '.pdf')
  121. plt.show()
  122. def plot_matrix(lh_mat, lh_edges, save_pdf=None):
  123. """
  124. :param lh_mat:
  125. :param lh_edges:
  126. :param save_pdf:
  127. :return:
  128. """
  129. mat_l, mat_h = lh_mat
  130. edges_l, edges_h = lh_edges
  131. # fig, ax = plt.subplots(2, 1, figsize=[3.45, 3.45 * 2 + 0.5])
  132. fig, ax = plt.subplots(1, 2, figsize=[3.45*2 + 0.5, 3.45])
  133. extent_l = [edges_l[0][0], edges_l[0][-1], edges_l[1][0], edges_l[1][-1]]
  134. extent_h = [edges_h[0][0], edges_h[0][-1], edges_h[1][0], edges_h[1][-1]]
  135. ax[0].set_title('Low-noise')
  136. im_l = ax[0].imshow(mat_l, origin='lower', cmap='gray', extent=extent_l)
  137. cb_l = fig.colorbar(im_l, ax=ax[0])
  138. cb_l.formatter.set_powerlimits((0, 0)) # , shrink=0.65)
  139. ax[1].set_title('High-noise')
  140. im_h = ax[1].imshow(mat_h, origin='lower', cmap='gray', extent=extent_h)
  141. cb_h = fig.colorbar(im_h, ax=ax[1])
  142. cb_h.formatter.set_powerlimits((0, 0)) # , shrink=0.65)
  143. # Set axis info
  144. for i in range(2):
  145. ax[i].set_xlim([0, 360])
  146. ax[i].xaxis.set_minor_locator(plt.FixedLocator(x_minor_ticks))
  147. ax[i].xaxis.set_major_locator(plt.FixedLocator(x_major_ticks))
  148. ax[i].set_xlabel('true stimulus (deg)')
  149. ax[i].set_ylim([0, 360])
  150. ax[i].yaxis.set_minor_locator(plt.FixedLocator(x_minor_ticks))
  151. ax[i].yaxis.set_major_locator(plt.FixedLocator(x_major_ticks))
  152. ax[i].set_ylabel('measurement (deg)')
  153. plt.tight_layout()
  154. if save_pdf is not None:
  155. plt.savefig('data_analysis/figures/modeling/' + save_pdf + '.pdf')
  156. plt.show()
  157. """
  158. ################# Running plotting ######################
  159. """
  160. """
  161. # For average sub
  162. sub = 'sAVG'
  163. sub_estimates = pd.read_csv('data_analysis/pf_estimates/avg_estimates.csv')
  164. """
  165. """
  166. # For a single subject
  167. sub = 's1'
  168. all_estimates = pd.read_csv('data_analysis/pf_estimates/all_estimates.csv')
  169. sub_estimates = all_estimates.query('subject == @sub')
  170. """
  171. """
  172. # Make plots
  173. dir = 'data_analysis/model_estimates/'
  174. sub_lh_jnd_fits = [dir + sub + '/' + sub + '_jnd_fit_l.pickle',
  175. dir + sub + '/' + sub + '_jnd_fit_h.pickle']
  176. sub_lh_mat = [np.load(dir + sub + '/' + sub + '_mat_l.npy'),
  177. np.load(dir + sub + '/' + sub + '_mat_h.npy')]
  178. sub_lh_edges = [np.load(dir + sub + '/' + sub + '_mat_edges_l.npy'),
  179. np.load(dir + sub + '/' + sub + '_mat_edges_h.npy')]
  180. plot_jnd_fits_both(sub_estimates, sub_lh_jnd_fits, save_pdf=sub + '_jnd_fits')
  181. plot_matrix(sub_lh_mat, sub_lh_edges, save_pdf=sub + '_matrix')
  182. """
  183. """
  184. # Iterate for all subjects
  185. sub_list = ['s1', 's2', 's3', 's4', 's5', 's6', 'sAVG']
  186. for sub in sub_list:
  187. if sub == 'sAVG':
  188. sub_estimates = pd.read_csv('data_analysis/pf_estimates/avg_estimates.csv')
  189. else:
  190. all_estimates = pd.read_csv('data_analysis/pf_estimates/all_estimates.csv')
  191. sub_estimates = all_estimates.query('subject == @sub')
  192. dir = 'data_analysis/model_estimates/'
  193. sub_lh_jnd_fits = [dir + sub + '/' + sub + '_jnd_fit_l.pickle',
  194. dir + sub + '/' + sub + '_jnd_fit_h.pickle']
  195. sub_lh_mat = [np.load(dir + sub + '/' + sub + '_mat_l.npy'),
  196. np.load(dir + sub + '/' + sub + '_mat_h.npy')]
  197. sub_lh_edges = [np.load(dir + sub + '/' + sub + '_mat_edges_l.npy'),
  198. np.load(dir + sub + '/' + sub + '_mat_edges_h.npy')]
  199. plot_jnd_fits_both(sub_estimates, sub_lh_jnd_fits, save_pdf=sub + '_jnd_fits')
  200. plot_matrix(sub_lh_mat, sub_lh_edges, save_pdf=sub + '_matrix')
  201. """