123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import numpy as np
- import matplotlib.pyplot as plt
- import pandas as pd
- from scipy.io import loadmat
- from pathlib import Path
- from c_inv.inference_paper_submission.util import (contrast_log_scale, replace_imshow_ticklabels, set_mpl_pars, cm2inch,
- JNEUROSCIPARS, excitcol, inhibcol, population_response, svd_spatial,
- DATA_DIR)
- from c_inv.inference_paper_submission.layout import Frame, Figure
- set_mpl_pars()
- plt.rcParams['font.sans-serif'] = ['Arial']
- plt.rcParams['axes.titlesize'] = 'x-small'
- plt.rcParams['axes.titlepad'] = 3
- layoutstring = 'AB'
- layout = Frame(layoutstring, left=0.03, top=0.07, right=0.02, bot=0.15, h_space=0.03,
- w_space=0.015, panelpars={'A': {'left': 0.12, 'h_space': 0.07,
- 'right': 0.08},
- 'B': {'left': 0.12, 'h_space': 0.07,
- 'right': 0.05}})
- layout.panels['A'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01},
- '2': {'w_space': 0.01}})
- layout.panels['B'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01},
- '2': {'w_space': 0.01}})
- layout.panels['A'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
- 'width_ratio': 2.2}})
- layout.panels['A'].panels['1'].panels['1'].new_panels('123')
- layout.panels['A'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
- 'width_ratio': 2.2}})
- layout.panels['A'].panels['2'].panels['1'].new_panels('123')
- layout.panels['B'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
- 'width_ratio': 2.2}})
- layout.panels['B'].panels['1'].panels['1'].new_panels('123')
- layout.panels['B'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05,
- 'width_ratio': 2.2}})
- layout.panels['B'].panels['2'].panels['1'].new_panels('123')
- figure = Figure(layout, cm2inch(JNEUROSCIPARS['doublecolumn']), 2.0)
- figure.annotate(fontweight='bold')
- ###########
- # load data
- ###########
- pkg_dir = Path(__file__).parent
- data_dir = pkg_dir.joinpath('Data')
- pop_e = pd.read_pickle(data_dir.joinpath('F2_population_contrast_exc.pkl'))
- pop_i = pd.read_pickle(data_dir.joinpath('F2_population_contrast_inh.pkl'))
- panels = ['A1', 'A2']
- pads = {'A1': 6.5, 'A2': 4}
- for i, (data, panel) in enumerate(zip([pop_e, pop_i], panels)):
- pop = population_response(data)
- pop = np.vstack([pop, pop[0, :]])*1000
- svd, err, power, g_z = svd_spatial(pop)
- ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13']
- ax4 = figure.axes[panel+'2']
- i1 = ax1.imshow(pop, origin='lower')
- ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop))
- ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto')
- i3 = ax4.imshow(err, origin='lower')
- cbar = plt.colorbar(i1, ax=ax3)
- cbar.set_label(label='Firing (Hz)', labelpad=pads[panel])
- cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)')
- cbar.set_ticks([0, 0.1])
- ax2.set_yticklabels([])
- ax3.set_yticklabels([])
- ax4.set_yticklabels([])
- if '2' in panel:
- ax1.set_xlabel('Contrast')
- ax2.set_xlabel('Contrast')
- ax3.set_xlabel('Contrast')
- ax4.set_xlabel('Contrast')
- else:
- ax1.set_title('Data')
- ax2.set_title('SVD')
- ax3.set_title('Residual')
- ax4.set_title('Residual')
- ax1.set_xticklabels([])
- ax2.set_xticklabels([])
- ax3.set_xticklabels([])
- ax4.set_xticklabels([])
- oris = np.linspace(0, 180, 13).astype(int)
- cons = contrast_log_scale(10, 8)
- figure.axes['A111'].set_ylabel('Orientation ($^{\circ}$)')
- figure.axes['A211'].set_ylabel('Orientation ($^{\circ}$)')
- for ax in figure.axes.values():
- ax.set_xticks([0, 4, 7])
- figure.axes['A111'].set_yticks([0, 6, 12])
- figure.axes['A111'].set_yticklabels([0, 90, 180])
- figure.axes['A211'].set_yticks([0, 6, 12])
- figure.axes['A211'].set_yticklabels([0, 90, 180])
- figure.axes['A211'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['A212'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['A213'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['A22'].set_xticklabels([0, cons[4].round(2), 1])
- NatE = loadmat(str(DATA_DIR.joinpath('S3_V1E_Sum_Fits')))['Sum_Fits'].T
- NatI = loadmat(str(DATA_DIR.joinpath('S3_V1I_Sum_Fits')))['Sum_Fits'].T
- panels = ['B1', 'B2']
- pads = {'B1': 4, 'B2': 4.5}
- for i, (pop, panel) in enumerate(zip([NatE, NatI], panels)):
- svd, err, power, g_z = svd_spatial(pop)
- ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13']
- ax4 = figure.axes[panel+'2']
- i1 = ax1.imshow(pop, origin='lower')
- ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop))
- ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto')
- i3 = ax4.imshow(err, origin='lower')
- cbar = plt.colorbar(i1, ax=ax3, label='Firing (Hz)')
- if '2' in panel:
- cbar.set_ticks([8, 10, 12, 14])
- cbar.set_ticklabels([8, 10, 12, 14])
- cbar.set_label(label='Firing (Hz)', labelpad=pads[panel])
- cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)')
- cbar.set_ticks([0, 0.1])
- ax2.set_yticklabels([])
- ax3.set_yticklabels([])
- ax4.set_yticklabels([])
- if '2' in panel:
- ax1.set_xlabel('Contrast')
- ax2.set_xlabel('Contrast')
- ax3.set_xlabel('Contrast')
- ax4.set_xlabel('Contrast')
- else:
- ax1.set_title('Data')
- ax2.set_title('SVD')
- ax3.set_title('Residual')
- ax4.set_title('Residual')
- ax1.set_xticklabels([])
- ax2.set_xticklabels([])
- ax3.set_xticklabels([])
- ax4.set_xticklabels([])
- figure.axes['B111'].set_ylabel('Orientation ($^{\circ}$)')
- figure.axes['B211'].set_ylabel('Orientation ($^{\circ}$)')
- figure.axes['B111'].set_yticks([0, 6, 12])
- figure.axes['B111'].set_yticklabels([0, 90, 180])
- figure.axes['B211'].set_yticks([0, 6, 12])
- figure.axes['B211'].set_yticklabels([0, 90, 180])
- figure.axes['B211'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['B212'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['B213'].set_xticklabels([0, cons[4].round(2), 1])
- figure.axes['B22'].set_xticklabels([0, cons[4].round(2), 1])
- plt.annotate('E', [0.01, 0.74], xycoords='figure fraction', color=excitcol, fontweight='bold')
- plt.annotate('I', [0.015, 0.31], xycoords='figure fraction', color=inhibcol, fontweight='bold')
|