Keisuke Sehara před 1 rokem
rodič
revize
33073240d8

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 1143 - 0
04_analysis/01_asymmetry.ipynb


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 846 - 0
04_analysis/02a_prediction_analysis-ROC.ipynb


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 897 - 0
04_analysis/02b_prediction-analysis-perceptron.ipynb


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 859 - 0
04_analysis/03_saccades.ipynb


+ 73 - 0
04_analysis/ROC.py

@@ -0,0 +1,73 @@
+from collections import namedtuple
+import numpy as np
+import matplotlib.pyplot as plt
+
+def scale(v, vmin=-180, vmax=180):
+    return (v - vmin) / (vmax - vmin)
+
+def by_small(v, frac=0.001, add=True):
+    by = abs(v)*frac
+    if add == False:
+        by = -by
+    return v + by
+
+def reported(v):
+    print(v)
+    return v
+
+def compute(positive, negative):
+    merged_data = [(v, True) for v in positive] + \
+                  [(v, False) for v in negative]
+    merged_data = sorted(merged_data, key=lambda v: v[0])
+
+    sorted_values = np.array([v[0] for v in merged_data])
+    sorted_flags  = np.array([v[1] for v in merged_data])
+    true_total  = np.count_nonzero(sorted_flags)
+    false_total = sorted_flags.size - true_total
+
+    thresholds  = [sorted_values.min()]
+    sensitivity = [1]
+    nonspecific = [1] 
+
+    above_init = np.where(sorted_values > sorted_values.min())[0]
+
+    if above_init.size > 0:
+        offset    = above_init.min()
+
+        while offset < sorted_values.size:
+            threshold     = sorted_values[offset]
+            detected_positive = sorted_flags[offset:]
+            true_positive     = np.count_nonzero(detected_positive)
+            false_positive    = np.count_nonzero(~detected_positive)
+
+            thresholds.append(threshold)
+            sensitivity.append(true_positive / true_total)
+            nonspecific.append(false_positive / false_total)
+
+            # compute next offset
+            ceiling = sorted_values[:(offset+1)].max()
+            above   = sorted_values > ceiling
+            if np.count_nonzero(above) == 0:
+                break
+            offset  = np.where(above)[0].min()
+
+    thresholds.append(by_small(sorted_values.max()))
+    sensitivity.append(0)
+    nonspecific.append(0)
+
+    return ROC(np.array(thresholds)[::-1],
+               np.array(sensitivity)[::-1],
+               np.array(nonspecific)[::-1])
+
+class ROC(namedtuple("_ROC", ("thresholds", "sensitivity", "nonspecific"))):   
+    @property
+    def AUC(self):
+        return sum((base1+base2)*abs(width)/2 \
+                   for base1, base2, width in zip(self.sensitivity[1:], self.sensitivity[:-1],
+                                                  (self.nonspecific[1:] - self.nonspecific[:-1])))
+            
+    
+    def plot(self, ax, cmap=plt.get_cmap("viridis"),
+             vmin=-180, vmax=180, alpha=1, markersize=10):
+        for t, n, s in zip(scale(self.thresholds, vmin=vmin, vmax=vmax), self.nonspecific, self.sensitivity):
+            ax.plot((n,), (s,), ".", color=cmap(t)[:3], alpha=alpha, mec="none", ms=markersize)

+ 311 - 0
04_analysis/datareader/__init__.py

@@ -0,0 +1,311 @@
+#
+# MIT License
+#
+# Copyright (c) 2020 Keisuke Sehara
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+
+from collections import namedtuple as _namedtuple
+from pathlib import Path as _Path
+import re as _re
+
+import numpy as _np
+import pandas as _pd
+
+DEBUG = False
+def debug(msg, end="\n"):
+    if DEBUG == True:
+        print(msg, flush=True)
+
+class Specification(_namedtuple("_Specification",
+                            ("dataset", "subject", "session", "domain", "trial", "suffix"))):
+    DOMAIN_TO_SEARCH = "states"
+    TRIAL_CAPTURE    = _re.compile(r"_run(\d+)")
+
+    def __new__(cls, dataset=None, subject=None, session=None, domain=None, trial=None, suffix=None):
+        if dataset is not None:
+            dataset = _Path(dataset).resolve()
+        if subject is not None:
+            subject = str(subject)
+        if session is not None:
+            session = str(session)
+        if trial is not None:
+            trial   = int(trial)
+        if domain is not None:
+            domain  = str(domain)
+        if suffix is not None:
+            suffix  = str(suffix)
+            if not suffix.startswith("."):
+                suffix = "." + suffix
+        return super(cls, Specification).__new__(cls, dataset, subject, session, domain, trial, suffix)
+
+    def with_values(self, **kwargs):
+        out = {}
+        for field, defaultvalue in zip(self._fields, self):
+            out[field] = kwargs.get(field, defaultvalue)
+        return self.__class__(**out)
+
+    @property
+    def dataset_directory(self):
+        if self.dataset is None:
+            raise ValueError("dataset is not specified")
+        elif not self.dataset.exists():
+            raise FileNotFoundError(f"dataset directory not found: {self.dataset}")
+        return self.dataset
+
+    @property
+    def subject_directory(self):
+        dsdir = self.dataset_directory
+        if self.subject is None:
+            raise ValueError("subject is not specified")
+        subdir = dsdir / self.subject
+        if not subdir.exists():
+            raise FileNotFoundError(f"subject directory not found: {subdir}")
+        return subdir
+
+    @property
+    def session_directory(self):
+        subdir = self.subject_directory
+        if self.session is None:
+            raise ValueError("session is not specified")
+        sessdir = subdir / self.session
+        if not sessdir.exists():
+            raise FileNotFoundError(f"session directory not found: {sessdir}")
+        return sessdir
+
+    @property
+    def domain_directory(self):
+        sessdir = self.session_directory
+        if self.domain is None:
+            raise ValueError("domain is not specified")
+        domdir = sessdir / self.domain
+        if not domdir.exists():
+            raise FileNotFoundError(f"domain directory not found: {domdir}")
+        return domdir
+
+    @property
+    def filepath(self):
+        domdir = self.domain_directory
+        if self.trial is None:
+            raise ValueError("trial index is not specified")
+        ret = domdir / f"{self.subject}_{self.session}_{self.domain}_run{self.trial:05d}"
+        if self.suffix is not None:
+            ret = ret.with_suffix(self.suffix)
+        if not ret.exists():
+            raise FileNotFoundError(f"file not found: {ret.name}")
+        return ret
+
+    @property
+    def subjects(self):
+        if self.subject is not None:
+            debug(f"found subject: {self.subject}")
+            return (self,)
+        else:
+            ret = []
+            for sub in sorted(self.dataset_directory.glob("MLA*")):
+                debug(f"found subject: {sub.name}")
+                ret.append(self.with_values(subject=sub.name))
+            return tuple(ret)
+
+    @property
+    def sessions(self):
+        ret = []
+        for sub in self.subjects:
+            for path in sorted(sub.subject_directory.glob("session*")):
+                if (self.session is None) or (path.name == self.session):
+                    debug(f"found session: {sub.subject}/{path.name}")
+                    ret.append(sub.with_values(session=path.name))
+        return tuple(ret)
+
+    @property
+    def trials(self):
+        ret = []
+        for session in self.sessions:
+            if self.domain is None:
+                domdir = session.with_values(domain=self.DOMAIN_TO_SEARCH).domain_directory
+            else:
+                domdir = session.domain_directory
+
+            for path in sorted(domdir.glob("*run*")):
+                captured = self.TRIAL_CAPTURE.search(path.stem)
+                if not captured:
+                    continue
+                try:
+                    idx = int(captured.group(1))
+                    if (self.trial is None) or (idx == self.trial):
+                        debug(f"found trial: {session.subject}/{session.session}/trial{idx:05d}")
+                        ret.append(session.with_values(trial=idx))
+                except ValueError:
+                    pass
+
+        return tuple(ret)
+
+def load_trials(dataset, subject=None, session=None, trial=None):
+    basespecs = Specification(dataset=dataset,
+                              subject=subject,
+                              session=session,
+                              trial=trial)
+    return tuple(Trial.load(trial) for trial in basespecs.trials)
+
+class Trial:
+    """
+    [usage]
+    ```
+    trial = Trial.load("../datasets/merged-output", <subject>, <session>, <trial-index>)
+    ```
+
+    or
+
+    ```
+    trial = Trial(<'states' DataFrame>, <'tracking' DataFrame>)
+    ```
+
+    - the 'states' data-frame can be accessed by `trial.states`.
+    - the 'tracking' data-frame can be accessed by `trial.tracking`.
+    """
+    @classmethod
+    def load(cls, dataset, subject=None, session=None, trial=None):
+        if not isinstance(dataset, str):
+            # assumed to be a Specification
+            basespecs = Specification(*dataset)
+        else:
+            basespecs   = Specification(dataset=dataset,
+                                        subject=subject,
+                                        session=session,
+                                        trial=trial)
+        states_path = basespecs.with_values(domain="states", suffix=".csv").filepath
+        track_path  = basespecs.with_values(domain="tracking", suffix=".csv").filepath
+        return cls(_pd.read_csv(str(states_path)),
+                   _pd.read_csv(str(track_path)),
+                   specification=basespecs)
+
+    def __init__(self, states_dataframe, tracking_dataframe, specification=None):
+        self.states   = states_dataframe
+        self.tracking = tracking_dataframe
+        self._specs   = specification
+
+    @property
+    def subject(self):
+        return self._specs.subject
+
+    @property
+    def session(self):
+        return self._specs.session
+
+    @property
+    def index(self):
+        return self._specs.trial
+
+    def has_eyedata(self):
+        """returns if eye data is there in this trial."""
+        return not _np.all(_np.isnan(self.tracking.left_pupil_normalized_position))
+
+    def get_timeranges(self, pattern):
+        """returns a list of (from-rowindex, to-rowindex) objects
+        in the tracking table, given a pattern of states."""
+        if isinstance(pattern, str):
+            pattern = StateCapture(pattern)
+        return pattern.parse(self.states)
+
+    @property
+    def specification(self):
+        return _specs
+
+class Capture:
+    CAPTURE_PATTERN = _re.compile(r"<([a-zA-Z]+)>$")
+
+    def __init__(self, pattern):
+        self._seq = [item.strip() for item in pattern.split()]
+        if len(self._seq) == 0:
+            raise ValueError("add at least one state in the pattern")
+        self._captured = -1
+        self._offset   = 0
+        for i, item in enumerate(self._seq):
+            if self.CAPTURE_PATTERN.match(item):
+                if self._captured >= 0:
+                    raise ValueError(f"cannot take two or more captured states in one pattern: {pattern}")
+                self._captured = i
+        if self._captured == -1:
+            raise ValueError(f"bracket the state to capture by '<' and '>': {pattern}")
+        self._seq[self._captured] = self._seq[self._captured][1:-1]
+        debug(f"capture: {self._seq} (captured: '{self._seq[self._captured]}')")
+
+    def push(self, item):
+        debug(f"pushed: {item}")
+        if item == self._seq[self._offset]:
+            if self._captured == self._offset:
+                debug("[capture]", end=" ")
+                self.capture()
+            self._offset += 1
+            if self._offset == len(self._seq):
+                debug("[match]", end=" ")
+                self.matched()
+                self._offset = 0
+        else:
+            debug("[clear]", end=" ")
+            self.clear()
+            self._offset = 1 if item == self._seq[0] else 0
+        debug(f"offset={self._offset}")
+
+    def clear(self):
+        pass
+
+    def matched(self):
+        pass
+
+    def capture(self):
+        pass
+
+class StateCapture(Capture):
+    def __init__(self, pattern):
+        super().__init__(pattern)
+
+    def initialize(self, states):
+        self.__states  = states
+        self.__started = 0
+        self.__current  = 0
+        self.__size    = self.__states.shape[0]
+        self.__captured = None
+        self.__periods = []
+
+    def parse(self, states):
+        self.initialize(states)
+        while self.__current < self.__size:
+            self.push(self.__states.iloc[self.__current].State)
+            self.__current += 1
+        debug(f"return: {self.__periods}")
+        return self.__periods
+
+    def clear(self):
+        self.__started = self.__current + 1
+
+    def matched(self):
+        if self.__captured is None:
+            raise RuntimeError("matched() called without any capture")
+        period = (self.__captured.FromFrame, self.__captured.ToFrame)
+        if period not in self.__periods:
+            debug(f"adding: {period}")
+            self.__periods.append(period)
+        self.__captured = None
+        self.__current  = self.__started
+        self.__started += 1
+
+    def capture(self):
+        self.__captured = self.__states.iloc[self.__current]

+ 50 - 0
04_analysis/epoch_analysis.py

@@ -0,0 +1,50 @@
+import numpy as np
+
+STATES   = ("AtEnd", "Backward", "Turn", "Forward", "Expect", "Lick")
+
+PATTERNS = dict(AtEnd="<AtEnd> Backward {turn}",
+                Backward="AtEnd <Backward> {turn}",
+                Turn="Backward <{turn}> Forward", # note a difference from the previous section
+                Forward="{turn} <Forward>", # note a difference from the previous section
+                Expect="{turn} Forward <Expect>",
+                Lick="{turn} Forward Expect <Lick>")
+
+def collect_epochs_with_pattern(trials, pattern, property="right_whisker_angle_deg"):
+    ret = []
+    for trial in trials:
+        # we collect traces within the matched periods, trial-by-trial
+        traces = []
+        for start, stop in trial.get_timeranges(pattern):
+            if (stop - start) < 2:
+                continue
+            # frames start from 1, so we have to adjust when reading from tracking data
+            epoch = np.array(trial.tracking[property][start-1:stop])
+            traces.append(epoch)
+
+        # ignore cases when the pattern did not match in this trial
+        if len(traces) == 0:
+            continue
+
+        # otherwise: add with subject/session information
+        ret.append(dict(session=(trial.subject, trial.session), traces=traces))
+    return ret
+
+def normalize_time(trace, t_out):
+    """sample temporal normalization using a 1-D trace.
+    `t_out` is a sequence of normalized timepoints, ranging [0,1].
+    """
+    t_in = np.arange(0, trace.size) / (trace.size - 1)
+    return np.interp(t_out, t_in, trace)
+
+def get_normalized_traces(epochs, bins=100):
+    """returns a numpy.ndarray of time-normalized traces.
+    the shape of the returned array is (bins, ntraces).
+
+    `epochs` must be in the form as they are collected
+    by `collect_epochs_with_pattern()`."""
+    t_out = np.linspace(0, 1, bins, endpoint=True)
+    traces = []
+    for epoch in epochs:
+        for trace in epoch["traces"]:
+            traces.append(normalize_time(trace, t_out))
+    return np.stack(traces, axis=-1)

+ 165 - 0
04_analysis/kw_dunn.py

@@ -0,0 +1,165 @@
+"""
+Downloaded from: https://gist.github.com/alimuldal/fbb19b73fa25423f02e8
+"""
+
+import numpy as np
+from scipy import stats
+from itertools import combinations
+from statsmodels.stats.multitest import multipletests
+from statsmodels.stats.libqsturng import psturng
+import warnings
+
+from collections import namedtuple
+
+def kw_dunn(groups, pairs=None, alpha=0.05, method='bonf'):
+    return KWDunn.compute(groups=groups,
+                          pairs=pairs,
+                          alpha=alpha,
+                          method=method)
+
+def chisqprob(chisq, df):
+    """a helper function to deal with recent SciPy versions."""
+    return stats.chi2.sf(chisq, df)
+
+Dunn = namedtuple('Dunn',
+    ["Z", "p_corrected", "alpha", "reject"]
+)
+
+class KWDunn(namedtuple('_KWDunn',
+                [
+                    "H",
+                    "p_omnibus",
+                    "pairwise",
+                ]
+             )):
+    @classmethod
+    def compute(cls, groups, pairs=None, alpha=0.05, method='bonf'):
+        H, p_omni, Z_pair, p_pair, rej = kw_dunn_raw(groups, to_compare=pairs,
+                                                     alpha=alpha, method=method)
+        pairwise = {}
+        if (pairs is None) or (len(pairs) == 0):
+            pass
+        else:
+            for pair, Z, p, r in zip(pairs, Z_pair, p_pair, rej):
+                pairwise[pair] = Dunn(Z, p, alpha, r)
+        return cls(H, p_omni, pairwise)
+
+def kw_dunn_raw(groups, to_compare=None, alpha=0.05, method='bonf'):
+    """
+
+    Kruskal-Wallis 1-way ANOVA with Dunn's multiple comparison test
+
+    Arguments:
+    ---------------
+    groups: sequence
+        arrays corresponding to k mutually independent samples from
+        continuous populations
+
+    to_compare: sequence
+        tuples specifying the indices of pairs of groups to compare, e.g.
+        [(0, 1), (0, 2)] would compare group 0 with 1 & 2. by default, all
+        possible pairwise comparisons between groups are performed.
+
+    alpha: float
+        family-wise error rate used for correcting for multiple comparisons
+        (see statsmodels.stats.multitest.multipletests for details)
+
+    method: string
+        method used to adjust p-values to account for multiple corrections (see
+        statsmodels.stats.multitest.multipletests for options)
+
+    Returns:
+    ---------------
+    H: float
+        Kruskal-Wallis H-statistic
+
+    p_omnibus: float
+        p-value corresponding to the global null hypothesis that the medians of
+        the groups are all equal
+
+    Z_pairs: float array
+        Z-scores computed for the absolute difference in mean ranks for each
+        pairwise comparison
+
+    p_corrected: float array
+        corrected p-values for each pairwise comparison, corresponding to the
+        null hypothesis that the pair of groups has equal medians. note that
+        these are only meaningful if the global null hypothesis is rejected.
+
+    reject: bool array
+        True for pairs where the null hypothesis can be rejected for the given
+        alpha
+
+    Reference:
+    ---------------
+    Gibbons, J. D., & Chakraborti, S. (2011). Nonparametric Statistical
+    Inference (5th ed., pp. 353-357). Boca Raton, FL: Chapman & Hall.
+
+    """
+
+    # omnibus test (K-W ANOVA)
+    # -------------------------------------------------------------------------
+
+    groups = [np.array(gg) for gg in groups]
+
+    k = len(groups)
+
+    n = np.array([len(gg) for gg in groups])
+    if np.any(n < 5):
+        warnings.warn("Sample sizes < 5 are not recommended (K-W test assumes "
+                      "a chi square distribution)")
+
+    allgroups = np.concatenate(groups)
+    N = len(allgroups)
+    ranked = stats.rankdata(allgroups)
+
+    # correction factor for ties
+    T = stats.tiecorrect(ranked)
+    if T == 0:
+        raise ValueError('All numbers are identical in kruskal')
+
+    # sum of ranks for each group
+    j = np.insert(np.cumsum(n), 0, 0)
+    R = np.empty(k, dtype=np.float)
+    for ii in range(k):
+        R[ii] = ranked[j[ii]:j[ii + 1]].sum()
+
+    # the Kruskal-Wallis H-statistic
+    H = (12. / (N * (N + 1.))) * ((R ** 2.) / n).sum() - 3 * (N + 1)
+
+    # apply correction factor for ties
+    H /= T
+
+    df_omnibus = k - 1
+    p_omnibus = chisqprob(H, df_omnibus)
+
+    # multiple comparisons
+    # -------------------------------------------------------------------------
+
+    # by default we compare every possible pair of groups
+    if to_compare is None:
+        to_compare = tuple(combinations(range(k), 2))
+
+    ncomp = len(to_compare)
+
+    Z_pairs = np.empty(ncomp, dtype=np.float)
+    p_uncorrected = np.empty(ncomp, dtype=np.float)
+    Rmean = R / n
+
+    for pp, (ii, jj) in enumerate(to_compare):
+
+        # standardized score
+        Zij = (np.abs(Rmean[ii] - Rmean[jj]) /
+               np.sqrt((1. / 12.) * N * (N + 1) * (1. / n[ii] + 1. / n[jj])))
+        Z_pairs[pp] = Zij
+
+    # corresponding p-values obtained from upper quantiles of the standard
+    # normal distribution
+    p_uncorrected = stats.norm.sf(Z_pairs) * 2.
+
+    # correction for multiple comparisons
+    reject, p_corrected, alphac_sidak, alphac_bonf = multipletests(
+        p_uncorrected, method=method
+    )
+
+    return H, p_omnibus, Z_pairs, p_corrected, reject

+ 115 - 0
04_analysis/saccades.py

@@ -0,0 +1,115 @@
+import numpy as _np
+from scipy.signal import find_peaks as _find_peaks
+import sliding1d as _sliding
+
+def get_sign(x):
+    return x / _np.abs(x)
+
+def sliding_diff(x):
+    x   = _np.array(x)
+    ret = _np.empty(x.size, dtype=_np.float64)
+    ret[1:-1] = (x[2:] - x[:-2])/2
+    ret[0]    = x[1] - x[0]
+    ret[-1]   = x[-1] - x[-2]
+    return ret
+
+def smoothing_diff(x, rad=5, num=3):
+    return sliding_diff(_sliding.nanmean(x, rad, num))
+
+def threshold_values(x, threshold):
+    ret = _np.array(x, copy=True)
+    ret[_np.abs(x) < threshold] = 0
+    return ret
+
+def detect(time, left_pupil, right_pupil,
+           std_period=None, std_threshold=5, min_distance_seconds=0.2,
+           smoothing_radius_seconds=0.05, smoothing_number=3):
+    """detect saccade events.
+    
+    Parameters
+    ----------
+    
+    time        -- the zero-starting trial time array.
+    left_pupil  -- the left-pupil position at each time point.
+    right_pupil -- the right-pupil position at each time point.
+    
+    For the keyword arguments, refer to the "Procedures" section below.
+    
+    Returns
+    -------
+    
+    a numpy.ndarray object whose size equals to the size of the `time` parameter.
+    the returned array has all-zero values, except for the positions of the detected saccades.
+    the values at the time points of detected saccades represent the average velocity of the two eyes.
+    
+    Procedures
+    ----------
+    
+    1. smoothing_diff() is used to compute velocity from position.
+       - `smoothing_radius_seconds` and `smoothing_number` is used here.
+       
+    2. the velocity is thresholded by its abosolute value
+       with the criterion: `abs(v) > v[std_period].std(ddof=1)*std_threshold`
+       - if `std_period` is None, then the initial 1 second of the recording is used.
+       
+    3. "average velocity" is calculated as the geometric mean 
+       of the (thresholded) velocities of the left and the right pupils.
+       
+    4. `scipy.signal.find_peaks` is used to detect peaks from the average velocity.
+       - `min_distance_seconds` is used to set the minimum distance between the peaks.
+     
+    5. the returned array is formatted. the value of each event is computed based on:
+       - amplitude: the amplitude in the average velocity
+       - signature: the signature in the left-pupil velocity.
+       
+    """
+    dt = _np.diff(time).mean()
+    if std_period is None:
+        std_period = time < 1 # the initial 1 second
+    smoothing_radius     = int(round(smoothing_radius_seconds / dt))
+    min_distance_samples = int(round(min_distance_seconds / dt))
+    
+    vleft  = smoothing_diff(left_pupil, rad=smoothing_radius, num=smoothing_number) / dt
+    vright = smoothing_diff(right_pupil, rad=smoothing_radius, num=smoothing_number) / dt
+    vleft_thresholded  = threshold_values(vleft, vleft[std_period].std(ddof=1)*std_threshold)
+    vright_thresholded = threshold_values(vright, vright[std_period].std(ddof=1)*std_threshold)
+    
+    # there are some "negative spikes", possibly due to small unsynchronization of pupil motions
+    # we take both negative and positive spikes here
+    vamp  = _np.sqrt(_np.abs(vleft_thresholded*vright_thresholded))
+    peaks = _find_peaks(vamp, distance=min_distance_samples)[0]
+    
+    ret   = _np.zeros(left_pupil.size, dtype=_np.float64)
+    ret[peaks] = vamp[peaks] * get_sign(vleft_thresholded[peaks])
+    return ret
+
+def annotate(events):
+    """splits the saccade events array into three traces for the plotting purpose.
+    
+    Parameter
+    ---------
+    
+    events -- as it is returned from saccades.detect()
+    
+    Returns
+    -------
+    
+    a `dict` object, which contains the following keys and values:
+    
+    - none:  no-event period trace, for plotting "baseline" period.
+    - left:  leftward saccade event trace
+    - right: rightward saccade event trace
+    
+    For event traces, NaNs are used to represent values at irrelevant time points.
+    
+    """
+    noevents    = _np.zeros(events.size, dtype=_np.float64)
+    leftevents  = _np.empty(events.size, dtype=_np.float64); leftevents[:]  = _np.nan
+    rightevents = _np.empty(events.size, dtype=_np.float64); rightevents[:] = _np.nan
+    for levt in _np.where(events > 0)[0]:
+        rng = slice(levt-1, levt+2)
+        leftevents[rng] = events[rng]
+    for revt in _np.where(events < 0)[0]:
+        rng = slice(revt-1, revt+2)
+        rightevents[rng] = events[rng]
+    return dict(none=noevents, left=leftevents, right=rightevents)

+ 40 - 0
04_analysis/trial_envelope.py

@@ -0,0 +1,40 @@
+from collections import namedtuple
+import numpy as np
+import sliding1d as sliding
+
+def interpolate(vec, min_size=10):
+    valid = ~np.isnan(vec)
+    if np.count_nonzero(valid) < min_size:
+        return np.empty(vec.size) * np.nan
+    t     = np.arange(vec.size)
+    return np.interp(t, t[valid], vec[valid])
+
+def whisker(trial, side="left", radius_sample=10, smooth=True):
+    return Envelope.whisker(trial, side=side, radius_sample=radius_sample, smooth=smooth)
+
+class Envelope(namedtuple("_Envelope", ("time", "raw", "bottom", "top"))):
+    @classmethod
+    def whisker(cls, trial, side="left", radius_sample=10, smooth=True):
+        vec  = interpolate(trial.tracking[f"{side}_whisker_angle_deg"])
+        vec  = (vec - vec.min()) / (vec.max() - vec.min())
+        time = np.array(trial.tracking["time"])
+        return cls.compute(time, vec, radius_sample=radius_sample, smooth=smooth)
+    
+    @classmethod
+    def compute(cls, time, vec, radius_sample=10, smooth=True):
+        bottom = sliding.nanmin(vec, radius_sample)
+        top    = sliding.nanmax(vec, radius_sample)
+        if smooth == True:
+            bottom = sliding.nanmean(bottom, radius_sample)
+            top    = sliding.nanmean(top, radius_sample)
+        return cls(time, vec, bottom, top)
+    
+    @property
+    def amplitude(self):
+        return self.top - self.bottom
+    
+    def with_range(self, rng):
+        return self.__class__(self.time[rng],
+                              self.raw[rng],
+                              self.bottom[rng],
+                              self.top[rng])

+ 2 - 1
README.md

@@ -5,6 +5,7 @@ The data repository for Bergmann, Sehara et al., 2022 (eye-whisker coordination
 - [01_data](01_data): contains datasets used for this study. includes the experimental / analytical procedures used to generate data.
 - [02_models](02_models): contains the computational model used for this study.
 - [03_demos](03_demos): some code demonstrating the usage of the `datareader` module, and the core part of our analysis.
+- [04_analysis](04_analysis): code for generation of analysis figures.
 
 ## Downloading the repository
 
@@ -30,4 +31,4 @@ $ git submodule foreach --recursive 'gin init'
 
 ----
 
-Copyright (c) 2022, Ronny Bergmann, [Keisuke Sehara](https://orcid.org/0000-0003-4368-8143), Sina E. Dominiak, [Julien Colomb](https://orcid.org/0000-0002-3127-5520), [Jens Kremkow](https://orcid.org/0000-0001-7077-4528), [Matthew E. Larkum](https://orcid.org/0000-0001-9799-2656), [Robert N. S. Sachdev](https://orcid.org/0000-0002-6627-0199), [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/)
+Copyright (c) 2022, [Ronny Bergmann](https://orcid.org/0000-0002-1477-7502), [Keisuke Sehara](https://orcid.org/0000-0003-4368-8143), Sina E. Dominiak, [Julien Colomb](https://orcid.org/0000-0002-3127-5520), [Jens Kremkow](https://orcid.org/0000-0001-7077-4528), [Matthew E. Larkum](https://orcid.org/0000-0001-9799-2656), [Robert N. S. Sachdev](https://orcid.org/0000-0002-6627-0199), [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/)