panels.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. """Create panels for figures. Helper module."""
  2. import numpy as np
  3. from .util import (stackmean, get_idx_match, contrast_log_scale, replace_imshow_ticklabels,
  4. ori_con_2d, circ_gaussian, hyper_ratio, wrapped_gaussian, von_mises)
  5. import matplotlib.pyplot as plt
  6. from cycler import cycler
  7. import datajoint as dj
  8. from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
  9. from statsmodels.distributions.empirical_distribution import ECDF
  10. from scipy.stats import sem
  11. def kernels(kerns, dtopt, dtrange, ax=None, key=None, prop='grat_orientation'):
  12. if ax is None:
  13. fig, ax = plt.subplots()
  14. kern = np.stack(kerns['t_kernel'].groupby(prop).apply(stackmean))
  15. propvals = np.unique(kerns.index.get_level_values(prop))
  16. if prop == 'grat_orientation':
  17. kern = np.concatenate((kern, kern[0][np.newaxis, :]))
  18. propvals = np.concatenate((propvals, np.array([180]))).astype(int)
  19. im = ax.imshow(kern[:, :30]*1000, origin='lower', aspect='auto')
  20. dtidx = np.where(np.isclose(dtrange, dtopt))[0]
  21. ax.vlines(dtidx, -0.5, kern.shape[0]-0.5)
  22. ax.set_ylabel(prop[5].capitalize() + prop[6:])
  23. if plt.rcParams['text.usetex']:
  24. ax.set_xlabel(r'$\delta t$ (s)')
  25. else:
  26. ax.set_xlabel(r'dt (s)')
  27. replace_imshow_ticklabels(ax, dtrange[:30], propvals)
  28. cbar = plt.colorbar(im, ax=ax)
  29. cbar.set_label('Firing (Hz)', rotation=270, labelpad=7)
  30. return cbar
  31. def opt_dt_distr_exc_inh(exc_t_opts, inh_t_opts, t_opts, ax=None, excitcol='darkorange',
  32. inhibcol='teal'):
  33. if ax is None:
  34. fig, ax = plt.subplots()
  35. _, bins, _ = ax.hist(exc_t_opts, color=excitcol, bins='doane')
  36. ax.hist(inh_t_opts, color=inhibcol, bins=bins)
  37. xlim = ax.get_xlim()
  38. ax.set_xlim([0, xlim[1]])
  39. ax7a = ax.twinx()
  40. ecdf = ECDF(t_opts)
  41. ax7a.plot(ecdf.x, ecdf.y, color='k')
  42. ax7a.set_ylim([0, 1.01])
  43. if plt.rcParams['text.usetex']:
  44. ax.set_xlabel(r'Optimal $\delta t$ (s)')
  45. else:
  46. ax.set_xlabel('Optimal dt (s)')
  47. ax.set_ylabel('N')
  48. ax7a.set_ylabel('Proportion')
  49. ax7a.set_yticks([0, 0.5, 1])
  50. ax7a.spines['right'].set_bounds(0, 1)
  51. ax7a.spines['right'].set_visible(True)
  52. ax.spines['bottom'].set_bounds(*ax.get_xlim())
  53. def oricon_matrix(ax, data, cbar=False):
  54. df = data
  55. m = df.pivot_table('p_spike_stim', 'grat_orientation', 'grat_contrast')
  56. m = np.concatenate((m.values, m.values[0][np.newaxis, :]))*1000 # convert to Hz
  57. im = ax.imshow(m, origin='lower')
  58. if cbar:
  59. bar = plt.colorbar(im, ax=ax)
  60. bar.set_label('Firing rate (Hz)')
  61. else:
  62. bar = None
  63. ax.set_xlabel('Contrast')
  64. ax.set_ylabel('Orientation ($^{\circ}$)')
  65. contrasts = np.unique(df.index.get_level_values('grat_contrast')).round(2)
  66. orientations = np.linspace(0, 180, 13, endpoint=True).astype(int)
  67. replace_imshow_ticklabels(ax, contrasts, orientations)
  68. return ax, bar
  69. def wave_shapes(wforms, ax=None, excitcol='darkorange', inhibcol='teal', annotate=False):
  70. if ax is None:
  71. fig, ax = plt.subplots()
  72. wforms['wave'] = wforms.apply(
  73. lambda x: x['u_wave'][np.where(x['u_chans'] == x['u_maxchan'])[0][0]], axis=1)
  74. wforms['normwave'] = wforms.wave.apply(lambda x: x / np.abs(np.min(x)))
  75. wforms['ip_wave'] = wforms.normwave.apply(
  76. lambda x: np.interp(np.linspace(0, len(x), 120), np.arange(len(x)), x))
  77. wforms['al_wave'] = wforms.ip_wave.apply(lambda x: np.roll(x, 29 - np.argmin(x)))
  78. wtime = np.linspace(0, 120 * 0.017, 120)
  79. colors = [excitcol, inhibcol]
  80. for c, (gname, g) in zip(colors, wforms.groupby('unit_type')):
  81. ax.plot(wtime, np.stack(g['al_wave']).T, color=c, alpha=0.1)
  82. exmean = np.mean(wforms.query('unit_type=="excit"')['al_wave'], axis=0)
  83. inmean = np.mean(wforms.query('unit_type=="inhib"')['al_wave'], axis=0)
  84. ax.plot(wtime, exmean, color=excitcol)
  85. ax.plot(wtime, inmean, color=inhibcol)
  86. ax.set_xlabel('Time (ms)')
  87. ax.set_ylabel('Voltage')
  88. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  89. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  90. if annotate:
  91. ax.vlines(wtime[np.argmin(inmean)], np.min(inmean), 1.2*np.max(inmean),
  92. linestyles='dashed', color='k', zorder=3, linewidth=1)
  93. ax.vlines(wtime[np.argmax(inmean)], np.max(inmean), 1.2*np.max(inmean),
  94. linestyles='dashed', color='k', zorder=3, linewidth=1)
  95. ax.hlines(1.2*np.max(inmean), wtime[np.argmin(inmean)], wtime[np.argmax(inmean)],
  96. color='k', zorder=3, linewidth=1)
  97. peak = np.where(inmean >= 0.5*np.max(inmean))[0]
  98. ax.vlines(wtime[peak[0]], 0.5*np.max(inmean), -0.5, linestyles='dashed', color='k',
  99. zorder=3, linewidth=1)
  100. ax.vlines(wtime[peak[-1]], 0.5*np.max(inmean), -0.5, linestyles='dashed', color='k',
  101. zorder=3, linewidth=1)
  102. ax.hlines(-0.5, wtime[peak[0]], wtime[peak[-1]], color='k', zorder=3, linewidth=1)
  103. def wave_shapes_lgn(wforms, ax=None, color='k', annotate=False):
  104. if ax is None:
  105. fig, ax = plt.subplots()
  106. wforms['wave'] = wforms.apply(
  107. lambda x: x['u_wave'][np.where(x['u_chans'] == x['u_maxchan'])[0][0]], axis=1)
  108. wforms['normwave'] = wforms.wave.apply(lambda x: x / np.abs(np.min(x)))
  109. wforms['ip_wave'] = wforms.normwave.apply(
  110. lambda x: np.interp(np.linspace(0, len(x), 120), np.arange(len(x)), x))
  111. wforms['al_wave'] = wforms.ip_wave.apply(lambda x: np.roll(x, 29 - np.argmin(x)))
  112. wtime = np.linspace(0, 120 * 0.017, 120)
  113. ax.plot(wtime, np.stack(wforms['al_wave']).T, color=color, alpha=0.1)
  114. mean = np.mean(wforms['al_wave'], axis=0)
  115. ax.plot(wtime, mean, color=color)
  116. ax.set_xlabel('Time (ms)')
  117. ax.set_ylabel('Voltage')
  118. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  119. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  120. def wave_props(non_rc_props, rc_props, ax=None, excitcol='darkorange', inhibcol='teal', markeredgewidth=0.2,
  121. **plotkwargs):
  122. if ax is None:
  123. fig, ax = plt.subplots()
  124. colors = [excitcol, inhibcol]
  125. for c, (gname, g) in zip(colors, non_rc_props.groupby('unit_type')):
  126. ax.plot(g['fwhm_peak'], g['t2p_time'], '.', color=c, alpha=0.25, markersize=4, **plotkwargs)
  127. for c, (gname, g) in zip(colors, rc_props.groupby('unit_type')):
  128. ax.plot(g['fwhm_peak'], g['t2p_time'], '.', color=c,
  129. markeredgewidth=markeredgewidth, alpha=0.5, markersize=7, **plotkwargs)
  130. ax.set_xlabel('Fwhm_peak (ms)')
  131. ax.set_ylabel('T2p_time (ms)')
  132. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  133. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  134. def wave_props_lgn(non_rc_props, rc_props, ax=None, color='k', markeredgewidth=0.2,
  135. **plotkwargs):
  136. if ax is None:
  137. fig, ax = plt.subplots()
  138. ax.plot(non_rc_props['fwhm_peak'], non_rc_props['t2p_time'], '.', color=color, alpha=0.25, markersize=4, **plotkwargs)
  139. ax.plot(rc_props['fwhm_peak'], rc_props['t2p_time'], '.', color=color,
  140. markeredgewidth=markeredgewidth, alpha=0.5, markersize=7, **plotkwargs)
  141. ax.set_xlabel('Fwhm_peak (ms)')
  142. ax.set_ylabel('T2p_time (ms)')
  143. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  144. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  145. def daterrsvd(datdf, svddf, dtrange, axs=None, cbar=False, cbaraxs=None):
  146. if axs is None:
  147. fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
  148. fig.suptitle(key)
  149. else:
  150. assert len(axs) == 3
  151. ax1, ax2, ax3 = axs
  152. # prepare data
  153. dat = datdf.pivot_table('p_spike_stim', 'grat_orientation', 'grat_contrast').values
  154. # prepare svd
  155. svddf['p_svd'] = svddf.apply(lambda x: x.svd_s_kernel[get_idx_match(x.dt_maxresponse,
  156. dtrange)], axis=1)
  157. svddf['p_err'] = svddf.apply(lambda x: x.err_s_kernel[get_idx_match(x.dt_maxresponse,
  158. dtrange)], axis=1)
  159. svd = svddf.pivot_table('p_svd', 'grat_orientation', 'grat_contrast').values
  160. svderr = svddf.pivot_table('p_err', 'grat_orientation', 'grat_contrast').values
  161. # wrap orientation domain
  162. dat = np.vstack((dat, dat[0, :]))
  163. svd = np.vstack((svd, svd[0, :]))
  164. svderr = np.vstack((svderr, svderr[0, :]))
  165. # prepare ticklabels
  166. oris = np.linspace(0, 180, 13).astype(int)
  167. cons = np.unique(datdf.index.get_level_values('grat_contrast'))
  168. # plot
  169. ims = []
  170. for ax, mat in zip([ax1, ax2, ax3], [dat, svd, svderr]):
  171. im = ax.imshow(mat*1000, origin='lower') # go from probability to Hz
  172. ims.append(im)
  173. replace_imshow_ticklabels(ax, cons, oris)
  174. # adjust labels
  175. ax2.set_yticklabels([])
  176. ax3.set_yticklabels([])
  177. ax1.set_ylabel('Orientation')
  178. for ax in [ax1, ax2, ax3]:
  179. ax.set_xlabel('Contrast')
  180. if cbar:
  181. if cbaraxs is None:
  182. cb1 = plt.colorbar(ims[1], ax=axs[1])
  183. cb2 = plt.colorbar(ims[2], ax=axs[2])
  184. else:
  185. cb1 = plt.colorbar(ims[1], cbaraxs[0])
  186. cb2 = plt.colorbar(ims[2], cbaraxs[1])
  187. return cb1, cb2
  188. def svd_vs_spatcorr(df, ax=None, exc_color='darkorange',
  189. inh_color='teal'):
  190. if ax is None:
  191. fig, ax = plt.subplots()
  192. inv_df = df.query('g_z_opt<=1.96 or power_opt>0.95')
  193. dep_df = df.query('g_z_opt>1.96 and power_opt<=0.95')
  194. e_inv_df = inv_df.query('unit_type=="excit"')
  195. i_inv_df = inv_df.query('unit_type=="inhib"')
  196. e_dep_df = dep_df.query('unit_type=="excit"')
  197. i_dep_df = dep_df.query('unit_type=="inhib"')
  198. ax.scatter(1-e_inv_df['power_opt'], e_inv_df['g_z_opt'], color=exc_color)
  199. ax.scatter(1-i_inv_df['power_opt'], i_inv_df['g_z_opt'], color=inh_color)
  200. ax.scatter(1-e_dep_df['power_opt'], e_dep_df['g_z_opt'], color=exc_color, alpha=0.5)
  201. ax.scatter(1-i_dep_df['power_opt'], i_dep_df['g_z_opt'], color=inh_color, alpha=0.5)
  202. ax.hlines(1.96, 0, 0.95, linestyles='dashed', alpha=0.5)
  203. ax.vlines(0.05, 1.96, np.max(df['g_z_opt']), linestyles='dashed',
  204. alpha=0.5)
  205. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  206. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  207. ax.set_ylabel('g_z_opz')
  208. ax.set_xlabel('svd_power')
  209. def svd_vs_spatcorr_lgn(df, ax=None, color='k', draw_c_inv_borders=True):
  210. if ax is None:
  211. fig, ax = plt.subplots()
  212. inv_df = df.query('g_z_opt<=1.96 or power_opt>0.95')
  213. dep_df = df.query('g_z_opt>1.96 and power_opt<=0.95')
  214. ax.scatter(1-inv_df['power_opt'], inv_df['g_z_opt'], color=color)
  215. ax.scatter(1-dep_df['power_opt'], dep_df['g_z_opt'], color=color, alpha=0.5)
  216. if draw_c_inv_borders:
  217. ax.hlines(1.96, 0.05, 0.95, linestyles='dashed', alpha=0.5)
  218. ax.vlines(0.05, 1.96, np.max(df['g_z_opt']), linestyles='dashed',
  219. alpha=0.5)
  220. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  221. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  222. ax.set_ylabel('g_z_opz')
  223. ax.set_xlabel('svd_power')
  224. def contrast_amp_distr(exc_data, inh_data, ax=None, force_align_bins=True,
  225. excit_col='darkorange', inhib_col='teal'):
  226. if ax is None:
  227. fig, ax = plt.subplots()
  228. tables = [exc_data, inh_data]
  229. colors = [excit_col, inhib_col]
  230. hists = []
  231. for i, (data, color) in enumerate(zip(tables, colors)):
  232. rmax, r0 = data.rmax.values, data.r0.values
  233. amp = (rmax - r0)*1000
  234. create_bins_with_first_hist = i == 0
  235. if create_bins_with_first_hist:
  236. if force_align_bins:
  237. bins = np.histogram_bin_edges(amp, bins='fd')
  238. bins = bins - bins[np.argmin(np.abs(bins))]
  239. h = np.histogram(amp, bins=bins)
  240. else:
  241. h = np.histogram(amp, bins='fd')
  242. split = sum(h[1] <= 0) - 1
  243. ax.hist(h[1][:split], h[1][:split+1], weights=h[0][:split], alpha=0.5, color=color)
  244. ax.hist(h[1][split:-1], h[1][split:], weights=h[0][split:], alpha=1, color=color)
  245. hists.append(h)
  246. else:
  247. h = np.histogram(amp, bins=hists[0][1])
  248. split = sum(h[1] <= 0) - 1
  249. ax.hist(h[1][:split], h[1][:split+1], weights=h[0][:split], alpha=0.5, color=color)
  250. ax.hist(h[1][split:-1], h[1][split:], weights=h[0][split:], alpha=1, color=color)
  251. hists.append(h)
  252. ax.vlines(0, 0, np.max(np.stack([h[0] for h in hists])), linestyle='dashed', alpha=0.5)
  253. ax.set_xlabel('Response amplitude (Hz)')
  254. ax.set_ylabel('N')
  255. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  256. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  257. def marginal_tuning(datdf, pars, axs=None, tuning_model='wrap_gauss'):
  258. if axs is None:
  259. gridspec_kw = {'width_ratios': [2, 1], 'height_ratios': [1, 3]}
  260. fig, [[ax1, delax], [ax2, ax3]] = plt.subplots(2, 2, gridspec_kw=gridspec_kw,
  261. tight_layout=True)
  262. fig.delaxes(delax)
  263. else:
  264. assert len(axs) == 3
  265. ax1, ax2, ax3 = axs
  266. # data
  267. datmat = datdf.pivot_table('p_spike_stim', 'grat_orientation', 'grat_contrast')
  268. firstrow = datmat.loc[:0, :]
  269. firstrow = firstrow.rename(index={0: 180.0})
  270. datmat = datmat.append(firstrow)
  271. ax2.imshow(datmat.values, origin='lower', aspect='auto')
  272. # smooth fit
  273. nfitpoints = 100
  274. smoothoris = np.linspace(0, 180, nfitpoints)
  275. smoothcons = contrast_log_scale(10, nfitpoints)
  276. smoothstim = np.meshgrid(smoothoris, smoothcons)
  277. fitmat = ori_con_2d(smoothstim, *pars, orifunc=tuning_model)
  278. # marginals
  279. # contrast marginal data points
  280. cons = np.unique(datdf.index.get_level_values('grat_contrast'))
  281. ax1.plot(np.arange(len(cons)), datmat.mean(axis=0)*1000, '.', color='k')
  282. # contrast marginal
  283. con_idcs = np.linspace(0, len(cons)-1, nfitpoints)
  284. ax1.plot(con_idcs, fitmat.mean(axis=0)*1000, color='k')
  285. # orientation marginal data points
  286. oris = np.unique(datmat.index.get_level_values('grat_orientation'))
  287. ax3.plot(datmat.mean(axis=1)*1000, np.arange(len(oris)), '.', color='k')
  288. # orientation marginal
  289. oriidcs = np.linspace(0, len(oris)-1, nfitpoints)
  290. ax3.plot(fitmat.mean(axis=1)*1000, oriidcs, color='k')
  291. # fix labels
  292. ax1.set_xticklabels([])
  293. ax1.set_ylabel('Response')
  294. ax3.set_yticklabels([])
  295. ax3.set_xlabel('Response')
  296. replace_imshow_ticklabels(ax2, cons, oris.astype(int))
  297. ax2.set_ylabel('Orientation ($^{\circ}$)')
  298. ax2.set_xlabel('Contrast')
  299. def goodness_of_fit(v1_r2s, ax=None, excit_col='darkorange', inhib_col='teal'):
  300. if ax is None:
  301. fig, ax = plt.subplots()
  302. h = np.histogram(v1_r2s.query('unit_type=="excit"')['r2'].values, bins='fd')
  303. split = sum(h[1] <= 0.4)
  304. ax.hist(h[1][:split], h[1][:split + 1], weights=h[0][:split], alpha=0.5, color=excit_col)
  305. ax.hist(h[1][split:-1], h[1][split:], weights=h[0][split:], alpha=1, color=excit_col)
  306. h2 = np.histogram(v1_r2s.query('unit_type=="inhib"')['r2'].values, bins=h[1])
  307. ax.hist(h2[1][:split], h2[1][:split+1], weights=h2[0][:split], alpha=0.5, color=inhib_col)
  308. ax.hist(h2[1][split:-1], h2[1][split:], weights=h2[0][split:], alpha=1, color=inhib_col)
  309. ax.vlines(0.4, 0, np.max(np.stack([h[0], h2[0]])), linestyles='dashed', alpha=0.5)
  310. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  311. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  312. ax.set_xlabel('Goodness of fit (r²)')
  313. ax.set_ylabel('N')
  314. def goodness_of_fit_lgn(v1_r2s, ax=None, color='k'):
  315. if ax is None:
  316. fig, ax = plt.subplots()
  317. h = np.histogram(v1_r2s['r2'].values, bins='fd')
  318. split = sum(h[1] <= 0.4)
  319. ax.hist(h[1][:split], h[1][:split + 1], weights=h[0][:split], alpha=0.5, color=color)
  320. ax.hist(h[1][split:-1], h[1][split:], weights=h[0][split:], alpha=1, color=color)
  321. ax.vlines(0.4, 0, np.max(np.stack([h[0]])), linestyles='dashed', alpha=0.5)
  322. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  323. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  324. ax.set_xlabel('Goodness of fit (r²)')
  325. ax.set_ylabel('N')
  326. def population_tuning(tuning_widths, ax=None, color='darkorange'):
  327. if ax is None:
  328. fig, ax = plt.subplots()
  329. oris = np.linspace(0, 180, 100)
  330. labeloris = np.linspace(-90, 90, 100)
  331. for tuning_width in tuning_widths:
  332. ax.plot(labeloris, wrapped_gaussian(oris, sigma=tuning_width), color=color, alpha=0.2)
  333. ax.set_xlabel('Orientation ($^{\circ}$)')
  334. ax.set_ylabel('Response')
  335. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  336. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  337. def population_tuning_von_mises(tuning_widths, ax=None, color='darkorange'):
  338. if ax is None:
  339. fig, ax = plt.subplots()
  340. oris = np.linspace(0, 180, 100)
  341. labeloris = np.linspace(-90, 90, 100)
  342. for tuning_width in tuning_widths:
  343. ax.plot(labeloris, von_mises(oris, kappa=tuning_width), color=color, alpha=0.2)
  344. ax.set_xlabel('Orientation ($^{\circ}$)')
  345. ax.set_ylabel('Response')
  346. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  347. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  348. def population_contrast(pardf, ax=None, color='k', alpha=0.5):
  349. if ax is None:
  350. fig, ax = plt.subplots()
  351. cons = np.linspace(0, 1, 100)
  352. for uname, pars in pardf.iterrows():
  353. conres = hyper_ratio(cons, 1, pars['c50'], pars['n'], 0, pars['sup'])
  354. conres = conres/np.max(conres)
  355. ax.plot(cons, conres, color=color, alpha=alpha)
  356. ax.set_xlabel('Contrast')
  357. ax.set_ylabel('Response')
  358. ax.spines['left'].set_bounds(ax.dataLim.ymin, ax.dataLim.ymax)
  359. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  360. def osi_cumulative(osidf, ax=None, unit_type='excit', region='V1',
  361. excit_col='darkorange', inhib_col='teal'):
  362. if ax is None:
  363. fig, ax = plt.subplots()
  364. # plot OSI distribution
  365. osis = osidf.query(f's_region=="{region}" and unit_type=="{unit_type}"').opt_osi
  366. colors = {'excit': excit_col, 'inhib': inhib_col}
  367. ecdf = ECDF(osis)
  368. ax.step(ecdf.x, ecdf.y, color=colors[unit_type])
  369. ax.set_xlabel('Orientation selectivity index')
  370. ax.set_ylabel('Fraction of units')
  371. ax.set_xlim(0 - ax.margins()[0], 1 + ax.margins()[1])
  372. ax.spines['bottom'].set_bounds(0, 1)
  373. ax.spines['left'].set_bounds(0, 1)
  374. def osi_cumulative_lgn(osidf, ax=None, color='k'):
  375. if ax is None:
  376. fig, ax = plt.subplots()
  377. # plot OSI distribution
  378. osis = osidf.opt_osi
  379. ecdf = ECDF(osis)
  380. ax.step(ecdf.x, ecdf.y, color=color)
  381. ax.set_xlabel('Orientation selectivity index')
  382. ax.set_ylabel('Fraction of units')
  383. ax.set_xlim(0 - ax.margins()[0], 1 + ax.margins()[1])
  384. ax.spines['bottom'].set_bounds(0, 1)
  385. ax.spines['left'].set_bounds(0, 1)
  386. def osi_hist(osidf, ax=None, unit_type='excit', region='V1', keys=None,
  387. excit_col='darkorange', inhib_col='teal'):
  388. if ax is None:
  389. fig, ax = plt.subplots()
  390. # plot OSI distribution
  391. osis = osidf.query(f's_region=="{region}" and unit_type=="{unit_type}"').opt_osi
  392. colors = {'excit': excit_col, 'inhib': inhib_col}
  393. ax.hist(osis, bins='fd', color=colors[unit_type], density=True, histtype='step')
  394. ax.set_xlabel('Orientation selectivity index')
  395. ax.set_ylabel('Nunits')
  396. def osi_hist_lgn(osidf, ax=None, color='k'):
  397. if ax is None:
  398. fig, ax = plt.subplots()
  399. # plot OSI distribution
  400. osis = osidf.opt_osi
  401. ax.hist(osis, bins='fd', color=color, density=True, histtype='step')
  402. ax.set_xlabel('Orientation selectivity index')
  403. ax.set_ylabel('Nunits')
  404. def contrast_cumulative(pardf, ax=None, color='darkorange'):
  405. if ax is None:
  406. fig, ax = plt.subplots()
  407. cons = np.linspace(0, 1, 100)
  408. c50s = []
  409. for uname, pars in pardf.iterrows():
  410. conres = hyper_ratio(cons, 1, pars['c50'], pars['n'], 0, pars['sup'])
  411. conres = conres/np.max(conres)
  412. c50idx = np.where(conres > 0.5)[0][0]
  413. c50s.append(cons[c50idx])
  414. ecdf = ECDF(c50s)
  415. ax.step(ecdf.x, ecdf.y, color=color)
  416. ax.set_xlabel('Contrast at half-height')
  417. ax.set_ylabel('Fraction of units')
  418. ax.set_xlim(0 - ax.margins()[0], 1 + ax.margins()[1])
  419. ax.spines['bottom'].set_bounds(0, 1)
  420. ax.spines['left'].set_bounds(0, 1)
  421. def mean_population_contrast(pardf, ax=None, color='k'):
  422. if ax is None:
  423. fig, ax = plt.subplots()
  424. cons = np.linspace(0, 1, 100)
  425. res = []
  426. for uname, pars in pardf.iterrows():
  427. conres = hyper_ratio(cons, pars['rmax'], pars['c50'], pars['n'], pars['r0'],
  428. pars['sup'])
  429. res.append(conres)
  430. resmat = np.stack(res)*1000
  431. pop_conres = resmat.mean(axis=0)
  432. res_sem = sem(resmat, axis=0)
  433. ax.plot(cons, pop_conres, color=color)
  434. ax.fill_between(cons, pop_conres + res_sem, pop_conres - res_sem, color=color,
  435. alpha=0.5)
  436. ax.set_xlabel('Contrast')
  437. ax.set_ylabel('Response')
  438. ax.spines['left'].set_bounds(0, ax.dataLim.ymax)
  439. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  440. _, ymax = ax.get_ylim()
  441. ax.set_ylim([0, ymax])
  442. def mean_norm_population_tuning(tuning_widths, ax=None, color='darkorange'):
  443. if ax is None:
  444. fig, ax = plt.subplots()
  445. oris = np.linspace(0, 180, 100)
  446. labeloris = np.linspace(-90, 90, 100)
  447. res = []
  448. for tw in tuning_widths:
  449. res.append(wrapped_gaussian(oris, sigma=tw))
  450. resmat = np.stack(res)
  451. mean_res = resmat.mean(axis=0)
  452. sem_res = sem(resmat, axis=0)
  453. ax.plot(labeloris, mean_res, color=color)
  454. ax.fill_between(labeloris, mean_res + sem_res, mean_res - sem_res,
  455. color=color, alpha=0.5)
  456. ax.set_xlabel('Orientation ($^{\circ}$)')
  457. ax.set_ylabel('Response')
  458. ax.spines['left'].set_bounds(0, ax.dataLim.ymax)
  459. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  460. _, ymax = ax.get_ylim()
  461. ax.set_ylim([0, ymax])
  462. def mean_norm_population_tuning_von_mises(tuning_widths, ax=None, color='darkorange'):
  463. if ax is None:
  464. fig, ax = plt.subplots()
  465. oris = np.linspace(0, 180, 100)
  466. labeloris = np.linspace(-90, 90, 100)
  467. res = []
  468. for tw in tuning_widths:
  469. res.append(von_mises(oris, kappa=tw))
  470. resmat = np.stack(res)
  471. mean_res = resmat.mean(axis=0)
  472. sem_res = sem(resmat, axis=0)
  473. ax.plot(labeloris, mean_res, color=color)
  474. ax.fill_between(labeloris, mean_res + sem_res, mean_res - sem_res,
  475. color=color, alpha=0.5)
  476. ax.set_xlabel('Orientation ($^{\circ}$)')
  477. ax.set_ylabel('Response')
  478. ax.spines['left'].set_bounds(0, ax.dataLim.ymax)
  479. ax.spines['bottom'].set_bounds(ax.dataLim.xmin, ax.dataLim.xmax)
  480. _, ymax = ax.get_ylim()
  481. ax.set_ylim([0, ymax])