|
@@ -0,0 +1,311 @@
|
|
|
+#
|
|
|
+# MIT License
|
|
|
+#
|
|
|
+# Copyright (c) 2020 Keisuke Sehara
|
|
|
+#
|
|
|
+# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
|
+# of this software and associated documentation files (the "Software"), to deal
|
|
|
+# in the Software without restriction, including without limitation the rights
|
|
|
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
|
+# copies of the Software, and to permit persons to whom the Software is
|
|
|
+# furnished to do so, subject to the following conditions:
|
|
|
+#
|
|
|
+# The above copyright notice and this permission notice shall be included in all
|
|
|
+# copies or substantial portions of the Software.
|
|
|
+#
|
|
|
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
|
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
|
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
|
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
|
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
|
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
|
+# SOFTWARE.
|
|
|
+#
|
|
|
+
|
|
|
+from collections import namedtuple as _namedtuple
|
|
|
+from pathlib import Path as _Path
|
|
|
+import re as _re
|
|
|
+
|
|
|
+import numpy as _np
|
|
|
+import pandas as _pd
|
|
|
+
|
|
|
+DEBUG = False
|
|
|
+def debug(msg, end="\n"):
|
|
|
+ if DEBUG == True:
|
|
|
+ print(msg, flush=True)
|
|
|
+
|
|
|
+class Specification(_namedtuple("_Specification",
|
|
|
+ ("dataset", "subject", "session", "domain", "trial", "suffix"))):
|
|
|
+ DOMAIN_TO_SEARCH = "states"
|
|
|
+ TRIAL_CAPTURE = _re.compile(r"_run(\d+)")
|
|
|
+
|
|
|
+ def __new__(cls, dataset=None, subject=None, session=None, domain=None, trial=None, suffix=None):
|
|
|
+ if dataset is not None:
|
|
|
+ dataset = _Path(dataset).resolve()
|
|
|
+ if subject is not None:
|
|
|
+ subject = str(subject)
|
|
|
+ if session is not None:
|
|
|
+ session = str(session)
|
|
|
+ if trial is not None:
|
|
|
+ trial = int(trial)
|
|
|
+ if domain is not None:
|
|
|
+ domain = str(domain)
|
|
|
+ if suffix is not None:
|
|
|
+ suffix = str(suffix)
|
|
|
+ if not suffix.startswith("."):
|
|
|
+ suffix = "." + suffix
|
|
|
+ return super(cls, Specification).__new__(cls, dataset, subject, session, domain, trial, suffix)
|
|
|
+
|
|
|
+ def with_values(self, **kwargs):
|
|
|
+ out = {}
|
|
|
+ for field, defaultvalue in zip(self._fields, self):
|
|
|
+ out[field] = kwargs.get(field, defaultvalue)
|
|
|
+ return self.__class__(**out)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def dataset_directory(self):
|
|
|
+ if self.dataset is None:
|
|
|
+ raise ValueError("dataset is not specified")
|
|
|
+ elif not self.dataset.exists():
|
|
|
+ raise FileNotFoundError(f"dataset directory not found: {self.dataset}")
|
|
|
+ return self.dataset
|
|
|
+
|
|
|
+ @property
|
|
|
+ def subject_directory(self):
|
|
|
+ dsdir = self.dataset_directory
|
|
|
+ if self.subject is None:
|
|
|
+ raise ValueError("subject is not specified")
|
|
|
+ subdir = dsdir / self.subject
|
|
|
+ if not subdir.exists():
|
|
|
+ raise FileNotFoundError(f"subject directory not found: {subdir}")
|
|
|
+ return subdir
|
|
|
+
|
|
|
+ @property
|
|
|
+ def session_directory(self):
|
|
|
+ subdir = self.subject_directory
|
|
|
+ if self.session is None:
|
|
|
+ raise ValueError("session is not specified")
|
|
|
+ sessdir = subdir / self.session
|
|
|
+ if not sessdir.exists():
|
|
|
+ raise FileNotFoundError(f"session directory not found: {sessdir}")
|
|
|
+ return sessdir
|
|
|
+
|
|
|
+ @property
|
|
|
+ def domain_directory(self):
|
|
|
+ sessdir = self.session_directory
|
|
|
+ if self.domain is None:
|
|
|
+ raise ValueError("domain is not specified")
|
|
|
+ domdir = sessdir / self.domain
|
|
|
+ if not domdir.exists():
|
|
|
+ raise FileNotFoundError(f"domain directory not found: {domdir}")
|
|
|
+ return domdir
|
|
|
+
|
|
|
+ @property
|
|
|
+ def filepath(self):
|
|
|
+ domdir = self.domain_directory
|
|
|
+ if self.trial is None:
|
|
|
+ raise ValueError("trial index is not specified")
|
|
|
+ ret = domdir / f"{self.subject}_{self.session}_{self.domain}_run{self.trial:05d}"
|
|
|
+ if self.suffix is not None:
|
|
|
+ ret = ret.with_suffix(self.suffix)
|
|
|
+ if not ret.exists():
|
|
|
+ raise FileNotFoundError(f"file not found: {ret.name}")
|
|
|
+ return ret
|
|
|
+
|
|
|
+ @property
|
|
|
+ def subjects(self):
|
|
|
+ if self.subject is not None:
|
|
|
+ debug(f"found subject: {self.subject}")
|
|
|
+ return (self,)
|
|
|
+ else:
|
|
|
+ ret = []
|
|
|
+ for sub in sorted(self.dataset_directory.glob("MLA*")):
|
|
|
+ debug(f"found subject: {sub.name}")
|
|
|
+ ret.append(self.with_values(subject=sub.name))
|
|
|
+ return tuple(ret)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def sessions(self):
|
|
|
+ ret = []
|
|
|
+ for sub in self.subjects:
|
|
|
+ for path in sorted(sub.subject_directory.glob("session*")):
|
|
|
+ if (self.session is None) or (path.name == self.session):
|
|
|
+ debug(f"found session: {sub.subject}/{path.name}")
|
|
|
+ ret.append(sub.with_values(session=path.name))
|
|
|
+ return tuple(ret)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def trials(self):
|
|
|
+ ret = []
|
|
|
+ for session in self.sessions:
|
|
|
+ if self.domain is None:
|
|
|
+ domdir = session.with_values(domain=self.DOMAIN_TO_SEARCH).domain_directory
|
|
|
+ else:
|
|
|
+ domdir = session.domain_directory
|
|
|
+
|
|
|
+ for path in sorted(domdir.glob("*run*")):
|
|
|
+ captured = self.TRIAL_CAPTURE.search(path.stem)
|
|
|
+ if not captured:
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ idx = int(captured.group(1))
|
|
|
+ if (self.trial is None) or (idx == self.trial):
|
|
|
+ debug(f"found trial: {session.subject}/{session.session}/trial{idx:05d}")
|
|
|
+ ret.append(session.with_values(trial=idx))
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return tuple(ret)
|
|
|
+
|
|
|
+def load_trials(dataset, subject=None, session=None, trial=None):
|
|
|
+ basespecs = Specification(dataset=dataset,
|
|
|
+ subject=subject,
|
|
|
+ session=session,
|
|
|
+ trial=trial)
|
|
|
+ return tuple(Trial.load(trial) for trial in basespecs.trials)
|
|
|
+
|
|
|
+class Trial:
|
|
|
+ """
|
|
|
+ [usage]
|
|
|
+ ```
|
|
|
+ trial = Trial.load("../datasets/merged-output", <subject>, <session>, <trial-index>)
|
|
|
+ ```
|
|
|
+
|
|
|
+ or
|
|
|
+
|
|
|
+ ```
|
|
|
+ trial = Trial(<'states' DataFrame>, <'tracking' DataFrame>)
|
|
|
+ ```
|
|
|
+
|
|
|
+ - the 'states' data-frame can be accessed by `trial.states`.
|
|
|
+ - the 'tracking' data-frame can be accessed by `trial.tracking`.
|
|
|
+ """
|
|
|
+ @classmethod
|
|
|
+ def load(cls, dataset, subject=None, session=None, trial=None):
|
|
|
+ if not isinstance(dataset, str):
|
|
|
+ # assumed to be a Specification
|
|
|
+ basespecs = Specification(*dataset)
|
|
|
+ else:
|
|
|
+ basespecs = Specification(dataset=dataset,
|
|
|
+ subject=subject,
|
|
|
+ session=session,
|
|
|
+ trial=trial)
|
|
|
+ states_path = basespecs.with_values(domain="states", suffix=".csv").filepath
|
|
|
+ track_path = basespecs.with_values(domain="tracking", suffix=".csv").filepath
|
|
|
+ return cls(_pd.read_csv(str(states_path)),
|
|
|
+ _pd.read_csv(str(track_path)),
|
|
|
+ specification=basespecs)
|
|
|
+
|
|
|
+ def __init__(self, states_dataframe, tracking_dataframe, specification=None):
|
|
|
+ self.states = states_dataframe
|
|
|
+ self.tracking = tracking_dataframe
|
|
|
+ self._specs = specification
|
|
|
+
|
|
|
+ @property
|
|
|
+ def subject(self):
|
|
|
+ return self._specs.subject
|
|
|
+
|
|
|
+ @property
|
|
|
+ def session(self):
|
|
|
+ return self._specs.session
|
|
|
+
|
|
|
+ @property
|
|
|
+ def index(self):
|
|
|
+ return self._specs.trial
|
|
|
+
|
|
|
+ def has_eyedata(self):
|
|
|
+ """returns if eye data is there in this trial."""
|
|
|
+ return not _np.all(_np.isnan(self.tracking.left_pupil_normalized_position))
|
|
|
+
|
|
|
+ def get_timeranges(self, pattern):
|
|
|
+ """returns a list of (from-rowindex, to-rowindex) objects
|
|
|
+ in the tracking table, given a pattern of states."""
|
|
|
+ if isinstance(pattern, str):
|
|
|
+ pattern = StateCapture(pattern)
|
|
|
+ return pattern.parse(self.states)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def specification(self):
|
|
|
+ return _specs
|
|
|
+
|
|
|
+class Capture:
|
|
|
+ CAPTURE_PATTERN = _re.compile(r"<([a-zA-Z]+)>$")
|
|
|
+
|
|
|
+ def __init__(self, pattern):
|
|
|
+ self._seq = [item.strip() for item in pattern.split()]
|
|
|
+ if len(self._seq) == 0:
|
|
|
+ raise ValueError("add at least one state in the pattern")
|
|
|
+ self._captured = -1
|
|
|
+ self._offset = 0
|
|
|
+ for i, item in enumerate(self._seq):
|
|
|
+ if self.CAPTURE_PATTERN.match(item):
|
|
|
+ if self._captured >= 0:
|
|
|
+ raise ValueError(f"cannot take two or more captured states in one pattern: {pattern}")
|
|
|
+ self._captured = i
|
|
|
+ if self._captured == -1:
|
|
|
+ raise ValueError(f"bracket the state to capture by '<' and '>': {pattern}")
|
|
|
+ self._seq[self._captured] = self._seq[self._captured][1:-1]
|
|
|
+ debug(f"capture: {self._seq} (captured: '{self._seq[self._captured]}')")
|
|
|
+
|
|
|
+ def push(self, item):
|
|
|
+ debug(f"pushed: {item}")
|
|
|
+ if item == self._seq[self._offset]:
|
|
|
+ if self._captured == self._offset:
|
|
|
+ debug("[capture]", end=" ")
|
|
|
+ self.capture()
|
|
|
+ self._offset += 1
|
|
|
+ if self._offset == len(self._seq):
|
|
|
+ debug("[match]", end=" ")
|
|
|
+ self.matched()
|
|
|
+ self._offset = 0
|
|
|
+ else:
|
|
|
+ debug("[clear]", end=" ")
|
|
|
+ self.clear()
|
|
|
+ self._offset = 1 if item == self._seq[0] else 0
|
|
|
+ debug(f"offset={self._offset}")
|
|
|
+
|
|
|
+ def clear(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def matched(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def capture(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+class StateCapture(Capture):
|
|
|
+ def __init__(self, pattern):
|
|
|
+ super().__init__(pattern)
|
|
|
+
|
|
|
+ def initialize(self, states):
|
|
|
+ self.__states = states
|
|
|
+ self.__started = 0
|
|
|
+ self.__current = 0
|
|
|
+ self.__size = self.__states.shape[0]
|
|
|
+ self.__captured = None
|
|
|
+ self.__periods = []
|
|
|
+
|
|
|
+ def parse(self, states):
|
|
|
+ self.initialize(states)
|
|
|
+ while self.__current < self.__size:
|
|
|
+ self.push(self.__states.iloc[self.__current].State)
|
|
|
+ self.__current += 1
|
|
|
+ debug(f"return: {self.__periods}")
|
|
|
+ return self.__periods
|
|
|
+
|
|
|
+ def clear(self):
|
|
|
+ self.__started = self.__current + 1
|
|
|
+
|
|
|
+ def matched(self):
|
|
|
+ if self.__captured is None:
|
|
|
+ raise RuntimeError("matched() called without any capture")
|
|
|
+ period = (self.__captured.FromFrame, self.__captured.ToFrame)
|
|
|
+ if period not in self.__periods:
|
|
|
+ debug(f"adding: {period}")
|
|
|
+ self.__periods.append(period)
|
|
|
+ self.__captured = None
|
|
|
+ self.__current = self.__started
|
|
|
+ self.__started += 1
|
|
|
+
|
|
|
+ def capture(self):
|
|
|
+ self.__captured = self.__states.iloc[self.__current]
|