fig4.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # code for Figure 4 panels
  2. # import libs
  3. import matplotlib
  4. from matplotlib import pyplot as plt
  5. from matplotlib.ticker import PercentFormatter
  6. import numpy as np
  7. import pandas
  8. import seaborn as sns
  9. from scipy import stats
  10. from importlib import reload
  11. import pickle
  12. import spatint_utils
  13. import os
  14. # reload modules
  15. reload(spatint_utils)
  16. # define variables
  17. trn_red, lgn_green, _ = spatint_utils.get_colors()
  18. spatint_utils.plot_params()
  19. class Fig4:
  20. """Class to for plotting panels for Fig. 4"""
  21. def __init__(self):
  22. """Init class"""
  23. parentdir = os. path. dirname(os. getcwd())
  24. filename = parentdir + '/data/params_mouse.yaml'
  25. # read trn size tuning dataframe
  26. self.trn_sztun_df = pandas.read_pickle(
  27. filepath_or_buffer= parentdir + '/data/trn_sztun_df.pkl')
  28. # read dict for trn_sztun_ex
  29. with open(parentdir + '/data/trn_sztun_ex_dict.pkl', 'rb') as f:
  30. self.trn_sztun_ex_dict = pickle.load(f)
  31. # read trn retinotopy dataframe
  32. self.trn_retino_df = pandas.read_pickle(
  33. filepath_or_buffer=parentdir + '/data/trn_retino_df.pkl')
  34. # get data for trn retinotopy example series
  35. self.trn_retino_ex = self.trn_retino_df[(self.trn_retino_df.m == 'BL6_2018_0003') &
  36. (self.trn_retino_df.s == 7)]
  37. # read trn/lgn rf area dataframe
  38. self.rf_area_df = pandas.read_pickle(
  39. filepath_or_buffer=parentdir + '/data/rf_area_df.pkl')
  40. # read lgn size tuning dataframe
  41. self.lgn_sztun_df = pandas.read_pickle(
  42. filepath_or_buffer=parentdir + '/data/lgn_sztun_df.pkl')
  43. def exrfs(self):
  44. """Plot example receptive fields (Fig. 4d,h)
  45. Returns
  46. -------
  47. axs: list
  48. list with axes for example rfs
  49. """
  50. # define keys for example rfs
  51. exrfs = [
  52. # trn example rfs
  53. {'m': 'BL6_2018_0003', 's': 7, 'e': 1, 'u': 25},
  54. {'m': 'BL6_2018_0003', 's': 4, 'e': 1, 'u': 9},
  55. {'m': 'PVCre_2018_0009', 's': 4, 'e': 10, 'u': 10},
  56. {'m': 'BL6_2018_0003', 's': 2, 'e': 1, 'u': 17},
  57. {'m': 'BL6_2018_0003', 's': 3, 'e': 1, 'u': 12},
  58. # dlGN example rfs
  59. {'m': 'Ntsr1Cre_2019_0003', 's': 4, 'e': 1, 'u': 7},
  60. {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 23},
  61. {'m': 'Ntsr1Cre_2018_0003', 's': 2, 'e': 1, 'u': 29}]
  62. # init figure
  63. f, axs = plt.subplots(1, 8, figsize=(7, 1.5))
  64. for i, (ax, exrf) in enumerate(zip(axs, exrfs)):
  65. # get target row in df
  66. ex_row = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
  67. (self.rf_area_df.s == exrf['s']) &
  68. (self.rf_area_df.e == exrf['e']) &
  69. (self.rf_area_df.u == exrf['u'])]
  70. if i == len(exrfs) - 1:
  71. scalebar = {'dist': 5, 'width': 1, 'length': 20, 'col': 'w'}
  72. else:
  73. scalebar = {'width': None, 'dist': None, 'length': None, 'col': None}
  74. # plot rfs
  75. spatint_utils.plot_rf(means=ex_row.mean_fr.values[0],
  76. stim_param_values=ex_row.ti_axes.values[0],
  77. grat_width=ex_row.grat_width,
  78. grat_height=ex_row.grat_height,
  79. ax=ax,
  80. scalebar=scalebar)
  81. # add labels
  82. if i == 0:
  83. ax.set_ylabel('visTRN')
  84. elif i == 5:
  85. ax.set_ylabel('dLGN')
  86. f = plt.gcf()
  87. f.tight_layout()
  88. # describe two example rfs
  89. self.desc_exrfs(exrfs)
  90. return axs
  91. def desc_exrfs(self, exrfs):
  92. """Describe two example RFs for Fig 4d
  93. Parameters
  94. -----
  95. exrfs: list
  96. example keys
  97. """
  98. labels = ['small', 'large']
  99. for i, label in zip(np.array([3, 4]), labels):
  100. exrf = exrfs[i]
  101. ex_area = self.rf_area_df[(self.rf_area_df.m == exrf['m']) &
  102. (self.rf_area_df.s == exrf['s']) &
  103. (self.rf_area_df.e == exrf['e']) &
  104. (self.rf_area_df.u == exrf['u'])]
  105. print('%s area= %0.3f \n'
  106. '%s rsq = %0.3f \n'
  107. % (label, ex_area.area, label, ex_area.add_rsq))
  108. def retino_ex(self, figsize=(2.5, 2.5), axs=None):
  109. """Plot RFs map of example trn recording (Fig. 4e)
  110. Parameters
  111. -------
  112. figsize: tuple
  113. Figure size (width, height)
  114. axs: list
  115. two axes used for plot and colorbar
  116. Returns
  117. -------
  118. axs: list
  119. two axes with plot and colorbar
  120. """
  121. if axs is None:
  122. # create figure
  123. f = plt.figure(figsize=figsize)
  124. axs = []
  125. # axis for plot
  126. l = 0.15
  127. b = 0.15
  128. w = 0.6
  129. h = 0.6
  130. axs.append(f.add_axes([l, b, w, h]))
  131. # axis for colorbar
  132. l = 0.8
  133. w_cbar = 0.05
  134. axs.append(f.add_axes([l, b, w_cbar, h]))
  135. # print number of units in series
  136. exseries = self.trn_retino_ex
  137. print('number of units in example series:', len(exseries))
  138. # define colors
  139. vmin = 0.3
  140. vmax = 0.95
  141. colors = np.linspace(vmin, vmax, len(exseries))
  142. mymap = plt.cm.get_cmap("Reds")
  143. my_colors = mymap(colors)
  144. # iterate over units
  145. for urowi, (_, urow) in enumerate(exseries.iterrows()):
  146. # compute parameters
  147. params = urow.params
  148. angle = params[2] * 180 / np.pi
  149. fitpars_deg = spatint_utils.degdiff(180, angle, 180)
  150. # calculate ellipse
  151. x, y = spatint_utils.calculate_ellipse(params[0], params[1], params[4], params[5],
  152. fitpars_deg)
  153. # plot
  154. axs[0].plot(x, y, c=my_colors[urowi])
  155. # layout
  156. axs[0].set_xlim(-20, 70)
  157. axs[0].set_ylim(-20, 70)
  158. axs[0].set_xticks((0, 30, 60))
  159. axs[0].set_yticks((0, 30, 60))
  160. axs[0].spines['bottom'].set_bounds(-10, 70)
  161. axs[0].spines['left'].set_bounds(-10, 70)
  162. axs[0].set_ylabel('Elevation ($\degree$)')
  163. axs[0].set_xlabel('Azimuth ($\degree$)')
  164. # draw colorbar
  165. sm = matplotlib.colors.LinearSegmentedColormap.from_list('Reds', my_colors)
  166. vmin_plot = 0.15
  167. vmax_plot = 1
  168. norm = matplotlib.colors.Normalize(vmin=vmin_plot, vmax=vmax_plot)
  169. cbar = matplotlib.colorbar.ColorbarBase(axs[1], cmap=sm, norm=norm)
  170. cbar.set_label('Depth (μm)', rotation=270)
  171. cbar.ax.invert_yaxis()
  172. barunits = (vmax - vmin) / (exseries.depth.iloc[-1] - exseries.depth.iloc[0])
  173. ticklabels = np.array([3050, 3100, 3150])
  174. ticks = ((ticklabels - exseries.depth.iloc[0]) * barunits) + vmin
  175. cbar.set_ticks(ticks)
  176. cbar.set_ticklabels(ticklabels)
  177. f = plt.gcf()
  178. f.tight_layout()
  179. return axs
  180. def trn_retino(self, figsize=(6, 2.5), axs=None):
  181. """Plot RFs map of example trn recording (Fig. 4f-g)
  182. Parameters
  183. -------
  184. figsize: tuple
  185. Figure size (width, height)
  186. axs: list
  187. two axes, one per dimension (azim, elev)
  188. Returns
  189. -------
  190. axs: list
  191. two axes
  192. """
  193. if axs is None:
  194. # create figure
  195. f, axs = plt.subplots(1, 2, figsize=figsize)
  196. # get data
  197. trn_retino_df = self.trn_retino_df
  198. # plot azimuth against depth
  199. axs[0].scatter(trn_retino_df.azim, trn_retino_df.depth, facecolors='none',
  200. edgecolors='k', linewidth=0.5)
  201. # print n
  202. n = len(trn_retino_df.azim)
  203. print('n azimuth = %d' % n)
  204. # plot regression line from ancova model
  205. xeval = np.arange(np.min(trn_retino_df.azim), np.max(trn_retino_df.azim), 1)
  206. # parameters for the model copied from R
  207. intercept = 3044.511585
  208. vis_angle = -1.191659
  209. azim = 60.728882 # for category 1: azim 0: elev
  210. interaction = -1.838795
  211. ymodel_azim = intercept + xeval * vis_angle + 1 * azim + 1 * xeval * interaction
  212. axs[0].plot(xeval, ymodel_azim, 'r')
  213. # layout
  214. axs[0].set_xticks((0, 30, 60))
  215. axs[0].set_yticks((2500, 2900, 3300))
  216. axs[0].invert_yaxis()
  217. axs[0].set_ylabel('Depth (μm)')
  218. axs[0].set_xlabel('Azimuth ($\degree$)')
  219. axs[0].spines['bottom'].set_bounds(-20, 70)
  220. axs[0].spines['left'].set_bounds(2500, 3300)
  221. # plot elevation against depth
  222. axs[1].scatter(trn_retino_df.elev, trn_retino_df.depth, facecolors='none',
  223. edgecolors='k', linewidth=0.5)
  224. # print n
  225. n = len(trn_retino_df.elev)
  226. print('n elev = %d' % n)
  227. # plot regression line from model
  228. xeval = np.arange(np.min(trn_retino_df.elev), np.max(trn_retino_df.elev), 1)
  229. ymodel_elev = intercept + xeval * vis_angle + 0 * azim + 0 * xeval * interaction
  230. axs[1].plot(xeval, ymodel_elev, 'r')
  231. # layout
  232. axs[1].set_xticks((0, 30, 60))
  233. axs[1].set_yticks((2500, 2900, 3300))
  234. axs[1].invert_yaxis()
  235. axs[1].set_ylabel('')
  236. axs[1].set_yticklabels([])
  237. axs[1].set_xlabel('Elevation ($\degree$)')
  238. axs[1].spines['bottom'].set_bounds(-20, 70)
  239. axs[1].spines['left'].set_bounds(2500, 3300)
  240. f = plt.gcf()
  241. f.tight_layout()
  242. return axs
  243. def rf_area(self, figsize=(2, 2), ax=None):
  244. """Create violin plot for comparison of TRN and LGN RF sizes (Fig. 4i)
  245. Parameters
  246. -------
  247. figsize: tuple
  248. figure size (width, height)
  249. ax: instance of matplotlib.axes class
  250. axis to use for plotting
  251. Returns
  252. -------
  253. ax: mpl axis
  254. Axis with plot
  255. """
  256. if ax is None:
  257. # make figure
  258. f, ax = plt.subplots(figsize=figsize)
  259. # get data
  260. rf_area_df = self.rf_area_df
  261. # split data
  262. trn_area = rf_area_df['area'][rf_area_df['region'] == 'PGN'].array
  263. lgn_area = rf_area_df['area'][rf_area_df['region'] == 'LGN'].array
  264. # plot data
  265. sns.violinplot(data=[np.log(trn_area), np.log(lgn_area)], palette=[trn_red,
  266. lgn_green], ax=ax, linewidth=1, inner=None)
  267. # plot mean
  268. ax.plot([0, 1], [np.log(trn_area.mean()), np.log(lgn_area.mean())], linestyle='',
  269. c='k', marker='.')
  270. # format plot
  271. ylabels = np.array([10, 100, 1000])
  272. ax.set_yticks(np.log(ylabels))
  273. ax.set_yticklabels(ylabels)
  274. ax.spines['right'].set_visible(False)
  275. ax.spines['top'].set_visible(False)
  276. ylims = ax.get_ylim()
  277. ax.spines['bottom'].set_bounds(0, 1)
  278. ax.spines['left'].set_bounds(ylims)
  279. plt.gca().get_xticklabels()[0].set_color(trn_red)
  280. plt.gca().get_xticklabels()[1].set_color(lgn_green)
  281. ax.set_xticklabels(['visTRN', 'dLGN'])
  282. ax.set_ylabel('RF area (deg$^2$)')
  283. ax.grid(False)
  284. # mannwhitneyUtest
  285. u_stat, parea = stats.mannwhitneyu(trn_area, lgn_area)
  286. # test for differences in variance
  287. # (with center = median, levene's test = brown-forsythe test)
  288. f_stat, pvar = stats.levene(trn_area, lgn_area, center='median')
  289. # ratio
  290. ratio = trn_area.mean() / lgn_area.mean()
  291. # print stats
  292. print('dispersion stats: Brown–Forsythe test\n'
  293. 'Fstat = %0.3f \n'
  294. 'pval = 10**%0.3f\n\n'
  295. 'central tendency stats\n'
  296. 'Ustat = %0.3f\n'
  297. 'pval area = 10**%0.3f \n'
  298. 'N area visTRN = %d \n'
  299. 'N area dLGN = %d\n'
  300. 'visTRN mean area +- sem = %0.3f (+- %0.3f)\n'
  301. 'dLGN mean area +- sem = %0.3f (+- %0.3f)\n'
  302. 'visTRN rfs are on average %0.3f x larger than dLGN rfs'
  303. % (f_stat,
  304. np.log10(pvar),
  305. u_stat,
  306. np.log10(parea),
  307. len(trn_area),
  308. len(lgn_area),
  309. trn_area.mean(),
  310. stats.sem(trn_area),
  311. lgn_area.mean(),
  312. stats.sem(lgn_area),
  313. ratio))
  314. f = plt.gcf()
  315. f.tight_layout()
  316. return ax
  317. def norm_szcurves(self, figsize=(3, 3), ax=None, eval_range=76, thres=128, mark_ex=True,
  318. lw=0.5, xticks=(0, 25, 50, 75), colormap='Greys'):
  319. """Plot normalized fitted size tuning curves for visTRN population (Fig. 4l)
  320. Parameters
  321. -------
  322. figsize: tuple
  323. Figure size (width, height)
  324. ax: instance of matplotlib.axes class
  325. Axis to use for plotting
  326. eval_range: int
  327. Range over which to evaluate model
  328. thres: int
  329. Lower threshold for darkness of line
  330. mark_ex: bool
  331. If true plots example neuron in different color
  332. lw: float
  333. Linewidth
  334. xticks: tuple
  335. xticks
  336. colormap: string
  337. Colormap
  338. Returns
  339. -------
  340. ax: mpl axis
  341. Axis with plot
  342. """
  343. if ax is None:
  344. # create figure
  345. f, ax = plt.subplots(figsize=figsize)
  346. # define range over which to evaluate model
  347. x_eval = range(eval_range)
  348. # define colormap
  349. cmap = plt.cm.get_cmap(colormap)
  350. # plot all tuning curves
  351. for row in self.trn_sztun_df.itertuples():
  352. # get y data
  353. params = row.tun_pars
  354. y = spatint_utils.rog_offset(x_eval, *params)
  355. # subtract offset
  356. y_sub = y - y[0]
  357. # normalize
  358. y_norm = y_sub / np.nanmax(y_sub)
  359. # define color
  360. si = row.si_76
  361. col = int(np.round((1 - si) * 255))
  362. col = np.max((col, thres))
  363. # plot
  364. ax.plot(x_eval, y_norm, c=cmap(col), lw=lw)
  365. if mark_ex:
  366. # plot example session in red
  367. params = self.trn_sztun_ex_dict['tun_pars']
  368. y = spatint_utils.rog_offset(x_eval, *params)
  369. y_sub = y - y[0]
  370. y_norm = y_sub / np.nanmax(y_sub)
  371. ax.plot(x_eval, y_norm, c=trn_red, lw=lw)
  372. # layout
  373. ax.set_xlabel('Diameter ($\degree$)')
  374. ax.set_ylabel('Normalized firing rate')
  375. ax.set_xticks(xticks)
  376. ax.set_yticks((0, 0.5, 1))
  377. ax.spines['bottom'].set_bounds(0, 75)
  378. ax.spines['left'].set_bounds(0, 1)
  379. return ax
  380. def ex_sztun_curve(self, figsize=(4, 2), axs=None):
  381. """Plot example visTRN size-tuning curve and raster (Fig. 4jk)
  382. Parameters
  383. -------
  384. figsize: tuple
  385. figure size (width, height)
  386. axs: instance of matplotlib.axes class
  387. axis to use for plotting
  388. Returns
  389. -------
  390. ax: mpl axis
  391. Axis with plot
  392. """
  393. if axs is None:
  394. # create figure
  395. f, axs = plt.subplots(1, 2, figsize=figsize)
  396. # get data
  397. ex_data = self.trn_sztun_ex_dict
  398. # plot raster
  399. spatint_utils.plot_raster(raster=ex_data['rasters'][ex_data['u']],
  400. tranges=ex_data['tranges'],
  401. opto=ex_data['opto'],
  402. ax=axs[0])
  403. # plot curve
  404. spatint_utils.plot_tun(means=ex_data['tun_mean'],
  405. sems=ex_data['tun_sem'],
  406. spons=ex_data['tun_spon_mean'],
  407. xs=ex_data['ti_axes'],
  408. params=ex_data['tun_pars'],
  409. ax=axs[1])
  410. # format layout
  411. axs[1].set_xticks((0, 25, 50, 75))
  412. axs[1].set_yticks((0, 10, 20))
  413. axs[1].spines['bottom'].set_bounds(0, 75)
  414. f = plt.gcf()
  415. f.tight_layout()
  416. # print info for example cell
  417. sz_ex_si = self.trn_sztun_ex_dict['si_76']
  418. sz_ex_rfcs = self.trn_sztun_ex_dict['rfcs_76']
  419. print('size tuning example cell: \n'
  420. 'Preferred size = %0.3f \n'
  421. 'SI = %0.3f'
  422. % (sz_ex_rfcs, sz_ex_si))
  423. def si(self, figsize=(2, 2), ax=None):
  424. """Plot histogram of suppression indices for visTRN and dLGN population (Fig. 4m)
  425. Parameters
  426. -------
  427. figsize: tuple
  428. figure size (width, height)
  429. ax: instance of matplotlib.axes class
  430. axis to use for plotting
  431. Returns
  432. -------
  433. ax: mpl axis
  434. Axis with plot
  435. """
  436. if ax is None:
  437. # create figure
  438. f, ax = plt.subplots(figsize=figsize)
  439. # get data
  440. trn_sis = self.trn_sztun_df.si_76.to_numpy()
  441. lgn_sis = self.lgn_sztun_df.si_76.str[0].to_numpy()
  442. # assert that no data is beyond limits
  443. assert (~np.any(trn_sis < 0)) & (~np.any(trn_sis > 1)), 'Datapoints beyond bounds'
  444. assert (~np.any(lgn_sis < 0)) & (~np.any(lgn_sis > 1)), 'Datapoints beyond bounds'
  445. # plot
  446. si_bins = np.arange(0, 1.05, 0.05) # define bins
  447. trn_red_trsp = (*trn_red[0:3], 0.5) # define color for trn
  448. lgn_green_trsp = (*lgn_green[0:3], 0.5) # define color for dlgn
  449. ax.hist(trn_sis, bins=si_bins, weights=np.ones(len(trn_sis)) / len(trn_sis), lw=0,
  450. fc=trn_red_trsp)
  451. ax.hist(lgn_sis, bins=si_bins, weights=np.ones(len(lgn_sis)) / len(lgn_sis), lw=0,
  452. fc=lgn_green_trsp)
  453. # layout
  454. ax.yaxis.set_major_formatter(PercentFormatter(1))
  455. ax.set_xlabel('Suppression index')
  456. ax.set_ylabel('neurons')
  457. ax.set_xticks((0, 0.5, 1))
  458. ax.set_yticks((0, 0.25, 0.5))
  459. ax.spines['bottom'].set_bounds(0, 1)
  460. ax.set_xlim((0, 1))
  461. f = plt.gcf()
  462. f.tight_layout()
  463. # compute and print stats for trn
  464. n_trn = len(trn_sis)
  465. si_mean_trn = np.mean(trn_sis)
  466. si_sem_trn = stats.sem(trn_sis)
  467. si_med_trn = np.median(trn_sis)
  468. lower05_trn = (len(trn_sis[trn_sis < 0.05]) /
  469. len(trn_sis) * 100) # percentage of cells with si smaller 0.05
  470. print('trn size tuning population: \n'
  471. 'n = %d\n'
  472. 'mean si +/- sem = %0.3f (+- %0.3f) \n'
  473. 'median si = %0.3f \n'
  474. '%0.3f percent of pgn cells have si < 0.05\n'
  475. % (n_trn,
  476. si_mean_trn,
  477. si_sem_trn,
  478. si_med_trn,
  479. lower05_trn))
  480. # compute and print stats for dlgn
  481. n_lgn = len(lgn_sis)
  482. si_mean_lgn = np.mean(lgn_sis)
  483. si_sem_lgn = stats.sem(lgn_sis)
  484. si_med_lgn = np.median(lgn_sis)
  485. print('lgn size tuning population:\n'
  486. 'n = %d\n'
  487. 'mean si +/- sem = %0.3f (+- %0.3f)\n'
  488. 'median si = %0.3f\n'
  489. % (n_lgn,
  490. si_mean_lgn,
  491. si_sem_lgn,
  492. si_med_lgn))
  493. # compare the two
  494. ustat_sis, p_sis = stats.mannwhitneyu(trn_sis, lgn_sis)
  495. print('mannwhitneyu test to compare si in dlgn and trn:\n'
  496. 'Ustat: %0.3f\n'
  497. 'pvalue: 10**%0.3f\n'
  498. % (ustat_sis, np.log10(p_sis)))
  499. return ax