fig1.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import matplotlib
  4. import pandas as pd
  5. idx = pd.IndexSlice
  6. import pickle as pkl
  7. import os
  8. import spatint_utils as su
  9. datapath = os.getcwd()+'/../data/'
  10. with open(datapath+'lgn_spat_profile.pkl','rb') as r:
  11. lgn_spat_profile = pkl.load(r)
  12. with open(datapath+'lgn_ori_examples.pkl','rb') as r:
  13. lgn_ori_examples = pkl.load(r)
  14. with open(datapath + 'v1_raw_muaRFs.pkl','rb') as r:
  15. v1_raw_muaRFs = pkl.load(r)
  16. with open(datapath + 'v1_muaRFs.pkl','rb') as r:
  17. v1_muaRFs = pkl.load(r)
  18. with open(datapath + 'lgn_raw_muaRFs.pkl','rb') as r:
  19. lgn_raw_muaRFs = pkl.load(r)
  20. with open(datapath + 'lgn_muaRFs.pkl','rb') as r:
  21. lgn_muaRFs = pkl.load(r)
  22. def agg_overlap_bins(data, start, end, nbins, binsize, agg_op, boot_value, filtr = False):
  23. """computes aggregate values for bins along a given axis, aggregate operation is defined by agg_op"""
  24. bins_data = []
  25. print('Averaging and bootstrapping in overlapping bins')
  26. for bin_start in np.linspace(start, end, nbins):
  27. bin_data = data.loc[(data['d']>bin_start) & (data['d']<(bin_start+binsize))]
  28. if filtr != False and len(bin_data)<filtr:
  29. continue
  30. booties = []
  31. for i in range(1000):
  32. bt_bin = bin_data[boot_value][np.random.randint(0,len(bin_data),len(bin_data))]
  33. booties.append(np.mean(bt_bin))
  34. booties = np.sort(booties)
  35. lower = booties[24]
  36. upper = booties[974]
  37. boot_std = np.std(booties)
  38. bin_agg = bin_data.agg(agg_op)
  39. bin_agg['lower'] = lower
  40. bin_agg['upper'] = upper
  41. bin_agg['boot'] = boot_std
  42. bin_agg['d'] = bin_start
  43. bins_data.append(bin_agg)
  44. return pd.concat(bins_data, axis =1)
  45. #bin_mean = agg_overlap_bins(lgn_spat_profile,0,100,31,15,agg_op='mean',boot_value='gain',filtr=5)
  46. def plot_fold_change(ax=None, data=lgn_spat_profile):#, bin_mean=bin_mean):
  47. """Plot mean fold change values of all units as scatter plot"""
  48. if ax is None:
  49. fig, ax = plt.subplots()
  50. # Clip low and high values for better visibility
  51. data.loc[data['gain']>2,'gain'] = 2.35
  52. data.loc[data['gain']<-2,'gain'] = -2.35
  53. ax.plot(data['d'],data['gain'],c='k', alpha=0.7, linestyle='', fillstyle='none', clip_on = False, marker='o',mec='k',markersize=1)
  54. examples = data.reindex([('Ntsr1Cre_2015_0080',3,4027),
  55. ('Ntsr1Cre_2018_0003',2,22,3),
  56. ('Ntsr1Cre_2015_0080',4,3049)])
  57. ax.plot([-100,100],[0,0],c='k', lw=0.35)
  58. ax.plot(bin_mean.loc['d'][:15]+7.5,bin_mean.loc['gain'][:15], lw=3,alpha=0.5, c='#1090cfff')
  59. ax.plot(bin_mean.loc["d"][8:15]+7.5,bin_mean.loc['gain'][8:15], lw=3,alpha=1, c ='#1090cfff')
  60. ax.set(xlim=[0,60], ylim=[-2.5,2.5])
  61. ax.set_xlabel('Distance of dLGN RFs to V1 RFs at injection site $(^{\circ})$',labelpad=5)
  62. ax.set_xticks([0,20,40,60])
  63. ax.set_ylabel('Fold change',labelpad=0)
  64. ax.minorticks_off()
  65. ax.set_yticks(list(np.linspace(-2,2,5))+[-2.35,2.35])
  66. ax.set_yticklabels([str("{0:.2f}".format(round(ytick,2)))for ytick in np.geomspace(0.25,4,5)]+['< 0.25','> 4.00'])
  67. ax.spines['left'].set_bounds(-2,2)
  68. ax.spines['right'].set_visible(False)
  69. ax.spines['top'].set_visible(False)
  70. def plot_modulation_histogram(ax = None, data = lgn_spat_profile):
  71. if ax is None:
  72. fig = plt.figure()
  73. fig.set_figheight(2.576)
  74. fig.set_figwidth(4.342)
  75. ax = fig.add_axes((0.15,0.25,0.75,0.75))
  76. bins = np.linspace(0,60,13)
  77. data = data.sort_values('d')
  78. data['bin'] = data.apply(lambda x: np.searchsorted(bins, x['d'], side='right')-1, axis=1)
  79. ratios = data.groupby('bin').apply(lambda x: pd.DataFrame({'sup' :[len(x[x['modlab']==-1])/len(x)],
  80. 'fac' :[len(x[x['modlab']== 1])/len(x)],
  81. 'snull':[len(x[(x['modlab']==0) & (x['gain']<0)])/len(x)],
  82. 'fnull' :[len(x[(x['modlab']==0) & (x['gain']>0)])/len(x)]}))
  83. ratios.index = ratios.index.set_names('dummy', level=1)
  84. ratios = ratios.reset_index("dummy").drop(columns='dummy')
  85. ratios['binpos'] = bins[ratios.index]+2.5
  86. ax.bar(ratios['binpos'],-ratios['sup'], bottom = -ratios['snull'], color = 'blue', width=4)
  87. ax.bar(ratios['binpos'], ratios['fac'], bottom = ratios['fnull'], color = 'orange',width = 4)
  88. ax.bar(ratios['binpos'], -ratios['snull'], bottom = 0, color='lightskyblue', width=4)
  89. ax.bar(ratios['binpos'], ratios['fnull'], bottom = 0, color = 'moccasin', width=4)
  90. ax.set_xlabel('Distance of dLGN RF to \n mean V1 RF at injection site ($\degree$)')
  91. ax.spines['top'].set_visible(False)
  92. ax.spines['right'].set_visible(False)
  93. ax.spines['left'].set_bounds(high=0.75,low=-0.75)
  94. ax.plot([0,70],[0,0],c='k',lw=0.35)
  95. ax.set(xlim=[2.5,60],ylim=[-0.9,0.9])
  96. ax.set_yticks([-0.75,-0.25,0,0.25,0.75])
  97. ax.set_yticklabels(['75','25','','25','75'])
  98. ax.set_xticks([0,20,40,60])
  99. ax.set_ylabel('Proportion (%)')
  100. ax.yaxis.set_label_coords(-0.1,0.55)
  101. ax.annotate('enhanced', xy=(1,0.15), xytext=(1,0.15),xycoords='data',rotation='vertical',color='darkorange')
  102. ax.annotate('suppressed', xy=(1,-0.8), xytext=(1,-0.8),xycoords='data',rotation='vertical',color='blue')
  103. def plot_ori_examples(ax = None, data = lgn_ori_examples):
  104. def compute_ori_curve(df,opto,xs):
  105. return su.sum_of_gaussians(xs,*df['tun_pars'][opto])
  106. xs = np.linspace(0,360,361)
  107. ctrl_curves = data.apply(compute_ori_curve, axis=1, opto=0,xs=xs)
  108. opto_curves = data.apply(compute_ori_curve, axis=1, opto=1,xs=xs)
  109. if ax is None:
  110. gridspec_kw = {'left':0.175,'right':0.95,'bottom':0.225,'top':0.95,'hspace':0.1}
  111. fig, axes = plt.subplots(1,3,figsize=(6.327,1.971), gridspec_kw = gridspec_kw)
  112. for i,ax in enumerate(axes):
  113. ax.plot(xs, ctrl_curves.iloc[i],c='k',ms=2)
  114. ax.plot(xs, opto_curves.iloc[i],c='#1090cfff',ms=2)
  115. ax.plot(np.linspace(0,330,12),data.iloc[i]['tun_mean'][:,0],mfc='k',ms=8,ls='',marker='.',mew=0)
  116. ax.plot(np.linspace(0,330,12),data.iloc[i]['tun_mean'][:,1],mfc='#1090cfff',ms=8,ls='',marker='.',mew=0)
  117. ax.errorbar(np.linspace(0,330,12), data.iloc[i]['tun_mean'][:,0], yerr=data.iloc[i]['tun_sem'][:,0],fmt='none', ecolor='k', elinewidth=1.5)
  118. ax.errorbar(np.linspace(0,330,12), data.iloc[i]['tun_mean'][:,1], yerr=data.iloc[i]['tun_sem'][:,1],fmt='none', ecolor='#1090cfff', elinewidth=1.5)
  119. ax.plot(np.linspace(0,330,12), 12*[data.iloc[i]['tun_spon_mean'][0]],c='k',lw=0.5)
  120. ax.plot(np.linspace(0,330,12), 12*[data.iloc[i]['tun_spon_mean'][1]],c='#1090cfff',lw=0.5)
  121. ax.spines['top'].set_visible(False)
  122. ax.spines['right'].set_visible(False)
  123. ax.spines['bottom'].set_visible(True)
  124. ax.spines['bottom'].set_bounds(0,360)
  125. xticks = [0,180,360]
  126. ax.set_xticks(xticks)
  127. ax.set_yticks([])
  128. ax.set_xticklabels([])
  129. axes[0].spines['left'].set_visible(True)
  130. axes[0].spines['bottom'].set_visible(True)
  131. axes[0].set(ylim=[0,100])
  132. axes[0].spines['left'].set_bounds(0,80)
  133. axes[0].set_yticks([0,40,80])
  134. axes[1].set(ylim=[0,62.5])
  135. axes[1].spines['left'].set_bounds(0,50)
  136. axes[1].set_yticks([0,25,50])
  137. axes[2].set(ylim=[0,62.5])
  138. axes[2].spines['left'].set_bounds(0,50)
  139. axes[2].set_yticks([0,25,50])
  140. axes[0].set_ylabel('Firing rate (sp/s)')
  141. axes[0].yaxis.set_label_coords(-0.33,0.4)
  142. axes[0].set_xlabel('Direction ($\degree$)',labelpad=0)
  143. axes[0].set_xticks([0,180,360])
  144. axes[0].set_xticklabels(['0','180','360'])
  145. axes[0].xaxis.set_tick_params(length=1)
  146. def plot_v1_raw_muaRFs(ax = None, data = v1_raw_muaRFs):
  147. if ax is None:
  148. gridspec_kw = {'left':0.1,'right':0.975,'bottom':0.15,'top':0.9}
  149. fig,ax = plt.subplots(3,1,gridspec_kw=gridspec_kw)
  150. fig.set_figheight(1.984)
  151. fig.set_figwidth(1.332)
  152. cmap = matplotlib.cm.get_cmap('Greens')
  153. cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
  154. [(0,'white'),
  155. (1,'green')])
  156. for axi, (index, row) in zip(ax,data.iterrows()):
  157. left = row['ti_axes'][:,0].min()-2.4
  158. right = row['ti_axes'][:,0].max()+2.5
  159. bottom = row['ti_axes'][:,1].min()-2.5
  160. top = row['ti_axes'][:,1].max()+2.5
  161. extent = (left,right,bottom,top)
  162. axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
  163. plot_ellipse(row,axi,cmap, linewidth=1.5)
  164. axi.axis('scaled')
  165. axi.set(xlim=[left,right],ylim=[bottom,top])
  166. axi.set_xticks([])
  167. axi.set_yticks([])
  168. else:
  169. xticks = [left,right]
  170. yticks = [bottom,top]
  171. axi.set_xticks(xticks)
  172. axi.set_yticks([])
  173. ax[0].set_yticks(yticks)
  174. axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks])
  175. ax[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
  176. rotation=90)
  177. plt.setp(ax[0].yaxis.get_majorticklabels(), va='center')
  178. plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
  179. axi.tick_params(length=2)
  180. ax[0].tick_params(axis='y', pad=0, length=2)
  181. axi.tick_params(axis='x', pad=1)
  182. def plot_v1_rf_ellipses(ax = None, data = v1_muaRFs):
  183. if ax is None:
  184. fig = plt.figure(figsize=(7,5))
  185. ax = fig.add_axes((0.15,0.15,0.7,0.7))
  186. cmap = matplotlib.cm.get_cmap('Greens')
  187. data = data.reset_index('ch')
  188. one_series = data.loc['Ntsr1Cre_2015_0080',2]
  189. one_series = one_series[one_series['sigma_x_mix']<15]
  190. maxchan,minchan = one_series['ch'].max(), one_series['ch'].min()
  191. V1_x = data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['x_mix']
  192. V1_y = data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['y_mix']
  193. one_series.apply(lambda x: plot_ellipse(x,ax=ax,cmap=cmap,linewidth=1,maxchan=maxchan),axis=1)
  194. ax.scatter(V1_x,V1_y,marker='+',s=100,lw=2,c='k',zorder=3)
  195. ax.set(xlim=[-20,80],ylim=[-35,50])
  196. ax.set_xlabel('Azimuth ($\degree$)'
  197. ,labelpad=0)
  198. ax.set_ylabel('Elevation ($\degree$)'
  199. ,labelpad=0)
  200. xticks = [tick.get_text() for tick in ax.get_xticklabels()]
  201. ax.spines['top'].set_visible(False)
  202. ax.spines['right'].set_visible(False)
  203. def plot_lgn_raw_muaRFs(ax = None, data = lgn_raw_muaRFs):
  204. if ax is None:
  205. gridspec_kw = {'left':0.1,'right':0.975,'bottom':0.1,'top':0.93}
  206. fig,ax = plt.subplots(5,2,gridspec_kw=gridspec_kw)
  207. fig.set_figheight(3.552)
  208. fig.set_figwidth(2.190)
  209. ax_column1 = ax[:,0]
  210. ax_column2 = ax[:,1]
  211. cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
  212. [(0,'xkcd:off white'),
  213. (1,'xkcd:neon purple')])
  214. cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
  215. [(0,'xkcd:off white'),
  216. (1,'xkcd:saffron')]) #deep tea
  217. cols = {5:'#00555a',9:'blue'}
  218. ser1 = data.loc['Ntsr1Cre_2015_0080',9].iloc[14:19]
  219. ser2 = data.loc['Ntsr1Cre_2015_0080',5].iloc[7:12]
  220. for axi, (index, row) in zip(ax_column1,ser1.iterrows()):
  221. left = row['ti_axes'][:,0].min()-2.4
  222. right = row['ti_axes'][:,0].max()+2.5
  223. bottom = row['ti_axes'][:,1].min()-2.5
  224. top = row['ti_axes'][:,1].max()+2.5
  225. extent = (left,right,bottom,top)
  226. axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
  227. plot_ellipse(row,axi,cmap, linewidth =1.5)
  228. axi.axis('scaled')
  229. axi.set(xlim=[left,right],ylim=[bottom,top])
  230. axi.set_xticks([])
  231. axi.set_yticks([])
  232. else:
  233. xticks = [left,right]
  234. yticks = [bottom,top]
  235. axi.set_xticks(xticks)
  236. axi.set_yticks([])
  237. ax_column1[0].set_yticks(yticks)
  238. axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks])
  239. ax_column1[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
  240. rotation=90)
  241. plt.setp(ax_column1[0].yaxis.get_majorticklabels(), va='center')
  242. plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
  243. ax_column1[0].tick_params(axis='y', pad=0)
  244. for axi, (index, row) in zip(ax_column2,ser2.iterrows()):
  245. left = row['ti_axes'][:,0].min()-2.4
  246. right = row['ti_axes'][:,0].max()+2.5
  247. bottom = row['ti_axes'][:,1].min()-2.5
  248. top = row['ti_axes'][:,1].max()+2.5
  249. extent = (left,right,bottom,top)
  250. axi.imshow(np.rot90(row['rfs'].mean(axis=(0,3))),cmap='gray', extent = extent)
  251. plot_ellipse(row,axi,cmap2, linewidth=1.5)
  252. axi.axis('scaled')
  253. axi.set(xlim=[left,right],ylim=[bottom,top])
  254. axi.set_xticks([])
  255. axi.set_yticks([])
  256. else:
  257. xticks = [left,right]
  258. yticks = [bottom,top]
  259. axi.set_xticks(xticks)
  260. axi.set_yticks([])
  261. ax_column2[0].set_yticks(yticks)
  262. axi.set_xticklabels([str("{0:.0f}".format(round(xtick,0))) +'$\degree$' for xtick in xticks]
  263. )
  264. ax_column2[0].set_yticklabels([str("{0:.0f}".format(round(ytick,0))) +'$\degree$' for ytick in yticks],
  265. rotation=90)
  266. plt.setp(ax_column2[0].yaxis.get_majorticklabels(), va='center')
  267. plt.setp(axi.xaxis.get_majorticklabels(), ha='center')
  268. ax_column2[0].tick_params(axis='y', pad=0)
  269. def plot_lgn_rf_ellipses(ax = None, data = lgn_muaRFs, v1data=v1_muaRFs):
  270. if ax is None:
  271. fig = plt.figure(figsize=(2.995,3.531))
  272. ax = fig.add_axes((0.275,0.2,0.65,0.725))
  273. cmap1 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
  274. [(0,'xkcd:neon purple'),
  275. (1,'xkcd:neon purple')])
  276. cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('custom',
  277. [(0,'xkcd:saffron'),
  278. (1,'xkcd:saffron')])
  279. cmaps = {9:cmap1,5:cmap2}
  280. two_series = data.loc[idx['Ntsr1Cre_2015_0080',(5,9),:]]
  281. V1_x = v1data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['x_mix']
  282. V1_y = v1data.groupby('m').agg('mean').loc['Ntsr1Cre_2015_0080']['y_mix']
  283. two_series.groupby('s').apply(lambda x: plot_RF_ellipses(x, ax, cmaps[x.index.unique(level=1)[0]],linewidth=0.5))
  284. ax.scatter(V1_x,V1_y,marker='+',s=100,lw=2,c='k',zorder=3)
  285. ax.set(xlim=[-5,70],ylim=[-25,60])
  286. ax.set_xlabel('Azimuth ($\degree$)'
  287. ,labelpad=0)
  288. ax.set_ylabel('Elevation ($\degree$)'
  289. ,labelpad=0)
  290. xticks = [tick.get_text() for tick in ax.get_xticklabels()]
  291. xticks = [0,40,80]
  292. yticks=[-20,0,30,60]
  293. xticklabels = [str(xtick) for xtick in xticks]
  294. yticklabels = [str(ytick) for ytick in yticks]
  295. ax.yaxis.set_tick_params(pad=0)
  296. ax.set_xticks(xticks)
  297. ax.set_xticklabels(xticklabels)
  298. ax.set_yticks(yticks)
  299. ax.set_yticklabels(yticklabels)
  300. ax.spines['top'].set_visible(False)
  301. ax.spines['right'].set_visible(False)
  302. ax.spines['left'].set_bounds(-20,60)
  303. ax.spines['bottom'].set_bounds(0,80)
  304. norm= matplotlib.colors.Normalize(vmin=0.5,vmax=1)
  305. def plot_RF_ellipses(data, ax, cmap,linewidth=1):
  306. """Plotting RF ellipses from fitted parameters."""
  307. data.apply(lambda x: plot_ellipse(x,ax,cmap,linewidth=linewidth),axis=1)
  308. def plot_ellipse(data, ax, cmap, linewidth=1, maxchan=60):
  309. """Plot from single df row"""
  310. ddiff = su.degdiff(180,(data['theta_mix']*180/np.pi),180)
  311. params = data[['x_mix',
  312. 'y_mix',
  313. 'sigma_x_mix',
  314. 'sigma_y_mix']]
  315. x_ellipse, y_ellipse = su.calculate_ellipse(*params,ddiff)
  316. col = cmap(data['rsq_mix'])
  317. ax.plot(x_ellipse,y_ellipse, lw=linewidth,c=col)
  318. ax.axis('scaled')