sessions.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from pathlib import Path
  2. import yaml
  3. import munch
  4. import pandas as pd
  5. import numpy as np
  6. from helpers import data_management as dm
  7. from helpers.nsp import fix_timestamps
  8. import os
  9. import re
  10. import logging
  11. EVENT_FIRST_LINE_RE = re.compile(r"^(\d+),.*Block, start'$")
  12. EVENT_LAST_LINE_RE = re.compile(r"^(\d+),.*Block, stop'$")
  13. logger = logging.getLogger('KIAP.sessions')
  14. def check_event_file_format(ev_file):
  15. with open(ev_file, "rb") as f:
  16. first = f.readline().decode() # Read the first line.
  17. f.seek(-2, os.SEEK_END) # Jump to the second last byte.
  18. while f.read(1) != b"\n": # Until EOL is found...
  19. f.seek(-2, os.SEEK_CUR) # ...jump back the read byte plus one more.
  20. last = f.readline().decode() # Read last line.
  21. return (EVENT_FIRST_LINE_RE.match(first) is not None) and (EVENT_LAST_LINE_RE.match(last) is not None)
  22. def get_sessions(path, mode=None, n=None, start=0):
  23. this_path = Path(path)
  24. config_files = sorted(this_path.glob("**/config_dump_*.yaml"))
  25. config_files_read = []
  26. for cfg_path in config_files[start:]:
  27. try:
  28. with open(cfg_path, 'r') as f:
  29. # cfg = yaml.load(f, Loader=yaml.Loader)
  30. cfg = munch.Munch.fromYAML(f, Loader=yaml.Loader)
  31. logger.info(f"Loading {cfg_path}")
  32. if (mode is None) or (mode == cfg.speller.type):
  33. cfg_d = {
  34. 'mode': cfg.speller.type,
  35. 'cfg': str(Path(*Path(cfg_path).parts[-2:])),
  36. 'events': cfg.file_handling.get('filename_events'),
  37. 'data': cfg.file_handling.get('filename_data'),
  38. 'log': cfg.file_handling.get('filename_log_info')
  39. }
  40. if cfg_d['events'] is None or cfg_d['data'] is None or cfg_d['log'] is None:
  41. continue
  42. cfg_d['events'] = str(Path(*Path(cfg_d['events']).parts[-2:]))
  43. if not check_event_file_format(this_path / cfg_d['events']):
  44. logger.warning(f"{cfg_d['events']} is not valid (first / last line does not match schema)")
  45. #continue
  46. cfg_d['data'] = str(Path(*Path(cfg_d['data']).parts[-2:]))
  47. cfg_d['log'] = str(Path(*Path(cfg_d['log']).parts[-2:]))
  48. config_files_read.append(cfg_d)
  49. if (n is not None) and (len(config_files_read) >= n):
  50. break
  51. except FileNotFoundError as e:
  52. logger.warning(f"A file related to {cfg_path} was not found ({e}).")
  53. # config_files_read.append(cfg)
  54. cfg_pd = pd.DataFrame(config_files_read)
  55. return (cfg_pd, len(config_files))
  56. TRE = re.compile(r"^(\d+),.*$")
  57. def get_session_data(path, session):
  58. """
  59. Load data for a session. Requires the the session configuration file to deduce file format.
  60. returns time vector for samples, data, and channel list.
  61. """
  62. fn_sess = Path(path, session['data'])
  63. fn_evs = Path(path, session['events'])
  64. logger.debug(f"gsd loading {fn_evs}")
  65. with open(Path(path, session['cfg']), 'r') as f:
  66. params = munch.Munch.fromYAML(f, Loader=yaml.Loader)
  67. datav, ts, ch_rec_list = dm.get_session(fn_sess, params=params)
  68. ts, offsets, _ = fix_timestamps(ts)
  69. with open(fn_evs, 'r') as f:
  70. evs = f.readlines()
  71. times = []
  72. for ev in evs:
  73. mtch = TRE.match(ev)
  74. times.append(int(mtch.group(1)))
  75. tsevs, _, _ = fix_timestamps(np.array(times))
  76. evt = [tsevs[0] / 3e4, tsevs[-1] / 3e4]
  77. tv = ts / 3e4
  78. return tv, datav, ch_rec_list, evt