gekko_related.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. from gekko import GEKKO
  2. import pandas as pd
  3. import numpy as np
  4. # to counteract a bug in gekko: https://github.com/BYU-PRISM/GEKKO/issues/108
  5. from subprocess import TimeoutExpired
  6. class GekkoSolver(object):
  7. def __init__(self, state_variables, parameters, equations):
  8. self.m = GEKKO(remote=False)
  9. # Input, fixed and to be provided
  10. self.m.input = self.m.Param()
  11. # State Variables
  12. self.state_variables = {}
  13. for var_name in state_variables:
  14. exec(f"self.m.{var_name} = self.m.SV()")
  15. self.state_variables[var_name] = eval(f"self.m.{var_name}")
  16. # Parameters, fixed by default
  17. self.parameters = {}
  18. for param in parameters:
  19. exec(f"self.m.{param} = self.m.{self.get_param_var_type()}(0)")
  20. self.parameters[param] = eval(f"self.m.{param}")
  21. # Output
  22. self.m.output = self.get_output_variable()
  23. # Equations
  24. self.equations = equations
  25. eval(f"self.m.Equations({equations.replace('m.', 'self.m.')})")
  26. @classmethod
  27. def init_from_model(cls, model):
  28. state_variable_inits = model.get_state_variable_inits()
  29. # sampling period is irrelevant as we only need the parameter names
  30. parameter_inits = model.get_parameter_inits(sampling_period=1)
  31. gm = cls(
  32. state_variables=state_variable_inits.keys(),
  33. parameters=parameter_inits.keys(),
  34. equations=model.get_equations()
  35. )
  36. model.extra_init_steps(gm.m)
  37. return gm
  38. def get_param_var_type(self):
  39. return "Const"
  40. def get_output_variable(self):
  41. return self.m.Var(fixed_initial=True)
  42. def solve(self, time_vec, input_vec, parameter_values, state_variable_init_dict, output_init):
  43. self.m.output.value = output_init[-1]
  44. # initialize input
  45. self.m.input.value = input_vec
  46. # initialize time
  47. self.m.time = time_vec
  48. # initialize parameters
  49. for param_name, param_value in parameter_values.items():
  50. if param_name in self.parameters:
  51. temp = self.parameters[param_name]
  52. temp.value = param_value
  53. # initialize state variables
  54. for sv_name, sv_value in state_variable_init_dict.items():
  55. if sv_name in self.state_variables:
  56. self.state_variables[sv_name].value = sv_value[-1]
  57. self.m.options.imode = 6 # sequential dynamic simulation
  58. # set timeout to 5 minutes
  59. self.m.options.max_time = 60
  60. # solve
  61. try:
  62. self.m.solve(disp=False)
  63. except FileNotFoundError as fnfe:
  64. print(fnfe)
  65. return None
  66. except Exception as e:
  67. poss1 = str(e).find("@error: Solution Not Found") >= 0
  68. poss2 = str(e).find("Time Limit Exceeded:") >= 0
  69. poss3 = str(e).find("NameError: name \'TimeoutExpired\' is not defined")
  70. if poss1 or poss2 or poss3:
  71. print(f"Encountered an error during function solving/fitting with GEKKO: {e}")
  72. return None
  73. else:
  74. raise e
  75. sv_fit_dict = pd.Series()
  76. sv_fit_dict["output"] = np.array(self.m.output)
  77. for sv_name, sv in self.state_variables.items():
  78. sv_fit_dict[sv_name] = np.array(sv.value)
  79. return sv_fit_dict
  80. class GekkoFitter(GekkoSolver):
  81. def __init__(self, state_variables, parameters, equations):
  82. super().__init__(state_variables, parameters, equations)
  83. def get_output_variable(self):
  84. return self.m.CV()
  85. def get_param_var_type(self):
  86. return "FV"
  87. def fit(
  88. self, time_vec, input_vec, output_vec, state_variable_init_dict=None, parameter_lb_init_ub=None,
  89. ev_type=1, dead_band=None
  90. ):
  91. # output fixed, given, to be used for estimation
  92. self.m.output.status = 0
  93. self.m.output.fstatus = 1
  94. self.m.output.value = output_vec
  95. # initialize input
  96. self.m.input.value = input_vec
  97. # initialize time
  98. self.m.time = time_vec
  99. # initialize parameters initial value, lower and upper bounds
  100. for param_name, param_value in parameter_lb_init_ub.items():
  101. if param_name in self.parameters:
  102. temp = self.parameters[param_name]
  103. temp.status = 1
  104. temp.fstatus = 0
  105. temp.lower, temp.value, temp.upper = parameter_lb_init_ub[param_name]
  106. # initialize state variables
  107. for sv_name, sv_value in state_variable_init_dict.items():
  108. if sv_name in self.state_variables:
  109. self.state_variables[sv_name].value = sv_value
  110. self.m.options.imode = 5 # moving horizon estimate method
  111. self.m.options.ev_type = ev_type
  112. if dead_band is not None:
  113. self.m.options.meas_gap = dead_band
  114. # set timeout to 5 minutes
  115. self.m.options.max_time = 60
  116. try:
  117. self.m.solve(disp=False)
  118. except FileNotFoundError as fnfe:
  119. print(fnfe)
  120. return None
  121. except Exception as e:
  122. poss1 = str(e).find("@error: Solution Not Found") >= 0
  123. poss2 = str(e).find("Time Limit Exceeded:") >= 0
  124. poss3 = str(e).find("name 'TimeoutExpired' is not defined") >= 0
  125. if poss1 or poss2 or poss3:
  126. print(e)
  127. return None
  128. else:
  129. raise e
  130. sv_fit_dict = pd.Series()
  131. sv_fit_dict["output"] = np.array(self.m.output)
  132. for sv_name, sv in self.state_variables.items():
  133. sv_fit_dict[sv_name] = np.array(sv.value)
  134. for param_name, param in self.parameters.items():
  135. sv_fit_dict[param_name] = param.value[0]
  136. return sv_fit_dict
  137. class ModelOneComp(object):
  138. def __init__(self):
  139. super().__init__()
  140. def extra_init_steps(m):
  141. pass
  142. self.extra_init_steps = extra_init_steps
  143. def get_equations(self):
  144. return """
  145. [
  146. m.A.dt() == - m.A / m.kA + m.input,
  147. m.F.dt() == - m.F / m.kF + m.kAF * m.A,
  148. m.output == m.F
  149. ]
  150. """
  151. def get_state_variable_inits(self):
  152. return {"F": 0, "A": 0}
  153. def get_parameter_inits(self, sampling_period):
  154. return {
  155. "kF": np.array([0.01, 20, 200]) * sampling_period,
  156. "kAF": np.array([1e-3, 1, 200]) * sampling_period,
  157. "kA": np.array([0.01, 1, 200]) * sampling_period,
  158. }
  159. class ModelTwoCompNoDelay(ModelOneComp):
  160. def __init__(self, param_inits=None, params_fixed=None):
  161. super().__init__()
  162. self.param_inits = param_inits
  163. self.params_fixed = params_fixed
  164. def f(m):
  165. pass
  166. self.extra_init_steps = f
  167. def get_equations(self):
  168. return """
  169. [
  170. m.A.dt() == - m.A / m.kA + m.input,
  171. m.F.dt() == - m.F / m.kF + m.kAF * m.A,
  172. m.S.dt() == - m.S / m.kS - m.kAS * m.A,
  173. m.output == m.F + m.S,
  174. m.kS > m.kF
  175. ]
  176. """
  177. def get_state_variable_inits(self):
  178. temp = super().get_state_variable_inits()
  179. temp.update({"S": 0})
  180. return temp
  181. def get_parameter_inits(self, sampling_period):
  182. temp = super().get_parameter_inits(sampling_period)
  183. temp.update(
  184. {
  185. "kS": np.array([1, 100, 200]) * sampling_period,
  186. "kAS": np.array([-200, 1, 200]) * sampling_period
  187. })
  188. if self.param_inits is not None:
  189. for k, v in self.param_inits.items():
  190. if k in temp:
  191. temp[k][1] = v # use specified value as starting point
  192. if self.params_fixed is not None:
  193. for k, v in self.params_fixed.items():
  194. if k in temp:
  195. temp[k] = v # use specified starting points and limits
  196. return temp
  197. class ModelTwoComp(ModelTwoCompNoDelay):
  198. def __init__(self, delay, param_inits=None, params_fixed=None):
  199. super().__init__(param_inits=param_inits, params_fixed=params_fixed)
  200. self.delay = delay
  201. def extra_init_steps(m):
  202. if delay > 0:
  203. m.delay(m.A, m.A_delayed, delay)
  204. self.extra_init_steps = extra_init_steps
  205. def get_equations(self):
  206. return """
  207. [
  208. m.A.dt() == - m.A / m.kA + m.input,
  209. m.F.dt() == - m.F / m.kF + m.kAF * m.A,
  210. m.S.dt() == - m.S / m.kS - m.kAS * m.A_delayed,
  211. m.output == m.F + m.S,
  212. m.kS > m.kF
  213. ]
  214. """
  215. def get_state_variable_inits(self):
  216. temp = super().get_state_variable_inits()
  217. temp.update({"A_delayed": 0})
  218. return temp
  219. def fit_compare_models(time_vec, output_vec, model2consider, input_vec=None, dead_band=None):
  220. """
  221. Fits output_vec and time_vec to the multiple models and returns the paramters of the best model fit
  222. based on AIC. Loosely based on
  223. "Temporal Responses of C. elegans Chemosensory Neurons Are Preserved in Behavioral Dynamics"
  224. # https://doi.org/10.1016/j.neuron.2013.11.020, Figure 3
  225. :param Sequence time_vec: time values
  226. :param Sequence output_vec: ca response values
  227. :param model2consider: object of either ModelOneComp or ModelTwoComp
  228. :param Sequence input_vec: input values, with maximum scaled to 1
  229. See https://gekko.readthedocs.io/en/latest/global.html#ev-type
  230. :param dead_band: noise half band-width of output signal. Fitting cost in this band will be 0.
  231. See https://gekko.readthedocs.io/en/latest/tuning_params.html#meas-gap
  232. :return:
  233. """
  234. sampling_period = time_vec[1] - time_vec[0]
  235. output_vec = np.asarray(output_vec)
  236. gm = GekkoFitter.init_from_model(model=model2consider)
  237. state_variable_inits = model2consider.get_state_variable_inits()
  238. parameter_inits = model2consider.get_parameter_inits(sampling_period=sampling_period)
  239. fit_params = gm.fit(
  240. time_vec=time_vec.copy(), input_vec=input_vec.copy(), output_vec=output_vec.copy(),
  241. state_variable_init_dict=state_variable_inits,
  242. parameter_lb_init_ub=parameter_inits,
  243. ev_type=2, dead_band=dead_band
  244. )
  245. if fit_params is not None:
  246. # residual sum of squares
  247. rss = np.power(output_vec - fit_params["output"], 2).sum()
  248. # definition from https://en.wikipedia.org/wiki/Bayesian_information_criterion
  249. n = output_vec.shape[0]
  250. k = len(parameter_inits)
  251. bic = n * np.log(rss / n) + k * np.log(n)
  252. fit_params["model_params"] = list(parameter_inits.keys())
  253. fit_params["model_name"] = model2consider.__class__.__name__
  254. fit_params["model"] = model2consider
  255. fit_params["bic"] = bic
  256. fit_params["stimulus_trace_fitted"] = input_vec
  257. fit_params["output_trace_expected"] = output_vec
  258. fit_params["time_trace_fitted"] = time_vec
  259. return fit_params
  260. # if len(bics):
  261. # best_model_ind = np.argmin(bics)
  262. #
  263. # return \
  264. # fit_params_all[best_model_ind], gms_all[best_model_ind], \
  265. # {f"{model.__class__.__name__};bic={bic:2.3g}": fit_params["output"]
  266. # for model, fit_params, bic in zip(models_to_consider, fit_params_all, bics)}
  267. # else:
  268. # return None, None, None