epoch_analysis.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import numpy as np
  2. STATES = ("AtEnd", "Backward", "Turn", "Forward", "Expect", "Lick")
  3. PATTERNS = dict(AtEnd="<AtEnd> Backward {turn}",
  4. Backward="AtEnd <Backward> {turn}",
  5. Turn="Backward <{turn}> Forward", # note a difference from the previous section
  6. Forward="{turn} <Forward>", # note a difference from the previous section
  7. Expect="{turn} Forward <Expect>",
  8. Lick="{turn} Forward Expect <Lick>")
  9. def collect_epochs_with_pattern(trials, pattern, property="right_whisker_angle_deg"):
  10. ret = []
  11. for trial in trials:
  12. # we collect traces within the matched periods, trial-by-trial
  13. traces = []
  14. for start, stop in trial.get_timeranges(pattern):
  15. if (stop - start) < 2:
  16. continue
  17. # frames start from 1, so we have to adjust when reading from tracking data
  18. epoch = np.array(trial.tracking[property][start-1:stop])
  19. traces.append(epoch)
  20. # ignore cases when the pattern did not match in this trial
  21. if len(traces) == 0:
  22. continue
  23. # otherwise: add with subject/session information
  24. ret.append(dict(session=(trial.subject, trial.session), traces=traces))
  25. return ret
  26. def normalize_time(trace, t_out):
  27. """sample temporal normalization using a 1-D trace.
  28. `t_out` is a sequence of normalized timepoints, ranging [0,1].
  29. """
  30. t_in = np.arange(0, trace.size) / (trace.size - 1)
  31. return np.interp(t_out, t_in, trace)
  32. def get_normalized_traces(epochs, bins=100):
  33. """returns a numpy.ndarray of time-normalized traces.
  34. the shape of the returned array is (bins, ntraces).
  35. `epochs` must be in the form as they are collected
  36. by `collect_epochs_with_pattern()`."""
  37. t_out = np.linspace(0, 1, bins, endpoint=True)
  38. traces = []
  39. for epoch in epochs:
  40. for trace in epoch["traces"]:
  41. traces.append(normalize_time(trace, t_out))
  42. return np.stack(traces, axis=-1)