123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468 |
- """Figure 2 plots, use run -i fig2.py"""
- fmodestxt = list(fmode2txt.values()) # 'all', 'nonburst', 'burst', 'nonrand'
- mi = pd.MultiIndex.from_product([mvimseustrs, fmodestxt], names=['mseu', 'fmode'])
- fig2 = pd.DataFrame(index=mi, columns=['slope', 'thresh', 'rsq',
- 'suppression_meanburstratio', 'suppression_snr'])
- # convert multi index to columns for easier to read restriction syntax:
- sampfitsr = sampfits.reset_index()
- mvirespr = mviresp.reset_index()
- linmin = 0
- # plot single-sample V1 model-fit PSTH overtop of test feedback and suppression PSTHs:
- dt = 5 # stimulus duration sec
- #tmin, tmax = 0 + OFFSETS[0], dt + OFFSETS[1]
- tmin, tmax = 0, dt
- #figsize = 0.855 + tmax*1.5*RASTERSCALEX, PSTHHEIGHT
- psthfigsize = 0.855 + tmax*2.1*RASTERSCALEX, PSTHHEIGHT
- # medians:
- # n1: rsq=0.29, slope=0.37, thr=-0.14
- # n2: rsq=0.90, slope=0.67, thr=1.58
- #for seedi in [110, 199, 339, 388]: # seeds that result in rsq identical to median
- np.random.seed(199) # fix random seed for identical results from np.random.choice() on each run
- for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs:
- for kind in ['nat']:#MVIKINDS:
- for st8 in ['none']:#ALLST8S:
- # get one sample of trials, separate into fit and test data:
- ctrlraster = mviresp.loc[mseustr, kind, st8, False]['raster'] # narrow
- optoraster = mviresp.loc[mseustr, kind, st8, True]['raster'] # narrow
- optoburstraster = mviresp.loc[mseustr, kind, st8, True]['braster'] # narrow
- nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none'
- noptotrials = len(optoraster) # typically 200 trials when st8 == 'none'
- for ntrials in [nctrltrials, noptotrials]:
- if ntrials < MINNTRIALSAMPLETHRESH:
- print('Not enough trials to sample:', mseustr, kind, st8)
- continue # don't bother sampling, leave entry in sampfits empty
- ctrltrialis = np.arange(nctrltrials)
- optotrialis = np.arange(noptotrials)
- ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing
- optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing
- # probably doesn't matter if use ctrl or opto bins, should be the same:
- bins = mviresp.loc[mseustr, kind, st8, False]['bins']
- t = mviresp.loc[mseustr, kind, st8, False]['t'] # mid bins
- # randomly sample half the ctrl and opto trials, without replacement:
- ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False))
- optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False))
- ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement
- optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement
- ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis]
- optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis]
- optotestburstraster = optoburstraster[optotestis]
- # calculate fit and test opto PSTHs, subsampled in time:
- ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx]
- optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx]
- ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx]
- optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx]
- optotestburstpsth = raster2psth(optotestburstraster, bins, binw, tres, kernel)[::ssx]
- mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth,
- model=model)
- #rsqstr = '%.2f' % np.round(rsq, 2)
- #exmpli2rsqstr = {1:'0.29', 2:'0.90'}
- #if rsqstr != exmpli2rsqstr[exmpli]:
- # continue
- #print(seedi, exmpli)
- th = -b / mm # threshold, i.e., x intercept
- ## plot the sample's PSTHs:
- f, a = plt.subplots(figsize=psthfigsize)
- title = '%s %s %s PSTH %s model' % (mseustr, kind, st8, model)
- wintitle(title, f)
- color = st82clr[st8]
- tss = t[::ssx] # subsampled in time
- # plot control and opto test PSTHs:
- opto2psth = {False:ctrltestpsth, True:optotestpsth}
- for opto in OPTOS:
- psth = opto2psth[opto]
- c = desat(color, opto2alpha[opto]) # do manual alpha mixing
- fb = opto2fb[opto].title()
- if opto == False:
- ls = '--'
- lw = 1
- else:
- ls = '-'
- lw = 2
- br = mviresp.loc[mseustr, kind, st8, opto]['meanburstratio']
- print('%s exampli=%d %s %s meanburstratio=%g'
- % (mseustr, exmpli, st8, opto, br))
- a.plot(tss, psth, ls=ls, lw=lw, marker='None', color=c, label=fb)
- # plot opto model:
- modeltestpsth = mm * ctrltestpsth + b
- a.plot(tss, modeltestpsth, '-', lw=2, color=optoblue, alpha=1,
- label='Suppression model')
- # plot opto burst test PSTH:
- bc = desat(burstclr, 0.3) # do manual alpha mixing
- #bc = desat(burstclr, opto2alpha[opto]) # do manual alpha mixing
- a.plot(tss, optotestburstpsth, ls=ls, lw=lw, marker='None', color=bc,
- label=fb+' burst', zorder=-99)
- l = a.legend(frameon=False)
- l.set_draggable(True)
- a.set_xlabel('Time (s)')
- a.set_ylabel('Firing rate (spk/s)')
- a.set_xlim(tmin, tmax)
- # plot horizontal line signifying stimulus period, just below data:
- #ymin, ymax = a.get_ylim()
- #a.hlines(-ymax*0.025, 0, dt, colors='k', lw=4, clip_on=False, in_layout=False)
- #a.set_ylim(ymin, ymax) # restore y limits from before plotting the hline
- ## scatter plot suppression vs. feedback for the sample's PSTH timepoints,
- ## plus model fit line:
- scale = 'linear' # log or linear
- logmin, logmax = -2, 2
- log0min = logmin + 0.05
- # for nat none all-spike plots for fig2:
- examplemaxrate = {'PVCre_2018_0003_s03_e03_u51':35,
- 'PVCre_2017_0008_s12_e06_u56':140}
- if kind == 'nat' and st8 == 'none':
- linmax = examplemaxrate[mseustr]
- else:
- linmax = np.vstack([ctrltestpsth, optotestpsth]).max()
- figsize = DEFAULTFIGURESIZE
- if linmax >= 100:
- figsize = figsize[0]*1.025, figsize[1] # tweak to make space extra y axis digit
- f, a = plt.subplots(figsize=figsize)
- titlestr = ('%s %s %s PSTH scatter %s' % (mseustr, kind, st8, model))
- wintitle(titlestr, f)
- if scale == 'log':
- # replace off-scale low values with log0min, so the points remain visible:
- ctrltestpsth[ctrltestpsth <= 10**logmin] = 10**log0min
- optotestpsth[optotestpsth <= 10**logmin] = 10**log0min
- # plot y=x line:
- if scale == 'log':
- xyline = [10**logmin, 10**logmax], [10**logmin, 10**logmax]
- else:
- xyline = [linmin, linmax], [linmin, linmax]
- a.plot(xyline[0], xyline[1], '--', clip_on=True, color='gray', zorder=100)
- # plot testpsth vs control:
- texts = []
- c = st82clr[st8]
- a.plot(ctrltestpsth, optotestpsth, '.', clip_on=True, ms=2, color=c, label=st8)
- x = np.array([ctrltestpsth.min(), ctrltestpsth.max()])
- y = mm * x + b # model output line
- a.plot(x, y, '-', color=optoblue, clip_on=True, alpha=1, lw=2) # plot model fit
- # display the sample's parameters:
- txt = ('$\mathregular{R^{2}=%.2f}$\n'
- '$\mathregular{Slope=%.2f}$\n'
- '$\mathregular{Thr.=%.1f}$' % (rsq, mm, th))
- a.add_artist(AnchoredText(txt, loc='upper left', frameon=False))
- a.set_xlabel('Feedback rate (spk/s)')
- a.set_ylabel('Suppression rate (spk/s)')
- if scale == 'log':
- a.set_xscale('log')
- a.set_yscale('log')
- xymin, xymax = 10**logmin, 10**logmax
- ticks = 10**(np.arange(logmin, logmax+1, dtype=float))
- else:
- xymin, xymax = linmin, linmax
- ticks = a.get_xticks()
- a.set_xticks(ticks) # seem to need to set both x and y to force them to match
- a.set_yticks(ticks)
- a.set_xlim(xymin, xymax)
- a.set_ylim(xymin, xymax)
- a.set_aspect('equal')
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- # scatter plot resampled model slope vs threshold, for only those fits with decent SNR
- # in at least one opto state:
- figsize = DEFAULTFIGURESIZE
- slopelogmin, slopelogmax = -4, 1
- logyticks = np.logspace(slopelogmin, slopelogmax, num=(slopelogmax-slopelogmin+1), base=2)
- threshmin, threshmax = -25, 25
- for kind in ['nat']:#MVIKINDS:
- for st8 in ['none']:#ALLST8S:
- for fmode in fmodes: # iterate over firing modes (all, non-burst, burst)
- print('fmode:', fmode)
- # boolean pandas Series:
- rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
- rows = sampfitsr[rowis]
- mmeds, thmeds, rsqmeds = [], [], [] # median fit values
- brs, snrs = [], [] # single burst ratio and SNR values matching each mseu
- sgnfis, insgnfis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
- keptmseui = 0 # manually init and increment instead of using enumerate()
- for i, row in rows.iterrows(): # for each mseu
- mseustr = row['mseu']
- ms, ths, rsqs = row[fmode+'ms'], row[fmode+'ths'], row[fmode+'rsqs']
- if np.isnan([ms, ths, rsqs]).any():
- continue # skip this mseu
- # skip mseus that had no meaningful signal to fit in either condition:
- maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
- if maxsnr < SNRTHRESH:
- continue
- mmed, ml, mh = percentile_ci(ms)
- thmed, thl, thh = percentile_ci(ths)
- rsqmed = np.median(rsqs)
- if rsqmed < 0: # catastrophic fit failure
- print('catastrophic fit failure:', mseustr)
- continue # skip
- mmeds.append(mmed)
- thmeds.append(thmed)
- rsqmeds.append(rsqmed)
- # collect suppression mean burst ratio and SNR:
- mvirowis = ((mvirespr['mseu'] == mseustr) &
- (mvirespr['kind'] == kind) &
- (mvirespr['st8'] == st8) &
- (mvirespr['opto'] == True)) # only for suppression
- mvirow = mvirespr[mvirowis]
- assert len(mvirow) == 1
- br = mvirow['meanburstratio'].iloc[0]
- snr = mvirow['snr'].iloc[0]
- brs.append(br)
- snrs.append(snr)
- fig2['slope'][mseustr, fmode2txt[fmode]] = mmed
- fig2['thresh'][mseustr, fmode2txt[fmode]] = thmed
- fig2['rsq'][mseustr, fmode2txt[fmode]] = rsqmed
- fig2['suppression_meanburstratio'][mseustr, fmode2txt[fmode]] = br
- fig2['suppression_snr'][mseustr, fmode2txt[fmode]] = snr
- # collect separate lists of significant and insignificant points:
- if not (ml < 1 < mh):# and not (thl < 0 < thh):
- sgnfis.append(keptmseui) # slope significantly different from 1
- else:
- insgnfis.append(keptmseui) # slope not significantly different from 1
- if mseustr in mvimseu2exmpli:
- exmplis.append(keptmseui)
- exmplmseustrs.append(mseustr)
- print('Example Neuron', mvimseu2exmpli[mseustr], ':', mseustr)
- print('mmed, ml, mh:', mmed, ml, mh)
- print('thmed, thl, thh:', thmed, thl, thh)
- else:
- normlis.append(keptmseui)
- keptmseui += 1 # manually increment
- mmeds = np.asarray(mmeds)
- thmeds = np.asarray(thmeds)
- rsqmeds = np.asarray(rsqmeds)
- brs = np.asarray(brs)
- snrs = np.asarray(snrs)
- ## plot median fit slope vs thresh:
- f, a = plt.subplots(figsize=(figsize[0]*1.11, figsize[1])) # extra space for ylabels
- titlestr = ('PSTH %s fit slope vs thresh %s %s snrthresh=%.3f %s'
- % (model, kind, st8, SNRTHRESH, fmode2txt[fmode]))
- wintitle(titlestr, f)
- normlinsgnfis = np.intersect1d(normlis, insgnfis)
- normlsgnfis = np.intersect1d(normlis, sgnfis)
- # plot x=0 and y=1 lines:
- a.axvline(x=0, ls='--', marker='', color='gray', zorder=-np.inf)
- a.axhline(y=1, ls='--', marker='', color='gray', zorder=-np.inf)
- # plot normal (non-example) insignificant points:
- c = desat(st82clr[st8], SGNF2ALPHA[False]) # do manual alpha mixing
- a.scatter(thmeds[normlinsgnfis], mmeds[normlinsgnfis], clip_on=False,
- marker='.', c='None', edgecolor=c, s=DEFSZ)
- # plot normal (non-example) significant points:
- c = desat(st82clr[st8], SGNF2ALPHA[True]) # do manual alpha mixing
- a.scatter(thmeds[normlsgnfis], mmeds[normlsgnfis], clip_on=False,
- marker='.', c='None', edgecolor=c, s=DEFSZ)
- exmplinsgnfis = np.intersect1d(exmplis, insgnfis)
- exmplsgnfis = np.intersect1d(exmplis, sgnfis)
- print('exmplinsgnfis', exmplinsgnfis)
- print('examplsgnfis', exmplsgnfis)
- # plot insignificant and significant example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- if exmpli in exmplinsgnfis:
- alpha = SGNF2ALPHA[False]
- elif exmpli in exmplsgnfis:
- alpha = SGNF2ALPHA[True]
- else:
- raise RuntimeError("Some kind of exmpli set membership error")
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(thmeds[exmpli], mmeds[exmpli], clip_on=False, marker=marker, c=c,
- s=sz, lw=lw, alpha=alpha)
- # plot mean median point:
- if fmode == '': # all spikes, plot LMM mean
- print('plotting LMM mean for fmode=%s' % fmode)
- a.scatter(-0.19, 0.75, # read off of stats/fig2*.pdf
- c='red', edgecolor='red', s=50, marker='^')
- elif fmode == 'nb': # non burst spikes, plot LMM mean
- print('plotting LMM mean for fmode=%s' % fmode)
- a.scatter(0.09, 0.74, # read off of stats/fig2*.pdf
- c='red', edgecolor='red', s=50, marker='^')
- else:
- a.scatter(np.mean(thmeds), gmean(mmeds),
- c='red', edgecolor='red', s=50, marker='^')
- # display median of median rsq sample values:
- txt = '$\mathregular{R^{2}_{med}=}$%.2f' % np.round(np.median(rsqmeds), 2)
- a.add_artist(AnchoredText(txt, loc='lower right', frameon=False))
- #cbar = f.colorbar(path)
- #cbar.ax.set_xlabel('$\mathregular{R^{2}}$')
- #cbar.ax.xaxis.set_label_position('top')
- a.set_yscale('log', basey=2)
- a.set_xlabel('Threshold')
- a.set_ylabel('Slope')
- a.set_xlim(threshmin, threshmax)
- a.set_xticks([threshmin, 0, threshmax])
- a.set_ylim(2**slopelogmin, 2**slopelogmax)
- a.set_yticks(logyticks)
- axes_disable_scientific(a)
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- ## plot fit rsq vs suppression meanburstratio:
- f, a = plt.subplots(figsize=figsize)
- titlestr = ('PSTH %s fit rsq vs suppression meanburstratio %s %s %s' %
- (model, kind, st8, fmode2txt[fmode]))
- wintitle(titlestr, f)
- a.scatter(brs[normlis], rsqmeds[normlis], clip_on=False,
- marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
- # plot example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(brs[exmpli], rsqmeds[exmpli], clip_on=False,
- marker=marker, c=c, s=sz, lw=lw)
- # get fname of appropriate LMM .cvs file:
- if fmode == '': # use Steffen's LMM linregress fit, for fig2f
- fname = os.path.join('stats', 'figure_2f_coefs.csv')
- # fetch LMM linregress fit params from .csv:
- df = pd.read_csv(fname)
- mm = df['slope'][0]
- b = df['intercept'][0]
- x = np.array([np.nanmin(brs), np.nanmax(brs)])
- y = mm * x + b
- a.plot(x, y, '-', color='red') # plot linregress fit
- a.set_xlabel('Suppression BR')
- a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
- a.set_xlim(xmin=0)
- a.set_ylim(0, 1)
- a.set_yticks([0, 0.5, 1])
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- ## plot fit rsq vs suppression SNR:
- f, a = plt.subplots(figsize=figsize)
- titlestr = ('PSTH %s fit rsq vs SNR %s %s %s' %
- (model, kind, st8, fmode2txt[fmode]))
- wintitle(titlestr, f)
- a.scatter(snrs[normlis], rsqmeds[normlis], clip_on=False,
- marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
- # plot example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(snrs[exmpli], rsqmeds[exmpli], clip_on=False,
- marker=marker, c=c, s=sz, lw=lw)
- a.set_xlabel('Suppression SNR')
- a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
- a.set_xlim(xmin=0)
- a.set_ylim(0, 1)
- a.set_yticks([0, 0.5, 1])
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- ## plot suppression fit SNR vs meanburstratio:
- f, a = plt.subplots(figsize=figsize)
- titlestr = ('PSTH %s fit SNR vs meanburstratio %s %s %s' %
- (model, kind, st8, fmode2txt[fmode]))
- wintitle(titlestr, f)
- a.scatter(brs[normlis], snrs[normlis], clip_on=False,
- marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
- # plot example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(brs[exmpli], snrs[exmpli], clip_on=False,
- marker=marker, c=c, s=sz, lw=lw)
- a.set_xlabel('Suppression BR')
- a.set_ylabel('Suppression SNR')
- a.set_xlim(xmin=0)
- a.set_ylim(ymin=0)
- a.set_yticks([0, 0.5, 1])
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- # scatter plot resampled model fit rsq for nonburst vs all spikes, and nonrand vs all spikes:
- figsize = DEFAULTFIGURESIZE
- for kind in ['nat']:#MVIKINDS:
- for st8 in ['none']:#ALLST8S:
- # boolean pandas Series:
- rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
- rows = sampfitsr[rowis]
- allrsqmeds, nbrsqmeds, nrrsqmeds = [], [], [] # median fit rsq values
- exmplis, exmplmseustrs, normlis = [], [], []
- keptmseui = 0 # manually init and increment instead of using enumerate()
- for i, row in rows.iterrows(): # for each mseu
- mseustr = row['mseu']
- allrsqs, nbrsqs, nrrsqs = row['rsqs'], row['nbrsqs'], row['nrrsqs']
- if np.isnan([allrsqs, nbrsqs, nrrsqs]).any():
- continue # skip this mseu
- # skip mseus that had no meaningful signal to fit in either condition:
- maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
- if maxsnr < SNRTHRESH:
- continue
- allrsqmed = np.median(allrsqs)
- nbrsqmed = np.median(nbrsqs)
- nrrsqmed = np.median(nrrsqs)
- if allrsqmed < 0 or nbrsqmed < 0 or nrrsqmed < 0: # catastrophic fit failure
- print('catastrophic fit failure:', mseustr)
- continue # skip
- allrsqmeds.append(allrsqmed)
- nbrsqmeds.append(nbrsqmed)
- nrrsqmeds.append(nrrsqmed)
- if mseustr in mvimseu2exmpli:
- exmplis.append(keptmseui)
- exmplmseustrs.append(mseustr)
- else:
- normlis.append(keptmseui)
- if nbrsqmed > 0.4 and allrsqmed < 0.4:
- print(mseustr, allrsqmed, nbrsqmed)
- keptmseui += 1 # manually increment
- allrsqmeds = np.asarray(allrsqmeds)
- nbrsqmeds = np.asarray(nbrsqmeds)
- nrrsqmeds = np.asarray(nrrsqmeds)
- # plot nonburst vs all rsq:
- f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
- titlestr = 'PSTH %s fit rsq nonburst vs all %s %s' % (model, kind, st8)
- wintitle(titlestr, f)
- linmax = np.vstack([allrsqmeds, nbrsqmeds]).max()
- xyline = [linmin, linmax], [linmin, linmax]
- a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
- a.scatter(allrsqmeds[normlis], nbrsqmeds[normlis], clip_on=False,
- marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
- # plot example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(allrsqmeds[exmpli], nbrsqmeds[exmpli], clip_on=False,
- marker=marker, c=c, s=sz, lw=lw)
- a.set_xlabel('All spikes $\mathregular{R^{2}}$')
- a.set_ylabel('Non-burst spikes $\mathregular{R^{2}}$')
- a.set_xlim(linmin, 1)
- a.set_ylim(linmin, 1)
- a.set_xticks([0, 0.5, 1])
- a.set_yticks([0, 0.5, 1])
- a.set_aspect('equal')
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
- # plot nr vs all rsq:
- f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
- titlestr = 'PSTH %s fit rsq nr vs all %s %s' % (model, kind, st8)
- wintitle(titlestr, f)
- linmax = np.vstack([allrsqmeds, nrrsqmeds]).max()
- xyline = [linmin, linmax], [linmin, linmax]
- a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
- a.scatter(allrsqmeds[normlis], nrrsqmeds[normlis], clip_on=False,
- marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
- # plot example points, one at a time:
- for exmpli, mseustr in zip(exmplis, exmplmseustrs):
- marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
- c = exmpli2clr[mvimseu2exmpli[mseustr]]
- sz = exmpli2sz[mvimseu2exmpli[mseustr]]
- lw = exmpli2lw[mvimseu2exmpli[mseustr]]
- a.scatter(allrsqmeds[exmpli], nrrsqmeds[exmpli], clip_on=False,
- marker=marker, c=c, s=sz, lw=lw)
- a.set_xlabel('All spikes $\mathregular{R^{2}}$')
- a.set_ylabel('Rand. rem. spikes $\mathregular{R^{2}}$')
- a.set_xlim(linmin, 1)
- a.set_ylim(linmin, 1)
- a.set_xticks([0, 0.5, 1])
- a.set_yticks([0, 0.5, 1])
- a.set_aspect('equal')
- a.spines['left'].set_position(('outward', 4))
- a.spines['bottom'].set_position(('outward', 4))
|