create_supp_figures.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. import os
  2. import pandas as pd
  3. import mrestimator as mre
  4. import numpy as np
  5. import matplotlib.pylab as plt
  6. import seaborn as sns
  7. import matplotlib.gridspec as gridspec
  8. from scipy import stats as sps
  9. import mr_utilities
  10. def cm2inch(value):
  11. return value/2.54
  12. mycol = ['darkgreen', 'mediumaquamarine', 'slategray', 'darkblue','steelblue',
  13. 'firebrick', 'purple', 'orange', "red", "violet", "darkred",
  14. "darkviolet", "green", "salmon", "black", 'darkgreen',
  15. 'mediumaquamarine', 'silver', 'darkblue','steelblue', 'firebrick',
  16. 'purple', 'orange', "red", "violet", "darkred", "darkviolet"]
  17. def distribution_A_recordings(resultpath):
  18. data_path = '../Data/preictal/'
  19. recordings = [dI for dI in os.listdir(data_path) if os.path.isdir(
  20. os.path.join(data_path, dI))]
  21. A_mean = []
  22. A_nonzero = []
  23. A_nonzero_frac = []
  24. for recording in recordings:
  25. outputfile = '{}{}/pkl/activity_SU_4ms.pkl'.format(data_path,
  26. recording)
  27. if not os.path.isfile(outputfile):
  28. print("No binning data for recording ", recording)
  29. continue
  30. data = pd.read_pickle(outputfile)
  31. binnings = data['binning'].tolist()
  32. for binning in binnings:
  33. activity = data[data['binning'] == binning]['activity'].tolist()[0]
  34. A_mean.append(np.mean(activity))
  35. A_nonzero.append(mr_utilities.count_nonzero(activity))
  36. A_nonzero_frac.append(mr_utilities.count_nonzero(activity) /
  37. len(activity) * 100)
  38. A_mean_Hz = [A_m*(1000/4) for A_m in A_mean]
  39. A_mean = A_mean_Hz
  40. quantile = 0.50
  41. A_mean_50 = np.quantile(A_mean, quantile)
  42. A_nonzero_50 = np.quantile(A_nonzero, quantile)
  43. A_nonzero_frac_50 = np.quantile(A_nonzero_frac, quantile)
  44. quantile = 0.25
  45. A_mean_25 = np.quantile(A_mean, quantile)
  46. A_nonzero_25 = np.quantile(A_nonzero, quantile)
  47. A_nonzero_frac_25 = np.quantile(A_nonzero_frac, quantile)
  48. sns.set(style='white')
  49. fig, ax = plt.subplots(1, 3, figsize=((cm2inch(21), cm2inch(2.8))))
  50. bins = np.arange(0, np.max(A_mean), 20)
  51. ax[0].hist(A_mean, alpha=0.6, color='darkblue')
  52. ax[0].set_xlabel(r'population firing rate R (Hz)')
  53. ax[0].set_ylabel(r'# recordings')
  54. ax[0].text(0.8, 0.8, r'$q_{50}$'+'={:.1f} Hz'.format(A_mean_50),
  55. horizontalalignment='center', verticalalignment='center',
  56. transform=ax[0].transAxes, color='green', fontsize=11)
  57. ax[0].text(0.8, 0.65, r'$q_{25}$'+'={:.1f} Hz'.format(A_mean_25),
  58. horizontalalignment='center', verticalalignment='center',
  59. transform=ax[0].transAxes, color='green', fontsize=11)
  60. bins = np.arange(0, np.max(A_nonzero), 20)
  61. ax[1].hist(A_nonzero, alpha=0.6, color='darkblue')
  62. ax[1].set_xlabel(r'$n_{A_t\neq 0}$')
  63. ax[1].text(0.8, 0.8, r'$q_{50}$'+'={:.0f}'.format(A_nonzero_50),
  64. horizontalalignment='center', verticalalignment='center',
  65. transform=ax[1].transAxes, color='green', fontsize=11)
  66. ax[1].text(0.8, 0.65, r'$q_{25}$'+'={:.0f}'.format(A_nonzero_25),
  67. horizontalalignment='center', verticalalignment='center',
  68. transform=ax[1].transAxes, color='green', fontsize=11)
  69. bins = np.arange(0, np.max(A_nonzero_frac), 20)
  70. ax[2].hist(A_nonzero_frac, alpha=0.6, color='darkblue')
  71. ax[2].set_xlabel(r'$n_{A_t\neq 0}$ / $n_{A_t}$ (%)')
  72. ax[2].text(0.7, 0.8, r'$q_{50}$'+'={:.1f} %'.format(A_nonzero_frac_50),
  73. horizontalalignment='center',
  74. verticalalignment='center', transform=ax[2].transAxes,
  75. color='green', fontsize=11)
  76. ax[2].text(0.7, 0.65, r'$q_{25}$'+'={:.1f} %'.format(A_nonzero_frac_25),
  77. horizontalalignment='center',
  78. verticalalignment='center', transform=ax[2].transAxes,
  79. color='green', fontsize=11)
  80. ax[0].tick_params(top='off', bottom='on', left='on', right='off',
  81. labelleft='on',labelbottom='on')
  82. ax[1].tick_params(top='off', bottom='on', left='on', right='off',
  83. labelleft='on', labelbottom='on')
  84. ax[2].tick_params(top='off', bottom='on', left='on', right='off',
  85. labelleft='on', labelbottom='on')
  86. sns.despine()
  87. plt.subplots_adjust(bottom=None, right=None,
  88. wspace=0.6, hspace=0.3, left=0.15, top=1.2)
  89. figfile = '{}S4_recording_statistics.pdf'.format(resultpath)
  90. plt.savefig(figfile, bbox_inches='tight')
  91. plt.show()
  92. def mt_all_supp(save_path,fit_method, kmin, kmax, windowsize, windowstep, dt):
  93. patients = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
  94. 18, 19, 20]
  95. fig = plt.figure(figsize=(cm2inch(2*13.5), cm2inch(1.2*10*5)))
  96. outer = gridspec.GridSpec(10, 2, wspace=0.2, hspace=0.9)
  97. laterality = ["SOZ", "nSOZ"]
  98. for ip, patient in enumerate(patients):
  99. inner = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outer[ip],
  100. wspace=0.1, hspace=0.9)
  101. for j, soz in enumerate(laterality):
  102. ax = plt.Subplot(fig, inner[j])
  103. ax = mt_plot(ax, patient, soz, save_path,fit_method, kmin, kmax,
  104. windowsize, windowstep, dt)
  105. if j == 0:
  106. ax.set_title('Patient {}'.format(int(patient)), fontsize=12)
  107. ax.tick_params(top='off', bottom='on', left='off',
  108. right='off', labelleft='on',
  109. labelbottom='off')
  110. ax.set_xlabel('')
  111. fig.add_subplot(ax)
  112. sns.despine(fig)
  113. plt.savefig('{}/S6_allmt.pdf'.format(save_path), bbox_inches='tight')
  114. plt.show()
  115. plt.close()
  116. def binning_label(patient, soz):
  117. focifile = '../Data/patients.txt'
  118. f = open(focifile, 'r')
  119. foci = [line.rstrip('\n').split('\t') for line in f.readlines()]
  120. focus = 'NA'
  121. for i, idf in enumerate(foci[1:]):
  122. if int(idf[0]) == int(patient):
  123. focus = idf[1]
  124. if soz == "SOZ":
  125. if focus == "L":
  126. binning = "left"
  127. elif focus == "R":
  128. binning = "right"
  129. else:
  130. raise Exception("No clear focus found.")
  131. elif soz == "nSOZ":
  132. if focus == "L":
  133. binning = "right"
  134. elif focus == "R":
  135. binning = "left"
  136. else:
  137. raise Exception("No clear focus found.")
  138. return binning
  139. def mt_plot(ax, patient, soz, save_path, fit_method, kmin, kmax,
  140. windowsize, windowstep, dt):
  141. recordings = mr_utilities.list_preictrecordings(patient)[0]
  142. binning = binning_label(patient, soz)
  143. for irec, rec in enumerate(recordings):
  144. mt_file = '{}{}/mt_results_{}_{}_kmin{}_kmax{}_winsize{}_winstep{}' \
  145. '_dt{}.pkl'.format(save_path, rec, binning, fit_method, kmin,
  146. kmax, windowsize, windowstep, dt)
  147. if not os.path.isfile(mt_file):
  148. print('{} not existing'.format(mt_file))
  149. continue
  150. # -----------------------------------------------------------------#
  151. # read m(t) data
  152. # -----------------------------------------------------------------#
  153. mt_frame = pd.read_pickle(mt_file)
  154. original_length = len(mt_frame.m.tolist())
  155. rejected_inconsistencies = ['H_nonzero', 'H_linear',
  156. 'not_converged']
  157. for incon in rejected_inconsistencies:
  158. if not mt_frame.empty:
  159. mt_frame = mt_frame[[incon not in mt_frame[
  160. 'inconsistencies'].tolist()[i] for i in range(
  161. len(mt_frame['inconsistencies'].tolist()))]]
  162. if len(mt_frame.m.tolist()) < 0.95 * original_length:
  163. continue
  164. kmin = mt_frame.kmin.unique()[0]
  165. kmax = mt_frame.kmax.unique()[0]
  166. fit_method = mt_frame.fit_method.unique()[0]
  167. winsize = mt_frame.windowsize.unique()[0]
  168. t0_list = mt_frame.t0.tolist()
  169. mt = mt_frame.m.tolist()
  170. t0_list = [t0_ms * 0.001 for t0_ms in t0_list]
  171. half_win_sec = int((winsize / 2) / 1000 * dt)
  172. t_middle = [t0 + half_win_sec for t0 in t0_list]
  173. # -----------------------------------------------------------------#
  174. # Plotting
  175. # -----------------------------------------------------------------#
  176. ax.plot(t_middle, mt, '.', label='Rec {}'.format(irec + 1),
  177. color=mycol[irec], markersize=0.5)
  178. ax.axhline(y=1, color='black', linestyle='--', alpha=0.3)
  179. ax.set_ylabel(r"$\hat{m}$")
  180. ax.set_xlim((t0_list[0] - 10, t0_list[-1] + 2 * half_win_sec + 10))
  181. binning_focus = mr_utilities.binning_focus(binning, int(patient))
  182. if binning_focus == 'focal':
  183. focus_label = 'ipsi'
  184. else:
  185. focus_label = 'contra'
  186. ax.set_title("{}".format(focus_label), loc='left')
  187. ax.set_xlim((-10, 610))
  188. handles, labels = ax.get_legend_handles_labels()
  189. if len(labels) == 0:
  190. if soz == 'SOZ':
  191. ax.set_title("ipsi", loc='left')
  192. if soz == 'nSOZ':
  193. ax.set_title("contra", loc='left')
  194. ax.axhline(y=1, color='black', linestyle='--', alpha=0.3)
  195. ax.set_ylabel(r"$\hat{m}$")
  196. ax.set_ylim((0.8, 1.05))
  197. ax.set_xlim((-10, 610))
  198. ax.xaxis.grid(False)
  199. ax.set_xlabel("time (s)")
  200. ax.tick_params(top='off', bottom='off', left='off', right='off',
  201. labelleft='on', labelbottom='off')
  202. ax.tick_params(top='off', bottom='on', left='off', right='off',
  203. labelleft='on', labelbottom='on')
  204. sns.despine()
  205. return ax
  206. def windowsize_analysis_recs(data_path, recording, binning, outputfilename,
  207. resultpath):
  208. outputfile = '{}{}/pkl/{}'.format(data_path, recording, outputfilename)
  209. if not os.path.isfile(outputfile):
  210. print("No binning data for recording ", recording)
  211. print(outputfile)
  212. return
  213. data = pd.read_pickle(outputfile)
  214. dt = data['dt'].unique()[0]
  215. activity = data[data['binning'] == binning]['activity'].tolist()[0]
  216. winstep = 100
  217. ksteps = (1, 400)
  218. windowsize = np.arange(2500, 40000, 2500)
  219. N_estimates = len(activity) - np.max(windowsize)
  220. tau_estimate = []
  221. tau_conf = []
  222. m_estimate = []
  223. m_conf = []
  224. for Lw in windowsize:
  225. tau_L = []
  226. m_L = []
  227. for iw in range(0, N_estimates, winstep):
  228. activity_window = activity[iw:iw+Lw]
  229. input = mre.input_handler(activity_window)
  230. rk = mre.coefficients(input, steps=ksteps, dt=dt)
  231. fitres_offset = mre.fit(rk, fitfunc=mre.f_exponential_offset)
  232. tau_L.append(fitres_offset.tau)
  233. m_L.append(fitres_offset.mre)
  234. conf_int_off = sps.mstats.mquantiles(tau_L, prob=[0.025, 0.975],
  235. alphap=0, betap=1, axis=None,
  236. limit=())
  237. conf_int_off = [abs(np.mean(tau_L) - ci) for ci in conf_int_off]
  238. tau_conf.append(conf_int_off)
  239. tau_estimate.append(np.mean(tau_L))
  240. conf_int_off = sps.mstats.mquantiles(m_L, prob=[0.025, 0.975],
  241. alphap=0, betap=1, axis=None,
  242. limit=())
  243. conf_int_off = [abs(np.mean(m_L) - ci) for ci in conf_int_off]
  244. m_conf.append(conf_int_off)
  245. m_estimate.append(np.mean(m_L))
  246. windowsize_seconds = [w * 4 / 1000 for w in windowsize]
  247. fig, ax = plt.subplots(1, 1, figsize=((cm2inch(9), cm2inch(0.9 * 4))))
  248. ax.scatter(windowsize_seconds, tau_estimate,
  249. label=r'$\hat{\tau}_{offset}$', c='steelblue',
  250. marker='.', s=9)
  251. ax.errorbar(windowsize_seconds, tau_estimate,
  252. yerr=np.transpose(np.array(tau_conf)),
  253. linestyle="None", c='steelblue', elinewidth=0.9,
  254. capsize=2.5)
  255. ax.set_ylim((-10, 800))
  256. T = [20, 40, 60, 80, 100, 120, 140]
  257. ax.set_xticks(T)
  258. ax.set_xlabel('window size (s)')
  259. ax.set_ylabel(r'$\hat{\tau}$ (ms)')
  260. sns.despine()
  261. plt.subplots_adjust(bottom=None, right=None,
  262. wspace=0.4, hspace=0.3, left=0.15, top=1.2)
  263. figfile = '{}S5_estimates_winsizes_rec_{}_{}.pdf'.format(
  264. resultpath, recording, binning)
  265. plt.savefig(figfile, bbox_inches='tight')
  266. fig, ax = plt.subplots(1, 1, figsize=((cm2inch(9), cm2inch(0.9 * 4))))
  267. ax.scatter(windowsize_seconds, m_estimate,
  268. label=r'$\hat{\tau}_{offset}$', c='royalblue',
  269. marker='.', s=9)
  270. ax.errorbar(windowsize_seconds, m_estimate,
  271. yerr=np.transpose(np.array(m_conf)),
  272. linestyle="None", c='royalblue', elinewidth=0.9,
  273. capsize=2.5)
  274. T = [20, 40, 60, 80, 100, 120, 140]
  275. ax.set_xticks(T)
  276. ax.set_ylim((0.5, 1.25))
  277. ax.set_xlabel('window size (s)')
  278. ax.set_ylabel(r'$\hat{m}$')
  279. sns.despine()
  280. plt.subplots_adjust(bottom=None, right=None,
  281. wspace=0.4, hspace=0.3, left=0.15, top=1.2)
  282. figfile = '{}S5_estimates_m_winsizes_rec_{}_{}.pdf'.format(
  283. resultpath, recording, binning)
  284. plt.savefig(figfile, bbox_inches='tight')
  285. plt.show()
  286. def windowsize_analysis():
  287. binning = 'left'
  288. outputfilename = 'activity_SU_4ms.pkl'
  289. resultpath = '../Results/preictal/singleunits/'
  290. distribution_A_recordings(resultpath)
  291. # warning: analysis below has long computation time
  292. # data_path = '../Data/interictal/'
  293. # recording = '13ref'
  294. # windowsize_analysis_recs(data_path, recording, binning, outputfilename,
  295. # resultpath)
  296. #
  297. # data_path = '../Data/preictal/'
  298. # recording = '13_02'
  299. # windowsize_analysis_recs(data_path, recording, binning, outputfilename,
  300. # resultpath)
  301. def all_mt():
  302. result_path = '../Results/preictal/singleunits/'
  303. ksteps = (1, 400)
  304. kmin = ksteps[0]
  305. kmax=ksteps[1]
  306. windowsize = 20000
  307. windowstep = 500
  308. dt = 4
  309. mt_all_supp(result_path, 'offset', kmin, kmax, windowsize, windowstep, dt)
  310. def example_autocorrelations(patient, pkl_file_inter, pkl_file_pre,
  311. resultpath):
  312. mr_results_pre = pd.read_pickle(pkl_file_pre)
  313. mr_results_inter = pd.read_pickle(pkl_file_inter)
  314. mr_results_pre['rectype'] = 'preictal'
  315. mr_results_inter['rectype'] = 'interictal'
  316. mr_results = pd.concat([mr_results_pre, mr_results_inter],
  317. ignore_index=True)
  318. mr_results = mr_results[mr_results['patient'] == patient]
  319. binning_foci = ['focal', 'contra-lateral']
  320. sns.set(style="white")
  321. fig, ax = plt.subplots(2, 4, figsize=(cm2inch(24), cm2inch(6)))
  322. fit_method = 'offset'
  323. mr_rec_pre = mr_results[mr_results['rectype'] == 'preictal']
  324. pre_recordings_choice = mr_rec_pre.recording.unique()[:3]
  325. for si, soz in enumerate(binning_foci):
  326. # plot interictal in first position
  327. mr_rec_inter = mr_results[mr_results['rectype'] == 'interictal']
  328. mr_rec_inter = mr_rec_inter[mr_rec_inter['binning_focus'] == soz]
  329. rk = mr_rec_inter['rk'].item()[0]
  330. k_steps = mr_rec_inter['rk_steps'].item()
  331. dt = mr_rec_inter['dt'].item()
  332. time_steps = np.array(k_steps[0]) * dt
  333. ax[si, 0].plot(time_steps, rk, '-', linewidth=0.8,
  334. c='slategray', alpha=0.4)
  335. fitfunc = mr_utilities.string_to_func(fit_method)
  336. fitparams = mr_rec_inter['fit_params'].item()
  337. fitparams = tuple(fitparams[0])
  338. m = mr_rec_inter['m'].item()
  339. ax[si, 0].plot(time_steps, fitfunc(time_steps, *fitparams),
  340. label=r'$m$ = {:.2f}'.format(m), color='slategray')
  341. ax[si, 0].tick_params(top='off', bottom='on', left='on',
  342. right='off', labelleft='on', labelbottom='on')
  343. ax[si, 0].set_xlabel(r"time lag $k$ (ms)")
  344. ax[si, 0].legend(loc=1, fontsize='x-small')
  345. ax[si, 0].ticklabel_format(axis='y', style='sci', scilimits=(-1, 1),
  346. useMathText=True)
  347. # plot 3 preictal recordings next to it
  348. mr_rec_pre = mr_results[mr_results['rectype'] == 'preictal']
  349. mr_rec_pre = mr_rec_pre[mr_rec_pre['binning_focus'] == soz]
  350. for irec, rec in enumerate(pre_recordings_choice):
  351. mr_rec = mr_rec_pre[mr_rec_pre['recording'] == rec]
  352. rk = mr_rec['rk'].item()[0]
  353. k_steps = mr_rec['rk_steps'].item()
  354. dt = mr_rec['dt'].item()
  355. time_steps = np.array(k_steps[0]) * dt
  356. ax[si, irec+1].plot(time_steps, rk, '-', linewidth=0.8,
  357. c='darkblue', alpha=0.4)
  358. fitfunc = mr_utilities.string_to_func(fit_method)
  359. fitparams = mr_rec['fit_params'].item()
  360. fitparams = tuple(fitparams[0])
  361. m = mr_rec['m'].item()
  362. ax[si, irec+1].plot(time_steps, fitfunc(time_steps, *fitparams),
  363. label=r'$m$ = {:.2f}'.format(m), color='darkblue')
  364. ax[si, irec+1].tick_params(top='off', bottom='on', left='on',
  365. right='off', labelleft='on',labelbottom='on')
  366. ax[si, irec+1].set_xlabel(r"time lag $k$ (ms)")
  367. ax[si, irec+1].legend(loc=1, fontsize='x-small')
  368. ax[si, irec+1].ticklabel_format(axis='y', style='sci',
  369. scilimits=(-1, 1), useMathText=True)
  370. ax[0, 0].set_ylabel("ACF")
  371. ax[1, 0].set_ylabel("ACF")
  372. plt.subplots_adjust(bottom=None, right=None,
  373. wspace=0.5, hspace=1, left=0.15, top=1.2)
  374. sns.despine(fig)
  375. resultpath_rec = '{}S1_autocorrelations_patient{}.pdf'.format(resultpath, patient)
  376. plt.savefig(resultpath_rec, bbox_inches='tight')
  377. plt.close()
  378. def create_example_ACF_plot():
  379. pkl_file_pre = '../Results/preictal/singleunits/mr_results.pkl'
  380. pkl_file_inter = '../Results/interictal/singleunits/mr_results.pkl'
  381. resultpath = '../Results/preictal/singleunits/'
  382. patient = 12
  383. example_autocorrelations(patient, pkl_file_inter, pkl_file_pre, resultpath)
  384. patient = 13
  385. example_autocorrelations(patient, pkl_file_inter, pkl_file_pre, resultpath)
  386. def main():
  387. create_example_ACF_plot()
  388. windowsize_analysis()
  389. all_mt()
  390. if __name__ == '__main__':
  391. main()