gpfa_util.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. # -*- coding: utf-8 -*-
  2. """
  3. GPFA util functions.
  4. :copyright: Copyright 2015-2019 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division, print_function, unicode_literals
  8. import warnings
  9. import numpy as np
  10. import quantities as pq
  11. import scipy as sp
  12. from elephant.conversion import BinnedSpikeTrain
  13. from elephant.utils import deprecated_alias
  14. @deprecated_alias(binsize='bin_size')
  15. def get_seqs(data, bin_size, use_sqrt=True):
  16. """
  17. Converts the data into a rec array using internally BinnedSpikeTrain.
  18. Parameters
  19. ----------
  20. data : list of list of neo.SpikeTrain
  21. The outer list corresponds to trials and the inner list corresponds to
  22. the neurons recorded in that trial, such that data[l][n] is the
  23. spike train of neuron n in trial l. Note that the number and order of
  24. neo.SpikeTrains objects per trial must be fixed such that data[l][n]
  25. and data[k][n] refer to the same spike generator for any choice of l,k
  26. and n.
  27. bin_size: quantity.Quantity
  28. Spike bin width
  29. use_sqrt: bool
  30. Boolean specifying whether or not to use square-root transform on
  31. spike counts (see original paper for motivation).
  32. Default: True
  33. Returns
  34. -------
  35. seq : np.recarray
  36. data structure, whose nth entry (corresponding to the nth experimental
  37. trial) has fields
  38. T : int
  39. number of timesteps in the trial
  40. y : (yDim, T) np.ndarray
  41. neural data
  42. Raises
  43. ------
  44. ValueError
  45. if `bin_size` is not a pq.Quantity.
  46. """
  47. if not isinstance(bin_size, pq.Quantity):
  48. raise ValueError("'bin_size' must be of type pq.Quantity")
  49. seqs = []
  50. for dat in data:
  51. sts = dat
  52. binned_spiketrain = BinnedSpikeTrain(sts, bin_size=bin_size)
  53. if use_sqrt:
  54. binned = np.sqrt(binned_spiketrain.to_array())
  55. else:
  56. binned = binned_spiketrain.to_array()
  57. seqs.append(
  58. (binned_spiketrain.n_bins, binned))
  59. seqs = np.array(seqs, dtype=[('T', np.int), ('y', 'O')])
  60. # Remove trials that are shorter than one bin width
  61. if len(seqs) > 0:
  62. trials_to_keep = seqs['T'] > 0
  63. seqs = seqs[trials_to_keep]
  64. return seqs
  65. def cut_trials(seq_in, seg_length=20):
  66. """
  67. Extracts trial segments that are all of the same length. Uses
  68. overlapping segments if trial length is not integer multiple
  69. of segment length. Ignores trials with length shorter than
  70. one segment length.
  71. Parameters
  72. ----------
  73. seq_in : np.recarray
  74. data structure, whose nth entry (corresponding to the nth experimental
  75. trial) has fields
  76. T : int
  77. number of timesteps in trial
  78. y : (yDim, T) np.ndarray
  79. neural data
  80. seg_length : int
  81. length of segments to extract, in number of timesteps. If infinite,
  82. entire trials are extracted, i.e., no segmenting.
  83. Default: 20
  84. Returns
  85. -------
  86. seqOut : np.recarray
  87. data structure, whose nth entry (corresponding to the nth experimental
  88. trial) has fields
  89. T : int
  90. number of timesteps in segment
  91. y : (yDim, T) np.ndarray
  92. neural data
  93. Raises
  94. ------
  95. ValueError
  96. If `seq_length == 0`.
  97. """
  98. if seg_length == 0:
  99. raise ValueError("At least 1 extracted trial must be returned")
  100. if np.isinf(seg_length):
  101. seqOut = seq_in
  102. return seqOut
  103. dtype_seqOut = [('segId', np.int), ('T', np.int),
  104. ('y', np.object)]
  105. seqOut_buff = []
  106. for n, seqIn_n in enumerate(seq_in):
  107. T = seqIn_n['T']
  108. # Skip trials that are shorter than segLength
  109. if T < seg_length:
  110. warnings.warn(
  111. 'trial corresponding to index {} shorter than one segLength...'
  112. 'skipping'.format(n))
  113. continue
  114. numSeg = np.int(np.ceil(float(T) / seg_length))
  115. # Randomize the sizes of overlaps
  116. if numSeg == 1:
  117. cumOL = np.array([0, ])
  118. else:
  119. totalOL = (seg_length * numSeg) - T
  120. probs = np.ones(numSeg - 1, np.float) / (numSeg - 1)
  121. randOL = np.random.multinomial(totalOL, probs)
  122. cumOL = np.hstack([0, np.cumsum(randOL)])
  123. seg = np.empty(numSeg, dtype_seqOut)
  124. seg['T'] = seg_length
  125. for s, seg_s in enumerate(seg):
  126. tStart = seg_length * s - cumOL[s]
  127. seg_s['y'] = seqIn_n['y'][:, tStart:tStart + seg_length]
  128. seqOut_buff.append(seg)
  129. if len(seqOut_buff) > 0:
  130. seqOut = np.hstack(seqOut_buff)
  131. else:
  132. seqOut = np.empty(0, dtype_seqOut)
  133. return seqOut
  134. def rdiv(a, b):
  135. """
  136. Returns the solution to x b = a. Equivalent to MATLAB right matrix
  137. division: a / b
  138. """
  139. return np.linalg.solve(b.T, a.T).T
  140. def logdet(A):
  141. """
  142. log(det(A)) where A is positive-definite.
  143. This is faster and more stable than using log(det(A)).
  144. Written by Tom Minka
  145. (c) Microsoft Corporation. All rights reserved.
  146. """
  147. U = np.linalg.cholesky(A)
  148. return 2 * (np.log(np.diag(U))).sum()
  149. def make_k_big(params, n_timesteps):
  150. """
  151. Constructs full GP covariance matrix across all state dimensions and
  152. timesteps.
  153. Parameters
  154. ----------
  155. params : dict
  156. GPFA model parameters
  157. n_timesteps : int
  158. number of timesteps
  159. Returns
  160. -------
  161. K_big : np.ndarray
  162. GP covariance matrix with dimensions (xDim * T) x (xDim * T).
  163. The (t1, t2) block is diagonal, has dimensions xDim x xDim, and
  164. represents the covariance between the state vectors at timesteps t1 and
  165. t2. K_big is sparse and striped.
  166. K_big_inv : np.ndarray
  167. Inverse of K_big
  168. logdet_K_big : float
  169. Log determinant of K_big
  170. Raises
  171. ------
  172. ValueError
  173. If `params['covType'] != 'rbf'`.
  174. """
  175. if params['covType'] != 'rbf':
  176. raise ValueError("Only 'rbf' GP covariance type is supported.")
  177. xDim = params['C'].shape[1]
  178. K_big = np.zeros((xDim * n_timesteps, xDim * n_timesteps))
  179. K_big_inv = np.zeros((xDim * n_timesteps, xDim * n_timesteps))
  180. Tdif = np.tile(np.arange(0, n_timesteps), (n_timesteps, 1)).T \
  181. - np.tile(np.arange(0, n_timesteps), (n_timesteps, 1))
  182. logdet_K_big = 0
  183. for i in range(xDim):
  184. K = (1 - params['eps'][i]) * np.exp(-params['gamma'][i] / 2 *
  185. Tdif ** 2) \
  186. + params['eps'][i] * np.eye(n_timesteps)
  187. K_big[i::xDim, i::xDim] = K
  188. # the original MATLAB program uses here a special algorithm, provided
  189. # in C and MEX, for inversion of Toeplitz matrix:
  190. # [K_big_inv(idx+i, idx+i), logdet_K] = invToeplitz(K);
  191. # TODO: use an inversion method optimized for Toeplitz matrix
  192. # Below is an attempt to use such a method, not leading to a speed-up.
  193. # # K_big_inv[i::xDim, i::xDim] = sp.linalg.solve_toeplitz((K[:, 0],
  194. # K[0, :]), np.eye(T))
  195. K_big_inv[i::xDim, i::xDim] = np.linalg.inv(K)
  196. logdet_K = logdet(K)
  197. logdet_K_big = logdet_K_big + logdet_K
  198. return K_big, K_big_inv, logdet_K_big
  199. def inv_persymm(M, blk_size):
  200. """
  201. Inverts a matrix that is block persymmetric. This function is
  202. faster than calling inv(M) directly because it only computes the
  203. top half of inv(M). The bottom half of inv(M) is made up of
  204. elements from the top half of inv(M).
  205. WARNING: If the input matrix M is not block persymmetric, no
  206. error message will be produced and the output of this function will
  207. not be meaningful.
  208. Parameters
  209. ----------
  210. M : (blkSize*T, blkSize*T) np.ndarray
  211. The block persymmetric matrix to be inverted.
  212. Each block is blkSize x blkSize, arranged in a T x T grid.
  213. blk_size : int
  214. Edge length of one block
  215. Returns
  216. -------
  217. invM : (blkSize*T, blkSize*T) np.ndarray
  218. Inverse of M
  219. logdet_M : float
  220. Log determinant of M
  221. """
  222. T = int(M.shape[0] / blk_size)
  223. Thalf = np.int(np.ceil(T / 2.0))
  224. mkr = blk_size * Thalf
  225. invA11 = np.linalg.inv(M[:mkr, :mkr])
  226. invA11 = (invA11 + invA11.T) / 2
  227. # Multiplication of a sparse matrix by a dense matrix is not supported by
  228. # SciPy. Making A12 a sparse matrix here an error later.
  229. off_diag_sparse = False
  230. if off_diag_sparse:
  231. A12 = sp.sparse.csr_matrix(M[:mkr, mkr:])
  232. else:
  233. A12 = M[:mkr, mkr:]
  234. term = invA11.dot(A12)
  235. F22 = M[mkr:, mkr:] - A12.T.dot(term)
  236. res12 = rdiv(-term, F22)
  237. res11 = invA11 - res12.dot(term.T)
  238. res11 = (res11 + res11.T) / 2
  239. # Fill in bottom half of invM by picking elements from res11 and res12
  240. invM = fill_persymm(np.hstack([res11, res12]), blk_size, T)
  241. logdet_M = -logdet(invA11) + logdet(F22)
  242. return invM, logdet_M
  243. def fill_persymm(p_in, blk_size, n_blocks, blk_size_vert=None):
  244. """
  245. Fills in the bottom half of a block persymmetric matrix, given the
  246. top half.
  247. Parameters
  248. ----------
  249. p_in : (xDim*Thalf, xDim*T) np.ndarray
  250. Top half of block persymmetric matrix, where Thalf = ceil(T/2)
  251. blk_size : int
  252. Edge length of one block
  253. n_blocks : int
  254. Number of blocks making up a row of Pin
  255. blk_size_vert : int, optional
  256. Vertical block edge length if blocks are not square.
  257. `blk_size` is assumed to be the horizontal block edge length.
  258. Returns
  259. -------
  260. Pout : (xDim*T, xDim*T) np.ndarray
  261. Full block persymmetric matrix
  262. """
  263. if blk_size_vert is None:
  264. blk_size_vert = blk_size
  265. Nh = blk_size * n_blocks
  266. Nv = blk_size_vert * n_blocks
  267. Thalf = np.int(np.floor(n_blocks / 2.0))
  268. THalf = np.int(np.ceil(n_blocks / 2.0))
  269. Pout = np.empty((blk_size_vert * n_blocks, blk_size * n_blocks))
  270. Pout[:blk_size_vert * THalf, :] = p_in
  271. for i in range(Thalf):
  272. for j in range(n_blocks):
  273. Pout[Nv - (i + 1) * blk_size_vert:Nv - i * blk_size_vert,
  274. Nh - (j + 1) * blk_size:Nh - j * blk_size] \
  275. = p_in[i * blk_size_vert:(i + 1) *
  276. blk_size_vert,
  277. j * blk_size:(j + 1) * blk_size]
  278. return Pout
  279. def make_precomp(seqs, xDim):
  280. """
  281. Make the precomputation matrices specified by the GPFA algorithm.
  282. Usage: [precomp] = makePautoSum( seq , xDim )
  283. Parameters
  284. ----------
  285. seqs : np.recarray
  286. The sequence struct of inferred latents, etc.
  287. xDim : int
  288. The dimension of the latent space.
  289. Returns
  290. -------
  291. precomp : np.recarray
  292. The precomp struct will be updated with the posterior covaraince and
  293. the other requirements.
  294. Notes
  295. -----
  296. All inputs are named sensibly to those in `learnGPparams`.
  297. This code probably should not be called from anywhere but there.
  298. We bother with this method because we
  299. need this particular matrix sum to be
  300. as fast as possible. Thus, no error checking
  301. is done here as that would add needless computation.
  302. Instead, the onus is on the caller (which should be
  303. learnGPparams()) to make sure this is called correctly.
  304. Finally, see the notes in the GPFA README.
  305. """
  306. Tall = seqs['T']
  307. Tmax = (Tall).max()
  308. Tdif = np.tile(np.arange(0, Tmax), (Tmax, 1)).T \
  309. - np.tile(np.arange(0, Tmax), (Tmax, 1))
  310. # assign some helpful precomp items
  311. # this is computationally cheap, so we keep a few loops in MATLAB
  312. # for ease of readability.
  313. precomp = np.empty(xDim, dtype=[(
  314. 'absDif', np.object), ('difSq', np.object), ('Tall', np.object),
  315. ('Tu', np.object)])
  316. for i in range(xDim):
  317. precomp[i]['absDif'] = np.abs(Tdif)
  318. precomp[i]['difSq'] = Tdif ** 2
  319. precomp[i]['Tall'] = Tall
  320. # find unique numbers of trial lengths
  321. trial_lengths_num_unique = np.unique(Tall)
  322. # Loop once for each state dimension (each GP)
  323. for i in range(xDim):
  324. precomp_Tu = np.empty(len(trial_lengths_num_unique), dtype=[(
  325. 'nList', np.object), ('T', np.int), ('numTrials', np.int),
  326. ('PautoSUM', np.object)])
  327. for j, trial_len_num in enumerate(trial_lengths_num_unique):
  328. precomp_Tu[j]['nList'] = np.where(Tall == trial_len_num)[0]
  329. precomp_Tu[j]['T'] = trial_len_num
  330. precomp_Tu[j]['numTrials'] = len(precomp_Tu[j]['nList'])
  331. precomp_Tu[j]['PautoSUM'] = np.zeros((trial_len_num,
  332. trial_len_num))
  333. precomp[i]['Tu'] = precomp_Tu
  334. # at this point the basic precomp is built. The previous steps
  335. # should be computationally cheap. We now try to embed the
  336. # expensive computation in a MEX call, defaulting to MATLAB if
  337. # this fails. The expensive computation is filling out PautoSUM,
  338. # which we initialized previously as zeros.
  339. ############################################################
  340. # Fill out PautoSum
  341. ############################################################
  342. # Loop once for each state dimension (each GP)
  343. for i in range(xDim):
  344. # Loop once for each trial length (each of Tu)
  345. for j in range(len(trial_lengths_num_unique)):
  346. # Loop once for each trial (each of nList)
  347. for n in precomp[i]['Tu'][j]['nList']:
  348. precomp[i]['Tu'][j]['PautoSUM'] += seqs[n]['VsmGP'][:, :, i] \
  349. + np.outer(seqs[n]['latent_variable'][i, :],
  350. seqs[n]['latent_variable'][i, :])
  351. return precomp
  352. def grad_betgam(p, pre_comp, const):
  353. """
  354. Gradient computation for GP timescale optimization.
  355. This function is called by minimize.m.
  356. Parameters
  357. ----------
  358. p : float
  359. variable with respect to which optimization is performed,
  360. where :math:`p = log(1 / timescale^2)`
  361. pre_comp : np.recarray
  362. structure containing precomputations
  363. const : dict
  364. contains hyperparameters
  365. Returns
  366. -------
  367. f : float
  368. value of objective function E[log P({x},{y})] at p
  369. df : float
  370. gradient at p
  371. """
  372. Tall = pre_comp['Tall']
  373. Tmax = Tall.max()
  374. # temp is Tmax x Tmax
  375. temp = (1 - const['eps']) * np.exp(-np.exp(p) / 2 * pre_comp['difSq'])
  376. Kmax = temp + const['eps'] * np.eye(Tmax)
  377. dKdgamma_max = -0.5 * temp * pre_comp['difSq']
  378. dEdgamma = 0
  379. f = 0
  380. for j in range(len(pre_comp['Tu'])):
  381. T = pre_comp['Tu'][j]['T']
  382. Thalf = np.int(np.ceil(T / 2.0))
  383. Kinv = np.linalg.inv(Kmax[:T, :T])
  384. logdet_K = logdet(Kmax[:T, :T])
  385. KinvM = Kinv[:Thalf, :].dot(dKdgamma_max[:T, :T]) # Thalf x T
  386. KinvMKinv = (KinvM.dot(Kinv)).T # Thalf x T
  387. dg_KinvM = np.diag(KinvM)
  388. tr_KinvM = 2 * dg_KinvM.sum() - np.fmod(T, 2) * dg_KinvM[-1]
  389. mkr = np.int(np.ceil(0.5 * T ** 2))
  390. numTrials = pre_comp['Tu'][j]['numTrials']
  391. PautoSUM = pre_comp['Tu'][j]['PautoSUM']
  392. pauto_kinv_dot = PautoSUM.ravel('F')[:mkr].dot(
  393. KinvMKinv.ravel('F')[:mkr])
  394. pauto_kinv_dot_rest = PautoSUM.ravel('F')[-1:mkr - 1:- 1].dot(
  395. KinvMKinv.ravel('F')[:(T ** 2 - mkr)])
  396. dEdgamma = dEdgamma - 0.5 * numTrials * tr_KinvM \
  397. + 0.5 * pauto_kinv_dot \
  398. + 0.5 * pauto_kinv_dot_rest
  399. f = f - 0.5 * numTrials * logdet_K \
  400. - 0.5 * (PautoSUM * Kinv).sum()
  401. f = -f
  402. # exp(p) is needed because we're computing gradients with
  403. # respect to log(gamma), rather than gamma
  404. df = -dEdgamma * np.exp(p)
  405. return f, df
  406. def orthonormalize(x, l):
  407. """
  408. Orthonormalize the columns of the loading matrix and apply the
  409. corresponding linear transform to the latent variables.
  410. In the following description, yDim and xDim refer to data dimensionality
  411. and latent dimensionality, respectively.
  412. Parameters
  413. ----------
  414. x : (xDim, T) np.ndarray
  415. Latent variables
  416. l : (yDim, xDim) np.ndarray
  417. Loading matrix
  418. Returns
  419. -------
  420. latent_variable_orth : (xDim, T) np.ndarray
  421. Orthonormalized latent variables
  422. Lorth : (yDim, xDim) np.ndarray
  423. Orthonormalized loading matrix
  424. TT : (xDim, xDim) np.ndarray
  425. Linear transform applied to latent variables
  426. """
  427. xDim = l.shape[1]
  428. if xDim == 1:
  429. TT = np.sqrt(np.dot(l.T, l))
  430. Lorth = rdiv(l, TT)
  431. latent_variable_orth = np.dot(TT, x)
  432. else:
  433. UU, DD, VV = sp.linalg.svd(l, full_matrices=False)
  434. # TT is transform matrix
  435. TT = np.dot(np.diag(DD), VV)
  436. Lorth = UU
  437. latent_variable_orth = np.dot(TT, x)
  438. return latent_variable_orth, Lorth, TT
  439. def segment_by_trial(seqs, x, fn):
  440. """
  441. Segment and store data by trial.
  442. Parameters
  443. ----------
  444. seqs : np.recarray
  445. Data structure that has field T, the number of timesteps
  446. x : np.ndarray
  447. Data to be segmented (any dimensionality x total number of timesteps)
  448. fn : str
  449. New field name of seq where segments of X are stored
  450. Returns
  451. -------
  452. seqs_new : np.recarray
  453. Data structure with new field `fn`
  454. Raises
  455. ------
  456. ValueError
  457. If `seqs['T']) != x.shape[1]`.
  458. """
  459. if np.sum(seqs['T']) != x.shape[1]:
  460. raise ValueError('size of X incorrect.')
  461. dtype_new = [(i, seqs[i].dtype) for i in seqs.dtype.names]
  462. dtype_new.append((fn, np.object))
  463. seqs_new = np.empty(len(seqs), dtype=dtype_new)
  464. for dtype_name in seqs.dtype.names:
  465. seqs_new[dtype_name] = seqs[dtype_name]
  466. ctr = 0
  467. for n, T in enumerate(seqs['T']):
  468. seqs_new[n][fn] = x[:, ctr:ctr + T]
  469. ctr += T
  470. return seqs_new