12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142 |
- import matplotlib
- matplotlib.use('TkAgg')
- from datetime import datetime as dt
- import matplotlib.pyplot as plt
- import pandas as pd
- import numpy as np
- import re
- import helpers.sessions as hs
- from pathlib import Path
- import yaml
- import munch
- from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE, FEEDBACK_CHANGE_DATE, ARRAY_MAPS
- import logging
- from logging.handlers import TimedRotatingFileHandler
- from helpers.tsdumper import TSDumper
- from matplotlib.colors import ListedColormap
- import matplotlib as mpl
- import itertools
- import scipy.stats as stats
- from scipy.interpolate import UnivariateSpline
- logger = logging.getLogger("KIAP")
- logger.handlers.clear()
- logger.setLevel(logging.DEBUG)
- ch = logging.StreamHandler()
- ch.setLevel(logging.DEBUG)
- formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
- ch.setFormatter(formatter)
- logger.addHandler(ch)
- fh = TimedRotatingFileHandler(BASE_PATH_OUT/'kiap_figures.log', when='D', backupCount=5)
- fh.setLevel(logging.DEBUG)
- fh.setFormatter(formatter)
- fh.setLevel(logging.INFO)
- logger.addHandler(fh)
- def string_count_eq(str1, str2):
- """Counts how many characters are the same from the beginning, between two strings"""
- n_match = 0
- for (c1, c2) in zip(str1, str2):
- if c1 == c2:
- n_match += 1
- else:
- break
- return n_match
- def string_dist(str1, str2):
- """Helper function to calculate string distance.
- In our case, we want to know how many characters were added and removed between two strings,
- the one a speller session started with, and the one it ended with.
- If these share n characters at beginning, then the first one has m and the second one has
- another k characters, then the difference in character changes is (m+k)."""
- return len(str1) + len(str2) - 2 * string_count_eq(str1, str2)
- T_RE = re.compile(r"^(\d+-\d+-\d+T\d+:\d+:\d+\.\d+)\s+color.*\('(.*)', '(.*)'\)$")
- def extract_session_data(session_log_file_name):
- """
- given a session log file name for a colour speller session, of the format info_XX_XX_XX.log,
- this function returns a dictionary with the following keys:
- phrase_start: speller started with this phrase
- phrase: speller ended with this phrase
- n: number of characters in (phrase - phrase_start)
- ch_per_min: characters spelled per minute
-
- """
- with open(session_log_file_name, 'r') as f:
- evs = f.read().splitlines()
- d = {'n': 0, 'ch_per_min': 0, 'phrase_start': '', 'phrase': ''}
- m01 = T_RE.match(evs[0])
- m02 = T_RE.match(evs[-1])
- start_dt = dt.strptime(m01.group(1), '%Y-%m-%dT%H:%M:%S.%f')
- end_dt = dt.strptime(m02.group(1), '%Y-%m-%dT%H:%M:%S.%f')
- duration_min = (end_dt - start_dt).total_seconds() / 60.0
- d['phrase_start'] = m01.group(2)
- d['phrase'] = m02.group(2)
- d['n'] = string_dist(d['phrase_start'], d['phrase'])
- if duration_min > 0:
- d['ch_per_min'] = d['n'] / duration_min
- return d
- def precompile_sessions(pth=BASE_PATH, start=0, n=None, use_cache=True):
- """
- Finds recorded sessions at pth and loads information into a pandas DataFrame
- with columns:
-
- mode: mode of BCI session (feedback, color, exploration, question)
- cfg: relative path to the configuration file
- events: relative path to the event log file
- data: relative path to the binary data file
- log: relative path to the debug log file
- duration_s: duration of session in seconds (last data timestamp - first data timestamp)
- duration_min: duration of session in minutes
- start_dt: start of session as datetime, from time in config file
- end_dt: end of session as datetime
- d_since_impl: days since implantation
-
- The session data will be saved in a cache file and reloaded from there by default,
- as it takes some time to read each individual data file.
- This reading is necessary to get the most accurate estimate for the length of a session.
- Since on some occasions the saving of data files was not ended automatically at the end
- of a session, the events file will also be read to determine the length.
-
- Params:
- pth: base path for data folders
- start: skip this many sessions at start
- n: read this many sessions
- use_cache: try to read cache file if True
- """
- cache_file_name = Path(pth, 'session_cache.pkl')
- if use_cache and cache_file_name.exists():
- logger.info(f"Using cache file {cache_file_name}")
- cfgs = pd.read_pickle(cache_file_name)
- return cfgs
- (cfgs, _) = hs.get_sessions(pth, start=start, n=n)
- n_total = len(cfgs)
- t_diffs = []
- start_dts = []
- for i, s in cfgs.iterrows():
- logger.debug(f"loading {i + 1} of {n_total}: {s['cfg']}")
- try:
- (ts, _, _, evts) = hs.get_session_data(pth, s)
- t_diff = evts[-1] - evts[0]
- logger.info(f"Loaded session of {t_diff} s.")
- t_diffs.append(t_diff)
- except FileNotFoundError:
- logger.warning(f"Data file not found for {pth}")
- t_diffs.append(0)
- cfgstr = s['cfg']
- cfgstr = cfgstr.replace('/', '_')
- cfgstr = cfgstr.replace('\\', '_')
- start_dt = dt.strptime(cfgstr, '%Y-%m-%d_config_dump_%H_%M_%S.yaml')
- start_dts.append(start_dt)
- cfgs['duration_s'] = t_diffs
- cfgs['duration_min'] = cfgs['duration_s'] / 60.0
- cfgs['start_dt'] = start_dts
- cfgs['end_dt'] = cfgs['start_dt'] + pd.to_timedelta(cfgs['duration_s'], unit='seconds')
- cfgs['d_since_impl'] = cfgs['start_dt'].apply(lambda x: (x.to_pydatetime() - IMPLANT_DATE).days)
- cfgs.to_pickle(cache_file_name)
- return cfgs
- def add_session_info(d):
- """
- Compute particular bits of information given data sessions.
-
- Params:
- d: DataFrame containing session information, as generated by precompile_sessions()
-
- Adds to DataFrame:
- cum_dur: cumulative session duration per day
- ecol: edge color for plots
- col: color for plots based on session mode
-
- Returns: DataFrame
- """
- d.set_index(['d_since_impl'], inplace=True)
- d["cum_dur"] = d.groupby(level='d_since_impl')['duration_min'].cumsum() - d['duration_min']
- d.reset_index(inplace=True)
- # add colour information to session entries
- d['ecol'] = [(1, 1, 1, 1)] * len(d)
- d['col'] = [(0.1, 0.1, 0.1)] * len(d)
- idx = d['mode'] == 'feedback'
- d.loc[idx, 'col'] = pd.Series([(0, 0.447, 0.741)] * idx.sum(), index=d.loc[idx].index)
- idx = d['mode'] == 'question'
- d.loc[idx, 'col'] = pd.Series([(0.494, 0.184, 0.556)] * idx.sum(), index=d.loc[idx].index)
- idx = d['mode'] == 'color'
- d.loc[idx, 'col'] = pd.Series([(0.85, 0.325, 0.098)] * idx.sum(), index=d.loc[idx].index)
- idx = d['mode'] == 'exploration'
- d.loc[idx, 'col'] = pd.Series([(0.301, 0.745, 0.933)] * idx.sum(), index=d.loc[idx].index)
- return d
- def add_speller_data(d, pth=BASE_PATH):
- """
- Reads speller results for every speller session and adds that information to
- the dataframe.
-
- Params:
- d: DataFrame containing session information, as generated by precompile_sessions()
-
- Adds to DataFrame:
- 'n': length of communication for speller block
- 'ch_per_min': number of characters spelled per minute
- 'phrase_start': phrase at begin of speller block
- 'phrase': phrase at end of speller block
- 'cum_n': cumulative length of phrases per day
- Returns: DataFrame
- """
- d['n'] = pd.array([None] * len(d), dtype='Int32')
- d['ch_per_min'] = None
- d['phrase_start'] = None
- d['phrase'] = None
- for ix, rw in d.loc[d['mode'] == 'color'].iterrows():
- sp_dict = extract_session_data(Path(pth, rw['log']))
- spr = pd.Series(sp_dict, name=ix)
- d.loc[ix, spr.index] = spr
- d.set_index(['d_since_impl'], inplace=True)
- ix = d['mode'] == 'color'
- d.loc[ix, "cum_n"] = d.loc[ix].groupby(level='d_since_impl')['n'].cumsum() - d.loc[ix, 'n']
- d.reset_index(inplace=True)
- return d
- def extract_trials_old(filename):
- """
- filename - path to an info_*.log file
- 'old': log pattern before 20 July 2019 ('up' / 'down' trials, 'yes' response for correct, 'unclassified' for timeout)
- """
- with open(filename, 'r') as f:
- evs = f.read().splitlines()
- # fix event timestamps
- decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$")
- contingency = pd.DataFrame
- imap = {'up': 'down', 'down': 'up'}
- conditions = []
- responses = []
- for ev in evs:
- m = decpat.match(ev)
- if m is not None:
- condition = m.group(2)
- response = m.group(1)
- if condition == 'baseline':
- continue
- if response == 'yes':
- response = condition
- elif response == 'unclassified':
- response = imap.get(condition, 'unclassified')
- conditions.append(condition)
- responses.append(response)
- all_condition = pd.Categorical(conditions, categories=['up', 'down'])
- all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
- # print(evs)
- ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
- return ct
- def extract_trials_new(filename):
- """
- filename - path to an info_*.log file
- 'new': log pattern on and after 20 July 2019: ('up' / 'down' trials, 'yes' / 'no' / 'unclassified' response)
- """
- with open(filename, 'r') as f:
- evs = f.read().splitlines()
- # fix event timestamps
- decpat = re.compile(r".*\sfeedback - Decoder decision: (\w*) - \('feedback', '(\w+)'\)$")
- contingency = pd.DataFrame
- conditions = []
- responses = []
- for ev in evs:
- m = decpat.match(ev)
- if m is not None:
- condition = m.group(2)
- response = m.group(1)
- if response == 'yes':
- response = 'up'
- elif response == 'no':
- response = 'down'
- conditions.append(condition)
- responses.append(response)
- all_condition = pd.Categorical(conditions, categories=['up', 'down'])
- all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
- # dft = pd.DataFrame({'condition': conditions, 'response':responses})
- # ct = pd.crosstab(dft.condition, dft.response)
- ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
- return ct
- def extract_trials_q(filename):
- """
- filename - path to an info_*.log file
- log pattern: 2019-10-02T21:20:05.053678 question - Decoder decision: no - ('No question', '002_11038.wav')
- """
- with open(filename, 'r') as f:
- evs = f.read().splitlines()
- # fix event timestamps
- decpat = re.compile(r".*\squestion - Decoder decision: (\w*) - \('(Yes|No) question', '(.+)'\)$")
- contingency = pd.DataFrame
- conditions = []
- responses = []
- for ev in evs:
- m = decpat.match(ev)
- if m is not None:
- condition = m.group(2)
- response = m.group(1)
- if response == 'yes':
- response = 'up'
- elif response == 'no':
- response = 'down'
- if condition == 'Yes':
- condition = 'up'
- elif condition == 'No':
- condition = 'down'
- conditions.append(condition)
- responses.append(response)
- all_condition = pd.Categorical(conditions, categories=['up', 'down'])
- all_response = pd.Categorical(responses, categories=['up', 'down', 'unclassified'])
- # dft = pd.DataFrame({'condition': conditions, 'response':responses})
- # ct = pd.crosstab(dft.condition, dft.response)
- ct = pd.crosstab(all_response, all_condition, dropna=False, colnames=['condition'], rownames=['response'])
- return ct
- def add_feedback_info(d, pth=BASE_PATH):
- """
- For all feedback sessions, read log file and save contingency table.
- """
- if 'ct' not in d.columns:
- d['ct'] = [None for _ in range(len(d))]
- for ix, rw in d.iterrows():
- if rw['mode'] != 'feedback':
- continue
- if FEEDBACK_CHANGE_DATE > rw.start_dt:
- ct = extract_trials_old(pth / rw.log)
- else:
- ct = extract_trials_new(pth / rw.log)
- cond_sums = ct[ct.index != 'unclassified'].sum()
- tpr = ct.loc['up', 'up'] / cond_sums['up']
- fpr = ct.loc['up', 'down'] / cond_sums['down']
- acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down'])
- d.at[ix, 'ct'] = ct
- d.at[ix, 'tpr'] = tpr
- d.at[ix, 'fpr'] = fpr
- d.at[ix, 'acc'] = acc
- d.at[ix, 'n_trials'] = ct.sum().sum()
- return d
- def calculate_feedback_before_speller(d):
- """For each day, find feedback sessions before the first speller session and add up trials."""
- u_dsi = d['d_since_impl'].unique()
- day_n_fb = [(di, d_day[(d_day['mode'] == 'feedback') & (d_day.index < ix)]['ct'].sum().sum().sum())
- for di in u_dsi
- for d_day in [d[d['d_since_impl'] == di]]
- for ix in [d_day[d_day['mode'] == 'color'].first_valid_index()]
- if ix is not None
- ]
- return day_n_fb
- def add_question_info(d, pth=BASE_PATH):
- """
- For all feedback sessions, read log file and save contingency table.
- """
- if 'ct' not in d.columns:
- d['ct'] = [None for _ in range(len(d))]
- for ix, rw in d.iterrows():
- if rw['mode'] != 'question':
- continue
- ct = extract_trials_q(pth / rw.log)
- cond_sums = ct[ct.index != 'unclassified'].sum()
- tpr = ct.loc['up', 'up'] / cond_sums['up']
- fpr = ct.loc['up', 'down'] / cond_sums['down']
- acc = (ct.loc['up', 'up'] + ct.loc['down', 'down']) / (cond_sums['up'] + cond_sums['down'])
- d.at[ix, 'ct'] = ct
- d.at[ix, 'tpr'] = tpr
- d.at[ix, 'fpr'] = fpr
- d.at[ix, 'acc'] = acc
- return d
- def prepare_for_annotation_export(d):
- """
- Prepare a yaml file for annotation of speller sessions.
-
- Will be written at [BASE_PATH_OUT]/records_for_annotation.yml
-
- """
- d2 = d.reset_index().set_index('start_dt', drop=False)
- d2 = d2.loc[d2['mode'] == 'color', ['data', 'd_since_impl', 'start_dt', 'phrase_start', 'phrase']]
- d2['intelligible'] = None
- d2['start'] = d2['start_dt'].map(lambda x: x.strftime('%Y-%m-%d %H:%M:%S'))
- d2 = d2[['data', 'start', 'd_since_impl', 'phrase_start', 'phrase', 'intelligible']]
- d_dict = d2.to_dict(orient='index')
- with open(BASE_PATH_OUT / 'records_for_annotation.yml', 'w') as f:
- f.write("""# Rating scheme for intelligibility of patient's communications.
- #
- # Instructions for filling this file:
- #
- # We consider the result of one total session / speller run.
- #
- # 1. Find session in log by date of data file (also given in plain text)
- # 2. Look up speller start (key 'phrase_start') and final output (key 'phrase')
- # and look up speller output / session remarks.
- # 3. Rate speller output.
- # For copy spelling sessions:
- # 0 – completely wrong
- # 1 – up to 20% of characters wrong or missing
- # 2 - no mistake
- # For free spelling:
- # 0 - incomprehensible speller output
- # 1 - partially understandable, but with doubts due to spelling mistakes
- # 2 - unambiguous to family / experimenter (even if single letters are
- # wrong or missing; even if words are incomplete)
- # where one session's output could be counted into several categories,
- # category 1 is likely appropriate.
- # 4. Find the record for the session in list below and replace the 'null' entry
- # under the 'intelligible' key with your rating.
- """)
- yaml.dump(d_dict, f, default_flow_style=False, Dumper=TSDumper, sort_keys=False)
- def add_annotation(d):
- """
- Read annotations and add them to DataFrame.
-
- Params:
- d: DataFrame containing session information, as generated by precompile_sessions()
-
- Adds to DataFrame:
- 'intelligible': rating of intelligibility of a spelled session
- 'col': updated color based on intelligibility
- Returns: DataFrame
-
- """
- # with open(Path('annotations', 'speller', 'records_for_annotation_full.yml'), 'r') as f:
- with open(Path('annotations', 'speller', 'records_for_annotation_consolidated.yml'), 'r') as f:
- anno = munch.Munch.fromYAML(f, Loader=yaml.Loader)
- adict = pd.DataFrame.from_dict(anno, orient='index')
- d.reset_index(inplace=True)
- d.set_index('start_dt', drop=False, inplace=True)
- d.loc[adict.index, 'intelligible'] = adict['intelligible']
- d.loc[adict.index, 'rating'] = adict['rating']
- d.loc[(d['intelligible'] == 2) & (d['mode'] == 'color'), 'col'] = [[(0.318, 0.039, 0.090)]]
- d.loc[(d['intelligible'] == 1) & (d['mode'] == 'color'), 'col'] = [[(0.635, 0.078, 0.184)]]
- d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'col'] = [[(1.000, 0.737, 0.843)]]
- d.loc[(d['intelligible'] == 0) & (d['mode'] == 'color'), 'ecol'] = [[(.2, .2, .2)]]
- d.set_index('index', inplace=True)
- return d
- def add_channel_info(d, pth=BASE_PATH):
- """
- For all sessions, load config files and extract information about channels
- being used in sessions. Channels are in 1-base number, corresponding to Blackrock's numbering scheme.
- """
- d.loc[:, 'channels'] = None
- d.loc[:, 'use_all'] = None
- # d.loc[:, 'norm'] = None
- d.loc[:, 'submode'] = None
- for ix, rw in d.iterrows():
- with open(Path(pth, rw.cfg), 'r') as f:
- cfg = munch.Munch.fromYAML(f, Loader=yaml.Loader)
- d.loc[ix, 'channels'] = [[ch.id + 1] for ch in cfg.daq.normalization.channels]
- d.loc[ix, 'use_all'] = cfg.daq.normalization.use_all_channels
- # d.loc[ix, 'norm'] = cfg.daq.normalization
- # test if paradigms key in config file exists. if not, load paradigms file
- if cfg.get('paradigms') is None:
- p_fn = rw.cfg.replace('config_dump', 'paradigm')
- with open(Path(pth, p_fn), 'r') as f:
- cfg.paradigms = munch.Munch.fromYAML(f, Loader=yaml.Loader)
- try:
- if rw['mode'] == 'question':
- d.loc[ix, 'submode'] = cfg.paradigms.question.mode[cfg.paradigms.question.selected_mode]
- elif rw['mode'] == 'color':
- d.loc[ix, 'submode'] = cfg.paradigms.color.mode[cfg.paradigms.color.selected_mode]
- elif rw['mode'] == 'feedback':
- d.loc[ix, 'submode'] = cfg.paradigms.feedback.mode[cfg.paradigms.feedback.selected_mode]
- except Exception:
- logger.warning(f"Exception reading mode for row {ix}", exc_info=True)
- logger.debug(f"Loading normalization info for session {ix} ({d.loc[ix, 'mode']}, {d.loc[ix, 'submode']}):"
- f" {d.loc[ix, 'channels']}")
- return d
- def get_indexer_and_color(d, mode, intelligible=None):
- """
- Given a session DataFrame, the session mode, and optionally an intelligibility rating,
- return indexer for rows matching that criterion, colour, edge colour, and label for plotting.
-
- Params:
- d: session DataFrame
- mode: one of 'color', 'feedback', 'exploration', 'question'
- intelligible: if mode == 'color', this should be 0, 1, or 2
-
- """
- ix = d['mode'] == mode
- ecol = (1, 1, 1, 1)
- col = (.1, .1, .1)
- label = ""
- if mode == 'color':
- if intelligible is None:
- ix = ix & pd.isna(d['intelligible'])
- else:
- ix = ix & (d['intelligible'] == intelligible)
- label = "Speller"
- if intelligible == 2:
- col = (0.635, 0.078, 0.184)
- label = "Speller clear"
- elif intelligible == 1:
- col = (0.85, 0.325, 0.098)
- label = "Speller ambiguous"
- elif intelligible == 0:
- col = (0.929, 0.694, 0.125)
- # ecol = (.2, .2, .2)
- label = "Speller unintelligible"
- else:
- col = (0.929, 0.894, 0.825)
- # ecol = (.2, .2, .2)
- label = "Speller not rated"
- elif mode == 'feedback':
- col = (0.466, 0.674, 0.188)
- label = "Feedback Training"
- elif mode == 'exploration':
- col = (0, 0.447, 0.741)
- label = "Exploration"
- elif mode == 'question':
- col = (0.301, 0.745, 0.933)
- label = "Questions"
- return dict(ix=ix, col=col, ec=ecol, label=label)
- def plot_bars_for_sel(ax, d, mode, intelligible=None):
- """
- Plot bars into axis.
- Params:
- ax: axis to plot into
- d: Pandas DataFrame containing session info
- mode: BCI mode, one of 'color', 'feedback', 'exploration', 'question'
- intelligible: if mode == 'color', this should be 0, 1, or 2
- """
- ic = get_indexer_and_color(d, mode, intelligible)
- return ax.bar(x=d.loc[ic['ix'], 'd_since_impl'], height=d.loc[ic['ix'], 'duration_min'], color=ic['col'],
- ec=ic['ec'], label=ic['label'], bottom=d.loc[ic['ix'], 'cum_dur_min'], lw=.5, width=.8)
- def prepare_sessions(pth=BASE_PATH):
- """Combine a few steps to load session data"""
- d = precompile_sessions(pth=pth)
- add_session_info(d)
- add_speller_data(d, pth=pth)
- return d
- def plot_sessions(d):
- """Load sessions, plot, and print summary"""
- from brokenaxes import brokenaxes
- n_days = len(d['d_since_impl'].unique())
- x_ranges = ((105, 126), (146, 470)) # ((105, 126), (146, 163), (174, 212), (223, 227), (238, 465))
- d2 = d.set_index(['d_since_impl'])
- other_m_d = d2[d2['mode'] != 'color'].groupby(['d_since_impl', 'mode'])[['duration_min']].sum()
- other_m_d.reset_index('mode', inplace=True)
- sp_d = d2[d2['mode'] == 'color'][
- ['start_dt', 'mode', 'duration_min', 'cfg', 'events', 'data', 'log', 'phrase', 'intelligible', 'n',
- 'ch_per_min']]
- df_dur_plot = pd.concat([other_m_d, sp_d])
- df_dur_plot.sort_index(inplace=True, kind='mergesort')
- df_dur_plot['cum_dur_min'] = df_dur_plot.groupby(level=0)['duration_min'].cumsum() - df_dur_plot['duration_min']
- df_dur_plot.reset_index(inplace=True)
- # save number of letters spelled in 'intelligible'==2 group, per day
- n_per_day_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl')['n'].sum()
- n_per_day_intell.name = 'n_per_day'
- cpm_intell = d[(d['mode'] == 'color') & (d['intelligible'] == 2)].groupby('d_since_impl').apply(
- lambda x: x['n'].sum() / (x['duration_s'].sum() / 60.0))
- cpm_intell.name = 'cpm_per_day'
- spl = UnivariateSpline(n_per_day_intell.index, n_per_day_intell.values, s=850000, k=3)
- tx = np.arange(n_per_day_intell.index[0], n_per_day_intell.index[-1])
- smoothv = spl(tx)
- spl_cpm = UnivariateSpline(cpm_intell.index, cpm_intell.values, s=850000, k=3)
- tx_cpm = np.arange(cpm_intell.index[0], cpm_intell.index[-1])
- smoothv_cpm = spl_cpm(tx_cpm)
- spell_int = d.groupby('d_since_impl').apply(lambda x: ((x['mode'] == 'color') & (x['intelligible'] >= 2)).any())
- spell_int_days = spell_int[spell_int].index
- spell_not_int = d.groupby('d_since_impl').apply(
- lambda x: ((x['mode'] == 'color') & (x['intelligible'] < 2)).any() & ~ (
- (x['mode'] == 'color') & (x['intelligible'] >= 2)).any())
- spell_not_int_days = spell_not_int[spell_not_int].index
- no_spell = d.groupby('d_since_impl').apply(lambda x: ~(x['mode'] == 'color').any())
- no_spell_days = no_spell[no_spell].index
- save_name = BASE_PATH_OUT / "Figure_2_SessionSummary"
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- fig = plt.figure(2, figsize=(15, 6))
- fig.clf()
- gs = fig.add_gridspec(3, hspace=0, height_ratios=[4,4,1])
- axs = gs.subplots(sharex=True)
- no_spell_color = (0.93725, 0.92941, 0.96078)
- not_intell_color = (0.73725, 0.74118, 0.86275)
- intell_color = (0.45882, 0.41961, 0.69412)
- bax = axs[2]
- bax.bar(no_spell_days, 1, width=1, color=no_spell_color)
- bax.bar(spell_not_int_days, 1, width=1, color=not_intell_color)
- bax.bar(spell_int_days, 1, width=1, color=intell_color)
- bax.set_ylim([0, 1])
- bax.set_xlabel('days since implantation')
- bax.set_yticks([])
- bax = axs[0]
- bax.plot(n_per_day_intell.index, n_per_day_intell.values, color=intell_color, marker='o', linestyle='none')
- bax.plot(tx, smoothv, color=intell_color)
- bax.set_ylabel('number of characters')
- ax2 = axs[1]
- ax2.set_ylabel('characters per minute') # we already handled the x-label with ax1
- ax2.plot(cpm_intell.index, cpm_intell.values, color=intell_color, marker='o', linestyle='none')
- ax2.plot(tx_cpm, smoothv_cpm, color=intell_color)
- ax2.tick_params(axis='y')
- fig.show()
- fig.savefig(save_name.with_suffix(".pdf"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
- fig.savefig(save_name.with_suffix(".svg"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
- fig.savefig(save_name.with_suffix(".eps"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
- ## Aggregate sessions.
- # First, for each day, how much time was spent with feedback, questions, exploration
- fbdesc = other_m_d[other_m_d['mode'] == 'feedback'].describe()
- qdesc = other_m_d[other_m_d['mode'] == 'question'].describe()
- exdesc = other_m_d[other_m_d['mode'] == 'exploration'].describe()
- # Sum up all speller sessions
- aggr_sp_d = sp_d.groupby(['d_since_impl'])[['duration_min']].sum()
- spdesc = aggr_sp_d.describe()
- # Only take speller sessions where the message could be understood at least partially.
- aggr_sp_int_d = sp_d[sp_d['intelligible'] > 0].groupby(['d_since_impl'])[['duration_min', 'n']].sum()
- int_spdesc = aggr_sp_int_d.describe()
- aggr_sum = aggr_sp_int_d.sum()
- aggr_sp_clear_d = sp_d[sp_d['intelligible'] > 1].groupby(['d_since_impl'])[['duration_min', 'n']].sum()
- clear_spdesc = aggr_sp_clear_d.describe()
- aggr_clear_sum = aggr_sp_clear_d.sum()
- ch_per_min = sp_d[(sp_d['mode'] == 'color') & (sp_d['intelligible'] == 2)]['ch_per_min']
- desc_str = f"""
- This analysis covers visits on {n_days} days. On average, {fbdesc.loc['mean', 'duration_min']:4.1f} minutes
- were spent in feedback training (min/25%/50%/75%/max: {fbdesc.loc['min', 'duration_min']:4.1f}, {fbdesc.loc['25%', 'duration_min']:4.1f}, {fbdesc.loc['50%', 'duration_min']:4.1f}, {fbdesc.loc['75%', 'duration_min']:4.1f}, {fbdesc.loc['max', 'duration_min']:4.1f}).
-
- On {qdesc.loc['count', 'duration_min']:n} days, the question paradigm was performed. On average, {qdesc.loc['mean', 'duration_min']:4.1f} minutes
- were spent in the question paradigm (min/25%/50%/75%/max: {qdesc.loc['min', 'duration_min']:4.1f}, {qdesc.loc['25%', 'duration_min']:4.1f}, {qdesc.loc['50%', 'duration_min']:4.1f}, {qdesc.loc['75%', 'duration_min']:4.1f}, {qdesc.loc['max', 'duration_min']:4.1f}).
-
- On {exdesc.loc['count', 'duration_min']:n} days, the exploration paradigm was performed. On average, {exdesc.loc['mean', 'duration_min']:4.1f} minutes
- were spent in the exploration paradigm (min/25%/50%/75%/max: {exdesc.loc['min', 'duration_min']:4.1f}, {exdesc.loc['25%', 'duration_min']:4.1f}, {exdesc.loc['50%', 'duration_min']:4.1f}, {exdesc.loc['75%', 'duration_min']:4.1f}, {exdesc.loc['max', 'duration_min']:4.1f}).
-
- On {spdesc.loc['count', 'duration_min']:n} days, we attempted to use the speller.
-
- On {int_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate at least partially understandable output. On average, {int_spdesc.loc['mean', 'duration_min']:4.1f} minutes
- were spent spelling (min/25%/50%/75%/max: {int_spdesc.loc['min', 'duration_min']:4.1f}, {int_spdesc.loc['25%', 'duration_min']:4.1f}, {int_spdesc.loc['50%', 'duration_min']:4.1f}, {int_spdesc.loc['75%', 'duration_min']:4.1f}, {int_spdesc.loc['max', 'duration_min']:4.1f}).
- On average, the daily output was {int_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {int_spdesc.loc['min', 'n']:n}, {int_spdesc.loc['25%', 'n']:n}, {int_spdesc.loc['50%', 'n']:n}, {int_spdesc.loc['75%', 'n']:n}, {int_spdesc.loc['max', 'n']:n}).
- Overall, the patient's at least partially comprehensible utterances comprised {aggr_sum['n']:n} characters
- produced over {aggr_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_sum['n'] / aggr_sum['duration_min']:4.2f} characters per minute.
-
- On {clear_spdesc.loc['count', 'duration_min']:n} days, the patient used the speller to generate clearly understandable output. On average, {clear_spdesc.loc['mean', 'duration_min']:4.1f} minutes
- were spent spelling (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'duration_min']:4.1f}, {clear_spdesc.loc['25%', 'duration_min']:4.1f}, {clear_spdesc.loc['50%', 'duration_min']:4.1f}, {clear_spdesc.loc['75%', 'duration_min']:4.1f}, {clear_spdesc.loc['max', 'duration_min']:4.1f}).
- On average, the daily output was {clear_spdesc.loc['mean', 'n']:4.1f} characters (min/25%/50%/75%/max: {clear_spdesc.loc['min', 'n']:n}, {clear_spdesc.loc['25%', 'n']:n}, {clear_spdesc.loc['50%', 'n']:n}, {clear_spdesc.loc['75%', 'n']:n}, {clear_spdesc.loc['max', 'n']:n}).
- Overall, the patient's clearly intelligible communications comprised {aggr_clear_sum['n']:n} characters
- produced over {aggr_clear_sum['duration_min']:n} minutes, corresponding to an grand average rate of {aggr_clear_sum['n'] / aggr_clear_sum['duration_min']:4.2f} characters per minute.
-
- Per-session spelling rate was min/median/max: {ch_per_min.min():.1f}/{ch_per_min.median():.1f}/{ch_per_min.max():.1f} characters per minute.
-
- On {len(d.groupby('d_since_impl')) - len(d[d['mode']=='color'].groupby('d_since_impl'))} days, use of the speller was not attempted because criterion was not reached or because of other circumstances.
- """
- logger.info(desc_str)
- days_df = pd.concat([pd.DataFrame(index=spell_int_days, columns=['intell'], data='intelligible'),
- pd.DataFrame(index=spell_not_int_days, columns=['intell'], data='not_intelligible'),
- pd.DataFrame(index=no_spell_days, columns=['intell'], data='no_speller')]).sort_index()
- n_df = pd.concat([n_per_day_intell, cpm_intell], axis=1)
- session_summary_df = pd.concat([n_df, days_df], axis=1)
- session_summary_df.to_csv(save_name.with_suffix(".csv"))
- return d, fig, bax
- def get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=True):
- """
- Finds neurofeedback sessions before speller
- :param d: DataFrame of sessions
- :param ignore_sessions_before_fb_change: if True (default), ignore sessions before logging scheme was changed.
- :return: DataFrame, list of indices, list of pairs of indices of NF blocks and corresponding speller blocks.
- """
- if ignore_sessions_before_fb_change:
- color_ix = d[d['mode'].isin(['color']) & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index
- fb_ix = d[(d['mode'] == 'feedback') & (d['start_dt'] >= FEEDBACK_CHANGE_DATE)].index
- else:
- color_ix = d[d['mode'].isin(['color'])].index
- fb_ix = d[(d['mode'] == 'feedback')].index
- fbix = np.unique([fb_ix[np.where(fb_ix < ci)[0][-1]] for ci in color_ix])
- fb_ci_pairs = [(fb_ix[np.where(fb_ix < ci)[0][-1]], ci) for ci in color_ix]
- return d, fbix, fb_ci_pairs
- def generate_fb_session_list(d):
- CFG_RE = re.compile(r"^(\d+-\d+-\d+)/config_dump_(\d+_\d+_\d+)\.yaml$")
- DTA_RE = re.compile(r"^\d+-\d+-\d+/data_(\d+_\d+_\d+)\.bin$")
- (_, fbix, _) = get_fb_sessions_before_speller(d)
- fb_list_d = []
- for fbi in fbix:
- s = d.loc[fbi]
- cfm = CFG_RE.match(d.loc[fbi].cfg)
- dtm = DTA_RE.match(d.loc[fbi].data)
- line_d = {'day': cfm[1], 'cfg_t': cfm[2], 'data_t': dtm[1]}
- fb_list_d.append(line_d)
- return fb_list_d
- def rand_jitter(arr):
- stdev = .005 * (max(arr) - min(arr))
- return arr + np.random.randn(len(arr)) * stdev
- def speller_performance_by_nf(d):
- *_, fb_ci_pairs = get_fb_sessions_before_speller(d)
- # acc_n = list(map(lambda p: (d.iloc[p[0]]['acc'], d.iloc[p[1]]['n']), fb_ci_pairs))
- acc_i_n = [(d.iloc[p[0]]['d_since_impl'], d.iloc[p[0]]['acc'], int(d.iloc[p[1]]['intelligible']), d.iloc[p[1]]['n']) for p in fb_ci_pairs]
- df = pd.DataFrame(acc_i_n, columns=['DSI', 'Acc', 'Int', 'N'])
- spc_acc_int = stats.spearmanr(df['Acc'], df['Int'])
- spc_acc_n = stats.spearmanr(df['Acc'], df['N'])
- res_text = f"""
- Correlation between Neurofeedback task accuracy and subsequent spelling.
- There were {len(df)} pairs of speller blocks and preceding neurofeedback blocks. The speller output was rated
- 0 for unintelligble, 1 for partially intelligible, 2 for intelligible. The Spearman correlation between Neurofeedback
- task accuracy and speller intelligibility was {spc_acc_int.correlation:4.3f} (p={spc_acc_int.pvalue:4.3e}).
- The Spearman correlation between Neurofeedback accuracy and number of letters spelled was
- {spc_acc_n.correlation:4.3f} (p={spc_acc_n.pvalue:4.3e}).
- """
- logger.info(res_text)
- return res_text
- def plot_audio_feedback_tpfp(d):
- """
- Finds all audio feedback sessions before speller.
- It then calculates true positive rates and false positive rates and plots
- each of these sessions in a scatter plot.
-
- """
- (_, fbix, _) = get_fb_sessions_before_speller(d)
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- save_name = BASE_PATH_OUT / "Figure_3B_TPFP"
- fig = plt.figure(2, figsize=(15, 6))
- fig.clf()
- ax = fig.subplots()
- ax.scatter(rand_jitter(d.loc[fbix, 'fpr']), rand_jitter(d.loc[fbix, 'tpr']))
- ax.plot([0, 1], [0, 1], 'k:')
- ax.set_ylabel('True Positive Rate')
- ax.set_xlabel('False Positive Rate')
- ax.set_aspect('equal')
- fig.savefig(save_name.with_suffix(".pdf"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
- fig.savefig(save_name.with_suffix(".eps"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
- fig.savefig(save_name.with_suffix(".svg"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
- ctsum = d.loc[fbix, 'ct'].sum()
- cond_sums = ctsum[ctsum.index != 'unclassified'].sum()
- n_trials = ctsum.sum().sum()
- n_timeout = ctsum[ctsum.index == 'unclassified'].sum().sum()
- n_correct = (ctsum.loc['up', 'up'] + ctsum.loc['down', 'down'])
- n_up_incorrect = ctsum.loc['down', 'up']
- n_down_incorrect = ctsum.loc['up', 'down']
- n_up = ctsum.sum()['up']
- n_down = ctsum.sum()['down']
- r_up_incorrect = n_up_incorrect / n_up
- r_down_incorrect = n_down_incorrect / n_down
- tpr = ctsum.loc['up', 'up'] / cond_sums['up']
- fpr = ctsum.loc['up', 'down'] / cond_sums['down']
- acc = n_correct / n_trials
- acc_thr = 0.8
- fraction_less_than_80 = (d.loc[fbix, 'acc'] < acc_thr).sum() / len(fbix)
- acc_min = 100.0 * d.loc[fbix, 'acc'].min()
- acc_max = 100.0 * d.loc[fbix, 'acc'].max()
- res_str = f"""Contingency table (columns = conditions, rows = observations):\n{ctsum}\n
- There were {len(fbix)} sessions. In total, there were {n_trials} trials.
- The accuracy over all trials was {100.0 * acc:4.1f}% (n={n_correct}). There were {n_timeout} timeout trials ({100.0 * n_timeout / n_trials:4.1f}%).
- There were {n_up_incorrect} ({100.0 * r_up_incorrect:4.1f}%) incorrect 'up' trials and {n_down_incorrect} ({100.0 * r_down_incorrect:4.1f}%) incorrect 'down' trials.\n
- In the last feedback sessions before speller sessions, the median accuracy was {100.0 * d.loc[fbix, 'acc'].median():4.1f}%,
- the minimum was {acc_min:4.1f}%. In {100.0 * fraction_less_than_80:4.1f}% of the sessions, accuracy was below {100.0 * acc_thr:4.1f}%.
- """
- logger.info(res_str)
- # Export as CSV
- d_filt = d.loc[fbix].copy()
- for i, row in d_filt.iterrows():
- d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up']
- d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up']
- d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up']
- d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down']
- d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down']
- d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down']
- d_filt[
- ['d_since_impl', 'start_dt', 'duration_s', 'channels', 'mode', 'data', 'tpr', 'fpr', 'acc', 'up_up', 'up_down',
- 'up_unclassified', 'down_up', 'down_down', 'down_unclassified']].to_csv(save_name.with_suffix(".csv"))
- def plot_fb_accuracy(d):
- """
- Plots all feedback sessions' accuracy as function of day.
- This generates Supplementary Figure 2.
- """
- from brokenaxes import brokenaxes
- from matplotlib.gridspec import GridSpec
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- save_name = BASE_PATH_OUT / "Figure_S2_fb_acc"
- d_filt = d[d['mode'] == 'feedback'].copy()
- d_filt['ddiff'] = (d_filt['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D')
- # accv = d_filt[['d_since_impl', 'acc', 'start_dt', 'ddiff']]
- n_fb_sessions = d_filt.groupby('d_since_impl').count()['acc']
- n_fb_sessions.name = 'n_fb'
- # Get fb before speller
- (_, fbix, _) = get_fb_sessions_before_speller(d, ignore_sessions_before_fb_change=False)
- # d_nf = d.loc[fbix].copy()
- # d_nf['ddiff'] = (d_nf['start_dt'] - IMPLANT_DATE) / pd.to_timedelta(1, unit='D')
- d_filt['b4sp'] = 0
- d_filt.loc[fbix, 'b4sp'] = 1
- fig = plt.figure(22, figsize=(15, 8), constrained_layout=False)
- fig.clf()
- gs = GridSpec(ncols=2, nrows=2, figure=fig, hspace=.25, top=.9, bottom=.1)
- bax = fig.add_subplot(gs[0, :])
- # bax = brokenaxes(xlims=x_ranges, hspace=.05, tilt=65, d=.005, subplot_spec=gs[0, :])
- bax.plot(d_filt['ddiff'] - .5, d_filt['acc'], 'o', color=(.5, .5, .5), ms=2, label="NF Trial sets")
- accvmax = d_filt.groupby('d_since_impl')[['acc']].max()
- # bax.plot(accvmax.index, accvmax['acc'], '-', color=(0, 0, 0), lw=2, label="Daily maximum")
- d_b4sp = d_filt[d_filt['b4sp'] > 0]
- bax.plot(d_b4sp['ddiff'] - .5, d_b4sp['acc'], 'o', color=(1, 0, 0), ms=3, lw=2, label="NF before speller")
- bax.set_ylim(-0.02, 1.02)
- bax.set_xlabel("Days after implantation")
- bax.set_ylabel("Accuracy")
- bax.set_title("Accuracy of Feedback trials")
- bax.legend(loc="lower right")
- dsi = 113
- d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')]
- d_sub.reset_index(inplace=True)
- d_sub.index = d_sub.index + 1
- f_ax1 = fig.add_subplot(gs[1, 0])
- f_ax1.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k')
- f_ax1.set_ylim(-0.05, 1.05)
- f_ax1.set_xlabel("Time since start [min]")
- f_ax1.set_ylabel("Accuracy")
- f_ax1.set_title(f"Feedback trial sets on day {dsi} post-implantation")
- dsi = 197
- d_sub = d[(d['d_since_impl'] == dsi) & (d['mode'] == 'feedback')]
- d_sub.reset_index(inplace=True)
- d_sub.index = d_sub.index + 1
- f_ax2 = fig.add_subplot(gs[1, 1])
- f_ax2.plot((d_sub['start_dt'] - d_sub.loc[1, 'start_dt']).astype('timedelta64[s]') / 60, d_sub['acc'], 'o-k')
- f_ax2.set_ylim(-0.05, 1.05)
- f_ax2.set_xlabel("Time since start [min]")
- f_ax2.set_ylabel("Accuracy")
- f_ax2.set_title(f"Feedback trial sets on day {dsi} post-implantation")
- fig.savefig(save_name.with_suffix(".pdf"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
- fig.savefig(save_name.with_suffix(".eps"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
- fig.savefig(save_name.with_suffix(".svg"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
- thr = .9
- descstr = f"""
-
- The best run per day had an accuracy of {thr:.1f} in {(accvmax['acc'] > thr).sum() / len(accvmax) * 100.0:.1f} % of days.
-
- Over the reported period, there were {n_fb_sessions.sum()} feedback sessions, min/median/maximum per day:
- {n_fb_sessions.min()} / {n_fb_sessions.median()} / {n_fb_sessions.max()}.
-
- Over all {d.loc[d['mode'] == 'feedback', 'acc'].count()} feedback sessions, the accuracy was {100*d.loc[d['mode'] == 'feedback', 'acc'].mean():.1f}%.
- """
- logger.info(descstr)
- # Export as CSV
- d_filt.loc[:, 'channels1'] = None
- for i, row in d_filt.iterrows():
- d_filt.at[i, 'up_up'] = row.ct.loc['up', 'up']
- d_filt.at[i, 'up_down'] = row.ct.loc['down', 'up']
- d_filt.at[i, 'up_unclassified'] = row.ct.loc['unclassified', 'up']
- d_filt.at[i, 'down_up'] = row.ct.loc['up', 'down']
- d_filt.at[i, 'down_down'] = row.ct.loc['down', 'down']
- d_filt.at[i, 'down_unclassified'] = row.ct.loc['unclassified', 'down']
- d_filt.at[i, 'channels1'] = ", ".join([str(x) for x in row.channels])
- d_filt[
- ['d_since_impl', 'start_dt', 'channels1', 'data', 'acc', 'b4sp',
- 'up_up', 'up_down', 'up_unclassified', 'down_up', 'down_down',
- 'down_unclassified']].to_csv(save_name.with_suffix(".csv"))
- return fig, bax, f_ax1, f_ax2, d_filt
- def plot_channels_used(d):
- """Plots for each speller block the channel that was used"""
- BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
- save_name = BASE_PATH_OUT / "Figure_S_channel_use"
- color_list = [(0.8, .8, .8), (.55, .55, .55), (.3, .3, .3), (0, 0, 0)]
- cmap = ListedColormap(color_list, len(color_list))
- df_blocks = d[d.channels.notna() & (d['mode'] == 'color')]
- ch_list = list(df_blocks.channels)
- ch_list_merged = list(itertools.chain(*ch_list))
- unique_channels = sorted(list(set(ch_list_merged)))
- ch_use_matrix = np.empty((len(unique_channels), len(df_blocks)))
- ch_use_matrix[:] = np.nan
- ch_count = np.zeros((len(unique_channels), len(df_blocks)))
- c_idx = {r: i for i, r in enumerate(unique_channels)}
- day_blocks = pd.DataFrame(index=df_blocks.d_since_impl.unique(), columns=unique_channels)
- day_blocks.fillna(0, inplace=True)
- b_d_s_i = df_blocks.d_since_impl
- b_d_s_i_change = (b_d_s_i.diff()).isna() | (b_d_s_i.diff() > 0)
- b_d_s_i_change = b_d_s_i_change.reset_index()
- tick_loc = b_d_s_i_change.index[b_d_s_i_change.d_since_impl].to_list()
- tick_lab = list(b_d_s_i[list(b_d_s_i_change.d_since_impl)])
- j = 0
- for i, r in df_blocks.iterrows():
- col_val = len(r['channels'])
- # if r['mode'] == 'color':
- # pass
- # col_val += 5
- for ch in r['channels']:
- ch_use_matrix[c_idx[ch], j] = col_val
- day_blocks.loc[r['d_since_impl'], ch] = 1
- ch_count[c_idx[ch], j] = 1
- j += 1
- # This is sorting by number of blocks. For consistency, number of days is better (below)
- # sort_idx = np.argsort(ch_count.sum(1))[::-1]
- channel_days = day_blocks.sum()
- sort_idx = list(channel_days.argsort()[::-1])
- sorted_channels = np.asarray(unique_channels)[sort_idx]
- channel_days.sort_values(ascending=False, inplace=True)
- # create array map with channel use counts
- array_map = ARRAY_MAPS['K01']
- map_sma = array_map[0]
- amesh_sma = np.empty((np.max(map_sma['x']) + 1, np.max(map_sma['y']) + 1))
- amesh_sma[:] = np.nan
- for index, el_ix in enumerate(map_sma['ix']):
- if el_ix + 1 in channel_days.index:
- amesh_sma[7 - map_sma['y'][index], map_sma['x'][index]] = channel_days[el_ix + 1]
- norm = mpl.colors.BoundaryNorm(np.arange(-.5+1, cmap.N+1), cmap.N)
- my_cmap = plt.get_cmap("plasma")
- rescale = lambda y: (y - np.min(y)) / (np.max(y) - np.min(y))
- # plot number of channels used per block, over days
- fig = plt.figure(constrained_layout=True)
- gs = fig.add_gridspec(2, 3)
- # fig, axs = plt.subplots(2, 1, constrained_layout=True)
- ax = fig.add_subplot(gs[0, :-1])
- imgplt = ax.pcolormesh(ch_use_matrix[sort_idx, :], cmap=cmap, norm=norm)
- ax.set_yticks(np.arange(0, len(unique_channels)) + .5)
- ax.set_yticklabels(sorted_channels)
- ax.set_ylabel('Channel ID')
- ax.set_xlabel('Days after implantation')
- ax.yaxis.set_inverted(True)
- ax.xaxis.set_ticks(tick_loc, minor=True)
- ax.set_xticks(tick_loc[::20])
- ax.set_xticklabels([f"{x}" for x in tick_lab[::20]])
- cbar = fig.colorbar(imgplt, ax=ax, location='right', ticks=range(1, cmap.N + 1))
- # Plotting the histogram for channels over days
- ax = fig.add_subplot(gs[1, :-1])
- ax.bar(range(len(channel_days)), channel_days, color=my_cmap(rescale(channel_days)), linewidth=.3, edgecolor='k')#, color='k')
- ax.set_xticks(range(len(channel_days)))
- ax.set_xticklabels(channel_days.index)
- ax.set_ylabel('used on number of days')
- ax.set_xlabel('Channel ID')
- ax = fig.add_subplot(gs[-1, -1])
- rect = plt.Rectangle((0, 0), 8, 8, fill=True, facecolor=(.3, .3, .3), edgecolor=(0, 0, 0), alpha=0.2, zorder=-1)
- ax.add_patch(rect)
- ax.pcolormesh(amesh_sma, cmap=my_cmap)
- ax.set_xlim(-1, 9)
- ax.set_ylim(-1, 9)
- ax.set_aspect(1)
- ax.axis('off')
- fig.savefig(save_name.with_suffix(".pdf"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.pdf')}>")
- fig.savefig(save_name.with_suffix(".eps"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.eps')}>")
- fig.savefig(save_name.with_suffix(".svg"))
- logger.info(f"Plot saved at <{save_name.with_suffix('.svg')}>")
- # save as csv
- ch_exp = ch_use_matrix[sort_idx, :]
- ch_exp[ch_exp > 0] = 1
- df_ch_exp = pd.DataFrame(data=ch_exp.T, columns=[f"ch_{c}" for c in sorted_channels], index=df_blocks.index)
- df_ch_exp['n_ch_used'] = df_ch_exp.sum(axis=1)
- df_ch_exp['d_since_impl'] = df_blocks.d_since_impl
- df_ch_exp.reset_index(drop=True, inplace=True)
- df_ch_exp[df_ch_exp.isna()] = 0
- df_ch_exp = df_ch_exp.astype('int')
- print(f"Number of channels used in speller blocks:\n{df_ch_exp[['n_ch_used']].groupby(['n_ch_used']).apply(lambda x: len(x))}")
- df_ch_exp.to_csv(save_name.with_suffix('._panel_a.csv'))
- data_panel_b = pd.DataFrame.from_dict({'days_used': channel_days})
- data_panel_b.to_csv(save_name.with_suffix('._panel_b.csv'))
- def calculate_ITR(d):
- # define ITR calculation helper
- def itr_helper(r):
- # number of selectable characters (26 letters, space, delete, question mark, end program)
- N = 30
- # remove substring from rating corresponding to the phrase_start
- rate_str = r.rating[len(r.phrase_start):]
- # calculate rate of correct characters
- P = rate_str.count('T') / (rate_str.count('T') + rate_str.count('F'))
- # calculate ITR in bits / min
- return (np.log2(N) + P * np.log2(P) + (1 - P) * np.log2((1 - P + 1e-6) / (N - 1))) * r.ch_per_min
- # select all rows with non-empty ratings
- dfwr = d.loc[(~d.rating.isna()) ]
- itr = dfwr.apply(itr_helper, axis=1)
- res = f"""There were {len(itr)} rated speller sessions. The min/mean/median/max ITR was {itr.min():.3f} / {itr.mean():.3f} / {itr.median():.3f} / {itr.max():.3f} bit per minute.\n"""
- print(res)
- if __name__ == '__main__':
- pth = BASE_PATH/'KIAP_BCI_neurofeedback'
- df = prepare_sessions(pth=pth)
- df = add_feedback_info(df, pth=pth)
- df = add_channel_info(df, pth=pth)
- pths = BASE_PATH/'KIAP_BCI_speller'
- ds = prepare_sessions(pth=pths)
- ds = add_feedback_info(ds, pth=pths)
- ds = add_channel_info(ds, pth=pths)
- d = pd.concat((df, ds), ignore_index=True)
- d.sort_values('start_dt', inplace=True)
- d.reset_index(drop=True, inplace=True)
- add_annotation(d)
- plot_sessions(d)
- plot_audio_feedback_tpfp(d)
- plot_fb_accuracy(d)
- plot_channels_used(d)
- speller_performance_by_nf(d)
- calculate_ITR(d)
|