"""
References:
1. Kato, S. et al., 2014.
“Temporal Responses of C.elegans Chemosensory Neurons Are Preserved in Behavioral Dynamics.” Neuron, 81(3),
pp.616–628. Available at: http://dx.doi.org/10.1016/j.neuron.2013.11.020.
"""
import shutil
import tempfile
from typing import Sequence
import numpy as np
import math
import pandas as pd
from scipy.interpolate import interp1d
import pathlib as pl
from view.python_core.ctvs.cascade_model.gekko_related import fit_compare_models, GekkoSolver, ModelOneComp, \
ModelTwoComp, ModelTwoCompNoDelay
def estimate_delay_rising_phase_only(stimulus_trace: Sequence[float], output_trace: Sequence[float]):
"""
Estimate delay between the signals and using a cross correlation and only those points with
positive derivative
:param Sequence stimulus_trace: sequence of floats
:param Sequence output_trace: sequence of floats
:return: delay (number of samples)
:rtype: int
"""
stimulus_trace_non_rising_mask = (np.diff(stimulus_trace) <= 0).tolist() + [False]
stim_min, stim_max = min(stimulus_trace), max(stimulus_trace)
stimulus_trace_copy = np.array(stimulus_trace, copy=True)
stimulus_trace_copy[stimulus_trace_non_rising_mask] = np.median(stimulus_trace)
stimulus_trace_copy[stimulus_trace_copy >= stim_min + 0.85 * (stim_max - stim_min)] = np.median(stimulus_trace)
output_trace_non_rising_mask = (np.diff(output_trace) <= 0).tolist() + [False]
output_min, output_max = min(output_trace), max(output_trace)
output_trace_copy = np.array(output_trace, copy=True)
output_trace_copy[output_trace_non_rising_mask] = np.median(output_trace)
output_trace_copy[output_trace_copy >= output_min + 0.85 * (output_max - output_min)] = np.median(output_trace)
corr = np.correlate(stimulus_trace_copy, output_trace_copy, "full")
delay_estimated = np.argmax(corr) - stimulus_trace_copy.shape[0] + 1
# # code for debugging starts here
# from matplotlib import pyplot as plt
# plt.ion()
# fig, axs = plt.subplots(nrows=3, figsize=(10, 8), squeeze=True)
# axs[0].plot(stimulus_trace, 'b-')
# axs[0].plot(stimulus_trace_copy, "r-o")
# axs[0].set_ylabel("Stimulus")
#
# axs[1].plot(output_trace, 'b-')
# axs[1].plot(output_trace_copy, "r-o")
# axs[1].set_ylabel("Response")
#
# axs[2].plot(corr, "b-")
# axs[2].axvline(np.argmax(corr), color='r')
# axs[2].set_ylabel("Correlation")
# axs[2].set_title(f"Estimated delay in samples: {delay_estimated}")
#
# plt.draw()
# input("Press any key to continue..")
# plt.close()
# # code for debugging ends here
return delay_estimated
def predict_into_future(fit_params, factor):
output = factor * fit_params["output"]
trace = np.array(output)
n_pts = trace.shape[0]
time_trace = fit_params["time_trace_fitted"]
sampling_period = time_trace[1] - time_trace[0]
stimulus_trace = np.array(fit_params["stimulus_trace_fitted"])
stimulus_peak_pos = np.argmax(stimulus_trace)
central_part_start = max(0, int(stimulus_peak_pos - 0.25 * n_pts))
central_part_end = min(n_pts, int(stimulus_peak_pos + 0.25 * n_pts))
central_part = trace[central_part_start: central_part_end + 1]
max_pos = central_part.argmax() + central_part_start
prediction_start = time_trace[-2]
longest_time_constant = fit_params["kF"] + fit_params['kA']
if "kS" in fit_params:
longest_time_constant = max(fit_params["kS"], longest_time_constant)
prediction_end = time_trace[max_pos] + 6 * longest_time_constant # units is seconds
predicted_trace = None
if prediction_end > prediction_start:
prediction_time_trace = np.arange(prediction_start, prediction_end + sampling_period, sampling_period)
gs = GekkoSolver.init_from_model(fit_params["model"])
predicted_traces_dict = gs.solve(
time_vec=prediction_time_trace,
input_vec=np.zeros_like(prediction_time_trace),
parameter_values={k: fit_params[k] for k in gs.parameters},
state_variable_init_dict={k: fit_params[k][-2:] for k in gs.state_variables},
output_init=output[-2:]
)
predicted_trace = predicted_traces_dict["output"]
predicted_trace_no_overlap = None
predicted_time_trace_no_overlap = None
if predicted_trace is not None:
if predicted_trace[-1] <= predicted_trace[-2]:
predicted_trace_no_overlap = np.array(predicted_trace)[1:] * factor
predicted_time_trace_no_overlap = np.array(prediction_time_trace)[1:]
return predicted_trace_no_overlap, predicted_time_trace_no_overlap
def fit_cascade_model(
stimulus_trace, output_trace, time_trace, delays_to_test=np.arange(-15, 9, 1).astype(int)
):
stimulus_trace = np.array(stimulus_trace)
output_trace = np.array(output_trace)
time_trace = np.array(time_trace)
assert stimulus_trace.shape == output_trace.shape, "Stimulus and output traces have different shapes"
assert time_trace.shape == output_trace.shape, "time trace and output trace have different shapes"
sampling_period = time_trace[1] - time_trace[0]
pcc = PrelimChunkClassifier(output_trace)
factor = 1 if pcc.is_chunk_response_positive() else -1
bics = []
fit_params_all = []
model_one_comp = ModelOneComp()
for delay in delays_to_test:
if delay > 0:
output_trace_to_fit = output_trace[:-delay]
stimulus_trace_to_fit = stimulus_trace[delay:]
time_trace_to_fit = time_trace[delay:]
else:
output_trace_to_fit = output_trace[-delay:]
stimulus_trace_to_fit = stimulus_trace[:stimulus_trace.shape[0] + delay]
time_trace_to_fit = time_trace[:time_trace.shape[0] + delay]
fit_params = fit_compare_models(
time_vec=time_trace_to_fit,
input_vec=stimulus_trace_to_fit, output_vec=factor * output_trace_to_fit,
model2consider=model_one_comp
)
if fit_params is not None:
fit_params["output"] = factor * fit_params["output"]
fit_params["output_trace_expected"] = factor * fit_params["output_trace_expected"]
fit_params_all.append(fit_params)
bics.append(fit_params["bic"])
if len(fit_params_all) == 0:
return None
else:
best_ind = np.argmin(bics)
fit_params_best_one_comp = fit_params_all[best_ind]
delay_best = delays_to_test[best_ind]
fit_params_best_one_comp["delay_input"] = delay_best * sampling_period
cascade_model_output = dict(
fit_params_one_comp=fit_params_best_one_comp
)
fit_params_all_second = []
bics_second = []
params_fixed = ["kA"]
models_to_fit = []
param_fixed_arg = {k: [fit_params_best_one_comp[k]] * 3 for k in params_fixed}
param_fixed_arg["kF"] = [
fit_params_best_one_comp["kF"] * 0.5,
fit_params_best_one_comp["kF"],
1.5 * fit_params_best_one_comp["kF"]]
param_inits = {
k: fit_params_best_one_comp[k]
for k in fit_params_best_one_comp["model_params"] if k not in param_fixed_arg}
model_two_comp_no_delay = ModelTwoCompNoDelay(
param_inits=param_inits,
params_fixed=param_fixed_arg)
models_to_fit.append(model_two_comp_no_delay)
second_comp_delays = list(range(0, 20))
for second_comp_delay in second_comp_delays[1:]:
model_two_comp = ModelTwoComp(
delay=second_comp_delay,
param_inits=param_inits,
params_fixed=param_fixed_arg
)
models_to_fit.append(model_two_comp)
for model2fit, second_comp_delay in zip(models_to_fit, second_comp_delays):
fit_params = fit_compare_models(
time_vec=fit_params_best_one_comp["time_trace_fitted"],
input_vec=fit_params_best_one_comp["stimulus_trace_fitted"],
output_vec=factor * fit_params_best_one_comp["output_trace_expected"],
model2consider=model2fit
)
if fit_params is not None:
fit_params["output"] = factor * fit_params["output"]
fit_params["output_trace_expected"] = factor * fit_params["output_trace_expected"]
fit_params_all_second.append(fit_params)
bics_second.append(fit_params["bic"])
best_ind = None
if len(fit_params_all):
best_ind = np.argmin(bics_second)
fit_params_best_two_comp = fit_params_all_second[best_ind]
fit_params_best_two_comp["delay_second_comp"] = second_comp_delays[best_ind]
cascade_model_output["fit_params_two_comp"] = fit_params_best_two_comp
# https://imaging.mrc-cbu.cam.ac.uk/statswiki/FAQ/AICreg
# difference of 50 chosen based fitting on a test set
if fit_params_all_second[best_ind]["bic"] >= fit_params_best_one_comp["bic"] - 50:
best_ind = None
if best_ind is None:
fit_params_best = fit_params_best_one_comp
else:
fit_params_best = fit_params_all_second[best_ind]
predicted_trace, predicted_time_trace = predict_into_future(fit_params_best, factor)
cascade_model_output.update(dict(
sign_factor=factor, fit_params=fit_params_best,
predicted_time_trace=predicted_time_trace, predicted_trace=predicted_trace
))
tempdir = pl.Path(tempfile.gettempdir())
for child in tempdir.iterdir():
if child.is_dir() and child.name.startswith("tmp"):
shutil.rmtree(child)
return cascade_model_output
class PrelimChunkClassifier(object):
def __init__(self, chunk: Sequence[float]):
self.chunk = np.asarray(chunk)
self.chunk_abs = np.abs(chunk)
self.chunk_abs_argmax = np.argmax(self.chunk_abs)
def is_chunk_contaminated(self):
# is the maximum of absolute values of chunk in its middle half?
lower_limit = 0.25 * self.chunk.shape[0]
upper_limit = 0.75 * self.chunk.shape[0]
return not (lower_limit <= self.chunk_abs_argmax <= upper_limit)
def is_chunk_response_positive(self):
# is the maximum value of the chunk positive?
return self.chunk[self.chunk_abs_argmax] > 0