kaux.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import argparse
  2. import datetime
  3. import glob
  4. import logging
  5. import os
  6. import pathlib
  7. import sys
  8. from enum import Enum
  9. import yaml
  10. import munch
  11. import numpy as np
  12. import subprocess
  13. from colorlog import ColoredFormatter
  14. from . import validate_config as vd
  15. from functools import reduce
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("gui", nargs='?', help="flag, 1:start gui", type=int, default=0)
  18. parser.add_argument("plot", nargs='?', help="flag, 1:start plot", type=int, default=0)
  19. parser.add_argument("--log", help="set verbosity level [DEBUG, INFO, WARNING, ERROR]")
  20. parser.add_argument('--speller', default='')
  21. parser.add_argument('-l', '--list', help='delimited list input', type=str)
  22. args = parser.parse_args()
  23. class decision(Enum):
  24. yes = 0
  25. no = 1
  26. nc = -3 # not confirmed
  27. baseline = 2
  28. unclassified = -1 # e.g. when history is not big enough yet
  29. error1 = -1 # not enough data to get_class, see classifier get_class2
  30. error2 = -2
  31. def static_vars(**kwargs):
  32. def decorate(func):
  33. for k in kwargs:
  34. setattr(func, k, kwargs[k])
  35. return func
  36. return decorate
  37. def config_logging():
  38. LOG_LEVEL = logging.WARN
  39. if args.log is not None:
  40. LOG_LEVEL = eval('logging.' + args.log.upper())
  41. # LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
  42. LOGFORMAT = "%(log_color)s %(asctime)s [%(filename)-12.12s] [%(lineno)4d] [%(processName)-12.12s] [%(threadName)-12.12s] [%(levelname)-7.7s] %(message)s"
  43. logging.root.setLevel(LOG_LEVEL)
  44. formatter = ColoredFormatter(LOGFORMAT)
  45. stream = logging.StreamHandler()
  46. stream.setLevel(LOG_LEVEL)
  47. stream.setFormatter(formatter)
  48. log = logging.getLogger('pythonConfig')
  49. log.setLevel(LOG_LEVEL)
  50. log.addHandler(stream)
  51. return log
  52. @static_vars(my_config=None)
  53. def load_config(force_reload=False):
  54. """Finds the config file and loads it"""
  55. if (not force_reload) and (load_config.my_config is not None):
  56. log.info("Found cached config data and will use that.")
  57. return load_config.my_config
  58. # config_files = glob.glob('/kiap/src/kiap_bci/config.yaml', recursive=True)
  59. config_files = glob.glob('./config.yaml', recursive=True)
  60. # Load the params file
  61. if config_files:
  62. config_fname = config_files[0]
  63. with open(config_fname) as stream:
  64. params = munch.fromYAML(stream, Loader=yaml.FullLoader)
  65. validation_passed, validation_error = vd.validate_schema(params)
  66. if not validation_passed:
  67. log.error(validation_error)
  68. raise ValueError('Configuration is not valid !')
  69. try:
  70. with open('paradigm.yaml') as stream:
  71. params.paradigms = munch.fromYAML(stream, Loader=yaml.FullLoader)
  72. except Exception as e:
  73. log.warning(f'Could not load paridigm yaml file.\n{e}')
  74. supplemental_cfgs = []
  75. try:
  76. for sfn in params.supplemental_config:
  77. file_list = []
  78. sfn_path = pathlib.Path(sfn)
  79. if sfn_path.exists() and sfn_path.is_dir():
  80. file_list += sfn_path.glob('**/*.yml')
  81. file_list += sfn_path.glob('**/*.yaml')
  82. file_list.sort(key=lambda p : str(p.absolute()).lower())
  83. else:
  84. file_list.append(sfn_path)
  85. for a_file in file_list:
  86. try:
  87. log.info("Reading supplementary config file '{}'.".format(a_file))
  88. with open(a_file) as stream:
  89. supplemental_cfgs.append(munch.fromYAML(stream, Loader=yaml.Loader))
  90. except FileNotFoundError as e:
  91. log.warning("Supplemental config file '{}' not found. This option will be ignored.".format(a_file), exc_info=1)
  92. params = reduce(lambda xx, yy: munch.Munch(mergemunch(xx, yy)), supplemental_cfgs, params)
  93. except AttributeError as e:
  94. log.info("Attribute 'supplemental_config' not set in config file.")
  95. validation_passed, validation_error = vd.validate_schema(params)
  96. if not validation_passed:
  97. log.error(validation_error)
  98. raise ValueError('Configuration is not valid !')
  99. params = setfileattr(params)
  100. params = config_setup(params)
  101. params = eval_ranges(params)
  102. params.buffer.shape = [params.buffer.length, params.daq.n_channels]
  103. else:
  104. log.debug("No file called 'config.yaml' found, please save the file in BCI folder. Shutting down...")
  105. sys.exit("CONFIG FILE NOT FOUND")
  106. load_config.my_config = params
  107. return params
  108. # This function merges two munch dictionaries. Use as munch.Munch(mergemunch(m1, m2))
  109. def mergemunch(dict1, dict2):
  110. for k in set(dict1) | set(dict2):
  111. if k in dict1 and k in dict2:
  112. if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):
  113. yield k, munch.Munch(mergemunch(dict1[k], dict2[k]))
  114. else:
  115. # If one of the values is not a dict, you can't continue merging it.
  116. # Value from second dict overrides one in first and we move on.
  117. yield k, dict2[k]
  118. # Alternatively, replace this with exception raiser to alert you of value conflicts
  119. elif k in dict1:
  120. yield k, dict1[k]
  121. else:
  122. yield k, dict2[k]
  123. def eval_ranges(params):
  124. '''evaluate all ranges, currently only for template'''
  125. params.daq.n_channels = params.daq.n_channels_max -len(params.daq.exclude_channels)
  126. if 'range' in params.classifier.template:
  127. params.classifier.template = np.array(eval(params.classifier.template))
  128. else:
  129. params.classifier.template = np.array(params.classifier.template)
  130. # if 'all' in params.classifier.channel_mask:
  131. # params.classifier.channel_mask = list(range(0,params.daq.n_channels))
  132. if 'range' in params.classifier.exclude_channels:
  133. params.classifier.exclude_channels = list(eval(params.classifier.exclude_channels))
  134. if 'range' in params.classifier.include_channels:
  135. params.classifier.include_channels = list(eval(params.classifier.include_channels))
  136. if 'range' in params.lfp.array1:
  137. params.lfp.array1 = list(eval(params.lfp.array1))
  138. if 'range' in params.lfp.array21:
  139. params.lfp.array21 = list(eval(params.lfp.array21))
  140. if 'range' in params.lfp.array22:
  141. params.lfp.array22 = list(eval(params.lfp.array22))
  142. params.lfp.array2 = list(params.lfp.array21 + params.lfp.array22)
  143. return params
  144. def setfileattr(params):
  145. tnow = datetime.datetime.now().strftime('%H_%M_%S')
  146. today = str(datetime.date.today())
  147. datafile_path = os.path.join(params.file_handling.data_path, today)
  148. params.file_handling.datafile_path = datafile_path
  149. pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
  150. # filename_data = 'data.bin'
  151. # filename_log_info = 'info.log'
  152. # filename_log_debug = 'debug.log'
  153. # filename_events = 'events.txt'
  154. filename_data = f'data_{tnow}.bin'
  155. filename_baseline = f'bl_{tnow}.npy'
  156. filename_log_info = f'info_{tnow}.log'
  157. filename_log_debug = f'debug_{tnow}.log'
  158. filename_events = f'events_{tnow}.txt'
  159. filename_config = f'config_{tnow}.yaml'
  160. filename_paradigm = f'paradigm_{tnow}.yaml'
  161. filename_config_dump = f'config_dump_{tnow}.yaml'
  162. filename_git_patch = f'git_changes_{tnow}.patch'
  163. filename_history = f'history.bin'
  164. params.file_handling.filename_data = os.path.join(datafile_path, filename_data)
  165. params.file_handling.filename_baseline = os.path.join(datafile_path, filename_baseline)
  166. params.file_handling.filename_log_info = os.path.join(datafile_path, filename_log_info)
  167. params.file_handling.filename_log_debug = os.path.join(datafile_path, filename_log_debug)
  168. params.file_handling.filename_events = os.path.join(datafile_path, filename_events)
  169. params.file_handling.filename_config = os.path.join(datafile_path, filename_config)
  170. params.file_handling.filename_paradigm = os.path.join(datafile_path, filename_paradigm)
  171. params.file_handling.filename_config_dump = os.path.join(datafile_path, filename_config_dump)
  172. params.file_handling.filename_git_patch = os.path.join(datafile_path, filename_git_patch)
  173. params.file_handling.filename_history = os.path.join(datafile_path, filename_history)
  174. # get current git hash and store it
  175. git_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'])
  176. params.file_handling.git_hash = git_hash.decode("utf-8")[:-1]
  177. #update config file before saving it
  178. return params
  179. # def add_timestamps(params):
  180. # tnow = datetime.datetime.now().strftime('%H_%M_%S')
  181. # today = str(datetime.date.today())
  182. # datafile_path = os.path.join(params.file_handling.data_path, today)
  183. # params.file_handling.datafile_path = datafile_path
  184. # pathlib.Path(datafile_path).mkdir(parents=True, exist_ok=True)
  185. # filename_data = f'data_{tnow}.bin'
  186. # filename_log_info = f'info_{tnow}.log'
  187. # filename_log_debug = f'debug_{tnow}.log'
  188. # filename_events = f'events_{tnow}.txt'
  189. # filename_data = os.path.join(datafile_path, filename_data)
  190. # filename_log_info = os.path.join(datafile_path, filename_log_info)
  191. # filename_log_debug = os.path.join(datafile_path, filename_log_debug)
  192. # filename_events = os.path.join(datafile_path, filename_events)
  193. # log.info(f'Data file: {filename_data}')
  194. # # os.rename(params.file_handling.filename_data, filename_data)
  195. # # os.rename(params.file_handling.filename_log_info, filename_log_info)
  196. # # os.rename(params.file_handling.filename_log_debug, filename_log_debug)
  197. # # os.rename(params.file_handling.filename_events, filename_events)
  198. # # log.info('Added timestamps to files')
  199. # return None
  200. def config_setup(params):
  201. """Interprets the config file and updates BCI parameters."""
  202. params.session.flags.stimulus = True
  203. params.session.flags.decode = False
  204. if params.speller.type == 'norm':
  205. params.session.flags.stimulus = False
  206. return params
  207. if 'log' not in locals():
  208. log = config_logging()
  209. if args.log is not None:
  210. log.info(f'Debug level is: {args.log.upper()}')
  211. else:
  212. log.info('Debug level is INFO')