Fig_S3.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. from scipy.io import loadmat
  5. from pathlib import Path
  6. from c_inv.inference_paper_submission.util import (contrast_log_scale, replace_imshow_ticklabels, set_mpl_pars, cm2inch,
  7. JNEUROSCIPARS, excitcol, inhibcol, population_response, svd_spatial,
  8. DATA_DIR)
  9. from c_inv.inference_paper_submission.layout import Frame, Figure
  10. set_mpl_pars()
  11. plt.rcParams['font.sans-serif'] = ['Arial']
  12. plt.rcParams['axes.titlesize'] = 'x-small'
  13. plt.rcParams['axes.titlepad'] = 3
  14. layoutstring = 'AB'
  15. layout = Frame(layoutstring, left=0.03, top=0.07, right=0.02, bot=0.15, h_space=0.03,
  16. w_space=0.015, panelpars={'A': {'left': 0.12, 'h_space': 0.07,
  17. 'right': 0.08},
  18. 'B': {'left': 0.12, 'h_space': 0.07,
  19. 'right': 0.05}})
  20. layout.panels['A'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01},
  21. '2': {'w_space': 0.01}})
  22. layout.panels['B'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01},
  23. '2': {'w_space': 0.01}})
  24. layout.panels['A'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
  25. 'width_ratio': 2.2}})
  26. layout.panels['A'].panels['1'].panels['1'].new_panels('123')
  27. layout.panels['A'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
  28. 'width_ratio': 2.2}})
  29. layout.panels['A'].panels['2'].panels['1'].new_panels('123')
  30. layout.panels['B'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
  31. 'width_ratio': 2.2}})
  32. layout.panels['B'].panels['1'].panels['1'].new_panels('123')
  33. layout.panels['B'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
  34. 'width_ratio': 2.2}})
  35. layout.panels['B'].panels['2'].panels['1'].new_panels('123')
  36. figure = Figure(layout, cm2inch(JNEUROSCIPARS['doublecolumn']), 2.0)
  37. figure.annotate(fontweight='bold')
  38. ###########
  39. # load data
  40. ###########
  41. pkg_dir = Path(__file__).parent
  42. data_dir = pkg_dir.joinpath('Data')
  43. pop_e = pd.read_pickle(data_dir.joinpath('F2_population_contrast_exc.pkl'))
  44. pop_i = pd.read_pickle(data_dir.joinpath('F2_population_contrast_inh.pkl'))
  45. panels = ['A1', 'A2']
  46. pads = {'A1': 6.5, 'A2': 4}
  47. for i, (data, panel) in enumerate(zip([pop_e, pop_i], panels)):
  48. pop = population_response(data)
  49. pop = np.vstack([pop, pop[0, :]])*1000
  50. svd, err, power, g_z = svd_spatial(pop)
  51. ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13']
  52. ax4 = figure.axes[panel+'2']
  53. i1 = ax1.imshow(pop, origin='lower')
  54. ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop))
  55. ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto')
  56. i3 = ax4.imshow(err, origin='lower')
  57. cbar = plt.colorbar(i1, ax=ax3)
  58. cbar.set_label(label='Firing (Hz)', labelpad=pads[panel])
  59. cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)')
  60. cbar.set_ticks([0, 0.1])
  61. ax2.set_yticklabels([])
  62. ax3.set_yticklabels([])
  63. ax4.set_yticklabels([])
  64. if '2' in panel:
  65. ax1.set_xlabel('Contrast')
  66. ax2.set_xlabel('Contrast')
  67. ax3.set_xlabel('Contrast')
  68. ax4.set_xlabel('Contrast')
  69. else:
  70. ax1.set_title('Data')
  71. ax2.set_title('SVD')
  72. ax3.set_title('Residual')
  73. ax4.set_title('Residual')
  74. ax1.set_xticklabels([])
  75. ax2.set_xticklabels([])
  76. ax3.set_xticklabels([])
  77. ax4.set_xticklabels([])
  78. oris = np.linspace(0, 180, 13).astype(int)
  79. cons = contrast_log_scale(10, 8)
  80. figure.axes['A111'].set_ylabel('Orientation ($^{\circ}$)')
  81. figure.axes['A211'].set_ylabel('Orientation ($^{\circ}$)')
  82. for ax in figure.axes.values():
  83. ax.set_xticks([0, 4, 7])
  84. figure.axes['A111'].set_yticks([0, 6, 12])
  85. figure.axes['A111'].set_yticklabels([0, 90, 180])
  86. figure.axes['A211'].set_yticks([0, 6, 12])
  87. figure.axes['A211'].set_yticklabels([0, 90, 180])
  88. figure.axes['A211'].set_xticklabels([0, cons[4].round(2), 1])
  89. figure.axes['A212'].set_xticklabels([0, cons[4].round(2), 1])
  90. figure.axes['A213'].set_xticklabels([0, cons[4].round(2), 1])
  91. figure.axes['A22'].set_xticklabels([0, cons[4].round(2), 1])
  92. NatE = loadmat(str(DATA_DIR.joinpath('S3_V1E_Sum_Fits')))['Sum_Fits'].T
  93. NatI = loadmat(str(DATA_DIR.joinpath('S3_V1I_Sum_Fits')))['Sum_Fits'].T
  94. panels = ['B1', 'B2']
  95. pads = {'B1': 4, 'B2': 4.5}
  96. for i, (pop, panel) in enumerate(zip([NatE, NatI], panels)):
  97. svd, err, power, g_z = svd_spatial(pop)
  98. ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13']
  99. ax4 = figure.axes[panel+'2']
  100. i1 = ax1.imshow(pop, origin='lower')
  101. ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop))
  102. ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto')
  103. i3 = ax4.imshow(err, origin='lower')
  104. cbar = plt.colorbar(i1, ax=ax3, label='Firing (Hz)')
  105. if '2' in panel:
  106. cbar.set_ticks([8, 10, 12, 14])
  107. cbar.set_ticklabels([8, 10, 12, 14])
  108. cbar.set_label(label='Firing (Hz)', labelpad=pads[panel])
  109. cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)')
  110. cbar.set_ticks([0, 0.1])
  111. ax2.set_yticklabels([])
  112. ax3.set_yticklabels([])
  113. ax4.set_yticklabels([])
  114. if '2' in panel:
  115. ax1.set_xlabel('Contrast')
  116. ax2.set_xlabel('Contrast')
  117. ax3.set_xlabel('Contrast')
  118. ax4.set_xlabel('Contrast')
  119. else:
  120. ax1.set_title('Data')
  121. ax2.set_title('SVD')
  122. ax3.set_title('Residual')
  123. ax4.set_title('Residual')
  124. ax1.set_xticklabels([])
  125. ax2.set_xticklabels([])
  126. ax3.set_xticklabels([])
  127. ax4.set_xticklabels([])
  128. figure.axes['B111'].set_ylabel('Orientation ($^{\circ}$)')
  129. figure.axes['B211'].set_ylabel('Orientation ($^{\circ}$)')
  130. figure.axes['B111'].set_yticks([0, 6, 12])
  131. figure.axes['B111'].set_yticklabels([0, 90, 180])
  132. figure.axes['B211'].set_yticks([0, 6, 12])
  133. figure.axes['B211'].set_yticklabels([0, 90, 180])
  134. figure.axes['B211'].set_xticklabels([0, cons[4].round(2), 1])
  135. figure.axes['B212'].set_xticklabels([0, cons[4].round(2), 1])
  136. figure.axes['B213'].set_xticklabels([0, cons[4].round(2), 1])
  137. figure.axes['B22'].set_xticklabels([0, cons[4].round(2), 1])
  138. plt.annotate('E', [0.01, 0.74], xycoords='figure fraction', color=excitcol, fontweight='bold')
  139. plt.annotate('I', [0.015, 0.31], xycoords='figure fraction', color=inhibcol, fontweight='bold')