fig2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """Figure 2 plots, use run -i fig2.py"""
  2. fmodestxt = list(fmode2txt.values()) # 'all', 'nonburst', 'burst', 'nonrand'
  3. mi = pd.MultiIndex.from_product([mvimseustrs, fmodestxt], names=['mseu', 'fmode'])
  4. fig2 = pd.DataFrame(index=mi, columns=['slope', 'thresh', 'rsq',
  5. 'suppression_meanburstratio', 'suppression_snr'])
  6. # convert multi index to columns for easier to read restriction syntax:
  7. sampfitsr = sampfits.reset_index()
  8. mvirespr = mviresp.reset_index()
  9. linmin = 0
  10. # plot single-sample V1 model-fit PSTH overtop of test feedback and suppression PSTHs:
  11. ## TODO: correlate burst PSTHs with (model-empirical)**2, scatter plot them against each other,
  12. ## or just plot distribution of corrcoefs
  13. dt = 5 # stimulus duration sec
  14. #tmin, tmax = 0 + OFFSETS[0], dt + OFFSETS[1]
  15. tmin, tmax = 0, dt
  16. #figsize = 0.855 + tmax*1.5*RASTERSCALEX, PSTHHEIGHT
  17. psthfigsize = 0.855 + tmax*2.1*RASTERSCALEX, PSTHHEIGHT
  18. # medians:
  19. # n1: rsq=0.29, slope=0.37, thr=-0.14
  20. # n2: rsq=0.90, slope=0.67, thr=1.58
  21. #for seedi in [110, 199, 339, 388]: # seeds that result in rsq identical to median
  22. np.random.seed(199) # fix random seed for identical results from np.random.choice() on each run
  23. for mseustr, exmpli in mvimseu2exmpli.items(): #mvimseustrs:
  24. for kind in ['nat']:#MVIKINDS:
  25. for st8 in ['none']:#ALLST8S:
  26. # get one sample of trials, separate into fit and test data:
  27. ctrlraster = mviresp.loc[mseustr, kind, st8, False]['raster'] # narrow
  28. optoraster = mviresp.loc[mseustr, kind, st8, True]['raster'] # narrow
  29. optoburstraster = mviresp.loc[mseustr, kind, st8, True]['braster'] # narrow
  30. nctrltrials = len(ctrlraster) # typically 200 trials when st8 == 'none'
  31. noptotrials = len(optoraster) # typically 200 trials when st8 == 'none'
  32. for ntrials in [nctrltrials, noptotrials]:
  33. if ntrials < MINNTRIALSAMPLETHRESH:
  34. print('Not enough trials to sample:', mseustr, kind, st8)
  35. continue # don't bother sampling, leave entry in sampfits empty
  36. ctrltrialis = np.arange(nctrltrials)
  37. optotrialis = np.arange(noptotrials)
  38. ctrlsamplesize = intround(nctrltrials / 2) # half for fitting, half for testing
  39. optosamplesize = intround(noptotrials / 2) # half for fitting, half for testing
  40. # probably doesn't matter if use ctrl or opto bins, should be the same:
  41. bins = mviresp.loc[mseustr, kind, st8, False]['bins']
  42. t = mviresp.loc[mseustr, kind, st8, False]['t'] # mid bins
  43. # randomly sample half the ctrl and opto trials, without replacement:
  44. ctrlfitis = np.sort(choice(ctrltrialis, size=ctrlsamplesize, replace=False))
  45. optofitis = np.sort(choice(optotrialis, size=optosamplesize, replace=False))
  46. ctrltestis = np.setdiff1d(ctrltrialis, ctrlfitis) # get the complement
  47. optotestis = np.setdiff1d(optotrialis, optofitis) # get the complement
  48. ctrlfitraster, ctrltestraster = ctrlraster[ctrlfitis], ctrlraster[ctrltestis]
  49. optofitraster, optotestraster = optoraster[optofitis], optoraster[optotestis]
  50. optotestburstraster = optoburstraster[optotestis]
  51. # calculate fit and test opto PSTHs, subsampled in time:
  52. ctrlfitpsth = raster2psth(ctrlfitraster, bins, binw, tres, kernel)[::ssx]
  53. optofitpsth = raster2psth(optofitraster, bins, binw, tres, kernel)[::ssx]
  54. ctrltestpsth = raster2psth(ctrltestraster, bins, binw, tres, kernel)[::ssx]
  55. optotestpsth = raster2psth(optotestraster, bins, binw, tres, kernel)[::ssx]
  56. optotestburstpsth = raster2psth(optotestburstraster, bins, binw, tres, kernel)[::ssx]
  57. mm, b, rsq = fitmodel(ctrlfitpsth, optofitpsth, ctrltestpsth, optotestpsth,
  58. model=model)
  59. #rsqstr = '%.2f' % np.round(rsq, 2)
  60. #exmpli2rsqstr = {1:'0.29', 2:'0.90'}
  61. #if rsqstr != exmpli2rsqstr[exmpli]:
  62. # continue
  63. #print(seedi, exmpli)
  64. th = -b / mm # threshold, i.e., x intercept
  65. ## plot the sample's PSTHs:
  66. f, a = plt.subplots(figsize=psthfigsize)
  67. title = '%s %s %s PSTH %s model' % (mseustr, kind, st8, model)
  68. wintitle(title, f)
  69. color = st82clr[st8]
  70. tss = t[::ssx] # subsampled in time
  71. # plot control and opto test PSTHs:
  72. opto2psth = {False:ctrltestpsth, True:optotestpsth}
  73. for opto in OPTOS:
  74. psth = opto2psth[opto]
  75. c = desat(color, opto2alpha[opto]) # do manual alpha mixing
  76. fb = opto2fb[opto].title()
  77. if opto == False:
  78. ls = '--'
  79. lw = 1
  80. else:
  81. ls = '-'
  82. lw = 2
  83. br = mviresp.loc[mseustr, kind, st8, opto]['meanburstratio']
  84. print('%s exampli=%d %s %s meanburstratio=%g'
  85. % (mseustr, exmpli, st8, opto, br))
  86. a.plot(tss, psth, ls=ls, lw=lw, marker='None', color=c, label=fb)
  87. # plot opto model:
  88. modeltestpsth = mm * ctrltestpsth + b
  89. a.plot(tss, modeltestpsth, '-', lw=2, color=optoblue, alpha=1,
  90. label='Suppression model')
  91. # plot opto burst test PSTH:
  92. bc = desat(burstclr, 0.3) # do manual alpha mixing
  93. #bc = desat(burstclr, opto2alpha[opto]) # do manual alpha mixing
  94. a.plot(tss, optotestburstpsth, ls=ls, lw=lw, marker='None', color=bc,
  95. label=fb+' burst', zorder=-99)
  96. l = a.legend(frameon=False)
  97. l.set_draggable(True)
  98. a.set_xlabel('Time (s)')
  99. a.set_ylabel('Firing rate (spk/s)')
  100. a.set_xlim(tmin, tmax)
  101. # plot horizontal line signifying stimulus period, just below data:
  102. #ymin, ymax = a.get_ylim()
  103. #a.hlines(-ymax*0.025, 0, dt, colors='k', lw=4, clip_on=False, in_layout=False)
  104. #a.set_ylim(ymin, ymax) # restore y limits from before plotting the hline
  105. ## scatter plot suppression vs. feedback for the sample's PSTH timepoints,
  106. ## plus model fit line:
  107. scale = 'linear' # log or linear
  108. logmin, logmax = -2, 2
  109. log0min = logmin + 0.05
  110. # for nat none all-spike plots for fig2:
  111. examplemaxrate = {'PVCre_2018_0003_s03_e03_u51':35,
  112. 'PVCre_2017_0008_s12_e06_u56':140}
  113. if kind == 'nat' and st8 == 'none':
  114. linmax = examplemaxrate[mseustr]
  115. else:
  116. linmax = np.vstack([ctrltestpsth, optotestpsth]).max()
  117. figsize = DEFAULTFIGURESIZE
  118. if linmax >= 100:
  119. figsize = figsize[0]*1.025, figsize[1] # tweak to make space extra y axis digit
  120. f, a = plt.subplots(figsize=figsize)
  121. titlestr = ('%s %s %s PSTH scatter %s' % (mseustr, kind, st8, model))
  122. wintitle(titlestr, f)
  123. if scale == 'log':
  124. # replace off-scale low values with log0min, so the points remain visible:
  125. ctrltestpsth[ctrltestpsth <= 10**logmin] = 10**log0min
  126. optotestpsth[optotestpsth <= 10**logmin] = 10**log0min
  127. # plot y=x line:
  128. if scale == 'log':
  129. xyline = [10**logmin, 10**logmax], [10**logmin, 10**logmax]
  130. else:
  131. xyline = [linmin, linmax], [linmin, linmax]
  132. a.plot(xyline[0], xyline[1], '--', clip_on=True, color='gray', zorder=100)
  133. # plot testpsth vs control:
  134. texts = []
  135. c = st82clr[st8]
  136. a.plot(ctrltestpsth, optotestpsth, '.', clip_on=True, ms=2, color=c, label=st8)
  137. x = np.array([ctrltestpsth.min(), ctrltestpsth.max()])
  138. y = mm * x + b # model output line
  139. a.plot(x, y, '-', color=optoblue, clip_on=True, alpha=1, lw=2) # plot model fit
  140. # display the sample's parameters:
  141. txt = ('$\mathregular{R^{2}=%.2f}$\n'
  142. '$\mathregular{Slope=%.2f}$\n'
  143. '$\mathregular{Thr.=%.1f}$' % (rsq, mm, th))
  144. a.add_artist(AnchoredText(txt, loc='upper left', frameon=False))
  145. a.set_xlabel('Feedback rate (spk/s)')
  146. a.set_ylabel('Suppression rate (spk/s)')
  147. if scale == 'log':
  148. a.set_xscale('log')
  149. a.set_yscale('log')
  150. xymin, xymax = 10**logmin, 10**logmax
  151. ticks = 10**(np.arange(logmin, logmax+1, dtype=float))
  152. else:
  153. xymin, xymax = linmin, linmax
  154. ticks = a.get_xticks()
  155. a.set_xticks(ticks) # seem to need to set both x and y to force them to match
  156. a.set_yticks(ticks)
  157. a.set_xlim(xymin, xymax)
  158. a.set_ylim(xymin, xymax)
  159. a.set_aspect('equal')
  160. a.spines['left'].set_position(('outward', 4))
  161. a.spines['bottom'].set_position(('outward', 4))
  162. # scatter plot resampled model slope vs threshold, for only those fits with decent SNR
  163. # in at least one opto state:
  164. figsize = DEFAULTFIGURESIZE
  165. slopelogmin, slopelogmax = -4, 1
  166. logyticks = np.logspace(slopelogmin, slopelogmax, num=(slopelogmax-slopelogmin+1), base=2)
  167. threshmin, threshmax = -25, 25
  168. for kind in ['nat']:#MVIKINDS:
  169. for st8 in ['none']:#ALLST8S:
  170. for fmode in fmodes: # iterate over firing modes (all, non-burst, burst)
  171. print('fmode:', fmode)
  172. # boolean pandas Series:
  173. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  174. rows = sampfitsr[rowis]
  175. mmeds, thmeds, rsqmeds = [], [], [] # median fit values
  176. brs, snrs = [], [] # single burst ratio and SNR values matching each mseu
  177. sgnfis, insgnfis, exmplis, exmplmseustrs, normlis = [], [], [], [], []
  178. keptmseui = 0 # manually init and increment instead of using enumerate()
  179. for i, row in rows.iterrows(): # for each mseu
  180. mseustr = row['mseu']
  181. ms, ths, rsqs = row[fmode+'ms'], row[fmode+'ths'], row[fmode+'rsqs']
  182. if np.isnan([ms, ths, rsqs]).any(): ## TODO: does this throw out too many units?
  183. continue # skip this mseu
  184. # skip mseus that had no meaningful signal to fit in either condition:
  185. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  186. if maxsnr < SNRTHRESH:
  187. continue
  188. mmed, ml, mh = percentile_ci(ms)
  189. thmed, thl, thh = percentile_ci(ths)
  190. rsqmed = np.median(rsqs)
  191. #if rsqmed < RSQTHRESH: # skip mseus with really bad median fits
  192. # continue # skip
  193. if rsqmed < 0: # catastrophic fit failure
  194. print('catastrophic fit failure:', mseustr)
  195. continue # skip
  196. mmeds.append(mmed)
  197. thmeds.append(thmed)
  198. rsqmeds.append(rsqmed)
  199. # collect suppression mean burst ratio and SNR:
  200. mvirowis = ((mvirespr['mseu'] == mseustr) &
  201. (mvirespr['kind'] == kind) &
  202. (mvirespr['st8'] == st8) &
  203. (mvirespr['opto'] == True)) # only for suppression
  204. mvirow = mvirespr[mvirowis]
  205. assert len(mvirow) == 1
  206. br = mvirow['meanburstratio'].iloc[0]
  207. snr = mvirow['snr'].iloc[0]
  208. brs.append(br)
  209. snrs.append(snr)
  210. fig2['slope'][mseustr, fmode2txt[fmode]] = mmed
  211. fig2['thresh'][mseustr, fmode2txt[fmode]] = thmed
  212. fig2['rsq'][mseustr, fmode2txt[fmode]] = rsqmed
  213. fig2['suppression_meanburstratio'][mseustr, fmode2txt[fmode]] = br
  214. fig2['suppression_snr'][mseustr, fmode2txt[fmode]] = snr
  215. # collect separate lists of significant and insignificant points:
  216. if not (ml < 1 < mh):# and not (thl < 0 < thh):
  217. sgnfis.append(keptmseui) # slope significantly different from 1
  218. else:
  219. insgnfis.append(keptmseui) # slope not significantly different from 1
  220. if mseustr in mvimseu2exmpli:
  221. exmplis.append(keptmseui)
  222. exmplmseustrs.append(mseustr)
  223. print('Example Neuron', mvimseu2exmpli[mseustr], ':', mseustr)
  224. print('mmed, ml, mh:', mmed, ml, mh)
  225. print('thmed, thl, thh:', thmed, thl, thh)
  226. else:
  227. normlis.append(keptmseui)
  228. keptmseui += 1 # manually increment
  229. mmeds = np.asarray(mmeds)
  230. thmeds = np.asarray(thmeds)
  231. rsqmeds = np.asarray(rsqmeds)
  232. brs = np.asarray(brs)
  233. snrs = np.asarray(snrs)
  234. ## plot median fit slope vs thresh:
  235. f, a = plt.subplots(figsize=(figsize[0]*1.11, figsize[1])) # extra space for ylabels
  236. titlestr = ('PSTH %s fit slope vs thresh %s %s snrthresh=%.3f %s'
  237. % (model, kind, st8, SNRTHRESH, fmode2txt[fmode]))
  238. wintitle(titlestr, f)
  239. normlinsgnfis = np.intersect1d(normlis, insgnfis)
  240. normlsgnfis = np.intersect1d(normlis, sgnfis)
  241. # plot x=0 and y=1 lines:
  242. a.axvline(x=0, ls='--', marker='', color='gray', zorder=-np.inf)
  243. a.axhline(y=1, ls='--', marker='', color='gray', zorder=-np.inf)
  244. # plot normal (non-example) insignificant points:
  245. c = desat(st82clr[st8], SGNF2ALPHA[False]) # do manual alpha mixing
  246. a.scatter(thmeds[normlinsgnfis], mmeds[normlinsgnfis], clip_on=False,
  247. marker='.', c='None', edgecolor=c, s=DEFSZ)
  248. # plot normal (non-example) significant points:
  249. c = desat(st82clr[st8], SGNF2ALPHA[True]) # do manual alpha mixing
  250. a.scatter(thmeds[normlsgnfis], mmeds[normlsgnfis], clip_on=False,
  251. marker='.', c='None', edgecolor=c, s=DEFSZ)
  252. exmplinsgnfis = np.intersect1d(exmplis, insgnfis)
  253. exmplsgnfis = np.intersect1d(exmplis, sgnfis)
  254. print('exmplinsgnfis', exmplinsgnfis)
  255. print('examplsgnfis', exmplsgnfis)
  256. # plot insignificant and significant example points, one at a time:
  257. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  258. if exmpli in exmplinsgnfis:
  259. alpha = SGNF2ALPHA[False]
  260. elif exmpli in exmplsgnfis:
  261. alpha = SGNF2ALPHA[True]
  262. else:
  263. raise RuntimeError("Some kind of exmpli set membership error")
  264. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  265. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  266. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  267. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  268. a.scatter(thmeds[exmpli], mmeds[exmpli], clip_on=False, marker=marker, c=c,
  269. s=sz, lw=lw, alpha=alpha)
  270. # plot mean median point:
  271. if fmode == '': # all spikes, plot LMM mean
  272. print('plotting LMM mean for fmode=%s' % fmode)
  273. a.scatter(-0.19, 0.75, # read off of stats/fig2*.pdf
  274. c='red', edgecolor='red', s=50, marker='^')
  275. elif fmode == 'nb': # non burst spikes, plot LMM mean
  276. print('plotting LMM mean for fmode=%s' % fmode)
  277. a.scatter(0.09, 0.74, # read off of stats/fig2*.pdf
  278. c='red', edgecolor='red', s=50, marker='^')
  279. else:
  280. a.scatter(np.mean(thmeds), gmean(mmeds),
  281. c='red', edgecolor='red', s=50, marker='^')
  282. # display median of median rsq sample values:
  283. txt = '$\mathregular{R^{2}_{med}=}$%.2f' % np.round(np.median(rsqmeds), 2)
  284. a.add_artist(AnchoredText(txt, loc='lower right', frameon=False))
  285. #cbar = f.colorbar(path)
  286. #cbar.ax.set_xlabel('$\mathregular{R^{2}}$')
  287. #cbar.ax.xaxis.set_label_position('top')
  288. a.set_yscale('log', basey=2)
  289. a.set_xlabel('Threshold')
  290. a.set_ylabel('Slope')
  291. a.set_xlim(threshmin, threshmax)
  292. a.set_xticks([threshmin, 0, threshmax])
  293. a.set_ylim(2**slopelogmin, 2**slopelogmax)
  294. a.set_yticks(logyticks)
  295. axes_disable_scientific(a)
  296. a.spines['left'].set_position(('outward', 4))
  297. a.spines['bottom'].set_position(('outward', 4))
  298. ## plot fit rsq vs suppression meanburstratio:
  299. f, a = plt.subplots(figsize=figsize)
  300. titlestr = ('PSTH %s fit rsq vs suppression meanburstratio %s %s %s' %
  301. (model, kind, st8, fmode2txt[fmode]))
  302. wintitle(titlestr, f)
  303. a.scatter(brs[normlis], rsqmeds[normlis], clip_on=False,
  304. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  305. # plot example points, one at a time:
  306. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  307. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  308. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  309. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  310. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  311. a.scatter(brs[exmpli], rsqmeds[exmpli], clip_on=False,
  312. marker=marker, c=c, s=sz, lw=lw)
  313. # get fname of appropriate LMM .cvs file:
  314. if fmode == '': # use Steffen's LMM linregress fit, for fig2f
  315. fname = os.path.join('stats', 'figure_2f_coefs.csv')
  316. # fetch LMM linregress fit params from .csv:
  317. df = pd.read_csv(fname)
  318. mm = df['slope'][0]
  319. b = df['intercept'][0]
  320. x = np.array([np.nanmin(brs), np.nanmax(brs)])
  321. y = mm * x + b
  322. a.plot(x, y, '-', color='red') # plot linregress fit
  323. a.set_xlabel('Suppression BR')
  324. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  325. a.set_xlim(xmin=0)
  326. a.set_ylim(0, 1)
  327. a.set_yticks([0, 0.5, 1])
  328. a.spines['left'].set_position(('outward', 4))
  329. a.spines['bottom'].set_position(('outward', 4))
  330. ## plot fit rsq vs suppression SNR:
  331. f, a = plt.subplots(figsize=figsize)
  332. titlestr = ('PSTH %s fit rsq vs SNR %s %s %s' %
  333. (model, kind, st8, fmode2txt[fmode]))
  334. wintitle(titlestr, f)
  335. a.scatter(snrs[normlis], rsqmeds[normlis], clip_on=False,
  336. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  337. # plot example points, one at a time:
  338. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  339. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  340. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  341. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  342. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  343. a.scatter(snrs[exmpli], rsqmeds[exmpli], clip_on=False,
  344. marker=marker, c=c, s=sz, lw=lw)
  345. a.set_xlabel('Suppression SNR')
  346. a.set_ylabel('%s spikes $\mathregular{R^{2}}$' % fmode2txt[fmode].title())
  347. a.set_xlim(xmin=0)
  348. a.set_ylim(0, 1)
  349. a.set_yticks([0, 0.5, 1])
  350. a.spines['left'].set_position(('outward', 4))
  351. a.spines['bottom'].set_position(('outward', 4))
  352. ## plot suppression fit SNR vs meanburstratio:
  353. f, a = plt.subplots(figsize=figsize)
  354. titlestr = ('PSTH %s fit SNR vs meanburstratio %s %s %s' %
  355. (model, kind, st8, fmode2txt[fmode]))
  356. wintitle(titlestr, f)
  357. a.scatter(brs[normlis], snrs[normlis], clip_on=False,
  358. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  359. # plot example points, one at a time:
  360. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  361. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  362. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  363. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  364. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  365. a.scatter(brs[exmpli], snrs[exmpli], clip_on=False,
  366. marker=marker, c=c, s=sz, lw=lw)
  367. a.set_xlabel('Suppression BR')
  368. a.set_ylabel('Suppression SNR')
  369. a.set_xlim(xmin=0)
  370. a.set_ylim(ymin=0)
  371. a.set_yticks([0, 0.5, 1])
  372. a.spines['left'].set_position(('outward', 4))
  373. a.spines['bottom'].set_position(('outward', 4))
  374. # scatter plot resampled model fit rsq for nonburst vs all spikes, and nonrand vs all spikes:
  375. figsize = DEFAULTFIGURESIZE
  376. for kind in ['nat']:#MVIKINDS:
  377. for st8 in ['none']:#ALLST8S:
  378. # boolean pandas Series:
  379. rowis = (sampfitsr['kind'] == kind) & (sampfitsr['st8'] == st8)
  380. rows = sampfitsr[rowis]
  381. allrsqmeds, nbrsqmeds, nrrsqmeds = [], [], [] # median fit rsq values
  382. exmplis, exmplmseustrs, normlis = [], [], []
  383. keptmseui = 0 # manually init and increment instead of using enumerate()
  384. for i, row in rows.iterrows(): # for each mseu
  385. mseustr = row['mseu']
  386. allrsqs, nbrsqs, nrrsqs = row['rsqs'], row['nbrsqs'], row['nrrsqs']
  387. ## TODO: does this throw out too many units?:
  388. if np.isnan([allrsqs, nbrsqs, nrrsqs]).any():
  389. continue # skip this mseu
  390. # skip mseus that had no meaningful signal to fit in either condition:
  391. maxsnr = get_max_snr(mvirespr, mseustr, kind, st8)
  392. if maxsnr < SNRTHRESH:
  393. continue
  394. allrsqmed = np.median(allrsqs)
  395. nbrsqmed = np.median(nbrsqs)
  396. nrrsqmed = np.median(nrrsqs)
  397. #if rsqmed < RSQTHRESH: # skip mseus with really bad median fits
  398. # continue # skip
  399. if allrsqmed < 0 or nbrsqmed < 0 or nrrsqmed < 0: # catastrophic fit failure
  400. print('catastrophic fit failure:', mseustr)
  401. continue # skip
  402. allrsqmeds.append(allrsqmed)
  403. nbrsqmeds.append(nbrsqmed)
  404. nrrsqmeds.append(nrrsqmed)
  405. if mseustr in mvimseu2exmpli:
  406. exmplis.append(keptmseui)
  407. exmplmseustrs.append(mseustr)
  408. else:
  409. normlis.append(keptmseui)
  410. if nbrsqmed > 0.4 and allrsqmed < 0.4:
  411. print(mseustr, allrsqmed, nbrsqmed)
  412. keptmseui += 1 # manually increment
  413. allrsqmeds = np.asarray(allrsqmeds)
  414. nbrsqmeds = np.asarray(nbrsqmeds)
  415. nrrsqmeds = np.asarray(nrrsqmeds)
  416. # plot nonburst vs all rsq:
  417. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  418. titlestr = 'PSTH %s fit rsq nonburst vs all %s %s' % (model, kind, st8)
  419. wintitle(titlestr, f)
  420. linmax = np.vstack([allrsqmeds, nbrsqmeds]).max()
  421. xyline = [linmin, linmax], [linmin, linmax]
  422. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  423. a.scatter(allrsqmeds[normlis], nbrsqmeds[normlis], clip_on=False,
  424. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  425. # plot example points, one at a time:
  426. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  427. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  428. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  429. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  430. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  431. a.scatter(allrsqmeds[exmpli], nbrsqmeds[exmpli], clip_on=False,
  432. marker=marker, c=c, s=sz, lw=lw)
  433. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  434. a.set_ylabel('Non-burst spikes $\mathregular{R^{2}}$')
  435. a.set_xlim(linmin, 1)
  436. a.set_ylim(linmin, 1)
  437. a.set_xticks([0, 0.5, 1])
  438. a.set_yticks([0, 0.5, 1])
  439. a.set_aspect('equal')
  440. a.spines['left'].set_position(('outward', 4))
  441. a.spines['bottom'].set_position(('outward', 4))
  442. # plot nr vs all rsq:
  443. f, a = plt.subplots(figsize=(figsize[0]*1.05, figsize[1])) # extra space for ylabels
  444. titlestr = 'PSTH %s fit rsq nr vs all %s %s' % (model, kind, st8)
  445. wintitle(titlestr, f)
  446. linmax = np.vstack([allrsqmeds, nrrsqmeds]).max()
  447. xyline = [linmin, linmax], [linmin, linmax]
  448. a.plot(xyline[0], xyline[1], '--', color='gray', zorder=-np.inf)
  449. a.scatter(allrsqmeds[normlis], nrrsqmeds[normlis], clip_on=False,
  450. marker='.', c='None', edgecolor=st82clr[st8], s=DEFSZ)
  451. # plot example points, one at a time:
  452. for exmpli, mseustr in zip(exmplis, exmplmseustrs):
  453. marker = exmpli2mrk[mvimseu2exmpli[mseustr]]
  454. c = exmpli2clr[mvimseu2exmpli[mseustr]]
  455. sz = exmpli2sz[mvimseu2exmpli[mseustr]]
  456. lw = exmpli2lw[mvimseu2exmpli[mseustr]]
  457. a.scatter(allrsqmeds[exmpli], nrrsqmeds[exmpli], clip_on=False,
  458. marker=marker, c=c, s=sz, lw=lw)
  459. a.set_xlabel('All spikes $\mathregular{R^{2}}$')
  460. a.set_ylabel('Rand. rem. spikes $\mathregular{R^{2}}$')
  461. a.set_xlim(linmin, 1)
  462. a.set_ylim(linmin, 1)
  463. a.set_xticks([0, 0.5, 1])
  464. a.set_yticks([0, 0.5, 1])
  465. a.set_aspect('equal')
  466. a.spines['left'].set_position(('outward', 4))
  467. a.spines['bottom'].set_position(('outward', 4))