Parcourir la source

gin commit from gb-laptop

New files: 26
born il y a 2 ans
Parent
commit
8b50e6cdc6

BIN
data/chr2LGN.mat


BIN
data/chr2V1.mat


BIN
data/lgn_sztun_df.pkl


+ 78 - 0
data/params_mouse.yaml

@@ -0,0 +1,78 @@
+name: "size_tuning"
+
+# Integrator settings
+grid:
+    nr: 7
+    nt: 1
+    dt:
+        unit: "ms"
+        value: 1.0
+    dr:
+        unit: "deg"
+        value: 1
+
+# Stimulus settings
+stimulus:
+    # index into all angular spatial frequencies
+    k_id: 1
+
+    w_id:  0
+    patch_diameter:
+           unit: "deg"
+           start: 0
+           stop: 75
+           count: 76
+
+# Neuron settings
+ganglion:
+    # DOG
+    A: 1
+    a:
+       unit: "deg"
+       #2.77
+       value: 3.33
+    # 0.05
+    B: 0.2
+    b:
+       unit: "deg"
+       #6.36
+       value: 7.67
+
+relay:
+    weight: &weight
+        start: 0.
+        stop: 1
+        count: 2
+
+    Krg:
+        w: 1.0
+        # Gaussian
+        A: 1
+        a:
+           unit: "deg"
+           value: 1
+    Krig:
+        w: 0.5
+
+        # Gaussian
+        A: -1
+        a:
+           unit: "deg"
+           value: 3
+
+    Krc_mix:
+        w: *weight
+
+        Krc_ex:
+          # Gaussian
+          A: 0.3
+          a:
+              unit: "deg"
+              value: 1
+
+        Krc_in:
+           # Gaussian
+           A: -0.6
+           a:
+               unit: "deg"
+               value: 6

BIN
data/pvV1.mat


BIN
data/pvdLGN.mat


BIN
data/rf_area_df.pkl


+ 381 - 0
data/sim_edog.py

@@ -0,0 +1,381 @@
+# code for simulating size tuning curves using the edog model
+
+# import
+import matplotlib.pyplot as plt
+from operator import itemgetter
+from edog.tools import*
+import spatint_utils
+from importlib import reload
+
+# reload modules
+reload(spatint_utils)
+
+# load params
+spatint_utils.plot_params()
+_, _, optocolor = spatint_utils.get_colors()
+
+
+class SimEdog:
+
+    def __init__(self, start=1, stop=41):
+        """Init class
+
+        Parameters
+        -----
+        start: int
+            starting width for inhibitory feedback kernel
+        stop: int
+            end width for inhibitory feedback kernel width + 1
+        """
+
+        self.start = start
+        self.stop = stop
+        self.step = 1
+        self.lowbounds = None
+        self.highbounds = None
+        self.sids = None
+        self.rfcsds = None
+        self.smallrateds = None
+        self.largerateds = None
+        self.siss = None
+        self.rfcsss = None
+        self.wavenumber = None
+        self.patch_diameter = None
+        self.w_rc_mix = None
+
+    def describe(self):
+        """Get parameter boundaries in which experimental results are qualitatively replicated.
+        Creates two dictionaries, one for the low boundaries, one for the high boundaries. Each
+        dict contains one key for each parameter (smallrateds: response to preferred size;
+        largerateds: response to largest stimulus; sids: surround suppression; rfcsds:
+        preferred size) plus one key containing the intersection of all parameter boundaries
+        (all)."""
+
+        if self.sids is None:
+            # run simulations if this has not been done
+            self.iterate()
+
+        # init boundary dicts
+        lowbounds = {}
+        highbounds = {}
+
+        # check for which kernel size CT feedback increases responses to optimal stimulus size
+        smallrateds_increaseis = (np.where(np.array(self.smallrateds) > 0))[0] + self.start
+        lowbounds['smallrateds'] = np.min(smallrateds_increaseis)
+        highbounds['smallrateds'] = np.max(smallrateds_increaseis)
+        print('Responses to optimal stimulus diameter inreases with feedback for  kernel '
+              'widths between %d and %d degree.'
+              % (lowbounds['smallrateds'], highbounds['smallrateds']))
+
+        # check for which kernel size CT feedback decreases responses to largest stimulus size
+        largerateds_decreaseis = (np.where(np.array(self.largerateds) < 0))[0] + self.start
+        lowbounds['largerateds'] = np.min(largerateds_decreaseis)
+        highbounds['largerateds'] = np.max(largerateds_decreaseis)
+        print('Responses to largest stimulus diameter decreases with feedback for kernel '
+              'widths between %d and %d degree.'
+              % (lowbounds['largerateds'], highbounds['largerateds']))
+
+        # check for which kernel size CT feedback increases surround suppression
+        sids_increaseis = (np.where(np.array(self.sids) > 0))[0] + self.start
+        lowbounds['sids'] = np.min(sids_increaseis)
+        highbounds['sids'] = np.max(sids_increaseis)
+        print('SI increases with feedback for kernel widths between %d and %d degree.'
+              % (lowbounds['sids'], highbounds['sids']))
+
+        # check for which kernel size CT feedback decreases receptive field center size
+        rfcsds_decreaseis = (np.where(np.array(self.rfcsds) < 0))[0] + self.start
+        lowbounds['rfcsds'] = np.min(rfcsds_decreaseis)
+        highbounds['rfcsds'] = np.max(rfcsds_decreaseis)
+        print('RFCS decreases with feedback for kernel widths between %d and %d degree.'
+              % (lowbounds['rfcsds'], highbounds['rfcsds']))
+
+        # check range for all paramters
+        all_lowi = max(lowbounds.items(), key=itemgetter(1))[0]
+        all_low = lowbounds[all_lowi]
+        all_highi = min(highbounds.items(), key=itemgetter(1))[0]
+        all_high = highbounds[all_highi]
+        lowbounds['all'] = all_low
+        highbounds['all'] = all_high
+        print('Qualitatively matching range between %d and %d degree.'
+              % (lowbounds['all'], highbounds['all']))
+
+        # store dicts as attributes
+        self.lowbounds = lowbounds
+        self.highbounds = highbounds
+
+    def iterate(self):
+        """Run size tuning simulation for range of inhibitory FB kernel widths specified for
+         conditions in which feedback is intact (weight: 1) and feedback is abolished (
+         weight: 0). For each stimulation store percent change in surround suppression (
+         sids); preferred size (rfcsds), response to preferred size (smallrateds),
+         and response to the largest stimulus size (largerateds), as well as absolute values
+         for suppression index (siss) and preferred size (rfcsss) in both conditions"""
+
+        # init lists to store values
+        sids = []
+        rfcsds = []
+        smallrateds = []
+        largerateds = []
+        siss = []
+        rfcsss = []
+
+        # create vector of inhibitory feedback widths
+        inhib_fbs = np.arange(self.start, self.stop, self.step)
+
+        # iterate over widths
+        for inhib_fb in inhib_fbs:
+
+            # run simulation
+            curves = self.simulate(inhib_fb=inhib_fb)
+
+            # compute target values
+            sid, rfcsd, smallrated, largerated, sis, rfcss = self.characterize(curves=curves)
+
+            # store values in list
+            sids.append(sid)
+            rfcsds.append(rfcsd)
+            smallrateds.append(smallrated)
+            largerateds.append(largerated)
+            siss.append(sis)
+            rfcsss.append(rfcss)
+
+        # store as attributes
+        self.sids = sids
+        self.rfcsds = rfcsds
+        self.smallrateds = smallrateds
+        self.largerateds = largerateds
+        self.siss = siss
+        self.rfcsss = rfcsss
+
+    def simulate(self, inhib_fb=None):
+        """Run the edog model and return the two tuning curves
+
+        Parameters
+        ----
+        inhib_fb: int
+            width of inhibitory feedback kernel
+        """
+
+        # get parameters
+        filename = "./to_publish/data/params_mouse.yaml"
+        params = parse_parameters(filename)
+
+        nt, nr, dt, dr = itemgetter("nt", "nr", "dt", "dr")(params["grid"])
+        k_id, w_id, patch_diameter = itemgetter("k_id", "w_id", "patch_diameter")(
+            params["stimulus"])
+        A_g, a_g, B_g, b_g = itemgetter("A", "a", "B", "b")(params["ganglion"])
+        w_rg, A_rg, a_rg = itemgetter("w", "A", "a")(params["relay"]["Krg"])
+        w_rig, A_rig, a_rig = itemgetter("w", "A", "a")(params["relay"]["Krig"])
+        w_rc_mix = itemgetter("w")(params["relay"]["Krc_mix"])
+        A_rc_mix_in, a_rc_mix_in = itemgetter("A", "a")(params["relay"]["Krc_mix"]["Krc_in"])
+        A_rc_mix_ex, a_rc_mix_ex = itemgetter("A", "a")(params["relay"]["Krc_mix"]["Krc_ex"])
+
+        if inhib_fb is not None:
+            # if input to inhib_fb is not none replace default value of inhibitory FB kernel
+            # width
+            a_rc_mix_in = inhib_fb * pq.deg
+
+        # initiate variables
+        tuning_curve = np.zeros([len(w_rc_mix), len(patch_diameter)])
+        cen_size = np.zeros(len(w_rc_mix))
+        supp_index = np.zeros(len(w_rc_mix))
+
+        tuning_curve[:] = np.nan
+        cen_size[:] = np.nan
+        supp_index[:] = np.nan
+
+        # run model
+        for i, w in enumerate(w_rc_mix):
+            network = create_spatial_network(nt=nt, nr=nr, dt=dt, dr=dr,
+                                             A_g=A_g, a_g=a_g, B_g=B_g, b_g=b_g,
+                                             w_rg=w_rg, A_rg=A_rg, a_rg=a_rg,
+                                             w_rig=w_rig, A_rig=A_rig, a_rig=a_rig,
+                                             w_rc_in=w, A_rc_in=A_rc_mix_in,
+                                             a_rc_in=a_rc_mix_in, w_rc_ex=w,
+                                             A_rc_ex=A_rc_mix_ex, a_rc_ex=a_rc_mix_ex)
+
+            angular_freq = network.integrator.temporal_angular_freqs[int(w_id)]
+            wavenumber = network.integrator.spatial_angular_freqs[int(k_id)]
+            spatiotemporal_tuning = spatiotemporal_size_tuning(network=network,
+                                                               angular_freq=angular_freq,
+                                                               wavenumber=wavenumber,
+                                                               patch_diameter=patch_diameter)
+
+            # store tuning curve, preferred size and suppression index
+            tuning_curve[i, :] = spatiotemporal_tuning[0, :]
+
+        # normalize by no feedback condition
+        curves = tuning_curve / tuning_curve[0].max()
+
+        # store anglular frequencies, FB weighs and patch diameters
+        self.wavenumber = network.integrator.spatial_angular_freqs[int(k_id)]
+        self.patch_diameter = patch_diameter
+        self.w_rc_mix = w_rc_mix
+
+        return curves
+
+    def characterize(self, curves):
+        """Compute differences between size tuning curves with and without feedback. Return
+         percent change in surround suppression (sid); preferred size (rfcsd), response to
+        preferred size (smallrated), and response to the largest stimulus size (largerated),
+        as well as absolute values for suppression index (sis) and preferred size (rfcss)
+
+        Parameters
+        ----
+        curves: ndarray
+            2D array containing size tuning curves in both conditions
+
+        Returns
+        ----
+        sid: float
+            percent change in suppression index
+        rfcsd: float
+            percent change in preferred size
+        smallrated: float
+            percent change in response to small stimulus
+        largerated: float
+            percent change in response to large stimulus
+        sis: list
+            suppression indices under both condtitions
+        rfcss: list
+            preferred sizes under both conditions
+        """
+
+        # get patch_diameters
+        patch_diameter = self.patch_diameter
+
+        # init lists to store target values
+        sis = []
+        rfcss = []
+        smallrates = []
+        largerates = []
+
+        # preferred size is defined by stimulus where an increase in one degree does not
+        # increase response by 0.05%. This is analogous to how visTRN size tuning is analyzed
+        perc_thres = 0.05
+
+        # iterate over the two conditions
+        for icurve in range(curves.shape[0]):
+
+            # get x and y
+            y = curves[1-icurve, :]  # start with control condition
+            x = patch_diameter
+
+            # compute RF center size: check where 1deg in increase of visual stim fails to
+            # increase resp by perc_thres (output 2 to deal with divide by 0 -> will set
+            # first increase to 100%)
+            rel_change = np.divide(y[1:], y[:-1], out=np.ones_like(y[1:])*2, where=y[:-1] != 0)
+
+            # compute percent change
+            perc_change = (rel_change - 1) * 100
+
+            # check where percent change is lower than perc_thres
+            max_resp = np.where(perc_change < perc_thres)[0]
+
+            # receptive field center size is either first max_resp or last stimulus size
+            if any(max_resp):
+                rfcs = max_resp[0]
+            else:
+                rfcs = x[-1]
+
+            if icurve == 0:
+                # store rfcs of control curve
+                cont_rfcs = rfcs
+
+            # compute surround size
+            rfss = int(x[-1].item())
+
+            # determine modeled firing rates
+            smallrate_cont = y[cont_rfcs]
+            smallrate = y[rfcs]
+            largerate = y[rfss]
+
+            # compute suppression index, if si is smaller 0 because of small increase for large
+            # stimuli, set it to 0
+            si = max(((smallrate - largerate) / smallrate), 0)
+
+            # store values
+            sis.append(si)
+            rfcss.append(rfcs)
+            smallrates.append(smallrate_cont)
+            largerates.append(largerate)
+
+        # compute percent change
+        sid = ((sis[0] / sis[1]) - 1) * 100
+        rfcsd = ((rfcss[0] / rfcss[1]) - 1) * 100
+        smallrated = ((smallrates[0] / smallrates[1]) - 1) * 100
+        largerated = ((largerates[0] / largerates[1]) - 1) * 100
+
+        return sid, rfcsd, smallrated, largerated, sis, rfcss
+
+    def plot(self, figsize=(6, 6), inhib_fb=None, quant=False, ax=None):
+        """Plot the two tuning curves for inhibitory feedback kernel width
+
+        Parameters
+        ----
+        figsize: tuple
+            figure size (width, height)
+        inhib_fb: int
+            inhibitory feedback kernel width
+        quant: bool
+            if true print info regarding
+        ax: mpl axis
+            axis for plotting
+        """
+
+        # simulate data
+        curves = self.simulate(inhib_fb=inhib_fb)
+
+        # get xvalues
+        patch_diameter = self.patch_diameter
+
+        # set colors
+        colors = [optocolor, 'k']
+
+        if ax is None:
+            # init fig if doesnt exist
+            fig, ax = plt.subplots(figsize=figsize)
+
+        # plot curves
+        for curve, color in zip(curves, colors):
+            ax.plot(patch_diameter, curve, '-', color=color)
+
+        # layout
+        ax.set_ylabel("Normalized response")
+        ax.set_xlabel("Diameter ($\degree$)")
+        ax.set_xticks((0, 25, 50, 75))
+        ax.set_xticklabels((0, 25, 50, 75))
+        ax.set_xlim([0, 75])
+        ax.set_ylim([0, 1.5])
+        ax.spines['bottom'].set_bounds(0, 75)
+
+        if quant:
+            # plot information about differences
+
+            # init variables
+            title = "Mixed feedback"
+            wavenumber = self.wavenumber
+
+            # get quant measures
+            sid, rfcsd, smallrated, largerated, sis, rfcss = self.characterize(curves=curves)
+
+            # insert model results
+            ax.text(50, 0.1, 'opto si = %0.2f' % sis[1])
+            ax.text(50, 0.14, 'control si = %0.2f' % sis[0])
+            ax.text(50, 0.18, 'perc change si = %0.2f' % sid)
+            ax.text(50, 0.22, 'opto rfcs = %0.2f' % rfcss[1])
+            ax.text(50, 0.26, 'control rfcs = %0.2f' % rfcss[0])
+            ax.text(50, 0.3, 'perc change rfcs = %0.2f' % rfcsd)
+            ax.text(50, 0.34, 'perc change smallrate = %0.2f' % smallrated)
+            ax.text(50, 0.38, 'perc change largerate = %0.2f' % largerated)
+            ax.text(50, 0.42, 'model')
+
+            # add legends and titles
+            ax.set_title(title)
+            fig.text(0.39, 0.97, "Patch grating ({})".format(round(wavenumber.item(), 2)),
+                     fontsize=12)
+
+            # plot preferred size
+            for rfcs, color in zip(rfcss[::-1], colors):
+                ax.axvline(rfcs, color=color)
+

BIN
data/trn_retino_df.pkl


BIN
data/trn_sztun_df.pkl


BIN
data/trn_sztun_ex_dict.pkl


BIN
data/trn_sztun_opto_df.pkl


BIN
data/trn_sztun_opto_ex_dict.pkl


BIN
data/v1_spat_int.mat


+ 395 - 0
figs/fig1.py

@@ -0,0 +1,395 @@
+import numpy as np 
+import matplotlib.pyplot as plt
+import matplotlib
+import pandas as pd
+idx = pd.IndexSlice
+import pickle as pkl
+import os
+import spatint_utils as su
+
+datapath = os.getcwd()+'/../data/'
+with open(datapath+'lgn_spat_profile.pkl','rb') as r:
+    lgn_spat_profile = pkl.load(r)
+    
+with open(datapath+'lgn_ori_examples.pkl','rb') as r:
+    lgn_ori_examples = pkl.load(r)
+    
+with open(datapath + 'v1_raw_muaRFs.pkl','rb') as r:
+    v1_raw_muaRFs = pkl.load(r)
+    
+with open(datapath + 'v1_muaRFs.pkl','rb') as r:
+    v1_muaRFs = pkl.load(r)
+    
+with open(datapath + 'lgn_raw_muaRFs.pkl','rb') as r:
+    lgn_raw_muaRFs = pkl.load(r)
+    
+with open(datapath + 'lgn_muaRFs.pkl','rb') as r:
+    lgn_muaRFs = pkl.load(r)
+
+def agg_overlap_bins(data, start, end, nbins, binsize, agg_op, boot_value, filtr = False):
+    """computes aggregate values for bins along a given axis, aggregate operation is defined by agg_op"""
+    
+    bins_data = []
+    print('Averaging and bootstrapping in overlapping bins')
+    for bin_start in np.linspace(start, end, nbins):
+        bin_data = data.loc[(data['d']>bin_start) & (data['d']<(bin_start+binsize))]
+        
+        if filtr != False and len(bin_data)<filtr:
+            continue
+        booties = []
+        for i in range(1000):
+            bt_bin = bin_data[boot_value][np.random.randint(0,len(bin_data),len(bin_data))]
+            booties.append(np.mean(bt_bin))
+        booties = np.sort(booties)
+        lower = booties[24]
+        upper = booties[974]
+        boot_std = np.std(booties)
+        bin_agg = bin_data.agg(agg_op)
+        bin_agg['lower'] = lower
+        bin_agg['upper'] = upper
+        bin_agg['boot'] = boot_std
+        bin_agg['d'] = bin_start
+        bins_data.append(bin_agg)
+    return pd.concat(bins_data, axis =1)
+
+
+#bin_mean = agg_overlap_bins(lgn_spat_profile,0,100,31,15,agg_op='mean',boot_value='gain',filtr=5)
+
+
+def plot_fold_change(ax=None, data=lgn_spat_profile):#, bin_mean=bin_mean):
+    """Plot mean fold change values of all units as scatter plot"""
+    
+    if ax is None:
+        fig, ax = plt.subplots()
+    
+    # Clip low and high values for better visibility
+    data.loc[data['gain']>2,'gain'] = 2.35
+    data.loc[data['gain']<-2,'gain'] = -2.35
+    
+    ax.plot(data['d'],data['gain'],c='k', alpha=0.7, linestyle='', fillstyle='none', clip_on = False, marker='o',mec='k',markersize=1)
+    examples = data.reindex([('Ntsr1Cre_2015_0080',3,4027),
+                             ('Ntsr1Cre_2018_0003',2,22,3),
+                            ('Ntsr1Cre_2015_0080',4,3049)])
+    ax.plot([-100,100],[0,0],c='k', lw=0.35)
+    ax.plot(bin_mean.loc['d'][:15]+7.5,bin_mean.loc['gain'][:15], lw=3,alpha=0.5, c='#1090cfff')
+    ax.plot(bin_mean.loc["d"][8:15]+7.5,bin_mean.loc['gain'][8:15], lw=3,alpha=1, c ='#1090cfff')
+    ax.set(xlim=[0,60], ylim=[-2.5,2.5])
+    ax.set_xlabel('Distance of dLGN RFs to V1 RFs at injection site $(^{\circ})$',labelpad=5)
+    ax.set_xticks([0,20,40,60])
+    ax.set_ylabel('Fold change',labelpad=0)
+    ax.minorticks_off()
+    ax.set_yticks(list(np.linspace(-2,2,5))+[-2.35,2.35])
+    ax.set_yticklabels([str("{0:.2f}".format(round(ytick,2)))for ytick in np.geomspace(0.25,4,5)]+['< 0.25','> 4.00'])
+    ax.spines['left'].set_bounds(-2,2)
+    ax.spines['right'].set_visible(False)
+    ax.spines['top'].set_visible(False)
+    
+def plot_modulation_histogram(ax = None, data = lgn_spat_profile):
+    
+    if ax is None:
+        fig = plt.figure()
+        fig.set_figheight(2.576)
+        fig.set_figwidth(4.342)
+        ax = fig.add_axes((0.15,0.25,0.75,0.75))
+        
+    
+    bins = np.linspace(0,60,13)
+    data = data.sort_values('d')
+    data['bin'] = data.apply(lambda x: np.searchsorted(bins, x['d'], side='right')-1, axis=1)
+
+    ratios = data.groupby('bin').apply(lambda x: pd.DataFrame({'sup' :[len(x[x['modlab']==-1])/len(x)],
+                                                      'fac' :[len(x[x['modlab']== 1])/len(x)],
+                                                     'snull':[len(x[(x['modlab']==0) & (x['gain']<0)])/len(x)],
+                                                     'fnull' :[len(x[(x['modlab']==0) & (x['gain']>0)])/len(x)]}))
+    ratios.index = ratios.index.set_names('dummy', level=1)
+    ratios = ratios.reset_index("dummy").drop(columns='dummy')
+    ratios['binpos'] = bins[ratios.index]+2.5
+
+
+    ax.bar(ratios['binpos'],-ratios['sup'], bottom = -ratios['snull'], color = 'blue', width=4)
+    ax.bar(ratios['binpos'], ratios['fac'], bottom = ratios['fnull'], color = 'orange',width = 4)
+    ax.bar(ratios['binpos'], -ratios['snull'], bottom = 0, color='lightskyblue', width=4)
+    ax.bar(ratios['binpos'], ratios['fnull'], bottom = 0, color = 'moccasin', width=4)
+    ax.set_xlabel('Distance of dLGN RF to \n mean V1 RF at injection site ($\degree$)')
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
+    ax.spines['left'].set_bounds(high=0.75,low=-0.75)
+    ax.plot([0,70],[0,0],c='k',lw=0.35)
+    ax.set(xlim=[2.5,60],ylim=[-0.9,0.9])
+    ax.set_yticks([-0.75,-0.25,0,0.25,0.75])
+    ax.set_yticklabels(['75','25','','25','75'])
+    ax.set_xticks([0,20,40,60])
+    ax.set_ylabel('Proportion (%)')
+    ax.yaxis.set_label_coords(-0.1,0.55)
+    ax.annotate('enhanced', xy=(1,0.15), xytext=(1,0.15),xycoords='data',rotation='vertical',color='darkorange')
+    ax.annotate('suppressed', xy=(1,-0.8), xytext=(1,-0.8),xycoords='data',rotation='vertical',color='blue')
+    
+def plot_ori_examples(ax = None, data = lgn_ori_examples):
+    
+    def compute_ori_curve(df,opto,xs):
+    
+        return su.sum_of_gaussians(xs,*df['tun_pars'][opto])
+    
+    xs = np.linspace(0,360,361)
+    ctrl_curves = data.apply(compute_ori_curve, axis=1, opto=0,xs=xs)
+    opto_curves = data.apply(compute_ori_curve, axis=1, opto=1,xs=xs)
+    
+    if ax is None:
+        gridspec_kw = {'left':0.175,'right':0.95,'bottom':0.225,'top':0.95,'hspace':0.1}
+        fig, axes = plt.subplots(1,3,figsize=(6.327,1.971), gridspec_kw = gridspec_kw) 
+    
+    for i,ax in enumerate(axes):
+        ax.plot(xs, ctrl_curves.iloc[i],c='k',ms=2)
+        ax.plot(xs, opto_curves.iloc[i],c='#1090cfff',ms=2)
+        ax.plot(np.linspace(0,330,12),data.iloc[i]['tun_mean'][:,0],mfc='k',ms=8,ls='',marker='.',mew=0)
+        ax.plot(np.linspace(0,330,12),data.iloc[i]['tun_mean'][:,1],mfc='#1090cfff',ms=8,ls='',marker='.',mew=0)
+        ax.errorbar(np.linspace(0,330,12), data.iloc[i]['tun_mean'][:,0], yerr=data.iloc[i]['tun_sem'][:,0],fmt='none', ecolor='k', elinewidth=1.5)
+        ax.errorbar(np.linspace(0,330,12), data.iloc[i]['tun_mean'][:,1], yerr=data.iloc[i]['tun_sem'][:,1],fmt='none', ecolor='#1090cfff', elinewidth=1.5)
+        ax.plot(np.linspace(0,330,12), 12*[data.iloc[i]['tun_spon_mean'][0]],c='k',lw=0.5)
+        ax.plot(np.linspace(0,330,12), 12*[data.iloc[i]['tun_spon_mean'][1]],c='#1090cfff',lw=0.5)
+        ax.spines['top'].set_visible(False)
+        ax.spines['right'].set_visible(False)
+        ax.spines['bottom'].set_visible(True)
+        ax.spines['bottom'].set_bounds(0,360)
+        xticks = [0,180,360]
+        ax.set_xticks(xticks)
+        ax.set_yticks([]) 
+        ax.set_xticklabels([])
+        
+    axes[0].spines['left'].set_visible(True)
+    axes[0].spines['bottom'].set_visible(True)
+    axes[0].set(ylim=[0,100])
+    axes[0].spines['left'].set_bounds(0,80)
+    axes[0].set_yticks([0,40,80])
+    
+    axes[1].set(ylim=[0,62.5])
+    axes[1].spines['left'].set_bounds(0,50)
+    axes[1].set_yticks([0,25,50])
+
+    axes[2].set(ylim=[0,62.5])
+    axes[2].spines['left'].set_bounds(0,50)
+    axes[2].set_yticks([0,25,50])
+
+    axes[0].set_ylabel('Firing rate (sp/s)')
+    axes[0].yaxis.set_label_coords(-0.33,0.4)
+    axes[0].set_xlabel('Direction ($\degree$)',labelpad=0)
+    axes[0].set_xticks([0,180,360])
+    axes[0].set_xticklabels(['0','180','360'])
+    axes[0].xaxis.set_tick_params(length=1)
+    
+    
+def plot_v1_raw_muaRFs(ax = None, data = v1_raw_muaRFs):
+    
+    if ax is None:
+        gridspec_kw = {'left':0.1,'right':0.975,'bottom':0.15,'top':0.9}
+        fig,ax = plt.subplots(3,1,gridspec_kw=gridspec_kw)
+        fig.set_figheight(1.984)
+        fig.set_figwidth(1.332)
+
+    cmap = matplotlib.cm.get_cmap('Greens')
+    cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
+                                                                [(0,'white'),
+                                                                (1,'green')])
+    for axi, (index, row) in zip(ax,data.iterrows()):
+
+        left = row['ti_axes'][:,0].min()-2.4
+        right = row['ti_axes'][:,0].max()+2.5
+        bottom = row['ti_axes'][:,1].min()-2.5
+        top = row['ti_axes'][:,1].max()+2.5
+        extent = (left,right,bottom,top)
+
+        axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
+        plot_ellipse(row,axi,cmap, linewidth=1.5)
+        axi.axis('scaled')
+        axi.set(xlim=[left,right],ylim=[bottom,top])
+        axi.set_xticks([])
+        axi.set_yticks([])
+
+    else:
+        xticks = [left,right]
+        yticks = [bottom,top]
+        axi.set_xticks(xticks)
+        axi.set_yticks([])
+        ax[0].set_yticks(yticks)
+        axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks])
+        ax[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
+                            rotation=90)
+        plt.setp(ax[0].yaxis.get_majorticklabels(), va='center')
+        plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
+        axi.tick_params(length=2)
+        ax[0].tick_params(axis='y', pad=0, length=2)
+        axi.tick_params(axis='x', pad=1)
+        
+def plot_v1_rf_ellipses(ax = None, data = v1_muaRFs):
+    
+    
+    if ax is None:
+        fig = plt.figure(figsize=(7,5))
+        ax = fig.add_axes((0.15,0.15,0.7,0.7))
+    
+    cmap = matplotlib.cm.get_cmap('Greens')
+
+    data = data.reset_index('ch')
+    
+    one_series = data.loc['Ntsr1Cre_2015_0080',2]
+    one_series = one_series[one_series['sigma_x_mix']<15]
+    maxchan,minchan = one_series['ch'].max(), one_series['ch'].min()
+    V1_x = data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['x_mix']
+    V1_y = data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['y_mix']
+
+    one_series.apply(lambda x: plot_ellipse(x,ax=ax,cmap=cmap,linewidth=1,maxchan=maxchan),axis=1)
+
+    ax.scatter(V1_x,V1_y,marker='+',s=100,lw=2,c='k',zorder=3)
+
+    ax.set(xlim=[-20,80],ylim=[-35,50])
+    ax.set_xlabel('Azimuth ($\degree$)'
+                 ,labelpad=0)
+    ax.set_ylabel('Elevation ($\degree$)'
+                 ,labelpad=0)
+
+    xticks = [tick.get_text() for tick in ax.get_xticklabels()]
+
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
+    
+def plot_lgn_raw_muaRFs(ax = None, data = lgn_raw_muaRFs):
+    
+    if ax is None:
+        gridspec_kw = {'left':0.1,'right':0.975,'bottom':0.1,'top':0.93}
+        fig,ax = plt.subplots(5,2,gridspec_kw=gridspec_kw)
+        fig.set_figheight(3.552)
+        fig.set_figwidth(2.190)
+        ax_column1 = ax[:,0]
+        ax_column2 = ax[:,1]
+    
+    cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
+                                                            [(0,'xkcd:off white'),
+                                                            (1,'xkcd:neon purple')])
+    cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
+                                                               [(0,'xkcd:off white'),
+                                                               (1,'xkcd:saffron')]) #deep tea
+    cols = {5:'#00555a',9:'blue'}
+
+    ser1 = data.loc['Ntsr1Cre_2015_0080',9].iloc[14:19]
+    ser2 = data.loc['Ntsr1Cre_2015_0080',5].iloc[7:12]
+
+    for axi, (index, row) in zip(ax_column1,ser1.iterrows()):
+
+        left = row['ti_axes'][:,0].min()-2.4
+        right = row['ti_axes'][:,0].max()+2.5
+        bottom = row['ti_axes'][:,1].min()-2.5
+        top = row['ti_axes'][:,1].max()+2.5
+        extent = (left,right,bottom,top)
+
+        axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
+        plot_ellipse(row,axi,cmap, linewidth =1.5)
+        axi.axis('scaled')
+        axi.set(xlim=[left,right],ylim=[bottom,top])
+        axi.set_xticks([])
+        axi.set_yticks([])
+
+    else:
+        xticks = [left,right]
+        yticks = [bottom,top]
+        axi.set_xticks(xticks)
+        axi.set_yticks([])
+        ax_column1[0].set_yticks(yticks)
+        axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks])
+        ax_column1[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
+                            rotation=90)
+        plt.setp(ax_column1[0].yaxis.get_majorticklabels(), va='center')
+        plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
+        ax_column1[0].tick_params(axis='y', pad=0)
+
+    for axi, (index, row) in zip(ax_column2,ser2.iterrows()):
+
+        left = row['ti_axes'][:,0].min()-2.4
+        right = row['ti_axes'][:,0].max()+2.5
+        bottom = row['ti_axes'][:,1].min()-2.5
+        top = row['ti_axes'][:,1].max()+2.5
+        extent = (left,right,bottom,top)
+
+        axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
+        plot_ellipse(row,axi,cmap2, linewidth=1.5)
+        axi.axis('scaled')
+        axi.set(xlim=[left,right],ylim=[bottom,top])
+
+        axi.set_xticks([])
+        axi.set_yticks([])
+
+    else:
+        xticks = [left,right]
+        yticks = [bottom,top]
+        axi.set_xticks(xticks)
+        axi.set_yticks([])
+        ax_column2[0].set_yticks(yticks)
+        axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks]
+            )
+        ax_column2[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
+                            rotation=90)
+        plt.setp(ax_column2[0].yaxis.get_majorticklabels(), va='center')
+        plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
+        ax_column2[0].tick_params(axis='y', pad=0)
+        
+def plot_lgn_rf_ellipses(ax = None, data = lgn_muaRFs, v1data=v1_muaRFs):
+    
+    if ax is None:
+        fig = plt.figure(figsize=(2.995,3.531))
+        ax = fig.add_axes((0.275,0.2,0.65,0.725))
+
+    cmap1 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
+                                                                [(0,'xkcd:neon purple'),
+                                                                (1,'xkcd:neon purple')])
+    cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
+                                                               [(0,'xkcd:saffron'),
+                                                               (1,'xkcd:saffron')]) 
+    cmaps = {9:cmap1,5:cmap2}
+    two_series = data.loc[idx['Ntsr1Cre_2015_0080',(5,9),:]]
+    V1_x = v1data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['x_mix']
+    V1_y = v1data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['y_mix']
+
+
+    two_series.groupby('s').apply(lambda x: plot_RF_ellipses(x, ax, cmaps[x.index.unique(level=1)[0]],linewidth=0.5))
+
+
+    ax.scatter(V1_x,V1_y,marker='+',s=100,lw=2,c='k',zorder=3)
+    ax.set(xlim=[-5,70],ylim=[-25,60])
+    ax.set_xlabel('Azimuth ($\degree$)'
+                 ,labelpad=0)
+    ax.set_ylabel('Elevation ($\degree$)'
+                 ,labelpad=0)
+    xticks = [tick.get_text() for tick in ax.get_xticklabels()]
+    xticks = [0,40,80]
+    yticks=[-20,0,30,60]
+    xticklabels = [str(xtick) for xtick in xticks]
+    yticklabels = [str(ytick) for ytick in yticks]
+
+    ax.yaxis.set_tick_params(pad=0)
+
+    ax.set_xticks(xticks)
+    ax.set_xticklabels(xticklabels)
+    ax.set_yticks(yticks)
+    ax.set_yticklabels(yticklabels)
+
+    ax.spines['top'].set_visible(False)
+    ax.spines['right'].set_visible(False)
+    ax.spines['left'].set_bounds(-20,60)
+    ax.spines['bottom'].set_bounds(0,80)
+
+    norm= matplotlib.colors.Normalize(vmin=0.5,vmax=1)
+        
+def plot_RF_ellipses(data, ax, cmap,linewidth=1):
+    """Plotting RF ellipses from fitted parameters."""
+    data.apply(lambda x: plot_ellipse(x,ax,cmap,linewidth=linewidth),axis=1)
+    
+def plot_ellipse(data, ax, cmap, linewidth=1, maxchan=60):
+    """Plot from single df row"""
+    ddiff = su.degdiff(180,(data['theta_mix']*180/np.pi),180)
+    params = data[['x_mix',
+                   'y_mix',
+                   'sigma_x_mix',
+                   'sigma_y_mix']]
+    x_ellipse, y_ellipse = su.calculate_ellipse(*params,ddiff)
+    col = cmap(data['rsq_mix'])
+    ax.plot(x_ellipse,y_ellipse, lw=linewidth,c=col)
+    ax.axis('scaled')

+ 302 - 0
figs/fig2.py

@@ -0,0 +1,302 @@
+# code for Figure 2 panels
+
+# import
+from matplotlib import pyplot as plt
+from matplotlib.ticker import ScalarFormatter, NullFormatter
+import numpy as np
+import pandas
+from scipy import stats
+from importlib import reload
+import spatint_utils
+
+# reload
+reload(spatint_utils)
+
+spatint_utils.plot_params()
+_, _, optocolor = spatint_utils.get_colors()
+
+
+class Fig2:
+    """Class to for plotting panels for Fig.2"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read lgn size tuning dataframe
+        self.lgn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/lgn_sztun_df.pkl')
+
+    def ex_sztun_curve(self, figsize=(2.5, 2.5), ax=None):
+        """Plot example dLGN size-tuning curve (Fig. 2h)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # define example index
+        ex_idx = 22
+
+        # plot curves
+        spatint_utils.plot_tun(means=self.lgn_sztun_df.loc[ex_idx]['tun_mean'],
+                               sems=self.lgn_sztun_df.loc[ex_idx]['tun_sem'],
+                               spons=self.lgn_sztun_df.loc[ex_idx]['tun_spon_mean'],
+                               xs=self.lgn_sztun_df.loc[ex_idx]['ti_axes'],
+                               c_fit=self.lgn_sztun_df.loc[ex_idx]['c_sz_fit'],
+                               op_fit=self.lgn_sztun_df.loc[ex_idx]['op_sz_fit'],
+                               c_prefsz=self.lgn_sztun_df.loc[ex_idx]['rfcs_76'][0],
+                               op_prefsz=self.lgn_sztun_df.loc[ex_idx]['rfcs_76'][1],
+                               ax=ax)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def fit_norm_curves(self, ax=None, figsize=(2.5, 2.5)):
+        """Plots normalized dLGN size-tuning curves (Fig. 2i)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure if ax is none
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get ydata
+        yconts = self.lgn_sztun_df.c_sz_fit.to_list()
+        yoptos = self.lgn_sztun_df.op_sz_fit.to_list()
+
+        # normalize
+        ycont_norm = np.vstack([cont / np.nanmax(np.concatenate((cont, opto))) for cont,
+                                opto in zip(yconts, yoptos)])
+        yopto_norm = np.vstack([opto / np.nanmax(np.concatenate((cont, opto))) for cont,
+                               opto in zip(yconts, yoptos)])
+
+        # compute mean for both conditions
+        cont_mean = np.nanmean(ycont_norm, axis=0)
+        opto_mean = np.nanmean(yopto_norm, axis=0)
+
+        # compute sem for both conditions
+        cont_sem = stats.sem(ycont_norm, axis=0)
+        opto_sem = stats.sem(yopto_norm, axis=0)
+
+        # plot curves and sem
+        x_eval = np.arange(76)
+        ax.plot(x_eval, cont_mean, color='k', linestyle='-')
+        ax.fill_between(x_eval, cont_mean - cont_sem, cont_mean + cont_sem, color='k',
+                        alpha=0.5, linewidth=0)
+        ax.plot(x_eval, opto_mean, color=optocolor, linestyle='-')
+        ax.fill_between(x_eval, opto_mean - opto_sem, opto_mean + opto_sem, color=optocolor,
+                        alpha=0.5, linewidth=0)
+
+        # layout
+        ax.set_ylabel('Normalized firing rate')
+        ax.set_xlabel('Diameter ($\degree$)')
+        ax.set_xticks((0, 25, 50, 75))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['bottom'].set_bounds(0, 75)
+        ax.spines['left'].set_bounds(0, 1)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def scatter(self, figsize=(2.5, 2.5), ax=None, alys=None):
+        """Plot scatterplots to compare dLGN spatial integration with V1 intact vs suppressed (
+        Fig. 2j-m)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+        alys: string: 'ropt', 'rsupp', 'rfcs', 'si'
+            determines which parameter to analyze: modelled response to optimal stimulus (
+            ropt), modelled response to large stimulus (rsupp), preferred size (rfcs),
+            suppression index (si)
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure if ax is none
+            f, ax = plt.subplots(figsize=figsize)
+
+        if alys == 'ropt':
+            # compare modelled response to optimal stimulus
+
+            # get data
+            cont = self.lgn_sztun_df['r_opt_c_76'].values
+            supp = self.lgn_sztun_df['r_opt_op_crfcs_76'].values
+
+            # compute statistics
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # set threshold for plotting
+            min_thres = 4
+            min_dist = 1
+            max_dist = 43
+            cont_plot = cont
+            cont_plot[cont_plot > 30] = max_dist
+            cont_plot[cont_plot < min_thres] = min_thres - min_dist
+            supp_plot = supp
+            supp_plot[supp_plot > 30] = max_dist
+            supp_plot[supp_plot < min_thres] = min_thres - min_dist
+
+            # layout
+            titlestr = 'Small size\nresponse (sp/s)'
+            ax.set_yscale('log')
+            ax.set_xscale('log')
+            for axis in [ax.xaxis, ax.yaxis]:
+                axis.set_major_formatter(ScalarFormatter())
+                axis.set_minor_formatter(NullFormatter())
+            ax.set_title(titlestr)
+            ax.set_xlim(min_thres - (min_dist * 2), max_dist)
+            ax.set_ylim(min_thres - (min_dist * 2), max_dist)
+            ax.plot((min_thres, 30), (min_thres, 30), linestyle='-', color='grey', zorder=-1,
+                    linewidth=0.35)
+            ax.spines['left'].set_bounds(min_thres, 30)
+            ax.spines['bottom'].set_bounds(min_thres, 30)
+            ax.set_xticks((min_thres - min_dist, 10, 30, max_dist))
+            ax.set_yticks((min_thres - min_dist, 10, 30, max_dist))
+            ax.set_yticklabels(('<' + str(min_thres), 10, 30, '>30'))
+            ax.set_xticklabels(('<' + str(min_thres), 10, 30, '   >30'))
+
+        elif alys == 'rsupp':
+            # compare modelled response to largest stimulus
+
+            # get data
+            cont = self.lgn_sztun_df['r_supp_201'].str[0].values
+            supp = self.lgn_sztun_df['r_supp_201'].str[1].values
+
+            # compute statistics
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # set threshold for plotting
+            min_thres = 4
+            min_dist = 1
+            max_dist = 43
+            cont_plot = cont
+            cont_plot[cont_plot > 30] = max_dist
+            cont_plot[cont_plot < min_thres] = min_thres - min_dist
+            supp_plot = supp
+            supp_plot[supp_plot > 30] = max_dist
+            supp_plot[supp_plot < min_thres] = min_thres - min_dist
+
+            # layout
+            titlestr = 'Large size\nresponse (sp/s)'
+            ax.set_yscale('log')
+            ax.set_xscale('log')
+            for axis in [ax.xaxis, ax.yaxis]:
+                axis.set_major_formatter(ScalarFormatter())
+                axis.set_minor_formatter(NullFormatter())
+            ax.set_title(titlestr)
+            ax.set_xlim(min_thres - (min_dist * 2), max_dist)
+            ax.set_ylim(min_thres - (min_dist * 2), max_dist)
+            ax.plot((min_thres, 30), (min_thres, 30), linestyle='-', color='grey', zorder=-1,
+                    linewidth=0.35)
+            ax.spines['left'].set_bounds(min_thres, 30)
+            ax.spines['bottom'].set_bounds(min_thres, 30)
+            ax.set_xticks((min_thres - min_dist, 10, 30, max_dist))
+            ax.set_yticks((min_thres - min_dist, 10, 30, max_dist))
+            ax.set_yticklabels(('<' + str(min_thres), 10, 30, '>30'))
+            ax.set_xticklabels(('<' + str(min_thres), 10, 30, '   >30'))
+
+        elif alys == 'rfcs':
+            # compare preferred size
+
+            # get data
+            cont = self.lgn_sztun_df['rfcs_76'].str[0].values
+            supp = self.lgn_sztun_df['rfcs_76'].str[1].values
+
+            # compute and calculate stats
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # set threshold for plotting
+            cont_plot = cont
+            cont_plot[cont_plot > 30] = 35
+            supp_plot = supp
+            supp_plot[supp_plot > 30] = 35
+
+            # layout
+            titlestr = 'Preferred size ($\degree$)'
+            ax.set_title(titlestr)
+            ax.set_xlim(-1.75, 35)
+            ax.set_ylim(-1.75, 35)
+            ax.plot((0, 30), (0, 30), linestyle='-', color='grey', zorder=-1, linewidth=0.35)
+            ax.spines['left'].set_bounds(0, 30)
+            ax.spines['bottom'].set_bounds(0, 30)
+            ax.set_xticks((0, 15, 30, 35))
+            ax.set_yticks((0, 15, 30, 35))
+            ax.set_yticklabels((0, 15, 30, '>30'))
+            ax.set_xticklabels((0, 15, 30, '   >30'))
+
+        elif alys == 'si':
+            # compare suppression indices
+
+            # get data
+            cont_plot = self.lgn_sztun_df['si_76'].str[0].values
+            supp_plot = self.lgn_sztun_df['si_76'].str[1].values
+
+            # compute and calculate stats
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont_plot, supp=supp_plot,
+                                                               alys=alys)
+
+            # layout
+            titlestr = 'SI'
+            ax.set_title(titlestr)
+            ax.set_xlim(-0.05, 1.05)
+            ax.set_ylim(-0.05, 1.05)
+            ax.set_xticks((0, 0.5, 1))
+            ax.set_yticks((0, 0.5, 1))
+            ax.plot((0, 1), (0, 1), linestyle='-', color='grey', linewidth=0.35, zorder=-1)
+            ax.spines['left'].set_bounds(0, 1)
+            ax.spines['bottom'].set_bounds(0, 1)
+
+        else:
+            print('No proper analysis selected')
+            return
+
+        # general layout
+        ax.set_title(titlestr)
+        ax.scatter(cont_plot, supp_plot, s=15, facecolors='none', edgecolors='k',
+                   linewidth=0.5, clip_on=False)
+        ax.plot(cont_mean, supp_mean, linestyle='', marker='.', color='goldenrod', ms=15)
+        # plot example
+        ax.plot(cont_plot[22], supp_plot[22], linestyle='', marker='.', color='deeppink',
+                ms=15)
+        ax.set_ylabel('V1 suppression')
+        ax.yaxis.label.set_color(optocolor)
+        ax.set_xlabel('Control')
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax

+ 358 - 0
figs/fig3.py

@@ -0,0 +1,358 @@
+# code to plot modelled data in figure 3
+
+# import libs
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.patches as patch
+import scipy.stats as stats
+from importlib import reload
+import spatint_utils
+import sim_edog
+
+# reload modules
+reload(spatint_utils)
+
+# load params
+spatint_utils.plot_params()
+_, _, optocolor = spatint_utils.get_colors()
+cmpin = 2.54
+
+
+class Fig3:
+    """Class for plotting Figure 3"""
+
+    def __init__(self, figsize=np.array((18.3, 8)), edog=None, inhib_fb=(1, 3, 9, 40)):
+        """Init class
+
+        Parameters
+        -----
+        figsize: tuple len 2
+            figuresize (width, heigth)
+        edog: class object
+            edog with simulated data
+        inhib_fb: tuple
+            example inhibitory feedback kernel widths to be plotted
+        """
+
+        # store init parameters
+        self.figsize = figsize / cmpin                     # figure size
+        self.sids = edog.sids                              # change in suppression indices
+        self.rfcsds = edog.rfcsds                          # change in preferred size
+        self.smallrateds = edog.smallrateds                # change in response to small stim
+        self.largerateds = edog.largerateds                # change in response to large stim
+        self.lowbounds = edog.lowbounds                    # dict with low bounds for matching
+        self.highbounds = edog.highbounds                  # dict with high bounds for matching
+        self.inhib_fb = inhib_fb                           # ex inhibitory fb kernel widths
+        self.start = edog.start                            # first inhib fb kernel width
+        self.stop = edog.stop                              # last inhib fb kernel width
+        self.inhib_fbi = np.array(inhib_fb) - edog.start   # index for example kernels
+
+    def plot(self):
+        """Plot figure"""
+
+        # init figure
+        f = plt.figure(figsize=self.figsize)
+
+        # init plotting variables
+        interplotspace = 0.5 / cmpin / self.figsize[0]
+        axdict = {}
+        axdict['b'] = 3 / cmpin / self.figsize[1]
+        axdict['h'] = 1.5 / cmpin / self.figsize[1]
+        axdict['w'] = 1.5 / cmpin / self.figsize[0]
+        l = 5 / cmpin / self.figsize[0]
+
+        # plot modelled size tuning curves
+        inhib_fbs = self.inhib_fb
+        inhib_fbi = self.inhib_fbi
+        start = self.start
+        stop = self.stop
+        edog_model = sim_edog.SimEdog()
+        sb_label = False
+
+        # get color for inhib kernels
+        cmap = plt.cm.get_cmap('Reds_r')
+        coli = np.linspace(0, 0.5, 4)
+
+        # loop over example inhibitory width
+        for i, inhib_fb in enumerate(inhib_fbs):
+
+            # add axis
+            ax = f.add_axes([l, axdict['b'], axdict['w'], axdict['h']])
+            # add curves
+            edog_model.plot(ax=ax, inhib_fb=inhib_fb)
+            # add kernel schemas
+            self._plot_kernel(inhib_fb=inhib_fb, l=l, f=f, col=cmap(coli[i]), i=i,
+                              sb_label=sb_label)
+
+            # label according to panel position
+            if i > 0:
+                # remove labels
+                ax.set_ylabel('')
+                ax.set_yticklabels('')
+                ax.set_xlabel('')
+                ax.set_xticklabels('')
+
+            if i < (len(inhib_fbs) - 1):
+                # add length
+                l += interplotspace + axdict['w']
+
+            if i == (len(inhib_fbs) - 2):
+                # add sb_label
+                sb_label = True
+
+            elif i == (len(inhib_fbs) - 1):
+                ax.text(10, 0.35, 'FB weight = 1', color='k')
+                ax.text(10, 0.15, 'FB weight = 0', color=optocolor)
+
+        # get x values
+        x = np.arange(start, stop, 1)
+
+        # define markersize
+        ms = 8
+
+        # add change in receptive field center size
+        # prepare axis
+        interplotspace = 1.5 / cmpin / self.figsize[0]
+        axdict['h'] = 1.2 / cmpin / self.figsize[1]
+        axdict['w'] = 1.2 / cmpin / self.figsize[0]
+        l += axdict['w'] + interplotspace + 0.02
+        ax = f.add_axes([l, axdict['b'], axdict['w'], axdict['h']])
+
+        # add plot
+        ax.plot(x, self.rfcsds, c='k')
+
+        # edit layout
+        xlims = ax.get_xlim()
+        ax.hlines(0, xlims[0], xlims[1], colors='grey', linestyle='--', linewidth=0.5)
+        ax.set_xticks((1, 20, 40))
+        ax.set_yticks((0, -5, -10))
+        ax.spines['bottom'].set_bounds(1, 40)
+        ax.spines['left'].set_bounds(-10, 0)
+        ax.set_ylabel('$\Delta$ preferred\nsize (%)')
+        ax.set_xlabel('Inh FB kernel width ($\degree$)')
+
+        # add quality rectangles
+        ylims = ax.get_ylim()
+        bottom = ylims[0]
+        height = ylims[1] - bottom
+        width_qual = self.highbounds['all'] - self.lowbounds['all']
+
+        # overall
+        rect_all = patch.Rectangle((self.lowbounds['all'], bottom), width_qual, height,
+                                   facecolor='gold', alpha=0.7, zorder=0)
+        ax.add_patch(rect_all)
+        # preferred size specific
+        rfcs_width = self.highbounds['rfcsds'] - self.lowbounds['rfcsds']
+        rect_rfcs = patch.Rectangle((self.lowbounds['rfcsds'], bottom), rfcs_width, height,
+                                    facecolor='gold', alpha=0.2, zorder=0)
+        ax.add_patch(rect_rfcs)
+
+        # add example points
+        ax.scatter(inhib_fbs, np.array(self.rfcsds)[inhib_fbi], c=np.array(cmap(coli)), s=ms,
+                   zorder=3)
+
+        # add changes in suppression index
+        # prepare axis
+        l += axdict['w'] + interplotspace
+        ax = f.add_axes([l, axdict['b'], axdict['w'], axdict['h']])
+
+        # add plot
+        ax.plot(x, self.sids, c='k')
+
+        # edit layout
+        xlims = ax.get_xlim()
+        ax.hlines(0, xlims[0], xlims[1], colors='grey', linestyle='--', linewidth=0.5)
+        ax.set_xticks((1, 20, 40))
+        ax.set_yticks((0, 100, 200))
+        ax.spines['bottom'].set_bounds(1, 40)
+        ax.spines['left'].set_bounds(0, 200)
+        ax.set_ylabel('$\Delta$ suppression\nindex (%)')
+
+        # add quality rectangle
+        ylims = ax.get_ylim()
+        bottom = ylims[0]
+        height = ylims[1] - bottom
+
+        # overall
+        rect_all = patch.Rectangle((self.lowbounds['all'], bottom), width_qual, height,
+                                   facecolor='gold', alpha=0.7, zorder=0)
+        ax.add_patch(rect_all)
+        # si specific
+        si_width = self.highbounds['sids'] - self.lowbounds['sids']
+        rect_si = patch.Rectangle((self.lowbounds['sids'], bottom), si_width, height,
+                                  facecolor='gold', alpha=0.2, zorder=0)
+        ax.add_patch(rect_si)
+
+        # add example points
+        ax.scatter(inhib_fbs, np.array(self.sids)[inhib_fbi], c=np.array(cmap(coli)), s=ms,
+                   zorder=3)
+
+        # add changes in smallrates
+        # prepare axis
+        axdict['b'] += axdict['h'] + (1.5 / cmpin / self.figsize[0])
+        l -= axdict['w'] + interplotspace
+        ax = f.add_axes([l, axdict['b'], axdict['w'], axdict['h']])
+
+        # add plot
+        ax.plot(x, self.smallrateds, c='k')
+
+        # layout
+        xlims = ax.get_xlim()
+        ax.hlines(0, xlims[0], xlims[1], colors='grey', linestyle='--', linewidth=0.5)
+        ax.set_xticks((1, 20, 40))
+        ax.set_xticklabels('')
+        ax.set_yticks((-25, 0, 25))
+        ax.spines['bottom'].set_bounds(1, 40)
+        ylims = ax.get_ylim()
+        ax.spines['left'].set_bounds(-25, ylims[1])
+        ax.set_ylabel('$\Delta$ response to\npreferred size (%)')
+
+        # add quality rectangle
+        ylims = ax.get_ylim()
+        bottom = ylims[0]
+        height = ylims[1] - bottom
+        # overall
+        rect_all = patch.Rectangle((self.lowbounds['all'], bottom), width_qual, height,
+                                   facecolor='gold', alpha=0.7, zorder=0)
+        ax.add_patch(rect_all)
+        # smallrate specific
+        smallrate_width = self.highbounds['smallrateds'] - self.lowbounds['smallrateds']
+        rect_smallrate = patch.Rectangle((self.lowbounds['smallrateds'], bottom),
+                                         smallrate_width, height, facecolor='gold',
+                                         alpha=0.2, zorder=0)
+        ax.add_patch(rect_smallrate)
+        # add example points
+        ax.scatter(inhib_fbs, np.array(self.smallrateds)[inhib_fbi], c=np.array(cmap(coli)),
+                   s=ms, zorder=3)
+
+        # add changes in large rates
+        # add axis
+        l += axdict['w'] + interplotspace
+        ax = f.add_axes([l, axdict['b'], axdict['w'], axdict['h']])
+
+        # add plot
+        ax.plot(x, self.largerateds, c='k')
+
+        # edit layout
+        xlims = ax.get_xlim()
+        ax.hlines(0, xlims[0], xlims[1], colors='grey', linestyle='--', linewidth=0.5)
+        ax.set_xticks((1, 20, 40))
+        ax.set_xticklabels('')
+        ax.set_yticks((-20, 0))
+        ax.spines['bottom'].set_bounds(1, 40)
+        ylims = ax.get_ylim()
+        ax.spines['left'].set_bounds(ylims)
+        ax.set_ylabel('$\Delta$ response to\n largest size (%)')
+
+        # add quality rectangle
+        ylims = ax.get_ylim()
+        bottom = ylims[0]
+        height = ylims[1] - bottom
+        # overall
+        rect_all = patch.Rectangle((self.lowbounds['all'], bottom), width_qual, height,
+                                   facecolor='gold', alpha=0.7, zorder=0)
+        ax.add_patch(rect_all)
+        # largerate specific
+        largerate_width = self.highbounds['largerateds'] - self.lowbounds['largerateds']
+        rect_largerate = patch.Rectangle((self.lowbounds['largerateds'], bottom),
+                                         largerate_width, height, facecolor='gold', alpha=0.2,
+                                         zorder=0)
+        ax.add_patch(rect_largerate)
+        # add example points
+        ax.scatter(inhib_fbs, np.array(self.largerateds)[inhib_fbi], c=np.array(cmap(coli)),
+                   s=ms, zorder=3)
+
+    def _plot_kernel(self, inhib_fb, l, f, col, i, sb_label=True):
+        """plot schematic kernels
+
+        Parameters
+        -----
+        inhib_fb: int
+            widht of inhibitory feedback kernel
+        l: float
+            left edge of plot
+        f: mpl figure
+            figure
+        col: tuple
+            color for inhibitory kernel
+        i: int
+            panel index
+        sb_label: bool
+            if true plot scalebar
+        """
+
+        # define axis for inhibitory gaussian
+        b = 5.65 / cmpin / self.figsize[1]
+        h = 0.2 / cmpin / self.figsize[1]
+        w = 2 / cmpin / self.figsize[0]
+        l -= 0.2 / cmpin / self.figsize[0]
+        zz = 0.1
+        xlims = (-40, 40)
+
+        # compute inhibtory gaussian
+        x = np.arange(-30, 31, 1)
+        mu = 0
+        sigma_c = inhib_fb
+        amp_c = -0.6
+        normdist = stats.norm.pdf(x, mu, sigma_c)
+        y_inhib = normdist * amp_c
+
+        # plot inhibitory gaussian
+        ax = f.add_axes([l, b, w, h])
+        ax.plot(x,y_inhib, color=col, clip_on=False)
+        ax.set_ylim((0, zz))
+        ax.set_xlim(xlims)
+        plt.axis('off')
+
+        # add x axis label
+        if i == 0:
+            ax.text(-45, -0.1, 'inh', rotation='vertical', c='r')
+
+        # define axis for excitatory gaussian
+        b = 6 / cmpin / self.figsize[1]
+
+        # compute excitatory gaussian
+        mu = 0
+        sigma_c = 1
+        amp_c = 0.3
+        normdist = stats.norm.pdf(x, mu, sigma_c)
+        y_ex = normdist * amp_c
+
+        # plot excitatory gaussian
+        ax = f.add_axes([l, b, w, h])
+        ax.plot(x, y_ex, color='g', clip_on=False)
+        ax.set_ylim((0, zz))
+        ax.set_xlim(xlims)
+        plt.axis('off')
+
+        # add x axis label
+        if i == 0:
+            ax.text(-45, 0, 'exc', rotation='vertical', c='g')
+
+        # add scalebar
+        x1 = 20
+        x2 = 30
+        xdiff = np.abs(x2 - x1)
+        ax.plot((x1, x2), (0.13, 0.13), color='k', clip_on=False)
+
+        if sb_label:
+            scalebar = '%d$\degree$' % xdiff
+            ax.text(x1, 0.18, scalebar, color='k', clip_on=False)
+
+        # define axis for sum of gaussians
+        b = 5 / cmpin / self.figsize[1]
+
+        # compute sum of two gaussians
+        y_sum = y_ex + y_inhib
+
+        # plot sum of gaussians
+        ax = f.add_axes([l, b, w, h])
+        ax.plot(x, y_sum, color='k', clip_on=False)
+        ax.set_ylim((0, zz))
+        ax.set_xlim(xlims)
+        plt.axis('off')
+
+        # add x axis label
+        if i == 0:
+            ax.text(-45, 0, 'sum', rotation='vertical')
+            ax.text(-70, 0, 'CT feedback\n kernel ($K_{RC}$)', rotation='vertical')

+ 601 - 0
figs/fig4.py

@@ -0,0 +1,601 @@
+# code for Figure 4 panels
+
+# import libs
+import matplotlib
+from matplotlib import pyplot as plt
+from matplotlib.ticker import PercentFormatter
+import numpy as np
+import pandas
+import seaborn as sns
+from scipy import stats
+from importlib import reload
+import pickle
+import spatint_utils
+
+# reload modules
+reload(spatint_utils)
+
+# define variables
+trn_red, lgn_green, _ = spatint_utils.get_colors()
+spatint_utils.plot_params()
+
+
+class Fig4:
+    """Class to for plotting panels for Fig. 4"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read trn size tuning dataframe
+        self.trn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_sztun_df.pkl')
+
+        # read dict for trn_sztun_ex
+        with open('./to_publish/data/trn_sztun_ex_dict.pkl', 'rb') as f:
+            self.trn_sztun_ex_dict = pickle.load(f)
+
+        # read trn retinotopy dataframe
+        self.trn_retino_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_retino_df.pkl')
+
+        # get data for trn retinotopy example series
+        self.trn_retino_ex = self.trn_retino_df[(self.trn_retino_df.m == 'BL6_2018_0003') &
+                                                (self.trn_retino_df.s == 7)]
+
+        # read trn/lgn rf area dataframe
+        self.rf_area_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/rf_area_df.pkl')
+
+        # read lgn size tuning dataframe
+        self.lgn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/lgn_sztun_df.pkl')
+
+    def exrfs(self):
+        """Plot example receptive fields (Fig. 4d,h)
+
+        Returns
+        -------
+        axs: list
+            list with axes for example rfs
+        """
+
+        # define keys for example rfs
+        exrfs = [
+
+            # trn example rfs
+            {'m': 'BL6_2018_0003', 's': 7, 'e': 1, 'u': 25},
+            {'m': 'BL6_2018_0003', 's': 4, 'e': 1, 'u': 9},
+            {'m': 'PVCre_2018_0009', 's': 4, 'e': 10, 'u': 10},
+            {'m': 'BL6_2018_0003', 's': 2, 'e': 1, 'u': 17},
+            {'m': 'BL6_2018_0003', 's': 3, 'e': 1, 'u': 12},
+
+            # dlGN example rfs
+            {'m': 'Ntsr1Cre_2019_0003', 's': 4, 'e': 1, 'u': 7},
+            {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 23},
+            {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 29}]
+
+        # init figure
+        f, axs = plt.subplots(1, 8, figsize=(7, 1.5))
+
+        for i, (ax, exrf) in enumerate(zip(axs, exrfs)):
+
+            # get target row in df
+            ex_row = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
+                                     (self.rf_area_df.s == exrf['s']) &
+                                     (self.rf_area_df.e == exrf['e']) &
+                                     (self.rf_area_df.u == exrf['u'])]
+            if i == len(exrfs) - 1:
+                scalebar = {'dist': 5, 'width': 1, 'length': 20, 'col': 'w'}
+            else:
+                scalebar = {'width': None, 'dist': None, 'length': None, 'col': None}
+            # plot rfs
+            spatint_utils.plot_rf(means=ex_row.mean_fr.values[0],
+                                  stim_param_values=ex_row.ti_axes.values[0],
+                                  grat_width=ex_row.grat_width,
+                                  grat_height=ex_row.grat_height,
+                                  ax=ax,
+                                  scalebar=scalebar)
+            # add labels
+            if i == 0:
+                ax.set_ylabel('visTRN')
+            elif i == 5:
+                ax.set_ylabel('dLGN')
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        # describe two example rfs
+        self.desc_exrfs(exrfs)
+
+        return axs
+
+    def desc_exrfs(self, exrfs):
+        """Describe two example RFs for Fig 4d
+
+        Parameters
+        -----
+        exrfs: list
+            example keys
+        """
+
+        labels = ['small', 'large']
+
+        for i, label in zip(np.array([3, 4]), labels):
+            exrf = exrfs[i]
+            ex_area = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
+                                      (self.rf_area_df.s == exrf['s']) &
+                                      (self.rf_area_df.e == exrf['e']) &
+                                      (self.rf_area_df.u == exrf['u'])]
+            print('%s area= %0.3f \n'
+                  '%s rsq = %0.3f \n'
+                  % (label, ex_area.area, label, ex_area.add_rsq))
+
+    def retino_ex(self, figsize=(2.5, 2.5), axs=None):
+        """Plot RFs map of example trn recording (Fig. 4e)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        axs: list
+            two axes used for plot and colorbar
+
+        Returns
+        -------
+        axs: list
+            two axes with plot and colorbar
+        """
+
+        if axs is None:
+            # create figure
+            f = plt.figure(figsize=figsize)
+            axs = []
+
+            # axis for plot
+            l = 0.15
+            b = 0.15
+            w = 0.6
+            h = 0.6
+            axs.append(f.add_axes([l, b, w, h]))
+
+            # axis for colorbar
+            l = 0.8
+            w_cbar = 0.05
+            axs.append(f.add_axes([l, b, w_cbar, h]))
+
+        # print number of units in series
+        exseries = self.trn_retino_ex
+        print('number of units in example series:', len(exseries))
+
+        # define colors
+        vmin = 0.3
+        vmax = 0.95
+        colors = np.linspace(vmin, vmax, len(exseries))
+        mymap = plt.cm.get_cmap("Reds")
+        my_colors = mymap(colors)
+
+        # iterate over units
+        for urowi, (_, urow) in enumerate(exseries.iterrows()):
+
+            # compute parameters
+            params = urow.params
+            angle = params[2] * 180 / np.pi
+            fitpars_deg = spatint_utils.degdiff(180, angle, 180)
+
+            # calculate ellipse
+            x, y = spatint_utils.calculate_ellipse(params[0], params[1], params[4], params[5],
+                                                   fitpars_deg)
+            # plot
+            axs[0].plot(x, y, c=my_colors[urowi])
+
+        # layout
+        axs[0].set_xlim(-20, 70)
+        axs[0].set_ylim(-20, 70)
+        axs[0].set_xticks((0, 30, 60))
+        axs[0].set_yticks((0, 30, 60))
+        axs[0].spines['bottom'].set_bounds(-10, 70)
+        axs[0].spines['left'].set_bounds(-10, 70)
+        axs[0].set_ylabel('Elevation ($\degree$)')
+        axs[0].set_xlabel('Azimuth ($\degree$)')
+
+        # draw colorbar
+        sm = matplotlib.colors.LinearSegmentedColormap.from_list('Reds', my_colors)
+        vmin_plot = 0.15
+        vmax_plot = 1
+        norm = matplotlib.colors.Normalize(vmin=vmin_plot, vmax=vmax_plot)
+        cbar = matplotlib.colorbar.ColorbarBase(axs[1], cmap=sm, norm=norm)
+        cbar.set_label('Depth (μm)', rotation=270)
+        cbar.ax.invert_yaxis()
+        barunits = (vmax - vmin) / (exseries.depth.iloc[-1] - exseries.depth.iloc[0])
+        ticklabels = np.array([3050, 3100, 3150])
+        ticks = ((ticklabels - exseries.depth.iloc[0]) * barunits) + vmin
+        cbar.set_ticks(ticks)
+        cbar.set_ticklabels(ticklabels)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return axs
+
+    def trn_retino(self, figsize=(6, 2.5), axs=None):
+        """Plot RFs map of example trn recording (Fig. 4f-g)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        axs: list
+            two axes, one per dimension (azim, elev)
+
+        Returns
+        -------
+        axs: list
+            two axes
+        """
+
+        if axs is None:
+            # create figure
+            f, axs = plt.subplots(1, 2, figsize=figsize)
+
+        # get data
+        trn_retino_df = self.trn_retino_df
+
+        # plot azimuth against depth
+        axs[0].scatter(trn_retino_df.azim, trn_retino_df.depth, facecolors='none',
+                       edgecolors='k', linewidth=0.5)
+        # print n
+        n = len(trn_retino_df.azim)
+        print('n azimuth = %d' % n)
+
+        # plot regression line from ancova model
+        xeval = np.arange(np.min(trn_retino_df.azim), np.max(trn_retino_df.azim), 1)
+        # parameters for the model copied from R
+        intercept = 3044.511585
+        vis_angle = -1.191659
+        azim = 60.728882  # for category 1: azim 0: elev
+        interaction = -1.838795
+        ymodel_azim = intercept + xeval * vis_angle + 1 * azim + 1 * xeval * interaction
+        axs[0].plot(xeval, ymodel_azim, 'r')
+
+        # layout
+        axs[0].set_xticks((0, 30, 60))
+        axs[0].set_yticks((2500, 2900, 3300))
+        axs[0].invert_yaxis()
+        axs[0].set_ylabel('Depth (μm)')
+        axs[0].set_xlabel('Azimuth ($\degree$)')
+        axs[0].spines['bottom'].set_bounds(-20, 70)
+        axs[0].spines['left'].set_bounds(2500, 3300)
+
+        # plot elevation against depth
+        axs[1].scatter(trn_retino_df.elev, trn_retino_df.depth, facecolors='none',
+                       edgecolors='k', linewidth=0.5)
+
+        # print n
+        n = len(trn_retino_df.elev)
+        print('n elev = %d' % n)
+
+        # plot regression line from model
+        xeval = np.arange(np.min(trn_retino_df.elev), np.max(trn_retino_df.elev), 1)
+        ymodel_elev = intercept + xeval * vis_angle + 0 * azim + 0 * xeval * interaction
+        axs[1].plot(xeval, ymodel_elev, 'r')
+
+        # layout
+        axs[1].set_xticks((0, 30, 60))
+        axs[1].set_yticks((2500, 2900, 3300))
+        axs[1].invert_yaxis()
+        axs[1].set_ylabel('')
+        axs[1].set_yticklabels([])
+        axs[1].set_xlabel('Elevation ($\degree$)')
+        axs[1].spines['bottom'].set_bounds(-20, 70)
+        axs[1].spines['left'].set_bounds(2500, 3300)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return axs
+
+    def rf_area(self, figsize=(2, 2), ax=None):
+        """Create violin plot for comparison of TRN and LGN RF sizes (Fig. 4i)
+
+        Parameters
+        -------
+        figsize: tuple
+            figure size (width, height)
+        ax: instance of matplotlib.axes class
+            axis to use for plotting
+
+        Returns
+        -------
+        ax: mpl axis
+            Axis with plot
+        """
+
+        if ax is None:
+            # make figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        rf_area_df = self.rf_area_df
+
+        # split data
+        trn_area = rf_area_df['area'][rf_area_df['region'] == 'PGN'].array
+        lgn_area = rf_area_df['area'][rf_area_df['region'] == 'LGN'].array
+
+        # plot data
+        sns.violinplot(data=[np.log(trn_area), np.log(lgn_area)], palette=[trn_red,
+                       lgn_green], ax=ax, linewidth=1, inner=None)
+        # plot mean
+        ax.plot([0, 1], [np.log(trn_area.mean()), np.log(lgn_area.mean())], linestyle='',
+                c='k', marker='.')
+
+        # format plot
+        ylabels = np.array([10, 100, 1000])
+        ax.set_yticks(np.log(ylabels))
+        ax.set_yticklabels(ylabels)
+        ax.spines['right'].set_visible(False)
+        ax.spines['top'].set_visible(False)
+        ylims = ax.get_ylim()
+        ax.spines['bottom'].set_bounds(0, 1)
+        ax.spines['left'].set_bounds(ylims)
+        plt.gca().get_xticklabels()[0].set_color(trn_red)
+        plt.gca().get_xticklabels()[1].set_color(lgn_green)
+        ax.set_xticklabels(['visTRN', 'dLGN'])
+        ax.set_ylabel('RF area (deg$^2$)')
+        ax.grid(False)
+
+        # mannwhitneyUtest
+        u_stat, parea = stats.mannwhitneyu(trn_area, lgn_area)
+
+        # test for differences in variance
+        # (with center = median, levene's test = brown-forsythe test)
+        f_stat, pvar = stats.levene(trn_area, lgn_area, center='median')
+
+        # ratio
+        ratio = trn_area.mean() / lgn_area.mean()
+
+        # print stats
+        print('dispersion stats: Brown–Forsythe test\n'
+              'Fstat = %0.3f \n'
+              'pval = 10**%0.3f\n\n'
+              'central tendency stats\n'
+              'Ustat = %0.3f\n'
+              'pval area = 10**%0.3f \n'
+              'N area visTRN = %d \n'
+              'N area dLGN = %d\n'
+              'visTRN mean area  +- sem = %0.3f (+- %0.3f)\n'
+              'dLGN mean area  +- sem = %0.3f (+- %0.3f)\n'
+              'visTRN rfs are on average %0.3f x larger than dLGN rfs'
+              % (f_stat,
+                 np.log10(pvar),
+                 u_stat,
+                 np.log10(parea),
+                 len(trn_area),
+                 len(lgn_area),
+                 trn_area.mean(),
+                 stats.sem(trn_area),
+                 lgn_area.mean(),
+                 stats.sem(lgn_area),
+                 ratio))
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def norm_szcurves(self, figsize=(3, 3), ax=None, eval_range=76, thres=128, mark_ex=True,
+                      lw=0.5, xticks=(0, 25, 50, 75), colormap='Greys'):
+        """Plot normalized fitted size tuning curves for visTRN population (Fig. 4l)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        eval_range: int
+            Range over which to evaluate model
+        thres: int
+            Lower threshold for darkness of line
+        mark_ex: bool
+            If true plots example neuron in different color
+        lw: float
+            Linewidth
+        xticks: tuple
+            xticks
+        colormap: string
+            Colormap
+
+        Returns
+        -------
+        ax: mpl axis
+            Axis with plot
+        """
+
+        if ax is None:
+            # create figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # define range over which to evaluate model
+        x_eval = range(eval_range)
+
+        # define colormap
+        cmap = plt.cm.get_cmap(colormap)
+
+        # plot all tuning curves
+        for row in self.trn_sztun_df.itertuples():
+            # get y data
+            params = row.tun_pars
+            y = spatint_utils.rog_offset(x_eval, *params)
+            # subtract offset
+            y_sub = y - y[0]
+            # normalize
+            y_norm = y_sub / np.nanmax(y_sub)
+            # define color
+            si = row.si_76
+            col = int(np.round((1 - si) * 255))
+            col = np.max((col, thres))
+            # plot
+            ax.plot(x_eval, y_norm, c=cmap(col), lw=lw)
+
+        if mark_ex:
+            # plot example session in red
+            params = self.trn_sztun_ex_dict['tun_pars']
+            y = spatint_utils.rog_offset(x_eval, *params)
+            y_sub = y - y[0]
+            y_norm = y_sub / np.nanmax(y_sub)
+            ax.plot(x_eval, y_norm, c=trn_red, lw=lw)
+
+        # layout
+        ax.set_xlabel('Diameter ($\degree$)')
+        ax.set_ylabel('Normalized firing rate')
+        ax.set_xticks(xticks)
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['bottom'].set_bounds(0, 75)
+        ax.spines['left'].set_bounds(0, 1)
+
+        return ax
+
+    def ex_sztun_curve(self, figsize=(4, 2), axs=None):
+        """Plot example visTRN size-tuning curve and raster (Fig. 4jk)
+
+        Parameters
+        -------
+        figsize: tuple
+            figure size (width, height)
+        axs: instance of matplotlib.axes class
+            axis to use for plotting
+
+        Returns
+        -------
+        ax: mpl axis
+            Axis with plot
+        """
+
+        if axs is None:
+            # create figure
+            f, axs = plt.subplots(1, 2, figsize=figsize)
+
+        # get data
+        ex_data = self.trn_sztun_ex_dict
+
+        # plot raster
+        spatint_utils.plot_raster(raster=ex_data['rasters'][ex_data['u']],
+                                  tranges=ex_data['tranges'],
+                                  opto=ex_data['opto'],
+                                  ax=axs[0])
+
+        # plot curve
+        spatint_utils.plot_tun(means=ex_data['tun_mean'],
+                               sems=ex_data['tun_sem'],
+                               spons=ex_data['tun_spon_mean'],
+                               xs=ex_data['ti_axes'],
+                               params=ex_data['tun_pars'],
+                               ax=axs[1])
+
+        # format layout
+        axs[1].set_xticks((0, 25, 50, 75))
+        axs[1].set_yticks((0, 10, 20))
+        axs[1].spines['bottom'].set_bounds(0, 75)
+        f = plt.gcf()
+        f.tight_layout()
+
+        # print info for example cell
+        sz_ex_si = self.trn_sztun_ex_dict['si_76']
+        sz_ex_rfcs = self.trn_sztun_ex_dict['rfcs_76']
+
+        print('size tuning example cell: \n'
+              'Preferred size = %0.3f \n'
+              'SI = %0.3f'
+              % (sz_ex_rfcs, sz_ex_si))
+
+    def si(self, figsize=(2, 2), ax=None):
+        """Plot histogram of suppression indices for visTRN and dLGN population (Fig. 4m)
+
+        Parameters
+        -------
+        figsize: tuple
+            figure size (width, height)
+        ax: instance of matplotlib.axes class
+            axis to use for plotting
+
+        Returns
+        -------
+        ax: mpl axis
+            Axis with plot
+        """
+
+        if ax is None:
+            # create figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        trn_sis = self.trn_sztun_df.si_76.to_numpy()
+        lgn_sis = self.lgn_sztun_df.si_76.str[0].to_numpy()
+
+        # assert that no data is beyond limits
+        assert (~np.any(trn_sis < 0)) & (~np.any(trn_sis > 1)), 'Datapoints beyond bounds'
+        assert (~np.any(lgn_sis < 0)) & (~np.any(lgn_sis > 1)), 'Datapoints beyond bounds'
+
+        # plot
+        si_bins = np.arange(0, 1.05, 0.05)  # define bins
+        trn_red_trsp = (*trn_red[0:3], 0.5)  # define color for trn
+        lgn_green_trsp = (*lgn_green[0:3], 0.5)  # define color for dlgn
+        ax.hist(trn_sis, bins=si_bins, weights=np.ones(len(trn_sis)) / len(trn_sis), lw=0,
+                fc=trn_red_trsp)
+        ax.hist(lgn_sis, bins=si_bins, weights=np.ones(len(lgn_sis)) / len(lgn_sis), lw=0,
+                fc=lgn_green_trsp)
+
+        # layout
+        ax.yaxis.set_major_formatter(PercentFormatter(1))
+        ax.set_xlabel('Suppression index')
+        ax.set_ylabel('neurons')
+        ax.set_xticks((0, 0.5, 1))
+        ax.set_yticks((0, 0.25, 0.5))
+        ax.spines['bottom'].set_bounds(0, 1)
+        ax.set_xlim((0, 1))
+        f = plt.gcf()
+        f.tight_layout()
+
+        # compute and print stats for trn
+        n_trn = len(trn_sis)
+        si_mean_trn = np.mean(trn_sis)
+        si_sem_trn = stats.sem(trn_sis)
+        si_med_trn = np.median(trn_sis)
+        lower05_trn = (len(trn_sis[trn_sis < 0.05]) /
+                       len(trn_sis) * 100)  # percentage of cells with si smaller 0.05
+
+        print('trn size tuning population: \n'
+              'n = %d\n'
+              'mean si +/- sem = %0.3f (+- %0.3f) \n'
+              'median si = %0.3f \n'
+              '%0.3f percent of pgn cells have si < 0.05\n'
+              % (n_trn,
+                 si_mean_trn,
+                 si_sem_trn,
+                 si_med_trn,
+                 lower05_trn))
+
+        # compute and print stats for dlgn
+        n_lgn = len(lgn_sis)
+        si_mean_lgn = np.mean(lgn_sis)
+        si_sem_lgn = stats.sem(lgn_sis)
+        si_med_lgn = np.median(lgn_sis)
+
+        print('lgn size tuning population:\n'
+              'n = %d\n'
+              'mean si +/- sem = %0.3f (+- %0.3f)\n'
+              'median si = %0.3f\n'
+              % (n_lgn,
+                 si_mean_lgn,
+                 si_sem_lgn,
+                 si_med_lgn))
+
+        # compare the two
+        ustat_sis, p_sis = stats.mannwhitneyu(trn_sis, lgn_sis)
+
+        print('mannwhitneyu test to compare si in dlgn and trn:\n'
+              'Ustat: %0.3f\n'
+              'pvalue: 10**%0.3f\n'
+              % (ustat_sis, np.log10(p_sis)))
+
+        return ax

+ 758 - 0
figs/fig5.py

@@ -0,0 +1,758 @@
+# code for Figure 5 panels
+
+# import libs
+from matplotlib import pyplot as plt
+import numpy as np
+import pandas
+from scipy import stats
+from importlib import reload
+import pickle
+from scipy.optimize import curve_fit
+import spatint_utils
+
+
+# reload module
+reload(spatint_utils)
+
+spatint_utils.plot_params()
+_, _, optocolor = spatint_utils.get_colors()
+
+
+class Fig5:
+    """Class for plotting panels for Fig.5"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read trn size tuning dataframe
+        self.trn_sztun_opto_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_sztun_opto_df.pkl')
+
+        # read dict for trn sz tuning example
+        with open('./to_publish/data/trn_sztun_opto_ex_dict.pkl', 'rb') as f:
+            self.trn_sztun_opto_ex_dict = pickle.load(f)
+
+    def ex_sztun_curve(self, figsize=(5, 2.5), ax=None):
+        """Plot example dLGN size-tuning raster plot and curves (Fig. 5bc)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure
+            f = plt.figure(figsize=figsize)
+            axs = []
+            axs.append(f.add_axes([0.1, 0.2, 0.35, 0.6]))
+            axs.append(f.add_axes([0.6, 0.2, 0.35, 0.6]))
+
+        # get data for example trn neuron
+        ex_data = self.trn_sztun_opto_ex_dict
+
+        # plot raster
+        spatint_utils.plot_raster(raster=ex_data['rasters'][ex_data['u']],
+                                  tranges=ex_data['tranges'],
+                                  opto=ex_data['opto'],
+                                  opto_ranges=ex_data['opto_ranges'],
+                                  ax=axs[0])
+
+        # plot curves
+        spatint_utils.plot_tun(means=ex_data['tun_mean'],
+                               sems=ex_data['tun_sem'],
+                               spons=ex_data['tun_spon_mean'],
+                               xs=ex_data['ti_axes'],
+                               params=ex_data['tun_pars'],
+                               ax=axs[1],
+                               sponline='-',
+                               sponlinewidth=0.5,
+                               ms=3.4)
+
+        # format layout
+        axs[1].set_xticks((0, 25, 50, 75))
+        axs[1].set_yticks((0, 25, 50))
+        axs[1].spines['bottom'].set_bounds(0, 75)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def fit_norm_curves(self, ax=None, figsize=(2.5, 2.5)):
+        """Plot mean normalized model fit for size tuning in TRN (Fig. 5d)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure if ax is none
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        trn_sztun_opto_df = self.trn_sztun_opto_df
+
+        # restrict to units fitted in both conditions
+        trn_sztun_opto_df = trn_sztun_opto_df[
+            trn_sztun_opto_df.tun_pars.apply(lambda x: any(~np.isnan(x[1])))]
+        trn_sztun_opto_df.reset_index(inplace=True)
+
+        # create array to store all generated curves in both conditions
+        x_eval = np.arange(76)
+        ys = np.zeros((len(trn_sztun_opto_df), len(x_eval), 2))
+        ys[:] = np.nan
+
+        # iterate over units
+        for bestexp in trn_sztun_opto_df.itertuples():
+            # get size tun parameters
+            params = bestexp.tun_pars
+
+            # evaluate model
+            ycont = spatint_utils.rog_offset(x_eval, *params[0])
+            yopto = spatint_utils.rog_offset(x_eval, *params[1])
+
+            # normalize by largest value
+            maxy = np.nanmax(np.concatenate((ycont, yopto)))
+            ycont_norm = ycont / maxy
+            yopto_norm = yopto / maxy
+
+            # store curves in array
+            ys[bestexp.Index, :, 0] = ycont_norm
+            ys[bestexp.Index, :, 1] = yopto_norm
+
+        # compute mean for both conditions
+        cont_mean = np.nanmean(ys[:, :, 0], axis=0)
+        opto_mean = np.nanmean(ys[:, :, 1], axis=0)
+
+        # plot curves and sem
+        ax.plot(x_eval, cont_mean, color='k', linestyle='-')
+        cont_sem = stats.sem(ys[:, :, 0], axis=0)
+        ax.fill_between(x_eval, cont_mean - cont_sem, cont_mean + cont_sem, color='k',
+                        alpha=0.5, linewidth=0)
+        ax.plot(x_eval, opto_mean, color=optocolor, linestyle='-')
+        opto_sem = stats.sem(ys[:, :, 1], axis=0)
+        ax.fill_between(x_eval, opto_mean - opto_sem, opto_mean + opto_sem,
+                        color=optocolor, alpha=0.5, linewidth=0)
+
+        # layout
+        ax.set_ylabel('Normalized firing rate')
+        ax.set_xlabel('Diameter ($\degree$)')
+        ax.set_xticks((0, 25, 50, 75))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['bottom'].set_bounds(0, 75)
+        ax.spines['left'].set_bounds(0, 1)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def scatter(self, figsize=(2.5, 2.5), ax=None, alys=None):
+        """Plot scatterplots to compare features of size tuning curves with and without V1
+        suppression (Fig. 5e-h)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+        alys: string
+            analysis to compute: responsiveness (all_stims), burst ratio (bratio),
+            preferred size (rfcs), surround suppression (si)
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # create figure if ax is none
+            f, ax = plt.subplots(figsize=figsize)
+
+        # determine data set
+        if (alys == 'all_stims') or (alys == 'bratio'):
+            # for responsiveness and burst ratio we consider also units that fired little in
+            # opto condition or were not well fit
+            trn_sztun_opto_df = self.trn_sztun_opto_df
+
+        elif (alys == 'rfcs') or (alys == 'si'):
+            # to interpret preferred size  and si units must be fit in both conditions and
+            # must not be completely suppressed by V1 suppression
+            trn_sztun_opto_df = self.trn_sztun_opto_df[
+                (self.trn_sztun_opto_df.tun_rsq.apply(lambda x: ~np.isnan(x[1]))) &
+                (self.trn_sztun_opto_df.tun_mean.apply(lambda x: np.mean(x[1])) >= 0.1)]
+
+        # get index of example key for plotting
+        ex_row = trn_sztun_opto_df[
+                          (trn_sztun_opto_df['m'] == self.trn_sztun_opto_ex_dict['m']) &
+                          (trn_sztun_opto_df['s'] == self.trn_sztun_opto_ex_dict['s']) &
+                          (trn_sztun_opto_df['e'] == self.trn_sztun_opto_ex_dict['e']) &
+                          (trn_sztun_opto_df['u'] == self.trn_sztun_opto_ex_dict['u'])
+                          ].index.values
+        all_indic = trn_sztun_opto_df.index.values
+        ex_index = np.where(ex_row == all_indic)[0]
+
+        if alys == 'all_stims':
+            # compare mean response to all stimuli (panel e)
+
+            # get mean responses to all stims from data
+            fr_lst = trn_sztun_opto_df.tun_mean.tolist()
+            cont = np.asarray([np.nanmean(fr[:, 0]) for fr in fr_lst])
+            supp = np.asarray([np.nanmean(fr[:, 1]) for fr in fr_lst])
+
+            # get stats for example unit
+            ex_cont = cont[ex_index]
+            ex_supp = supp[ex_index]
+            ex_diff = ((ex_supp / ex_cont) - 1) * 100
+            print('example cell mean response: \n'
+                  'control: %.3f \n'
+                  'suppressed: %.3f \n'
+                  'perc change: %.3f\n'
+                  % (ex_cont, ex_supp, ex_diff))
+
+            # compute stats for population
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # layout
+            titlestr = 'Responsiveness (sp/s)'
+            ax.set_title(titlestr)
+            uplim = 85
+            assert not max(max(cont), max(supp)) > uplim, 'found resp > uplim'
+            ax.plot((0, uplim), (0, uplim), linestyle='-', color='grey', zorder=-1,
+                    linewidth=0.35)
+            ax.set_xlim(-3, uplim + 3)
+            ax.set_ylim(-3, uplim + 3)
+            ax.set_xticks((0, 40, 80))
+            ax.set_yticks((0, 40, 80))
+            xlims = ax.get_xlim()
+            ylims = ax.get_ylim()
+            ax.spines['left'].set_bounds(0, ylims[1])
+            ax.spines['bottom'].set_bounds(0, xlims[1])
+
+        elif alys == 'bratio':
+            # compare burst ratios (panel f)
+
+            # compute per condition
+            cont = trn_sztun_opto_df.bratio_c.apply(lambda x: np.nanmean(x))
+            supp = trn_sztun_opto_df.bratio_op.apply(lambda x: np.nanmean(x))
+
+            # compute and calculate stats
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # layout
+            titlestr = 'Burst ratio'
+            ax.set_title(titlestr)
+            ax.set_xlim(0.003, 1.05)
+            ax.set_ylim(0.003, 1.05)
+            ax.plot((0.0035, 1), (0.0035, 1), linestyle='-', color='grey', linewidth=0.35,
+                    zorder=-1)
+            ax.spines['left'].set_bounds(0.0035, 1)
+            ax.spines['bottom'].set_bounds(0.0035, 1)
+
+            # threshold for plotting
+            cont[cont == 0] = 0.0035
+            supp[supp == 0] = 0.0035
+            assert not max(max(cont), max(supp)) > 1, 'found bratio > 1'
+            assert not min(min(cont), min(supp)) < 0.0035, 'found bratio < 0.0035'
+
+            # print burst ratios for example neuron
+            ex_cont = cont[ex_index]
+            ex_supp = supp[ex_index]
+            print('example neuron burst ratio: \n'
+                  'control: %.3f \n'
+                  'suppressed: %.3f \n'
+                  % (ex_cont, ex_supp))
+
+        elif alys == 'rfcs':
+            # compare preferred size (panel g)
+
+            # get data
+            rfcs_int = np.vstack(trn_sztun_opto_df.rfcs_76.values)
+            cont = rfcs_int[:, 0]
+            supp = rfcs_int[:, 1]
+
+            # print preferred size for example neuron
+            ex_cont = rfcs_int[ex_index][0][0]
+            ex_supp = rfcs_int[ex_index][0][1]
+            print('example cell rfcs: \n'
+                  'control: %.3f \n'
+                  'suppressed: %.3f \n'
+                  % (ex_cont, ex_supp))
+
+            # compute stats
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # layout
+            titlestr = 'Preferred size'
+            ax.set_title(titlestr)
+            ax.set_xlim(-3, np.nanmax(cont) + 3)
+            ax.set_ylim(-3, np.nanmax(supp) + 3)
+            ax.plot((0, 75), (0, 75), linestyle='-', color='grey', zorder=-1, linewidth=0.35)
+            ax.spines['left'].set_bounds(0, 75)
+            ax.spines['bottom'].set_bounds(0, 75)
+
+        elif alys == 'si':
+            # compare suppression indices (panel h)
+
+            # get data
+            si_int = np.vstack(trn_sztun_opto_df.si_76.values)
+            assert not any(si_int[si_int > 1] or si_int[si_int < 0]), 'si > 1 or < 0, check!'
+            cont = si_int[:, 0]
+            supp = si_int[:, 1]
+
+            # get example si
+            ex_cont = si_int[ex_index][0][0]
+            ex_supp = si_int[ex_index][0][1]
+            print('example cell si: \n'
+                  'control: %.3f \n'
+                  'suppressed: %.3f \n'
+                  % (ex_cont, ex_supp))
+
+            # compute stats
+            cont_mean, supp_mean = spatint_utils.compute_stats(cont=cont, supp=supp, alys=alys)
+
+            # layout
+            titlestr = 'si'
+            ax.set_title(titlestr)
+            ax.set_xlim(-0.05, 1.05)
+            ax.set_ylim(-0.05, 1.05)
+            ax.set_xticks((0, 0.5, 1))
+            ax.set_yticks((0, 0.5, 1))
+            ax.plot((0, 1), (0, 1), linestyle='-', color='grey', linewidth=0.35, zorder=-1)
+            ax.spines['left'].set_bounds(0, 1)
+            ax.spines['bottom'].set_bounds(0, 1)
+
+        else:
+            print('No proper analysis selected')
+            return
+
+        # general layout
+        ax.set_ylabel('V1 suppression')
+        ax.yaxis.label.set_color(optocolor)
+        ax.set_xlabel('Control')
+
+        if alys == 'bratio':
+            # plot on log scale
+
+            ax.scatter(cont, supp, s=15, facecolors='none', edgecolors='k', linewidth=0.5)
+            ax.plot(cont_mean, supp_mean, linestyle='', marker='.', color='goldenrod', ms=15)
+            ax.set_yscale('log')
+            ax.set_xscale('log')
+            ax.set_yticks((0.01, 0.1, 1))
+            ax.set_xticks((0.01, 0.1, 1))
+            ax.set_xticklabels((0.01, 0.1, 1))
+            ax.set_yticklabels((0.01, 0.1, 1))
+
+        else:
+            ax.scatter(cont, supp, s=15, facecolors='none', edgecolors='k', linewidth=0.5)
+            ax.plot(cont_mean, supp_mean, linestyle='', marker='.', color='goldenrod', ms=15)
+
+        # plot example neuron
+        ax.plot(ex_cont, ex_supp, linestyle='', marker='.', color='deeppink', ms=15)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def resp_differences(self, figsize=(2.5, 2.5), ax=None, alpha=0.4, itnum=1000,
+                         sig_alpha=0.05, stepsize=1):
+        """Plots difference between modelled trn responses under V1 suppression and control
+        condition (Fig. 5i)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+        alpha: float
+            level of transparency
+        itnum: int
+            number of bootstraps
+        sig_alpha: float
+            significance level
+        stepsize: int
+            stepssize to evalute consecutive difference between conditions
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        bestexps = self.trn_sztun_opto_df
+
+        # filter out all units that were not fit in opto conditon
+        bestexps = bestexps[bestexps.tun_pars.apply(lambda x: any(~np.isnan(x[1])))]
+        bestexps.reset_index(inplace=True)
+
+        # define eval space
+        x_eval = np.arange(0, 76, 1)
+
+        # init lists to store rog diffs
+        diff_lst = []
+
+        for bestexp in bestexps.itertuples():
+            # loop over units
+
+            # evaluate rog models for both conditions
+            cont_fr = spatint_utils.rog_offset(x_eval, *bestexp.tun_pars[0])
+            opto_fr = spatint_utils.rog_offset(x_eval, *bestexp.tun_pars[1])
+
+            # normalize by max fr in control
+            cont_max = np.max(cont_fr)
+            cont_fr_norm = cont_fr / cont_max
+            opto_fr_norm = opto_fr / cont_max
+
+            # calculate difference
+            diff_fr = opto_fr_norm - cont_fr_norm
+
+            # collect fr diffs in list
+            diff_lst.append(diff_fr)
+
+            # plot single trace
+            ax.plot(x_eval, diff_fr, color='grey', alpha=alpha, linewidth=0.5)
+
+        # print number of plotted curves
+        print('plotted n = %d individual curves' % len(bestexps))
+
+        # plot mean diff_fr
+        cent_diff = np.mean(np.array(diff_lst), axis=0)
+        ax.plot(x_eval[0::stepsize], cent_diff[0::stepsize], linestyle='-', color='k')
+
+        # bootstrap rog diffs
+        # init list to collect bootstrapped diffs
+        bootdiffs_lst = []
+        # set seed
+        np.random.seed(0)
+        for _ in range(itnum):
+            # bootstrap rog differences
+            # get random indices
+            randidc = np.array(
+                [np.random.randint(0, len(diff_lst)) for x in range(len(diff_lst))])
+            # get random diffs
+            bootdiffs = [diff_lst[randidx] for randidx in randidc]
+            # take diffs between consecutive size steps
+            bootchanges = np.diff(np.array(bootdiffs), axis=1)
+            # compute mean
+            bootcent = np.mean(bootchanges, axis=0)
+            # only get each xth value
+            bootcent = bootcent[0::stepsize]
+            # take diff and store in list
+            bootdiffs_lst.append(bootcent)
+
+        # convert boot diff list to array
+        bootdiffs_arr = np.vstack(bootdiffs_lst)
+        # init list to store ci percentils for each point
+        percentils = []
+        for steps in range(bootdiffs_arr.shape[1]):
+            # loop over each sizestep and get 95% ci
+            percentils.append([np.percentile(bootdiffs_arr[:, steps], sig_alpha * 100 / 2),
+                               np.percentile(bootdiffs_arr[:, steps],
+                                             100 - (sig_alpha * 100 / 2))])
+        # check if slope is sig different from 0
+        sig_inc = [percentil[0] < percentil[1] < 0 for percentil in percentils]
+        sig_dec = [percentil[1] > percentil[0] > 0 for percentil in percentils]
+
+        # print report
+        nstimincreases = len(sig_inc)
+        nsig_dec = np.sum(sig_dec)
+        nsig_inc = np.sum(sig_inc)
+
+        # get indices of sig dec/inc
+        sig_dec_idc = np.where(sig_dec)[0]
+        sig_inc_idc = np.where(sig_inc)[0]
+
+        print(
+            '%d out of %d increases in diameter led to a significant increase in response'
+            ' reduction' % (nsig_inc, nstimincreases))
+        print(
+            '%d out of %d increases in diameter led to a significant decrease in response'
+            ' reduction' % (nsig_dec, nstimincreases))
+        print(
+            'Stim sizes for which reduction was sig greater than for the next smaller size:',
+            sig_inc_idc)
+        print(
+            'Stim sizes for which reduction was sig smaller than for the next smaller size:',
+            sig_dec_idc)
+
+        # layout
+        ax.set_ylabel('Normalized rate $\Delta$')
+        ax.set_xlabel('Diameter ($\degree$)')
+        ax.plot(x_eval[0::stepsize][sig_inc_idc], cent_diff[0::stepsize][sig_inc_idc],
+                linestyle='-', color='mediumseagreen')
+        ax.plot(x_eval[0::stepsize][sig_dec_idc], cent_diff[0::stepsize][sig_dec_idc],
+                linestyle='-', color='blue')
+        ax.set_yticks((0, -0.5, -1))
+        ax.set_xticks((0, 25, 50, 75))
+        ylims = ax.get_ylim()
+        ax.spines['left'].set_bounds(-1, ylims[1])
+        ax.spines['bottom'].set_bounds(0, 75)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def threshlin_ex(self, figsize=(2.5, 2.5), ax=None):
+        """Plot threshold linear model fit to repsonses of example neuron (Fig. 5j)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        bestexps = self.trn_sztun_opto_df
+
+        # get data for example neuron
+        ex_cell = bestexps[
+                          (bestexps['m'] == self.trn_sztun_opto_ex_dict['m']) &
+                          (bestexps['s'] == self.trn_sztun_opto_ex_dict['s']) &
+                          (bestexps['e'] == self.trn_sztun_opto_ex_dict['e']) &
+                          (bestexps['u'] == self.trn_sztun_opto_ex_dict['u'])]
+
+        # evaluate rog models
+        xeval = range(76)
+        cont_rog = spatint_utils.rog_offset(xeval, *ex_cell.tun_pars.values[0][0])
+        opto_rog = spatint_utils.rog_offset(xeval, *ex_cell.tun_pars.values[0][1])
+
+        # normalize to max in control
+        cont_norm = cont_rog / np.nanmax(cont_rog)
+        opto_norm = opto_rog / np.nanmax(cont_rog)
+
+        # define start parameters for threshlin fit
+        startparams = (1, 0)
+
+        # fit threshold linear model
+        tlparams, _ = curve_fit(spatint_utils.threshlin, cont_norm, opto_norm, p0=startparams)
+
+        # evaluate threshlin model
+        xeval = np.arange(0, 1.01, 0.01)
+        tl_fit = spatint_utils.threshlin(xeval, *tlparams)
+
+        # compute threshold
+        thres = (tlparams[1] * -1) / tlparams[0]
+
+        # plot
+        ax.plot((0, 1), (0, 1), linestyle='-', color='grey', zorder=-1, linewidth=0.35)
+        ax.scatter(cont_norm, opto_norm, edgecolors='k', facecolors='none', linewidth=0.5)
+        ax.plot(xeval, tl_fit, color=optocolor)
+        ax.plot((thres, thres), (-0.2, 0.75), '--', color='k', clip_on=False, linewidth=0.5)
+
+        # label threshold for panel
+        ax.text(thres - 0.2, 0.8, s='Threshold', fontsize=4)
+
+        # layout
+        ax.set_xlim((-0.05, 1))
+        ax.set_ylim((-0.05, 1))
+        ax.set_xticks((0, 0.5, 1))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['bottom'].set_bounds(0, 1)
+        ax.spines['left'].set_bounds(0, 1)
+        ax.set_ylabel('V1 suppression', color=optocolor)
+        ax.set_xlabel('Control')
+        ax.set_title('Normalized\nRoG-model rates')
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def threshlin(self, figsize=(2.5, 2.5), ax=None, rsq_thres=0.8, outlier=1.1,
+                  xbounds=(-1, 1), ybounds=(-2, 2)):
+        """Plot threshold and slope parameter for trn population (Fig. 5k)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+        rsq_thres: float
+            Rsquared threshold for threshold linear model
+        outlier: float
+            factor for placing outliers at the boundary of plot
+        xbounds: tuple of len 2
+            plotting bounds for x axis
+        ybounds: tuple of len 2
+            plotting bounds for y axis
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        bestexps = self.trn_sztun_opto_df
+
+        # filter out all units that were not fit in opto conditon
+        bestexps = bestexps[bestexps.tun_pars.apply(lambda x: any(~np.isnan(x[1])))]
+        bestexps.reset_index(inplace=True)
+
+        # define eval space
+        x_eval = range(76)
+
+        # init list to store xintercepts, slopes, and rsqs
+        xinters = []
+        slopes = []
+        rsqs = []
+        bad_indices = []
+        badfits = 0
+
+        for bestexp in bestexps.itertuples():
+            # loop over units
+
+            # evaluate models
+            cont_rog = spatint_utils.rog_offset(x_eval, *bestexp.tun_pars[0])
+            opto_rog = spatint_utils.rog_offset(x_eval, *bestexp.tun_pars[1])
+
+            # normalize to max in control
+            cont_norm = cont_rog / np.nanmax(cont_rog)
+            opto_norm = opto_rog / np.nanmax(cont_rog)
+
+            # define start parameters for threshlin fit
+            startparams = (1, 0)
+
+            # fit threshold linear model
+            tlparams, _ = curve_fit(spatint_utils.threshlin, cont_norm, opto_norm,
+                                    p0=startparams)
+
+            # get rsq for fit
+            rsq = spatint_utils.rsquared(opto_norm, spatint_utils.threshlin(cont_norm,
+                                                                            *tlparams))
+
+            # compute and store xintercept and slope
+            if rsq < rsq_thres:
+                # nan if fit below threshold
+                badfits += 1
+                xinters.append(np.nan)
+                slopes.append(np.nan)
+                rsqs.append(np.nan)
+                bad_indices.append(bestexp.index)
+
+            else:
+                xinter = -tlparams[1] / tlparams[0]
+                xinters.append(xinter)
+                slopes.append(tlparams[0])
+                rsqs.append(rsq)
+
+        # transform lists to arrays
+        xinters = np.array(xinters)
+        slopes = np.array(slopes)
+        rsqs = np.array(rsqs)
+
+        # get index of example key for plotting
+        ex_row = bestexps[
+                          (bestexps['m'] == self.trn_sztun_opto_ex_dict['m']) &
+                          (bestexps['s'] == self.trn_sztun_opto_ex_dict['s']) &
+                          (bestexps['e'] == self.trn_sztun_opto_ex_dict['e']) &
+                          (bestexps['u'] == self.trn_sztun_opto_ex_dict['u'])
+                          ].index.values
+        all_indic = bestexps.index.values
+        ex_index = np.where(ex_row == all_indic)[0]
+        ex_slope = slopes[ex_index]
+        ex_xinter = xinters[ex_index]
+        ex_rsq = rsqs[ex_index]
+        print('threshold linear paramters example cell \n'
+              'slope:       %0.3f \n'
+              'threshold: %0.3f \n'
+              'rsq: %0.3f \n'
+              % (ex_slope, ex_xinter, ex_rsq))
+
+        # remove nans
+        xinters = xinters[~np.isnan(xinters)]
+        slopes = slopes[~np.isnan(slopes)]
+
+        # convert to log2
+        slopes_log2 = np.log2(slopes)
+
+        # compute stats
+        nsamples = len(slopes_log2)
+        W_xint, p_xint = stats.wilcoxon(xinters)
+        W_slopes, p_slopes = stats.wilcoxon(slopes_log2)
+        mean_slopes = 2 ** np.mean(slopes_log2)
+        mean_xint = np.mean(xinters)
+        sem_xint = stats.sem(xinters)
+        sem_slopes = 2 ** stats.sem(slopes_log2)
+
+        print('population stats: \n'
+              'n = %d \n'
+              'threshold mean +- sem = %.3f +- %.3f \n'
+              'W_thres = %0.3f \n'
+              'p_thres = %0.3f \n'
+              'slope mean +- sem = %.3f +- %.3f \n'
+              'W_slope = %0.3f \n'
+              'p_slope = 10 ** %0.3f \n'
+
+              % (nsamples, mean_xint, sem_xint, W_xint, p_xint, mean_slopes, sem_slopes,
+                 W_slopes, np.log10(p_slopes)))
+
+        # move outlier
+        xinters_plot = np.copy(xinters)
+        slopes_plot = np.copy(slopes_log2)
+        slopes_plot[np.where(slopes_plot > ybounds[1])[0]] = ybounds[1] * outlier
+        slopes_plot[np.where(slopes_plot < ybounds[0])[0]] = ybounds[0] * outlier
+
+        # plot
+        ax.plot(xbounds, [0, 0], linestyle='-', color='grey', zorder=-1, linewidth=0.35, ms=15)
+        ax.plot([0, 0], ybounds, linestyle='-', color='grey', zorder=-1, linewidth=0.35, ms=15)
+        # all datapoints
+        ax.scatter(xinters_plot, slopes_plot, facecolors='none', edgecolors='k', s=15,
+                   linewidth=0.5)
+        ax.plot(np.mean(xinters), np.mean(slopes_log2), '.', color='goldenrod', ms=15)  # mean
+        ax.plot(ex_xinter, np.log2(ex_slope), marker='.', color='deeppink', ms=15)  # example
+
+        # layout
+        ax.set_yticks([ybounds[0] * outlier, ybounds[0] / 2, 0, ybounds[1] / 2,
+                       ybounds[1] * outlier])
+        ax.set_xticks([-1, 0, 1])
+        ax.set_yticklabels([('<%.2f' % 2 ** ybounds[0]), 2 ** (ybounds[0] / 2), 2 ** 0,
+                            int(2 ** (ybounds[1] / 2)), ('>%d' % 2 ** ybounds[1])])
+        ax.spines['left'].set_bounds(ybounds[0], ybounds[1])
+        ax.spines['bottom'].set_bounds(xbounds[0], xbounds[1])
+        ax.set_ylabel('Slope')
+        ax.set_xlabel('Threshold')
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+

+ 462 - 0
figs/figS10.py

@@ -0,0 +1,462 @@
+# code for panels for Figure S10
+
+# import libs
+from matplotlib import pyplot as plt
+import numpy as np
+import pandas
+from scipy import stats
+from importlib import reload
+import spatint_utils
+
+# reload modules
+reload(spatint_utils)
+
+optocolor = spatint_utils.get_colors()
+
+
+class FigS10:
+    """Class to for plotting panels for Fig. S10"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read lgn size tuning dataframe
+        self.lgn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/lgn_sztun_df.pkl')
+
+        # read trn size tuning dataframe
+        self.trn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_sztun_df.pkl')
+
+        # read trn/lgn rf area dataframe
+        self.rf_area_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/rf_area_df.pkl')
+
+    def drfstim_x_si(self, figsize=(2.5, 2.5), ax=None, ms=15, fs=6, lw=0.5):
+        """Plots for each unit normalized distance between stimulus center and
+         RF center against SI (Fig. S10a)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        ms: float
+            Markersize
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        # init figure
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        dstimrf = self.trn_sztun_df.dstimrf
+        sis = self.trn_sztun_df.si_76
+
+        # compute linear regression
+        slope, intercept, r, p, _ = stats.linregress(dstimrf.astype('float'), sis.astype(
+            'float'))
+        xmodel = np.arange(0, 1.1, 0.1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        rsq = r ** 2  # compute rsq
+
+        # plot
+        ax.scatter(dstimrf, sis, facecolors='none', edgecolors='k', s=ms, color='k', lw=lw)
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('Normalized distance between\nstimulus center and RF center')
+        xlim = (np.nanmin((np.nanmin(dstimrf), -0.05)), np.nanmax((np.nanmax(dstimrf), 1)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 0.5, 1))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 1)
+
+        # print info to plot
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.9, 'R² = %.2e' % rsq, fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.84, 'b = %.2f' % slope,
+                fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.78, 'p = %.2f' % p, fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.72, 'n = %d' % len(sis),
+                fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def drfmon_x_si(self, figsize=(2.5, 2.5), ax=None, ms=15, fs=6, lw=0.5):
+        """Plots for each unit distance of monitor center to RF center against SI (Fig. S10b)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        ms: float
+            Markersize
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        # init figure
+        if ax is None:
+            f, ax = plt.subplots(figsize=figsize)
+
+        # get data
+        dmonrf = self.trn_sztun_df.dmonrf.values
+        sis = self.trn_sztun_df.si_76.values
+
+        # compute linear regression
+        slope, intercept, r, p, _ = stats.linregress(dmonrf.astype('float'), sis.astype(
+            'float'))
+        xmodel = np.arange(np.nanmin(dmonrf), np.nanmax(dmonrf), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        rsq = r ** 2  # compute rsq
+
+        # plot
+        ax.scatter(dmonrf, sis, facecolors='none', edgecolors='k', s=ms, color='k', lw=lw)
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('Distance between\nmonitor center and RF center ($\degree$)')
+        xlim = (np.nanmin((np.nanmin(dmonrf), -2.5)), np.nanmax((np.nanmax(dmonrf) + 1, 1)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 20, 40))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 40)
+
+        # print info to plot
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.9, 'R² = %.2e' % rsq, fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.84, 'b = %.2e' % slope,
+                fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.78, 'p = %.2f' % p, fontsize=fs)
+        ax.text(ax.get_xlim()[1] * 0.5, ax.get_ylim()[1] * 0.72, 'n = %d' % len(sis),
+                fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def trn_rfarea_x_si(self, figsize=(2.5, 2.5), ax=None, fs=6, lw=0.5):
+        """Plots for each trn unit RF area against SI (Fig. S10c)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        # init figure
+        if ax is None:
+            f, ax = plt.subplots(figsize=figsize)
+            f.tight_layout()
+
+        # get si data
+        sis = self.trn_sztun_df.si_201.values
+
+        # merge dataframes
+        dfmerge = pandas.merge(self.trn_sztun_df, self.rf_area_df[['m', 's', 'u', 'area']],
+                               on=['m', 's', 'u'], how='left')
+        rfarea = dfmerge.area.values
+
+        # restrict selection to exclude outliers
+        supp_idc = sis > 0.01
+        sis_restr = sis[supp_idc]
+        rfarea_restr = rfarea[supp_idc]
+
+        rfarea_idc = rfarea_restr < 2000
+        rfarea_restr = rfarea_restr[rfarea_idc]
+        sis_restr = sis_restr[rfarea_idc]
+
+        # plot entire selection
+        ax.scatter(rfarea, sis, facecolors='none', edgecolors='k', lw=lw)
+
+        # plot restricted selection
+        ax.scatter(rfarea_restr, sis_restr, facecolors='grey', edgecolors='k', lw=lw)
+
+        # compute linear regression for restricted selection
+        slope_restr, intercept_restr, r_restr, p_restr, _ = stats.linregress(
+            rfarea_restr.astype('float'), sis_restr.astype('float'))
+        xmodel = np.arange(np.nanmin(rfarea_restr), np.nanmax(rfarea_restr), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope_restr, intercept_restr))
+        rsq_restr = r_restr ** 2  # compute rsq
+        # plot
+        ax.plot(xmodel, ymodel, color='grey')
+
+        # compute linear regression for entire selection
+        slope, intercept, r, p, _ = stats.linregress(rfarea.astype('float'), sis.astype(
+            'float'))
+        xmodel = np.arange(np.nanmin(rfarea), np.nanmax(rfarea), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        rsq = r ** 2  # compute rsq
+        # plot
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('RF area ($\deg^2$)')
+        xlim = (np.nanmin((np.nanmin(rfarea), -50)), np.nanmax((np.nanmax(rfarea) + 1, 3000)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 1500, 3000))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 3000)
+
+        # add stats to plot
+        xmax = max(ax.get_xlim())
+        ymax = max(ax.get_ylim())
+        ax.text(xmax * 0.6, ymax * 0.90, 'R² = %.2e' % rsq, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.84, 'b = %.2e' % slope, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.78, 'p = %.2f' % p, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.72, 'n = %d' % len(sis), color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.66, 'R² = %.2e' % rsq_restr, color='grey', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.60, 'b = %.2e' % slope_restr, color='grey', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.54, 'p = %.2f' % p_restr, color='grey', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.48, 'n = %d' % len(sis_restr), color='grey', fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def trn_rfcs_x_si(self, figsize=(2.5, 2.5), ax=None, fs=6, lw=0.5):
+        """Plots for each trn unit preferred size against SI (Fig. S10d)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        # init figure
+        if ax is None:
+            f, ax = plt.subplots(figsize=figsize)
+            f.tight_layout()
+
+        # get data
+        sis = self.trn_sztun_df.si_201.values
+        rfcss = self.trn_sztun_df.rfcs_201.values
+
+        # plot
+        ax.scatter(rfcss, sis, facecolors='none', edgecolors='k', lw=lw)
+
+        # compute linear regression for entire selection
+        slope, intercept, r, p, _ = stats.linregress(rfcss.astype('float'),
+                                                     sis.astype('float'))
+        xmodel = np.arange(np.nanmin(rfcss), np.nanmax(rfcss), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        rsq = r ** 2  # compute rsq
+        # plot
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('Preferred size ($\degree$)')
+        xlim = (np.nanmin((np.nanmin(rfcss), -2)), np.nanmax((np.nanmax(rfcss), 200 + 2)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 100, 200))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 200)
+
+        # add stats to plot
+        xmax = max(ax.get_xlim())
+        ymax = max(ax.get_ylim())
+        ax.text(xmax * 0.6, ymax * 0.90, 'R² = %.2f' % rsq, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.84, 'b = %.2e' % slope, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.78, 'p = %0.2e' % p, color='k', fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.72, 'n = %d' % len(sis), color='k', fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def lgn_rfarea_x_si(self, figsize=(2.5, 2.5), ax=None,  fs=6, lw=0.5):
+        """Plot for each dLGN unit rf area against suppression index (Fig. S10e)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+            f.tight_layout()
+
+        # restrict selection to single unit receptive fields
+        su_hits = self.lgn_sztun_df.loc[self.lgn_sztun_df.rf_type == 'su']
+        # compute rf size
+        rfareas = su_hits.rf_pars.apply(lambda x: x[4] * x[5] * np.pi).values
+        # get suppression indices
+        sis = su_hits.si_201.apply(lambda x: x[0]).values
+
+        # compute linear regression
+        slope, intercept, r, p, _ = stats.linregress(rfareas, sis)
+        xmodel = np.arange(np.nanmin(rfareas), np.nanmax(rfareas), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        # compute rsq
+        rsq = r ** 2
+
+        # plot scatter
+        ax.scatter(rfareas, sis, facecolors='none', edgecolors='k', lw=lw)
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('RF area ($\deg^2$)')
+        xlim = (np.nanmin((np.nanmin(rfareas), -5)), np.nanmax((np.nanmax(rfareas) + 1, 250)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 150, 300))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 300)
+
+        # add stats to plot
+        xmax = max(ax.get_xlim())
+        ymax = max(ax.get_ylim())
+        ax.text(xmax * 0.6, ymax * 0.90, 'R² = %.2f' % rsq, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.84, 'b = %.2e' % slope, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.78, 'p = %.2f' % p, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.72, 'n = %d' % len(sis), fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+    def lgn_rfcs_x_si(self, figsize=(2.5, 2.5), ax=None, fs=6, lw=0.5):
+        """Plot for each dLGN unit preferred size against suppression index (Fig. S10f)
+
+        Parameters
+        ----------
+        figsize: tuple
+            Figure size (width, height)
+        ax: instance of matplotlib.axes class
+            Axis to use for plotting
+        fs: float
+            Fontsize
+        lw: float
+            Linewidth
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+            f.tight_layout()
+
+        # get data
+        sis = self.lgn_sztun_df.si_201.apply(lambda x: x[0]).values
+        rfcss = self.lgn_sztun_df.rfcs_201.apply(lambda x: x[0]).values
+
+        # plot
+        ax.scatter(rfcss, sis, facecolors='none', edgecolors='k', lw=lw)
+
+        # compute linear regression
+        slope, intercept, r, p, _ = stats.linregress(rfcss, sis)
+        xmodel = np.arange(np.nanmin(rfcss), np.nanmax(rfcss), 1)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+        # compute rsq
+        rsq = r ** 2
+        # plot
+        ax.plot(xmodel, ymodel, color='k')
+
+        # layout
+        ax.set_ylabel('Suppression index')
+        ax.set_xlabel('Preferred size ($\degree$)')
+        xlim = (np.nanmin((np.nanmin(rfcss), -2)), np.nanmax((np.nanmax(rfcss), 60 + 2)))
+        ylim = (np.nanmin((np.nanmin(sis), -0.05)), np.nanmax((np.nanmax(sis), 1)))
+        ax.set_xlim(xlim)
+        ax.set_ylim(ylim)
+        ax.set_xticks((0, 30, 60))
+        ax.set_yticks((0, 0.5, 1))
+        ax.spines['left'].set_bounds(0, 1)
+        ax.spines['bottom'].set_bounds(0, 60)
+
+        # add stats to plot
+        xmax = max(ax.get_xlim())
+        ymax = max(ax.get_ylim())
+        ax.text(xmax * 0.6, ymax * 0.90, 'R² = %.2f' % rsq, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.84, 'b = %.2f' % slope, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.78, 'p = %.2f' % p, fontsize=fs)
+        ax.text(xmax * 0.6, ymax * 0.72, 'n = %d' % len(sis), fontsize=fs)
+
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax

+ 125 - 0
figs/figS4.py

@@ -0,0 +1,125 @@
+# code to plot Fig. S4
+
+# import
+import matplotlib.pyplot as plt
+import scipy
+import numpy as np
+import spatint_utils
+
+# load plot params
+spatint_utils.plot_params()
+
+
+class FigS4:
+    """Class for plotting Fig. S4"""
+
+    def __init__(self):
+        """Init class"""
+
+        # load data from matfile
+        data = scipy.io.loadmat('./to_publish/data/v1_spat_int.mat')
+
+        # unfold data
+        self.layerDepth = data['layerDepth'][0]
+        self.center_sz_c_smooth = np.array([x[0] for x in data['center_sz_c_smooth']])
+        self.si_c_smooth = np.array([x[0] for x in data['si_c_smooth']])
+        self.layerY = data['layerY'][0]
+        self.max_sz_plot = data['max_sz_plot'][0][0]
+        self.rel_layerDepth = data['rel_layerDepth'][0]
+        self.si_c = data['si_c'][0]
+        self.relative_depth = data['relative_depth'][0]
+        self.center_sz_c = data['center_sz_c'][0]
+
+    def rfcs_x_depth(self, ax=None, figsize=(2.5, 2.5)):
+        """Plot preferred size against cortical depth in V1 (Fig. S4a)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # plot
+        self.center_sz_c[self.center_sz_c > self.max_sz_plot] = self.max_sz_plot + 20
+        ax.scatter(self.center_sz_c, self.relative_depth, color='k', facecolor='none')
+        ax.plot(self.center_sz_c_smooth, self.relative_depth, color='r')
+        for layer in self.layerY:
+            ax.hlines(layer, 0, self.max_sz_plot, color='k', linestyle='--')
+
+        # layout
+        ax.set_xlabel('Preferred size ($\degree$)')
+        ax.set_ylabel('Cortical depth to\nbase of L4 (mm)')
+        ax.set_xlim(-2, self.max_sz_plot+25)
+        ax.set_ylim([(self.rel_layerDepth[0] - self.layerDepth[0]) / 1000, self.rel_layerDepth[
+            -1] / 1000])
+        ax.invert_yaxis()
+        ax.set_yticks(())
+        ax.set_xticks((0, 25, 50, 75, 95))
+        ax.set_xticklabels((0, 25, 50, 75, '>75'))
+        ax.set_yticks((-0.4, 0, 0.4))
+        ax.set_yticklabels(())
+        ax.spines['bottom'].set_bounds(0, self.max_sz_plot)
+        f = plt.gcf()
+        f.tight_layout()
+
+        # print n
+        print('n = %d' % len(self.relative_depth))
+
+        return ax
+
+    def si_x_depth(self, ax=None, figsize=(2.5, 2.5)):
+        """Plot suppression index against cortical depth in V1 (Fig. S4b)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axis for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if ax is None:
+            # init figure
+            f, ax = plt.subplots(figsize=figsize)
+
+        # plot si vs depth
+        ax.scatter(self.si_c, self.relative_depth, color='k', s=15, facecolor='none')
+        ax.plot(self.si_c_smooth, self.relative_depth, color='r')
+        for layer in self.layerY:
+            ax.hlines(layer, 0, 1, color='k', linestyle='--')
+
+        # layout
+        ax.set_xlabel('Suppression index (SI)')
+        ax.set_ylabel('Cortical depth to\nbase of L4 (mm)')
+        ax.set_xlim(-0.05, 1.05)
+        ax.set_ylim([(self.rel_layerDepth[0] - self.layerDepth[0]) / 1000, self.rel_layerDepth[
+            -1] / 1000])
+        ax.invert_yaxis()
+        ax.set_yticks((-0.4, 0, 0.4))
+        ax.set_xticks((0, 0.5, 1))
+        ax.spines['bottom'].set_bounds(0, 1)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+
+
+
+
+

+ 94 - 0
figs/figS6.py

@@ -0,0 +1,94 @@
+# code to plot figure S6
+
+# import
+import matplotlib.pyplot as plt
+from importlib import reload
+import pandas
+import spatint_utils
+
+# reload modules
+reload(spatint_utils)
+
+# load plot params
+spatint_utils.plot_params()
+
+
+class FigS6:
+    """Class for plotting panels of Fig. S6"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read lgn size tuning dataframe
+        self.lgn_sztun_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/lgn_sztun_df.pkl')
+
+    def ex_curves(self, figsize=(5, 2), axs=None):
+        """Plot example dLGN size tuning curves (Fig. S6 a-j)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        axs: mpl axis
+            axes for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if axs is None:
+            # init figure
+            f, axs = plt.subplots(2, 5, figsize=figsize)
+
+        # determine example keys
+        exkeys = [{'m': 'PVCre_2013_0046', 's': 8, 'e': 9, 'u': 3038},   # a
+                  {'m': 'PVCre_2013_0046', 's': 6, 'e': 6, 'u': 4042},   # b
+                  {'m': 'PVCre_2013_0046', 's': 8, 'e': 12, 'u': 3030},  # c
+                  {'m': 'PVCre_2013_0054', 's': 10, 'e': 4, 'u': 2073},  # d
+                  {'m': 'PVCre_2013_0046', 's': 8, 'e': 14, 'u': 3039},  # e
+                  {'m': 'PVCre_2013_0054', 's': 10, 'e': 5, 'u': 1048},  # f
+                  {'m': 'PVCre_2013_0046', 's': 5, 'e': 7, 'u': 3037},   # g
+                  {'m': 'PVCre_2013_0054', 's': 11, 'e': 4, 'u': 3016},  # h
+                  {'m': 'PVCre_2020_0002', 's': 6, 'e': 5, 'u': 4},      # i
+                  {'m': 'PVCre_2013_0046', 's': 6, 'e': 6, 'u': 4040}]   # j
+
+        # loop over example keys and plot
+        for ax, exkey in zip(axs.flatten(), exkeys):
+
+            ex_row = self.lgn_sztun_df[(self.lgn_sztun_df.m == exkey['m']) &
+                                       (self.lgn_sztun_df.s == exkey['s']) &
+                                       (self.lgn_sztun_df.e == exkey['e']) &
+                                       (self.lgn_sztun_df.u == exkey['u'])].iloc[0]
+
+            spatint_utils.plot_tun(means=ex_row['tun_mean'],
+                                   sems=ex_row['tun_sem'],
+                                   spons=ex_row['tun_spon_mean'],
+                                   xs=ex_row['ti_axes'],
+                                   c_fit=ex_row['c_sz_fit'],
+                                   op_fit=ex_row['op_sz_fit'],
+                                   c_prefsz=ex_row['rfcs_76'][0],
+                                   op_prefsz=ex_row['rfcs_76'][1],
+                                   ax=ax)
+
+            # layout
+            max_x = 76
+            ax.set_xlim(-3, max_x)
+            ax.set_xticks([0, 25, 50, 75])
+            ax.spines['bottom'].set_bounds(0, max_x)
+            yticks = ax.get_yticks()
+            ax.set_yticks((yticks[0], int((yticks[-1] - 5) / 2), yticks[-1] - 5))
+            ax.set_ylabel('')
+            ax.set_xlabel('')
+            ax.set_xticklabels([])
+
+        axs.flatten()[5].set_ylabel('Firing rate (sp/s)')
+        axs.flatten()[5].set_xlabel('Diameter ($\degree$)')#
+        axs.flatten()[5].set_xticklabels([0, 25, 50, 75])
+        f = plt.gcf()
+        f.tight_layout()
+
+        return axs
+

+ 143 - 0
figs/figS8.py

@@ -0,0 +1,143 @@
+# plots panels for Fig S8
+
+# import libs
+import matplotlib.pyplot as plt
+from importlib import reload
+import pandas
+import spatint_utils
+import numpy as np
+
+# reload modules
+reload(spatint_utils)
+
+# load plot params
+spatint_utils.plot_params()
+_, _, optocolor = spatint_utils.get_colors()
+
+
+class FigS8:
+    """Class to for plotting panels for Fig. S8"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read lgn size tuning dataframe
+        self.trn_sztun_opto_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_sztun_opto_df.pkl')
+
+    def ex_curves(self, axs=None, figsize=(5, 2)):
+        """Plot example dLGN size tuning curves (Fig. S6 a-j)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        axs: mpl axes
+            axes for plot
+
+        Returns
+        -------
+        axs: mpl axes
+            axis for plot
+        """
+
+        if axs is None:
+            # init figure
+            f, axs = plt.subplots(2, 5, figsize=figsize)
+
+        # determine keys for example neurons
+        exkeys = [{'m': 'PVCre_2018_0009', 's': 4, 'e': 11, 'u': 8},   # a
+                  {'m': 'PVCre_2018_0009', 's': 4, 'e': 14, 'u': 3},   # b
+                  {'m': 'PVCre_2018_0009', 's': 4, 'e': 11, 'u': 5},   # c
+                  {'m': 'PVCre_2019_0006', 's': 5, 'e': 14, 'u': 6},   # d
+                  {'m': 'PVCre_2019_0007', 's': 8, 'e': 9, 'u': 20},   # e
+                  {'m': 'PVCre_2018_0009', 's': 4, 'e': 18, 'u': 24},  # f
+                  {'m': 'PVCre_2019_0006', 's': 5, 'e': 3, 'u': 8},    # g
+                  {'m': 'PVCre_2019_0006', 's': 5, 'e': 9, 'u': 7},    # h
+                  {'m': 'PVCre_2019_0006', 's': 7, 'e': 10, 'u': 9},   # i
+                  {'m': 'PVCre_2019_0007', 's': 8, 'e': 12, 'u': 8}]   # j
+
+        # loop over example neurons
+        for ax, exkey in zip(axs.flatten(), exkeys):
+
+            ex_row = self.trn_sztun_opto_df[(self.trn_sztun_opto_df.m == exkey['m']) &
+                                            (self.trn_sztun_opto_df.s == exkey['s']) &
+                                            (self.trn_sztun_opto_df.e == exkey['e']) &
+                                            (self.trn_sztun_opto_df.u == exkey['u'])].iloc[0]
+
+            spatint_utils.plot_tun(means=ex_row['tun_mean'],
+                                   sems=ex_row['tun_sem'],
+                                   spons=ex_row['tun_spon_mean'],
+                                   params=ex_row['tun_pars'],
+                                   xs=ex_row['ti_axes'],
+                                   c_prefsz=ex_row['rfcs_76'][0],
+                                   op_prefsz=ex_row['rfcs_76'][1],
+                                   ax=ax)
+
+            # layout
+            max_x = 76
+            ax.set_xlim(-3, max_x)
+            ax.set_xticks([0, 25, 50, 75])
+            ax.spines['bottom'].set_bounds(0, max_x)
+            yticks = ax.get_yticks()
+            ax.set_yticks((yticks[0], int((yticks[-1] - 10) / 2), yticks[-1] - 10))
+            ylims = ax.get_ylim()
+            ax.set_ylim((ylims[1]*-0.05, ylims[1]))
+            ax.spines['left'].set_bounds(0, ylims[1])
+            ax.set_ylabel('')
+            ax.set_xlabel('')
+            ax.set_xticklabels([])
+
+        axs.flatten()[5].set_ylabel('Firing rate (sp/s)')
+        axs.flatten()[5].set_xlabel('Diameter ($\degree$)')
+        axs.flatten()[5].set_xticklabels([0, 25, 50, 75])
+        f = plt.gcf()
+        f.tight_layout()
+
+        return axs
+
+    def si_hist(self, figsize=(2.5, 2.5), ax=None):
+        """Plot histogram with suppression indices under both conditions
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        ax: mpl axis
+            axes for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        # get data
+        trn_sztun_opto_df = self.trn_sztun_opto_df[
+            (self.trn_sztun_opto_df.tun_rsq.apply(lambda x: ~np.isnan(x[1]))) &
+            (self.trn_sztun_opto_df.tun_mean.apply(lambda x: np.mean(x[1])) >= 0.1)]
+
+        # get sis
+        si_int = np.vstack(trn_sztun_opto_df.si_76.values)
+        assert not any(si_int[si_int > 1] or si_int[si_int < 0]), 'si > 1 or < 0, check!'
+        cont = si_int[:, 0]
+        supp = si_int[:, 1]
+
+        # plot
+        if ax is None:
+            f, ax = plt.subplots(figsize=figsize)
+        si_bins = np.arange(0, 1.05, 0.05)
+        ax.hist(cont, bins=si_bins, color='k')
+        ax.hist(supp, bins=si_bins, color=optocolor, weights=np.ones(len(supp)) * -1)
+        ax.set_yticks([-20, 0, 20])
+        ax.set_yticklabels([20, 0, 20])
+        ax.set_xticks([0, 0.5, 1])
+        ax.set_ylabel('# of neurons')
+        ax.set_xlabel('SI')
+        ax.spines['left'].set_bounds(-30, 30)
+        ax.spines['bottom'].set_bounds(0, 1)
+        f = plt.gcf()
+        f.tight_layout()
+
+        return ax
+

+ 405 - 0
figs/figS9.py

@@ -0,0 +1,405 @@
+# code for plotting panels for figS9
+
+# import libs
+from matplotlib import pyplot as plt
+import numpy as np
+import pandas
+from scipy import stats
+from importlib import reload
+import spatint_utils
+
+# reload modules
+reload(spatint_utils)
+
+spatint_utils.plot_params()
+
+
+class FigS9:
+    """Class to for plotting panels for Fig. S9"""
+
+    def __init__(self):
+        """Init class"""
+
+        # read lgn size tuning dataframe
+        self.trn_sztun_opto_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/trn_sztun_opto_df.pkl')
+
+        # read trn/lgn rf area dataframe
+        self.rf_area_df = pandas.read_pickle(
+            filepath_or_buffer='./to_publish/data/rf_area_df.pkl')
+
+    def trnchange_x_props(self, figsize=(8, 2.7), axs=None):
+        """Plot example dLGN size tuning curves (Fig. S6 a-j)
+
+        Parameters
+        -------
+        figsize: tuple
+            Figure size (width, height)
+        axs: mpl axis
+            axes for plot
+
+        Returns
+        -------
+        ax: mpl axis
+            axis for plot
+        """
+
+        if axs is None:
+            # create figure
+            f, axs = plt.subplots(3, 9, figsize=figsize)
+
+        # get data
+        trn_data = self.trn_sztun_opto_df
+        rf_area = self.rf_area_df
+
+        # add percent change in mean response
+        perc_change_resp = trn_data.tun_mean.apply(
+            lambda x: ((x[:, 1].mean() / x[:, 0].mean()) - 1) * 100)
+        trn_data['perc_change_resp'] = perc_change_resp.to_list()
+
+        # add difference in burst ratio
+        diff_bratio = trn_data.bratio_c.apply(
+            lambda x: np.nanmean(x)).values - trn_data.bratio_op.apply(
+            lambda x: np.nanmean(x)).values
+        trn_data['diff_bratio'] = diff_bratio.tolist()
+
+        # add rf area
+        trn_data = pandas.merge(trn_data, rf_area[['m', 's', 'u', 'area']],
+                                on=['m', 's', 'u'], how='left')
+
+        # add percent change in preferred size
+        perc_change_rfcs = trn_data.rfcs_76.apply(lambda x: ((x[1] / x[0]) - 1) * 100)
+        trn_data['perc_change_rfcs'] = perc_change_rfcs.to_list()
+
+        # add si diff
+        diff_si = trn_data.si_76.apply(lambda x: (x[1] - x[0]))
+        trn_data['diff_si'] = diff_si.to_list()
+
+        # restrict data sets
+        # fit: fit in both size tuning conditions and sufficient mean firing during V1
+        # suppression
+        trn_data_fit = trn_data[(trn_data.tun_rsq.apply(lambda x: ~np.isnan(x[1]))) &
+                                (trn_data.tun_mean.apply(lambda x: np.mean(x[1])) >= 0.1)]
+        # hyp: contrast tuning exp well fit with hyper ratio model
+        trn_data_hyp = trn_data.loc[trn_data.hyp_r2 > 0.8]
+        # fit hyp: combination of both
+        trn_data_fit_hyp = trn_data_fit.loc[trn_data_fit.hyp_r2 > 0.8]
+
+        # ------------------------- plot perc change response ---------------------------------
+        # against suppression index
+        si_cont = trn_data.si_76.apply(lambda x: x[0]).values
+        self.makeplot(x=si_cont, y=trn_data.perc_change_resp, ax=axs[0, 0])
+
+        # against preferred size
+        rfcs_cont = trn_data.rfcs_76.apply(lambda x: x[0]).values
+        self.makeplot(x=rfcs_cont, y=trn_data.perc_change_resp, ax=axs[0, 1])
+
+        # against receptive field area
+        self.makeplot(x=trn_data.area, y=trn_data.perc_change_resp, ax=axs[0, 2])
+
+        # against contrast sensitivity
+        self.makeplot(x=trn_data_hyp.hyp_sigma, y=trn_data_hyp.perc_change_resp, ax=axs[0, 3])
+
+        # against contrast component
+        self.makeplot(x=trn_data_hyp.hyp_n, y=trn_data_hyp.perc_change_resp, ax=axs[0, 4])
+
+        # against spontaneous activity
+        spon_c = trn_data.tun_spon_mean.apply(lambda x: x[0]).values
+        self.makeplot(x=spon_c, y=trn_data.perc_change_resp, ax=axs[0, 5])
+
+        # against mean firing rate
+        meanfr_c = trn_data.tun_mean.apply(lambda x: np.mean(x[:, 0])).values
+        self.makeplot(x=meanfr_c, y=trn_data.perc_change_resp, ax=axs[0, 6])
+
+        # against burst ratio
+        bratio_c = trn_data.bratio_c.apply(lambda x: np.nanmean(x)).values
+        self.makeplot(x=bratio_c, y=trn_data.perc_change_resp, ax=axs[0, 7])
+
+        # against burst length
+        blen_c = trn_data.blen_c.apply(
+            lambda x: np.nanmean(np.array([item for sublist in x for item in sublist])))
+        blen_c_exnan = blen_c[~np.isnan(blen_c)]
+        change_resp_exnan = trn_data.perc_change_resp[~np.isnan(blen_c)]
+        self.makeplot(x=blen_c_exnan, y=change_resp_exnan, ax=axs[0, 8])
+
+        # -------------------- plot perc change preferred size --------------------------------
+
+        # against suppression index
+        # plot outlier
+        axs[1, 0].scatter(trn_data_fit.loc[trn_data_fit.perc_change_rfcs > 600].si_76.apply(
+                lambda x: x[0]).values, 350, facecolors='none', edgecolors='grey', s=3.4,
+                          lw=0.5)
+        change_rfcs_exoutlier = trn_data_fit.loc[
+            trn_data_fit.perc_change_rfcs < 600].perc_change_rfcs
+        si_cont_exoutlier = trn_data_fit.si_76.loc[
+            trn_data_fit.perc_change_rfcs < 600].apply(lambda x: x[0]).values
+        self.makeplot(x=si_cont_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 0])
+
+        # against preferred size
+        axs[1, 1].scatter(
+            trn_data_fit.loc[trn_data_fit.perc_change_rfcs > 600].rfcs_76.apply(
+                lambda x: x[0]).values, 350, facecolors='none', edgecolors='grey', s=3.4,
+            lw=0.5)
+        rfcs_cont_exoutlier = trn_data_fit.loc[
+            trn_data_fit.perc_change_rfcs < 600].rfcs_76.apply(lambda x: x[0]).values
+        self.makeplot(x=rfcs_cont_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 1])
+
+        # against receptive field area
+        axs[1, 2].scatter(trn_data_fit.loc[trn_data_fit.perc_change_rfcs > 600].area, 350,
+                          facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        area_ex_outlier = trn_data_fit.loc[trn_data_fit.perc_change_rfcs < 600].area
+        self.makeplot(x=area_ex_outlier, y=change_rfcs_exoutlier, ax=axs[1, 2])
+
+        # against contrast sensitivity
+        axs[1, 3].scatter(
+            trn_data_fit_hyp.loc[trn_data_fit_hyp.perc_change_rfcs > 600].hyp_sigma, 350,
+            facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        change_rfcs_hyp_exoutlier = trn_data_fit_hyp.loc[
+            trn_data_fit_hyp.perc_change_rfcs < 600].perc_change_rfcs
+        sig_exoutlier = trn_data_fit_hyp.loc[
+            trn_data_fit_hyp.perc_change_rfcs < 600].hyp_sigma
+        self.makeplot(x=sig_exoutlier, y=change_rfcs_hyp_exoutlier, ax=axs[1, 3])
+
+        # against contrast component
+        axs[1, 4].scatter(
+            trn_data_fit_hyp.loc[trn_data_fit_hyp.perc_change_rfcs > 600].hyp_n, 350,
+            facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        n_exoutlier = trn_data_fit_hyp.loc[trn_data_fit_hyp.perc_change_rfcs < 600].hyp_n
+        self.makeplot(x=n_exoutlier, y=change_rfcs_hyp_exoutlier, ax=axs[1, 4])
+
+        # against spontaneous firing rate
+        spon_c = trn_data_fit.tun_spon_mean.apply(lambda x: x[0]).values
+        axs[1, 5].scatter(spon_c[trn_data_fit.perc_change_rfcs > 600], 350,
+                          facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        spon_c_exoutlier = spon_c[trn_data_fit.perc_change_rfcs < 600]
+        self.makeplot(x=spon_c_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 5])
+
+        # against mean firing rate
+        meanfr_c = trn_data_fit.tun_mean.apply(lambda x: np.mean(x[:, 0]))
+        axs[1, 6].scatter(meanfr_c[trn_data_fit.perc_change_rfcs > 600], 350,
+                          facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        meanfr_c_exoutlier = meanfr_c[trn_data_fit.perc_change_rfcs < 600]
+        self.makeplot(x=meanfr_c_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 6])
+
+        # against burst ratio
+        bratio_c = trn_data_fit.bratio_c.apply(lambda x: np.nanmean(x)).values
+        axs[1, 7].scatter(bratio_c[trn_data_fit.perc_change_rfcs > 600], 350,
+                          facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        bratio_c_exoutlier = bratio_c[trn_data_fit.perc_change_rfcs < 600]
+        self.makeplot(x=bratio_c_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 7])
+
+        # against burst length
+        blen_c = trn_data_fit.blen_c.apply(
+            lambda x: np.nanmean(np.array([item for sublist in x for item in sublist])))
+        blen_c_exnan = blen_c[~np.isnan(blen_c)]
+        change_rfcs_exnan = trn_data_fit.perc_change_rfcs[~np.isnan(blen_c)]
+        axs[1, 8].scatter(blen_c_exnan[change_rfcs_exnan > 600], 350,
+                          facecolors='none', edgecolors='grey', s=3.4, lw=0.5)
+        blen_c_exoutlier = blen_c_exnan[change_rfcs_exnan < 600]
+        change_rfcs_exoutlier = change_rfcs_exnan[change_rfcs_exnan < 600]
+        self.makeplot(x=blen_c_exoutlier, y=change_rfcs_exoutlier, ax=axs[1, 8])
+
+        # -------------------- plot change in surround suppression ----------------------------
+
+        # against surround suppression
+        si_cont_fit = trn_data_fit.si_76.apply(lambda x: x[0]).values
+        self.makeplot(x=si_cont_fit, y=trn_data_fit.diff_si, ax=axs[2, 0])
+
+        # against preferred size
+        rfcs_cont_fit = trn_data_fit.rfcs_76.apply(lambda x: x[0]).values
+        self.makeplot(x=rfcs_cont_fit, y=trn_data_fit.diff_si, ax=axs[2, 1])
+
+        # against receptive field area
+        self.makeplot(x=trn_data_fit.area, y=trn_data_fit.diff_si, ax=axs[2, 2])
+
+        # against contrast sensitivity
+        self.makeplot(x=trn_data_fit_hyp.hyp_sigma, y=trn_data_fit_hyp.diff_si,ax=axs[2, 3])
+
+        # against contrast component
+        self.makeplot(x=trn_data_fit_hyp.hyp_n, y=trn_data_fit_hyp.diff_si, ax=axs[2, 4])
+
+        # against spontaneous activity
+        spon_c = trn_data_fit.tun_spon_mean.apply(lambda x: x[0]).values
+        self.makeplot(x=spon_c, y=trn_data_fit.diff_si, ax=axs[2, 5])
+
+        # against mean firing rate
+        meanfr_c = trn_data_fit.tun_mean.apply(lambda x: np.mean(x[:, 0]))
+        self.makeplot(x=meanfr_c, y=trn_data_fit.diff_si, ax=axs[2, 6])
+
+        # against burst ratio
+        bratio_c = trn_data_fit.bratio_c.apply(lambda x: np.nanmean(x)).values
+        self.makeplot(x=bratio_c, y=trn_data_fit.diff_si, ax=axs[2, 7])
+
+        # against burst length
+        blen_c = trn_data_fit.blen_c.apply(
+            lambda x: np.nanmean(np.array([item for sublist in x for item in sublist])))
+        blen_c_exnan = blen_c[~np.isnan(blen_c)]
+        diff_si_exnan = trn_data_fit.diff_si[~np.isnan(blen_c)]
+        self.makeplot(x=blen_c_exnan, y=diff_si_exnan, ax=axs[2, 8])
+
+        # format plot
+        # rows
+        for i, ax in enumerate(axs[0, :]):
+            ax.set_yticks((-100, -50, 0))
+            ax.spines['left'].set_bounds(-100, 30)
+            ax.set_ylabel('%$\Delta$ responsiveness')
+            if i > 0:
+                ax.set_yticklabels(())
+                ax.set_ylabel('')
+
+        for i, ax in enumerate(axs[1, :]):
+            ax.set_yticks((0, 150, 300, 350))
+            ax.set_yticklabels((0, 150, 300, '> 300'))
+            ax.spines['left'].set_bounds(-60, 300)
+            ax.set_ylabel('%$\Delta$ preferred size')
+            if i > 0:
+                ax.set_yticklabels(())
+                ax.set_ylabel('')
+
+        for i, ax in enumerate(axs[2, :]):
+            ax.set_yticks((-0.5, 0, 0.5))
+            ax.spines['left'].set_bounds(-0.5, 0.5)
+            ax.set_ylabel('$\Delta$ SI')
+            if i > 0:
+                ax.set_yticklabels(())
+                ax.set_ylabel('')
+
+        # columns
+        for i, ax in enumerate(axs[:, 0]):
+            ax.set_xticks((0, 0.5, 1))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 1)
+            ax.set_xlabel('SI')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 1]):
+            ax.set_xticks((0, 25, 50, 75))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 75)
+            ax.set_xlabel('Preferred size $\degree$')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 2]):
+            ax.set_xticks((0, 1500, 3000))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 3000)
+            ax.set_xlabel('RF area (deg²)')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 3]):
+            ax.set_xticks((0, 0.5, 1))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 1)
+            ax.set_xlabel('Contrast\nsensitivity (c_50)')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 4]):
+            ax.set_xticks((0, 5, 10))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 10)
+            ax.set_xlabel('Contrast exponent')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 5]):
+            ax.set_xticks((0, 30, 60))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 60)
+            ax.set_xlabel('Spontaneous\nfiring rate (sp/s)')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 6]):
+            ax.set_xticks((0, 40, 80))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 80)
+            ax.set_xlabel('Mean response\n(sp/s)')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 7]):
+            ax.set_xticks((0, 0.2, 0.4))
+            xlims = ax.get_xlim()
+            ax.set_xlim(xlims[1] * - 0.05, xlims[1])
+            ax.spines['bottom'].set_bounds(0, 0.4)
+            ax.set_xlabel('Burst ratio')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        for i, ax in enumerate(axs[:, 8]):
+            ax.set_xticks((2, 6, 10))
+            xlims = ax.get_xlim()
+            ax.set_xlim(1, xlims[1])
+            ax.spines['bottom'].set_bounds(2, 10)
+            ax.set_xlabel('Burst length')
+            if i < 2:
+                ax.set_xticklabels(())
+                ax.set_xlabel('')
+
+        f = plt.gcf()
+        f.tight_layout()
+
+    def makeplot(self, x, y, ax):
+        """Create plot
+
+        Parameters
+        -------
+        x: array
+            x data
+        y: array
+            y data
+        ax: mpl axis
+            axis for plot
+        """
+
+        # compute parameters for linear regression and r squared
+        slope, intercept, r, p, _ = stats.linregress(x, y)
+        rsq = r ** 2
+        n = len(x)
+
+        # evaluate model
+        xmodel = np.arange(np.nanmin(x), np.nanmax(x), 0.01)
+        ymodel = spatint_utils.linreg(xmodel, *(slope, intercept))
+
+        # plot
+        ax.scatter(x, y, facecolors='none', edgecolors='k', s=3.4, lw=0.5)
+        ax.plot(xmodel, ymodel, color='k')
+
+        # add info as text
+        xpos_txt = ax.get_xlim()[0] + ((ax.get_xlim()[1] - ax.get_xlim()[0]) * 0.5)
+        ypos_txt = ax.get_ylim()[0] + (ax.get_ylim()[1] - ax.get_ylim()[0]) * np.array(
+            [0.7, 0.8, 0.9, 1])
+        fs = 4
+        if rsq < 0.01:
+            ax.text(xpos_txt, ypos_txt[3], 'R² = %.2e' % rsq, fontsize=fs)
+        else:
+            ax.text(xpos_txt, ypos_txt[3], 'R² = %.2f' % rsq, fontsize=fs)
+        if (slope < 0.01) & (slope > -0.01):
+            ax.text(xpos_txt, ypos_txt[2], 'b = %.2e' % slope, fontsize=fs)
+        else:
+            ax.text(xpos_txt, ypos_txt[2], 'b = %.2f' % slope, fontsize=fs)
+        if p < 0.01:
+            ax.text(xpos_txt, ypos_txt[1], 'p = %.2e' % p, fontsize=fs)
+        else:
+            ax.text(xpos_txt, ypos_txt[1], 'p = %.2f' % p, fontsize=fs)
+        ax.text(xpos_txt, ypos_txt[0], 'n = %d' % n, fontsize=fs)
+

+ 637 - 0
figs/spatint_utils.py

@@ -0,0 +1,637 @@
+# helper functions to publish
+
+# import libs
+import numpy as np
+from scipy.special import erf
+from scipy import stats
+from matplotlib import colors, cm
+import matplotlib.lines as mlines
+import matplotlib as plt
+
+
+def plot_params():
+    """Determine plotting parameters"""
+
+    params = {
+        'xtick.labelsize': 6,
+        'ytick.labelsize': 6,
+        'axes.linewidth': 0.35,
+        'xtick.major.width': 0.35,
+        'ytick.major.width': 0.35,
+        'ytick.major.size': 1.8,
+        'xtick.major.size': 1.8,
+        'lines.linewidth': 1,
+        'lines.markersize': 3.4,
+        'font.family': 'sans-serif',
+        'font.sans-serif': 'FreeSans',
+        'font.weight': 'normal',
+        'font.size': 6,
+        'axes.titlesize': 6,
+        'pdf.fonttype': 42,
+        'ps.fonttype': 42,
+        'savefig.dpi': 300,
+        'mathtext.default': 'regular'
+        }
+
+    plt.rcParams.update(params)
+
+
+def get_colors():
+    """Get colors for plotting
+
+    Returns
+    -------
+    trn_red: tuple
+        rgba values for trn red
+    lgn_green: tuple
+        rgba values for lgn green
+    optocolor: string
+        color for V1 suppression condition
+    """
+
+    trn_red = (255/255, 11/255, 0/255, 1)
+    lgn_green = (0/255, 243/255, 76/255, 1)
+    optocolor = '#1090cfff'
+
+    return trn_red, lgn_green, optocolor
+
+
+def rsquared(targets, predictions):
+    """Return the r-squared value for the fit
+
+    Parameters
+    ----------
+    targets: ndarray
+        data to predict
+    predictions: ndarray
+        predicted data
+
+    Returns
+    -------
+    rsq: float
+        R squared value
+    """
+
+    residuals = targets - predictions
+    residual_variance = np.sum(residuals**2)
+    variance_of_targets = np.sum((targets - np.mean(targets))**2)
+    if variance_of_targets == 0:
+        rsq = np.nan
+    else:
+        rsq = 1 - (residual_variance / variance_of_targets)
+    return rsq
+
+
+def threshlin(x, m, b):
+    """Return threshold linear model
+
+    Parameters
+    ----------
+    x: ndarray
+        x data
+    m: float
+        slope
+    b: float
+        offset
+
+    Returns
+    -------
+    y: ndarray
+        ydata
+    """
+
+    y = m * x + b
+    y[y < 0] = 0
+    return y
+
+
+def rog_offset(x, kc, wc, ks, ws, offset):
+    """Ratio of Gaussians model (Cavanaugh et al. 2002 J Neurophysiol) with offset and
+    rectification.
+
+    Parameters
+    ----------
+    x: ndarray
+        x data
+    kc : float
+        gain center
+    wc : float
+        width center (sigma)
+    ks : float
+        gain surround
+    ws : float
+        width surround (sigma)
+    offset: float
+        offset
+
+    Returns
+    -------
+    y : ndarray
+        y data
+    """
+
+    lc = erf(x / wc) ** 2
+    ls = erf(x / ws) ** 2
+    y = (kc * lc / (1 + ks * ls)) + offset
+    if np.any(~np.isnan(y)):
+        # make sure that not all y are nans
+        zer_idx = y[y < 0]
+        if len(zer_idx) > 0:
+            y[y < 0] = 0
+    return y
+
+def sum_of_gaussians(xs, dp, rp, rn, r0, sigma):
+    """ORITUNE  sum of two gaussians living on a circle, for orientation tuning
+
+        x are the orientations, bet 0 and 360.
+        dp is the preferred direction (bet 0 and 360)
+        rp is the response to the preferred direction;
+        rn is the response to the opposite direction;
+        r0 is the background response (useful only in some cases)
+        sigma is the tuning width;
+    """
+    angles_p = 180 / np.pi * np.angle(np.exp(1j*(xs-dp) * np.pi / 180))
+    angles_n = 180 / np.pi * np.angle(np.exp(1j*(xs-dp+180) * np.pi / 180))
+    y = (r0 + rp*np.exp(-angles_p**2 / (2*sigma**2)) + rn*np.exp(-angles_n**2 / (2*sigma**2)))
+    return y
+
+
+def linreg(x, m, p):
+    """Return linear regression model
+
+    Parameters
+    -------
+    x: ndarray
+        x data
+    m: float
+        slope
+    p: float
+        offset
+
+    Returns
+    -------
+    y : ndarray
+        y data
+    """
+    y = x * m + p
+    return y
+
+
+def degdiff(d1, d2, circledeg):
+    """Compute the difference in degrees between two angles.
+
+    Parameters
+    -------
+    d1: float
+        first angle in degree
+    d2: float
+        second angle in degree
+    circledeg: int
+        if 180 indicates that you want to be on the half circle ie -90 and 90
+
+    Returns
+    -------
+    d : float
+        difference between two angles in degree
+    """
+
+    if circledeg == 180:
+        d1 = 2 * d1
+        d2 = 2 * d2
+    d = 180 / np.pi * np.angle(np.exp(1j * (d1-d2) * np.pi / 180))
+    if circledeg == 180:
+        d = d / 2
+    return d
+
+
+def calculate_ellipse(x, y, a, b, angle, steps=36):
+    """Computes points to draw an ellipse.
+
+    Parameters
+    -------
+    x: float
+        x center
+    y: float
+        y center
+    a: float
+        semimajor axis
+    b: float
+        semiminor axis
+    angle:
+        angle of the ellipse in degree
+    steps: int
+        number of points to compute
+
+    Returns
+    -------
+    x_coord: float
+        x coordinates
+    y_coord: float
+        y coordinates
+    """
+
+    # do the calculateEllipse
+    beta = angle * (np.pi / 180)
+    sinbeta = np.sin(beta)
+    cosbeta = np.cos(beta)
+    alpha = np.linspace(0, 360, steps).T * (np.pi / 180)
+    sinalpha = np.sin(alpha)
+    cosalpha = np.cos(alpha)
+    x_coord = x + (a * cosalpha * cosbeta - b * sinalpha * sinbeta)
+    y_coord = y + (a * cosalpha * sinbeta + b * sinalpha * cosbeta)
+
+    return x_coord, y_coord
+
+
+def plot_rf(means, stim_param_values, grat_width, grat_height, ax=None, interp='spline16',
+            scalebar=None):
+    """Plot receptive field
+
+    Parameters
+    ----------
+    means: ndarray
+        responses to sparse noise stimulus
+    stim_param_values: ndarray
+        center coordinates of squares
+    grat_width: ndarray
+        grating widths
+    grat_height:
+        grating heights
+    ax: ax
+        axes to plot rf on
+    interp: string
+        interpolation method for imshow
+        uses addition of rgb arrays instead of overlaying to rfs with alpha value
+    scalebar: dict
+        dictionary to determine scalebar parameters ('dist': distance from edge, 'width':
+        width, 'length': length in degree, 'col': color)
+
+    Returns
+    -------
+    ax : matplotlib axis
+    """
+
+    if scalebar is None:
+        # set all dict values to none
+        scalebar = {'width': None, 'dist': None, 'length': None, 'col': None}
+
+    # normalize RFs:
+    normalized_means = np.zeros(means.shape)
+    mean_range = np.nanmax(means) - np.nanmin(means)  # peak-to-peak, i.e. max range of
+    # values
+    if mean_range != 0:
+        normalized_means = (means - np.nanmin(means)) / mean_range
+
+    # get half edge length of sparse noise squares
+    edge_length = np.unique(np.concatenate((grat_width, grat_height)))
+    assert len(edge_length) == 1, "Stimuli have unequal edge lengths"
+    half_edge = edge_length[0] / 2
+
+    # extent indicates the limits of the actual data coordinates. This enables plotting
+    # the RFs in the correct space.
+    extent = (np.nanmin(stim_param_values[:, 0]) - half_edge,
+              np.nanmax(stim_param_values[:, 0]) + half_edge,
+              np.nanmin(stim_param_values[:, 1]) - half_edge,
+              np.nanmax(stim_param_values[:, 1]) + half_edge)
+
+    # split normalized rfs into on and off and rotate
+    off = (normalized_means[:, :, 0]).T
+    on = (normalized_means[:, :, 1]).T
+
+    # create colormaps
+    stop_off_cm = {'r': 223, 'g': 112, 'b': 0}
+    stop_on_cm = {'r': 125, 'g': 0, 'b': 255}
+    start_cm = {'r': 0, 'g': 0, 'b': 0}
+    off_cm = create_cm(start_cm, stop_off_cm)
+    on_cm = create_cm(start_cm, stop_on_cm)
+
+    # get rgb values
+    off_rgb = cm.ScalarMappable(cmap=off_cm).to_rgba(off)
+    on_rgb = cm.ScalarMappable(cmap=on_cm).to_rgba(on)
+
+    # get rid of alpha value
+    off_rgb = off_rgb[:, :, :3]
+    on_rgb = on_rgb[:, :, :3]
+
+    # add rgb arrays
+    rgb_add = np.add(off_rgb, on_rgb)
+
+    # normalize RFs:
+    normalized_rgb = np.zeros(rgb_add.shape)
+    rgb_range = np.nanmax(rgb_add) - np.nanmin(rgb_add)  # peak-to-peak
+    if rgb_range != 0:
+        normalized_rgb = (rgb_add - np.nanmin(rgb_add)) / rgb_range
+
+    # plot
+    ax.imshow(normalized_rgb, origin='lower', interpolation=interp, extent=extent)
+
+    # disable all ticks, tick labels, and spines:
+    ax.set_xticks([])
+    ax.set_xticklabels([])
+    ax.set_yticks([])
+    ax.set_yticklabels([])
+    for spine in ax.spines.values():
+        spine.set_visible(False)
+
+    if all(scalebar.values()):
+        # if values specified plot scalebar
+        yscalebar = ax.get_ylim()[0] + 5
+        xscalebar_max = ax.get_xlim()[1] - scalebar['dist']
+        xscalebar_min = ax.get_xlim()[1] - scalebar['dist'] - scalebar['length']
+        ax.plot(([xscalebar_min, xscalebar_max]), ([yscalebar, yscalebar]),
+                linewidth=scalebar['width'], color=scalebar['col'])
+
+    return ax
+
+
+def create_cm(start, stop):
+    """Create customized colormap that is a linear interpolation between two colors.
+
+    Parameters
+    ----------
+    start : dict with keys r, g, b, and values between 0 and 255
+        rgb values of start color
+    stop : dict with keys r, g, b, and values between 0 and 255
+        rgb values of stop color
+
+    Returns
+    -------
+    cm : matplotlib colormap
+    """
+
+    rgb_dict = {'red':  ((0.0, 1.0, start['r']/255),
+                         (1.0, stop['r']/255, 1.0)),
+                'green': ((0.0, 1.0, start['g']/255),
+                          (1.0, stop['g']/255, 1.0)),
+                'blue': ((0.0, 1.0, start['b']/255),
+                         (1.0, stop['b']/255, 1.0))}
+
+    colormap = colors.LinearSegmentedColormap('Cm', rgb_dict)
+
+    return colormap
+
+
+def plot_tun(means, sems, spons, xs, params=(None, None), c_fit=None, op_fit=None,
+             c_prefsz=None, op_prefsz=None, ax=None, fmt='.', mfc=None, ms=3.4, lw=1,
+             cs=('k', '#1090cfff'), sponline='-', sponlinewidth=0.5):
+
+    """
+    Plot tuning curve
+
+    Parameters
+    ----------
+    means: ndarray
+        responses to stimuli
+    sems: ndarray
+        standard error of the mean for responses
+    spons: ndarray
+        spontaneous firing rates
+    xs: ndarray
+        independent variable
+    params: ndarray
+        model parameter
+    c_fit: ndarray
+        model fit for control condition
+    op_fit: ndarray
+        model fit for suppression condition
+    c_prefsz: float
+        preferred size under control condition
+    op_prefsz: float
+        preferred size under V1 suppression
+    ax: axis
+        mpl axis
+    fmt: string
+        formatting information
+    mfc: string
+        markerfacecolor
+    ms: float
+        markersize
+    lw: float
+        linewidth
+    cs: tuple
+        colors
+    sponline: string
+        linestyle for horizontal line depicting spontaenous activity
+    sponlinewidth: float
+        line width for spon firing rate
+
+    Returns
+    -------
+    ax : mpl axis
+
+    """
+
+    # for convenience while iterating, swap dimensions from nvals x nconds to
+    # nconds x nvals
+    ys = means.T
+    yerrs = sems.T
+
+    # to iterate for plotting make sure that ys and yerrs are 2D
+    if ys.ndim == 1:
+        ys = np.expand_dims(ys, axis=0)
+        yerrs = np.expand_dims(yerrs, axis=0)
+        params = np.expand_dims(params, axis=0)
+
+    # plot data points and model:
+    modelxs = []
+    zorders = np.arange(1, len(ys) + 1, 1)
+    fits = [c_fit, op_fit]
+
+    for i, args in enumerate(zip(fits, spons, cs, ys, yerrs, zorders, params)):
+        fit, spon, c, y, yerr, zorder, param = args
+        ax.errorbar(xs, y, yerr=yerr, fmt=fmt, mfc=mfc, ms=ms, c=c, zorder=zorder, lw=lw)
+        if param is not None:
+            modelx = np.arange(76)
+            yfit = rog_offset(modelx, *param)
+            ax.plot(modelx, yfit, '-', c=c, zorder=zorder)
+        elif param is None:
+            modelx = np.arange(76)
+            ax.plot(modelx, fit, '-', c=c, zorder=zorder)
+        modelxs.append(modelx)
+        # xmin and xmax can differ between data and model:
+        xmin = min(xs.min(), modelx.min())
+        xmax = max(xs.max(), modelx.max())
+        # plot horizontal line for spontaneous activity
+        ax.plot([xmin, xmax], [spon, spon], linestyle=sponline, zorder=-1, marker='', c=c,
+                linewidth=sponlinewidth)
+    ax.set_ylim(bottom=0)
+    modelx = np.unique(np.concatenate(modelxs))  # superset of all model x vals
+
+    # calc xmin and xmax one last time, this time across data and all model points:
+    xmin = min(xs.min(), modelx.min())
+    xmax = max(xs.max(), modelx.max())
+    xmid = (xmax - xmin) / 2  # xmid isn't necessarily a data or model xval
+    ax.set_xticks([xmin, xmid, xmax])
+    ax.set_ylabel('Firing rate (sp/s)')
+    ax.set_xlabel('Diameter ($\degree$)')
+    ax.spines['bottom'].set_bounds(0, 75)
+    # hide the right and top spines:
+
+    if c_prefsz is not None:
+        # plot vertical bar for preferres size
+        ylims = ax.get_ylim()
+        ax.set_ylim((0, ylims[1]))
+        ax.plot([c_prefsz, c_prefsz], [0, ylims[1] / 10], zorder=-1, marker='', c=cs[0],
+                linestyle='-', linewidth=0.5)
+        ax.plot([op_prefsz, op_prefsz], [0, ylims[1] / 10], zorder=-1, marker='', c=cs[1],
+                linestyle='-', linewidth=0.5)
+
+    return ax
+
+
+def plot_raster(raster, tranges, opto=None, opto_ranges=None, offsets=(-0.15, 0.15),
+                s=1, l=1, color='k', optocolor='#1090cfff', barwidth=1, bardistance=30,
+                axisbg='w', ax=None):
+    """
+    Create trial raster plots, one per unit in an experiment. Allows specifying
+    multiple units as self.
+
+    Parameters
+    ----------
+    raster: ndarray
+        array of arrays holding spike times, one array per trial
+    tranges: list of tuples
+        stimulus time ranges
+    opto: ndarray
+        array with booleans, one for each trial indicating if control or V1 suppression
+        condition
+    opto_ranges: ndarray
+        opto on time ranges
+    offsets: tuple
+        Set time offsets relative to trial for spike event inclusion
+    s: scalar
+        marker width
+    l: float
+        linelength
+    color: str or tuple of RGB values
+        marker color
+    optocolor: str or tuple of RGB values
+        marker color for opto spikes
+    barwidth: scalar
+        width of event bar
+    bardistance: scalar
+        position of event bar relative to plot
+    axisbg: str
+        axis background color
+    ax: axis object
+        Optionally provide target axes to plot raster into.
+
+    Returns
+    -------
+    axs : axis
+        axis
+    """
+
+    # get event times for bars
+    # get stimulus duration
+    stim_durs = np.round((tranges[:, 1] - tranges[:, 0]), decimals=2)
+    stim_dur, stim_dur_counts = np.unique(stim_durs, return_counts=True)
+
+    if type(opto) is np.ndarray:
+        # get opto duration
+        opto_durs = np.round(opto_ranges[:, 1] - opto_ranges[:, 0], decimals=2)
+        opto_dur, opto_dur_counts = np.unique(opto_durs, return_counts=True)
+
+        # get opto start relative to stim onset
+        tranges_sort = np.sort(tranges[opto], axis=0)
+        opto_starts = np.round(opto_ranges[:, 0] - tranges_sort[:, 0], decimals=2)
+        opto_start, opto_start_counts = np.unique(opto_starts, return_counts=True)
+
+    sortedtrialis, c = [], []  # lists of sorted trial IDs (not original) and colors
+    # collect 1-based sorted trial info, one entry per spike.
+    # sortedtrialis is used as y coords in scatter plot.
+    for sortedtriali, trialspikes in enumerate(raster):
+        nspikes = len(trialspikes)
+        sortedtrialis.append(np.tile(sortedtriali + 1, nspikes))  # 1-based
+        if opto is not None and opto[sortedtriali]:
+            spikecolors = np.tile(optocolor, nspikes)
+        else:
+            spikecolors = np.tile(color, nspikes)  # black
+        c.append(spikecolors)
+
+    # convert each list of arrays to a single flat array:
+    raster = np.hstack(raster)
+    sortedtrialis = np.hstack(sortedtrialis)
+    c = np.hstack(c)
+
+    dtmax = tranges.ptp(axis=1).max()  # max tranges duration
+    xmin, xmax = offsets[0], dtmax + offsets[1]
+
+    raster = raster[:, np.newaxis]
+    ax.eventplot(raster, lineoffsets=sortedtrialis, colors=c, linewidth=s,
+                 linelengths=l, zorder=1)
+    ax.set_xlim(xmin, xmax)
+
+    line = mlines.Line2D([0, stim_dur], [-bardistance, -bardistance],
+                         linewidth=barwidth, color='k', clip_on=False)
+    ax.add_line(line)
+    if type(opto) is np.ndarray:
+        line = mlines.Line2D([opto_start, opto_dur + opto_start],
+                             [- bardistance * 2, -bardistance * 2], linewidth=barwidth,
+                             color=optocolor, clip_on=False)
+        ax.add_line(line)
+
+    # -1 inverts the y axis, +1 ensures last trial is fully visible:
+    ntrials = len(tranges)
+    ax.set_ylim(ntrials + 1, -1)
+
+    ax.set_facecolor(axisbg)
+    ax.set_xlabel('Time (s)')
+    if opto is None:
+        ax.set_ylabel('Trial\n(by diameter)')
+    else:
+        ax.set_ylabel('Trial\n(by diameter and condition)')
+
+    return ax
+
+
+def compute_stats(cont, supp, alys):
+    """Compute and print population statistics for scatter plots
+
+    Parameters
+    -------
+    cont: ndarray
+        values in control condition
+    supp: ndarray
+        values in condition in which V1 is suppressed
+    alys: string: 'ropt', 'rsupp', 'rfcs', 'si'
+        determines which parameter to analyze: modelled response to optimal stimulus (
+        ropt), modelled response to large stimulus (rsupp), preferred size (rfcs),
+        suppression index (si)
+
+    Returns
+    -------
+    cont_mean: float
+        mean for control condition
+    supp_mean: float
+        mean for V1 suppression condition
+    """
+
+    w_stat, pval = stats.wilcoxon(cont, supp)
+    nsamples = np.unique([len(cont), len(supp)])[0]
+    cont_mean = np.mean(cont)
+    cont_sem = stats.sem(cont)
+    supp_mean = np.mean(supp)
+    supp_sem = stats.sem(supp)
+
+    if alys == 'si' or 'bratio':
+        diff = np.mean(supp - cont)
+        diff_str = 'mean diff: np.mean(supp - cont)'
+        if alys == 'si':
+            print('%d/%d cells still have SI >= 0.1 during suppression' % (np.sum(supp >= 0.1),
+                                                                           len(supp)))
+    else:
+        diff = ((2 ** np.mean(np.log2(supp / cont))) - 1) * 100
+        diff_str = 'perc change: ((2**np.mean(np.log2(supp / cont))) - 1) * 100'
+    print('stats %s: \n'
+          'cont mean +- sem = %.3f +- %.3f \n'
+          'supp mean +- sem = %.3f +- %.3f \n'
+          'Wstat = %0.3f \n'
+          'p = 10**%0.3f \n'
+          'n = %d \n'
+          '%.3f %s\n'
+          % (alys, float(cont_mean), cont_sem, float(supp_mean), supp_sem, w_stat,
+             np.log10(pval), float(nsamples), diff, diff_str))
+
+    return cont_mean, supp_mean
+

+ 79 - 0
figs/wrap.py

@@ -0,0 +1,79 @@
+import fig2
+import fig3
+import sim_edog
+import fig4
+import fig5
+import figS4
+import figS6
+import figS8
+import figS9
+import figS10
+
+
+# panels for figure 2
+f2 = fig2.Fig2()
+f2.ex_sztun_curve()
+f2.fit_norm_curves()
+f2.scatter(alys='ropt')
+f2.scatter(alys='rsupp')
+f2.scatter(alys='rfcs')
+f2.scatter(alys='si')
+
+# panels for figure 3
+edog = sim_edog.SimEdog()
+edog.describe()
+f3 = fig3.Fig3(edog=edog)
+f3.plot()
+
+# panels for figure 4
+f4 = fig4.Fig4()
+f4.exrfs()
+f4.retino_ex()
+f4.trn_retino()
+f4.rf_area()
+f4.norm_szcurves()
+f4.ex_sztun_curve()
+f4.si()
+
+# panels for figure 5
+f5 = fig5.Fig5()
+f5.ex_sztun_curve()
+f5.fit_norm_curves()
+f5.scatter(alys='all_stims')
+f5.scatter(alys='bratio')
+f5.scatter(alys='rfcs')
+f5.scatter(alys='si')
+f5.resp_differences()
+f5.threshlin_ex()
+f5.threshlin()
+
+# panels for figure S4
+fS4 = figS4.FigS4()
+fS4.rfcs_x_depth()
+fS4.si_x_depth()
+
+# panels for figure S6
+fS6 = figS6.FigS6()
+fS6.ex_curves()
+
+# panels for figure S8
+fS8 = figS8.FigS8()
+fS8.ex_curves()
+fS8.si_hist()
+
+# panels for figure
+fS9 = figS9.FigS9()
+fS9.trnchange_x_props()
+
+# panels for figure S10
+fS10 = figS10.FigS10()
+fS10.drfstim_x_si()
+fS10.drfmon_x_si()
+fS10.trn_rfarea_x_si()
+fS10.trn_rfcs_x_si()
+fS10.lgn_rfarea_x_si()
+fS10.lgn_rfcs_x_si()
+
+
+
+