fbplot.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import matplotlib.pyplot as plt
  2. from matplotlib.lines import Line2D
  3. import numpy as np
  4. from munch import Munch
  5. from .kaux import mergemunch
  6. def plot_df(df, plot_order, plot_t, fb_states, fig=None, options=None):
  7. """
  8. options: Munch dictionary with keys:
  9. show_beginning: 0: don't show beginning of trial trace
  10. 1: mark beginning by a symbol
  11. > 1: show first number of samples with a heavy line
  12. show_end: 0: don't show end of trial trace
  13. 1: mark end by a symbol
  14. > 1: show last number of samples with a heavy line
  15. show_thresholds: T: will use thresholds in fb_states to plot horizontal lines
  16. show_median: T: will show median lines for groups
  17. """
  18. default_options = Munch({'show_beginning': 1, 'show_end': 2, 'show_thresholds': True, 'show_median':True})
  19. if options is None:
  20. options = default_options
  21. else:
  22. options = Munch(mergemunch(default_options, options))
  23. if fig is None:
  24. fig = plt.figure(34, figsize=(16, 4))
  25. fig.clf()
  26. n_rows = np.max([x.r for x in plot_order.p]) + 1
  27. n_c = np.max([x.c for x in plot_order.p]) + 1
  28. n_cols = len(plot_order.s) * n_c
  29. axs = fig.subplots(n_rows, n_cols, False, True, squeeze=False)
  30. for i, s in enumerate(plot_order.s):
  31. good_sample_idx = f'{s.dcol}_good'
  32. xtr_n = plot_t[s.dcol].twin
  33. xtr_n = xtr_n - xtr_n[0]
  34. xtr_n = xtr_n / xtr_n[-1]
  35. i_c = i
  36. i_r = 0
  37. if s.dcol == 'stimulus_start_samples':
  38. stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
  39. # rect = patches.Rectangle((0.0,0.0), stim_end_t, 1, linewidth=1, edgecolor='none', facecolor=(.01,.01,.01,.14))
  40. # axs[i_c].add_patch(rect)
  41. for (st, vals) in fb_states.items():
  42. axs[i_r, i_c].plot([0, stim_end_t], [vals[0], vals[0]], lw=4, ls='-', c=plot_order.state[st].c,
  43. alpha=.99, clip_on=False)
  44. axs[i_r, i_c].text(0, vals[0] + .02, plot_order.state[st].d, c=plot_order.state[st].c, alpha=1)
  45. for jj, pp in enumerate(plot_order.p):
  46. pdd = df[pp.sel(df)]
  47. if len(pdd) == 0:
  48. continue
  49. i_c = i + pp.c * len(plot_order.s)
  50. i_r = pp.r
  51. axs[i_r, i_c].spines['top'].set_visible(False)
  52. axs[i_r, i_c].spines['right'].set_visible(False)
  53. # if s.dcol == 'stimulus_start_samples':
  54. # stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
  55. #
  56. # axs[i_r, i_c].plot([0, stim_end_t], [fb_states[pp.t][0], fb_states[pp.t][0]], lw=4, ls='-',
  57. # c=plot_order.state[pp.t].c, alpha=.99, clip_on=False)
  58. # axs[i_r, i_c].text(0, fb_states[pp.t][0] + .02, plot_order.state[pp.t].d, c=plot_order.state[pp.t].c,
  59. # alpha=1)
  60. for ix, pd_row in pdd.iterrows():
  61. if pd_row[s.dcol] is None:
  62. continue
  63. y = pd_row[s.dcol][pd_row[good_sample_idx]]
  64. xtr = plot_t[s.dcol].twin[pd_row[good_sample_idx]]
  65. if s.norm:
  66. xtr = xtr - xtr[0]
  67. xtr = xtr / xtr[-1]
  68. axs[i_r, i_c].plot(xtr, y, c=pp.col, alpha=.3, lw=.8)
  69. if options.show_beginning == 1:
  70. axs[i_r, i_c].plot(xtr[0], y[0], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
  71. elif options.show_beginning > 1:
  72. axs[i_r, i_c].plot(xtr[:(options.show_beginning - 1)], y[:(options.show_beginning - 1)], c=pp.col, alpha=.8, lw=1.5)
  73. if options.show_end == 1:
  74. axs[i_r, i_c].plot(xtr[-1], y[-1], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
  75. elif options.show_end > 1:
  76. axs[i_r, i_c].plot(xtr[(-options.show_end ):], y[(-options.show_end ):], c=pp.col, alpha=.8, lw=1.5)
  77. # axs[i_r, i_c].plot(xtr_n, pd_row['response_start_samples_norm'], c=plot_order.cols[i_col], alpha=.3, lw=1)
  78. if options.show_median and not pdd[s.dcol].empty:
  79. if s.norm:
  80. all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
  81. mean_tr = np.median(all_traces, 0)
  82. axs[i_r, i_c].plot(xtr_n, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
  83. else:
  84. all_traces = np.vstack(pdd[s.dcol].to_numpy())
  85. mean_tr = np.nanmedian(all_traces, 0)
  86. axs[i_r, i_c].plot(plot_t[s.dcol].twin, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
  87. axs[i_r, i_c].set_title(f'{s.title} {pp.desc} n={len(pdd)}')
  88. if options.show_thresholds:
  89. thr = fb_states.up
  90. axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
  91. thr = fb_states.down
  92. axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
  93. axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
  94. return fig
  95. def plot_df_combined(df, plot_order, plot_t, fb_states, fig=None, options=None):
  96. """
  97. options: Munch dictionary with keys:
  98. show_beginning: 0: don't show beginning of trial trace
  99. 1: mark beginning by a symbol
  100. > 1: show first number of samples with a heavy line
  101. show_end: 0: don't show end of trial trace
  102. 1: mark end by a symbol
  103. > 1: show last number of samples with a heavy line
  104. show_thresholds: T: will use thresholds in fb_states to plot horizontal lines
  105. show_median: T: will show median lines for groups
  106. """
  107. default_options = Munch({'show_beginning': 1, 'show_end': 2, 'show_thresholds': True, 'show_median':True})
  108. if options is None:
  109. options = default_options
  110. else:
  111. options = Munch(mergemunch(default_options, options))
  112. if fig is None:
  113. fig = plt.figure(35, figsize=(16, 4))
  114. fig.clf()
  115. n_cols = len(plot_order.s)
  116. axs = fig.subplots(1, n_cols, False, True, squeeze=False)
  117. for i, s in enumerate(plot_order.s):
  118. good_sample_idx = f'{s.dcol}_good'
  119. xtr_n = plot_t[s.dcol].twin
  120. xtr_n = xtr_n - xtr_n[0]
  121. xtr_n = xtr_n / xtr_n[-1]
  122. i_c = i
  123. i_r = 0
  124. axs[i_r, i_c].set_title(f'{s.title}')
  125. if options.show_thresholds:
  126. thr = fb_states.up
  127. axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
  128. thr = fb_states.down
  129. axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
  130. axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
  131. axs[i_r, i_c].spines['top'].set_visible(False)
  132. axs[i_r, i_c].spines['right'].set_visible(False)
  133. for jj, pp in enumerate(plot_order.p):
  134. pdd = df[pp.sel(df)]
  135. if s.dcol == 'stimulus_start_samples':
  136. stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
  137. axs[i_r, i_c].plot([0, stim_end_t], [fb_states[pp.t][0], fb_states[pp.t][0]], lw=4, ls='-',
  138. c=plot_order.state[pp.t].c, alpha=.99, clip_on=False)
  139. axs[i_r, i_c].text(0, fb_states[pp.t][0] + .02, plot_order.state[pp.t].d, c=plot_order.state[pp.t].c,
  140. alpha=1)
  141. my_label = f"{pp.desc} (n = {len(pdd)})"
  142. for ix, pd_row in pdd.iterrows():
  143. y = pd_row[s.dcol][pd_row[good_sample_idx]]
  144. xtr = plot_t[s.dcol].twin[pd_row[good_sample_idx]]
  145. if s.norm:
  146. xtr = xtr - xtr[0]
  147. xtr = xtr / xtr[-1]
  148. axs[i_r, i_c].plot(xtr, y, c=pp.col, alpha=.3, lw=.8, ls=pp.ls, label=my_label)
  149. my_label = None
  150. if options.show_beginning == 1:
  151. axs[i_r, i_c].plot(xtr[0], y[0], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
  152. elif options.show_beginning > 1:
  153. axs[i_r, i_c].plot(xtr[:(options.show_beginning - 1)], y[:(options.show_beginning - 1)], c=pp.col, alpha=.8, lw=1.5)
  154. if options.show_end == 1:
  155. axs[i_r, i_c].plot(xtr[-1], y[-1], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
  156. elif options.show_end > 1:
  157. axs[i_r, i_c].plot(xtr[(-options.show_end ):], y[(-options.show_end ):], c=pp.col, alpha=.8, lw=1.5)
  158. # axs[i_r, i_c].plot(xtr_n, pd_row['response_start_samples_norm'], c=plot_order.cols[i_col], alpha=.3, lw=1)
  159. if options.show_median and not pdd[s.dcol].empty:
  160. if s.norm:
  161. all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
  162. mean_tr = np.median(all_traces, 0)
  163. axs[i_r, i_c].plot(xtr_n, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
  164. else:
  165. all_traces = np.vstack(pdd[s.dcol].to_numpy())
  166. mean_tr = np.nanmedian(all_traces, 0)
  167. axs[i_r, i_c].plot(plot_t[s.dcol].twin, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls, label=pp.desc)
  168. axs[i_r, i_c].legend()
  169. return fig
  170. def plot_df_avg(df, plot_order, plot_t, fb_states, fig=None, show_stimulus_start=False):
  171. if fig is None:
  172. fig = plt.figure(35, figsize=(16, 4))
  173. fig.clf()
  174. n_rows = 1
  175. # n_c = len(plot_order.c)
  176. n_cols = len(plot_order.s)
  177. axs = fig.subplots(n_rows, n_cols, False, True, squeeze=False)
  178. i_r = 0
  179. for i, s in enumerate(plot_order.s):
  180. good_sample_idx = f'{s.dcol}_good'
  181. xtr_n = plot_t[s.dcol].twin
  182. xtr_n = xtr_n - xtr_n[0]
  183. xtr_n = xtr_n / xtr_n[-1]
  184. i_c = i
  185. i_r = 0
  186. axs[i_r, i_c].spines['top'].set_visible(False)
  187. axs[i_r, i_c].spines['right'].set_visible(False)
  188. if s.dcol == 'stimulus_start_samples':
  189. stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
  190. # rect = patches.Rectangle((0.0,0.0), stim_end_t, 1, linewidth=1, edgecolor='none', facecolor=(.01,.01,.01,.14))
  191. # axs[i_c].add_patch(rect)
  192. for (st, vals) in fb_states.items():
  193. axs[i_r, i_c].plot([0, stim_end_t], [vals[0], vals[0]], lw=4, ls='-', c=plot_order.state[st].c,
  194. alpha=.99, clip_on=False)
  195. axs[i_r, i_c].text(0, vals[0] + .02, plot_order.state[st].d, c=plot_order.state[st].c, alpha=1)
  196. for jj, pp in enumerate(plot_order.p):
  197. pdd = df[pp.sel(df)]
  198. i_r = 0
  199. if not pdd[s.dcol].empty:
  200. if s.norm:
  201. all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
  202. avg_x = xtr_n
  203. else:
  204. all_traces = np.vstack(pdd[s.dcol].to_numpy())
  205. avg_x = plot_t[s.dcol].twin
  206. mean_tr = np.nanmedian(all_traces, 0)
  207. perct_tr = np.nanpercentile(all_traces, [25, 75], axis=0)
  208. axs[i_r, i_c].plot(avg_x, mean_tr, c=pp.col, alpha=1, lw=pp.lw, clip_on=False, ls=pp.ls)
  209. if pp.show_var:
  210. axs[i_r, i_c].plot(avg_x, perct_tr[0, :], c=pp.col, alpha=.1, lw=pp.lw, clip_on=False, ls=pp.ls)
  211. axs[i_r, i_c].plot(avg_x, perct_tr[1, :], c=pp.col, alpha=.1, lw=pp.lw, clip_on=False, ls=pp.ls)
  212. axs[i_r, i_c].fill_between(avg_x, perct_tr[0, :], perct_tr[1, :],
  213. alpha=0.20, facecolor=pp.col, edgecolor='none', ls=pp.ls, lw=1,
  214. antialiased=True, clip_on=False)
  215. if show_stimulus_start and s.dcol == 'stimulus_stop_samples':
  216. stim_start_offset = (pdd.stimulus_start - pdd.stimulus_stop) / 3e4
  217. for x in stim_start_offset:
  218. axs[i_r, i_c].axvline(x, c=[0, 0, 0], alpha=.1)
  219. thr = fb_states.up
  220. # axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='--', c="#333333")
  221. thr = fb_states.down
  222. # axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='--', c="#333333")
  223. # axs[i_c].set_ylim([0,1])
  224. if s.norm:
  225. axs[i_r, i_c].set_xlim([0, 1])
  226. axs[i_r, i_c].set_xlabel('t (normalized)')
  227. else:
  228. axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
  229. axs[i_r, i_c].set_xlabel('t [s]')
  230. if i_c == 0:
  231. axs[i_r, i_c].set_ylabel('normalized neural activity')
  232. axs[i_r, i_c].set_title(f'{s.title}')
  233. custom_lines = [Line2D([0], [0], color=(.3, .3, .3), lw=2, ls='-'),
  234. Line2D([0], [0], color=(.3, .3, .3), lw=2, ls=':')]
  235. axs[i_r, n_cols - 1].legend(custom_lines, ['Correct Trials', 'Error Trials'])
  236. fig.show()
  237. return fig