123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- import argparse
- import datetime
- import glob
- import logging
- import os
- import pathlib
- import sys
- from enum import Enum
- import yaml
- import munch
- import numpy as np
- import subprocess
- from colorlog import ColoredFormatter
- from . import validate_config as vd
- from functools import reduce
- parser = argparse.ArgumentParser()
- parser.add_argument("gui", nargs='?', help="flag, 1:start gui", type=int, default=0)
- parser.add_argument("plot", nargs='?', help="flag, 1:start plot", type=int, default=0)
- parser.add_argument("--log", help="set verbosity level [DEBUG, INFO, WARNING, ERROR]")
- parser.add_argument('--speller', default='')
- parser.add_argument('-l', '--list', help='delimited list input', type=str)
- args = parser.parse_args()
- class decision(Enum):
- yes = 0
- no = 1
- nc = -3 # not confirmed
- baseline = 2
- unclassified = -1 # e.g. when history is not big enough yet
- error1 = -1 # not enough data to get_class, see classifier get_class2
- error2 = -2
- def static_vars(**kwargs):
- def decorate(func):
- for k in kwargs:
- setattr(func, k, kwargs[k])
- return func
- return decorate
- def config_logging():
- LOG_LEVEL = logging.WARN
-
- if args.log is not None:
- LOG_LEVEL = eval('logging.' + args.log.upper())
- # LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
- LOGFORMAT = "%(log_color)s %(asctime)s [%(filename)-12.12s] [%(lineno)4d] [%(processName)-12.12s] [%(threadName)-12.12s] [%(levelname)-7.7s] %(message)s"
- logging.root.setLevel(LOG_LEVEL)
- formatter = ColoredFormatter(LOGFORMAT)
- stream = logging.StreamHandler()
- stream.setLevel(LOG_LEVEL)
- stream.setFormatter(formatter)
- log = logging.getLogger('pythonConfig')
- log.setLevel(LOG_LEVEL)
- log.addHandler(stream)
- return log
- @static_vars(my_config=None)
- def load_config(force_reload=False):
- """Finds the config file and loads it"""
- if (not force_reload) and (load_config.my_config is not None):
- log.info("Found cached config data and will use that.")
- return load_config.my_config
-
- # config_files = glob.glob('/kiap/src/kiap_bci/config.yaml', recursive=True)
- config_files = glob.glob('./config.yaml', recursive=True)
- # Load the params file
- if config_files:
- config_fname = config_files[0]
- with open(config_fname) as stream:
- params = munch.fromYAML(stream, Loader=yaml.FullLoader)
- validation_passed, validation_error = vd.validate_schema(params)
- if not validation_passed:
- log.error(validation_error)
- raise ValueError('Configuration is not valid !')
- try:
- with open('paradigm.yaml') as stream:
- params.paradigms = munch.fromYAML(stream, Loader=yaml.FullLoader)
- except Exception as e:
- log.warning(f'Could not load paridigm yaml file.\n{e}')
-
- supplemental_cfgs = []
- try:
- for sfn in params.supplemental_config:
- file_list = []
- sfn_path = pathlib.Path(sfn)
- if sfn_path.exists() and sfn_path.is_dir():
- file_list += sfn_path.glob('**/*.yml')
- file_list += sfn_path.glob('**/*.yaml')
- file_list.sort(key=lambda p : str(p.absolute()).lower())
- else:
- file_list.append(sfn_path)
- for a_file in file_list:
- try:
- log.info("Reading supplementary config file '{}'.".format(a_file))
- with open(a_file) as stream:
- supplemental_cfgs.append(munch.fromYAML(stream, Loader=yaml.Loader))
- except FileNotFoundError as e:
- log.warning("Supplemental config file '{}' not found. This option will be ignored.".format(a_file), exc_info=1)
- params = reduce(lambda xx, yy: munch.Munch(mergemunch(xx, yy)), supplemental_cfgs, params)
- except AttributeError as e:
- log.info("Attribute 'supplemental_config' not set in config file.")
-
- validation_passed, validation_error = vd.validate_schema(params)
- if not validation_passed:
- log.error(validation_error)
- raise ValueError('Configuration is not valid !')
- params = setfileattr(params)
- params = config_setup(params)
- params = eval_ranges(params)
- params.buffer.shape = [params.buffer.length, params.daq.n_channels]
- else:
- log.debug("No file called 'config.yaml' found, please save the file in BCI folder. Shutting down...")
- sys.exit("CONFIG FILE NOT FOUND")
- load_config.my_config = params
- return params
- # This function merges two munch dictionaries. Use as munch.Munch(mergemunch(m1, m2))
- def mergemunch(dict1, dict2):
- for k in set(dict1) | set(dict2):
- if k in dict1 and k in dict2:
- if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
- yield k, munch.Munch(mergemunch(dict1[k], dict2[k]))
- else:
- # If one of the values is not a dict, you can't continue merging it.
- # Value from second dict overrides one in first and we move on.
- yield k, dict2[k]
- # Alternatively, replace this with exception raiser to alert you of value conflicts
- elif k in dict1:
- yield k, dict1[k]
- else:
- yield k, dict2[k]
- def eval_ranges(params):
- '''evaluate all ranges, currently only for template'''
- params.daq.n_channels = params.daq.n_channels_max -len(params.daq.exclude_channels)
- if 'range' in params.classifier.template:
- params.classifier.template = np.array(eval(params.classifier.template))
- else:
- params.classifier.template = np.array(params.classifier.template)
- # if 'all' in params.classifier.channel_mask:
- # params.classifier.channel_mask = list(range(0,params.daq.n_channels))
- if 'range' in params.classifier.exclude_channels:
- params.classifier.exclude_channels = list(eval(params.classifier.exclude_channels))
- if 'range' in params.classifier.include_channels:
- params.classifier.include_channels = list(eval(params.classifier.include_channels))
- if 'range' in params.lfp.array1:
- params.lfp.array1 = list(eval(params.lfp.array1))
- if 'range' in params.lfp.array21:
- params.lfp.array21 = list(eval(params.lfp.array21))
- if 'range' in params.lfp.array22:
- params.lfp.array22 = list(eval(params.lfp.array22))
- params.lfp.array2 = list(params.lfp.array21 + params.lfp.array22)
- return params
- def setfileattr(params):
- tnow = datetime.datetime.now().strftime('%H_%M_%S')
- today = str(datetime.date.today())
- datafile_path = os.path.join(params.file_handling.data_path, today)
- params.file_handling.datafile_path = datafile_path
- pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
- # filename_data = 'data.bin'
- # filename_log_info = 'info.log'
- # filename_log_debug = 'debug.log'
- # filename_events = 'events.txt'
- filename_data = f'data_{tnow}.bin'
- filename_baseline = f'bl_{tnow}.npy'
- filename_log_info = f'info_{tnow}.log'
- filename_log_debug = f'debug_{tnow}.log'
- filename_events = f'events_{tnow}.txt'
- filename_config = f'config_{tnow}.yaml'
- filename_paradigm = f'paradigm_{tnow}.yaml'
- filename_config_dump = f'config_dump_{tnow}.yaml'
- filename_git_patch = f'git_changes_{tnow}.patch'
- filename_history = f'history.bin'
- params.file_handling.filename_data = os.path.join(datafile_path, filename_data)
- params.file_handling.filename_baseline = os.path.join(datafile_path, filename_baseline)
- params.file_handling.filename_log_info = os.path.join(datafile_path, filename_log_info)
- params.file_handling.filename_log_debug = os.path.join(datafile_path, filename_log_debug)
- params.file_handling.filename_events = os.path.join(datafile_path, filename_events)
- params.file_handling.filename_config = os.path.join(datafile_path, filename_config)
- params.file_handling.filename_paradigm = os.path.join(datafile_path, filename_paradigm)
- params.file_handling.filename_config_dump = os.path.join(datafile_path, filename_config_dump)
- params.file_handling.filename_git_patch = os.path.join(datafile_path, filename_git_patch)
- params.file_handling.filename_history = os.path.join(datafile_path, filename_history)
- # get current git hash and store it
- git_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'])
- params.file_handling.git_hash = git_hash.decode("utf-8")[:-1]
- #update config file before saving it
-
- return params
- # def add_timestamps(params):
- # tnow = datetime.datetime.now().strftime('%H_%M_%S')
-
- # today = str(datetime.date.today())
- # datafile_path = os.path.join(params.file_handling.data_path, today)
- # params.file_handling.datafile_path = datafile_path
- # pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
- # filename_data = f'data_{tnow}.bin'
- # filename_log_info = f'info_{tnow}.log'
- # filename_log_debug = f'debug_{tnow}.log'
- # filename_events = f'events_{tnow}.txt'
- # filename_data = os.path.join(datafile_path, filename_data)
- # filename_log_info = os.path.join(datafile_path, filename_log_info)
- # filename_log_debug = os.path.join(datafile_path, filename_log_debug)
- # filename_events = os.path.join(datafile_path, filename_events)
- # log.info(f'Data file: {filename_data}')
- # # os.rename(params.file_handling.filename_data, filename_data)
- # # os.rename(params.file_handling.filename_log_info, filename_log_info)
- # # os.rename(params.file_handling.filename_log_debug, filename_log_debug)
- # # os.rename(params.file_handling.filename_events, filename_events)
- # # log.info('Added timestamps to files')
- # return None
- def config_setup(params):
- """Interprets the config file and updates BCI parameters."""
- params.session.flags.stimulus = True
- params.session.flags.decode = False
- if params.speller.type == 'norm':
- params.session.flags.stimulus = False
- return params
- if 'log' not in locals():
- log = config_logging()
- if args.log is not None:
- log.info(f'Debug level is: {args.log.upper()}')
- else:
- log.info('Debug level is INFO')
|