import matplotlib matplotlib.use('TkAgg') from datetime import datetime as dt import matplotlib.pyplot as plt import pandas as pd import numpy as np import re import helpers.sessions as hs from pathlib import Path import yaml import munch from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE, FEEDBACK_CHANGE_DATE, ARRAY_MAPS import logging from logging.handlers import TimedRotatingFileHandler from helpers.tsdumper import TSDumper from matplotlib.colors import ListedColormap import matplotlib as mpl import itertools import scipy.stats as stats from scipy.interpolate import UnivariateSpline logger = logging.getLogger("KIAP") logger.handlers.clear() logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) fh = TimedRotatingFileHandler(BASE_PATH_OUT/'kiap_figures.log', when='D', backupCount=5) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) fh.setLevel(logging.INFO) logger.addHandler(fh) def string_count_eq(str1, str2): """Counts how many characters are the same from the beginning, between two strings""" n_match = 0 for (c1, c2) in zip(str1, str2): if c1 == c2: n_match += 1 else: break return n_match def string_dist(str1, str2): """Helper function to calculate string distance. In our case, we want to know how many characters were added and removed between two strings, the one a speller session started with, and the one it ended with. If these share n characters at beginning, then the first one has m and the second one has another k characters, then the difference in character changes is (m+k).""" return len(str1) + len(str2) - 2 * string_count_eq(str1, str2) T_RE = re.compile(r"^(\d+-\d+-\d+T\d+:\d+:\d+\.\d+)\s+color.*\('(.*)', '(.*)'\)$") def extract_session_data(session_log_file_name): """ given a session log file name for a colour speller session, of the format info_XX_XX_XX.log, this function returns a dictionary with the following keys: phrase_start: speller started with this phrase phrase: speller ended with this phrase n: number of characters in (phrase - phrase_start) ch_per_min: characters spelled per minute """ with open(session_log_file_name, 'r') as f: evs = f.read().splitlines() d = {'n': 0, 'ch_per_min': 0, 'phrase_start': '', 'phrase': ''} m01 = T_RE.match(evs[0]) m02 = T_RE.match(evs[-1]) start_dt = dt.strptime(m01.group(1), '%Y-%m-%dT%H:%M:%S.%f') end_dt = dt.strptime(m02.group(1), '%Y-%m-%dT%H:%M:%S.%f') duration_min = (end_dt - start_dt).total_seconds() / 60.0 d['phrase_start'] = m01.group(2) d['phrase'] = m02.group(2) d['n'] = string_dist(d['phrase_start'], d['phrase']) if duration_min > 0: d['ch_per_min'] = d['n'] / duration_min return d def precompile_sessions(pth=BASE_PATH, start=0, n=None, use_cache=True): """ Finds recorded sessions at pth and loads information into a pandas DataFrame with columns: mode: mode of BCI session (feedback, color, exploration, question) cfg: relative path to the configuration file events: relative path to the event log file data: relative path to the binary data file log: relative path to the debug log file duration_s: duration of session in seconds (last data timestamp - first data timestamp) duration_min: duration of session in minutes start_dt: start of session as datetime, from time in config file end_dt: end of session as datetime d_since_impl: days since implantation The session data will be saved in a cache file and reloaded from there by default, as it takes some time to read each individual data file. This reading is necessary to get the most accurate estimate for the length of a session. Since on some occasions the saving of data files was not ended automatically at the end of a session, the events file will also be read to determine the length. Params: pth: base path for data folders start: skip this many sessions at start n: read this many sessions use_cache: try to read cache file if True """ cache_file_name = Path(pth, 'session_cache.pkl') if use_cache and cache_file_name.exists(): logger.info(f"Using cache file {cache_file_name}") cfgs = pd.read_pickle(cache_file_name) return cfgs (cfgs, _) = hs.get_sessions(pth, start=start, n=n) n_total = len(cfgs) t_diffs = [] start_dts = [] for i, s in cfgs.iterrows(): logger.debug(f"loading {i + 1} of {n_total}: {s['cfg']}") try: (ts, _, _, evts) = hs.get_session_data(pth, s) t_diff = evts[-1] - evts[0] logger.info(f"Loaded session of {t_diff} s.") t_diffs.append(t_diff) except FileNotFoundError: logger.warning(f"Data file not found for {pth}") t_diffs.append(0) cfgstr = s['cfg'] cfgstr = cfgstr.replace('/', '_') cfgstr = cfgstr.replace('\\', '_') start_dt = dt.strptime(cfgstr, '%Y-%m-%d_config_dump_%H_%M_%S.yaml') start_dts.append(start_dt) cfgs['duration_s'] = t_diffs cfgs['duration_min'] = cfgs['duration_s'] / 60.0 cfgs['start_dt'] = start_dts cfgs['end_dt'] = cfgs['start_dt'] + pd.to_timedelta(cfgs['duration_s'], unit='seconds') cfgs['d_since_impl'] = cfgs['start_dt'].apply(lambda x: (x.to_pydatetime() - IMPLANT_DATE).days) cfgs.to_pickle(cache_file_name) return cfgs def add_session_info(d): """ Compute particular bits of information given data sessions. Params: d: DataFrame containing session information, as generated by precompile_sessions() Adds to DataFrame: cum_dur: cumulative session duration per day ecol: edge color for plots col: color for plots based on session mode Returns: DataFrame """ d.set_index(['d_since_impl'], inplace=True) d["cum_dur"] = d.groupby(level='d_since_impl')['duration_min'].cumsum() - d['duration_min'] d.reset_index(inplace=True) # add colour information to session entries d['ecol'] = [(1, 1, 1, 1)] * len(d) d['col'] = [(0.1, 0.1, 0.1)] * len(d) idx = d['mode'] == 'feedback' d.loc[idx, 'col'] = pd.Series([(0, 0.447, 0.741)] * idx.sum(), index=d.loc[idx].index) idx = d['mode'] == 'question' d.loc[idx, 'col'] = pd.Series([(0.494, 0.184, 0.556)] * idx.sum(), index=d.loc[idx].index) idx = d['mode'] == 'color' d.loc[idx, 'col'] = pd.Series([(0.85, 0.325, 0.098)] * idx.sum(), index=d.loc[idx].index) idx = d['mode'] == 'exploration' d.loc[idx, 'col'] = pd.Series([(0.301, 0.745, 0.933)] * idx.sum(), index=d.loc[idx].index) return d def add_speller_data(d, pth=BASE_PATH): """ Reads speller results for every speller session and adds that information to the dataframe. Params: d: DataFrame containing session information, as generated by precompile_sessions() Adds to DataFrame: 'n': length of communication for speller block 'ch_per_min': number of characters spelled per minute 'phrase_start': phrase at begin of speller block 'phrase': phrase at end of speller block 'cum_n': cumulative length of phrases per day Returns: DataFrame """ d['n'] = pd.array([None] * len(d), dtype='Int32') d['ch_per_min'] = None d['phrase_start'] = None d['phrase'] = None for ix, rw in d.loc[d['mode'] == 'color'].iterrows(): sp_dict = extract_session_data(Path(pth, rw['log'])) spr = pd.Series(sp_dict, name=ix) d.loc[ix, spr.index] = spr d.set_index(['d_since_impl'], inplace=True) ix = d['mode'] == 'color' d.loc[ix, "cum_n"] = d.loc[ix].groupby(level='d_since_impl')['n'].cumsum() - d.loc[ix, 'n'] d.reset_index(inplace=True) return d def extract_trials_old(filename): """ filename - path to an info_*.log file 'old': log pattern before 20 July 2019 ('up' / 'down' trials, 'yes' response for correct, 'unclassified' for timeout) """ with open(filename, 'r') as f: evs = f.read().splitlines() # fix event timestamps decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$") contingency = pd.DataFrame imap = {'up': 'down', 'down': 'up'} conditions = [] responses = [] for ev in evs: m = decpat.match(ev) if m is not None: condition = m.group(2) response = m.group(1) if condition == 'baseline': continue if response == 'yes': response = condition elif response == 'unclassified': response = imap.get(condition, 'unclassified') conditions.append(condition) responses.append(response) all_condition = pd.Categorical(conditions, categories=['up', 'down']) all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified']) # print(evs) ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response']) return ct def extract_trials_new(filename): """ filename - path to an info_*.log file 'new': log pattern on and after 20 July 2019: ('up' / 'down' trials, 'yes' / 'no' / 'unclassified' response) """ with open(filename, 'r') as f: evs = f.read().splitlines() # fix event timestamps decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$") contingency = pd.DataFrame conditions = [] responses = [] for ev in evs: m = decpat.match(ev) if m is not None: condition = m.group(2) response = m.group(1) if response == 'yes': response = 'up' elif response == 'no': response = 'down' conditions.append(condition) responses.append(response) all_condition = pd.Categorical(conditions, categories=['up', 'down']) all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified']) # dft = pd.DataFrame({'condition': conditions, 'response':responses}) # ct = pd.crosstab(dft.condition, dft.response) ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response']) return ct def extract_trials_q(filename): """ filename - path to an info_*.log file log pattern: 2019-10-02T21:20:05.053678 question - Decoder decision: no - ('No question', '002_11038.wav') """ with open(filename, 'r') as f: evs = f.read().splitlines() # fix event timestamps decpat = re.compile(r".*\squestion - Decoder decision: (\w*) - \('(Yes|No) question', '(.+)'\)$") contingency = pd.DataFrame conditions = [] responses = [] for ev in evs: m = decpat.match(ev) if m is not None: condition = m.group(2) response = m.group(1) if response == 'yes': response = 'up' elif response == 'no': response = 'down' if condition == 'Yes': condition = 'up' elif condition == 'No': condition = 'down' conditions.append(condition) responses.append(response) all_condition = pd.Categorical(conditions, categories=['up', 'down']) all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified']) # dft = pd.DataFrame({'condition': conditions, 'response':responses}) # ct = pd.crosstab(dft.condition, dft.response) ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response']) return ct def add_feedback_info(d, pth=BASE_PATH): """ For all feedback sessions, read log file and save contingency table. """ if 'ct' not in d.columns: d['ct'] = [None for _ in range(len(d))] for ix, rw in d.iterrows(): if rw['mode'] != 'feedback': continue if FEEDBACK_CHANGE_DATE > rw.start_dt: ct = extract_trials_old(pth / rw.log) else: ct = extract_trials_new(pth / rw.log) cond_sums = ct[ct.index != 'unclassified'].sum() tpr = ct.loc['up', 'up'] / cond_sums['up'] fpr = ct.loc['up', 'down'] / cond_sums['down'] acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down']) d.at[ix, 'ct'] = ct d.at[ix, 'tpr'] = tpr d.at[ix, 'fpr'] = fpr d.at[ix, 'acc'] = acc d.at[ix, 'n_trials'] = ct.sum().sum() return d def calculate_feedback_before_speller(d): """For each day, find feedback sessions before the first speller session and add up trials.""" u_dsi = d['d_since_impl'].unique() day_n_fb = [(di, d_day[(d_day['mode'] == 'feedback') & (d_day.index < ix)]['ct'].sum().sum().sum()) for di in u_dsi for d_day in [d[d['d_since_impl'] == di]] for ix in [d_day[d_day['mode'] == 'color'].first_valid_index()] if ix is not None ] return day_n_fb def add_question_info(d, pth=BASE_PATH): """ For all feedback sessions, read log file and save contingency table. """ if 'ct' not in d.columns: d['ct'] = [None for _ in range(len(d))] for ix, rw in d.iterrows(): if rw['mode'] != 'question': continue ct = extract_trials_q(pth / rw.log) cond_sums = ct[ct.index != 'unclassified'].sum() tpr = ct.loc['up', 'up'] / cond_sums['up'] fpr = ct.loc['up', 'down'] / cond_sums['down'] acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down']) d.at[ix, 'ct'] = ct d.at[ix, 'tpr'] = tpr d.at[ix, 'fpr'] = fpr d.at[ix, 'acc'] = acc return d def prepare_for_annotation_export(d): """ Prepare a yaml file for annotation of speller sessions. Will be written at [BASE_PATH_OUT]/records_for_annotation.yml """ d2 = d.reset_index().set_index('start_dt', drop=False) d2 = d2.loc[d2['mode'] == 'color', ['data', 'd_since_impl', 'start_dt', 'phrase_start', 'phrase']] d2['intelligible'] = None d2['start'] = d2['start_dt'].map(lambda x: x.strftime('%Y-%m-%d %H:%M:%S')) d2 = d2[['data', 'start', 'd_since_impl', 'phrase_start', 'phrase', 'intelligible']] d_dict = d2.to_dict(orient='index') with open(BASE_PATH_OUT / 'records_for_annotation.yml', 'w') as f: f.write("""# Rating scheme for intelligibility of patient's communications. # # Instructions for filling this file: # # We consider the result of one total session / speller run. # # 1. Find session in log by date of data file (also given in plain text) # 2. Look up speller start (key 'phrase_start') and final output (key 'phrase') # and look up speller output / session remarks. # 3. Rate speller output. # For copy spelling sessions: # 0 – completely wrong # 1 – up to 20% of characters wrong or missing # 2 - no mistake # For free spelling: # 0 - incomprehensible speller output # 1 - partially understandable, but with doubts due to spelling mistakes # 2 - unambiguous to family / experimenter (even if single letters are # wrong or missing; even if words are incomplete) # where one session's output could be counted into several categories, # category 1 is likely appropriate. # 4. Find the record for the session in list below and replace the 'null' entry # under the 'intelligible' key with your rating. """) yaml.dump(d_dict, f, default_flow_style=False, Dumper=TSDumper, sort_keys=False) def add_annotation(d): """ Read annotations and add them to DataFrame. Params: d: DataFrame containing session information, as generated by precompile_sessions() Adds to DataFrame: 'intelligible': rating of intelligibility of a spelled session 'col': updated color based on intelligibility Returns: DataFrame """ # with open(Path('annotations', 'speller', 'records_for_annotation_full.yml'), 'r') as f: with open(Path('annotations', 'speller', 'records_for_annotation_consolidated.yml'), 'r') as f: anno = munch.Munch.fromYAML(f, Loader=yaml.Loader) adict = pd.DataFrame.from_dict(anno, orient='index') d.reset_index(inplace=True) d.set_index('start_dt', drop=False, inplace=True) d.loc[adict.index, 'intelligible'] = adict['intelligible'] d.loc[adict.index, 'rating'] = adict['rating'] d.loc[(d['intelligible'] == 2) & (d['mode'] == 'color'), 'col'] = [[(0.318, 0.039, 0.090)]] d.loc[(d['intelligible'] == 1) & (d['mode'] == 'color'), 'col'] = [[(0.635, 0.078, 0.184)]] d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'col'] = [[(1.000, 0.737, 0.843)]] d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'ecol'] = [[(.2, .2, .2)]] d.set_index('index', inplace=True) return d def add_channel_info(d, pth=BASE_PATH): """ For all sessions, load config files and extract information about channels being used in sessions. Channels are in 1-base number, corresponding to Blackrock's numbering scheme. """ d.loc[:, 'channels'] = None d.loc[:, 'use_all'] = None # d.loc[:, 'norm'] = None d.loc[:, 'submode'] = None for ix, rw in d.iterrows(): with open(Path(pth, rw.cfg), 'r') as f: cfg = munch.Munch.fromYAML(f, Loader=yaml.Loader) d.loc[ix, 'channels'] = [[ch.id + 1] for ch in cfg.daq.normalization.channels] d.loc[ix, 'use_all'] = cfg.daq.normalization.use_all_channels # d.loc[ix, 'norm'] = cfg.daq.normalization # test if paradigms key in config file exists. if not, load paradigms file if cfg.get('paradigms') is None: p_fn = rw.cfg.replace('config_dump', 'paradigm') with open(Path(pth, p_fn), 'r') as f: cfg.paradigms = munch.Munch.fromYAML(f, Loader=yaml.Loader) try: if rw['mode'] == 'question': d.loc[ix, 'submode'] = cfg.paradigms.question.mode[cfg.paradigms.question.selected_mode] elif rw['mode'] == 'color': d.loc[ix, 'submode'] = cfg.paradigms.color.mode[cfg.paradigms.color.selected_mode] elif rw['mode'] == 'feedback': d.loc[ix, 'submode'] = cfg.paradigms.feedback.mode[cfg.paradigms.feedback.selected_mode] except Exception: logger.warning(f"Exception reading mode for row {ix}", exc_info=True) logger.debug(f"Loading normalization info for session {ix} ({d.loc[ix, 'mode']}, {d.loc[ix, 'submode']}):" f" {d.loc[ix, 'channels']}") return d def get_indexer_and_color(d, mode, intelligible=None): """ Given a session DataFrame, the session mode, and optionally an intelligibility rating, return indexer for rows matching that criterion, colour, edge colour, and label for plotting. Params: d: session DataFrame mode: one of 'color', 'feedback', 'exploration', 'question' intelligible: if mode == 'color', this should be 0, 1, or 2 """ ix = d['mode'] == mode ecol = (1, 1, 1, 1) col = (.1, .1, .1) label = "" if mode == 'color': if intelligible is None: ix = ix & pd.isna(d['intelligible']) else: ix = ix & (d['intelligible'] == intelligible) label = "Speller" if intelligible == 2: col = (0.635, 0.078, 0.184) label = "Speller clear" elif intelligible == 1: col = (0.85, 0.325, 0.098) label = "Speller ambiguous" elif intelligible == 0: col = (0.929, 0.694, 0.125) # ecol = (.2, .2, .2) label = "Speller unintelligible" else: col = (0.929, 0.894, 0.825) # ecol = (.2, .2, .2) label = "Speller not rated" elif mode == 'feedback': col = (0.466, 0.674, 0.188) label = "Feedback Training" elif mode == 'exploration': col = (0, 0.447, 0.741) label = "Exploration" elif mode == 'question': col = (0.301, 0.745, 0.933) label = "Questions" return dict(ix=ix, col=col, ec=ecol, label=label) def plot_bars_for_sel(ax, d, mode, intelligible=None): """ Plot bars into axis. Params: ax: axis to plot into d: Pandas DataFrame containing session info mode: BCI mode, one of 'color', 'feedback', 'exploration', 'question' intelligible: if mode == 'color', this should be 0, 1, or 2 """ ic = get_indexer_and_color(d, mode, intelligible) return ax.bar(x=d.loc[ic['ix'], 'd_since_impl'], height=d.loc[ic['ix'], 'duration_min'], color=ic['col'], ec=ic['ec'], label=ic['label'], bottom=d.loc[ic['ix'], 'cum_dur_min'], lw=.5, width=.8) def prepare_sessions(pth=BASE_PATH): """Combine a few steps to load session data""" d = precompile_sessions(pth=pth) add_session_info(d) add_speller_data(d, pth=pth) return d def plot_sessions(d): """Load sessions, plot, and print summary""" from brokenaxes import brokenaxes n_days = len(d['d_since_impl'].unique()) x_ranges = ((105, 126), (146, 470)) # ((105, 126), (146, 163), (174, 212), (223, 227), (238, 465)) d2 = d.set_index(['d_since_impl']) other_m_d = d2[d2['mode'] != 'color'].groupby(['d_since_impl', 'mode'])[['duration_min']].sum() other_m_d.reset_index('mode', inplace=True) sp_d = d2[d2['mode'] == 'color'][ ['start_dt', 'mode', 'duration_min', 'cfg', 'events', 'data', 'log', 'phrase', 'intelligible', 'n', 'ch_per_min']] df_dur_plot = pd.concat([other_m_d, sp_d]) df_dur_plot.sort_index(inplace=True, kind='mergesort') df_dur_plot['cum_dur_min'] = df_dur_plot.groupby(level=0)['duration_min'].cumsum() - df_dur_plot['duration_min'] df_dur_plot.reset_index(inplace=True) # save number of letters spelled in 'intelligible'==2 group, per day n_per_day_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl')['n'].sum() n_per_day_intell.name = 'n_per_day' cpm_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl').apply( lambda x: x['n'].sum() / (x['duration_s'].sum() / 60.0)) cpm_intell.name = 'cpm_per_day' spl = UnivariateSpline(n_per_day_intell.index, n_per_day_intell.values, s=850000, k=3) tx = np.arange(n_per_day_intell.index[0], n_per_day_intell.index[-1]) smoothv = spl(tx) spl_cpm = UnivariateSpline(cpm_intell.index, cpm_intell.values, s=850000, k=3) tx_cpm = np.arange(cpm_intell.index[0], cpm_intell.index[-1]) smoothv_cpm = spl_cpm(tx_cpm) spell_int = d.groupby('d_since_impl').apply(lambda x: ((x['mode'] == 'color') & (x['intelligible'] >= 2)).any()) spell_int_days = spell_int[spell_int].index spell_not_int = d.groupby('d_since_impl').apply( lambda x: ((x['mode'] == 'color') & (x['intelligible'] < 2)).any() & ~ ( (x['mode'] == 'color') & (x['intelligible'] >= 2)).any()) spell_not_int_days = spell_not_int[spell_not_int].index no_spell = d.groupby('d_since_impl').apply(lambda x: ~(x['mode'] == 'color').any()) no_spell_days = no_spell[no_spell].index save_name = BASE_PATH_OUT / "Figure_2_SessionSummary" BASE_PATH_OUT.mkdir(parents=True, exist_ok=True) fig = plt.figure(2, figsize=(15, 6)) fig.clf() gs = fig.add_gridspec(3, hspace=0, height_ratios=[4,4,1]) axs = gs.subplots(sharex=True) no_spell_color = (0.93725, 0.92941, 0.96078) not_intell_color = (0.73725, 0.74118, 0.86275) intell_color = (0.45882, 0.41961, 0.69412) bax = axs[2] bax.bar(no_spell_days, 1, width=1, color=no_spell_color) bax.bar(spell_not_int_days, 1, width=1, color=not_intell_color) bax.bar(spell_int_days, 1, width=1, color=intell_color) bax.set_ylim([0, 1]) bax.set_xlabel('days since implantation') bax.set_yticks([]) bax = axs[0] bax.plot(n_per_day_intell.index, n_per_day_intell.values, color=intell_color, marker='o', linestyle='none') bax.plot(tx, smoothv, color=intell_color) bax.set_ylabel('number of characters') ax2 = axs[1] ax2.set_ylabel('characters per minute') # we already handled the x-label with ax1 ax2.plot(cpm_intell.index, cpm_intell.values, color=intell_color, marker='o', linestyle='none') ax2.plot(tx_cpm, smoothv_cpm, color=intell_color) ax2.tick_params(axis='y') fig.show() fig.savefig(save_name.with_suffix(".pdf")) logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>") fig.savefig(save_name.with_suffix(".svg")) logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>") fig.savefig(save_name.with_suffix(".eps")) logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>") ## Aggregate sessions. # First, for each day, how much time was spent with feedback, questions, exploration fbdesc = other_m_d[other_m_d['mode'] == 'feedback'].describe() qdesc = other_m_d[other_m_d['mode'] == 'question'].describe() exdesc = other_m_d[other_m_d['mode'] == 'exploration'].describe() # Sum up all speller sessions aggr_sp_d = sp_d.groupby(['d_since_impl'])[['duration_min']].sum() spdesc = aggr_sp_d.describe() # Only take speller sessions where the message could be understood at least partially. aggr_sp_int_d = sp_d[sp_d['intelligible'] > 0].groupby(['d_since_impl'])[['duration_min', 'n']].sum() int_spdesc = aggr_sp_int_d.describe() aggr_sum = aggr_sp_int_d.sum() aggr_sp_clear_d = sp_d[sp_d['intelligible'] > 1].groupby(['d_since_impl'])[['duration_min', 'n']].sum() clear_spdesc = aggr_sp_clear_d.describe() aggr_clear_sum = aggr_sp_clear_d.sum() ch_per_min = sp_d[(sp_d['mode'] == 'color') & (sp_d['intelligible'] == 2)]['ch_per_min'] desc_str = f""" This analysis covers visits on {n_days} days. On average, {fbdesc.loc['mean', 'duration_min']:4.1f} minutes were spent in feedback training (min/25%/50%/75%/max: {fbdesc.loc['min', 'duration_min']:4.1f}, {fbdesc.loc['25%', 'duration_min']:4.1f}, {fbdesc.loc['50%', 'duration_min']:4.1f}, {fbdesc.loc['75%', 'duration_min']:4.1f}, {fbdesc.loc['max', 'duration_min']:4.1f}). On {qdesc.loc['count', 'duration_min']:n} days, the question paradigm was performed. On average, {qdesc.loc['mean', 'duration_min']:4.1f} minutes were spent in the question paradigm (min/25%/50%/75%/max: {qdesc.loc['min', 'duration_min']:4.1f}, {qdesc.loc['25%', 'duration_min']:4.1f}, {qdesc.loc['50%', 'duration_min']:4.1f}, {qdesc.loc['75%', 'duration_min']:4.1f}, {qdesc.loc['max', 'duration_min']:4.1f}). On {exdesc.loc['count', 'duration_min']:n} days, the exploration paradigm was performed. On average, {exdesc.loc['mean', 'duration_min']:4.1f} minutes were spent in the exploration paradigm (min/25%/50%/75%/max: {exdesc.loc['min', 'duration_min']:4.1f}, {exdesc.loc['25%', 'duration_min']:4.1f}, {exdesc.loc['50%', 'duration_min']:4.1f}, {exdesc.loc['75%', 'duration_min']:4.1f}, {exdesc.loc['max', 'duration_min']:4.1f}). On {spdesc.loc['count', 'duration_min']:n} days, we attempted to use the speller. On {int_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate at least partially understandable output. On average, {int_spdesc.loc['mean', 'duration_min']:4.1f} minutes were spent spelling (min/25%/50%/75%/max: {int_spdesc.loc['min', 'duration_min']:4.1f}, {int_spdesc.loc['25%', 'duration_min']:4.1f}, {int_spdesc.loc['50%', 'duration_min']:4.1f}, {int_spdesc.loc['75%', 'duration_min']:4.1f}, {int_spdesc.loc['max', 'duration_min']:4.1f}). On average, the daily output was {int_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {int_spdesc.loc['min', 'n']:n}, {int_spdesc.loc['25%', 'n']:n}, {int_spdesc.loc['50%', 'n']:n}, {int_spdesc.loc['75%', 'n']:n}, {int_spdesc.loc['max', 'n']:n}). Overall, the patient's at least partially comprehensible utterances comprised {aggr_sum['n']:n} characters produced over {aggr_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_sum['n'] / aggr_sum['duration_min']:4.2f} characters per minute. On {clear_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate clearly understandable output. On average, {clear_spdesc.loc['mean', 'duration_min']:4.1f} minutes were spent spelling (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'duration_min']:4.1f}, {clear_spdesc.loc['25%', 'duration_min']:4.1f}, {clear_spdesc.loc['50%', 'duration_min']:4.1f}, {clear_spdesc.loc['75%', 'duration_min']:4.1f}, {clear_spdesc.loc['max', 'duration_min']:4.1f}). On average, the daily output was {clear_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'n']:n}, {clear_spdesc.loc['25%', 'n']:n}, {clear_spdesc.loc['50%', 'n']:n}, {clear_spdesc.loc['75%', 'n']:n}, {clear_spdesc.loc['max', 'n']:n}). Overall, the patient's clearly intelligible communications comprised {aggr_clear_sum['n']:n} characters produced over {aggr_clear_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_clear_sum['n'] / aggr_clear_sum['duration_min']:4.2f} characters per minute. Per-session spelling rate was min/median/max: {ch_per_min.min():.1f}/{ch_per_min.median():.1f}/{ch_per_min.max():.1f} characters per minute. On {len(d.groupby('d_since_impl')) - len(d[d['mode']=='color'].groupby('d_since_impl'))} days, use of the speller was not attempted because criterion was not reached or because of other circumstances. """ logger.info(desc_str) days_df = pd.concat([pd.DataFrame(index=spell_int_days, columns=['intell'], data='intelligible'), pd.DataFrame(index=spell_not_int_days, columns=['intell'], data='not_intelligible'), pd.DataFrame(index=no_spell_days, columns=['intell'], data='no_speller')]).sort_index() n_df = pd.concat([n_per_day_intell, cpm_intell], axis=1) session_summary_df = pd.concat([n_df, days_df], axis=1) session_summary_df.to_csv(save_name.with_suffix(".csv")) return d, fig, bax def get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=True): """ Finds neurofeedback sessions before speller :param d: DataFrame of sessions :param ignore_sessions_before_fb_change: if True (default), ignore sessions before logging scheme was changed. :return: DataFrame, list of indices, list of pairs of indices of NF blocks and corresponding speller blocks. """ if ignore_sessions_before_fb_change: color_ix = d[d['mode'].isin(['color']) & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index fb_ix = d[(d['mode'] == 'feedback') & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index else: color_ix = d[d['mode'].isin(['color'])].index fb_ix = d[(d['mode'] == 'feedback')].index fbix = np.unique([fb_ix[np.where(fb_ix < ci)[0][-1]] for ci in color_ix]) fb_ci_pairs = [(fb_ix[np.where(fb_ix < ci)[0][-1]], ci) for ci in color_ix] return d, fbix, fb_ci_pairs def generate_fb_session_list(d): CFG_RE = re.compile(r"^(\d+-\d+-\d+)/config_dump_(\d+_\d+_\d+)\.yaml$") DTA_RE = re.compile(r"^\d+-\d+-\d+/data_(\d+_\d+_\d+)\.bin$") (_, fbix, _) = get_fb_sessions_before_speller(d) fb_list_d = [] for fbi in fbix: s = d.loc[fbi] cfm = CFG_RE.match(d.loc[fbi].cfg) dtm = DTA_RE.match(d.loc[fbi].data) line_d = {'day': cfm[1], 'cfg_t': cfm[2], 'data_t': dtm[1]} fb_list_d.append(line_d) return fb_list_d def rand_jitter(arr): stdev = .005 * (max(arr) - min(arr)) return arr + np.random.randn(len(arr)) * stdev def speller_performance_by_nf(d): *_, fb_ci_pairs = get_fb_sessions_before_speller(d) # acc_n = list(map(lambda p: (d.iloc[p[0]]['acc'], d.iloc[p[1]]['n']), fb_ci_pairs)) acc_i_n = [(d.iloc[p[0]]['d_since_impl'], d.iloc[p[0]]['acc'], int(d.iloc[p[1]]['intelligible']), d.iloc[p[1]]['n']) for p in fb_ci_pairs] df = pd.DataFrame(acc_i_n, columns=['DSI', 'Acc', 'Int', 'N']) spc_acc_int = stats.spearmanr(df['Acc'], df['Int']) spc_acc_n = stats.spearmanr(df['Acc'], df['N']) res_text = f""" Correlation between Neurofeedback task accuracy and subsequent spelling. There were {len(df)} pairs of speller blocks and preceding neurofeedback blocks. The speller output was rated 0 for unintelligble, 1 for partially intelligible, 2 for intelligible. The Spearman correlation between Neurofeedback task accuracy and speller intelligibility was {spc_acc_int.correlation:4.3f} (p={spc_acc_int.pvalue:4.3e}). The Spearman correlation between Neurofeedback accuracy and number of letters spelled was {spc_acc_n.correlation:4.3f} (p={spc_acc_n.pvalue:4.3e}). """ logger.info(res_text) return res_text def plot_audio_feedback_tpfp(d): """ Finds all audio feedback sessions before speller. It then calculates true positive rates and false positive rates and plots each of these sessions in a scatter plot. """ (_, fbix, _) = get_fb_sessions_before_speller(d) BASE_PATH_OUT.mkdir(parents=True, exist_ok=True) save_name = BASE_PATH_OUT / "Figure_3B_TPFP" fig = plt.figure(2, figsize=(15, 6)) fig.clf() ax = fig.subplots() ax.scatter(rand_jitter(d.loc[fbix, 'fpr']), rand_jitter(d.loc[fbix, 'tpr'])) ax.plot([0, 1], [0, 1], 'k:') ax.set_ylabel('True Positive Rate') ax.set_xlabel('False Positive Rate') ax.set_aspect('equal') fig.savefig(save_name.with_suffix(".pdf")) logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>") fig.savefig(save_name.with_suffix(".eps")) logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>") fig.savefig(save_name.with_suffix(".svg")) logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>") ctsum = d.loc[fbix, 'ct'].sum() cond_sums = ctsum[ctsum.index != 'unclassified'].sum() n_trials = ctsum.sum().sum() n_timeout = ctsum[ctsum.index == 'unclassified'].sum().sum() n_correct = (ctsum.loc['up', 'up'] + ctsum.loc['down', 'down']) n_up_incorrect = ctsum.loc['down', 'up'] n_down_incorrect = ctsum.loc['up', 'down'] n_up = ctsum.sum()['up'] n_down = ctsum.sum()['down'] r_up_incorrect = n_up_incorrect / n_up r_down_incorrect = n_down_incorrect / n_down tpr = ctsum.loc['up', 'up'] / cond_sums['up'] fpr = ctsum.loc['up', 'down'] / cond_sums['down'] acc = n_correct / n_trials acc_thr = 0.8 fraction_less_than_80 = (d.loc[fbix, 'acc'] < acc_thr).sum() / len(fbix) acc_min = 100.0 * d.loc[fbix, 'acc'].min() acc_max = 100.0 * d.loc[fbix, 'acc'].max() res_str = f"""Contingency table (columns = conditions, rows = observations):\n{ctsum}\n There were {len(fbix)} sessions. In total, there were {n_trials} trials. The accuracy over all trials was {100.0 * acc:4.1f}% (n={n_correct}). There were {n_timeout} timeout trials ({100.0 * n_timeout / n_trials:4.1f}%). There were {n_up_incorrect} ({100.0 * r_up_incorrect:4.1f}%) incorrect 'up' trials and {n_down_incorrect} ({100.0 * r_down_incorrect:4.1f}%) incorrect 'down' trials.\n In the last feedback sessions before speller sessions, the median accuracy was {100.0 * d.loc[fbix, 'acc'].median():4.1f}%, the minimum was {acc_min:4.1f}%. In {100.0 * fraction_less_than_80:4.1f}% of the sessions, accuracy was below {100.0 * acc_thr:4.1f}%. """ logger.info(res_str) # Export as CSV d_filt = d.loc[fbix].copy() for i, row in d_filt.iterrows(): d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up'] d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up'] d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up'] d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down'] d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down'] d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down'] d_filt[ ['d_since_impl', 'start_dt', 'duration_s', 'channels', 'mode', 'data', 'tpr', 'fpr', 'acc', 'up_up', 'up_down', 'up_unclassified', 'down_up', 'down_down', 'down_unclassified']].to_csv(save_name.with_suffix(".csv")) def plot_fb_accuracy(d): """ Plots all feedback sessions' accuracy as function of day. This generates Supplementary Figure 2. """ from brokenaxes import brokenaxes from matplotlib.gridspec import GridSpec BASE_PATH_OUT.mkdir(parents=True, exist_ok=True) save_name = BASE_PATH_OUT / "Figure_S2_fb_acc" d_filt = d[d['mode'] == 'feedback'].copy() d_filt['ddiff'] = (d_filt['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D') # accv = d_filt[['d_since_impl', 'acc', 'start_dt', 'ddiff']] n_fb_sessions = d_filt.groupby('d_since_impl').count()['acc'] n_fb_sessions.name = 'n_fb' # Get fb before speller (_, fbix, _) = get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=False) # d_nf = d.loc[fbix].copy() # d_nf['ddiff'] = (d_nf['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D') d_filt['b4sp'] = 0 d_filt.loc[fbix, 'b4sp'] = 1 fig = plt.figure(22, figsize=(15, 8), constrained_layout=False) fig.clf() gs = GridSpec(ncols=2, nrows=2, figure=fig, hspace=.25, top=.9, bottom=.1) bax = fig.add_subplot(gs[0, :]) # bax = brokenaxes(xlims=x_ranges, hspace=.05, tilt=65, d=.005, subplot_spec=gs[0, :]) bax.plot(d_filt['ddiff'] - .5, d_filt['acc'], 'o', color=(.5, .5, .5), ms=2, label="NF Trial sets") accvmax = d_filt.groupby('d_since_impl')[['acc']].max() # bax.plot(accvmax.index, accvmax['acc'], '-', color=(0, 0, 0), lw=2, label="Daily maximum") d_b4sp = d_filt[d_filt['b4sp'] > 0] bax.plot(d_b4sp['ddiff'] - .5, d_b4sp['acc'], 'o', color=(1, 0, 0), ms=3, lw=2, label="NF before speller") bax.set_ylim(-0.02, 1.02) bax.set_xlabel("Days after implantation") bax.set_ylabel("Accuracy") bax.set_title("Accuracy of Feedback trials") bax.legend(loc="lower right") dsi = 113 d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')] d_sub.reset_index(inplace=True) d_sub.index = d_sub.index + 1 f_ax1 = fig.add_subplot(gs[1, 0]) f_ax1.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k') f_ax1.set_ylim(-0.05, 1.05) f_ax1.set_xlabel("Time since start [min]") f_ax1.set_ylabel("Accuracy") f_ax1.set_title(f"Feedback trial sets on day {dsi} post-implantation") dsi = 197 d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')] d_sub.reset_index(inplace=True) d_sub.index = d_sub.index + 1 f_ax2 = fig.add_subplot(gs[1, 1]) f_ax2.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k') f_ax2.set_ylim(-0.05, 1.05) f_ax2.set_xlabel("Time since start [min]") f_ax2.set_ylabel("Accuracy") f_ax2.set_title(f"Feedback trial sets on day {dsi} post-implantation") fig.savefig(save_name.with_suffix(".pdf")) logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>") fig.savefig(save_name.with_suffix(".eps")) logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>") fig.savefig(save_name.with_suffix(".svg")) logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>") thr = .9 descstr = f""" The best run per day had an accuracy of {thr:.1f} in {(accvmax['acc'] > thr).sum() / len(accvmax) * 100.0:.1f} % of days. Over the reported period, there were {n_fb_sessions.sum()} feedback sessions, min/median/maximum per day: {n_fb_sessions.min()} / {n_fb_sessions.median()} / {n_fb_sessions.max()}. Over all {d.loc[d['mode'] == 'feedback', 'acc'].count()} feedback sessions, the accuracy was {100*d.loc[d['mode'] == 'feedback', 'acc'].mean():.1f}%. """ logger.info(descstr) # Export as CSV d_filt.loc[:, 'channels1'] = None for i, row in d_filt.iterrows(): d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up'] d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up'] d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up'] d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down'] d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down'] d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down'] d_filt.at[i, 'channels1'] = ", ".join([str(x) for x in row.channels]) d_filt[ ['d_since_impl', 'start_dt', 'channels1', 'data', 'acc', 'b4sp', 'up_up', 'up_down', 'up_unclassified', 'down_up', 'down_down', 'down_unclassified']].to_csv(save_name.with_suffix(".csv")) return fig, bax, f_ax1, f_ax2, d_filt def plot_channels_used(d): """Plots for each speller block the channel that was used""" BASE_PATH_OUT.mkdir(parents=True, exist_ok=True) save_name = BASE_PATH_OUT / "Figure_S_channel_use" color_list = [(0.8, .8, .8), (.55, .55, .55), (.3, .3, .3), (0, 0, 0)] cmap = ListedColormap(color_list, len(color_list)) df_blocks = d[d.channels.notna() & (d['mode'] == 'color')] ch_list = list(df_blocks.channels) ch_list_merged = list(itertools.chain(*ch_list)) unique_channels = sorted(list(set(ch_list_merged))) ch_use_matrix = np.empty((len(unique_channels), len(df_blocks))) ch_use_matrix[:] = np.nan ch_count = np.zeros((len(unique_channels), len(df_blocks))) c_idx = {r: i for i, r in enumerate(unique_channels)} day_blocks = pd.DataFrame(index=df_blocks.d_since_impl.unique(), columns=unique_channels) day_blocks.fillna(0, inplace=True) b_d_s_i = df_blocks.d_since_impl b_d_s_i_change = (b_d_s_i.diff()).isna() | (b_d_s_i.diff() > 0) b_d_s_i_change = b_d_s_i_change.reset_index() tick_loc = b_d_s_i_change.index[b_d_s_i_change.d_since_impl].to_list() tick_lab = list(b_d_s_i[list(b_d_s_i_change.d_since_impl)]) j = 0 for i, r in df_blocks.iterrows(): col_val = len(r['channels']) # if r['mode'] == 'color': # pass # col_val += 5 for ch in r['channels']: ch_use_matrix[c_idx[ch], j] = col_val day_blocks.loc[r['d_since_impl'], ch] = 1 ch_count[c_idx[ch], j] = 1 j += 1 # This is sorting by number of blocks. For consistency, number of days is better (below) # sort_idx = np.argsort(ch_count.sum(1))[::-1] channel_days = day_blocks.sum() sort_idx = list(channel_days.argsort()[::-1]) sorted_channels = np.asarray(unique_channels)[sort_idx] channel_days.sort_values(ascending=False, inplace=True) # create array map with channel use counts array_map = ARRAY_MAPS['K01'] map_sma = array_map[0] amesh_sma = np.empty((np.max(map_sma['x']) + 1, np.max(map_sma['y']) + 1)) amesh_sma[:] = np.nan for index, el_ix in enumerate(map_sma['ix']): if el_ix + 1 in channel_days.index: amesh_sma[7 - map_sma['y'][index], map_sma['x'][index]] = channel_days[el_ix + 1] norm = mpl.colors.BoundaryNorm(np.arange(-.5+1, cmap.N+1), cmap.N) my_cmap = plt.get_cmap("plasma") rescale = lambda y: (y - np.min(y)) / (np.max(y) - np.min(y)) # plot number of channels used per block, over days fig = plt.figure(constrained_layout=True) gs = fig.add_gridspec(2, 3) # fig, axs = plt.subplots(2, 1, constrained_layout=True) ax = fig.add_subplot(gs[0, :-1]) imgplt = ax.pcolormesh(ch_use_matrix[sort_idx, :], cmap=cmap, norm=norm) ax.set_yticks(np.arange(0, len(unique_channels)) + .5) ax.set_yticklabels(sorted_channels) ax.set_ylabel('Channel ID') ax.set_xlabel('Days after implantation') ax.yaxis.set_inverted(True) ax.xaxis.set_ticks(tick_loc, minor=True) ax.set_xticks(tick_loc[::20]) ax.set_xticklabels([f"{x}" for x in tick_lab[::20]]) cbar = fig.colorbar(imgplt, ax=ax, location='right', ticks=range(1, cmap.N + 1)) # Plotting the histogram for channels over days ax = fig.add_subplot(gs[1, :-1]) ax.bar(range(len(channel_days)), channel_days, color=my_cmap(rescale(channel_days)), linewidth=.3, edgecolor='k')#, color='k') ax.set_xticks(range(len(channel_days))) ax.set_xticklabels(channel_days.index) ax.set_ylabel('used on number of days') ax.set_xlabel('Channel ID') ax = fig.add_subplot(gs[-1, -1]) rect = plt.Rectangle((0, 0), 8, 8, fill=True, facecolor=(.3, .3, .3), edgecolor=(0, 0, 0), alpha=0.2, zorder=-1) ax.add_patch(rect) ax.pcolormesh(amesh_sma, cmap=my_cmap) ax.set_xlim(-1, 9) ax.set_ylim(-1, 9) ax.set_aspect(1) ax.axis('off') fig.savefig(save_name.with_suffix(".pdf")) logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>") fig.savefig(save_name.with_suffix(".eps")) logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>") fig.savefig(save_name.with_suffix(".svg")) logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>") # save as csv ch_exp = ch_use_matrix[sort_idx, :] ch_exp[ch_exp > 0] = 1 df_ch_exp = pd.DataFrame(data=ch_exp.T, columns=[f"ch_{c}" for c in sorted_channels], index=df_blocks.index) df_ch_exp['n_ch_used'] = df_ch_exp.sum(axis=1) df_ch_exp['d_since_impl'] = df_blocks.d_since_impl df_ch_exp.reset_index(drop=True, inplace=True) df_ch_exp[df_ch_exp.isna()] = 0 df_ch_exp = df_ch_exp.astype('int') print(f"Number of channels used in speller blocks:\n{df_ch_exp[['n_ch_used']].groupby(['n_ch_used']).apply(lambda x: len(x))}") df_ch_exp.to_csv(save_name.with_suffix('._panel_a.csv')) data_panel_b = pd.DataFrame.from_dict({'days_used': channel_days}) data_panel_b.to_csv(save_name.with_suffix('._panel_b.csv')) def calculate_ITR(d): # define ITR calculation helper def itr_helper(r): # number of selectable characters (26 letters, space, delete, question mark, end program) N = 30 # remove substring from rating corresponding to the phrase_start rate_str = r.rating[len(r.phrase_start):] # calculate rate of correct characters P = rate_str.count('T') / (rate_str.count('T') + rate_str.count('F')) # calculate ITR in bits / min return (np.log2(N) + P * np.log2(P) + (1 - P) * np.log2((1 - P + 1e-6) / (N - 1))) * r.ch_per_min # select all rows with non-empty ratings dfwr = d.loc[(~d.rating.isna()) ] itr = dfwr.apply(itr_helper, axis=1) res = f"""There were {len(itr)} rated speller sessions. The min/mean/median/max ITR was {itr.min():.3f} / {itr.mean():.3f} / {itr.median():.3f} / {itr.max():.3f} bit per minute.\n""" print(res) if __name__ == '__main__': pth = BASE_PATH/'KIAP_BCI_neurofeedback' df = prepare_sessions(pth=pth) df = add_feedback_info(df, pth=pth) df = add_channel_info(df, pth=pth) pths = BASE_PATH/'KIAP_BCI_speller' ds = prepare_sessions(pth=pths) ds = add_feedback_info(ds, pth=pths) ds = add_channel_info(ds, pth=pths) d = pd.concat((df, ds), ignore_index=True) d.sort_values('start_dt', inplace=True) d.reset_index(drop=True, inplace=True) add_annotation(d) plot_sessions(d) plot_audio_feedback_tpfp(d) plot_fb_accuracy(d) plot_channels_used(d) speller_performance_by_nf(d) calculate_ITR(d)