import os from typing import List, Union from datetime import datetime as dt from helpers import data_management as dm import matplotlib.pyplot as plt import matplotlib import pandas as pd import numpy as np import yaml from helpers.data import DataNormalizer import re from helpers.nsp import fix_timestamps from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE plot_win_len = 120.0 s = {'day': '2019-07-05', 'cfg_t': '14_42_36', 'data_t': '14_42_36', 's': np.datetime64('2019-07-05T14:40:00'), 'e': np.datetime64('2019-07-05T15:11:45'), 'title_str':'KIAP Session day 108 14:42 Free Speller', 'plot_start': 980, 'plot_win': 90, 'plot_end': 1070} fn_sess = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'data_{s["data_t"]}.bin' fn_evs = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'events_{s["data_t"]}.txt' fn_spl = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'debug_{s["cfg_t"]}.log' fn_cfg = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'config_dump_{s["cfg_t"]}.yaml' s_dt = dt.strptime(s["day"], '%Y-%m-%d') days_post_implant = (s_dt - IMPLANT_DATE).days with open(fn_cfg) as stream: params = yaml.load(stream, Loader=yaml.Loader) datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params, t_lim_start=s.get('s'), t_lim_end=s.get('e')) ts, offsets,_ = fix_timestamps(ts) tv = ts / 3e4 with open(fn_evs, 'r') as f: evs = f.read().splitlines() with open(fn_spl, 'r') as f: splevs = f.read().splitlines() # fix event timestamps tpat = re.compile(r"^(\d+)(, .*)$") evs_time = [] for ev in evs: m = tpat.match(ev) evs_time.append(np.int64(m.group(1))) evs_time = np.asarray(evs_time, dtype=np.int64) rollev, = np.where(np.diff(evs_time) < 0) for j, i in enumerate(rollev): evs_time[(i + 1):] += offsets[j] # parse decoder decisions pevs = [] lpat = re.compile(r"^(\d+), b'(.*)'$") decpat = re.compile(r".*Decoder decision is: (.*)$") alt_resp_pat = re.compile(r".*, response, stop$") for ev, ev_time in zip(evs, evs_time): m = lpat.match(ev) # evts = m.group(1) evstr = m.group(2) m2 = decpat.match(evstr) if m2 is not None: pevs.append([float(ev_time) / 3e4, m2.group(1)]) if len(pevs) == 0: for ev, ev_time in zip(evs, evs_time): m = lpat.match(ev) # evts = m.group(1) evstr = m.group(2) m2 = alt_resp_pat.match(evstr) if m2 is not None: pevs.append([float(ev_time) / 3e4, 'decision']) # parse events evlist = {} evlist['bl'] = [] evlist['st'] = [] evlist['re'] = [] rp_ev_pat = re.compile(r"^(\d+), b'.*(stimulus|response|baseline), (start|stop)'$") # rp_rst_pat = re.compile(r"^(\d+), b'.*response, start'$") tmp_itb = [] tmp_its = [] tmp_itr = [] for ev_time, ev in zip(evs_time, evs): m = rp_ev_pat.match(ev) if m is not None: if m.group(2) == 'baseline': if m.group(3) == 'start': tmp_itb.append(float(ev_time) / 3e4) continue elif m.group(3) == 'stop': tmp_itb.append(float(ev_time) / 3e4) evlist['bl'].append(tmp_itb) tmp_itb = [] continue continue elif m.group(2) == 'stimulus': if m.group(3) == 'start': tmp_its.append(float(ev_time) / 3e4) continue elif m.group(3) == 'stop': tmp_its.append(float(ev_time) / 3e4) evlist['st'].append(tmp_its) tmp_its = [] continue continue elif m.group(2) == 'response': if m.group(3) == 'start': tmp_itr.append(float(ev_time) / 3e4) continue elif m.group(3) == 'stop': tmp_itr.append(float(ev_time) / 3e4) evlist['re'].append(tmp_itr) tmp_itr = [] continue continue continue splpevs = [] spl_pat = re.compile(r"^.* (\w+) - \('(.*)', '(.*)'\)") for sev in splevs: m = spl_pat.match(sev) if m is not None: splpevs.append([m.group(2), m.group(3), m.group(1)]) dn = DataNormalizer(params) fr = dn.calculate_norm_rate(datav) br_chnum = np.array(dn.norm_rate['ch_ids']) + 1 i_plt = 0 p_start = s.get('plot_start', tv[0]) p_end = min(s.get('plot_end', tv[-1]), tv[-1]) p_step = s.get('plot_win', plot_win_len) while p_start + i_plt * p_step < p_end: i_plt += 1 t_win = np.array([p_start + (i_plt - 1) * p_step, min(p_end, p_start + i_plt * p_step)]) pl_idx = np.logical_and(t_win[0] <= tv, tv <= t_win[1]) fig = plt.figure(34, figsize=(12, 6.22)) fig.clf() fig.set_tight_layout(True) axs = [None, None] # fig.subplots(2,1) axs[0] = fig.add_axes((.04, .65, .93, .21)) axs[1] = fig.add_axes((.04, .1, .93, .35)) raw_phs = axs[0].plot(tv[pl_idx], datav[pl_idx, :][:, dn.norm_rate['ch_ids']]) df_raw_data = pd.DataFrame(index=tv[pl_idx], data=datav[pl_idx, :][:, dn.norm_rate['ch_ids']], columns=dn.norm_rate['ch_ids']) df_raw_data['normalized_Rate'] = fr[pl_idx] for ph, ch in zip(raw_phs, br_chnum): ph.set_label(f'Channel {ch}') axs[0].set_xlim(t_win[0], t_win[0] + p_step) axs[0].spines['right'].set_color('none') axs[0].spines['top'].set_color('none') axs[0].spines['left'].set_position(('data', t_win[0] - 1)) axs[0].spines['bottom'].set_position('zero') axs[0].set_ylabel('Firing Rate [Hz]') axs[0].legend(loc='upper right', ncol=2, fontsize='small') axs[0].set_title('Raw firing rates of channels used for “yes”/“no” classification') frph, = axs[1].plot(tv[pl_idx], fr[pl_idx], label='Normalized Rate', color=(157.0/255, 157.0/255, 157.0/255)) thr_top = params.paradigms.feedback.states.up[1] thr_bot = params.paradigms.feedback.states.down[2] topl = axs[1].hlines(thr_top, t_win[0], t_win[1], color=(0, .6, .1), label='“Yes” threshold') botl = axs[1].hlines(thr_bot, t_win[0], t_win[1], color=(.9, 0, .1), label='“No” threshold') plot_speller_events = {'t': [], 'event': [], 'option': [], 'txt': []} p: List[Union[float, str]] for p, sp in zip(pevs, splpevs): if t_win[0] <= p[0] <= t_win[1]: cur_spel_st = sp[0] if len(cur_spel_st) > 5: cur_spel_st = "…" + sp[0][-4:] spel_ev_txt = f'{cur_spel_st}\n“{sp[1]}”\n{sp[2]}' plot_speller_events['t'].append(p[0]) plot_speller_events['event'].append(sp[2]) plot_speller_events['option'].append(sp[1]) plot_speller_events['txt'].append(cur_spel_st) if sp[2] == 'yes': axs[1].vlines(p[0], 0, 1, color=(0, .6, .1), linestyles=':') axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(0, .6, .1)) elif sp[2] == 'no': axs[1].vlines(p[0], 0, 1, color=(.9, 0, .1), linestyles=':') axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.9, 0, .1)) else: axs[1].vlines(p[0], 0, 1, color=(.7, .7, .7), linestyles=':') axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.7, .7, .7)) df_speller_events = pd.DataFrame(plot_speller_events) sp_periods = {'t_begin': [], 't_end': [], 'event': []} # for p in evlist['bl']: # axs[1].add_patch(matplotlib.patches.Rectangle((p[0],0), p[1]-p[0], 1, color=(.9, .9, .9))) for p in evlist['st']: if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]: axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(198.0/255.0, 198.0/255.0, 198.0/255.0), alpha=.15)) sp_periods['t_begin'].append(p[0]) sp_periods['t_end'].append(p[1]) sp_periods['event'].append('stimulus') for p in evlist['re']: if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]: axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(87.0/255.0, 46.0/255.0, 136.0/255.0), alpha=.15)) sp_periods['t_begin'].append(p[0]) sp_periods['t_end'].append(p[1]) sp_periods['event'].append('response') df_speller_periods = pd.DataFrame(sp_periods) axs[1].set_xlim(t_win[0], t_win[0] + p_step) axs[1].set_xlabel('Time [s]') axs[1].set_ylabel('Normalized Firing Rate') axs[1].legend(fontsize='small', loc='upper right') axs[1].set_title("Normalized rate and speller state", y=1.3) axs[1].spines['right'].set_color('none') axs[1].spines['top'].set_color('none') axs[1].spines['left'].set_position(('data', t_win[0] - 1)) axs[1].spines['bottom'].set_position('zero') fig.suptitle(f'{s["title_str"]} – “{splpevs[-1][0]}”', fontsize=15, wrap=True)#, fontweight='medium') fig.show() BASE_PATH_OUT.mkdir(parents=True, exist_ok=True) savename = BASE_PATH_OUT / f'Figure_4_SpellerProgress_plt{i_plt}' fig.savefig(savename.with_suffix('.pdf'), transparent=True) fig.savefig(savename.with_suffix('.svg'), transparent=True) fig.savefig(savename.with_suffix('.eps'), transparent=True) df_raw_data.to_csv(savename.with_suffix('.raw.csv')) df_speller_events.to_csv(savename.with_suffix('.events.csv')) df_speller_periods.to_csv(savename.with_suffix('.periods.csv'))