123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- import os
- from typing import List, Union
- from datetime import datetime as dt
- from helpers import data_management as dm
- import matplotlib.pyplot as plt
- import matplotlib
- import pandas as pd
- import numpy as np
- import yaml
- from helpers.data import DataNormalizer
- import re
- from helpers.nsp import fix_timestamps
- from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE
- plot_win_len = 120.0
- s = {'day': '2019-07-05', 'cfg_t': '14_42_36', 'data_t': '14_42_36', 's': np.datetime64('2019-07-05T14:40:00'),
- 'e': np.datetime64('2019-07-05T15:11:45'), 'title_str':'KIAP Session day 108 14:42 Free Speller',
- 'plot_start': 980, 'plot_win': 90, 'plot_end': 1070}
- fn_sess = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'data_{s["data_t"]}.bin'
- fn_evs = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'events_{s["data_t"]}.txt'
- fn_spl = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'debug_{s["cfg_t"]}.log'
- fn_cfg = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'config_dump_{s["cfg_t"]}.yaml'
- s_dt = dt.strptime(s["day"], '%Y-%m-%d')
- days_post_implant = (s_dt - IMPLANT_DATE).days
- with open(fn_cfg) as stream:
- params = yaml.load(stream, Loader=yaml.Loader)
- datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params, t_lim_start=s.get('s'), t_lim_end=s.get('e'))
- ts, offsets,_ = fix_timestamps(ts)
- tv = ts / 3e4
- with open(fn_evs, 'r') as f:
- evs = f.read().splitlines()
- with open(fn_spl, 'r') as f:
- splevs = f.read().splitlines()
- # fix event timestamps
- tpat = re.compile(r"^(\d+)(, .*)$")
- evs_time = []
- for ev in evs:
- m = tpat.match(ev)
- evs_time.append(np.int64(m.group(1)))
- evs_time = np.asarray(evs_time, dtype=np.int64)
- rollev, = np.where(np.diff(evs_time) < 0)
- for j, i in enumerate(rollev):
- evs_time[(i + 1):] += offsets[j]
- # parse decoder decisions
- pevs = []
- lpat = re.compile(r"^(\d+), b'(.*)'$")
- decpat = re.compile(r".*Decoder decision is: (.*)$")
- alt_resp_pat = re.compile(r".*, response, stop$")
- for ev, ev_time in zip(evs, evs_time):
- m = lpat.match(ev)
- # evts = m.group(1)
- evstr = m.group(2)
- m2 = decpat.match(evstr)
- if m2 is not None:
- pevs.append([float(ev_time) / 3e4, m2.group(1)])
- if len(pevs) == 0:
- for ev, ev_time in zip(evs, evs_time):
- m = lpat.match(ev)
- # evts = m.group(1)
- evstr = m.group(2)
- m2 = alt_resp_pat.match(evstr)
- if m2 is not None:
- pevs.append([float(ev_time) / 3e4, 'decision'])
- # parse events
- evlist = {}
- evlist['bl'] = []
- evlist['st'] = []
- evlist['re'] = []
- rp_ev_pat = re.compile(r"^(\d+), b'.*(stimulus|response|baseline), (start|stop)'$")
- # rp_rst_pat = re.compile(r"^(\d+), b'.*response, start'$")
- tmp_itb = []
- tmp_its = []
- tmp_itr = []
- for ev_time, ev in zip(evs_time, evs):
- m = rp_ev_pat.match(ev)
- if m is not None:
- if m.group(2) == 'baseline':
- if m.group(3) == 'start':
- tmp_itb.append(float(ev_time) / 3e4)
- continue
- elif m.group(3) == 'stop':
- tmp_itb.append(float(ev_time) / 3e4)
- evlist['bl'].append(tmp_itb)
- tmp_itb = []
- continue
- continue
- elif m.group(2) == 'stimulus':
- if m.group(3) == 'start':
- tmp_its.append(float(ev_time) / 3e4)
- continue
- elif m.group(3) == 'stop':
- tmp_its.append(float(ev_time) / 3e4)
- evlist['st'].append(tmp_its)
- tmp_its = []
- continue
- continue
- elif m.group(2) == 'response':
- if m.group(3) == 'start':
- tmp_itr.append(float(ev_time) / 3e4)
- continue
- elif m.group(3) == 'stop':
- tmp_itr.append(float(ev_time) / 3e4)
- evlist['re'].append(tmp_itr)
- tmp_itr = []
- continue
- continue
- continue
- splpevs = []
- spl_pat = re.compile(r"^.* (\w+) - \('(.*)', '(.*)'\)")
- for sev in splevs:
- m = spl_pat.match(sev)
- if m is not None:
- splpevs.append([m.group(2), m.group(3), m.group(1)])
- dn = DataNormalizer(params)
- fr = dn.calculate_norm_rate(datav)
- br_chnum = np.array(dn.norm_rate['ch_ids']) + 1
- i_plt = 0
- p_start = s.get('plot_start', tv[0])
- p_end = min(s.get('plot_end', tv[-1]), tv[-1])
- p_step = s.get('plot_win', plot_win_len)
- while p_start + i_plt * p_step < p_end:
- i_plt += 1
- t_win = np.array([p_start + (i_plt - 1) * p_step, min(p_end, p_start + i_plt * p_step)])
- pl_idx = np.logical_and(t_win[0] <= tv, tv <= t_win[1])
- fig = plt.figure(34, figsize=(12, 6.22))
- fig.clf()
- fig.set_tight_layout(True)
- axs = [None, None] # fig.subplots(2,1)
- axs[0] = fig.add_axes((.04, .65, .93, .21))
- axs[1] = fig.add_axes((.04, .1, .93, .35))
- raw_phs = axs[0].plot(tv[pl_idx], datav[pl_idx, :][:, dn.norm_rate['ch_ids']])
- df_raw_data = pd.DataFrame(index=tv[pl_idx], data=datav[pl_idx, :][:, dn.norm_rate['ch_ids']], columns=dn.norm_rate['ch_ids'])
- df_raw_data['normalized_Rate'] = fr[pl_idx]
- for ph, ch in zip(raw_phs, br_chnum):
- ph.set_label(f'Channel {ch}')
- axs[0].set_xlim(t_win[0], t_win[0] + p_step)
- axs[0].spines['right'].set_color('none')
- axs[0].spines['top'].set_color('none')
- axs[0].spines['left'].set_position(('data', t_win[0] - 1))
- axs[0].spines['bottom'].set_position('zero')
- axs[0].set_ylabel('Firing Rate [Hz]')
- axs[0].legend(loc='upper right', ncol=2, fontsize='small')
- axs[0].set_title('Raw firing rates of channels used for “yes”/“no” classification')
- frph, = axs[1].plot(tv[pl_idx], fr[pl_idx], label='Normalized Rate', color=(157.0/255, 157.0/255, 157.0/255))
- thr_top = params.paradigms.feedback.states.up[1]
- thr_bot = params.paradigms.feedback.states.down[2]
- topl = axs[1].hlines(thr_top, t_win[0], t_win[1], color=(0, .6, .1), label='“Yes” threshold')
- botl = axs[1].hlines(thr_bot, t_win[0], t_win[1], color=(.9, 0, .1), label='“No” threshold')
- plot_speller_events = {'t': [], 'event': [], 'option': [], 'txt': []}
- p: List[Union[float, str]]
- for p, sp in zip(pevs, splpevs):
- if t_win[0] <= p[0] <= t_win[1]:
- cur_spel_st = sp[0]
- if len(cur_spel_st) > 5:
- cur_spel_st = "…" + sp[0][-4:]
- spel_ev_txt = f'{cur_spel_st}\n“{sp[1]}”\n{sp[2]}'
- plot_speller_events['t'].append(p[0])
- plot_speller_events['event'].append(sp[2])
- plot_speller_events['option'].append(sp[1])
- plot_speller_events['txt'].append(cur_spel_st)
- if sp[2] == 'yes':
- axs[1].vlines(p[0], 0, 1, color=(0, .6, .1), linestyles=':')
- axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(0, .6, .1))
- elif sp[2] == 'no':
- axs[1].vlines(p[0], 0, 1, color=(.9, 0, .1), linestyles=':')
- axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.9, 0, .1))
- else:
- axs[1].vlines(p[0], 0, 1, color=(.7, .7, .7), linestyles=':')
- axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.7, .7, .7))
- df_speller_events = pd.DataFrame(plot_speller_events)
- sp_periods = {'t_begin': [], 't_end': [], 'event': []}
- # for p in evlist['bl']:
- # axs[1].add_patch(matplotlib.patches.Rectangle((p[0],0), p[1]-p[0], 1, color=(.9, .9, .9)))
- for p in evlist['st']:
- if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
- axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(198.0/255.0, 198.0/255.0, 198.0/255.0), alpha=.15))
- sp_periods['t_begin'].append(p[0])
- sp_periods['t_end'].append(p[1])
- sp_periods['event'].append('stimulus')
- for p in evlist['re']:
- if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
- axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(87.0/255.0, 46.0/255.0, 136.0/255.0), alpha=.15))
- sp_periods['t_begin'].append(p[0])
- sp_periods['t_end'].append(p[1])
- sp_periods['event'].append('response')
- df_speller_periods = pd.DataFrame(sp_periods)
- axs[1].set_xlim(t_win[0], t_win[0] + p_step)
- axs[1].set_xlabel('Time [s]')
- axs[1].set_ylabel('Normalized Firing Rate')
- axs[1].legend(fontsize='small', loc='upper right')
- axs[1].set_title("Normalized rate and speller state", y=1.3)
- axs[1].spines['right'].set_color('none')
- axs[1].spines['top'].set_color('none')
- axs[1].spines['left'].set_position(('data', t_win[0] - 1))
- axs[1].spines['bottom'].set_position('zero')
- fig.suptitle(f'{s["title_str"]} – “{splpevs[-1][0]}”', fontsize=15, wrap=True)#, fontweight='medium')
- fig.show()
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- savename = BASE_PATH_OUT / f'Figure_4_SpellerProgress_plt{i_plt}'
- fig.savefig(savename.with_suffix('.pdf'), transparent=True)
- fig.savefig(savename.with_suffix('.svg'), transparent=True)
- fig.savefig(savename.with_suffix('.eps'), transparent=True)
- df_raw_data.to_csv(savename.with_suffix('.raw.csv'))
- df_speller_events.to_csv(savename.with_suffix('.events.csv'))
- df_speller_periods.to_csv(savename.with_suffix('.periods.csv'))
|