# # MIT License # # Copyright (c) 2020 Keisuke Sehara # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import namedtuple as _namedtuple from pathlib import Path as _Path import re as _re import numpy as _np import pandas as _pd DEBUG = False def debug(msg, end="\n"): if DEBUG == True: print(msg, flush=True) class Specification(_namedtuple("_Specification", ("dataset", "subject", "session", "domain", "trial", "suffix"))): DOMAIN_TO_SEARCH = "states" TRIAL_CAPTURE = _re.compile(r"_run(\d+)") def __new__(cls, dataset=None, subject=None, session=None, domain=None, trial=None, suffix=None): if dataset is not None: dataset = _Path(dataset).resolve() if subject is not None: subject = str(subject) if session is not None: session = str(session) if trial is not None: trial = int(trial) if domain is not None: domain = str(domain) if suffix is not None: suffix = str(suffix) if not suffix.startswith("."): suffix = "." + suffix return super(cls, Specification).__new__(cls, dataset, subject, session, domain, trial, suffix) def with_values(self, **kwargs): out = {} for field, defaultvalue in zip(self._fields, self): out[field] = kwargs.get(field, defaultvalue) return self.__class__(**out) @property def dataset_directory(self): if self.dataset is None: raise ValueError("dataset is not specified") elif not self.dataset.exists(): raise FileNotFoundError(f"dataset directory not found: {self.dataset}") return self.dataset @property def subject_directory(self): dsdir = self.dataset_directory if self.subject is None: raise ValueError("subject is not specified") subdir = dsdir / self.subject if not subdir.exists(): raise FileNotFoundError(f"subject directory not found: {subdir}") return subdir @property def session_directory(self): subdir = self.subject_directory if self.session is None: raise ValueError("session is not specified") sessdir = subdir / self.session if not sessdir.exists(): raise FileNotFoundError(f"session directory not found: {sessdir}") return sessdir @property def domain_directory(self): sessdir = self.session_directory if self.domain is None: raise ValueError("domain is not specified") domdir = sessdir / self.domain if not domdir.exists(): raise FileNotFoundError(f"domain directory not found: {domdir}") return domdir @property def filepath(self): domdir = self.domain_directory if self.trial is None: raise ValueError("trial index is not specified") ret = domdir / f"{self.subject}_{self.session}_{self.domain}_run{self.trial:05d}" if self.suffix is not None: ret = ret.with_suffix(self.suffix) if not ret.exists(): raise FileNotFoundError(f"file not found: {ret.name}") return ret @property def subjects(self): if self.subject is not None: debug(f"found subject: {self.subject}") return (self,) else: ret = [] for sub in sorted(self.dataset_directory.glob("MLA*")): debug(f"found subject: {sub.name}") ret.append(self.with_values(subject=sub.name)) return tuple(ret) @property def sessions(self): ret = [] for sub in self.subjects: for path in sorted(sub.subject_directory.glob("session*")): if (self.session is None) or (path.name == self.session): debug(f"found session: {sub.subject}/{path.name}") ret.append(sub.with_values(session=path.name)) return tuple(ret) @property def trials(self): ret = [] for session in self.sessions: if self.domain is None: domdir = session.with_values(domain=self.DOMAIN_TO_SEARCH).domain_directory else: domdir = session.domain_directory for path in sorted(domdir.glob("*run*")): captured = self.TRIAL_CAPTURE.search(path.stem) if not captured: continue try: idx = int(captured.group(1)) if (self.trial is None) or (idx == self.trial): debug(f"found trial: {session.subject}/{session.session}/trial{idx:05d}") ret.append(session.with_values(trial=idx)) except ValueError: pass return tuple(ret) def load_trials(dataset, subject=None, session=None, trial=None): basespecs = Specification(dataset=dataset, subject=subject, session=session, trial=trial) return tuple(Trial.load(trial) for trial in basespecs.trials) class Trial: """ [usage] ``` trial = Trial.load("../datasets/merged-output", , , ) ``` or ``` trial = Trial(<'states' DataFrame>, <'tracking' DataFrame>) ``` - the 'states' data-frame can be accessed by `trial.states`. - the 'tracking' data-frame can be accessed by `trial.tracking`. """ @classmethod def load(cls, dataset, subject=None, session=None, trial=None): if not isinstance(dataset, str): # assumed to be a Specification basespecs = Specification(*dataset) else: basespecs = Specification(dataset=dataset, subject=subject, session=session, trial=trial) states_path = basespecs.with_values(domain="states", suffix=".csv").filepath track_path = basespecs.with_values(domain="tracking", suffix=".csv").filepath return cls(_pd.read_csv(str(states_path)), _pd.read_csv(str(track_path)), specification=basespecs) def __init__(self, states_dataframe, tracking_dataframe, specification=None): self.states = states_dataframe self.tracking = tracking_dataframe self._specs = specification @property def subject(self): return self._specs.subject @property def session(self): return self._specs.session @property def index(self): return self._specs.trial def has_eyedata(self): """returns if eye data is there in this trial.""" return not _np.all(_np.isnan(self.tracking.left_pupil_normalized_position)) def get_timeranges(self, pattern): """returns a list of (from-rowindex, to-rowindex) objects in the tracking table, given a pattern of states.""" if isinstance(pattern, str): pattern = StateCapture(pattern) return pattern.parse(self.states) @property def specification(self): return _specs class Capture: CAPTURE_PATTERN = _re.compile(r"<([a-zA-Z]+)>$") def __init__(self, pattern): self._seq = [item.strip() for item in pattern.split()] if len(self._seq) == 0: raise ValueError("add at least one state in the pattern") self._captured = -1 self._offset = 0 for i, item in enumerate(self._seq): if self.CAPTURE_PATTERN.match(item): if self._captured >= 0: raise ValueError(f"cannot take two or more captured states in one pattern: {pattern}") self._captured = i if self._captured == -1: raise ValueError(f"bracket the state to capture by '<' and '>': {pattern}") self._seq[self._captured] = self._seq[self._captured][1:-1] debug(f"capture: {self._seq} (captured: '{self._seq[self._captured]}')") def push(self, item): debug(f"pushed: {item}") if item == self._seq[self._offset]: if self._captured == self._offset: debug("[capture]", end=" ") self.capture() self._offset += 1 if self._offset == len(self._seq): debug("[match]", end=" ") self.matched() self._offset = 0 else: debug("[clear]", end=" ") self.clear() self._offset = 1 if item == self._seq[0] else 0 debug(f"offset={self._offset}") def clear(self): pass def matched(self): pass def capture(self): pass class StateCapture(Capture): def __init__(self, pattern): super().__init__(pattern) def initialize(self, states): self.__states = states self.__started = 0 self.__current = 0 self.__size = self.__states.shape[0] self.__captured = None self.__periods = [] def parse(self, states): self.initialize(states) while self.__current < self.__size: self.push(self.__states.iloc[self.__current].State) self.__current += 1 debug(f"return: {self.__periods}") return self.__periods def clear(self): self.__started = self.__current + 1 def matched(self): if self.__captured is None: raise RuntimeError("matched() called without any capture") period = (self.__captured.FromFrame, self.__captured.ToFrame) if period not in self.__periods: debug(f"adding: {period}") self.__periods.append(period) self.__captured = None self.__current = self.__started self.__started += 1 def capture(self): self.__captured = self.__states.iloc[self.__current]