"""Figure 2 plots, use run -i fig2.py""" fmodestxt = list(fmode2txt.values()) # 'all', 'nonburst', 'burst', 'nonrand' mi = pd.MultiIndex.from_product([mvimseustrs, fmodestxt], names=['mseu', 'fmode']) fig2 = pd.DataFrame(index=mi, columns=['slope', 'thresh', 'rsq', 'suppression_meanburstratio', 'suppression_snr']) # convert multi index to columns for easier to read restriction syntax: sampfitsr = sampfits.reset_index() mvirespr = mviresp.reset_index() linmin = 0 # plot single-sample V1 model-fit PSTH overtop of test feedback and suppression PSTHs: dt = 5 # stimulus duration sec #tmin, tmax = 0 + OFFSETS[0], dt + OFFSETS[1] tmin, tmax = 0, dt #figsize = 0.855 + tmax*1.5*RASTERSCALEX, PSTHHEIGHT psthfigsize = 0.855 + tmax*2.1*RASTERSCALEX, PSTHHEIGHT # medians: # n1: rsq=0.29, slope=0.37, thr=-0.14 # n2: rsq=0.90, slope=0.67, thr=1.58 #for seedi in [110, 199, 339, 388]: # seeds that result in rsq identical to median np.random.seed(199) # fix random seed for identical results from np.random.choice() on each run for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs: for kind in ['nat']:#MVIKINDS: for st8 in ['none']:#ALLST8S: # get one sample of trials, separate into fit and test data: ctrlraster = mviresp.loc[mseustr, kind, st8, False]['raster'] # narrow optoraster = mviresp.loc[mseustr, kind, st8, True]['raster'] # narrow optoburstraster = mviresp.loc[mseustr, kind, st8, True]['braster'] # narrow nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none' noptotrials = len(optoraster) # typically 200 trials when st8 == 'none' for ntrials in [nctrltrials, noptotrials]: if ntrials < MINNTRIALSAMPLETHRESH: print('Not enough trials to sample:', mseustr, kind, st8) continue # don't bother sampling, leave entry in sampfits empty ctrltrialis = np.arange(nctrltrials) optotrialis = np.arange(noptotrials) ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing # probably doesn't matter if use ctrl or opto bins, should be the same: bins = mviresp.loc[mseustr, kind, st8, False]['bins'] t = mviresp.loc[mseustr, kind, st8, False]['t'] # mid bins # randomly sample half the ctrl and opto trials, without replacement: ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False)) optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False)) ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis] optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis] optotestburstraster = optoburstraster[optotestis] # calculate fit and test opto PSTHs, subsampled in time: ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx] optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx] ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx] optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx] optotestburstpsth = raster2psth(optotestburstraster, bins, binw, tres, kernel)[::ssx] mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth, model=model) #rsqstr = '%.2f' % np.round(rsq, 2) #exmpli2rsqstr = {1:'0.29', 2:'0.90'} #if rsqstr != exmpli2rsqstr[exmpli]: # continue #print(seedi, exmpli) th = -b / mm # threshold, i.e., x intercept ## plot the sample's PSTHs: f, a = plt.subplots(figsize=psthfigsize) title = '%s %s %s PSTH %s model' % (mseustr, kind, st8, model) wintitle(title, f) color = st82clr[st8] tss = t[::ssx] # subsampled in time # plot control and opto test PSTHs: opto2psth = {False:ctrltestpsth, True:optotestpsth} for opto in OPTOS: psth = opto2psth[opto] c = desat(color, opto2alpha[opto]) # do manual alpha mixing fb = opto2fb[opto].title() if opto == False: ls = '--' lw = 1 else: ls = '-' lw = 2 br = mviresp.loc[mseustr, kind, st8, opto]['meanburstratio'] print('%s exampli=%d %s %s meanburstratio=%g' % (mseustr, exmpli, st8, opto, br)) a.plot(tss, psth, ls=ls, lw=lw, marker='None', color=c, label=fb) # plot opto model: modeltestpsth = mm * ctrltestpsth + b a.plot(tss, modeltestpsth, '-', lw=2, color=optoblue, alpha=1, label='Suppression model') # plot opto burst test PSTH: bc = desat(burstclr, 0.3) # do manual alpha mixing #bc = desat(burstclr, opto2alpha[opto]) # do manual alpha mixing a.plot(tss, optotestburstpsth, ls=ls, lw=lw, marker='None', color=bc, label=fb+' burst', zorder=-99) l = a.legend(frameon=False) l.set_draggable(True) a.set_xlabel('Time (s)') a.set_ylabel('Firing rate (spk/s)') a.set_xlim(tmin, tmax) # plot horizontal line signifying stimulus period, just below data: #ymin, ymax = a.get_ylim() #a.hlines(-ymax*0.025, 0, dt, colors='k', lw=4, clip_on=False, in_layout=False) #a.set_ylim(ymin, ymax) # restore y limits from before plotting the hline ## scatter plot suppression vs. feedback for the sample's PSTH timepoints, ## plus model fit line: scale = 'linear' # log or linear logmin, logmax = -2, 2 log0min = logmin + 0.05 # for nat none all-spike plots for fig2: examplemaxrate = {'PVCre_2018_0003_s03_e03_u51':35, 'PVCre_2017_0008_s12_e06_u56':140} if kind == 'nat' and st8 == 'none': linmax = examplemaxrate[mseustr] else: linmax = np.vstack([ctrltestpsth, optotestpsth]).max() figsize = DEFAULTFIGURESIZE if linmax >= 100: figsize = figsize[0]*1.025, figsize[1] # tweak to make space extra y axis digit f, a = plt.subplots(figsize=figsize) titlestr = ('%s %s %s PSTH scatter %s' % (mseustr, kind, st8, model)) wintitle(titlestr, f) if scale == 'log': # replace off-scale low values with log0min, so the points remain visible: ctrltestpsth[ctrltestpsth <= 10**logmin] = 10**log0min optotestpsth[optotestpsth <= 10**logmin] = 10**log0min # plot y=x line: if scale == 'log': xyline = [10**logmin, 10**logmax], [10**logmin, 10**logmax] else: xyline = [linmin, linmax], [linmin, linmax] a.plot(xyline[0], xyline[1], '--', clip_on=True, color='gray', zorder=100) # plot testpsth vs control: texts = [] c = st82clr[st8] a.plot(ctrltestpsth, optotestpsth, '.', clip_on=True, ms=2, color=c, label=st8) x = np.array([ctrltestpsth.min(), ctrltestpsth.max()]) y = mm * x + b # model output line a.plot(x, y, '-', color=optoblue, clip_on=True, alpha=1, lw=2) # plot model fit # display the sample's parameters: txt = ('$\mathregular{R^{2}=%.2f}$\n' '$\mathregular{Slope=%.2f}$\n' '$\mathregular{Thr.=%.1f}$' % (rsq, mm, th)) a.add_artist(AnchoredText(txt, loc='upper left', frameon=False)) a.set_xlabel('Feedback rate (spk/s)') a.set_ylabel('Suppression rate (spk/s)') if scale == 'log': a.set_xscale('log') a.set_yscale('log') xymin, xymax = 10**logmin, 10**logmax ticks = 10**(np.arange(logmin, logmax+1, dtype=float)) else: xymin, xymax = linmin, linmax ticks = a.get_xticks() a.set_xticks(ticks) # seem to need to set both x and y to force them to match a.set_yticks(ticks) a.set_xlim(xymin, xymax) a.set_ylim(xymin, xymax) a.set_aspect('equal') a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) # scatter plot resampled model slope vs threshold, for only those fits with decent SNR # in at least one opto state: figsize = DEFAULTFIGURESIZE slopelogmin, slopelogmax = -4, 1 logyticks = np.logspace(slopelogmin, slopelogmax, num=(slopelogmax-slopelogmin+1), base=2) threshmin, threshmax = -25, 25 for kind in ['nat']:#MVIKINDS: for st8 in ['none']:#ALLST8S: for fmode in fmodes: # iterate over firing modes (all, non-burst, burst) print('fmode:', fmode) # boolean pandas Series: rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8) rows = sampfitsr[rowis] mmeds, thmeds, rsqmeds = [], [], [] # median fit values brs, snrs = [], [] # single burst ratio and SNR values matching each mseu sgnfis, insgnfis, exmplis, exmplmseustrs, normlis = [], [], [], [], [] keptmseui = 0 # manually init and increment instead of using enumerate() for i, row in rows.iterrows(): # for each mseu mseustr = row['mseu'] ms, ths, rsqs = row[fmode+'ms'], row[fmode+'ths'], row[fmode+'rsqs'] if np.isnan([ms, ths, rsqs]).any(): continue # skip this mseu # skip mseus that had no meaningful signal to fit in either condition: maxsnr = get_max_snr(mvirespr, mseustr, kind, st8) if maxsnr < SNRTHRESH: continue mmed, ml, mh = percentile_ci(ms) thmed, thl, thh = percentile_ci(ths) rsqmed = np.median(rsqs) if rsqmed < 0: # catastrophic fit failure print('catastrophic fit failure:', mseustr) continue # skip mmeds.append(mmed) thmeds.append(thmed) rsqmeds.append(rsqmed) # collect suppression mean burst ratio and SNR: mvirowis = ((mvirespr['mseu'] == mseustr) & (mvirespr['kind'] == kind) & (mvirespr['st8'] == st8) & (mvirespr['opto'] == True)) # only for suppression mvirow = mvirespr[mvirowis] assert len(mvirow) == 1 br = mvirow['meanburstratio'].iloc[0] snr = mvirow['snr'].iloc[0] brs.append(br) snrs.append(snr) fig2['slope'][mseustr, fmode2txt[fmode]] = mmed fig2['thresh'][mseustr, fmode2txt[fmode]] = thmed fig2['rsq'][mseustr, fmode2txt[fmode]] = rsqmed fig2['suppression_meanburstratio'][mseustr, fmode2txt[fmode]] = br fig2['suppression_snr'][mseustr, fmode2txt[fmode]] = snr # collect separate lists of significant and insignificant points: if not (ml < 1 < mh):# and not (thl < 0 < thh): sgnfis.append(keptmseui) # slope significantly different from 1 else: insgnfis.append(keptmseui) # slope not significantly different from 1 if mseustr in mvimseu2exmpli: exmplis.append(keptmseui) exmplmseustrs.append(mseustr) print('Example Neuron', mvimseu2exmpli[mseustr], ':', mseustr) print('mmed, ml, mh:', mmed, ml, mh) print('thmed, thl, thh:', thmed, thl, thh) else: normlis.append(keptmseui) keptmseui += 1 # manually increment mmeds = np.asarray(mmeds) thmeds = np.asarray(thmeds) rsqmeds = np.asarray(rsqmeds) brs = np.asarray(brs) snrs = np.asarray(snrs) ## plot median fit slope vs thresh: f, a = plt.subplots(figsize=(figsize[0]*1.11, figsize[1])) # extra space for ylabels titlestr = ('PSTH %s fit slope vs thresh %s %s snrthresh=%.3f %s' % (model, kind, st8, SNRTHRESH, fmode2txt[fmode])) wintitle(titlestr, f) normlinsgnfis = np.intersect1d(normlis, insgnfis) normlsgnfis = np.intersect1d(normlis, sgnfis) # plot x=0 and y=1 lines: a.axvline(x=0, ls='--', marker='', color='gray', zorder=-np.inf) a.axhline(y=1, ls='--', marker='', color='gray', zorder=-np.inf) # plot normal (non-example) insignificant points: c = desat(st82clr[st8], SGNF2ALPHA[False]) # do manual alpha mixing a.scatter(thmeds[normlinsgnfis], mmeds[normlinsgnfis], clip_on=False, marker='.', c='None', edgecolor=c, s=DEFSZ) # plot normal (non-example) significant points: c = desat(st82clr[st8], SGNF2ALPHA[True]) # do manual alpha mixing a.scatter(thmeds[normlsgnfis], mmeds[normlsgnfis], clip_on=False, marker='.', c='None', edgecolor=c, s=DEFSZ) exmplinsgnfis = np.intersect1d(exmplis, insgnfis) exmplsgnfis = np.intersect1d(exmplis, sgnfis) print('exmplinsgnfis', exmplinsgnfis) print('examplsgnfis', exmplsgnfis) # plot insignificant and significant example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): if exmpli in exmplinsgnfis: alpha = SGNF2ALPHA[False] elif exmpli in exmplsgnfis: alpha = SGNF2ALPHA[True] else: raise RuntimeError("Some kind of exmpli set membership error") marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(thmeds[exmpli], mmeds[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw, alpha=alpha) # plot mean median point: if fmode == '': # all spikes, plot LMM mean print('plotting LMM mean for fmode=%s' % fmode) a.scatter(-0.19, 0.75, # read off of stats/fig2*.pdf c='red', edgecolor='red', s=50, marker='^') elif fmode == 'nb': # non burst spikes, plot LMM mean print('plotting LMM mean for fmode=%s' % fmode) a.scatter(0.09, 0.74, # read off of stats/fig2*.pdf c='red', edgecolor='red', s=50, marker='^') else: a.scatter(np.mean(thmeds), gmean(mmeds), c='red', edgecolor='red', s=50, marker='^') # display median of median rsq sample values: txt = '$\mathregular{R^{2}_{med}=}$%.2f' % np.round(np.median(rsqmeds), 2) a.add_artist(AnchoredText(txt, loc='lower right', frameon=False)) #cbar = f.colorbar(path) #cbar.ax.set_xlabel('$\mathregular{R^{2}}$') #cbar.ax.xaxis.set_label_position('top') a.set_yscale('log', basey=2) a.set_xlabel('Threshold') a.set_ylabel('Slope') a.set_xlim(threshmin, threshmax) a.set_xticks([threshmin, 0, threshmax]) a.set_ylim(2**slopelogmin, 2**slopelogmax) a.set_yticks(logyticks) axes_disable_scientific(a) a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) ## plot fit rsq vs suppression meanburstratio: f, a = plt.subplots(figsize=figsize) titlestr = ('PSTH %s fit rsq vs suppression meanburstratio %s %s %s' % (model, kind, st8, fmode2txt[fmode])) wintitle(titlestr, f) a.scatter(brs[normlis], rsqmeds[normlis], clip_on=False, marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ) # plot example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(brs[exmpli], rsqmeds[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw) # get fname of appropriate LMM .cvs file: if fmode == '': # use Steffen's LMM linregress fit, for fig2f fname = os.path.join('stats', 'figure_2f_coefs.csv') # fetch LMM linregress fit params from .csv: df = pd.read_csv(fname) mm = df['slope'][0] b = df['intercept'][0] x = np.array([np.nanmin(brs), np.nanmax(brs)]) y = mm * x + b a.plot(x, y, '-', color='red') # plot linregress fit a.set_xlabel('Suppression BR') a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title()) a.set_xlim(xmin=0) a.set_ylim(0, 1) a.set_yticks([0, 0.5, 1]) a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) ## plot fit rsq vs suppression SNR: f, a = plt.subplots(figsize=figsize) titlestr = ('PSTH %s fit rsq vs SNR %s %s %s' % (model, kind, st8, fmode2txt[fmode])) wintitle(titlestr, f) a.scatter(snrs[normlis], rsqmeds[normlis], clip_on=False, marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ) # plot example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(snrs[exmpli], rsqmeds[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw) a.set_xlabel('Suppression SNR') a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title()) a.set_xlim(xmin=0) a.set_ylim(0, 1) a.set_yticks([0, 0.5, 1]) a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) ## plot suppression fit SNR vs meanburstratio: f, a = plt.subplots(figsize=figsize) titlestr = ('PSTH %s fit SNR vs meanburstratio %s %s %s' % (model, kind, st8, fmode2txt[fmode])) wintitle(titlestr, f) a.scatter(brs[normlis], snrs[normlis], clip_on=False, marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ) # plot example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(brs[exmpli], snrs[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw) a.set_xlabel('Suppression BR') a.set_ylabel('Suppression SNR') a.set_xlim(xmin=0) a.set_ylim(ymin=0) a.set_yticks([0, 0.5, 1]) a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) # scatter plot resampled model fit rsq for nonburst vs all spikes, and nonrand vs all spikes: figsize = DEFAULTFIGURESIZE for kind in ['nat']:#MVIKINDS: for st8 in ['none']:#ALLST8S: # boolean pandas Series: rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8) rows = sampfitsr[rowis] allrsqmeds, nbrsqmeds, nrrsqmeds = [], [], [] # median fit rsq values exmplis, exmplmseustrs, normlis = [], [], [] keptmseui = 0 # manually init and increment instead of using enumerate() for i, row in rows.iterrows(): # for each mseu mseustr = row['mseu'] allrsqs, nbrsqs, nrrsqs = row['rsqs'], row['nbrsqs'], row['nrrsqs'] if np.isnan([allrsqs, nbrsqs, nrrsqs]).any(): continue # skip this mseu # skip mseus that had no meaningful signal to fit in either condition: maxsnr = get_max_snr(mvirespr, mseustr, kind, st8) if maxsnr < SNRTHRESH: continue allrsqmed = np.median(allrsqs) nbrsqmed = np.median(nbrsqs) nrrsqmed = np.median(nrrsqs) if allrsqmed < 0 or nbrsqmed < 0 or nrrsqmed < 0: # catastrophic fit failure print('catastrophic fit failure:', mseustr) continue # skip allrsqmeds.append(allrsqmed) nbrsqmeds.append(nbrsqmed) nrrsqmeds.append(nrrsqmed) if mseustr in mvimseu2exmpli: exmplis.append(keptmseui) exmplmseustrs.append(mseustr) else: normlis.append(keptmseui) if nbrsqmed > 0.4 and allrsqmed < 0.4: print(mseustr, allrsqmed, nbrsqmed) keptmseui += 1 # manually increment allrsqmeds = np.asarray(allrsqmeds) nbrsqmeds = np.asarray(nbrsqmeds) nrrsqmeds = np.asarray(nrrsqmeds) # plot nonburst vs all rsq: f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels titlestr = 'PSTH %s fit rsq nonburst vs all %s %s' % (model, kind, st8) wintitle(titlestr, f) linmax = np.vstack([allrsqmeds, nbrsqmeds]).max() xyline = [linmin, linmax], [linmin, linmax] a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf) a.scatter(allrsqmeds[normlis], nbrsqmeds[normlis], clip_on=False, marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ) # plot example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(allrsqmeds[exmpli], nbrsqmeds[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw) a.set_xlabel('All spikes $\mathregular{R^{2}}$') a.set_ylabel('Non-burst spikes $\mathregular{R^{2}}$') a.set_xlim(linmin, 1) a.set_ylim(linmin, 1) a.set_xticks([0, 0.5, 1]) a.set_yticks([0, 0.5, 1]) a.set_aspect('equal') a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4)) # plot nr vs all rsq: f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels titlestr = 'PSTH %s fit rsq nr vs all %s %s' % (model, kind, st8) wintitle(titlestr, f) linmax = np.vstack([allrsqmeds, nrrsqmeds]).max() xyline = [linmin, linmax], [linmin, linmax] a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf) a.scatter(allrsqmeds[normlis], nrrsqmeds[normlis], clip_on=False, marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ) # plot example points, one at a time: for exmpli, mseustr in zip(exmplis, exmplmseustrs): marker = exmpli2mrk[mvimseu2exmpli[mseustr]] c = exmpli2clr[mvimseu2exmpli[mseustr]] sz = exmpli2sz[mvimseu2exmpli[mseustr]] lw = exmpli2lw[mvimseu2exmpli[mseustr]] a.scatter(allrsqmeds[exmpli], nrrsqmeds[exmpli], clip_on=False, marker=marker, c=c, s=sz, lw=lw) a.set_xlabel('All spikes $\mathregular{R^{2}}$') a.set_ylabel('Rand. rem. spikes $\mathregular{R^{2}}$') a.set_xlim(linmin, 1) a.set_ylim(linmin, 1) a.set_xticks([0, 0.5, 1]) a.set_yticks([0, 0.5, 1]) a.set_aspect('equal') a.spines['left'].set_position(('outward', 4)) a.spines['bottom'].set_position(('outward', 4))