fig2.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. """Figure 2 plots, use run -i fig2.py"""
  2. fmodestxt = list(fmode2txt.values()) # 'all', 'nonburst', 'burst', 'nonrand'
  3. mi = pd.MultiIndex.from_product([mvimseustrs, fmodestxt], names=['mseu', 'fmode'])
  4. fig2 = pd.DataFrame(index=mi, columns=['slope', 'thresh', 'rsq',
  5. 'suppression_meanburstratio', 'suppression_snr'])
  6. # convert multi index to columns for easier to read restriction syntax:
  7. sampfitsr = sampfits.reset_index()
  8. mvirespr = mviresp.reset_index()
  9. linmin = 0
  10. # plot single-sample V1 model-fit PSTH overtop of test feedback and suppression PSTHs:
  11. ## TODO: correlate burst PSTHs with (model-empirical)**2, scatter plot them against each other,
  12. ## or just plot distribution of corrcoefs
  13. dt = 5 # stimulus duration sec
  14. #tmin, tmax = 0 + OFFSETS[0], dt + OFFSETS[1]
  15. tmin, tmax = 0, dt
  16. #figsize = 0.855 + tmax*1.5*RASTERSCALEX, PSTHHEIGHT
  17. psthfigsize = 0.855 + tmax*2.1*RASTERSCALEX, PSTHHEIGHT
  18. # medians:
  19. # n1: rsq=0.29, slope=0.37, thr=-0.14
  20. # n2: rsq=0.90, slope=0.67, thr=1.58
  21. #for seedi in [110, 199, 339, 388]: # seeds that result in rsq identical to median
  22. np.random.seed(199) # fix random seed for identical results from np.random.choice() on each run
  23. for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs:
  24. for kind in ['nat']:#MVIKINDS:
  25. for st8 in ['none']:#ALLST8S:
  26. # get one sample of trials, separate into fit and test data:
  27. ctrlraster = mviresp.loc[mseustr, kind, st8, False]['raster'] # narrow
  28. optoraster = mviresp.loc[mseustr, kind, st8, True]['raster'] # narrow
  29. optoburstraster = mviresp.loc[mseustr, kind, st8, True]['braster'] # narrow
  30. nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none'
  31. noptotrials = len(optoraster) # typically 200 trials when st8 == 'none'
  32. for ntrials in [nctrltrials, noptotrials]:
  33. if ntrials < MINNTRIALSAMPLETHRESH:
  34. print('Not enough trials to sample:', mseustr, kind, st8)
  35. continue # don't bother sampling, leave entry in sampfits empty
  36. ctrltrialis = np.arange(nctrltrials)
  37. optotrialis = np.arange(noptotrials)
  38. ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing
  39. optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing
  40. # probably doesn't matter if use ctrl or opto bins, should be the same:
  41. bins = mviresp.loc[mseustr, kind, st8, False]['bins']
  42. t = mviresp.loc[mseustr, kind, st8, False]['t'] # mid bins
  43. # randomly sample half the ctrl and opto trials, without replacement:
  44. ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False))
  45. optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False))
  46. ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement
  47. optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement
  48. ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis]
  49. optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis]
  50. optotestburstraster = optoburstraster[optotestis]
  51. # calculate fit and test opto PSTHs, subsampled in time:
  52. ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx]
  53. optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx]
  54. ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx]
  55. optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx]
  56. optotestburstpsth = raster2psth(optotestburstraster, bins, binw, tres, kernel)[::ssx]
  57. mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth,
  58. model=model)
  59. #rsqstr = '%.2f' % np.round(rsq, 2)
  60. #exmpli2rsqstr = {1:'0.29', 2:'0.90'}
  61. #if rsqstr != exmpli2rsqstr[exmpli]:
  62. # continue
  63. #print(seedi, exmpli)
  64. th = -b / mm # threshold, i.e., x intercept
  65. ## plot the sample's PSTHs:
  66. f, a = plt.subplots(figsize=psthfigsize)
  67. title = '%s %s %s PSTH %s model' % (mseustr, kind, st8, model)
  68. wintitle(title, f)
  69. color = st82clr[st8]
  70. tss = t[::ssx] # subsampled in time
  71. # plot control and opto test PSTHs:
  72. opto2psth = {False:ctrltestpsth, True:optotestpsth}
  73. for opto in OPTOS:
  74. psth = opto2psth[opto]
  75. c = desat(color, opto2alpha[opto]) # do manual alpha mixing
  76. fb = opto2fb[opto].title()
  77. if opto == False:
  78. ls = '--'
  79. lw = 1
  80. else:
  81. ls = '-'
  82. lw = 2
  83. br = mviresp.loc[mseustr, kind, st8, opto]['meanburstratio']
  84. print('%s exampli=%d %s %s meanburstratio=%g'
  85. % (mseustr, exmpli, st8, opto, br))
  86. a.plot(tss, psth, ls=ls, lw=lw, marker='None', color=c, label=fb)
  87. # plot opto model:
  88. modeltestpsth = mm * ctrltestpsth + b
  89. a.plot(tss, modeltestpsth, '-', lw=2, color=optoblue, alpha=1,
  90. label='Suppression model')
  91. # plot opto burst test PSTH:
  92. bc = desat(burstclr, 0.3) # do manual alpha mixing
  93. #bc = desat(burstclr, opto2alpha[opto]) # do manual alpha mixing
  94. a.plot(tss, optotestburstpsth, ls=ls, lw=lw, marker='None', color=bc,
  95. label=fb+' burst', zorder=-99)
  96. l = a.legend(frameon=False)
  97. l.set_draggable(True)
  98. a.set_xlabel('Time (s)')
  99. a.set_ylabel('Firing rate (spk/s)')
  100. a.set_xlim(tmin, tmax)
  101. # plot horizontal line signifying stimulus period, just below data:
  102. #ymin, ymax = a.get_ylim()
  103. #a.hlines(-ymax*0.025, 0, dt, colors='k', lw=4, clip_on=False, in_layout=False)
  104. #a.set_ylim(ymin, ymax) # restore y limits from before plotting the hline
  105. ## scatter plot suppression vs. feedback for the sample's PSTH timepoints,
  106. ## plus model fit line:
  107. scale = 'linear' # log or linear
  108. logmin, logmax = -2, 2
  109. log0min = logmin + 0.05
  110. # for nat none all-spike plots for fig2:
  111. examplemaxrate = {'PVCre_2018_0003_s03_e03_u51':35,
  112. 'PVCre_2017_0008_s12_e06_u56':140}
  113. if kind == 'nat' and st8 == 'none':
  114. linmax = examplemaxrate[mseustr]
  115. else:
  116. linmax = np.vstack([ctrltestpsth, optotestpsth]).max()
  117. figsize = DEFAULTFIGURESIZE
  118. if linmax >= 100:
  119. figsize = figsize[0]*1.025, figsize[1] # tweak to make space extra y axis digit
  120. f, a = plt.subplots(figsize=figsize)
  121. titlestr = ('%s %s %s PSTH scatter %s' % (mseustr, kind, st8, model))
  122. wintitle(titlestr, f)
  123. if scale == 'log':
  124. # replace off-scale low values with log0min, so the points remain visible:
  125. ctrltestpsth[ctrltestpsth <= 10**logmin] = 10**log0min
  126. optotestpsth[optotestpsth <= 10**logmin] = 10**log0min
  127. # plot y=x line:
  128. if scale == 'log':
  129. xyline = [10**logmin, 10**logmax], [10**logmin, 10**logmax]
  130. else:
  131. xyline = [linmin, linmax], [linmin, linmax]
  132. a.plot(xyline[0], xyline[1], '--', clip_on=True, color='gray', zorder=100)
  133. # plot testpsth vs control:
  134. texts = []
  135. c = st82clr[st8]
  136. a.plot(ctrltestpsth, optotestpsth, '.', clip_on=True, ms=2, color=c, label=st8)
  137. x = np.array([ctrltestpsth.min(), ctrltestpsth.max()])
  138. y = mm * x + b # model output line
  139. a.plot(x, y, '-', color=optoblue, clip_on=True, alpha=1, lw=2) # plot model fit
  140. # display the sample's parameters:
  141. txt = ('$\mathregular{R^{2}=%.2f}$\n'
  142. '$\mathregular{Slope=%.2f}$\n'
  143. '$\mathregular{Thr.=%.1f}$' % (rsq, mm, th))
  144. a.add_artist(AnchoredText(txt, loc='upper left', frameon=False))
  145. a.set_xlabel('Feedback rate (spk/s)')
  146. a.set_ylabel('Suppression rate (spk/s)')
  147. if scale == 'log':
  148. a.set_xscale('log')
  149. a.set_yscale('log')
  150. xymin, xymax = 10**logmin, 10**logmax
  151. ticks = 10**(np.arange(logmin, logmax+1, dtype=float))
  152. else:
  153. xymin, xymax = linmin, linmax
  154. ticks = a.get_xticks()
  155. a.set_xticks(ticks) # seem to need to set both x and y to force them to match
  156. a.set_yticks(ticks)
  157. a.set_xlim(xymin, xymax)
  158. a.set_ylim(xymin, xymax)
  159. a.set_aspect('equal')
  160. a.spines['left'].set_position(('outward', 4))
  161. a.spines['bottom'].set_position(('outward', 4))
  162. '''
  163. ## TODO: plot intercept normalized by ctrl peak, or somehow, for unitless intercepts distrib
  164. ## TODO: highlight 3rd example neuron in all scatter plots
  165. ## TODO: add burst distrib as well
  166. # strip plot distributions of linear fit params, keep only those fits with a decent rsq:
  167. RSQTHRESH = 0.2#0.4
  168. keepis = (fits['rsq'] >= RSQTHRESH).values
  169. keepfits = fits.loc[keepis]
  170. resetkeepfits = keepfits.reset_index() # convert mi to columns for sns
  171. figsize = DEFAULTFIGURESIZE
  172. for kind in MVIKINDS:
  173. for par in ['slope', 'intercept', 'rsq']:
  174. rowis = (resetkeepfits['kind'] == kind) # boolean pandas Series
  175. if not rowis.any():
  176. print('No data to plot for par, kind = %s, %s, skipping' % (par, kind))
  177. continue
  178. data = resetkeepfits[rowis]
  179. # do paired t-test for par:
  180. runvals, sitvals, nonevals = [], [], []
  181. submseustrs = data['mseu'].unique()
  182. for mseustr in submseustrs:
  183. rows = keepfits.loc[mseustr, kind]
  184. if not np.all([ st8 in rows.index for st8 in ALLST8S ]): # not all st8s in rows
  185. continue
  186. if rows.isna().any().any(): # two any()'s: one for st8, one for fit param
  187. continue
  188. # rows has non-NaN entries for run, sit and none
  189. runvals.append(rows.loc['run'][par])
  190. sitvals.append(rows.loc['sit'][par])
  191. nonevals.append(rows.loc['none'][par])
  192. nunits = len(runvals) # number of units that survived
  193. assert nunits == len(sitvals) == len(nonevals)
  194. if nunits == 0:
  195. print('No data to plot for par, kind = %s, %s, skipping' % (par, kind))
  196. continue
  197. t, p = ttest_rel(runvals, sitvals) # paired t-test
  198. # make a strip plot:
  199. f, a = plt.subplots(figsize=figsize)
  200. titlestr = 'PSTH %s fit %s %s rsqthresh=%.1f' % (model, kind, par, RSQTHRESH)
  201. wintitle(titlestr, f)
  202. sns.stripplot(x="st8", y=par, data=data, palette=st82clr, jitter=True, size=3)
  203. a.set_xlabel('')
  204. ylabel = par
  205. if ylabel == 'intercept':
  206. ylabel += ' (spk/s)'
  207. ylabel = ylabel.capitalize()
  208. a.set_ylabel(ylabel)
  209. a.add_artist(AnchoredText('p=%.1e' % p, loc='upper left', frameon=False))
  210. # connect the dots:
  211. x = np.array([[0]*nunits, [1]*nunits, [2]*nunits])
  212. y = np.array([runvals, sitvals, nonevals])
  213. a.plot(x, y, '-', c='k', alpha=0.2, lw=1)
  214. # due to jitter, dots don't perfectly connect. Can get actual data using:
  215. #runxy, sitxy = a.collections[0].get_offsets(), a.collections[1].get_offsets()
  216. # but some of those points aren't paired, which makes it complicated
  217. # plot mean with short horizontal lines:
  218. meanrun, meansit, meannone = y.mean(axis=1)
  219. a.plot([-0.25, 0.25], [meanrun, meanrun], '-', lw=2, c=st82clr['run'], zorder=np.inf)
  220. a.plot([0.75, 1.25], [meansit, meansit], '-', lw=2, c=st82clr['sit'], zorder=np.inf)
  221. a.plot([1.75, 2.25], [meannone, meannone], '-', lw=2, c=st82clr['none'], zorder=np.inf)
  222. # test if slope and intercept are significantly different from 1 and 0, respectively:
  223. if par == 'slope':
  224. a.axhline(y=1, color='gray', ls='--', marker='')
  225. runt1samp, runp1samp = ttest_1samp(runvals, 1)
  226. sitt1samp, sitp1samp = ttest_1samp(sitvals, 1)
  227. nonet1samp, nonep1samp = ttest_1samp(nonevals, 1)
  228. print('%s slope != 1 test: runp1samp=%.2e, sitp1samp=%.2e, nonep1samp=%.2e'
  229. % (kind, runp1samp, sitp1samp, nonep1samp))
  230. a.set_ylim(top=1.6)
  231. elif par == 'intercept':
  232. a.axhline(y=0, color='gray', ls='--', marker='')
  233. runt1samp, runp1samp = ttest_1samp(runvals, 0)
  234. sitt1samp, sitp1samp = ttest_1samp(sitvals, 0)
  235. nonet1samp, nonep1samp = ttest_1samp(nonevals, 0)
  236. print('%s intercept != 0 test: runp1samp=%.2e, sitp1samp=%.2e, nonep1samp=%.2e'
  237. % (kind, runp1samp, sitp1samp, nonep1samp))
  238. a.set_ylim(top=7, bottom=-7)
  239. # plot PDFs for nat movie none st8 only:
  240. if kind == 'nat':
  241. f, pdfa = plt.subplots(figsize=figsize)
  242. titlestr = 'PSTH %s fit %s none %s PDF rsqthresh=%.1f' % (model, kind, par, RSQTHRESH)
  243. wintitle(titlestr, f)
  244. pdfa.hist(nonevals, bins=20, density=False, color=st82clr['none'])
  245. pdfa.set_xlabel(ylabel)
  246. pdfa.set_ylabel('Unit count')
  247. '''
  248. # scatter plot resampled model slope vs threshold, for only those fits with decent SNR
  249. # in at least one opto state:
  250. figsize = DEFAULTFIGURESIZE
  251. slopelogmin, slopelogmax = -4, 1
  252. logyticks = np.logspace(slopelogmin, slopelogmax, num=(slopelogmax-slopelogmin+1), base=2)
  253. threshmin, threshmax = -25, 25
  254. for kind in ['nat']:#MVIKINDS:
  255. for st8 in ['none']:#ALLST8S:
  256. for fmode in fmodes: # iterate over firing modes (all, non-burst, burst)
  257. print('fmode:', fmode)
  258. # boolean pandas Series:
  259. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  260. rows = sampfitsr[rowis]
  261. mmeds, thmeds, rsqmeds = [], [], [] # median fit values
  262. brs, snrs = [], [] # single burst ratio and SNR values matching each mseu
  263. sgnfis, insgnfis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  264. keptmseui = 0 # manually init and increment instead of using enumerate()
  265. for i, row in rows.iterrows(): # for each mseu
  266. mseustr = row['mseu']
  267. ms, ths, rsqs = row[fmode+'ms'], row[fmode+'ths'], row[fmode+'rsqs']
  268. if np.isnan([ms, ths, rsqs]).any(): ## TODO: does this throw out too many units?
  269. continue # skip this mseu
  270. # skip mseus that had no meaningful signal to fit in either condition:
  271. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  272. if maxsnr < SNRTHRESH:
  273. continue
  274. mmed, ml, mh = percentile_ci(ms)
  275. thmed, thl, thh = percentile_ci(ths)
  276. rsqmed = np.median(rsqs)
  277. #if rsqmed < RSQTHRESH: # skip mseus with really bad median fits
  278. # continue # skip
  279. if rsqmed < 0: # catastrophic fit failure
  280. print('catastrophic fit failure:', mseustr)
  281. continue # skip
  282. mmeds.append(mmed)
  283. thmeds.append(thmed)
  284. rsqmeds.append(rsqmed)
  285. # collect suppression mean burst ratio and SNR:
  286. mvirowis = ((mvirespr['mseu'] == mseustr) &
  287. (mvirespr['kind'] == kind) &
  288. (mvirespr['st8'] == st8) &
  289. (mvirespr['opto'] == True)) # only for suppression
  290. mvirow = mvirespr[mvirowis]
  291. assert len(mvirow) == 1
  292. br = mvirow['meanburstratio'].iloc[0]
  293. snr = mvirow['snr'].iloc[0]
  294. brs.append(br)
  295. snrs.append(snr)
  296. fig2['slope'][mseustr, fmode2txt[fmode]] = mmed
  297. fig2['thresh'][mseustr, fmode2txt[fmode]] = thmed
  298. fig2['rsq'][mseustr, fmode2txt[fmode]] = rsqmed
  299. fig2['suppression_meanburstratio'][mseustr, fmode2txt[fmode]] = br
  300. fig2['suppression_snr'][mseustr, fmode2txt[fmode]] = snr
  301. # collect separate lists of significant and insignificant points:
  302. if not (ml < 1 < mh):# and not (thl < 0 < thh):
  303. sgnfis.append(keptmseui) # slope significantly different from 1
  304. else:
  305. insgnfis.append(keptmseui) # slope not significantly different from 1
  306. if mseustr in mvimseu2exmpli:
  307. exmplis.append(keptmseui)
  308. exmplmseustrs.append(mseustr)
  309. print('Example Neuron', mvimseu2exmpli[mseustr], ':', mseustr)
  310. print('mmed, ml, mh:', mmed, ml, mh)
  311. print('thmed, thl, thh:', thmed, thl, thh)
  312. else:
  313. normlis.append(keptmseui)
  314. keptmseui += 1 # manually increment
  315. mmeds = np.asarray(mmeds)
  316. thmeds = np.asarray(thmeds)
  317. rsqmeds = np.asarray(rsqmeds)
  318. brs = np.asarray(brs)
  319. snrs = np.asarray(snrs)
  320. ## plot median fit slope vs thresh:
  321. f, a = plt.subplots(figsize=(figsize[0]*1.11, figsize[1])) # extra space for ylabels
  322. titlestr = ('PSTH %s fit slope vs thresh %s %s snrthresh=%.3f %s'
  323. % (model, kind, st8, SNRTHRESH, fmode2txt[fmode]))
  324. wintitle(titlestr, f)
  325. normlinsgnfis = np.intersect1d(normlis, insgnfis)
  326. normlsgnfis = np.intersect1d(normlis, sgnfis)
  327. # plot x=0 and y=1 lines:
  328. a.axvline(x=0, ls='--', marker='', color='gray', zorder=-np.inf)
  329. a.axhline(y=1, ls='--', marker='', color='gray', zorder=-np.inf)
  330. # plot normal (non-example) insignificant points:
  331. c = desat(st82clr[st8], SGNF2ALPHA[False]) # do manual alpha mixing
  332. a.scatter(thmeds[normlinsgnfis], mmeds[normlinsgnfis], clip_on=False,
  333. marker='.', c='None', edgecolor=c, s=DEFSZ)
  334. # plot normal (non-example) significant points:
  335. c = desat(st82clr[st8], SGNF2ALPHA[True]) # do manual alpha mixing
  336. a.scatter(thmeds[normlsgnfis], mmeds[normlsgnfis], clip_on=False,
  337. marker='.', c='None', edgecolor=c, s=DEFSZ)
  338. exmplinsgnfis = np.intersect1d(exmplis, insgnfis)
  339. exmplsgnfis = np.intersect1d(exmplis, sgnfis)
  340. print('exmplinsgnfis', exmplinsgnfis)
  341. print('examplsgnfis', exmplsgnfis)
  342. # plot insignificant and significant example points, one at a time:
  343. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  344. if exmpli in exmplinsgnfis:
  345. alpha = SGNF2ALPHA[False]
  346. elif exmpli in exmplsgnfis:
  347. alpha = SGNF2ALPHA[True]
  348. else:
  349. raise RuntimeError("Some kind of exmpli set membership error")
  350. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  351. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  352. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  353. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  354. a.scatter(thmeds[exmpli], mmeds[exmpli], clip_on=False, marker=marker, c=c,
  355. s=sz, lw=lw, alpha=alpha)
  356. # plot mean median point:
  357. if fmode == '': # all spikes, plot LMM mean
  358. print('plotting LMM mean for fmode=%s' % fmode)
  359. a.scatter(-0.19, 0.75, # read off of stats/fig2*.pdf
  360. c='red', edgecolor='red', s=50, marker='^')
  361. elif fmode == 'nb': # non burst spikes, plot LMM mean
  362. print('plotting LMM mean for fmode=%s' % fmode)
  363. a.scatter(0.09, 0.74, # read off of stats/fig2*.pdf
  364. c='red', edgecolor='red', s=50, marker='^')
  365. else:
  366. a.scatter(np.mean(thmeds), gmean(mmeds),
  367. c='red', edgecolor='red', s=50, marker='^')
  368. # display median of median rsq sample values:
  369. txt = '$\mathregular{R^{2}_{med}=}$%.2f' % np.round(np.median(rsqmeds), 2)
  370. a.add_artist(AnchoredText(txt, loc='lower right', frameon=False))
  371. #cbar = f.colorbar(path)
  372. #cbar.ax.set_xlabel('$\mathregular{R^{2}}$')
  373. #cbar.ax.xaxis.set_label_position('top')
  374. a.set_yscale('log', basey=2)
  375. a.set_xlabel('Threshold')
  376. a.set_ylabel('Slope')
  377. a.set_xlim(threshmin, threshmax)
  378. a.set_xticks([threshmin, 0, threshmax])
  379. a.set_ylim(2**slopelogmin, 2**slopelogmax)
  380. a.set_yticks(logyticks)
  381. axes_disable_scientific(a)
  382. a.spines['left'].set_position(('outward', 4))
  383. a.spines['bottom'].set_position(('outward', 4))
  384. ## plot fit rsq vs suppression meanburstratio:
  385. f, a = plt.subplots(figsize=figsize)
  386. titlestr = ('PSTH %s fit rsq vs suppression meanburstratio %s %s %s' %
  387. (model, kind, st8, fmode2txt[fmode]))
  388. wintitle(titlestr, f)
  389. a.scatter(brs[normlis], rsqmeds[normlis], clip_on=False,
  390. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  391. # plot example points, one at a time:
  392. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  393. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  394. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  395. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  396. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  397. a.scatter(brs[exmpli], rsqmeds[exmpli], clip_on=False,
  398. marker=marker, c=c, s=sz, lw=lw)
  399. # get fname of appropriate LMM .cvs file:
  400. if fmode == '': # use Steffen's LMM linregress fit, for fig2f
  401. fname = os.path.join(PAPERPATH, 'stats', 'figure_2f_coefs.csv')
  402. # fetch LMM linregress fit params from .csv:
  403. df = pd.read_csv(fname)
  404. mm = df['slope'][0]
  405. b = df['intercept'][0]
  406. x = np.array([np.nanmin(brs), np.nanmax(brs)])
  407. y = mm * x + b
  408. a.plot(x, y, '-', color='red') # plot linregress fit
  409. a.set_xlabel('Suppression BR')
  410. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  411. a.set_xlim(xmin=0)
  412. a.set_ylim(0, 1)
  413. a.set_yticks([0, 0.5, 1])
  414. a.spines['left'].set_position(('outward', 4))
  415. a.spines['bottom'].set_position(('outward', 4))
  416. ## plot fit rsq vs suppression SNR:
  417. f, a = plt.subplots(figsize=figsize)
  418. titlestr = ('PSTH %s fit rsq vs SNR %s %s %s' %
  419. (model, kind, st8, fmode2txt[fmode]))
  420. wintitle(titlestr, f)
  421. a.scatter(snrs[normlis], rsqmeds[normlis], clip_on=False,
  422. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  423. # plot example points, one at a time:
  424. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  425. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  426. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  427. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  428. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  429. a.scatter(snrs[exmpli], rsqmeds[exmpli], clip_on=False,
  430. marker=marker, c=c, s=sz, lw=lw)
  431. a.set_xlabel('Suppression SNR')
  432. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  433. a.set_xlim(xmin=0)
  434. a.set_ylim(0, 1)
  435. a.set_yticks([0, 0.5, 1])
  436. a.spines['left'].set_position(('outward', 4))
  437. a.spines['bottom'].set_position(('outward', 4))
  438. ## plot suppression fit SNR vs meanburstratio:
  439. f, a = plt.subplots(figsize=figsize)
  440. titlestr = ('PSTH %s fit SNR vs meanburstratio %s %s %s' %
  441. (model, kind, st8, fmode2txt[fmode]))
  442. wintitle(titlestr, f)
  443. a.scatter(brs[normlis], snrs[normlis], clip_on=False,
  444. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  445. # plot example points, one at a time:
  446. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  447. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  448. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  449. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  450. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  451. a.scatter(brs[exmpli], snrs[exmpli], clip_on=False,
  452. marker=marker, c=c, s=sz, lw=lw)
  453. a.set_xlabel('Suppression BR')
  454. a.set_ylabel('Suppression SNR')
  455. a.set_xlim(xmin=0)
  456. a.set_ylim(ymin=0)
  457. a.set_yticks([0, 0.5, 1])
  458. a.spines['left'].set_position(('outward', 4))
  459. a.spines['bottom'].set_position(('outward', 4))
  460. # scatter plot resampled model fit rsq for nonburst vs all spikes, and nonrand vs all spikes:
  461. figsize = DEFAULTFIGURESIZE
  462. for kind in ['nat']:#MVIKINDS:
  463. for st8 in ['none']:#ALLST8S:
  464. # boolean pandas Series:
  465. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  466. rows = sampfitsr[rowis]
  467. allrsqmeds, nbrsqmeds, nrrsqmeds = [], [], [] # median fit rsq values
  468. exmplis, exmplmseustrs, normlis = [], [], []
  469. keptmseui = 0 # manually init and increment instead of using enumerate()
  470. for i, row in rows.iterrows(): # for each mseu
  471. mseustr = row['mseu']
  472. allrsqs, nbrsqs, nrrsqs = row['rsqs'], row['nbrsqs'], row['nrrsqs']
  473. ## TODO: does this throw out too many units?:
  474. if np.isnan([allrsqs, nbrsqs, nrrsqs]).any():
  475. continue # skip this mseu
  476. # skip mseus that had no meaningful signal to fit in either condition:
  477. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  478. if maxsnr < SNRTHRESH:
  479. continue
  480. allrsqmed = np.median(allrsqs)
  481. nbrsqmed = np.median(nbrsqs)
  482. nrrsqmed = np.median(nrrsqs)
  483. #if rsqmed < RSQTHRESH: # skip mseus with really bad median fits
  484. # continue # skip
  485. if allrsqmed < 0 or nbrsqmed < 0 or nrrsqmed < 0: # catastrophic fit failure
  486. print('catastrophic fit failure:', mseustr)
  487. continue # skip
  488. allrsqmeds.append(allrsqmed)
  489. nbrsqmeds.append(nbrsqmed)
  490. nrrsqmeds.append(nrrsqmed)
  491. if mseustr in mvimseu2exmpli:
  492. exmplis.append(keptmseui)
  493. exmplmseustrs.append(mseustr)
  494. else:
  495. normlis.append(keptmseui)
  496. if nbrsqmed > 0.4 and allrsqmed < 0.4:
  497. print(mseustr, allrsqmed, nbrsqmed)
  498. keptmseui += 1 # manually increment
  499. allrsqmeds = np.asarray(allrsqmeds)
  500. nbrsqmeds = np.asarray(nbrsqmeds)
  501. nrrsqmeds = np.asarray(nrrsqmeds)
  502. # plot nonburst vs all rsq:
  503. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  504. titlestr = 'PSTH %s fit rsq nonburst vs all %s %s' % (model, kind, st8)
  505. wintitle(titlestr, f)
  506. linmax = np.vstack([allrsqmeds, nbrsqmeds]).max()
  507. xyline = [linmin, linmax], [linmin, linmax]
  508. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  509. a.scatter(allrsqmeds[normlis], nbrsqmeds[normlis], clip_on=False,
  510. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  511. # plot example points, one at a time:
  512. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  513. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  514. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  515. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  516. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  517. a.scatter(allrsqmeds[exmpli], nbrsqmeds[exmpli], clip_on=False,
  518. marker=marker, c=c, s=sz, lw=lw)
  519. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  520. a.set_ylabel('Non-burst spikes $\mathregular{R^{2}}$')
  521. a.set_xlim(linmin, 1)
  522. a.set_ylim(linmin, 1)
  523. a.set_xticks([0, 0.5, 1])
  524. a.set_yticks([0, 0.5, 1])
  525. a.set_aspect('equal')
  526. a.spines['left'].set_position(('outward', 4))
  527. a.spines['bottom'].set_position(('outward', 4))
  528. # plot nr vs all rsq:
  529. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  530. titlestr = 'PSTH %s fit rsq nr vs all %s %s' % (model, kind, st8)
  531. wintitle(titlestr, f)
  532. linmax = np.vstack([allrsqmeds, nrrsqmeds]).max()
  533. xyline = [linmin, linmax], [linmin, linmax]
  534. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  535. a.scatter(allrsqmeds[normlis], nrrsqmeds[normlis], clip_on=False,
  536. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  537. # plot example points, one at a time:
  538. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  539. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  540. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  541. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  542. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  543. a.scatter(allrsqmeds[exmpli], nrrsqmeds[exmpli], clip_on=False,
  544. marker=marker, c=c, s=sz, lw=lw)
  545. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  546. a.set_ylabel('Rand. rem. spikes $\mathregular{R^{2}}$')
  547. a.set_xlim(linmin, 1)
  548. a.set_ylim(linmin, 1)
  549. a.set_xticks([0, 0.5, 1])
  550. a.set_yticks([0, 0.5, 1])
  551. a.set_aspect('equal')
  552. a.spines['left'].set_position(('outward', 4))
  553. a.spines['bottom'].set_position(('outward', 4))