Browse Source

gin commit from L-1010036236

New files: 17
Jonas Zimmermann 2 years ago
parent
commit
27f81fbfcb

+ 33 - 0
scripts_for_figures/README.md

@@ -0,0 +1,33 @@
+---
+title: "KIAP Paper figure scripts"
+author:
+    - Jonas B. Zimmermann
+date: 18 April 2020
+---
+
+# KIAP Paper figure scripts #
+
+## Purpose ##
+
+This repository contains Python scripts to reproduce the figures published in the KIAP paper.
+
+## Prerequisites ##
+
+1. A copy of the KIAP BCI dataset. By default, these data should be in a folder `../data` relative to this file. The location can be changed in `basics.py`.
+2. The output folder (by default `../out`) has to exist and be writable.
+3. Install conda (e.g. https://docs.conda.io/en/latest/miniconda.html)
+4. Open a terminal and navigate to the folder containing this file.
+5. Create a conda environment: `conda env create -f=environment.yml`
+6. Activate the new environment: `conda activate kiap_paper`
+7. Run `ipython`
+8. From within IPython, run the figure producing scripts:
+
+    run plot_figures_part_A.py
+    run plot_figure_3A.py
+    run plot_figure_4.py
+
+  * `plot_figures_part_A.py` contains functions to produce figures 2, 3B, and S2.
+  * `plot_figures_part_B.py` contains functions to produce figure 3A.
+  * `plot_figures_part_C.py` contains functions to produce figure 4.
+
+The figures will be saved to the output directory, as eps, pdf, and svg files.

File diff suppressed because it is too large
+ 2527 - 0
scripts_for_figures/annotations/speller/records_for_annotation_consolidated.yml


+ 28 - 0
scripts_for_figures/basics.py

@@ -0,0 +1,28 @@
+from datetime import datetime as dt
+from pathlib import Path
+
+BASE_PATH = Path('..')
+BASE_PATH_OUT = Path('..', 'figures_out')
+
+# DO NOT CHANGE BELOW THIS LINE #
+IMPLANT_DATE = dt.strptime('2019-03-19', '%Y-%m-%d')
+FEEDBACK_CHANGE_DATE = dt.strptime('2019-07-20', '%Y-%m-%d')
+
+ARRAY_MAPS = dict(K01=[
+    dict(name='SMA', y=[7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0,
+                        7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0],
+         x=[7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
+            3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+         ix=[1, 0, 2, 3, 6, 4, 5, 8, 16, 14, 7, 10, 12, 18, 9, 11, 22, 24, 20, 13, 15, 19, 26, 28, 30, 17, 21, 23, 25,
+             27, 29, 31, 96, 98, 97, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 111, 114, 116, 118,
+             120, 122, 124, 126, 127, 113, 115, 117, 119, 121, 123, 125]),
+    dict(name='M1', x=[7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4,
+                       4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
+                       0, 0, 0, 0],
+         y=[7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4,
+            3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0],
+         ix=[64, 66, 68, 70, 72, 74, 76, 78, 65, 67, 69, 71, 73, 75, 77, 79, 80, 81, 83, 82, 85, 84, 86, 87, 88, 89, 90,
+             91, 92, 93, 95, 94, 32, 34, 36, 38, 40, 44, 33, 35, 37, 39, 42, 46, 48, 47, 41, 43, 50, 52, 49, 45, 55, 54,
+             53, 51, 57, 56, 61, 59, 58, 63, 60, 62])
+]
+)

+ 20 - 0
scripts_for_figures/environment.yml

@@ -0,0 +1,20 @@
+name: kiap_paper
+dependencies:
+  - python>=3.8.1
+  - numpy
+  - scipy
+  - matplotlib==3.3.2
+  - pandas==1.0.5
+  - munch
+  - tabulate
+  - pip
+  - tk
+  - ipython
+  - pip:
+    - PyYAML
+    - pyaml
+    - colorlog
+    - cerberus
+    - pyfiglet
+    - brokenaxes
+

+ 0 - 0
scripts_for_figures/helpers/__init__.py


+ 113 - 0
scripts_for_figures/helpers/data.py

@@ -0,0 +1,113 @@
+import datetime
+import glob
+import logging
+import os
+import pathlib
+import pickle
+import random
+import time
+from datetime import datetime as dt
+
+import matplotlib.pyplot as plt
+import munch
+import numpy as np
+from scipy import io
+
+import yaml
+
+from .kaux import log
+
+from .ringbuffer import RingBuffer
+
+class DataNormalizer:
+    def __init__(self, params, initial_data=None):
+        self.params = params
+        self.norm_rate = {}
+        self.norm_rate['ch_ids'] = [ch.id for ch in self.params.daq.normalization.channels]
+        self.norm_rate['bottoms'] = np.asarray([ch.bottom for ch in self.params.daq.normalization.channels])
+        self.norm_rate['tops'] = np.asarray([ch.top for ch in self.params.daq.normalization.channels])
+        self.norm_rate['invs'] = [ch.invert for ch in self.params.daq.normalization.channels]
+        
+        n_norm_buffer = int(self.params.daq.normalization.len * (1000.0 / self.params.daq.spike_rates.loop_interval))
+        self.norm_buffer = RingBuffer(n_norm_buffer, dtype=(float, self.params.daq.n_channels), allow_overwrite=True)
+        self.last_update = time.time()
+        if initial_data is not None:
+            self.norm_buffer.extend(initial_data)
+    
+    def _update_norm_range(self):
+        buf_vals = self.norm_buffer[:, self.norm_rate['ch_ids']]
+        centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
+        self.norm_rate['bottoms'] = centiles[0, :]
+        self.norm_rate['tops'] = centiles[1, :]
+        log.info(f"Updated normalization ranges for channels {self.norm_rate['ch_ids']} to bottoms: {self.norm_rate['bottoms']}, tops: {self.norm_rate['tops']}")
+
+    def _update_norm_range_all(self):
+        buf_vals = np.mean(self.norm_buffer, axis=1)
+        centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
+        # log.info(f"Centiles: {centiles}")
+        
+        self.params.daq.normalization.all_channels.bottom = centiles[0]
+        self.params.daq.normalization.all_channels.top = centiles[1]
+        log.info(f"Updated normalization range for all channels to [{self.params.daq.normalization.all_channels.bottom}, {self.params.daq.normalization.all_channels.top}]")
+    
+    def update_norm_range(self, data=None, force=False):
+        if data is not None and data.size > 0:
+            self.norm_buffer.extend(data)
+        if (self.params.daq.normalization.do_update and (time.time() - self.last_update >= self.params.daq.normalization.update_interval)) or force:
+            if self.params.daq.normalization.use_all_channels:
+                self._update_norm_range_all()
+            else:
+                self._update_norm_range()
+            self.last_update = time.time()
+            log.info(f"New channel normalization setting: {yaml.dump(self._format_current_config(), sort_keys=False, default_flow_style=None)}")
+
+    def _format_current_config(self):
+        if self.params.daq.normalization.use_all_channels:
+            out_dict = {'all_channels': {'bottom': float(self.params.daq.normalization.all_channels.bottom), 'top': float(self.params.daq.normalization.all_channels.top),
+                'invert': bool(self.params.daq.normalization.all_channels.invert)}}
+        else:
+            out_dict = {'channels': []}
+            for ii in range(len(self.norm_rate['ch_ids'])):
+                out_dict['channels'].append({'id': int(self.norm_rate['ch_ids'][ii]), 
+                     'bottom': float(self.norm_rate['bottoms'][ii]),
+                     'top': float(self.norm_rate['tops'][ii]),
+                     'invert': self.norm_rate['invs'][ii]}
+                     )
+        return out_dict
+            
+    
+    def _calculate_all_norm_rate(self, buf_item):
+        avg_r = np.mean(buf_item, axis=1)
+        if self.params.daq.normalization.clamp_firing_rates:
+             avg_r = np.maximum(np.minimum(avg_r, self.params.daq.normalization.all_channels.top), self.params.daq.normalization.all_channels.bottom)
+        norm_rate = (avg_r - self.params.daq.normalization.all_channels.bottom) / (self.params.daq.normalization.all_channels.top - self.params.daq.normalization.all_channels.bottom)
+        if self.params.daq.normalization.all_channels.invert:
+            norm_rate = 1 - norm_rate
+        return norm_rate
+        
+    def _calculate_individual_norm_rate(self, buf_items):
+        """Calculate normalized firing rate, determined by feedback settings"""
+        if self.params.daq.normalization.clamp_firing_rates:
+            clamped_rates = np.maximum(np.minimum(buf_items[:, self.norm_rate['ch_ids']], self.norm_rate['tops']), self.norm_rate['bottoms'])
+        else:
+            clamped_rates = buf_items[:, self.norm_rate['ch_ids']]
+        denom = self.norm_rate['tops'] - self.norm_rate['bottoms']
+        if np.all(denom==0):
+            denom[:] = 1
+        norm_rates = (clamped_rates - self.norm_rate['bottoms']) / denom
+        norm_rates[:, self.norm_rate['invs']] = 1 - norm_rates[:, self.norm_rate['invs']]
+        norm_rate = np.nanmean(norm_rates, axis=1)
+        if not self.params.daq.normalization.clamp_firing_rates:
+            norm_rate = np.maximum(norm_rate, 0.0)
+        return norm_rate
+                   
+        
+    def calculate_norm_rate(self, buf_item):
+        if buf_item.ndim == 1:
+            buf_item.shape = (1, buf_item.shape[0])
+        if self.params.daq.normalization.use_all_channels:
+            return self._calculate_all_norm_rate(buf_item)
+        else:
+            return self._calculate_individual_norm_rate(buf_item)
+        
+

+ 521 - 0
scripts_for_figures/helpers/data_management.py

@@ -0,0 +1,521 @@
+'''
+description: script to read header and data from  data.bin
+author: Ioannis Vlachos
+date:   02.11.18
+
+Copyright (c) 2018 Ioannis Vlachos.
+All rights reserved.
+
+HEADER OF BINARY FILE
+---------------------
+laptop timestamp        np.datetime64       bytes: 8
+NSP timestamp           np.int64            bytes: 8
+number of bytes         np.int64            bytes: 8
+number of samples       np.int64            bytes: 8
+number of channels      np.int64            bytes: 8
+
+'''
+
+import csv
+import datetime as dt
+import glob
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+from numpy import datetime64 as dt64
+from tabulate import tabulate
+# import munch
+
+from . import kaux as aux
+from .kaux import log
+
+# params = aux.load_config()
+
+
+# def get_event_name(filename, events_file_names)
+
+
+def get_raw(verbose=0, n_triggers=2, exploration=False, feedback=False, triggers_all=False, fids=[], trigger_pos='start'):
+
+    '''read raw data from one or more binary files
+    trigger_pos: ['start, 'stop']'''
+    params = aux.load_config()
+    n_channels = params.daq.n_channels
+
+
+    file_names = []
+    # events_file_names = []
+    print('Available binary data files:\n')
+    for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.bin', recursive=True))): 
+        print(f'{ii}, {file_name}, {os.path.getsize(file_name)//1000}K')
+        file_names.append(file_name)
+
+    # for ii, file_name in enumerate(sorted(glob.iglob(params.file_handling.data_path + '**/*.txt', recursive=True))): 
+        # events_file_names.append(file_name)
+    
+    if fids ==[]:
+        fids = [int(x) for x in input('\nSelect file ids (separated by space): ').split()]        # select file ids
+        if fids == []:
+            fids = [len(file_names)-1]
+            print(f'Selected: {fids}')
+
+    data_tot = np.empty((len(fids), 1), dtype=object)
+    time_stamps_tot = np.empty((len(fids), 1), dtype=object)
+    if triggers_all:
+        triggers_tot = np.empty((len(fids), 6), dtype=object)
+    elif exploration:
+        triggers_tot = np.empty((len(fids), 1), dtype=object)
+    elif feedback:
+        triggers_tot = np.empty((len(fids), 2), dtype=object)       # up and down
+    else:
+        triggers_tot = np.empty((len(fids), n_triggers), dtype=object)
+
+    ch_rec_list_tot = np.empty((len(fids), 1), dtype=object)
+
+    for ii, fid in enumerate(fids):     # go through all sessions
+        data, time_stamps, ch_rec_list = get_session(file_names[fid], verbose=verbose)
+        data = np.delete(data, params.classifier.exclude_data_channels, axis=1)
+        # triggers = get_triggers(os.path.dirname(file_names[fid])+'/events.txt', time_stamps, n_triggers)
+        event_file_name = file_names[fid].replace('data_','events_').replace('.bin','.txt')
+        info_file_name = file_names[fid].replace('data_','info_').replace('.bin','.log')
+        
+        if triggers_all:
+            triggers = get_triggers_all(event_file_name, time_stamps, trigger_pos)
+        elif exploration:
+            triggers = get_triggers_exploration(event_file_name, time_stamps)
+        elif feedback:
+            triggers = get_triggers_feedback(event_file_name, time_stamps, trigger_pos, n_triggers)
+        else:
+            triggers = get_triggers(event_file_name, time_stamps, trigger_pos, n_triggers)
+
+            # yes_mask, no_mask = get_accuracy(info_file_name,n_triggers)
+            # triggers[0] = np.where(yes_mask,triggers[0],'')
+            # triggers[1] = np.where(no_mask,triggers[1],'')
+
+        data_tot[ii, 0] = data
+        time_stamps_tot[ii, 0] = time_stamps
+        if exploration:
+            triggers_tot[ii,0] = triggers
+        else:
+            triggers_tot[ii, :] = triggers
+        ch_rec_list_tot[ii, 0] = ch_rec_list
+        # triggers_tot[ii,1] = triggers[1]
+        print(f'\nRead binary  neural data from file {file_names[fid]}')
+        print(f'Read trigger info events from file {event_file_name}')
+        print(f'Session {ii}, data shape: {data.shape}')
+
+
+    config = read_config('paradigm.yaml')
+    states = config.exploration.states
+
+    for jj, _ in enumerate(fids):
+        print()
+        if triggers_all:
+            for ii in range(6):
+                print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
+        elif exploration:
+            for ii in range(len(triggers_tot[jj,0][0])):
+                print(f'Session {fids[jj]}, State {states[ii]}, trigger #: {triggers_tot[jj,0][0][ii].shape}')
+        else:
+            for ii in range(n_triggers):
+                print(f'Session {fids[jj]}, Class {ii}, trigger #: {triggers_tot[jj,ii].shape}')
+    
+    file_names = [file_names[ii] for ii in fids]
+    return  data_tot, time_stamps_tot, triggers_tot, ch_rec_list_tot, file_names
+
+
+def get_session(file_name, verbose=0, t_lim_start=None, t_lim_end=None, params=None):
+
+    ii = 0
+    # data = np.empty((0, params.daq.n_channels-len(params.daq.exclude_channels)))
+    data = np.empty((0, params.daq.n_channels_max))
+
+    log.info(f'Data shape: {data.shape}')
+
+    time_stamps = []
+    date_times = []
+    time_stamps_rcv = []
+
+    with open(file_name, 'rb') as fh:
+        log.info(f'\nReading binary file {file_name}...\n')
+        while True:        
+            tmp = np.frombuffer(fh.read(8), dtype='datetime64[us]')          # laptop timestamp
+
+            if tmp.size == 0:
+                break
+                # return data, np.array(time_stamps, dtype=np.uint), ch_rec_list
+            else:
+                t_now1 = tmp
+
+            if (t_lim_end is not None) and (t_lim_end < t_now1[0]):
+                break
+
+            t_now2 = np.frombuffer(fh.read(8), dtype=np.int64)[0]               # NSP timestamp
+
+            n_bytes = int.from_bytes(fh.read(8), byteorder='little')            # number of bytes
+            n_samples = int.from_bytes(fh.read(8), byteorder='little')          # number of samples
+            n_ch = int.from_bytes(fh.read(8), byteorder='little')               # number of channels
+            ch_rec_list_len = int.from_bytes(fh.read(2), byteorder='little')
+            ch_rec_list = np.frombuffer(fh.read(ch_rec_list_len), dtype=np.uint16) # detailed channel list
+            # log.info(f'recorded channels: {ch_rec_list}')
+
+            d = fh.read(n_bytes)
+            d2 = np.frombuffer(d, dtype=np.float32)
+            d3 = d2.reshape(d2.size // n_ch, n_ch)                              # data, shape : (n_samples, n_ch)
+            log.info(f'data shape: {d3.shape}')
+            
+            if data.size == 0 and data.shape[1] != d3.shape[1]:
+                log.warning(f'Shape mismatch. {d3.shape} vs {data.shape[1]}. Using data shape from file: {d3.shape}')
+                data = np.empty((0, d3.shape[1]))
+
+            # fct = params.daq.spike_rates.bin_width * 30000                        # factor to get correct starting time in ticks
+            # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1))   # check if +1 index is correct
+            # time_stamps.extend(np.arange(t_now2-d3.shape[0]*fct + 1, t_now2+1, 3000))   # check if +1 index is correct
+            ts = np.frombuffer(fh.read(8 * d3.shape[0]), dtype=np.uint64)
+
+            if (t_lim_start is None) or (t_lim_start <= t_now1[0] -  np.timedelta64(int((t_now2 - ts[0]) / 3e4), 's')):
+                data = np.concatenate((data, d3))
+                time_stamps.extend(ts)
+                date_times.append(t_now1[0])
+                time_stamps_rcv.append(t_now2)
+            else:
+                log.info(f"Skipped set {t_now1[0]} | {ts[0]} n_samples: {n_samples}")
+
+            # if ts.size[0] > 0:
+            log.info(f'ts size: {ts.size}')
+            # log.info(time_stamps)
+
+            # if verbose:
+            # print(ii, t_now1[0], t_now2, n_bytes, n_samples, n_ch, d3[10:20, 0], np.any(d3))
+            ii += 1
+
+    return data, np.array(time_stamps, dtype=np.uint), ch_rec_list
+    # return data, np.array(time_stamps, dtype=np.uint), ch_rec_list, np.array(date_times, dtype='datetime64[us]'), \
+    #        np.array(time_stamps_rcv, dtype=np.uint64)
+
+def get_accuracy_question(fname, n_triggers=2):
+    with open(fname, 'r') as fh:
+        events = fh.read().splitlines()
+
+        cl1 = []
+        cl2 = []
+
+        for ev in events:
+            if 'Yes Question' in ev:
+                if 'Decoder decision: yes' in ev:
+                    cl1.append(True)
+                else:
+                    cl1.append(False)
+            elif 'No Question' in ev:
+                if 'Decoder decision: no' in ev:
+                    cl2.append(True)
+                else:
+                    cl2.append(False)
+
+        return [ind1, ind2]
+
+def get_triggers(fname, time_stamps, trigger_pos, n_triggers=2):
+    with open(fname, 'r') as fh:
+        # events = fh.readlines()
+        events = fh.read().splitlines()
+
+        cl1 = []
+        cl2 = []
+        cl3 = []
+        tt1 = []
+        tt2 = []
+        tt3 = []
+
+        for ev in events:
+            if 'response' in ev and 'yes' in ev and trigger_pos in ev:
+                cl1.append(int(ev.split(',')[0]))
+            elif 'response' in ev and 'no' in ev and trigger_pos in ev:
+                cl2.append(int(ev.split(',')[0]))
+            elif 'baseline' in ev and 'start'in ev:                
+                cl3.append(int(ev.split(',')[0]))
+
+
+    if n_triggers == 2:
+        # cl2.extend(cl3)         # add baseline to class 2
+        cl3 = []
+
+    for ev in events:
+        if 'response' in ev and trigger_pos in ev:
+            print(f'\033[91m{ev}\033[0m')
+        else:
+            print(ev)
+        
+
+    for  ii in cl1:
+        tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+
+    for  ii in cl2:
+        tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+    
+    for  ii in cl3:
+        tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+
+    ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]  
+    ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
+    ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
+
+    print()
+    print(cl1, tt1, ind1)
+    print(cl2, tt2, ind2)
+    print(cl3, tt3, ind3)
+
+    if n_triggers == 1:
+        res = [ind1]
+    elif n_triggers == 2:
+        res = [ind1, ind2]
+        # res = [ind1, np.hstack((ind2, ind3))]       # put no and baseline together
+    elif n_triggers == 3:
+        res = [ind1, ind2, ind3]
+
+    return res
+
+def get_triggers_feedback(fname, time_stamps, trigger_pos, n_triggers=2):
+    with open(fname, 'r') as fh:
+        # events = fh.readlines()
+        events = fh.read().splitlines()
+
+        cl1 = []
+        cl2 = []
+        cl3 = []
+        tt1 = []
+        tt2 = []
+        tt3 = []
+
+        for ev in events:
+            if 'response' in ev and 'down' in ev and trigger_pos in ev:
+                cl2.append(int(ev.split(',')[0]))
+            elif 'response' in ev and 'up' in ev and trigger_pos in ev:
+                cl1.append(int(ev.split(',')[0]))
+            elif 'baseline' in ev and 'start'in ev:                
+                cl3.append(int(ev.split(',')[0]))
+    
+    if n_triggers == 2:
+        # cl2.extend(cl3)         # add baseline to class 2
+        cl3 = []
+
+    for ev in events:
+        if 'response' in ev and trigger_pos in ev:
+            print(f'\033[91m{ev}\033[0m')
+        else:
+            print(ev)
+        
+
+    for  ii in cl1:
+        tt1.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+
+    for  ii in cl2:
+        tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+    
+    for  ii in cl3:
+        tt3.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+
+    ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]  
+    ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
+    ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
+
+    print()
+    print(cl1, tt1, ind1)
+    print(cl2, tt2, ind2)
+    print(cl3, tt3, ind3)
+
+    if n_triggers == 1:
+        res = [ind1]
+    elif n_triggers == 2:
+        res = [ind1, ind2]
+        # res = [ind1, np.hstack((ind2, ind3))]       # put no and baseline together
+    elif n_triggers == 3:
+        res = [ind1, ind2, ind3]
+
+    return res
+
+def get_triggers_exploration(fname, time_stamps):
+    with open(fname, 'r') as fh:
+        # events = fh.readlines()
+        events = fh.read().splitlines()
+
+        config = read_config('paradigm.yaml')
+        states = config.exploration.states
+
+        cl1 = [[] for x in range(len(states))]
+        cl2 = []
+        # cl3 = []
+        tt1 = [[] for x in range(len(states))]
+        tt2 = []
+        # tt3 = []
+        for ev in events:
+            for ii,state in enumerate(states):
+                if 'response' in ev and state in ev and 'start' in ev:
+                    cl1[ii].append(int(ev.split(',')[0]))
+            if 'baseline' in ev and 'start'in ev:                
+                cl2.append(int(ev.split(',')[0]))
+
+        
+    for ev in events:
+        if 'response' in ev and 'start' in ev:
+            print(f'\033[91m{ev}\033[0m')
+        else:
+            print(ev)
+        
+
+    for  ii in range(len(cl1)):
+        for jj in cl1[ii]:
+            tt1[ii].append(time_stamps.flat[np.abs(time_stamps - jj).argmin()])
+
+    for  ii in cl2:
+        tt2.append(time_stamps.flat[np.abs(time_stamps - ii).argmin()])
+    
+    ind1 = [[] for x in range(len(states))]
+    for  ii in range(len(tt1)):
+        ind1[ii] = np.where(np.in1d(time_stamps, tt1[ii]))[0][np.newaxis, :] 
+
+    ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
+    # ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
+
+    print()
+    print(cl1, tt1, ind1)
+    print(cl2, tt2, ind2)
+    # print(cl3, tt3, ind3)
+    # xx
+    res = [ind1,ind2]
+
+    return res
+
+
+def get_triggers_all(fname, time_stamps, trigger_pos, n_triggers=2):
+    with open(fname, 'r') as fh:
+        # events = fh.readlines()
+        events = fh.read().splitlines()
+
+        cl1 = []
+        cl2 = []
+        cl3 = []
+        cl4 = []
+        cl5 = []
+        cl6 = []
+        tt1 = []
+        tt2 = []
+        tt3 = []
+        tt4 = []
+        tt5 = []
+        tt6 = []
+
+        for ev in events:
+            if 'baseline' in ev and 'start'in ev:
+                cl1.append(int(ev.split(',')[0]))
+
+            elif 'baseline' in ev and 'stop'in ev:
+                cl2.append(int(ev.split(',')[0]))
+
+            elif 'stimulus' in ev and 'start'in ev:
+                cl3.append(int(ev.split(',')[0]))
+
+            elif 'stimulus' in ev and 'stop'in ev:
+                cl4.append(int(ev.split(',')[0]))
+
+            elif 'response' in ev and 'start' in ev:
+                cl5.append(int(ev.split(',')[0]))
+
+            elif 'response' in ev and 'stop' in ev:
+                cl6.append(int(ev.split(',')[0]))
+    
+
+    for ev in events:
+        if 'response' in ev and trigger_pos in ev:
+            print(f'\033[91m{ev}\033[0m')
+        else:
+            print(ev)
+
+    for  ii in cl1:
+        tt1.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+
+    for  ii in cl2:
+        tt2.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+    
+    for  ii in cl3:
+        tt3.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+
+    for  ii in cl4:
+        tt4.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+
+    for  ii in cl5:
+        tt5.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+    
+    for  ii in cl6:
+        tt6.append(time_stamps.flat[np.abs(np.int64(time_stamps - ii)).argmin()])
+
+    ind1 = np.where(np.in1d(time_stamps, tt1))[0][np.newaxis, :]
+    ind2 = np.where(np.in1d(time_stamps, tt2))[0][np.newaxis, :]
+    ind3 = np.where(np.in1d(time_stamps, tt3))[0][np.newaxis, :]
+    ind4 = np.where(np.in1d(time_stamps, tt4))[0][np.newaxis, :]
+    ind5 = np.where(np.in1d(time_stamps, tt5))[0][np.newaxis, :]
+    ind6 = np.where(np.in1d(time_stamps, tt6))[0][np.newaxis, :]
+
+    print('\nTriggers and timestamps')
+    print(cl1, tt1, ind1)
+    print(cl2, tt2, ind2)
+    print(cl3, tt3, ind3)
+    print(cl4, tt4, ind4)
+    print(cl5, tt5, ind5)
+    print(cl6, tt6, ind6)
+
+    # if n_triggers == 1:
+    #     res = [ind1]
+    # elif n_triggers == 2:
+    #     res = [ind1, ind2]
+    #     # res = [ind1, np.hstack((ind2, ind3))]       # put no and baseline together
+    # elif n_triggers == 3:
+    #     res = [ind1, ind2, ind3]
+
+    res = [ind1, ind2, ind3, ind4, ind5, ind6]
+
+    return res
+
+def read_config(file_name):
+    try:
+        with open(file_name) as stream:
+            config = munch.fromYAML(stream)
+        return config
+    except Exception as e:
+        raise e
+
+
+if __name__ == '__main__':
+
+    print("\nto read binary data use: 'data = get_raw(verbose=1)'")
+    print("\nto read log file use: 'log = read_log(date)'")
+
+
+    if aux.args.speller == 'exploration':
+        exploration = True
+    else:
+        exploration = False
+    if aux.args.speller == 'feedback':
+        feedback = True
+    else:
+        feedback = False
+
+    col = ['b', 'r', 'g']
+    data_tot, tt, triggers_tot, ch_rec_list, file_names = get_raw(n_triggers=params.classifier.n_classes, exploration=exploration, feedback=feedback)
+
+    if not exploration:
+        plt.figure(1)
+        plt.clf()
+        xx = np.arange(data_tot[0, 0].shape[0]) * 0.05
+
+        for cl_id in range(triggers_tot.shape[1]):
+            markerline, stemlines, baseline = plt.stem(triggers_tot[0, cl_id][0]*0.05, triggers_tot[0, cl_id][0] * 0 + 50, '--', basefmt=' ')
+            plt.setp(stemlines, alpha=0.8, color=col[cl_id], lw=1)
+            plt.setp(markerline, marker=None)
+
+        plt.gca().set_prop_cycle(None)
+        plt.plot(xx, data_tot[0, 0][:, :2], alpha=0.5)
+        # plt.plot(cur_data[0, 0][:, :2])
+    
+        plt.show()

+ 276 - 0
scripts_for_figures/helpers/fbplot.py

@@ -0,0 +1,276 @@
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+import numpy as np
+
+from munch import Munch
+from .kaux import mergemunch
+
+def plot_df(df, plot_order, plot_t, fb_states, fig=None, options=None):
+    """
+    options: Munch dictionary with keys:
+        show_beginning: 0: don't show beginning of trial trace
+            1: mark beginning by a symbol
+            > 1: show first number of samples with a heavy line
+        show_end: 0: don't show end of trial trace
+            1: mark end by a symbol
+            > 1: show last number of samples with a heavy line
+        show_thresholds: T: will use thresholds in fb_states to plot horizontal lines
+        show_median: T: will show median lines for groups
+    """
+    default_options = Munch({'show_beginning': 1, 'show_end': 2, 'show_thresholds': True, 'show_median':True})
+    if options is None:
+        options = default_options
+    else:
+        options = Munch(mergemunch(default_options, options))
+
+    if fig is None:
+        fig = plt.figure(34, figsize=(16, 4))
+        fig.clf()
+
+    n_rows = np.max([x.r for x in plot_order.p]) + 1
+    n_c = np.max([x.c for x in plot_order.p]) + 1
+    n_cols = len(plot_order.s) * n_c
+
+    axs = fig.subplots(n_rows, n_cols, False, True, squeeze=False)
+
+    for i, s in enumerate(plot_order.s):
+        good_sample_idx = f'{s.dcol}_good'
+        xtr_n = plot_t[s.dcol].twin
+        xtr_n = xtr_n - xtr_n[0]
+        xtr_n = xtr_n / xtr_n[-1]
+
+        i_c = i
+        i_r = 0
+
+        if s.dcol == 'stimulus_start_samples':
+            stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
+
+            # rect = patches.Rectangle((0.0,0.0), stim_end_t, 1, linewidth=1, edgecolor='none', facecolor=(.01,.01,.01,.14))
+            # axs[i_c].add_patch(rect)
+
+            for (st, vals) in fb_states.items():
+                axs[i_r, i_c].plot([0, stim_end_t], [vals[0], vals[0]], lw=4, ls='-', c=plot_order.state[st].c,
+                                   alpha=.99, clip_on=False)
+                axs[i_r, i_c].text(0, vals[0] + .02, plot_order.state[st].d, c=plot_order.state[st].c, alpha=1)
+
+        for jj, pp in enumerate(plot_order.p):
+            pdd = df[pp.sel(df)]
+            if len(pdd) == 0:
+                continue
+            i_c = i + pp.c * len(plot_order.s)
+            i_r = pp.r
+
+            axs[i_r, i_c].spines['top'].set_visible(False)
+            axs[i_r, i_c].spines['right'].set_visible(False)
+            # if s.dcol == 'stimulus_start_samples':
+            #     stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
+            #
+            #     axs[i_r, i_c].plot([0, stim_end_t], [fb_states[pp.t][0], fb_states[pp.t][0]], lw=4, ls='-',
+            #                        c=plot_order.state[pp.t].c, alpha=.99, clip_on=False)
+            #     axs[i_r, i_c].text(0, fb_states[pp.t][0] + .02, plot_order.state[pp.t].d, c=plot_order.state[pp.t].c,
+            #                        alpha=1)
+
+            for ix, pd_row in pdd.iterrows():
+                if pd_row[s.dcol] is None:
+                    continue
+                y = pd_row[s.dcol][pd_row[good_sample_idx]]
+                xtr = plot_t[s.dcol].twin[pd_row[good_sample_idx]]
+                if s.norm:
+                    xtr = xtr - xtr[0]
+                    xtr = xtr / xtr[-1]
+                axs[i_r, i_c].plot(xtr, y, c=pp.col, alpha=.3, lw=.8)
+                if options.show_beginning == 1:
+                    axs[i_r, i_c].plot(xtr[0], y[0], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
+                elif options.show_beginning > 1:
+                    axs[i_r, i_c].plot(xtr[:(options.show_beginning - 1)], y[:(options.show_beginning - 1)], c=pp.col, alpha=.8, lw=1.5)
+                if options.show_end == 1:
+                    axs[i_r, i_c].plot(xtr[-1], y[-1], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
+                elif options.show_end > 1:
+                    axs[i_r, i_c].plot(xtr[(-options.show_end ):], y[(-options.show_end ):], c=pp.col, alpha=.8, lw=1.5)
+                # axs[i_r, i_c].plot(xtr_n, pd_row['response_start_samples_norm'], c=plot_order.cols[i_col], alpha=.3, lw=1)
+            if options.show_median and not pdd[s.dcol].empty:
+                if s.norm:
+                    all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
+                    mean_tr = np.median(all_traces, 0)
+                    axs[i_r, i_c].plot(xtr_n, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
+                else:
+                    all_traces = np.vstack(pdd[s.dcol].to_numpy())
+                    mean_tr = np.nanmedian(all_traces, 0)
+                    axs[i_r, i_c].plot(plot_t[s.dcol].twin, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
+            axs[i_r, i_c].set_title(f'{s.title} {pp.desc} n={len(pdd)}')
+            if options.show_thresholds:
+                thr = fb_states.up
+                axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
+                thr = fb_states.down
+                axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
+            axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
+    return fig
+
+def plot_df_combined(df, plot_order, plot_t, fb_states, fig=None, options=None):
+    """
+    options: Munch dictionary with keys:
+        show_beginning: 0: don't show beginning of trial trace
+            1: mark beginning by a symbol
+            > 1: show first number of samples with a heavy line
+        show_end: 0: don't show end of trial trace
+            1: mark end by a symbol
+            > 1: show last number of samples with a heavy line
+        show_thresholds: T: will use thresholds in fb_states to plot horizontal lines
+        show_median: T: will show median lines for groups
+    """
+    default_options = Munch({'show_beginning': 1, 'show_end': 2, 'show_thresholds': True, 'show_median':True})
+    if options is None:
+        options = default_options
+    else:
+        options = Munch(mergemunch(default_options, options))
+
+    if fig is None:
+        fig = plt.figure(35, figsize=(16, 4))
+        fig.clf()
+
+    n_cols = len(plot_order.s)
+
+    axs = fig.subplots(1, n_cols, False, True, squeeze=False)
+
+    for i, s in enumerate(plot_order.s):
+        good_sample_idx = f'{s.dcol}_good'
+        xtr_n = plot_t[s.dcol].twin
+        xtr_n = xtr_n - xtr_n[0]
+        xtr_n = xtr_n / xtr_n[-1]
+        i_c = i
+        i_r = 0
+        axs[i_r, i_c].set_title(f'{s.title}')
+        
+        if options.show_thresholds:
+            thr = fb_states.up
+            axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
+            thr = fb_states.down
+            axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='-', c="#333333", lw=0.5, alpha=.4)
+        axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
+        axs[i_r, i_c].spines['top'].set_visible(False)
+        axs[i_r, i_c].spines['right'].set_visible(False)
+
+        for jj, pp in enumerate(plot_order.p):
+            pdd = df[pp.sel(df)]
+
+            if s.dcol == 'stimulus_start_samples':
+                stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
+
+                axs[i_r, i_c].plot([0, stim_end_t], [fb_states[pp.t][0], fb_states[pp.t][0]], lw=4, ls='-',
+                                   c=plot_order.state[pp.t].c, alpha=.99, clip_on=False)
+                axs[i_r, i_c].text(0, fb_states[pp.t][0] + .02, plot_order.state[pp.t].d, c=plot_order.state[pp.t].c,
+                                   alpha=1)
+
+            my_label = f"{pp.desc} (n = {len(pdd)})"
+            for ix, pd_row in pdd.iterrows():
+                y = pd_row[s.dcol][pd_row[good_sample_idx]]
+                xtr = plot_t[s.dcol].twin[pd_row[good_sample_idx]]
+                if s.norm:
+                    xtr = xtr - xtr[0]
+                    xtr = xtr / xtr[-1]
+                axs[i_r, i_c].plot(xtr, y, c=pp.col, alpha=.3, lw=.8, ls=pp.ls, label=my_label)
+                my_label = None
+                if options.show_beginning == 1:
+                    axs[i_r, i_c].plot(xtr[0], y[0], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
+                elif options.show_beginning > 1:
+                    axs[i_r, i_c].plot(xtr[:(options.show_beginning - 1)], y[:(options.show_beginning - 1)], c=pp.col, alpha=.8, lw=1.5)
+                if options.show_end == 1:
+                    axs[i_r, i_c].plot(xtr[-1], y[-1], c=pp.col, alpha=.3, lw=.8, marker='o', markersize=2)
+                elif options.show_end > 1:
+                    axs[i_r, i_c].plot(xtr[(-options.show_end ):], y[(-options.show_end ):], c=pp.col, alpha=.8, lw=1.5)
+                # axs[i_r, i_c].plot(xtr_n, pd_row['response_start_samples_norm'], c=plot_order.cols[i_col], alpha=.3, lw=1)
+            if options.show_median and not pdd[s.dcol].empty:
+                if s.norm:
+                    all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
+                    mean_tr = np.median(all_traces, 0)
+                    axs[i_r, i_c].plot(xtr_n, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls)
+                else:
+                    all_traces = np.vstack(pdd[s.dcol].to_numpy())
+                    mean_tr = np.nanmedian(all_traces, 0)
+                    axs[i_r, i_c].plot(plot_t[s.dcol].twin, mean_tr, c=pp.col, alpha=.6, lw=4, ls=pp.ls, label=pp.desc)
+        axs[i_r, i_c].legend()
+    return fig
+
+def plot_df_avg(df, plot_order, plot_t, fb_states, fig=None, show_stimulus_start=False):
+    if fig is None:
+        fig = plt.figure(35, figsize=(16, 4))
+        fig.clf()
+
+    n_rows = 1
+    # n_c = len(plot_order.c)
+    n_cols = len(plot_order.s)
+    axs = fig.subplots(n_rows, n_cols, False, True, squeeze=False)
+    i_r = 0
+
+    for i, s in enumerate(plot_order.s):
+        good_sample_idx = f'{s.dcol}_good'
+        xtr_n = plot_t[s.dcol].twin
+        xtr_n = xtr_n - xtr_n[0]
+        xtr_n = xtr_n / xtr_n[-1]
+        i_c = i
+        i_r = 0
+        axs[i_r, i_c].spines['top'].set_visible(False)
+        axs[i_r, i_c].spines['right'].set_visible(False)
+        if s.dcol == 'stimulus_start_samples':
+            stim_end_t = np.mean((df.stimulus_stop - df.stimulus_start) / 3e4)
+
+            # rect = patches.Rectangle((0.0,0.0), stim_end_t, 1, linewidth=1, edgecolor='none', facecolor=(.01,.01,.01,.14))
+            # axs[i_c].add_patch(rect)
+
+            for (st, vals) in fb_states.items():
+                axs[i_r, i_c].plot([0, stim_end_t], [vals[0], vals[0]], lw=4, ls='-', c=plot_order.state[st].c,
+                                   alpha=.99, clip_on=False)
+                axs[i_r, i_c].text(0, vals[0] + .02, plot_order.state[st].d, c=plot_order.state[st].c, alpha=1)
+
+        for jj, pp in enumerate(plot_order.p):
+            pdd = df[pp.sel(df)]
+
+            i_r = 0
+
+            if not pdd[s.dcol].empty:
+                if s.norm:
+                    all_traces = np.vstack(pdd[f'{s.dcol}_norm'].to_numpy())
+                    avg_x = xtr_n
+                else:
+                    all_traces = np.vstack(pdd[s.dcol].to_numpy())
+                    avg_x = plot_t[s.dcol].twin
+
+                mean_tr = np.nanmedian(all_traces, 0)
+                perct_tr = np.nanpercentile(all_traces, [25, 75], axis=0)
+
+                axs[i_r, i_c].plot(avg_x, mean_tr, c=pp.col, alpha=1, lw=pp.lw, clip_on=False, ls=pp.ls)
+                if pp.show_var:
+                    axs[i_r, i_c].plot(avg_x, perct_tr[0, :], c=pp.col, alpha=.1, lw=pp.lw, clip_on=False, ls=pp.ls)
+                    axs[i_r, i_c].plot(avg_x, perct_tr[1, :], c=pp.col, alpha=.1, lw=pp.lw, clip_on=False, ls=pp.ls)
+                    axs[i_r, i_c].fill_between(avg_x, perct_tr[0, :], perct_tr[1, :],
+                                               alpha=0.20, facecolor=pp.col, edgecolor='none', ls=pp.ls, lw=1,
+                                               antialiased=True, clip_on=False)
+
+                if show_stimulus_start and s.dcol == 'stimulus_stop_samples':
+                    stim_start_offset = (pdd.stimulus_start - pdd.stimulus_stop) / 3e4
+                    for x in stim_start_offset:
+                        axs[i_r, i_c].axvline(x, c=[0, 0, 0], alpha=.1)
+            thr = fb_states.up
+            # axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[1], thr[1]], linestyle='--', c="#333333")
+            thr = fb_states.down
+
+            # axs[i_r, i_c].plot(plot_t[s.dcol].twin[[0,-1]], [thr[2], thr[2]], linestyle='--', c="#333333")
+        # axs[i_c].set_ylim([0,1])
+        if s.norm:
+            axs[i_r, i_c].set_xlim([0, 1])
+            axs[i_r, i_c].set_xlabel('t (normalized)')
+
+        else:
+            axs[i_r, i_c].set_xlim(plot_t[s.dcol].t)
+            axs[i_r, i_c].set_xlabel('t [s]')
+
+        if i_c == 0:
+            axs[i_r, i_c].set_ylabel('normalized neural activity')
+        axs[i_r, i_c].set_title(f'{s.title}')
+
+    custom_lines = [Line2D([0], [0], color=(.3, .3, .3), lw=2, ls='-'),
+                    Line2D([0], [0], color=(.3, .3, .3), lw=2, ls=':')]
+    axs[i_r, n_cols - 1].legend(custom_lines, ['Correct Trials', 'Error Trials'])
+
+    fig.show()
+    return fig

+ 268 - 0
scripts_for_figures/helpers/kaux.py

@@ -0,0 +1,268 @@
+import argparse
+import datetime
+import glob
+import logging
+import os
+import pathlib
+import sys
+from enum import Enum
+
+import yaml
+import munch
+import numpy as np
+import subprocess
+from colorlog import ColoredFormatter
+
+from . import validate_config as vd
+from functools import reduce
+
+parser = argparse.ArgumentParser()
+parser.add_argument("gui", nargs='?', help="flag, 1:start gui", type=int, default=0)
+parser.add_argument("plot", nargs='?', help="flag, 1:start plot", type=int, default=0)
+parser.add_argument("--log", help="set verbosity level [DEBUG, INFO, WARNING, ERROR]")
+parser.add_argument('--speller', default='')
+parser.add_argument('-l', '--list', help='delimited list input', type=str)
+args = parser.parse_args()
+
+class decision(Enum):
+    yes = 0
+    no = 1
+    nc = -3                     # not confirmed
+    baseline = 2
+    unclassified = -1        # e.g. when history is not big enough yet
+    error1 = -1             # not enough data to get_class, see classifier get_class2
+    error2 = -2
+
+
+def static_vars(**kwargs):
+    def decorate(func):
+        for k in kwargs:
+            setattr(func, k, kwargs[k])
+        return func
+    return decorate
+
+
+
+def config_logging():
+    LOG_LEVEL = logging.WARN
+    
+    if args.log is not None:
+        LOG_LEVEL = eval('logging.' + args.log.upper())
+
+    # LOGFORMAT = "  %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
+    LOGFORMAT = "%(log_color)s %(asctime)s [%(filename)-12.12s] [%(lineno)4d] [%(processName)-12.12s] [%(threadName)-12.12s] [%(levelname)-7.7s]  %(message)s"
+    logging.root.setLevel(LOG_LEVEL)
+    formatter = ColoredFormatter(LOGFORMAT)
+    stream = logging.StreamHandler()
+    stream.setLevel(LOG_LEVEL)
+    stream.setFormatter(formatter)
+    log = logging.getLogger('pythonConfig')
+    log.setLevel(LOG_LEVEL)
+    log.addHandler(stream)
+
+
+    return log
+
+
+@static_vars(my_config=None)
+def load_config(force_reload=False):
+    """Finds the config file and loads it"""
+    if (not force_reload) and (load_config.my_config is not None):
+        log.info("Found cached config data and will use that.")
+        return load_config.my_config
+    
+    # config_files = glob.glob('/kiap/src/kiap_bci/config.yaml', recursive=True)
+    config_files = glob.glob('./config.yaml', recursive=True)
+
+    # Load the params file
+    if config_files:
+        config_fname = config_files[0]
+
+        with open(config_fname) as stream:
+            params = munch.fromYAML(stream, Loader=yaml.FullLoader)
+
+        validation_passed, validation_error = vd.validate_schema(params)
+        if not validation_passed:
+            log.error(validation_error)
+            raise ValueError('Configuration is not valid !')
+
+        try:
+            with open('paradigm.yaml') as stream:
+                params.paradigms = munch.fromYAML(stream, Loader=yaml.FullLoader)
+        except Exception as e:
+            log.warning(f'Could not load paridigm yaml file.\n{e}')
+
+            
+        supplemental_cfgs = []
+        try:
+            for sfn in params.supplemental_config:
+                file_list = []
+                sfn_path = pathlib.Path(sfn)
+                if sfn_path.exists() and sfn_path.is_dir():
+                    file_list += sfn_path.glob('**/*.yml')
+                    file_list += sfn_path.glob('**/*.yaml')
+                    file_list.sort(key=lambda p : str(p.absolute()).lower())
+                else:
+                    file_list.append(sfn_path)
+                for a_file in file_list:
+                    try:
+                        log.info("Reading supplementary config file '{}'.".format(a_file))
+                        with open(a_file) as stream:
+                            supplemental_cfgs.append(munch.fromYAML(stream, Loader=yaml.Loader))
+                    except FileNotFoundError as e:
+                        log.warning("Supplemental config file '{}' not found. This option will be ignored.".format(a_file), exc_info=1)
+            params = reduce(lambda xx, yy: munch.Munch(mergemunch(xx, yy)), supplemental_cfgs, params)
+        except AttributeError as e:
+            log.info("Attribute 'supplemental_config' not set in config file.")
+        
+        validation_passed, validation_error = vd.validate_schema(params)
+        if not validation_passed:
+            log.error(validation_error)
+            raise ValueError('Configuration is not valid !')
+
+        params = setfileattr(params)
+        params = config_setup(params)
+        params = eval_ranges(params)
+
+        params.buffer.shape = [params.buffer.length, params.daq.n_channels]
+
+    else:
+        log.debug("No file called 'config.yaml' found, please save the file in BCI folder. Shutting down...")
+        sys.exit("CONFIG FILE NOT FOUND")
+    load_config.my_config = params
+    return params
+
+# This function merges two munch dictionaries. Use as munch.Munch(mergemunch(m1, m2))
+def mergemunch(dict1, dict2):
+    for k in set(dict1) | set(dict2):
+        if k in dict1 and k in dict2:
+            if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
+                yield k, munch.Munch(mergemunch(dict1[k], dict2[k]))
+            else:
+                # If one of the values is not a dict, you can't continue merging it.
+                # Value from second dict overrides one in first and we move on.
+                yield k, dict2[k]
+                # Alternatively, replace this with exception raiser to alert you of value conflicts
+        elif k in dict1:
+            yield k, dict1[k]
+        else:
+            yield k, dict2[k]
+
+
+def eval_ranges(params):
+    '''evaluate all ranges, currently only for template'''
+
+    params.daq.n_channels = params.daq.n_channels_max -len(params.daq.exclude_channels)
+
+    if 'range' in params.classifier.template:
+        params.classifier.template = np.array(eval(params.classifier.template))
+    else:
+        params.classifier.template = np.array(params.classifier.template)
+
+    # if 'all' in params.classifier.channel_mask:
+        # params.classifier.channel_mask = list(range(0,params.daq.n_channels))
+    if 'range' in params.classifier.exclude_channels:
+        params.classifier.exclude_channels = list(eval(params.classifier.exclude_channels))
+
+    if 'range' in params.classifier.include_channels:
+        params.classifier.include_channels = list(eval(params.classifier.include_channels))
+
+    if 'range' in params.lfp.array1:
+        params.lfp.array1 = list(eval(params.lfp.array1))
+    if 'range' in params.lfp.array21:
+        params.lfp.array21 = list(eval(params.lfp.array21))
+    if 'range' in params.lfp.array22:
+        params.lfp.array22 = list(eval(params.lfp.array22))
+    params.lfp.array2 = list(params.lfp.array21 + params.lfp.array22)
+    return params
+
+
+def setfileattr(params):
+
+    tnow = datetime.datetime.now().strftime('%H_%M_%S')
+    today = str(datetime.date.today())
+
+    datafile_path = os.path.join(params.file_handling.data_path, today)
+    params.file_handling.datafile_path = datafile_path
+    pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
+
+    # filename_data = 'data.bin'
+    # filename_log_info = 'info.log'
+    # filename_log_debug = 'debug.log'
+    # filename_events = 'events.txt'
+    filename_data = f'data_{tnow}.bin'
+    filename_baseline = f'bl_{tnow}.npy'
+    filename_log_info = f'info_{tnow}.log'
+    filename_log_debug = f'debug_{tnow}.log'
+    filename_events = f'events_{tnow}.txt'
+    filename_config = f'config_{tnow}.yaml'
+    filename_paradigm = f'paradigm_{tnow}.yaml'
+    filename_config_dump = f'config_dump_{tnow}.yaml'
+    filename_git_patch = f'git_changes_{tnow}.patch'
+    filename_history = f'history.bin'
+    params.file_handling.filename_data = os.path.join(datafile_path, filename_data)
+    params.file_handling.filename_baseline = os.path.join(datafile_path, filename_baseline)
+    params.file_handling.filename_log_info = os.path.join(datafile_path, filename_log_info)
+    params.file_handling.filename_log_debug = os.path.join(datafile_path, filename_log_debug)
+    params.file_handling.filename_events = os.path.join(datafile_path, filename_events)
+    params.file_handling.filename_config = os.path.join(datafile_path, filename_config)
+    params.file_handling.filename_paradigm = os.path.join(datafile_path, filename_paradigm)
+    params.file_handling.filename_config_dump = os.path.join(datafile_path, filename_config_dump)
+    params.file_handling.filename_git_patch = os.path.join(datafile_path, filename_git_patch)
+    params.file_handling.filename_history = os.path.join(datafile_path, filename_history)
+
+    # get current git hash and store it
+    git_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'])
+    params.file_handling.git_hash = git_hash.decode("utf-8")[:-1]
+    #update config file before saving it
+        
+    return params
+
+# def add_timestamps(params):
+
+#     tnow = datetime.datetime.now().strftime('%H_%M_%S')
+    
+#     today = str(datetime.date.today())
+#     datafile_path = os.path.join(params.file_handling.data_path, today)
+#     params.file_handling.datafile_path = datafile_path
+#     pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
+
+#     filename_data = f'data_{tnow}.bin'
+#     filename_log_info = f'info_{tnow}.log'
+#     filename_log_debug = f'debug_{tnow}.log'
+#     filename_events = f'events_{tnow}.txt'
+
+#     filename_data = os.path.join(datafile_path, filename_data)
+#     filename_log_info = os.path.join(datafile_path, filename_log_info)
+#     filename_log_debug = os.path.join(datafile_path, filename_log_debug)
+#     filename_events = os.path.join(datafile_path, filename_events)
+
+#     log.info(f'Data file: {filename_data}')
+#     # os.rename(params.file_handling.filename_data, filename_data)
+#     # os.rename(params.file_handling.filename_log_info, filename_log_info)
+#     # os.rename(params.file_handling.filename_log_debug, filename_log_debug)
+#     # os.rename(params.file_handling.filename_events, filename_events)
+
+#     # log.info('Added timestamps to files')
+
+#     return None
+
+def config_setup(params):
+    """Interprets the config file and updates BCI parameters."""
+
+    params.session.flags.stimulus = True
+    params.session.flags.decode = False
+
+    if params.speller.type == 'norm':
+        params.session.flags.stimulus = False
+
+    return params
+
+
+if 'log' not in locals():
+    log = config_logging()
+
+if args.log is not None:
+    log.info(f'Debug level is: {args.log.upper()}')
+else:
+    log.info('Debug level is INFO')

+ 12 - 0
scripts_for_figures/helpers/nsp.py

@@ -0,0 +1,12 @@
+import numpy as np
+
+def fix_timestamps(tstmp, offsets=None):
+    tstmp = tstmp.astype(np.int64)
+    rollover, = np.where(np.diff(tstmp) < 0)
+    if offsets is None:
+        offsets = tstmp[rollover]
+    for j, i in enumerate(rollover):
+        tstmp[(i + 1):] = tstmp[(i + 1):] + offsets[j]
+        # offsets[j:] += tstmp[i]
+    cum_offsets = np.cumsum(offsets)
+    return tstmp, offsets, cum_offsets

+ 181 - 0
scripts_for_figures/helpers/ringbuffer.py

@@ -0,0 +1,181 @@
+import numpy as np
+from collections import Sequence
+
+class RingBuffer(Sequence):
+	def __init__(self, capacity, dtype=float, allow_overwrite=True):
+		"""
+		Create a new ring buffer with the given capacity and element type
+		Parameters
+		----------
+		capacity: int
+			The maximum capacity of the ring buffer
+		dtype: data-type, optional
+			Desired type of buffer elements. Use a type like (float, 2) to
+			produce a buffer with shape (N, 2)
+		allow_overwrite: bool
+			If false, throw an IndexError when trying to append to an already
+			full buffer
+		"""
+		self._arr = np.empty(capacity, dtype)
+		self._left_index = 0
+		self._right_index = 0
+		self._capacity = capacity
+		self._allow_overwrite = allow_overwrite
+
+	def _unwrap(self):
+		""" Copy the data from this buffer into unwrapped form """
+		return np.concatenate((
+			self._arr[self._left_index:min(self._right_index, self._capacity)],
+			self._arr[:max(self._right_index - self._capacity, 0)]
+		))
+
+	def _fix_indices(self):
+		"""
+		Enforce our invariant that 0 <= self._left_index < self._capacity
+		"""
+		if self._left_index >= self._capacity:
+			self._left_index -= self._capacity
+			self._right_index -= self._capacity
+		elif self._left_index < 0:
+			self._left_index += self._capacity
+			self._right_index += self._capacity
+
+	@property
+	def is_full(self):
+		""" True if there is no more space in the buffer """
+		return len(self) == self._capacity
+
+	# numpy compatibility
+	def __array__(self):
+		return self._unwrap()
+
+	@property
+	def dtype(self):
+		return self._arr.dtype
+
+	@property
+	def shape(self):
+		return (len(self),) + self._arr.shape[1:]
+
+
+	# these mirror methods from deque
+	@property
+	def maxlen(self):
+		return self._capacity
+
+	def append(self, value):
+		if self.is_full:
+			if not self._allow_overwrite:
+				raise IndexError('append to a full RingBuffer with overwrite disabled')
+			elif not len(self):
+				return
+			else:
+				self._left_index += 1
+
+		self._arr[self._right_index % self._capacity] = value
+		self._right_index += 1
+		self._fix_indices()
+
+	def appendleft(self, value):
+		if self.is_full:
+			if not self._allow_overwrite:
+				raise IndexError('append to a full RingBuffer with overwrite disabled')
+			elif not len(self):
+				return
+			else:
+				self._right_index -= 1
+
+		self._left_index -= 1
+		self._fix_indices()
+		self._arr[self._left_index] = value
+
+	def pop(self):
+		if len(self) == 0:
+			raise IndexError("pop from an empty RingBuffer")
+		self._right_index -= 1
+		self._fix_indices()
+		res = self._arr[self._right_index % self._capacity]
+		return res
+
+	def popleft(self):
+		if len(self) == 0:
+			raise IndexError("pop from an empty RingBuffer")
+		res = self._arr[self._left_index]
+		self._left_index += 1
+		self._fix_indices()
+		return res
+
+	def extend(self, values):
+		lv = len(values)
+		if len(self) + lv > self._capacity:
+			if not self._allow_overwrite:
+				raise IndexError('extend a RingBuffer such that it would overflow, with overwrite disabled')
+
+		if lv >= self._capacity:
+			# wipe the entire array! - this may not be threadsafe
+			self._arr[...] = values[-self._capacity:]
+			self._right_index = self._capacity
+			self._left_index = 0
+			return
+
+		ri = self._right_index % self._capacity
+		sl1 = np.s_[ri:min(ri + lv, self._capacity)]
+		sl2 = np.s_[:max(ri + lv - self._capacity, 0)]
+		self._arr[sl1] = values[:sl1.stop - sl1.start]
+		self._arr[sl2] = values[sl1.stop - sl1.start:]
+		self._right_index += lv
+
+		self._left_index = max(self._left_index, self._right_index - self._capacity)
+		self._fix_indices()
+
+	def extendleft(self, values):
+		lv = len(values)
+		if len(self) + lv > self._capacity:
+			if not self._allow_overwrite:
+				raise IndexError('extend a RingBuffer such that it would overflow, with overwrite disabled')
+			elif not len(self):
+				return
+		if lv >= self._capacity:
+			# wipe the entire array! - this may not be threadsafe
+			self._arr[...] = values[:self._capacity]
+			self._right_index = self._capacity
+			self._left_index = 0
+			return
+
+		self._left_index -= lv
+		self._fix_indices()
+		li = self._left_index
+		sl1 = np.s_[li:min(li + lv, self._capacity)]
+		sl2 = np.s_[:max(li + lv - self._capacity, 0)]
+		self._arr[sl1] = values[:sl1.stop - sl1.start]
+		self._arr[sl2] = values[sl1.stop - sl1.start:]
+
+		self._right_index = min(self._right_index, self._left_index + self._capacity)
+
+
+	# implement Sequence methods
+	def __len__(self):
+		return self._right_index - self._left_index
+
+	def __getitem__(self, item):
+		# handle simple (b[1]) and basic (b[np.array([1, 2, 3])]) fancy indexing specially
+		if not isinstance(item, tuple):
+			if isinstance(item, int) and item < 0:
+				# ringbuf[-<n>]
+				item = (item + self._right_index) % self._capacity
+				return self._arr[item]
+			item_arr = np.asarray(item)
+			if issubclass(item_arr.dtype.type, np.integer):
+				item_arr = (item_arr + self._left_index) % self._capacity
+				return self._arr[item_arr]
+
+		# for everything else, get it right at the expense of efficiency
+		return self._unwrap()[item]
+
+	def __iter__(self):
+		# alarmingly, this is comparable in speed to using itertools.chain
+		return iter(self._unwrap())
+
+	# Everything else
+	def __repr__(self):
+		return '<RingBuffer of {!r}>'.format(np.asarray(self))

+ 98 - 0
scripts_for_figures/helpers/sessions.py

@@ -0,0 +1,98 @@
+from pathlib import Path
+import yaml
+import munch
+import pandas as pd
+import numpy as np
+from helpers import data_management as dm
+from helpers.nsp import fix_timestamps
+import os
+import re
+import logging
+
+EVENT_FIRST_LINE_RE = re.compile(r"^(\d+),.*Block, start'$")
+EVENT_LAST_LINE_RE = re.compile(r"^(\d+),.*Block, stop'$")
+
+logger = logging.getLogger('KIAP.sessions')
+
+
+def check_event_file_format(ev_file):
+    with open(ev_file, "rb") as f:
+        first = f.readline().decode()  # Read the first line.
+        f.seek(-2, os.SEEK_END)  # Jump to the second last byte.
+        while f.read(1) != b"\n":  # Until EOL is found...
+            f.seek(-2, os.SEEK_CUR)  # ...jump back the read byte plus one more.
+        last = f.readline().decode()  # Read last line.
+    return (EVENT_FIRST_LINE_RE.match(first) is not None) and (EVENT_LAST_LINE_RE.match(last) is not None)
+
+
+def get_sessions(path, mode=None, n=None, start=0):
+    this_path = Path(path)
+    config_files = sorted(this_path.glob("**/config_dump_*.yaml"))
+
+    config_files_read = []
+    for cfg_path in config_files[start:]:
+        try:
+            with open(cfg_path, 'r') as f:
+                # cfg = yaml.load(f, Loader=yaml.Loader)
+                cfg = munch.Munch.fromYAML(f, Loader=yaml.Loader)
+            logger.info(f"Loading {cfg_path}")
+
+            if (mode is None) or (mode == cfg.speller.type):
+                cfg_d = {
+                    'mode': cfg.speller.type,
+                    'cfg': str(Path(*Path(cfg_path).parts[-2:])),
+                    'events': cfg.file_handling.get('filename_events'),
+                    'data': cfg.file_handling.get('filename_data'),
+                    'log': cfg.file_handling.get('filename_log_info')
+                }
+                if cfg_d['events'] is None or cfg_d['data'] is None or cfg_d['log'] is None:
+                    continue
+
+                cfg_d['events'] = str(Path(*Path(cfg_d['events']).parts[-2:]))
+                if not check_event_file_format(this_path / cfg_d['events']):
+                    logger.warning(f"{cfg_d['events']} is not valid (first / last line does not match schema)")
+                    #continue
+
+                cfg_d['data'] = str(Path(*Path(cfg_d['data']).parts[-2:]))
+                cfg_d['log'] = str(Path(*Path(cfg_d['log']).parts[-2:]))
+                config_files_read.append(cfg_d)
+                if (n is not None) and (len(config_files_read) >= n):
+                    break
+        except FileNotFoundError as e:
+            logger.warning(f"A file related to {cfg_path} was not found ({e}).")
+
+    #            config_files_read.append(cfg)
+    cfg_pd = pd.DataFrame(config_files_read)
+
+    return (cfg_pd, len(config_files))
+
+
+TRE = re.compile(r"^(\d+),.*$")
+
+
+def get_session_data(path, session):
+    """
+    Load data for a session. Requires the the session configuration file to deduce file format.
+    returns time vector for samples, data, and channel list.
+    """
+    fn_sess = Path(path, session['data'])
+    fn_evs = Path(path, session['events'])
+    logger.debug(f"gsd loading {fn_evs}")
+    with open(Path(path, session['cfg']), 'r') as f:
+        params = munch.Munch.fromYAML(f, Loader=yaml.Loader)
+
+    datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params)
+    ts, offsets, _ = fix_timestamps(ts)
+
+    with open(fn_evs, 'r') as f:
+        evs = f.readlines()
+    times = []
+    for ev in evs:
+        mtch = TRE.match(ev)
+        times.append(int(mtch.group(1)))
+    tsevs, _, _ = fix_timestamps(np.array(times))
+    evt = [tsevs[0] / 3e4, tsevs[-1] / 3e4]
+
+    tv = ts / 3e4
+
+    return tv, datav, ch_rec_list, evt

+ 16 - 0
scripts_for_figures/helpers/tsdumper.py

@@ -0,0 +1,16 @@
+from yaml import CDumper
+from yaml.representer import SafeRepresenter
+import datetime
+import pandas as pd
+
+
+class TSDumper(CDumper):
+    pass
+
+
+def timestamp_representer(dumper, data):
+    return SafeRepresenter.represent_datetime(dumper, data.to_pydatetime())
+
+
+TSDumper.add_representer(datetime.datetime, SafeRepresenter.represent_datetime)
+TSDumper.add_representer(pd.Timestamp, timestamp_representer)

+ 80 - 0
scripts_for_figures/helpers/validate_config.py

@@ -0,0 +1,80 @@
+import glob
+
+import yaml
+
+from cerberus import Validator
+
+
+def validate_schema(params):
+    # config_files = glob.glob('**/config.yaml', recursive=True)
+    # fname = config_files[0]
+
+    # with open(fname) as stream:
+    #     params = yaml.load(stream, Loader=yaml.FullLoader)
+
+    # print(params)
+
+    schema = {
+        'daq': {
+            'type': 'dict',
+            'schema': {
+                'n_channels_max': {'required': True, 'type': 'integer', 'min': 1, 'max': 128},
+                'daq_sleep': {'required': True, 'type': 'float', 'min': 0, 'max': 1.},
+                'fs': {'required': True, 'type': 'float', 'min': 1., 'max': 30000.},
+                'smpl_fct': {'required': True, 'type': 'integer', 'min': 1, 'max': 30000},
+                        }
+                    },
+        'speller': {
+            'type': 'dict',
+            'schema': {
+                'type': {'required': True, 'type': 'string', 'allowed': ['question', 'exploration', 'training_color', 'color', 'feedback']}
+                      }
+                    },
+        'recording': {
+            'type': 'dict',
+            'schema': {
+                'timing': {
+                    'type': 'dict',
+                    'schema': {
+                        't_baseline_1': {'required': True, 'type': 'float', 'min': 0., 'max': 600.},
+                        't_baseline_all': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        't_baseline_rand': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        't_after_stimulus': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        't_response': {'required': True, 'type': 'float', 'min': 0., 'max': 60.},
+                        'decoder_refresh_interval': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        'bci_loop_interval': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        'recording_loop_interval': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                        'recording_loop_interval_data': {'required': True, 'type': 'float', 'min': 0., 'max': 10.},
+                              }
+                            }
+                        }
+                    }
+            }
+
+    # print(schema)
+
+    v = Validator()
+    v.schema = schema
+    v.allow_unknown = True
+
+    validation_passed = v.validate(params, schema)
+    errors = v.errors
+
+    # print("Validation validation_passed: {}".format(validation_passed))
+    # print("Errors: {}".format(errors))
+
+    return validation_passed, errors
+
+
+if __name__ == '__main__':
+    config_files = glob.glob('**/config.yaml', recursive=True)
+    fname = config_files[0]
+    validation_passed, errors = validate_schema(fname)
+    # print(validation_passed)
+    if validation_passed is True:
+        print("Config file validation success! All values are as expected.")
+    else:
+        print("Config file validation failed with the following errors:")
+
+        for key, value in errors.items():
+            print("\nerror in {}: {}".format(key, errors[key]))

File diff suppressed because it is too large
+ 1142 - 0
scripts_for_figures/plot_figures_part_A.py


+ 213 - 0
scripts_for_figures/plot_figures_part_B.py

@@ -0,0 +1,213 @@
+from helpers import data_management as dm
+import matplotlib.pyplot as plt
+import matplotlib
+from pathlib import Path
+# matplotlib.use('TkAgg')
+
+import numpy as np
+import pandas as pd
+
+from helpers.data import DataNormalizer
+import re
+
+import yaml
+import munch
+
+import scipy
+import scipy.interpolate
+from scipy import stats
+
+from helpers.fbplot import plot_df_combined
+from helpers.nsp import fix_timestamps
+from basics import BASE_PATH, BASE_PATH_OUT
+
+plot_t = munch.munchify({'response_start_samples': {'t': [-1.5, 3.0]}})
+
+plot_order_comb = munch.munchify({
+    'p': [
+        {'r': 0, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': '-', 'lw': 2, 'show_var': True,
+         'sel': lambda df: ((df.target == 'up') | (df.target == 'yes')) & (df.decision == 'up'), 'desc': 'Target: up'},
+        {'r': 0, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': '-', 'lw': 2, 'show_var': True,
+         'sel': lambda df: ((df.target == 'down') | (df.target == 'no')) & (df.decision == 'down'),
+         'desc': 'Target: down'},
+        {'r': 1, 'c': 0, 'col': (0.635, 0.078, 0.184, .2), 'ls': ':', 'lw': 2,
+         'show_var': False, 'sel': lambda df: (df.target == 'up') & (df.decision == 'down'),
+         'desc': 'Target: up, decision: down'},
+        {'r': 1, 'c': 1, 'col': (0, 0.447, 0.741, .2), 'ls': ':', 'lw': 2, 'show_var': False,
+         'sel': lambda df: (df.target == 'down') & (df.decision == 'up'), 'desc': 'Target: down, decision: up'}
+    ],
+    's': [
+        {'dcol': 'response_start_samples', 'title': 'Response Period', 'norm': False}],
+    'state': {'down': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
+              'up': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'},
+              'no': {'c': (0, 0.447, 0.741, .2), 'd': 'Low Frequency Target Tone'},
+              'yes': {'c': (0.635, 0.078, 0.184, .2), 'd': 'High Frequency Target Tone'}}},
+)
+
+SAVENAME = f'Figure_3A_FBTrials'
+
+
+def extract_trials(filename, offsets=None):
+    if offsets is None:
+        offsets = []
+    with open(filename, 'r') as f:
+        evs = f.read().splitlines()
+    # fix event timestamps
+    tpat = re.compile(r"^(\d+)(, .*)$")
+    stage_pat = re.compile(r"(\d+), b'(feedback|question), (\w+), (\w+), (\w+), (\w+)'$")
+
+    decpat = re.compile(r".*\s(\w+), Decoder decision is: (.*)'$")
+
+    evs_time = []
+    for ev in evs:
+        m = tpat.match(ev)
+        evs_time.append(np.int64(m.group(1)))
+    evs_time = np.asarray(evs_time, dtype=np.int64)
+    rollev, = np.where(np.diff(evs_time) < 0)
+    for j, i in enumerate(rollev):
+        evs_time[(i + 1):] += offsets[j]
+
+    trials = []
+
+    this_trial = {}
+    for ev, ev_time in zip(evs, evs_time):
+        m = stage_pat.match(ev)
+        if m is not None:
+            if m.group(5) == 'Block':
+                continue
+            if m.group(5) == 'baseline' and m.group(6) == 'start':
+                if len(this_trial) > 0:
+                    trials.append(this_trial)
+                this_trial = {'baseline_start': -1, 'baseline_stop': -1, 'stimulus_start': -1,
+                              'stimulus_stop': -1, 'response_start': -1, 'response_stop': -1, 'target': m.group(4),
+                              'response_start_samples': None, 'response_start_samples_wnan': None,
+                              'response_start_samples_norm': None, 'response_start_samples_good': None}
+            this_trial[f'{m.group(5)}_{m.group(6)}'] = ev_time
+            continue
+        m = decpat.match(ev)
+        if m is not None:
+            if m.group(2) == 'yes':
+                this_trial['decision'] = 'up'
+            elif m.group(2) == 'no':
+                this_trial['decision'] = 'down'
+            else:
+                this_trial['decision'] = m.group(2)
+    if this_trial['decision']:
+        trials.append(this_trial)
+
+    return pd.DataFrame(trials)
+
+
+s = {'day': '2019-11-21', 'cfg_t': '15_23_18'}
+
+day_str = s['day']
+cfg_t_str = s['cfg_t']
+data_t_str = s.get('data_t', s['cfg_t'])
+
+try:
+    fn_cfgdump = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'config_dump_{cfg_t_str}.yaml'
+    with open(fn_cfgdump) as stream:
+        params = yaml.load(stream, Loader=yaml.Loader)
+except Exception as e:
+    print(e)
+
+fb_states = params.paradigms.feedback.states
+
+for (k, v) in plot_t.items():
+    plot_t[k].s = [int(plot_t[k].t[0] * 1000 / params.daq.spike_rates.loop_interval),
+                   int(plot_t[k].t[1] * 1000 / params.daq.spike_rates.loop_interval)]
+    plot_t[k].swin = np.arange(plot_t[k].s[0], plot_t[k].s[1] + 1)
+    plot_t[k].twin = plot_t[k].swin / 1000.0 * params.daq.spike_rates.loop_interval
+
+# print(params)
+
+fn_sess = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'data_{data_t_str}.bin'
+fn_evs = BASE_PATH / 'KIAP_BCI_neurofeedback' / day_str / f'events_{data_t_str}.txt'
+
+datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params)
+ts, offsets, _ = fix_timestamps(ts)
+
+tv = ts / 3e4
+
+trs = extract_trials(fn_evs, offsets=offsets)
+
+dn = DataNormalizer(params)
+fr = dn.calculate_norm_rate(datav)
+
+plot_data = np.reshape(fr, (-1, 1))
+labels = ('fr',)
+
+sess_info = 'Channels used for control: [' + ' '.join([f"{ch.id}" for ch in params.daq.normalization.channels]) + ']'
+
+for label, dv in zip(labels, np.hsplit(plot_data, plot_data.shape[1])):
+    dv = np.reshape(dv, (-1))
+    for ii, row in trs.iterrows():
+        t = row['stimulus_start']
+        t_off = np.where(ts >= t)[0][0] - 1
+
+        t_stop = row['response_stop']
+        if len(np.where(ts >= t_stop)[0]) > 0:
+            t_stop_off = np.where(ts >= t_stop)[0][0] - 1
+
+        t_rstart = row['response_start']
+        if len(np.where(ts >= t_rstart)[0]) > 0:
+            t_rstart_off = np.where(ts >= t_rstart)[0][0] - 1
+
+            resp_start_idx = plot_t.response_start_samples.swin + t_rstart_off
+            if np.all(resp_start_idx < len(dv)):
+                trs.at[ii, 'response_start_samples'] = dv[resp_start_idx]
+            else:
+                trs.at[ii, 'response_start_samples'] = np.empty(resp_start_idx.shape)
+                trs.at[ii, 'response_start_samples'][:] = np.nan
+                good_vals = resp_start_idx < len(dv)
+                trs.at[ii, 'response_start_samples'][good_vals] = dv[resp_start_idx[good_vals]]
+
+            trs.at[ii, 'response_start_offset'] = (ts[t_rstart_off] - t_rstart) / 3e4
+            trs.at[ii, 'response_start_samples_good'] = resp_start_idx < t_stop_off - 1
+            # trs.at[ii, 'response_start_samples_good'] = np.full(plot_t.response_start_samples.swin.shape, True)
+
+            trs.at[ii, 'response_start_samples_wnan'] = trs.at[ii, 'response_start_samples']
+            trs.at[ii, 'response_start_samples_wnan'][~ trs.at[ii, 'response_start_samples_good']] = np.nan
+            fp = trs.at[ii, 'response_start_samples'][trs.at[ii, 'response_start_samples_good']]
+            xp = plot_t.response_start_samples.twin[trs.at[ii, 'response_start_samples_good']]
+            x = np.linspace(xp[0], xp[-1], len(plot_t.response_start_samples.twin))
+            f = scipy.interpolate.interp1d(xp, fp, kind='cubic')
+            trs.at[ii, 'response_start_samples_norm'] = np.interp(x, xp, fp)
+
+    if np.all(np.vstack(trs['response_start_samples']) == 0):
+        continue
+
+BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
+
+fig = plt.figure(35, figsize=(16, 9))
+fig.clf()
+
+plot_df_combined(trs, plot_order_comb, plot_t, fb_states, fig=fig,
+                 options=munch.Munch({'show_thresholds': label == 'fr', 'show_end': 0, 'show_median': False}))
+fig.suptitle(f"{day_str} {cfg_t_str} ch:{label}\n{sess_info} n={len(trs)}")
+fig.show()
+
+savename = BASE_PATH_OUT / SAVENAME
+
+fig.savefig(savename.with_suffix('.pdf'))
+fig.savefig(savename.with_suffix('.eps'))
+fig.savefig(savename.with_suffix('.svg'))
+
+n_correct_up = ((trs.target == trs.decision) & (trs.target == 'up')).sum()
+n_correct_down = ((trs.target == trs.decision) & (trs.target == 'down')).sum()
+n_error_up = (('down' == trs.decision) & (trs.target == 'up')).sum()
+n_error_down = (('up' == trs.decision) & (trs.target == 'down')).sum()
+n_timeout_up = (('unclassified' == trs.decision) & (trs.target == 'up')).sum()
+n_timeout_down = (('unclassified' == trs.decision) & (trs.target == 'down')).sum()
+
+n_total = len(trs)
+print(
+    f"Total trials: {n_total}\nTarget up, decision up: {n_correct_up} trials.\nTarget up, decision down: {n_error_up} trials.\nTarget down, decision up: {n_error_down} trials.\nTarget down, decision down: {n_correct_down} trials.\nCorrect: {(n_correct_up + n_correct_down) / n_total}. Correct up: {n_correct_up / (n_correct_up + n_error_up)}\nCorrect down: {n_correct_down / (n_correct_down + n_error_down)}\nTime-out up target: {n_timeout_up}; Time-out down target: {n_timeout_down}")
+
+# Export as CSV
+trs_export = pd.DataFrame()
+trs_export['trial_data'] = [r.response_start_samples[r.response_start_samples_good] for i, r in trs.iterrows()]
+trs_export['trial_times'] = [plot_t['response_start_samples']['twin'][r.response_start_samples_good] for i, r in trs.iterrows()]
+trs_export['target'] = trs['target']
+trs_export['decision'] = trs['decision']
+trs_export.to_csv(savename.with_suffix('.csv'))

+ 241 - 0
scripts_for_figures/plot_figures_part_C.py

@@ -0,0 +1,241 @@
+import os
+from typing import List, Union
+from datetime import datetime as dt
+
+from helpers import data_management as dm
+import matplotlib.pyplot as plt
+import matplotlib
+import pandas as pd
+import numpy as np
+import yaml
+from helpers.data import DataNormalizer
+import re
+
+from helpers.nsp import fix_timestamps
+from basics import BASE_PATH, BASE_PATH_OUT, IMPLANT_DATE
+
+plot_win_len = 120.0
+
+s = {'day': '2019-07-05', 'cfg_t': '14_42_36', 'data_t': '14_42_36', 's': np.datetime64('2019-07-05T14:40:00'),
+     'e': np.datetime64('2019-07-05T15:11:45'), 'title_str':'KIAP Session day 108 14:42 Free Speller',
+     'plot_start': 980, 'plot_win': 90, 'plot_end': 1070}
+
+fn_sess = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'data_{s["data_t"]}.bin'
+fn_evs = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'events_{s["data_t"]}.txt'
+fn_spl = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'debug_{s["cfg_t"]}.log'
+fn_cfg = BASE_PATH / 'KIAP_BCI_speller' / s["day"] / f'config_dump_{s["cfg_t"]}.yaml'
+
+s_dt = dt.strptime(s["day"], '%Y-%m-%d')
+
+days_post_implant = (s_dt - IMPLANT_DATE).days
+
+with open(fn_cfg) as stream:
+    params = yaml.load(stream, Loader=yaml.Loader)
+
+datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params, t_lim_start=s.get('s'), t_lim_end=s.get('e'))
+ts, offsets,_ = fix_timestamps(ts)
+tv = ts / 3e4
+
+with open(fn_evs, 'r') as f:
+    evs = f.read().splitlines()
+with open(fn_spl, 'r') as f:
+    splevs = f.read().splitlines()
+
+# fix event timestamps
+tpat = re.compile(r"^(\d+)(, .*)$")
+evs_time = []
+for ev in evs:
+    m = tpat.match(ev)
+    evs_time.append(np.int64(m.group(1)))
+evs_time = np.asarray(evs_time, dtype=np.int64)
+rollev, = np.where(np.diff(evs_time) < 0)
+for j, i in enumerate(rollev):
+    evs_time[(i + 1):] += offsets[j]
+# parse decoder decisions
+pevs = []
+lpat = re.compile(r"^(\d+), b'(.*)'$")
+decpat = re.compile(r".*Decoder decision is: (.*)$")
+
+alt_resp_pat = re.compile(r".*, response, stop$")
+
+for ev, ev_time in zip(evs, evs_time):
+    m = lpat.match(ev)
+    # evts = m.group(1)
+    evstr = m.group(2)
+    m2 = decpat.match(evstr)
+    if m2 is not None:
+        pevs.append([float(ev_time) / 3e4, m2.group(1)])
+if len(pevs) == 0:
+    for ev, ev_time in zip(evs, evs_time):
+        m = lpat.match(ev)
+        # evts = m.group(1)
+        evstr = m.group(2)
+        m2 = alt_resp_pat.match(evstr)
+        if m2 is not None:
+            pevs.append([float(ev_time) / 3e4, 'decision'])
+# parse events
+evlist = {}
+evlist['bl'] = []
+evlist['st'] = []
+evlist['re'] = []
+
+rp_ev_pat = re.compile(r"^(\d+), b'.*(stimulus|response|baseline), (start|stop)'$")
+# rp_rst_pat = re.compile(r"^(\d+), b'.*response, start'$")
+tmp_itb = []
+tmp_its = []
+tmp_itr = []
+
+for ev_time, ev in zip(evs_time, evs):
+    m = rp_ev_pat.match(ev)
+    if m is not None:
+        if m.group(2) == 'baseline':
+            if m.group(3) == 'start':
+                tmp_itb.append(float(ev_time) / 3e4)
+                continue
+            elif m.group(3) == 'stop':
+                tmp_itb.append(float(ev_time) / 3e4)
+                evlist['bl'].append(tmp_itb)
+                tmp_itb = []
+                continue
+            continue
+        elif m.group(2) == 'stimulus':
+            if m.group(3) == 'start':
+                tmp_its.append(float(ev_time) / 3e4)
+                continue
+            elif m.group(3) == 'stop':
+                tmp_its.append(float(ev_time) / 3e4)
+                evlist['st'].append(tmp_its)
+                tmp_its = []
+                continue
+            continue
+        elif m.group(2) == 'response':
+            if m.group(3) == 'start':
+                tmp_itr.append(float(ev_time) / 3e4)
+                continue
+            elif m.group(3) == 'stop':
+                tmp_itr.append(float(ev_time) / 3e4)
+                evlist['re'].append(tmp_itr)
+                tmp_itr = []
+                continue
+            continue
+        continue
+
+splpevs = []
+spl_pat = re.compile(r"^.* (\w+) - \('(.*)', '(.*)'\)")
+for sev in splevs:
+    m = spl_pat.match(sev)
+    if m is not None:
+        splpevs.append([m.group(2), m.group(3), m.group(1)])
+
+dn = DataNormalizer(params)
+fr = dn.calculate_norm_rate(datav)
+br_chnum = np.array(dn.norm_rate['ch_ids']) + 1
+
+i_plt = 0
+
+p_start = s.get('plot_start', tv[0])
+p_end = min(s.get('plot_end', tv[-1]), tv[-1])
+p_step = s.get('plot_win', plot_win_len)
+
+while p_start + i_plt * p_step < p_end:
+    i_plt += 1
+    t_win = np.array([p_start + (i_plt - 1) * p_step, min(p_end, p_start + i_plt * p_step)])
+    pl_idx = np.logical_and(t_win[0] <= tv, tv <= t_win[1])
+
+    fig = plt.figure(34, figsize=(12, 6.22))
+    fig.clf()
+    fig.set_tight_layout(True)
+    axs = [None, None]  # fig.subplots(2,1)
+    axs[0] = fig.add_axes((.04, .65, .93, .21))
+    axs[1] = fig.add_axes((.04, .1, .93, .35))
+
+    raw_phs = axs[0].plot(tv[pl_idx], datav[pl_idx, :][:, dn.norm_rate['ch_ids']])
+
+    df_raw_data = pd.DataFrame(index=tv[pl_idx], data=datav[pl_idx, :][:, dn.norm_rate['ch_ids']], columns=dn.norm_rate['ch_ids'])
+    df_raw_data['normalized_Rate'] = fr[pl_idx]
+
+    for ph, ch in zip(raw_phs, br_chnum):
+        ph.set_label(f'Channel {ch}')
+
+    axs[0].set_xlim(t_win[0], t_win[0] + p_step)
+    axs[0].spines['right'].set_color('none')
+    axs[0].spines['top'].set_color('none')
+    axs[0].spines['left'].set_position(('data', t_win[0] - 1))
+    axs[0].spines['bottom'].set_position('zero')
+    axs[0].set_ylabel('Firing Rate [Hz]')
+    axs[0].legend(loc='upper right', ncol=2, fontsize='small')
+    axs[0].set_title('Raw firing rates of channels used for “yes”/“no” classification')
+    frph, = axs[1].plot(tv[pl_idx], fr[pl_idx], label='Normalized Rate', color=(157.0/255, 157.0/255, 157.0/255))
+
+    thr_top = params.paradigms.feedback.states.up[1]
+    thr_bot = params.paradigms.feedback.states.down[2]
+
+    topl = axs[1].hlines(thr_top, t_win[0], t_win[1], color=(0, .6, .1), label='“Yes” threshold')
+    botl = axs[1].hlines(thr_bot, t_win[0], t_win[1], color=(.9, 0, .1), label='“No” threshold')
+
+    plot_speller_events = {'t': [], 'event': [], 'option': [], 'txt': []}
+
+    p: List[Union[float, str]]
+    for p, sp in zip(pevs, splpevs):
+        if t_win[0] <= p[0] <= t_win[1]:
+            cur_spel_st = sp[0]
+            if len(cur_spel_st) > 5:
+                cur_spel_st = "…" + sp[0][-4:]
+            spel_ev_txt = f'{cur_spel_st}\n“{sp[1]}”\n{sp[2]}'
+            plot_speller_events['t'].append(p[0])
+            plot_speller_events['event'].append(sp[2])
+            plot_speller_events['option'].append(sp[1])
+            plot_speller_events['txt'].append(cur_spel_st)
+
+            if sp[2] == 'yes':
+                axs[1].vlines(p[0], 0, 1, color=(0, .6, .1), linestyles=':')
+                axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(0, .6, .1))
+            elif sp[2] == 'no':
+                axs[1].vlines(p[0], 0, 1, color=(.9, 0, .1), linestyles=':')
+                axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.9, 0, .1))
+            else:
+                axs[1].vlines(p[0], 0, 1, color=(.7, .7, .7), linestyles=':')
+                axs[1].text(p[0], 1, spel_ev_txt, horizontalalignment='center', va='bottom', color=(.7, .7, .7))
+
+    df_speller_events = pd.DataFrame(plot_speller_events)
+    sp_periods = {'t_begin': [], 't_end': [], 'event': []}
+    # for p in evlist['bl']:
+    #     axs[1].add_patch(matplotlib.patches.Rectangle((p[0],0), p[1]-p[0], 1, color=(.9, .9, .9)))
+    for p in evlist['st']:
+        if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
+            axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(198.0/255.0, 198.0/255.0, 198.0/255.0), alpha=.15))
+            sp_periods['t_begin'].append(p[0])
+            sp_periods['t_end'].append(p[1])
+            sp_periods['event'].append('stimulus')
+    for p in evlist['re']:
+        if t_win[0] <= p[0] <= t_win[1] and t_win[0] <= p[1] <= t_win[1]:
+            axs[1].add_patch(matplotlib.patches.Rectangle((p[0], 0), p[1] - p[0], 1, color=(87.0/255.0, 46.0/255.0, 136.0/255.0), alpha=.15))
+            sp_periods['t_begin'].append(p[0])
+            sp_periods['t_end'].append(p[1])
+            sp_periods['event'].append('response')
+    df_speller_periods = pd.DataFrame(sp_periods)
+
+    axs[1].set_xlim(t_win[0], t_win[0] + p_step)
+    axs[1].set_xlabel('Time [s]')
+    axs[1].set_ylabel('Normalized Firing Rate')
+    axs[1].legend(fontsize='small', loc='upper right')
+    axs[1].set_title("Normalized rate and speller state", y=1.3)
+
+    axs[1].spines['right'].set_color('none')
+    axs[1].spines['top'].set_color('none')
+    axs[1].spines['left'].set_position(('data', t_win[0] - 1))
+    axs[1].spines['bottom'].set_position('zero')
+
+    fig.suptitle(f'{s["title_str"]} – “{splpevs[-1][0]}”', fontsize=15, wrap=True)#, fontweight='medium')
+
+    fig.show()
+    BASE_PATH_OUT.mkdir(parents=True, exist_ok=True)
+    savename = BASE_PATH_OUT / f'Figure_4_SpellerProgress_plt{i_plt}'
+    fig.savefig(savename.with_suffix('.pdf'), transparent=True)
+    fig.savefig(savename.with_suffix('.svg'), transparent=True)
+    fig.savefig(savename.with_suffix('.eps'), transparent=True)
+
+    df_raw_data.to_csv(savename.with_suffix('.raw.csv'))
+    df_speller_events.to_csv(savename.with_suffix('.events.csv'))
+    df_speller_periods.to_csv(savename.with_suffix('.periods.csv'))
+