fig2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. """Figure 2 plots, use run -i fig2.py"""
  2. fmodestxt = list(fmode2txt.values()) # 'all', 'nonburst', 'burst', 'nonrand'
  3. mi = pd.MultiIndex.from_product([mvimseustrs, fmodestxt], names=['mseu', 'fmode'])
  4. fig2 = pd.DataFrame(index=mi, columns=['slope', 'thresh', 'rsq',
  5. 'suppression_meanburstratio', 'suppression_snr'])
  6. # convert multi index to columns for easier to read restriction syntax:
  7. sampfitsr = sampfits.reset_index()
  8. mvirespr = mviresp.reset_index()
  9. linmin = 0
  10. # plot single-sample V1 model-fit PSTH overtop of test feedback and suppression PSTHs:
  11. dt = 5 # stimulus duration sec
  12. #tmin, tmax = 0 + OFFSETS[0], dt + OFFSETS[1]
  13. tmin, tmax = 0, dt
  14. #figsize = 0.855 + tmax*1.5*RASTERSCALEX, PSTHHEIGHT
  15. psthfigsize = 0.855 + tmax*2.1*RASTERSCALEX, PSTHHEIGHT
  16. # medians:
  17. # n1: rsq=0.29, slope=0.37, thr=-0.14
  18. # n2: rsq=0.90, slope=0.67, thr=1.58
  19. #for seedi in [110, 199, 339, 388]: # seeds that result in rsq identical to median
  20. np.random.seed(199) # fix random seed for identical results from np.random.choice() on each run
  21. for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs:
  22. for kind in ['nat']:#MVIKINDS:
  23. for st8 in ['none']:#ALLST8S:
  24. # get one sample of trials, separate into fit and test data:
  25. ctrlraster = mviresp.loc[mseustr, kind, st8, False]['raster'] # narrow
  26. optoraster = mviresp.loc[mseustr, kind, st8, True]['raster'] # narrow
  27. optoburstraster = mviresp.loc[mseustr, kind, st8, True]['braster'] # narrow
  28. nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none'
  29. noptotrials = len(optoraster) # typically 200 trials when st8 == 'none'
  30. for ntrials in [nctrltrials, noptotrials]:
  31. if ntrials < MINNTRIALSAMPLETHRESH:
  32. print('Not enough trials to sample:', mseustr, kind, st8)
  33. continue # don't bother sampling, leave entry in sampfits empty
  34. ctrltrialis = np.arange(nctrltrials)
  35. optotrialis = np.arange(noptotrials)
  36. ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing
  37. optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing
  38. # probably doesn't matter if use ctrl or opto bins, should be the same:
  39. bins = mviresp.loc[mseustr, kind, st8, False]['bins']
  40. t = mviresp.loc[mseustr, kind, st8, False]['t'] # mid bins
  41. # randomly sample half the ctrl and opto trials, without replacement:
  42. ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False))
  43. optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False))
  44. ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement
  45. optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement
  46. ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis]
  47. optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis]
  48. optotestburstraster = optoburstraster[optotestis]
  49. # calculate fit and test opto PSTHs, subsampled in time:
  50. ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx]
  51. optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx]
  52. ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx]
  53. optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx]
  54. optotestburstpsth = raster2psth(optotestburstraster, bins, binw, tres, kernel)[::ssx]
  55. mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth,
  56. model=model)
  57. #rsqstr = '%.2f' % np.round(rsq, 2)
  58. #exmpli2rsqstr = {1:'0.29', 2:'0.90'}
  59. #if rsqstr != exmpli2rsqstr[exmpli]:
  60. # continue
  61. #print(seedi, exmpli)
  62. th = -b / mm # threshold, i.e., x intercept
  63. ## plot the sample's PSTHs:
  64. f, a = plt.subplots(figsize=psthfigsize)
  65. title = '%s %s %s PSTH %s model' % (mseustr, kind, st8, model)
  66. wintitle(title, f)
  67. color = st82clr[st8]
  68. tss = t[::ssx] # subsampled in time
  69. # plot control and opto test PSTHs:
  70. opto2psth = {False:ctrltestpsth, True:optotestpsth}
  71. for opto in OPTOS:
  72. psth = opto2psth[opto]
  73. c = desat(color, opto2alpha[opto]) # do manual alpha mixing
  74. fb = opto2fb[opto].title()
  75. if opto == False:
  76. ls = '--'
  77. lw = 1
  78. else:
  79. ls = '-'
  80. lw = 2
  81. br = mviresp.loc[mseustr, kind, st8, opto]['meanburstratio']
  82. print('%s exampli=%d %s %s meanburstratio=%g'
  83. % (mseustr, exmpli, st8, opto, br))
  84. a.plot(tss, psth, ls=ls, lw=lw, marker='None', color=c, label=fb)
  85. # plot opto model:
  86. modeltestpsth = mm * ctrltestpsth + b
  87. a.plot(tss, modeltestpsth, '-', lw=2, color=optoblue, alpha=1,
  88. label='Suppression model')
  89. # plot opto burst test PSTH:
  90. bc = desat(burstclr, 0.3) # do manual alpha mixing
  91. #bc = desat(burstclr, opto2alpha[opto]) # do manual alpha mixing
  92. a.plot(tss, optotestburstpsth, ls=ls, lw=lw, marker='None', color=bc,
  93. label=fb+' burst', zorder=-99)
  94. l = a.legend(frameon=False)
  95. l.set_draggable(True)
  96. a.set_xlabel('Time (s)')
  97. a.set_ylabel('Firing rate (spk/s)')
  98. a.set_xlim(tmin, tmax)
  99. # plot horizontal line signifying stimulus period, just below data:
  100. #ymin, ymax = a.get_ylim()
  101. #a.hlines(-ymax*0.025, 0, dt, colors='k', lw=4, clip_on=False, in_layout=False)
  102. #a.set_ylim(ymin, ymax) # restore y limits from before plotting the hline
  103. ## scatter plot suppression vs. feedback for the sample's PSTH timepoints,
  104. ## plus model fit line:
  105. scale = 'linear' # log or linear
  106. logmin, logmax = -2, 2
  107. log0min = logmin + 0.05
  108. # for nat none all-spike plots for fig2:
  109. examplemaxrate = {'PVCre_2018_0003_s03_e03_u51':35,
  110. 'PVCre_2017_0008_s12_e06_u56':140}
  111. if kind == 'nat' and st8 == 'none':
  112. linmax = examplemaxrate[mseustr]
  113. else:
  114. linmax = np.vstack([ctrltestpsth, optotestpsth]).max()
  115. figsize = DEFAULTFIGURESIZE
  116. if linmax >= 100:
  117. figsize = figsize[0]*1.025, figsize[1] # tweak to make space extra y axis digit
  118. f, a = plt.subplots(figsize=figsize)
  119. titlestr = ('%s %s %s PSTH scatter %s' % (mseustr, kind, st8, model))
  120. wintitle(titlestr, f)
  121. if scale == 'log':
  122. # replace off-scale low values with log0min, so the points remain visible:
  123. ctrltestpsth[ctrltestpsth <= 10**logmin] = 10**log0min
  124. optotestpsth[optotestpsth <= 10**logmin] = 10**log0min
  125. # plot y=x line:
  126. if scale == 'log':
  127. xyline = [10**logmin, 10**logmax], [10**logmin, 10**logmax]
  128. else:
  129. xyline = [linmin, linmax], [linmin, linmax]
  130. a.plot(xyline[0], xyline[1], '--', clip_on=True, color='gray', zorder=100)
  131. # plot testpsth vs control:
  132. texts = []
  133. c = st82clr[st8]
  134. a.plot(ctrltestpsth, optotestpsth, '.', clip_on=True, ms=2, color=c, label=st8)
  135. x = np.array([ctrltestpsth.min(), ctrltestpsth.max()])
  136. y = mm * x + b # model output line
  137. a.plot(x, y, '-', color=optoblue, clip_on=True, alpha=1, lw=2) # plot model fit
  138. # display the sample's parameters:
  139. txt = ('$\mathregular{R^{2}=%.2f}$\n'
  140. '$\mathregular{Slope=%.2f}$\n'
  141. '$\mathregular{Thr.=%.1f}$' % (rsq, mm, th))
  142. a.add_artist(AnchoredText(txt, loc='upper left', frameon=False))
  143. a.set_xlabel('Feedback rate (spk/s)')
  144. a.set_ylabel('Suppression rate (spk/s)')
  145. if scale == 'log':
  146. a.set_xscale('log')
  147. a.set_yscale('log')
  148. xymin, xymax = 10**logmin, 10**logmax
  149. ticks = 10**(np.arange(logmin, logmax+1, dtype=float))
  150. else:
  151. xymin, xymax = linmin, linmax
  152. ticks = a.get_xticks()
  153. a.set_xticks(ticks) # seem to need to set both x and y to force them to match
  154. a.set_yticks(ticks)
  155. a.set_xlim(xymin, xymax)
  156. a.set_ylim(xymin, xymax)
  157. a.set_aspect('equal')
  158. a.spines['left'].set_position(('outward', 4))
  159. a.spines['bottom'].set_position(('outward', 4))
  160. # scatter plot resampled model slope vs threshold, for only those fits with decent SNR
  161. # in at least one opto state:
  162. figsize = DEFAULTFIGURESIZE
  163. slopelogmin, slopelogmax = -4, 1
  164. logyticks = np.logspace(slopelogmin, slopelogmax, num=(slopelogmax-slopelogmin+1), base=2)
  165. threshmin, threshmax = -25, 25
  166. for kind in ['nat']:#MVIKINDS:
  167. for st8 in ['none']:#ALLST8S:
  168. for fmode in fmodes: # iterate over firing modes (all, non-burst, burst)
  169. print('fmode:', fmode)
  170. # boolean pandas Series:
  171. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  172. rows = sampfitsr[rowis]
  173. mmeds, thmeds, rsqmeds = [], [], [] # median fit values
  174. brs, snrs = [], [] # single burst ratio and SNR values matching each mseu
  175. sgnfis, insgnfis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  176. keptmseui = 0 # manually init and increment instead of using enumerate()
  177. for i, row in rows.iterrows(): # for each mseu
  178. mseustr = row['mseu']
  179. ms, ths, rsqs = row[fmode+'ms'], row[fmode+'ths'], row[fmode+'rsqs']
  180. if np.isnan([ms, ths, rsqs]).any():
  181. continue # skip this mseu
  182. # skip mseus that had no meaningful signal to fit in either condition:
  183. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  184. if maxsnr < SNRTHRESH:
  185. continue
  186. mmed, ml, mh = percentile_ci(ms)
  187. thmed, thl, thh = percentile_ci(ths)
  188. rsqmed = np.median(rsqs)
  189. if rsqmed < 0: # catastrophic fit failure
  190. print('catastrophic fit failure:', mseustr)
  191. continue # skip
  192. mmeds.append(mmed)
  193. thmeds.append(thmed)
  194. rsqmeds.append(rsqmed)
  195. # collect suppression mean burst ratio and SNR:
  196. mvirowis = ((mvirespr['mseu'] == mseustr) &
  197. (mvirespr['kind'] == kind) &
  198. (mvirespr['st8'] == st8) &
  199. (mvirespr['opto'] == True)) # only for suppression
  200. mvirow = mvirespr[mvirowis]
  201. assert len(mvirow) == 1
  202. br = mvirow['meanburstratio'].iloc[0]
  203. snr = mvirow['snr'].iloc[0]
  204. brs.append(br)
  205. snrs.append(snr)
  206. fig2['slope'][mseustr, fmode2txt[fmode]] = mmed
  207. fig2['thresh'][mseustr, fmode2txt[fmode]] = thmed
  208. fig2['rsq'][mseustr, fmode2txt[fmode]] = rsqmed
  209. fig2['suppression_meanburstratio'][mseustr, fmode2txt[fmode]] = br
  210. fig2['suppression_snr'][mseustr, fmode2txt[fmode]] = snr
  211. # collect separate lists of significant and insignificant points:
  212. if not (ml < 1 < mh):# and not (thl < 0 < thh):
  213. sgnfis.append(keptmseui) # slope significantly different from 1
  214. else:
  215. insgnfis.append(keptmseui) # slope not significantly different from 1
  216. if mseustr in mvimseu2exmpli:
  217. exmplis.append(keptmseui)
  218. exmplmseustrs.append(mseustr)
  219. print('Example Neuron', mvimseu2exmpli[mseustr], ':', mseustr)
  220. print('mmed, ml, mh:', mmed, ml, mh)
  221. print('thmed, thl, thh:', thmed, thl, thh)
  222. else:
  223. normlis.append(keptmseui)
  224. keptmseui += 1 # manually increment
  225. mmeds = np.asarray(mmeds)
  226. thmeds = np.asarray(thmeds)
  227. rsqmeds = np.asarray(rsqmeds)
  228. brs = np.asarray(brs)
  229. snrs = np.asarray(snrs)
  230. ## plot median fit slope vs thresh:
  231. f, a = plt.subplots(figsize=(figsize[0]*1.11, figsize[1])) # extra space for ylabels
  232. titlestr = ('PSTH %s fit slope vs thresh %s %s snrthresh=%.3f %s'
  233. % (model, kind, st8, SNRTHRESH, fmode2txt[fmode]))
  234. wintitle(titlestr, f)
  235. normlinsgnfis = np.intersect1d(normlis, insgnfis)
  236. normlsgnfis = np.intersect1d(normlis, sgnfis)
  237. # plot x=0 and y=1 lines:
  238. a.axvline(x=0, ls='--', marker='', color='gray', zorder=-np.inf)
  239. a.axhline(y=1, ls='--', marker='', color='gray', zorder=-np.inf)
  240. # plot normal (non-example) insignificant points:
  241. c = desat(st82clr[st8], SGNF2ALPHA[False]) # do manual alpha mixing
  242. a.scatter(thmeds[normlinsgnfis], mmeds[normlinsgnfis], clip_on=False,
  243. marker='.', c='None', edgecolor=c, s=DEFSZ)
  244. # plot normal (non-example) significant points:
  245. c = desat(st82clr[st8], SGNF2ALPHA[True]) # do manual alpha mixing
  246. a.scatter(thmeds[normlsgnfis], mmeds[normlsgnfis], clip_on=False,
  247. marker='.', c='None', edgecolor=c, s=DEFSZ)
  248. exmplinsgnfis = np.intersect1d(exmplis, insgnfis)
  249. exmplsgnfis = np.intersect1d(exmplis, sgnfis)
  250. print('exmplinsgnfis', exmplinsgnfis)
  251. print('examplsgnfis', exmplsgnfis)
  252. # plot insignificant and significant example points, one at a time:
  253. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  254. if exmpli in exmplinsgnfis:
  255. alpha = SGNF2ALPHA[False]
  256. elif exmpli in exmplsgnfis:
  257. alpha = SGNF2ALPHA[True]
  258. else:
  259. raise RuntimeError("Some kind of exmpli set membership error")
  260. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  261. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  262. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  263. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  264. a.scatter(thmeds[exmpli], mmeds[exmpli], clip_on=False, marker=marker, c=c,
  265. s=sz, lw=lw, alpha=alpha)
  266. # plot mean median point:
  267. if fmode == '': # all spikes, plot LMM mean
  268. print('plotting LMM mean for fmode=%s' % fmode)
  269. a.scatter(-0.19, 0.75, # read off of stats/fig2*.pdf
  270. c='red', edgecolor='red', s=50, marker='^')
  271. elif fmode == 'nb': # non burst spikes, plot LMM mean
  272. print('plotting LMM mean for fmode=%s' % fmode)
  273. a.scatter(0.09, 0.74, # read off of stats/fig2*.pdf
  274. c='red', edgecolor='red', s=50, marker='^')
  275. else:
  276. a.scatter(np.mean(thmeds), gmean(mmeds),
  277. c='red', edgecolor='red', s=50, marker='^')
  278. # display median of median rsq sample values:
  279. txt = '$\mathregular{R^{2}_{med}=}$%.2f' % np.round(np.median(rsqmeds), 2)
  280. a.add_artist(AnchoredText(txt, loc='lower right', frameon=False))
  281. #cbar = f.colorbar(path)
  282. #cbar.ax.set_xlabel('$\mathregular{R^{2}}$')
  283. #cbar.ax.xaxis.set_label_position('top')
  284. a.set_yscale('log', basey=2)
  285. a.set_xlabel('Threshold')
  286. a.set_ylabel('Slope')
  287. a.set_xlim(threshmin, threshmax)
  288. a.set_xticks([threshmin, 0, threshmax])
  289. a.set_ylim(2**slopelogmin, 2**slopelogmax)
  290. a.set_yticks(logyticks)
  291. axes_disable_scientific(a)
  292. a.spines['left'].set_position(('outward', 4))
  293. a.spines['bottom'].set_position(('outward', 4))
  294. ## plot fit rsq vs suppression meanburstratio:
  295. f, a = plt.subplots(figsize=figsize)
  296. titlestr = ('PSTH %s fit rsq vs suppression meanburstratio %s %s %s' %
  297. (model, kind, st8, fmode2txt[fmode]))
  298. wintitle(titlestr, f)
  299. a.scatter(brs[normlis], rsqmeds[normlis], clip_on=False,
  300. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  301. # plot example points, one at a time:
  302. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  303. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  304. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  305. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  306. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  307. a.scatter(brs[exmpli], rsqmeds[exmpli], clip_on=False,
  308. marker=marker, c=c, s=sz, lw=lw)
  309. # get fname of appropriate LMM .cvs file:
  310. if fmode == '': # use Steffen's LMM linregress fit, for fig2f
  311. fname = os.path.join('stats', 'figure_2f_coefs.csv')
  312. # fetch LMM linregress fit params from .csv:
  313. df = pd.read_csv(fname)
  314. mm = df['slope'][0]
  315. b = df['intercept'][0]
  316. x = np.array([np.nanmin(brs), np.nanmax(brs)])
  317. y = mm * x + b
  318. a.plot(x, y, '-', color='red') # plot linregress fit
  319. a.set_xlabel('Suppression BR')
  320. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  321. a.set_xlim(xmin=0)
  322. a.set_ylim(0, 1)
  323. a.set_yticks([0, 0.5, 1])
  324. a.spines['left'].set_position(('outward', 4))
  325. a.spines['bottom'].set_position(('outward', 4))
  326. ## plot fit rsq vs suppression SNR:
  327. f, a = plt.subplots(figsize=figsize)
  328. titlestr = ('PSTH %s fit rsq vs SNR %s %s %s' %
  329. (model, kind, st8, fmode2txt[fmode]))
  330. wintitle(titlestr, f)
  331. a.scatter(snrs[normlis], rsqmeds[normlis], clip_on=False,
  332. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  333. # plot example points, one at a time:
  334. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  335. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  336. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  337. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  338. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  339. a.scatter(snrs[exmpli], rsqmeds[exmpli], clip_on=False,
  340. marker=marker, c=c, s=sz, lw=lw)
  341. a.set_xlabel('Suppression SNR')
  342. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  343. a.set_xlim(xmin=0)
  344. a.set_ylim(0, 1)
  345. a.set_yticks([0, 0.5, 1])
  346. a.spines['left'].set_position(('outward', 4))
  347. a.spines['bottom'].set_position(('outward', 4))
  348. ## plot suppression fit SNR vs meanburstratio:
  349. f, a = plt.subplots(figsize=figsize)
  350. titlestr = ('PSTH %s fit SNR vs meanburstratio %s %s %s' %
  351. (model, kind, st8, fmode2txt[fmode]))
  352. wintitle(titlestr, f)
  353. a.scatter(brs[normlis], snrs[normlis], clip_on=False,
  354. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  355. # plot example points, one at a time:
  356. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  357. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  358. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  359. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  360. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  361. a.scatter(brs[exmpli], snrs[exmpli], clip_on=False,
  362. marker=marker, c=c, s=sz, lw=lw)
  363. a.set_xlabel('Suppression BR')
  364. a.set_ylabel('Suppression SNR')
  365. a.set_xlim(xmin=0)
  366. a.set_ylim(ymin=0)
  367. a.set_yticks([0, 0.5, 1])
  368. a.spines['left'].set_position(('outward', 4))
  369. a.spines['bottom'].set_position(('outward', 4))
  370. # scatter plot resampled model fit rsq for nonburst vs all spikes, and nonrand vs all spikes:
  371. figsize = DEFAULTFIGURESIZE
  372. for kind in ['nat']:#MVIKINDS:
  373. for st8 in ['none']:#ALLST8S:
  374. # boolean pandas Series:
  375. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  376. rows = sampfitsr[rowis]
  377. allrsqmeds, nbrsqmeds, nrrsqmeds = [], [], [] # median fit rsq values
  378. exmplis, exmplmseustrs, normlis = [], [], []
  379. keptmseui = 0 # manually init and increment instead of using enumerate()
  380. for i, row in rows.iterrows(): # for each mseu
  381. mseustr = row['mseu']
  382. allrsqs, nbrsqs, nrrsqs = row['rsqs'], row['nbrsqs'], row['nrrsqs']
  383. if np.isnan([allrsqs, nbrsqs, nrrsqs]).any():
  384. continue # skip this mseu
  385. # skip mseus that had no meaningful signal to fit in either condition:
  386. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  387. if maxsnr < SNRTHRESH:
  388. continue
  389. allrsqmed = np.median(allrsqs)
  390. nbrsqmed = np.median(nbrsqs)
  391. nrrsqmed = np.median(nrrsqs)
  392. if allrsqmed < 0 or nbrsqmed < 0 or nrrsqmed < 0: # catastrophic fit failure
  393. print('catastrophic fit failure:', mseustr)
  394. continue # skip
  395. allrsqmeds.append(allrsqmed)
  396. nbrsqmeds.append(nbrsqmed)
  397. nrrsqmeds.append(nrrsqmed)
  398. if mseustr in mvimseu2exmpli:
  399. exmplis.append(keptmseui)
  400. exmplmseustrs.append(mseustr)
  401. else:
  402. normlis.append(keptmseui)
  403. if nbrsqmed > 0.4 and allrsqmed < 0.4:
  404. print(mseustr, allrsqmed, nbrsqmed)
  405. keptmseui += 1 # manually increment
  406. allrsqmeds = np.asarray(allrsqmeds)
  407. nbrsqmeds = np.asarray(nbrsqmeds)
  408. nrrsqmeds = np.asarray(nrrsqmeds)
  409. # plot nonburst vs all rsq:
  410. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  411. titlestr = 'PSTH %s fit rsq nonburst vs all %s %s' % (model, kind, st8)
  412. wintitle(titlestr, f)
  413. linmax = np.vstack([allrsqmeds, nbrsqmeds]).max()
  414. xyline = [linmin, linmax], [linmin, linmax]
  415. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  416. a.scatter(allrsqmeds[normlis], nbrsqmeds[normlis], clip_on=False,
  417. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  418. # plot example points, one at a time:
  419. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  420. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  421. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  422. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  423. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  424. a.scatter(allrsqmeds[exmpli], nbrsqmeds[exmpli], clip_on=False,
  425. marker=marker, c=c, s=sz, lw=lw)
  426. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  427. a.set_ylabel('Non-burst spikes $\mathregular{R^{2}}$')
  428. a.set_xlim(linmin, 1)
  429. a.set_ylim(linmin, 1)
  430. a.set_xticks([0, 0.5, 1])
  431. a.set_yticks([0, 0.5, 1])
  432. a.set_aspect('equal')
  433. a.spines['left'].set_position(('outward', 4))
  434. a.spines['bottom'].set_position(('outward', 4))
  435. # plot nr vs all rsq:
  436. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  437. titlestr = 'PSTH %s fit rsq nr vs all %s %s' % (model, kind, st8)
  438. wintitle(titlestr, f)
  439. linmax = np.vstack([allrsqmeds, nrrsqmeds]).max()
  440. xyline = [linmin, linmax], [linmin, linmax]
  441. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  442. a.scatter(allrsqmeds[normlis], nrrsqmeds[normlis], clip_on=False,
  443. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  444. # plot example points, one at a time:
  445. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  446. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  447. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  448. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  449. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  450. a.scatter(allrsqmeds[exmpli], nrrsqmeds[exmpli], clip_on=False,
  451. marker=marker, c=c, s=sz, lw=lw)
  452. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  453. a.set_ylabel('Rand. rem. spikes $\mathregular{R^{2}}$')
  454. a.set_xlim(linmin, 1)
  455. a.set_ylim(linmin, 1)
  456. a.set_xticks([0, 0.5, 1])
  457. a.set_yticks([0, 0.5, 1])
  458. a.set_aspect('equal')
  459. a.spines['left'].set_position(('outward', 4))
  460. a.spines['bottom'].set_position(('outward', 4))