123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- from helpers import data_management as dm
- import matplotlib.pyplot as plt
- import matplotlib
- from pathlib import Path
- # matplotlib.use('TkAgg')
- import numpy as np
- import pandas as pd
- from helpers.data import DataNormalizer
- import re
- import yaml
- import munch
- import scipy
- import scipy.interpolate
- from scipy import stats
- from helpers.fbplot import plot_df_combined
- from helpers.nsp import fix_timestamps
- from basics import BASE_PATH, BASE_PATH_OUT
- plot_t = munch.munchify({'response_start_samples': {'t': [-1.5, 3.0]}})
- plot_order_comb = munch.munchify({
- 'p': [
- {'r': 0, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': '-', 'lw': 2, 'show_var': True,
- 'sel': lambda df: ((df.target == 'up') | (df.target == 'yes')) & (df.decision == 'up'), 'desc': 'Target: up'},
- {'r': 0, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': '-', 'lw': 2, 'show_var': True,
- 'sel': lambda df: ((df.target == 'down') | (df.target == 'no')) & (df.decision == 'down'),
- 'desc': 'Target: down'},
- {'r': 1, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': ':', 'lw': 2,
- 'show_var': False, 'sel': lambda df: (df.target == 'up') & (df.decision == 'down'),
- 'desc': 'Target: up, decision: down'},
- {'r': 1, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': ':', 'lw': 2, 'show_var': False,
- 'sel': lambda df: (df.target == 'down') & (df.decision == 'up'), 'desc': 'Target: down, decision: up'}
- ],
- 's': [
- {'dcol': 'response_start_samples', 'title': 'Response Period', 'norm': False}],
- 'state': {'down': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
- 'up': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'},
- 'no': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
- 'yes': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'}}},
- )
- SAVENAME = f'Figure_3A_FBTrials'
- def extract_trials(filename, offsets=None):
- if offsets is None:
- offsets = []
- with open(filename, 'r') as f:
- evs = f.read().splitlines()
- # fix event timestamps
- tpat = re.compile(r"^(\d+)(, .*)$")
- stage_pat = re.compile(r"(\d+), b'(feedback|question), (\w+), (\w+), (\w+), (\w+)'$")
- decpat = re.compile(r".*\s(\w+), Decoder decision is: (.*)'$")
- evs_time = []
- for ev in evs:
- m = tpat.match(ev)
- evs_time.append(np.int64(m.group(1)))
- evs_time = np.asarray(evs_time, dtype=np.int64)
- rollev, = np.where(np.diff(evs_time) < 0)
- for j, i in enumerate(rollev):
- evs_time[(i + 1):] += offsets[j]
- trials = []
- this_trial = {}
- for ev, ev_time in zip(evs, evs_time):
- m = stage_pat.match(ev)
- if m is not None:
- if m.group(5) == 'Block':
- continue
- if m.group(5) == 'baseline' and m.group(6) == 'start':
- if len(this_trial) > 0:
- trials.append(this_trial)
- this_trial = {'baseline_start': -1, 'baseline_stop': -1, 'stimulus_start': -1,
- 'stimulus_stop': -1, 'response_start': -1, 'response_stop': -1, 'target': m.group(4),
- 'response_start_samples': None, 'response_start_samples_wnan': None,
- 'response_start_samples_norm': None, 'response_start_samples_good': None}
- this_trial[f'{m.group(5)}_{m.group(6)}'] = ev_time
- continue
- m = decpat.match(ev)
- if m is not None:
- if m.group(2) == 'yes':
- this_trial['decision'] = 'up'
- elif m.group(2) == 'no':
- this_trial['decision'] = 'down'
- else:
- this_trial['decision'] = m.group(2)
- if this_trial['decision']:
- trials.append(this_trial)
- return pd.DataFrame(trials)
- s = {'day': '2019-11-21', 'cfg_t': '15_23_18'}
- day_str = s['day']
- cfg_t_str = s['cfg_t']
- data_t_str = s.get('data_t', s['cfg_t'])
- try:
- fn_cfgdump = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'config_dump_{cfg_t_str}.yaml'
- with open(fn_cfgdump) as stream:
- params = yaml.load(stream, Loader=yaml.Loader)
- except Exception as e:
- print(e)
- fb_states = params.paradigms.feedback.states
- for (k, v) in plot_t.items():
- plot_t[k].s = [int(plot_t[k].t[0] * 1000 / params.daq.spike_rates.loop_interval),
- int(plot_t[k].t[1] * 1000 / params.daq.spike_rates.loop_interval)]
- plot_t[k].swin = np.arange(plot_t[k].s[0], plot_t[k].s[1] + 1)
- plot_t[k].twin = plot_t[k].swin / 1000.0 * params.daq.spike_rates.loop_interval
- # print(params)
- fn_sess = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'data_{data_t_str}.bin'
- fn_evs = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'events_{data_t_str}.txt'
- datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params)
- ts, offsets, _ = fix_timestamps(ts)
- tv = ts / 3e4
- trs = extract_trials(fn_evs, offsets=offsets)
- dn = DataNormalizer(params)
- fr = dn.calculate_norm_rate(datav)
- plot_data = np.reshape(fr, (-1, 1))
- labels = ('fr',)
- sess_info = 'Channels used for control: [' + ' '.join([f"{ch.id}" for ch in params.daq.normalization.channels]) + ']'
- for label, dv in zip(labels, np.hsplit(plot_data, plot_data.shape[1])):
- dv = np.reshape(dv, (-1))
- for ii, row in trs.iterrows():
- t = row['stimulus_start']
- t_off = np.where(ts >= t)[0][0] - 1
- t_stop = row['response_stop']
- if len(np.where(ts >= t_stop)[0]) > 0:
- t_stop_off = np.where(ts >= t_stop)[0][0] - 1
- t_rstart = row['response_start']
- if len(np.where(ts >= t_rstart)[0]) > 0:
- t_rstart_off = np.where(ts >= t_rstart)[0][0] - 1
- resp_start_idx = plot_t.response_start_samples.swin + t_rstart_off
- if np.all(resp_start_idx < len(dv)):
- trs.at[ii, 'response_start_samples'] = dv[resp_start_idx]
- else:
- trs.at[ii, 'response_start_samples'] = np.empty(resp_start_idx.shape)
- trs.at[ii, 'response_start_samples'][:] = np.nan
- good_vals = resp_start_idx < len(dv)
- trs.at[ii, 'response_start_samples'][good_vals] = dv[resp_start_idx[good_vals]]
- trs.at[ii, 'response_start_offset'] = (ts[t_rstart_off] - t_rstart) / 3e4
- trs.at[ii, 'response_start_samples_good'] = resp_start_idx < t_stop_off - 1
- # trs.at[ii, 'response_start_samples_good'] = np.full(plot_t.response_start_samples.swin.shape, True)
- trs.at[ii, 'response_start_samples_wnan'] = trs.at[ii, 'response_start_samples']
- trs.at[ii, 'response_start_samples_wnan'][~ trs.at[ii, 'response_start_samples_good']] = np.nan
- fp = trs.at[ii, 'response_start_samples'][trs.at[ii, 'response_start_samples_good']]
- xp = plot_t.response_start_samples.twin[trs.at[ii, 'response_start_samples_good']]
- x = np.linspace(xp[0], xp[-1], len(plot_t.response_start_samples.twin))
- f = scipy.interpolate.interp1d(xp, fp, kind='cubic')
- trs.at[ii, 'response_start_samples_norm'] = np.interp(x, xp, fp)
- if np.all(np.vstack(trs['response_start_samples']) == 0):
- continue
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- fig = plt.figure(35, figsize=(16, 9))
- fig.clf()
- plot_df_combined(trs, plot_order_comb, plot_t, fb_states, fig=fig,
- options=munch.Munch({'show_thresholds': label == 'fr', 'show_end': 0, 'show_median': False}))
- fig.suptitle(f"{day_str} {cfg_t_str} ch:{label}\n{sess_info} n={len(trs)}")
- fig.show()
- savename = BASE_PATH_OUT / SAVENAME
- fig.savefig(savename.with_suffix('.pdf'))
- fig.savefig(savename.with_suffix('.eps'))
- fig.savefig(savename.with_suffix('.svg'))
- n_correct_up = ((trs.target == trs.decision) & (trs.target == 'up')).sum()
- n_correct_down = ((trs.target == trs.decision) & (trs.target == 'down')).sum()
- n_error_up = (('down' == trs.decision) & (trs.target == 'up')).sum()
- n_error_down = (('up' == trs.decision) & (trs.target == 'down')).sum()
- n_timeout_up = (('unclassified' == trs.decision) & (trs.target == 'up')).sum()
- n_timeout_down = (('unclassified' == trs.decision) & (trs.target == 'down')).sum()
- n_total = len(trs)
- print(
- f"Total trials: {n_total}\nTarget up, decision up: {n_correct_up} trials.\nTarget up, decision down: {n_error_up} trials.\nTarget down, decision up: {n_error_down} trials.\nTarget down, decision down: {n_correct_down} trials.\nCorrect: {(n_correct_up + n_correct_down) / n_total}. Correct up: {n_correct_up / (n_correct_up + n_error_up)}\nCorrect down: {n_correct_down / (n_correct_down + n_error_down)}\nTime-out up target: {n_timeout_up}; Time-out down target: {n_timeout_down}")
- # Export as CSV
- trs_export = pd.DataFrame()
- trs_export['trial_data'] = [r.response_start_samples[r.response_start_samples_good] for i, r in trs.iterrows()]
- trs_export['trial_times'] = [plot_t['response_start_samples']['twin'][r.response_start_samples_good] for i, r in trs.iterrows()]
- trs_export['target'] = trs['target']
- trs_export['decision'] = trs['decision']
- trs_export.to_csv(savename.with_suffix('.csv'))
|