from scipy.optimize import least_squares from statsmodels.stats.stattools import durbin_watson import numpy as np def fitKmodel(subdata, nolog=None, pfit=None, p0=None): """ Parameters: - subdata : subject data - nolog : If zero, logarithm is not used (default is 0) - pfit : A list of logical values to indicate which parameters to fit (default is [True, True, True]) - p0 : Initial parameters (must always have length 3) Returns: - px : Parameters of the model By S.Glasauer 2019 (matlab), translated to Python by Strongway # add AIC and DW """ # Handle default arguments if p0 is None: p0 = [1., 1, 0] if pfit is None: pfit = [True, True, True] if nolog is None: nolog = 0 # Convert pfit to logical and filter p0 pfit = np.array(pfit, dtype=bool) p0 = np.array(p0)[pfit] # Lower bounds (lb) for the optimization lb = np.array([0, 0, -np.inf])[pfit] # extract Duration and Reproduction from subdata as 2d array x = subdata['Duration'].values y = subdata['Reproduction'].values # replace extreme y with nan with y > 3 * x or y < x/3 y[(y > 3 * x) | (y < x/3)] = np.nan # combine x,y as 2d array stimrep = np.vstack([x,y]).T # Perform the optimization using least_squares (equivalent to lsqnonlin in MATLAB) result = least_squares(kmodelY, p0, args = (stimrep, 1), bounds=(lb, np.inf), method='trf') # calculate kalmann filter parameters q11 = result.x[0] q22 = result.x[1] r = 1 # calculate residual sum of squares rss = np.sum(result.fun**2) dw = durbin_watson(result.fun) # number of parameters k = len(result.x) # number of observations n = len(stimrep) # calculate the log-likelihood ll = -n/2*(np.log(2*np.pi) + np.log(rss/n) + 1) # calculate the Akaike information criterion (AIC) aic = 2*k - 2*ll # steady state solution p22 = (q22+np.sqrt(q22*q22+4*(q11+r)*q22))/2 K = np.array([p22 + q11, p22])/(p22+q11+r) # return the optimized parameters, steady state solution, and AIC return np.append(np.append(result.x, K), [aic, dw]) # Optimized parameters def kmodelY(par, stimrep, nolog=1, pfit=[1, 1, 1]): """ Function to perform Kalman filter-based estimation. Parameters: - par: Model parameters (if pfit = [1,1,1], then par = [q1/r, q2/r, cost-related parameter (0 for median)]) - stimrep: Stimulus representation - nolog: Flag to decide if logarithm transformation is needed - pfit: Parameter fitting list (note: len(par) = sum(pfit)) Returns: - sres: Stimulus residuals - xest: Estimated state - pest: Estimate error covariance - resp: Response - perr: Prediction error S.Glasauer 2019/2023, translated to Python by Strongway """ # Convert pfit to a boolean array pfit = np.array(pfit, dtype=bool) # Adjust pfit based on the size of par if len(par) < 3: pfit[len(par):] = False # Adjust stimrep's shape for further processing if stimrep.shape[1] == 1: stimrep = np.tile(stimrep, (1, 2)) # the first column is the stimulus, the second column is the response, # and add the third column to indicate the start of a new sequence #if stimrep.shape[1] == 2: # stimrep = np.hstack((stimrep, np.zeros((stimrep.shape[0], 1)))) # stimrep[0, 2] = 1 # Initialize pars and overwrite with provided parameters based on pfit pars = np.array([0.0, 0.0, 0.0]) pars[pfit] = par par = pars # Constants for the model a = 10.0 off = 1. r = 1. q1 = par[0] * r q2 = par[1] * r # Define matrices Q, P, H, and F for the Kalman filter of two-state model # details see Glasauer & Shi, 2022, Sci. Rep., https://doi.org/10.1038/s41598-022-14939-8 Q = np.array([[q1, 0], [0, q2]]) P = np.array([[r, 0], [0, r]]) H = np.array([[1., 0]]) F = np.array([[0, 1.], [0, 1.]]) # Apply logarithm transformation if nolog is false if nolog: z = stimrep[:, 0] else: # log transformation z = np.log(a * stimrep[:, 0] + off) # Initialize state vector x x = np.array([[z[0]], [z[0]]]) # Initialize matrices for storing results xest = np.zeros((len(z), 2)) pest = np.zeros((len(z), 2)) perr = np.zeros(len(z)) # Kalman filter estimation loop for i in range(len(z)): x = F@x P = F@P@F.T + Q K = P@H.T/(H@P@H.T + r) perr[i] = z[i] - H@x x = x + K*perr[i] P = (np.eye(2) - K@H)@P pest[i, :] = np.diag(P) xest[i, :] = x.reshape(-1) # Adjust for third parameter, if present if len(par) == 3: sh = par[2] else: sh = 0 # Compute response, adjusting for logarithm if needed if nolog: resp = xest[:, 0] + sh else: # log transformation resp = (np.exp(xest[:, 0] + sh) - off)/a # Calculate stimulus residuals sres = stimrep[:, 1] - resp # Remove NaNs from sres sres = sres[np.isfinite(sres)] return sres