plot_figures_part_C.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import os
  2. from typing import List, Union
  3. from datetime import datetime as dt
  4. from helpers import data_management as dm
  5. import matplotlib.pyplot as plt
  6. import matplotlib
  7. import pandas as pd
  8. import numpy as np
  9. import yaml
  10. from helpers.data import DataNormalizer
  11. import re
  12. from helpers.nsp import fix_timestamps
  13. from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE
  14. plot_win_len = 120.0
  15. s = {'day': '2019-07-05', 'cfg_t': '14_42_36', 'data_t': '14_42_36', 's': np.datetime64('2019-07-05T14:40:00'),
  16. 'e': np.datetime64('2019-07-05T15:11:45'), 'title_str':'KIAP Session day 108 14:42 Free Speller',
  17. 'plot_start': 980, 'plot_win': 90, 'plot_end': 1070}
  18. fn_sess = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'data_{s["data_t"]}.bin'
  19. fn_evs = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'events_{s["data_t"]}.txt'
  20. fn_spl = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'debug_{s["cfg_t"]}.log'
  21. fn_cfg = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'config_dump_{s["cfg_t"]}.yaml'
  22. s_dt = dt.strptime(s["day"], '%Y-%m-%d')
  23. days_post_implant = (s_dt - IMPLANT_DATE).days
  24. with open(fn_cfg) as stream:
  25. params = yaml.load(stream, Loader=yaml.Loader)
  26. datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params, t_lim_start=s.get('s'), t_lim_end=s.get('e'))
  27. ts, offsets,_ = fix_timestamps(ts)
  28. tv = ts / 3e4
  29. with open(fn_evs, 'r') as f:
  30. evs = f.read().splitlines()
  31. with open(fn_spl, 'r') as f:
  32. splevs = f.read().splitlines()
  33. # fix event timestamps
  34. tpat = re.compile(r"^(\d+)(, .*)$")
  35. evs_time = []
  36. for ev in evs:
  37. m = tpat.match(ev)
  38. evs_time.append(np.int64(m.group(1)))
  39. evs_time = np.asarray(evs_time, dtype=np.int64)
  40. rollev, = np.where(np.diff(evs_time) < 0)
  41. for j, i in enumerate(rollev):
  42. evs_time[(i + 1):] += offsets[j]
  43. # parse decoder decisions
  44. pevs = []
  45. lpat = re.compile(r"^(\d+), b'(.*)'$")
  46. decpat = re.compile(r".*Decoder decision is: (.*)$")
  47. alt_resp_pat = re.compile(r".*, response, stop$")
  48. for ev, ev_time in zip(evs, evs_time):
  49. m = lpat.match(ev)
  50. # evts = m.group(1)
  51. evstr = m.group(2)
  52. m2 = decpat.match(evstr)
  53. if m2 is not None:
  54. pevs.append([float(ev_time) / 3e4, m2.group(1)])
  55. if len(pevs) == 0:
  56. for ev, ev_time in zip(evs, evs_time):
  57. m = lpat.match(ev)
  58. # evts = m.group(1)
  59. evstr = m.group(2)
  60. m2 = alt_resp_pat.match(evstr)
  61. if m2 is not None:
  62. pevs.append([float(ev_time) / 3e4, 'decision'])
  63. # parse events
  64. evlist = {}
  65. evlist['bl'] = []
  66. evlist['st'] = []
  67. evlist['re'] = []
  68. rp_ev_pat = re.compile(r"^(\d+), b'.*(stimulus|response|baseline), (start|stop)'$")
  69. # rp_rst_pat = re.compile(r"^(\d+), b'.*response, start'$")
  70. tmp_itb = []
  71. tmp_its = []
  72. tmp_itr = []
  73. for ev_time, ev in zip(evs_time, evs):
  74. m = rp_ev_pat.match(ev)
  75. if m is not None:
  76. if m.group(2) == 'baseline':
  77. if m.group(3) == 'start':
  78. tmp_itb.append(float(ev_time) / 3e4)
  79. continue
  80. elif m.group(3) == 'stop':
  81. tmp_itb.append(float(ev_time) / 3e4)
  82. evlist['bl'].append(tmp_itb)
  83. tmp_itb = []
  84. continue
  85. continue
  86. elif m.group(2) == 'stimulus':
  87. if m.group(3) == 'start':
  88. tmp_its.append(float(ev_time) / 3e4)
  89. continue
  90. elif m.group(3) == 'stop':
  91. tmp_its.append(float(ev_time) / 3e4)
  92. evlist['st'].append(tmp_its)
  93. tmp_its = []
  94. continue
  95. continue
  96. elif m.group(2) == 'response':
  97. if m.group(3) == 'start':
  98. tmp_itr.append(float(ev_time) / 3e4)
  99. continue
  100. elif m.group(3) == 'stop':
  101. tmp_itr.append(float(ev_time) / 3e4)
  102. evlist['re'].append(tmp_itr)
  103. tmp_itr = []
  104. continue
  105. continue
  106. continue
  107. splpevs = []
  108. spl_pat = re.compile(r"^.* (\w+) - \('(.*)', '(.*)'\)")
  109. for sev in splevs:
  110. m = spl_pat.match(sev)
  111. if m is not None:
  112. splpevs.append([m.group(2), m.group(3), m.group(1)])
  113. dn = DataNormalizer(params)
  114. fr = dn.calculate_norm_rate(datav)
  115. br_chnum = np.array(dn.norm_rate['ch_ids']) + 1
  116. i_plt = 0
  117. p_start = s.get('plot_start', tv[0])
  118. p_end = min(s.get('plot_end', tv[-1]), tv[-1])
  119. p_step = s.get('plot_win', plot_win_len)
  120. while p_start + i_plt * p_step < p_end:
  121. i_plt += 1
  122. t_win = np.array([p_start + (i_plt - 1) * p_step, min(p_end, p_start + i_plt * p_step)])
  123. pl_idx = np.logical_and(t_win[0] <= tv, tv <= t_win[1])
  124. fig = plt.figure(34, figsize=(12, 6.22))
  125. fig.clf()
  126. fig.set_tight_layout(True)
  127. axs = [None, None] # fig.subplots(2,1)
  128. axs[0] = fig.add_axes((.04, .65, .93, .21))
  129. axs[1] = fig.add_axes((.04, .1, .93, .35))
  130. raw_phs = axs[0].plot(tv[pl_idx], datav[pl_idx, :][:, dn.norm_rate['ch_ids']])
  131. df_raw_data = pd.DataFrame(index=tv[pl_idx], data=datav[pl_idx, :][:, dn.norm_rate['ch_ids']], columns=dn.norm_rate['ch_ids'])
  132. df_raw_data['normalized_Rate'] = fr[pl_idx]
  133. for ph, ch in zip(raw_phs, br_chnum):
  134. ph.set_label(f'Channel {ch}')
  135. axs[0].set_xlim(t_win[0], t_win[0] + p_step)
  136. axs[0].spines['right'].set_color('none')
  137. axs[0].spines['top'].set_color('none')
  138. axs[0].spines['left'].set_position(('data', t_win[0] - 1))
  139. axs[0].spines['bottom'].set_position('zero')
  140. axs[0].set_ylabel('Firing Rate [Hz]')
  141. axs[0].legend(loc='upper right', ncol=2, fontsize='small')
  142. axs[0].set_title('Raw firing rates of channels used for “yes”/“no” classification')
  143. frph, = axs[1].plot(tv[pl_idx], fr[pl_idx], label='Normalized Rate', color=(157.0/255, 157.0/255, 157.0/255))
  144. thr_top = params.paradigms.feedback.states.up[1]
  145. thr_bot = params.paradigms.feedback.states.down[2]
  146. topl = axs[1].hlines(thr_top, t_win[0], t_win[1], color=(0, .6, .1), label='“Yes” threshold')
  147. botl = axs[1].hlines(thr_bot, t_win[0], t_win[1], color=(.9, 0, .1), label='“No” threshold')
  148. plot_speller_events = {'t': [], 'event': [], 'option': [], 'txt': []}
  149. p: List[Union[float, str]]
  150. for p, sp in zip(pevs, splpevs):
  151. if t_win[0] <= p[0] <= t_win[1]:
  152. cur_spel_st = sp[0]
  153. if len(cur_spel_st) > 5:
  154. cur_spel_st = "…" + sp[0][-4:]
  155. spel_ev_txt = f'{cur_spel_st}\n“{sp[1]}”\n{sp[2]}'
  156. plot_speller_events['t'].append(p[0])
  157. plot_speller_events['event'].append(sp[2])
  158. plot_speller_events['option'].append(sp[1])
  159. plot_speller_events['txt'].append(cur_spel_st)
  160. if sp[2] == 'yes':
  161. axs[1].vlines(p[0], 0, 1, color=(0, .6, .1), linestyles=':')
  162. axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(0, .6, .1))
  163. elif sp[2] == 'no':
  164. axs[1].vlines(p[0], 0, 1, color=(.9, 0, .1), linestyles=':')
  165. axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.9, 0, .1))
  166. else:
  167. axs[1].vlines(p[0], 0, 1, color=(.7, .7, .7), linestyles=':')
  168. axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.7, .7, .7))
  169. df_speller_events = pd.DataFrame(plot_speller_events)
  170. sp_periods = {'t_begin': [], 't_end': [], 'event': []}
  171. # for p in evlist['bl']:
  172. # axs[1].add_patch(matplotlib.patches.Rectangle((p[0],0), p[1]-p[0], 1, color=(.9, .9, .9)))
  173. for p in evlist['st']:
  174. if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
  175. axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(198.0/255.0, 198.0/255.0, 198.0/255.0), alpha=.15))
  176. sp_periods['t_begin'].append(p[0])
  177. sp_periods['t_end'].append(p[1])
  178. sp_periods['event'].append('stimulus')
  179. for p in evlist['re']:
  180. if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
  181. axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(87.0/255.0, 46.0/255.0, 136.0/255.0), alpha=.15))
  182. sp_periods['t_begin'].append(p[0])
  183. sp_periods['t_end'].append(p[1])
  184. sp_periods['event'].append('response')
  185. df_speller_periods = pd.DataFrame(sp_periods)
  186. axs[1].set_xlim(t_win[0], t_win[0] + p_step)
  187. axs[1].set_xlabel('Time [s]')
  188. axs[1].set_ylabel('Normalized Firing Rate')
  189. axs[1].legend(fontsize='small', loc='upper right')
  190. axs[1].set_title("Normalized rate and speller state", y=1.3)
  191. axs[1].spines['right'].set_color('none')
  192. axs[1].spines['top'].set_color('none')
  193. axs[1].spines['left'].set_position(('data', t_win[0] - 1))
  194. axs[1].spines['bottom'].set_position('zero')
  195. fig.suptitle(f'{s["title_str"]} – “{splpevs[-1][0]}”', fontsize=15, wrap=True)#, fontweight='medium')
  196. fig.show()
  197. BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
  198. savename = BASE_PATH_OUT / f'Figure_4_SpellerProgress_plt{i_plt}'
  199. fig.savefig(savename.with_suffix('.pdf'), transparent=True)
  200. fig.savefig(savename.with_suffix('.svg'), transparent=True)
  201. fig.savefig(savename.with_suffix('.eps'), transparent=True)
  202. df_raw_data.to_csv(savename.with_suffix('.raw.csv'))
  203. df_speller_events.to_csv(savename.with_suffix('.events.csv'))
  204. df_speller_periods.to_csv(savename.with_suffix('.periods.csv'))