recover_model_diff_weightf.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #!/user/bin/env python
  2. # coding=utf-8
  3. """
  4. @author: yannansu
  5. @created at: 28.09.22 10:25
  6. Recover model profiles with MLE by lmfit,
  7. Version 2: different pdf_c profiles for I and E conditions.
  8. """
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. import pandas as pd
  12. import seaborn as sns
  13. from scipy.stats import gamma
  14. # from scipy.optimize import least_squares, minimize, differential_evolution, basinhopping, shgo, brute
  15. from lmfit import Parameters, minimize, Minimizer, fit_report
  16. import json
  17. import time
  18. import datetime
  19. def gaussian(x, mu, sigma):
  20. y = 1. / (np.sqrt(2. * np.pi) * sigma) * np.exp(-np.power((x - mu) / sigma, 2.) / 2)
  21. return y / np.sum(y)
  22. def mix_gamma(loc, a, scale):
  23. w = np.zeros(shape=(2, nss))
  24. w[0, :] = gamma.pdf(ss, loc=loc, a=a, scale=scale) # cw
  25. w[1, :] = gamma.pdf(-ss, loc=loc, a=a, scale=scale) # ccw
  26. w /= np.sum(w)
  27. return np.sum(w, axis=0)
  28. def simul_estimate(stim, n_sample, pars):
  29. sigma, loc, a, scale, truc_p, motor_bias = pars[0], pars[1], pars[2], pars[3], pars[4], pars[5]
  30. pdf_c = mix_gamma(loc, a, scale)
  31. pdf_c_cw = np.hstack([np.repeat(0, sum(ss < 0)), +pdf_c[ss >= 0]])
  32. pdf_c_ccw = np.hstack([pdf_c[ss < 0], np.repeat(0, sum(ss >= 0))])
  33. if sampling:
  34. non_truc_n = int(n_sample * (1 - truc_p))
  35. np.random.seed(i_btrp)
  36. measure_samples = np.hstack([np.random.normal(s_i, sigma, n_sample) for s_i in stim])
  37. pdf_s = np.vstack([gaussian(x=ss, mu=m, sigma=sigma) for m in measure_samples])
  38. samples_size = np.array([len(stim), n_sample])
  39. # reshape to n_stim x n_sample x n_ss
  40. pdf_s = np.reshape(pdf_s, (samples_size[0], samples_size[1], len(ss)))
  41. measure_samples = np.reshape(measure_samples, samples_size)
  42. along_ax = 1
  43. if cmb_type == 'mix':
  44. pdf_cmb_marginal = np.vstack([pdf_s[sidx[0], sidx[1], :] * (1 - w_c) + pdf_c * w_c if sidx[1] < non_truc_n \
  45. else (
  46. pdf_s[sidx, :] * (1 - w_c) + pdf_c_ccw * w_c if s <= 0 else (
  47. pdf_s[sidx[0], sidx[1], :] * (1 - w_c) + pdf_c_cw * w_c))
  48. for sidx, s in np.ndenumerate(measure_samples)])
  49. elif cmb_type == 'mul':
  50. pdf_cmb_marginal = np.vstack([pdf_s[sidx[0], sidx[1], :] * pdf_c if sidx[1] < non_truc_n \
  51. else (
  52. (pdf_s[sidx[0], sidx[1], :] * pdf_c_ccw) if s <= 0 else (pdf_s[sidx[0], sidx[1], :] * pdf_c_cw))
  53. for sidx, s in np.ndenumerate(measure_samples)])
  54. else:
  55. raise ValueError("combination type not defined!")
  56. pdf_cmb_marginal = np.reshape(pdf_cmb_marginal, (samples_size[0], samples_size[1], len(ss)))
  57. else:
  58. pdf_s = np.vstack([gaussian(x=ss, mu=s, sigma=sigma) for s in stim])
  59. # pdf_s = np.vstack([gaussian(x=ss, mu=m, sigma=sigma) for m in mm])
  60. # measure_samples = stim
  61. along_ax = 0
  62. if truc_p not in [0, 1]:
  63. raise ValueError("Truncation portion must be 0 or 1 due to no sampling!")
  64. if cmb_type == 'mix':
  65. if truc_p == 0:
  66. pdf_cmb_marginal = np.vstack([pdf_s[sidx[0], sidx[1], :] * (1 - w_c) + pdf_c * w_c for sidx, s in
  67. np.ndenumerate(stim)])
  68. else:
  69. pdf_cmb_marginal = np.vstack([pdf_s[sidx, :] * (1 - w_c) + pdf_c_cw * w_c if s <= 0 else (
  70. pdf_s[sidx, :] * (1 - w_c) + pdf_c_ccw * w_c)
  71. for sidx, s in np.ndenumerate(stim)])
  72. elif cmb_type == 'mul':
  73. if truc_p == 0:
  74. pdf_cmb_marginal = np.vstack([pdf_s[sidx, :] * pdf_c for sidx, s in np.ndenumerate(stim)])
  75. else:
  76. pdf_cmb_marginal = np.vstack([(pdf_s[sidx, :] * pdf_c_ccw) if s <= 0 else (pdf_s[sidx, :] * pdf_c_cw)
  77. for sidx, s in np.ndenumerate(stim)])
  78. else:
  79. raise ValueError("combination type not defined!")
  80. pdf_cmb = np.array([(p.T / np.sum(p, axis=along_ax)).T for p in pdf_cmb_marginal])
  81. # MAP = np.array([np.max(p, axis=along_ax) for p in pdf_cmb])
  82. # MAP_s = np.array([ss[np.argmax(p, axis=along_ax)] for p in pdf_cmb])
  83. # median_s = np.array([ss[np.argmin(np.abs(np.cumsum(p, axis=along_ax) - .5), axis=along_ax)] for p in pdf_cmb])
  84. # mean_s = np.array([np.sum(ss * p, axis=along_ax) for p in pdf_cmb])
  85. #
  86. # pred = mean_s + motor_bias
  87. if sampling:
  88. pdf_cmb = np.mean(pdf_cmb, axis=1)
  89. return pdf_cmb
  90. def find_nearest(arr, vals):
  91. ids = [(np.abs(arr - v)).argmin() for v in vals]
  92. return np.array(ids)
  93. def negloglik(pars, df):
  94. # sigma_l, sigma_h = pars[0], pars[1]
  95. # loc_i, loc_e = pars[2], pars[3]
  96. # a_i, a_e = pars[4], pars[5]
  97. # scale_i, scale_e = pars[6], pars[7]
  98. # truc_p_i, truc_p_e = pars[8]/10., pars[9]/10.
  99. # motor_bias_i, motor_bias_e = pars[10], pars[11]
  100. # if I and E only differ in truncation portion
  101. vals = pars.valuesdict()
  102. sigma_l, sigma_h = vals['sigma_l'], vals['sigma_h']
  103. loc_i, loc_e = vals['loc_i'], vals['loc_e']
  104. a_i, a_e = vals['a_i'], vals['a_e']
  105. scale_i, scale_e = vals['scale_i'], vals['scale_e']
  106. truc_p_i, truc_p_e = vals['truc_p_i'] / 10., vals['truc_p_e'] / 10.
  107. motor_bias_i, motor_bias_e = vals['motor_bias'], vals['motor_bias']
  108. noises = np.sort(df['noise'].unique())
  109. unique_stim = df['relative_stimulus'].unique()
  110. n_obs = int(df['relative_stimulus'].value_counts()[0] / 4)
  111. lik_i = []
  112. lik_e = []
  113. err = 0
  114. for idx, sigma in enumerate([sigma_l, sigma_h]):
  115. df_i = df[(df.condition == 'I') & (df.noise == noises[idx])]
  116. df_e = df[(df.condition == 'E') & (df.noise == noises[idx])]
  117. # stim_i = df_i['relative_stimulus'].to_numpy()
  118. # stim_e = df_e['relative_stimulus'].to_numpy()
  119. obs_i = np.reshape(df_i['relative_response'].to_numpy(), (len(unique_stim), n_obs)) - motor_bias_i
  120. obs_e = np.reshape(df_e['relative_response'].to_numpy(), (len(unique_stim), n_obs)) - motor_bias_e
  121. # bootstrap sampling from measured data
  122. np.random.seed(i_btrp)
  123. # print(np.random.normal(1))
  124. obs_i_btrp = np.apply_along_axis(lambda x: np.random.choice(x, replace=True, size=n_obs), arr=obs_i, axis=1)
  125. obs_e_btrp = np.apply_along_axis(lambda x: np.random.choice(x, replace=True, size=n_obs), arr=obs_e, axis=1)
  126. # obs_i_btrp = obs_i
  127. # obs_e_btrp = obs_e
  128. pdf_cmb_i = simul_estimate(unique_stim, n_sample, [sigma, loc_i, a_i, scale_i, truc_p_i, motor_bias_i])
  129. pdf_cmb_e = simul_estimate(unique_stim, n_sample, [sigma, loc_e, a_e, scale_e, truc_p_e, motor_bias_e])
  130. lik_i.append([np.sum(np.log10(pdf[find_nearest(ss, obs)])) for obs, pdf in zip(obs_i_btrp, pdf_cmb_i)])
  131. lik_e.append([np.sum(np.log10(pdf[find_nearest(ss, obs)])) for obs, pdf in zip(obs_e_btrp, pdf_cmb_e)])
  132. print(-(np.sum(lik_i) + np.sum(lik_e)))
  133. return -(np.sum(lik_i) + np.sum(lik_e))
  134. # def negloglik(pars, stim, obs):
  135. # n_stim = len(stim)
  136. # residuals = obs - simul_estimate(stim, pars)
  137. # ll = -(n_stim * 1 / 2) * (1 + np.log(2 * np.pi)) - (n_stim / 2) * np.log(residuals.dot(residuals) / n_stim)
  138. # # print(-ll/n_stim)
  139. # return -ll/n_stim
  140. """
  141. ===========================================================
  142. Run modeling from here
  143. ===========================================================
  144. """
  145. # save directory
  146. save_dir = 'data_analysis/recover_model_v4/'
  147. # simulus grids
  148. ss = np.arange(-60., 60., .5)
  149. nss = len(ss)
  150. # measurement grids
  151. # mm = np.arange(-80, 80, .5)
  152. # nmm = len(mm)
  153. # model settings
  154. n_sample = 100
  155. sampling = False
  156. cmb_type = 'mul'
  157. w_c = None
  158. n_btrp = 100
  159. # """
  160. model_cfg = {
  161. 'date': datetime.datetime.now().strftime('%c'),
  162. 'n_sample': n_sample,
  163. 'sampling': sampling,
  164. 'cmb_type': cmb_type,
  165. 'w_c': None,
  166. 'n_btrp': n_btrp
  167. }
  168. json.dump(model_cfg, open(save_dir + f'model_cfg.json', 'w'))
  169. # param settings
  170. params = Parameters()
  171. # add with tuples: (NAME VALUE VARY MIN MAX EXPR BRUTE_STEP)
  172. params.add_many(
  173. ('sigma_l', 3, True, 1, 20, None, .1),
  174. ('sigma_h', 5, True, 1, 20, None, .1),
  175. ('loc_i', -.5, True, -10, 10, None, .1),
  176. ('loc_e', -.5, True, -10, 10, None, .1),
  177. ('a_i', 3, True, 2, 30, None, .1),
  178. ('a_e', 3, True, 2, 30, None, .1),
  179. ('scale_i', 10, True, 1, 50, None, .1),
  180. ('scale_e', 10, True, 1, 50, None, .1),
  181. ('truc_p_i', 0, False, 0, 10, None, .1),
  182. ('truc_p_e', 0, False, 0, 10, None, .1),
  183. ('motor_bias', -.5, True, -5, 5, None, .1),
  184. )
  185. params.pretty_print()
  186. params.dump(open(save_dir + f'param_cfg.json', 'w'))
  187. all_dat = pd.read_csv('data/dat.csv')
  188. # for sub in ['sPool']:
  189. for sub in ['S1', 'S2', 'S3', 'S4', 'S5']:
  190. if sub != 'sPool':
  191. dat = all_dat.query("sub == @sub").sort_values(by=['noise', 'condition', 'relative_stimulus'])
  192. else:
  193. dat = all_dat.sort_values(by=['noise', 'condition', 'relative_stimulus'])
  194. res_list = []
  195. stim_finer = np.arange(-20, 20, .5)
  196. res_pdf = np.zeros(shape=[4, n_btrp, len(stim_finer), nss])
  197. start = time.time()
  198. # bootstrapping
  199. for i_btrp in range(n_btrp):
  200. # np.random.seed(i_btrp)
  201. fitter = Minimizer(negloglik, params, fcn_args=[dat], nan_policy='omit')
  202. res = fitter.minimize(method='nelder')
  203. print(fit_report(res))
  204. print(f"time:{time.time() - start}")
  205. res_df = pd.DataFrame([{'sub': sub,
  206. 'i_btrp': i_btrp,
  207. 'sigma_l': res.params['sigma_l'].value,
  208. 'sigma_h': res.params['sigma_h'].value,
  209. 'loc_i': res.params['loc_i'].value,
  210. 'loc_e': res.params['loc_e'].value,
  211. 'a_i': res.params['a_i'].value,
  212. 'a_e': res.params['a_e'].value,
  213. 'scale_i': res.params['scale_i'].value,
  214. 'scale_e': res.params['scale_e'].value,
  215. 'truc_p_i': res.params['truc_p_i'].value/10.,
  216. 'truc_p_e': res.params['truc_p_e'].value/10.,
  217. 'motor_bias': res.params['motor_bias'].value,
  218. 'negloglik': negloglik(res.params, dat),
  219. 'chi-square': res.chisqr,
  220. 'aic': res.aic,
  221. 'bic': res.bic
  222. }])
  223. # res_df.to_csv(save_dir + f'{sub}_pars.csv', index=False)
  224. res_list.append(res_df)
  225. # low+I, low+E, high+I, high+E
  226. res_pdf[0, i_btrp, :, :] = simul_estimate(stim_finer, n_sample, res_df[
  227. ['sigma_l', 'loc_i', 'a_i', 'scale_i', 'truc_p_i', 'motor_bias']].values.flatten())
  228. res_pdf[1, i_btrp, :, :] = simul_estimate(stim_finer, n_sample, res_df[
  229. ['sigma_l', 'loc_e', 'a_e', 'scale_e', 'truc_p_e', 'motor_bias']].values.flatten())
  230. res_pdf[2, i_btrp, :, :] = simul_estimate(stim_finer, n_sample, res_df[
  231. ['sigma_h', 'loc_i', 'a_i', 'scale_i', 'truc_p_i', 'motor_bias']].values.flatten())
  232. res_pdf[3, i_btrp, :, :] = simul_estimate(stim_finer, n_sample, res_df[
  233. ['sigma_h', 'loc_e', 'a_e', 'scale_e', 'truc_p_e', 'motor_bias']].values.flatten())
  234. # Save parameters
  235. pd.concat(res_list, ignore_index=True).to_csv(save_dir + f'{sub}_pars.csv', index=False)
  236. # Save simulation results
  237. np.save(save_dir + f"{sub}_pdf_cmb.npy", res_pdf)
  238. # """