123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- # code for Figure 4 panels
- # import libs
- import matplotlib
- from matplotlib import pyplot as plt
- from matplotlib.ticker import PercentFormatter
- import numpy as np
- import pandas
- import seaborn as sns
- from scipy import stats
- from importlib import reload
- import pickle
- import spatint_utils
- import os
- # reload modules
- reload(spatint_utils)
- # define variables
- trn_red, lgn_green, _ = spatint_utils.get_colors()
- spatint_utils.plot_params()
- class Fig4:
- """Class to for plotting panels for Fig. 4"""
- def __init__(self):
- """Init class"""
- parentdir = os. path. dirname(os. getcwd())
- filename = parentdir + '/data/params_mouse.yaml'
- # read trn size tuning dataframe
- self.trn_sztun_df = pandas.read_pickle(
- filepath_or_buffer= parentdir + '/data/trn_sztun_df.pkl')
- # read dict for trn_sztun_ex
- with open(parentdir + '/data/trn_sztun_ex_dict.pkl', 'rb') as f:
- self.trn_sztun_ex_dict = pickle.load(f)
- # read trn retinotopy dataframe
- self.trn_retino_df = pandas.read_pickle(
- filepath_or_buffer=parentdir + '/data/trn_retino_df.pkl')
- # get data for trn retinotopy example series
- self.trn_retino_ex = self.trn_retino_df[(self.trn_retino_df.m == 'BL6_2018_0003') &
- (self.trn_retino_df.s == 7)]
- # read trn/lgn rf area dataframe
- self.rf_area_df = pandas.read_pickle(
- filepath_or_buffer=parentdir + '/data/rf_area_df.pkl')
- # read lgn size tuning dataframe
- self.lgn_sztun_df = pandas.read_pickle(
- filepath_or_buffer=parentdir + '/data/lgn_sztun_df.pkl')
- def exrfs(self):
- """Plot example receptive fields (Fig. 4d,h)
- Returns
- -------
- axs: list
- list with axes for example rfs
- """
- # define keys for example rfs
- exrfs = [
- # trn example rfs
- {'m': 'BL6_2018_0003', 's': 7, 'e': 1, 'u': 25},
- {'m': 'BL6_2018_0003', 's': 4, 'e': 1, 'u': 9},
- {'m': 'PVCre_2018_0009', 's': 4, 'e': 10, 'u': 10},
- {'m': 'BL6_2018_0003', 's': 2, 'e': 1, 'u': 17},
- {'m': 'BL6_2018_0003', 's': 3, 'e': 1, 'u': 12},
- # dlGN example rfs
- {'m': 'Ntsr1Cre_2019_0003', 's': 4, 'e': 1, 'u': 7},
- {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 23},
- {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 29}]
- # init figure
- f, axs = plt.subplots(1, 8, figsize=(7, 1.5))
- for i, (ax, exrf) in enumerate(zip(axs, exrfs)):
- # get target row in df
- ex_row = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
- (self.rf_area_df.s == exrf['s']) &
- (self.rf_area_df.e == exrf['e']) &
- (self.rf_area_df.u == exrf['u'])]
- if i == len(exrfs) - 1:
- scalebar = {'dist': 5, 'width': 1, 'length': 20, 'col': 'w'}
- else:
- scalebar = {'width': None, 'dist': None, 'length': None, 'col': None}
- # plot rfs
- spatint_utils.plot_rf(means=ex_row.mean_fr.values[0],
- stim_param_values=ex_row.ti_axes.values[0],
- grat_width=ex_row.grat_width,
- grat_height=ex_row.grat_height,
- ax=ax,
- scalebar=scalebar)
- # add labels
- if i == 0:
- ax.set_ylabel('visTRN')
- elif i == 5:
- ax.set_ylabel('dLGN')
- f = plt.gcf()
- f.tight_layout()
- # describe two example rfs
- self.desc_exrfs(exrfs)
- return axs
- def desc_exrfs(self, exrfs):
- """Describe two example RFs for Fig 4d
- Parameters
- -----
- exrfs: list
- example keys
- """
- labels = ['small', 'large']
- for i, label in zip(np.array([3, 4]), labels):
- exrf = exrfs[i]
- ex_area = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
- (self.rf_area_df.s == exrf['s']) &
- (self.rf_area_df.e == exrf['e']) &
- (self.rf_area_df.u == exrf['u'])]
- print('%s area= %0.3f \n'
- '%s rsq = %0.3f \n'
- % (label, ex_area.area, label, ex_area.add_rsq))
- def retino_ex(self, figsize=(2.5, 2.5), axs=None):
- """Plot RFs map of example trn recording (Fig. 4e)
- Parameters
- -------
- figsize: tuple
- Figure size (width, height)
- axs: list
- two axes used for plot and colorbar
- Returns
- -------
- axs: list
- two axes with plot and colorbar
- """
- if axs is None:
- # create figure
- f = plt.figure(figsize=figsize)
- axs = []
- # axis for plot
- l = 0.15
- b = 0.15
- w = 0.6
- h = 0.6
- axs.append(f.add_axes([l, b, w, h]))
- # axis for colorbar
- l = 0.8
- w_cbar = 0.05
- axs.append(f.add_axes([l, b, w_cbar, h]))
- # print number of units in series
- exseries = self.trn_retino_ex
- print('number of units in example series:', len(exseries))
- # define colors
- vmin = 0.3
- vmax = 0.95
- colors = np.linspace(vmin, vmax, len(exseries))
- mymap = plt.cm.get_cmap("Reds")
- my_colors = mymap(colors)
- # iterate over units
- for urowi, (_, urow) in enumerate(exseries.iterrows()):
- # compute parameters
- params = urow.params
- angle = params[2] * 180 / np.pi
- fitpars_deg = spatint_utils.degdiff(180, angle, 180)
- # calculate ellipse
- x, y = spatint_utils.calculate_ellipse(params[0], params[1], params[4], params[5],
- fitpars_deg)
- # plot
- axs[0].plot(x, y, c=my_colors[urowi])
- # layout
- axs[0].set_xlim(-20, 70)
- axs[0].set_ylim(-20, 70)
- axs[0].set_xticks((0, 30, 60))
- axs[0].set_yticks((0, 30, 60))
- axs[0].spines['bottom'].set_bounds(-10, 70)
- axs[0].spines['left'].set_bounds(-10, 70)
- axs[0].set_ylabel('Elevation ($\degree$)')
- axs[0].set_xlabel('Azimuth ($\degree$)')
- # draw colorbar
- sm = matplotlib.colors.LinearSegmentedColormap.from_list('Reds', my_colors)
- vmin_plot = 0.15
- vmax_plot = 1
- norm = matplotlib.colors.Normalize(vmin=vmin_plot, vmax=vmax_plot)
- cbar = matplotlib.colorbar.ColorbarBase(axs[1], cmap=sm, norm=norm)
- cbar.set_label('Depth (μm)', rotation=270)
- cbar.ax.invert_yaxis()
- barunits = (vmax - vmin) / (exseries.depth.iloc[-1] - exseries.depth.iloc[0])
- ticklabels = np.array([3050, 3100, 3150])
- ticks = ((ticklabels - exseries.depth.iloc[0]) * barunits) + vmin
- cbar.set_ticks(ticks)
- cbar.set_ticklabels(ticklabels)
- f = plt.gcf()
- f.tight_layout()
- return axs
- def trn_retino(self, figsize=(6, 2.5), axs=None):
- """Plot RFs map of example trn recording (Fig. 4f-g)
- Parameters
- -------
- figsize: tuple
- Figure size (width, height)
- axs: list
- two axes, one per dimension (azim, elev)
- Returns
- -------
- axs: list
- two axes
- """
- if axs is None:
- # create figure
- f, axs = plt.subplots(1, 2, figsize=figsize)
- # get data
- trn_retino_df = self.trn_retino_df
- # plot azimuth against depth
- axs[0].scatter(trn_retino_df.azim, trn_retino_df.depth, facecolors='none',
- edgecolors='k', linewidth=0.5)
- # print n
- n = len(trn_retino_df.azim)
- print('n azimuth = %d' % n)
- # plot regression line from ancova model
- xeval = np.arange(np.min(trn_retino_df.azim), np.max(trn_retino_df.azim), 1)
- # parameters for the model copied from R
- intercept = 3044.511585
- vis_angle = -1.191659
- azim = 60.728882 # for category 1: azim 0: elev
- interaction = -1.838795
- ymodel_azim = intercept + xeval * vis_angle + 1 * azim + 1 * xeval * interaction
- axs[0].plot(xeval, ymodel_azim, 'r')
- # layout
- axs[0].set_xticks((0, 30, 60))
- axs[0].set_yticks((2500, 2900, 3300))
- axs[0].invert_yaxis()
- axs[0].set_ylabel('Depth (μm)')
- axs[0].set_xlabel('Azimuth ($\degree$)')
- axs[0].spines['bottom'].set_bounds(-20, 70)
- axs[0].spines['left'].set_bounds(2500, 3300)
- # plot elevation against depth
- axs[1].scatter(trn_retino_df.elev, trn_retino_df.depth, facecolors='none',
- edgecolors='k', linewidth=0.5)
- # print n
- n = len(trn_retino_df.elev)
- print('n elev = %d' % n)
- # plot regression line from model
- xeval = np.arange(np.min(trn_retino_df.elev), np.max(trn_retino_df.elev), 1)
- ymodel_elev = intercept + xeval * vis_angle + 0 * azim + 0 * xeval * interaction
- axs[1].plot(xeval, ymodel_elev, 'r')
- # layout
- axs[1].set_xticks((0, 30, 60))
- axs[1].set_yticks((2500, 2900, 3300))
- axs[1].invert_yaxis()
- axs[1].set_ylabel('')
- axs[1].set_yticklabels([])
- axs[1].set_xlabel('Elevation ($\degree$)')
- axs[1].spines['bottom'].set_bounds(-20, 70)
- axs[1].spines['left'].set_bounds(2500, 3300)
- f = plt.gcf()
- f.tight_layout()
- return axs
- def rf_area(self, figsize=(2, 2), ax=None):
- """Create violin plot for comparison of TRN and LGN RF sizes (Fig. 4i)
- Parameters
- -------
- figsize: tuple
- figure size (width, height)
- ax: instance of matplotlib.axes class
- axis to use for plotting
- Returns
- -------
- ax: mpl axis
- Axis with plot
- """
- if ax is None:
- # make figure
- f, ax = plt.subplots(figsize=figsize)
- # get data
- rf_area_df = self.rf_area_df
- # split data
- trn_area = rf_area_df['area'][rf_area_df['region'] == 'PGN'].array
- lgn_area = rf_area_df['area'][rf_area_df['region'] == 'LGN'].array
- # plot data
- sns.violinplot(data=[np.log(trn_area), np.log(lgn_area)], palette=[trn_red,
- lgn_green], ax=ax, linewidth=1, inner=None)
- # plot mean
- ax.plot([0, 1], [np.log(trn_area.mean()), np.log(lgn_area.mean())], linestyle='',
- c='k', marker='.')
- # format plot
- ylabels = np.array([10, 100, 1000])
- ax.set_yticks(np.log(ylabels))
- ax.set_yticklabels(ylabels)
- ax.spines['right'].set_visible(False)
- ax.spines['top'].set_visible(False)
- ylims = ax.get_ylim()
- ax.spines['bottom'].set_bounds(0, 1)
- ax.spines['left'].set_bounds(ylims)
- plt.gca().get_xticklabels()[0].set_color(trn_red)
- plt.gca().get_xticklabels()[1].set_color(lgn_green)
- ax.set_xticklabels(['visTRN', 'dLGN'])
- ax.set_ylabel('RF area (deg$^2$)')
- ax.grid(False)
- # mannwhitneyUtest
- u_stat, parea = stats.mannwhitneyu(trn_area, lgn_area)
- # test for differences in variance
- # (with center = median, levene's test = brown-forsythe test)
- f_stat, pvar = stats.levene(trn_area, lgn_area, center='median')
- # ratio
- ratio = trn_area.mean() / lgn_area.mean()
- # print stats
- print('dispersion stats: Brown–Forsythe test\n'
- 'Fstat = %0.3f \n'
- 'pval = 10**%0.3f\n\n'
- 'central tendency stats\n'
- 'Ustat = %0.3f\n'
- 'pval area = 10**%0.3f \n'
- 'N area visTRN = %d \n'
- 'N area dLGN = %d\n'
- 'visTRN mean area +- sem = %0.3f (+- %0.3f)\n'
- 'dLGN mean area +- sem = %0.3f (+- %0.3f)\n'
- 'visTRN rfs are on average %0.3f x larger than dLGN rfs'
- % (f_stat,
- np.log10(pvar),
- u_stat,
- np.log10(parea),
- len(trn_area),
- len(lgn_area),
- trn_area.mean(),
- stats.sem(trn_area),
- lgn_area.mean(),
- stats.sem(lgn_area),
- ratio))
- f = plt.gcf()
- f.tight_layout()
- return ax
- def norm_szcurves(self, figsize=(3, 3), ax=None, eval_range=76, thres=128, mark_ex=True,
- lw=0.5, xticks=(0, 25, 50, 75), colormap='Greys'):
- """Plot normalized fitted size tuning curves for visTRN population (Fig. 4l)
- Parameters
- -------
- figsize: tuple
- Figure size (width, height)
- ax: instance of matplotlib.axes class
- Axis to use for plotting
- eval_range: int
- Range over which to evaluate model
- thres: int
- Lower threshold for darkness of line
- mark_ex: bool
- If true plots example neuron in different color
- lw: float
- Linewidth
- xticks: tuple
- xticks
- colormap: string
- Colormap
- Returns
- -------
- ax: mpl axis
- Axis with plot
- """
- if ax is None:
- # create figure
- f, ax = plt.subplots(figsize=figsize)
- # define range over which to evaluate model
- x_eval = range(eval_range)
- # define colormap
- cmap = plt.cm.get_cmap(colormap)
- # plot all tuning curves
- for row in self.trn_sztun_df.itertuples():
- # get y data
- params = row.tun_pars
- y = spatint_utils.rog_offset(x_eval, *params)
- # subtract offset
- y_sub = y - y[0]
- # normalize
- y_norm = y_sub / np.nanmax(y_sub)
- # define color
- si = row.si_76
- col = int(np.round((1 - si) * 255))
- col = np.max((col, thres))
- # plot
- ax.plot(x_eval, y_norm, c=cmap(col), lw=lw)
- if mark_ex:
- # plot example session in red
- params = self.trn_sztun_ex_dict['tun_pars']
- y = spatint_utils.rog_offset(x_eval, *params)
- y_sub = y - y[0]
- y_norm = y_sub / np.nanmax(y_sub)
- ax.plot(x_eval, y_norm, c=trn_red, lw=lw)
- # layout
- ax.set_xlabel('Diameter ($\degree$)')
- ax.set_ylabel('Normalized firing rate')
- ax.set_xticks(xticks)
- ax.set_yticks((0, 0.5, 1))
- ax.spines['bottom'].set_bounds(0, 75)
- ax.spines['left'].set_bounds(0, 1)
- return ax
- def ex_sztun_curve(self, figsize=(4, 2), axs=None):
- """Plot example visTRN size-tuning curve and raster (Fig. 4jk)
- Parameters
- -------
- figsize: tuple
- figure size (width, height)
- axs: instance of matplotlib.axes class
- axis to use for plotting
- Returns
- -------
- ax: mpl axis
- Axis with plot
- """
- if axs is None:
- # create figure
- f, axs = plt.subplots(1, 2, figsize=figsize)
- # get data
- ex_data = self.trn_sztun_ex_dict
- # plot raster
- spatint_utils.plot_raster(raster=ex_data['rasters'][ex_data['u']],
- tranges=ex_data['tranges'],
- opto=ex_data['opto'],
- ax=axs[0])
- # plot curve
- spatint_utils.plot_tun(means=ex_data['tun_mean'],
- sems=ex_data['tun_sem'],
- spons=ex_data['tun_spon_mean'],
- xs=ex_data['ti_axes'],
- params=ex_data['tun_pars'],
- ax=axs[1])
- # format layout
- axs[1].set_xticks((0, 25, 50, 75))
- axs[1].set_yticks((0, 10, 20))
- axs[1].spines['bottom'].set_bounds(0, 75)
- f = plt.gcf()
- f.tight_layout()
- # print info for example cell
- sz_ex_si = self.trn_sztun_ex_dict['si_76']
- sz_ex_rfcs = self.trn_sztun_ex_dict['rfcs_76']
- print('size tuning example cell: \n'
- 'Preferred size = %0.3f \n'
- 'SI = %0.3f'
- % (sz_ex_rfcs, sz_ex_si))
- def si(self, figsize=(2, 2), ax=None):
- """Plot histogram of suppression indices for visTRN and dLGN population (Fig. 4m)
- Parameters
- -------
- figsize: tuple
- figure size (width, height)
- ax: instance of matplotlib.axes class
- axis to use for plotting
- Returns
- -------
- ax: mpl axis
- Axis with plot
- """
- if ax is None:
- # create figure
- f, ax = plt.subplots(figsize=figsize)
- # get data
- trn_sis = self.trn_sztun_df.si_76.to_numpy()
- lgn_sis = self.lgn_sztun_df.si_76.str[0].to_numpy()
- # assert that no data is beyond limits
- assert (~np.any(trn_sis < 0)) & (~np.any(trn_sis > 1)), 'Datapoints beyond bounds'
- assert (~np.any(lgn_sis < 0)) & (~np.any(lgn_sis > 1)), 'Datapoints beyond bounds'
- # plot
- si_bins = np.arange(0, 1.05, 0.05) # define bins
- trn_red_trsp = (*trn_red[0:3], 0.5) # define color for trn
- lgn_green_trsp = (*lgn_green[0:3], 0.5) # define color for dlgn
- ax.hist(trn_sis, bins=si_bins, weights=np.ones(len(trn_sis)) / len(trn_sis), lw=0,
- fc=trn_red_trsp)
- ax.hist(lgn_sis, bins=si_bins, weights=np.ones(len(lgn_sis)) / len(lgn_sis), lw=0,
- fc=lgn_green_trsp)
- # layout
- ax.yaxis.set_major_formatter(PercentFormatter(1))
- ax.set_xlabel('Suppression index')
- ax.set_ylabel('neurons')
- ax.set_xticks((0, 0.5, 1))
- ax.set_yticks((0, 0.25, 0.5))
- ax.spines['bottom'].set_bounds(0, 1)
- ax.set_xlim((0, 1))
- f = plt.gcf()
- f.tight_layout()
- # compute and print stats for trn
- n_trn = len(trn_sis)
- si_mean_trn = np.mean(trn_sis)
- si_sem_trn = stats.sem(trn_sis)
- si_med_trn = np.median(trn_sis)
- lower05_trn = (len(trn_sis[trn_sis < 0.05]) /
- len(trn_sis) * 100) # percentage of cells with si smaller 0.05
- print('trn size tuning population: \n'
- 'n = %d\n'
- 'mean si +/- sem = %0.3f (+- %0.3f) \n'
- 'median si = %0.3f \n'
- '%0.3f percent of pgn cells have si < 0.05\n'
- % (n_trn,
- si_mean_trn,
- si_sem_trn,
- si_med_trn,
- lower05_trn))
- # compute and print stats for dlgn
- n_lgn = len(lgn_sis)
- si_mean_lgn = np.mean(lgn_sis)
- si_sem_lgn = stats.sem(lgn_sis)
- si_med_lgn = np.median(lgn_sis)
- print('lgn size tuning population:\n'
- 'n = %d\n'
- 'mean si +/- sem = %0.3f (+- %0.3f)\n'
- 'median si = %0.3f\n'
- % (n_lgn,
- si_mean_lgn,
- si_sem_lgn,
- si_med_lgn))
- # compare the two
- ustat_sis, p_sis = stats.mannwhitneyu(trn_sis, lgn_sis)
- print('mannwhitneyu test to compare si in dlgn and trn:\n'
- 'Ustat: %0.3f\n'
- 'pvalue: 10**%0.3f\n'
- % (ustat_sis, np.log10(p_sis)))
- return ax
|