import numpy as np import matplotlib.pyplot as plt import pandas as pd from scipy.io import loadmat from pathlib import Path from c_inv.inference_paper_submission.util import (contrast_log_scale, replace_imshow_ticklabels, set_mpl_pars, cm2inch, JNEUROSCIPARS, excitcol, inhibcol, population_response, svd_spatial, DATA_DIR) from c_inv.inference_paper_submission.layout import Frame, Figure set_mpl_pars() plt.rcParams['font.sans-serif'] = ['Arial'] plt.rcParams['axes.titlesize'] = 'x-small' plt.rcParams['axes.titlepad'] = 3 layoutstring = 'AB' layout = Frame(layoutstring, left=0.03, top=0.07, right=0.02, bot=0.15, h_space=0.03, w_space=0.015, panelpars={'A': {'left': 0.12, 'h_space': 0.07, 'right': 0.08}, 'B': {'left': 0.12, 'h_space': 0.07, 'right': 0.05}}) layout.panels['A'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01}, '2': {'w_space': 0.01}}) layout.panels['B'].new_panels('1\n2', panelpars={'1': {'w_space': 0.01}, '2': {'w_space': 0.01}}) layout.panels['A'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05, 'width_ratio': 2.2}}) layout.panels['A'].panels['1'].panels['1'].new_panels('123') layout.panels['A'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05, 'width_ratio': 2.2}}) layout.panels['A'].panels['2'].panels['1'].new_panels('123') layout.panels['B'].panels['1'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05, 'width_ratio': 2.2}}) layout.panels['B'].panels['1'].panels['1'].new_panels('123') layout.panels['B'].panels['2'].new_panels('12', panelpars={'1': {'w_space': 0.01, 'right': 0.05, 'width_ratio': 2.2}}) layout.panels['B'].panels['2'].panels['1'].new_panels('123') figure = Figure(layout, cm2inch(JNEUROSCIPARS['doublecolumn']), 2.0) figure.annotate(fontweight='bold') ########### # load data ########### pkg_dir = Path(__file__).parent data_dir = pkg_dir.joinpath('Data') pop_e = pd.read_pickle(data_dir.joinpath('F2_population_contrast_exc.pkl')) pop_i = pd.read_pickle(data_dir.joinpath('F2_population_contrast_inh.pkl')) panels = ['A1', 'A2'] pads = {'A1': 6.5, 'A2': 4} for i, (data, panel) in enumerate(zip([pop_e, pop_i], panels)): pop = population_response(data) pop = np.vstack([pop, pop[0, :]])*1000 svd, err, power, g_z = svd_spatial(pop) ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13'] ax4 = figure.axes[panel+'2'] i1 = ax1.imshow(pop, origin='lower') ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop)) ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto') i3 = ax4.imshow(err, origin='lower') cbar = plt.colorbar(i1, ax=ax3) cbar.set_label(label='Firing (Hz)', labelpad=pads[panel]) cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)') cbar.set_ticks([0, 0.1]) ax2.set_yticklabels([]) ax3.set_yticklabels([]) ax4.set_yticklabels([]) if '2' in panel: ax1.set_xlabel('Contrast') ax2.set_xlabel('Contrast') ax3.set_xlabel('Contrast') ax4.set_xlabel('Contrast') else: ax1.set_title('Data') ax2.set_title('SVD') ax3.set_title('Residual') ax4.set_title('Residual') ax1.set_xticklabels([]) ax2.set_xticklabels([]) ax3.set_xticklabels([]) ax4.set_xticklabels([]) oris = np.linspace(0, 180, 13).astype(int) cons = contrast_log_scale(10, 8) figure.axes['A111'].set_ylabel('Orientation ($^{\circ}$)') figure.axes['A211'].set_ylabel('Orientation ($^{\circ}$)') for ax in figure.axes.values(): ax.set_xticks([0, 4, 7]) figure.axes['A111'].set_yticks([0, 6, 12]) figure.axes['A111'].set_yticklabels([0, 90, 180]) figure.axes['A211'].set_yticks([0, 6, 12]) figure.axes['A211'].set_yticklabels([0, 90, 180]) figure.axes['A211'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['A212'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['A213'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['A22'].set_xticklabels([0, cons[4].round(2), 1]) NatE = loadmat(str(DATA_DIR.joinpath('S3_V1E_Sum_Fits')))['Sum_Fits'].T NatI = loadmat(str(DATA_DIR.joinpath('S3_V1I_Sum_Fits')))['Sum_Fits'].T panels = ['B1', 'B2'] pads = {'B1': 4, 'B2': 4.5} for i, (pop, panel) in enumerate(zip([NatE, NatI], panels)): svd, err, power, g_z = svd_spatial(pop) ax1, ax2, ax3 = figure.axes[panel+'11'], figure.axes[panel+'12'], figure.axes[panel+'13'] ax4 = figure.axes[panel+'2'] i1 = ax1.imshow(pop, origin='lower') ax2.imshow(svd, origin='lower', vmin=np.min(pop), vmax=np.max(pop)) ax3.imshow(err, origin='lower', vmin=np.min(pop), vmax=np.max(pop), aspect='auto') i3 = ax4.imshow(err, origin='lower') cbar = plt.colorbar(i1, ax=ax3, label='Firing (Hz)') if '2' in panel: cbar.set_ticks([8, 10, 12, 14]) cbar.set_ticklabels([8, 10, 12, 14]) cbar.set_label(label='Firing (Hz)', labelpad=pads[panel]) cbar = plt.colorbar(i3, ax=ax4, label='Firing (Hz)') cbar.set_ticks([0, 0.1]) ax2.set_yticklabels([]) ax3.set_yticklabels([]) ax4.set_yticklabels([]) if '2' in panel: ax1.set_xlabel('Contrast') ax2.set_xlabel('Contrast') ax3.set_xlabel('Contrast') ax4.set_xlabel('Contrast') else: ax1.set_title('Data') ax2.set_title('SVD') ax3.set_title('Residual') ax4.set_title('Residual') ax1.set_xticklabels([]) ax2.set_xticklabels([]) ax3.set_xticklabels([]) ax4.set_xticklabels([]) figure.axes['B111'].set_ylabel('Orientation ($^{\circ}$)') figure.axes['B211'].set_ylabel('Orientation ($^{\circ}$)') figure.axes['B111'].set_yticks([0, 6, 12]) figure.axes['B111'].set_yticklabels([0, 90, 180]) figure.axes['B211'].set_yticks([0, 6, 12]) figure.axes['B211'].set_yticklabels([0, 90, 180]) figure.axes['B211'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['B212'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['B213'].set_xticklabels([0, cons[4].round(2), 1]) figure.axes['B22'].set_xticklabels([0, cons[4].round(2), 1]) plt.annotate('E', [0.01, 0.74], xycoords='figure fraction', color=excitcol, fontweight='bold') plt.annotate('I', [0.015, 0.31], xycoords='figure fraction', color=inhibcol, fontweight='bold')