gpfa_core.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # -*- coding: utf-8 -*-
  2. """
  3. GPFA core functionality.
  4. :copyright: Copyright 2015-2019 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division, print_function, unicode_literals
  8. import time
  9. import warnings
  10. import numpy as np
  11. import scipy.linalg as linalg
  12. import scipy.optimize as optimize
  13. import scipy.sparse as sparse
  14. from sklearn.decomposition import FactorAnalysis
  15. from tqdm import trange
  16. from . import gpfa_util
  17. def fit(seqs_train, x_dim=3, bin_width=20.0, min_var_frac=0.01, em_tol=1.0E-8,
  18. em_max_iters=500, tau_init=100.0, eps_init=1.0E-3, freq_ll=5,
  19. verbose=False):
  20. """
  21. Fit the GPFA model with the given training data.
  22. Parameters
  23. ----------
  24. seqs_train : np.recarray
  25. training data structure, whose n-th element (corresponding to
  26. the n-th experimental trial) has fields
  27. T : int
  28. number of bins
  29. y : (#units, T) np.ndarray
  30. neural data
  31. x_dim : int, optional
  32. state dimensionality
  33. Default: 3
  34. bin_width : float, optional
  35. spike bin width in msec
  36. Default: 20.0
  37. min_var_frac : float, optional
  38. fraction of overall data variance for each observed dimension to set as
  39. the private variance floor. This is used to combat Heywood cases,
  40. where ML parameter learning returns one or more zero private variances.
  41. Default: 0.01
  42. (See Martin & McDonald, Psychometrika, Dec 1975.)
  43. em_tol : float, optional
  44. stopping criterion for EM
  45. Default: 1e-8
  46. em_max_iters : int, optional
  47. number of EM iterations to run
  48. Default: 500
  49. tau_init : float, optional
  50. GP timescale initialization in msec
  51. Default: 100
  52. eps_init : float, optional
  53. GP noise variance initialization
  54. Default: 1e-3
  55. freq_ll : int, optional
  56. data likelihood is computed at every freq_ll EM iterations. freq_ll = 1
  57. means that data likelihood is computed at every iteration.
  58. Default: 5
  59. verbose : bool, optional
  60. specifies whether to display status messages
  61. Default: False
  62. Returns
  63. -------
  64. parameter_estimates : dict
  65. Estimated model parameters.
  66. When the GPFA method is used, following parameters are contained
  67. covType: {'rbf', 'tri', 'logexp'}
  68. type of GP covariance
  69. gamma: np.ndarray of shape (1, #latent_vars)
  70. related to GP timescales by 'bin_width / sqrt(gamma)'
  71. eps: np.ndarray of shape (1, #latent_vars)
  72. GP noise variances
  73. d: np.ndarray of shape (#units, 1)
  74. observation mean
  75. C: np.ndarray of shape (#units, #latent_vars)
  76. mapping between the neuronal data space and the latent variable
  77. space
  78. R: np.ndarray of shape (#units, #latent_vars)
  79. observation noise covariance
  80. fit_info : dict
  81. Information of the fitting process and the parameters used there
  82. iteration_time : list
  83. containing the runtime for each iteration step in the EM algorithm.
  84. """
  85. # For compute efficiency, train on equal-length segments of trials
  86. seqs_train_cut = gpfa_util.cut_trials(seqs_train)
  87. if len(seqs_train_cut) == 0:
  88. warnings.warn('No segments extracted for training. Defaulting to '
  89. 'segLength=Inf.')
  90. seqs_train_cut = gpfa_util.cut_trials(seqs_train, seg_length=np.inf)
  91. # ==================================
  92. # Initialize state model parameters
  93. # ==================================
  94. params_init = dict()
  95. params_init['covType'] = 'rbf'
  96. # GP timescale
  97. # Assume binWidth is the time step size.
  98. params_init['gamma'] = (bin_width / tau_init) ** 2 * np.ones(x_dim)
  99. # GP noise variance
  100. params_init['eps'] = eps_init * np.ones(x_dim)
  101. # ========================================
  102. # Initialize observation model parameters
  103. # ========================================
  104. print('Initializing parameters using factor analysis...')
  105. y_all = np.hstack(seqs_train_cut['y'])
  106. fa = FactorAnalysis(n_components=x_dim, copy=True,
  107. noise_variance_init=np.diag(np.cov(y_all, bias=True)))
  108. fa.fit(y_all.T)
  109. params_init['d'] = y_all.mean(axis=1)
  110. params_init['C'] = fa.components_.T
  111. params_init['R'] = np.diag(fa.noise_variance_)
  112. # Define parameter constraints
  113. params_init['notes'] = {
  114. 'learnKernelParams': True,
  115. 'learnGPNoise': False,
  116. 'RforceDiagonal': True,
  117. }
  118. # =====================
  119. # Fit model parameters
  120. # =====================
  121. print('\nFitting GPFA model...')
  122. params_est, seqs_train_cut, ll_cut, iter_time = em(
  123. params_init, seqs_train_cut, min_var_frac=min_var_frac,
  124. max_iters=em_max_iters, tol=em_tol, freq_ll=freq_ll, verbose=verbose)
  125. fit_info = {'iteration_time': iter_time, 'log_likelihoods': ll_cut}
  126. return params_est, fit_info
  127. def em(params_init, seqs_train, max_iters=500, tol=1.0E-8, min_var_frac=0.01,
  128. freq_ll=5, verbose=False):
  129. """
  130. Fits GPFA model parameters using expectation-maximization (EM) algorithm.
  131. Parameters
  132. ----------
  133. params_init : dict
  134. GPFA model parameters at which EM algorithm is initialized
  135. covType : {'rbf', 'tri', 'logexp'}
  136. type of GP covariance
  137. gamma : np.ndarray of shape (1, #latent_vars)
  138. related to GP timescales by
  139. 'bin_width / sqrt(gamma)'
  140. eps : np.ndarray of shape (1, #latent_vars)
  141. GP noise variances
  142. d : np.ndarray of shape (#units, 1)
  143. observation mean
  144. C : np.ndarray of shape (#units, #latent_vars)
  145. mapping between the neuronal data space and the
  146. latent variable space
  147. R : np.ndarray of shape (#units, #latent_vars)
  148. observation noise covariance
  149. seqs_train : np.recarray
  150. training data structure, whose n-th entry (corresponding to the n-th
  151. experimental trial) has fields
  152. T : int
  153. number of bins
  154. y : np.ndarray (yDim x T)
  155. neural data
  156. max_iters : int, optional
  157. number of EM iterations to run
  158. Default: 500
  159. tol : float, optional
  160. stopping criterion for EM
  161. Default: 1e-8
  162. min_var_frac : float, optional
  163. fraction of overall data variance for each observed dimension to set as
  164. the private variance floor. This is used to combat Heywood cases,
  165. where ML parameter learning returns one or more zero private variances.
  166. Default: 0.01
  167. (See Martin & McDonald, Psychometrika, Dec 1975.)
  168. freq_ll : int, optional
  169. data likelihood is computed at every freq_ll EM iterations.
  170. freq_ll = 1 means that data likelihood is computed at every
  171. iteration.
  172. Default: 5
  173. verbose : bool, optional
  174. specifies whether to display status messages
  175. Default: False
  176. Returns
  177. -------
  178. params_est : dict
  179. GPFA model parameter estimates, returned by EM algorithm (same
  180. format as params_init)
  181. seqs_latent : np.recarray
  182. a copy of the training data structure, augmented with the new
  183. fields:
  184. latent_variable : np.ndarray of shape (#latent_vars x #bins)
  185. posterior mean of latent variables at each time bin
  186. Vsm : np.ndarray of shape (#latent_vars, #latent_vars, #bins)
  187. posterior covariance between latent variables at each
  188. timepoint
  189. VsmGP : np.ndarray of shape (#bins, #bins, #latent_vars)
  190. posterior covariance over time for each latent
  191. variable
  192. ll : list
  193. list of log likelihoods after each EM iteration
  194. iter_time : list
  195. lisf of computation times (in seconds) for each EM iteration
  196. """
  197. params = params_init
  198. t = seqs_train['T']
  199. y_dim, x_dim = params['C'].shape
  200. lls = []
  201. ll_old = ll_base = ll = 0.0
  202. iter_time = []
  203. var_floor = min_var_frac * np.diag(np.cov(np.hstack(seqs_train['y'])))
  204. seqs_latent = None
  205. # Loop once for each iteration of EM algorithm
  206. for iter_id in trange(1, max_iters + 1, desc='EM iteration',
  207. disable=not verbose):
  208. if verbose:
  209. print()
  210. tic = time.time()
  211. get_ll = (np.fmod(iter_id, freq_ll) == 0) or (iter_id <= 2)
  212. # ==== E STEP =====
  213. if not np.isnan(ll):
  214. ll_old = ll
  215. seqs_latent, ll = exact_inference_with_ll(seqs_train, params,
  216. get_ll=get_ll)
  217. lls.append(ll)
  218. # ==== M STEP ====
  219. sum_p_auto = np.zeros((x_dim, x_dim))
  220. for seq_latent in seqs_latent:
  221. sum_p_auto += seq_latent['Vsm'].sum(axis=2) \
  222. + seq_latent['latent_variable'].dot(
  223. seq_latent['latent_variable'].T)
  224. y = np.hstack(seqs_train['y'])
  225. latent_variable = np.hstack(seqs_latent['latent_variable'])
  226. sum_yxtrans = y.dot(latent_variable.T)
  227. sum_xall = latent_variable.sum(axis=1)[:, np.newaxis]
  228. sum_yall = y.sum(axis=1)[:, np.newaxis]
  229. # term is (xDim+1) x (xDim+1)
  230. term = np.vstack([np.hstack([sum_p_auto, sum_xall]),
  231. np.hstack([sum_xall.T, t.sum().reshape((1, 1))])])
  232. # yDim x (xDim+1)
  233. cd = gpfa_util.rdiv(np.hstack([sum_yxtrans, sum_yall]), term)
  234. params['C'] = cd[:, :x_dim]
  235. params['d'] = cd[:, -1]
  236. # yCent must be based on the new d
  237. # yCent = bsxfun(@minus, [seq.y], currentParams.d);
  238. # R = (yCent * yCent' - (yCent * [seq.latent_variable]') * \
  239. # currentParams.C') / sum(T);
  240. c = params['C']
  241. d = params['d'][:, np.newaxis]
  242. if params['notes']['RforceDiagonal']:
  243. sum_yytrans = (y * y).sum(axis=1)[:, np.newaxis]
  244. yd = sum_yall * d
  245. term = ((sum_yxtrans - d.dot(sum_xall.T)) * c).sum(axis=1)
  246. term = term[:, np.newaxis]
  247. r = d ** 2 + (sum_yytrans - 2 * yd - term) / t.sum()
  248. # Set minimum private variance
  249. r = np.maximum(var_floor, r)
  250. params['R'] = np.diag(r[:, 0])
  251. else:
  252. sum_yytrans = y.dot(y.T)
  253. yd = sum_yall.dot(d.T)
  254. term = (sum_yxtrans - d.dot(sum_xall.T)).dot(c.T)
  255. r = d.dot(d.T) + (sum_yytrans - yd - yd.T - term) / t.sum()
  256. params['R'] = (r + r.T) / 2 # ensure symmetry
  257. if params['notes']['learnKernelParams']:
  258. res = learn_gp_params(seqs_latent, params, verbose=verbose)
  259. params['gamma'] = res['gamma']
  260. t_end = time.time() - tic
  261. iter_time.append(t_end)
  262. # Verify that likelihood is growing monotonically
  263. if iter_id <= 2:
  264. ll_base = ll
  265. elif verbose and ll < ll_old:
  266. print('\nError: Data likelihood has decreased ',
  267. 'from {0} to {1}'.format(ll_old, ll))
  268. elif (ll - ll_base) < (1 + tol) * (ll_old - ll_base):
  269. break
  270. if len(lls) < max_iters:
  271. print('Fitting has converged after {0} EM iterations.)'.format(
  272. len(lls)))
  273. if np.any(np.diag(params['R']) == var_floor):
  274. warnings.warn('Private variance floor used for one or more observed '
  275. 'dimensions in GPFA.')
  276. return params, seqs_latent, lls, iter_time
  277. def exact_inference_with_ll(seqs, params, get_ll=True):
  278. """
  279. Extracts latent trajectories from neural data, given GPFA model parameters.
  280. Parameters
  281. ----------
  282. seqs : np.recarray
  283. Input data structure, whose n-th element (corresponding to the n-th
  284. experimental trial) has fields:
  285. y : np.ndarray of shape (#units, #bins)
  286. neural data
  287. T : int
  288. number of bins
  289. params : dict
  290. GPFA model parameters whe the following fields:
  291. C : np.ndarray
  292. FA factor loadings matrix
  293. d : np.ndarray
  294. FA mean vector
  295. R : np.ndarray
  296. FA noise covariance matrix
  297. gamma : np.ndarray
  298. GP timescale
  299. eps : np.ndarray
  300. GP noise variance
  301. get_ll : bool, optional
  302. specifies whether to compute data log likelihood (default: True)
  303. Returns
  304. -------
  305. seqs_latent : np.recarray
  306. a copy of the input data structure, augmented with the new
  307. fields:
  308. latent_variable : (#latent_vars, #bins) np.ndarray
  309. posterior mean of latent variables at each time bin
  310. Vsm : (#latent_vars, #latent_vars, #bins) np.ndarray
  311. posterior covariance between latent variables at each
  312. timepoint
  313. VsmGP : (#bins, #bins, #latent_vars) np.ndarray
  314. posterior covariance over time for each latent
  315. variable
  316. ll : float
  317. data log likelihood, np.nan is returned when `get_ll` is set False
  318. """
  319. y_dim, x_dim = params['C'].shape
  320. # copy the contents of the input data structure to output structure
  321. dtype_out = [(x, seqs[x].dtype) for x in seqs.dtype.names]
  322. dtype_out.extend([('latent_variable', np.object), ('Vsm', np.object),
  323. ('VsmGP', np.object)])
  324. seqs_latent = np.empty(len(seqs), dtype=dtype_out)
  325. for dtype_name in seqs.dtype.names:
  326. seqs_latent[dtype_name] = seqs[dtype_name]
  327. # Precomputations
  328. if params['notes']['RforceDiagonal']:
  329. rinv = np.diag(1.0 / np.diag(params['R']))
  330. logdet_r = (np.log(np.diag(params['R']))).sum()
  331. else:
  332. rinv = linalg.inv(params['R'])
  333. rinv = (rinv + rinv.T) / 2 # ensure symmetry
  334. logdet_r = gpfa_util.logdet(params['R'])
  335. c_rinv = params['C'].T.dot(rinv)
  336. c_rinv_c = c_rinv.dot(params['C'])
  337. t_all = seqs_latent['T']
  338. t_uniq = np.unique(t_all)
  339. ll = 0.
  340. # Overview:
  341. # - Outer loop on each element of Tu.
  342. # - For each element of Tu, find all trials with that length.
  343. # - Do inference and LL computation for all those trials together.
  344. for t in t_uniq:
  345. k_big, k_big_inv, logdet_k_big = gpfa_util.make_k_big(params, t)
  346. k_big = sparse.csr_matrix(k_big)
  347. blah = [c_rinv_c for _ in range(t)]
  348. c_rinv_c_big = linalg.block_diag(*blah) # (xDim*T) x (xDim*T)
  349. minv, logdet_m = gpfa_util.inv_persymm(k_big_inv + c_rinv_c_big, x_dim)
  350. # Note that posterior covariance does not depend on observations,
  351. # so can compute once for all trials with same T.
  352. # xDim x xDim posterior covariance for each timepoint
  353. vsm = np.full((x_dim, x_dim, t), np.nan)
  354. idx = np.arange(0, x_dim * t + 1, x_dim)
  355. for i in range(t):
  356. vsm[:, :, i] = minv[idx[i]:idx[i + 1], idx[i]:idx[i + 1]]
  357. # T x T posterior covariance for each GP
  358. vsm_gp = np.full((t, t, x_dim), np.nan)
  359. for i in range(x_dim):
  360. vsm_gp[:, :, i] = minv[i::x_dim, i::x_dim]
  361. # Process all trials with length T
  362. n_list = np.where(t_all == t)[0]
  363. # dif is yDim x sum(T)
  364. dif = np.hstack(seqs_latent[n_list]['y']) - params['d'][:, np.newaxis]
  365. # term1Mat is (xDim*T) x length(nList)
  366. term1_mat = c_rinv.dot(dif).reshape((x_dim * t, -1), order='F')
  367. # Compute blkProd = CRinvC_big * invM efficiently
  368. # blkProd is block persymmetric, so just compute top half
  369. t_half = np.int(np.ceil(t / 2.0))
  370. blk_prod = np.zeros((x_dim * t_half, x_dim * t))
  371. idx = range(0, x_dim * t_half + 1, x_dim)
  372. for i in range(t_half):
  373. blk_prod[idx[i]:idx[i + 1], :] = c_rinv_c.dot(
  374. minv[idx[i]:idx[i + 1], :])
  375. blk_prod = k_big[:x_dim * t_half, :].dot(
  376. gpfa_util.fill_persymm(np.eye(x_dim * t_half, x_dim * t) -
  377. blk_prod, x_dim, t))
  378. # latent_variableMat is (xDim*T) x length(nList)
  379. latent_variable_mat = gpfa_util.fill_persymm(
  380. blk_prod, x_dim, t).dot(term1_mat)
  381. for i, n in enumerate(n_list):
  382. seqs_latent[n]['latent_variable'] = \
  383. latent_variable_mat[:, i].reshape((x_dim, t), order='F')
  384. seqs_latent[n]['Vsm'] = vsm
  385. seqs_latent[n]['VsmGP'] = vsm_gp
  386. if get_ll:
  387. # Compute data likelihood
  388. val = -t * logdet_r - logdet_k_big - logdet_m \
  389. - y_dim * t * np.log(2 * np.pi)
  390. ll = ll + len(n_list) * val - (rinv.dot(dif) * dif).sum() \
  391. + (term1_mat.T.dot(minv) * term1_mat.T).sum()
  392. if get_ll:
  393. ll /= 2
  394. else:
  395. ll = np.nan
  396. return seqs_latent, ll
  397. def learn_gp_params(seqs_latent, params, verbose=False):
  398. """Updates parameters of GP state model, given neural trajectories.
  399. Parameters
  400. ----------
  401. seqs_latent : np.recarray
  402. data structure containing neural trajectories;
  403. params : dict
  404. current GP state model parameters, which gives starting point
  405. for gradient optimization;
  406. verbose : bool, optional
  407. specifies whether to display status messages (default: False)
  408. Returns
  409. -------
  410. param_opt : np.ndarray
  411. updated GP state model parameter
  412. Raises
  413. ------
  414. ValueError
  415. If `params['covType'] != 'rbf'`.
  416. If `params['notes']['learnGPNoise']` set to True.
  417. """
  418. if params['covType'] != 'rbf':
  419. raise ValueError("Only 'rbf' GP covariance type is supported.")
  420. if params['notes']['learnGPNoise']:
  421. raise ValueError("learnGPNoise is not supported.")
  422. param_name = 'gamma'
  423. fname = 'gpfa_util.grad_betgam'
  424. param_init = params[param_name]
  425. param_opt = {param_name: np.empty_like(param_init)}
  426. x_dim = param_init.shape[-1]
  427. precomp = gpfa_util.make_precomp(seqs_latent, x_dim)
  428. # Loop once for each state dimension (each GP)
  429. for i in range(x_dim):
  430. const = {'eps': params['eps'][i]}
  431. initp = np.log(param_init[i])
  432. res_opt = optimize.minimize(eval(fname), initp,
  433. args=(precomp[i], const),
  434. method='L-BFGS-B', jac=True)
  435. param_opt['gamma'][i] = np.exp(res_opt.x)
  436. if verbose:
  437. print('\n Converged p; xDim:{}, p:{}'.format(i, res_opt.x))
  438. return param_opt
  439. def orthonormalize(params_est, seqs):
  440. """
  441. Orthonormalize the columns of the loading matrix C and apply the
  442. corresponding linear transform to the latent variables.
  443. Parameters
  444. ----------
  445. params_est : dict
  446. First return value of extract_trajectory() on the training data set.
  447. Estimated model parameters.
  448. When the GPFA method is used, following parameters are contained
  449. covType : {'rbf', 'tri', 'logexp'}
  450. type of GP covariance
  451. Currently, only 'rbf' is supported.
  452. gamma : np.ndarray of shape (1, #latent_vars)
  453. related to GP timescales by 'bin_width / sqrt(gamma)'
  454. eps : np.ndarray of shape (1, #latent_vars)
  455. GP noise variances
  456. d : np.ndarray of shape (#units, 1)
  457. observation mean
  458. C : np.ndarray of shape (#units, #latent_vars)
  459. mapping between the neuronal data space and the latent variable
  460. space
  461. R : np.ndarray of shape (#units, #latent_vars)
  462. observation noise covariance
  463. seqs : np.recarray
  464. Contains the embedding of the training data into the latent variable
  465. space.
  466. Data structure, whose n-th entry (corresponding to the n-th
  467. experimental trial) has fields
  468. T : int
  469. number of timesteps
  470. y : np.ndarray of shape (#units, #bins)
  471. neural data
  472. latent_variable : np.ndarray of shape (#latent_vars, #bins)
  473. posterior mean of latent variables at each time bin
  474. Vsm : np.ndarray of shape (#latent_vars, #latent_vars, #bins)
  475. posterior covariance between latent variables at each
  476. timepoint
  477. VsmGP : np.ndarray of shape (#bins, #bins, #latent_vars)
  478. posterior covariance over time for each latent variable
  479. Returns
  480. -------
  481. params_est : dict
  482. Estimated model parameters, including `Corth`, obtained by
  483. orthonormalizing the columns of C.
  484. seqs : np.recarray
  485. Training data structure that contains the new field
  486. `latent_variable_orth`, the orthonormalized neural trajectories.
  487. """
  488. C = params_est['C']
  489. X = np.hstack(seqs['latent_variable'])
  490. latent_variable_orth, Corth, _ = gpfa_util.orthonormalize(X, C)
  491. seqs = gpfa_util.segment_by_trial(
  492. seqs, latent_variable_orth, 'latent_variable_orth')
  493. params_est['Corth'] = Corth
  494. return Corth, seqs