gpfa.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. """
  2. Gaussian-process factor analysis (GPFA) is a dimensionality reduction method
  3. [#f1]_ for neural trajectory visualization of parallel spike trains. GPFA
  4. applies factor analysis (FA) to time-binned spike count data to reduce the
  5. dimensionality and at the same time smoothes the resulting low-dimensional
  6. trajectories by fitting a Gaussian process (GP) model to them.
  7. The input consists of a set of trials (Y), each containing a list of spike
  8. trains (N neurons). The output is the projection (X) of the data in a space
  9. of pre-chosen dimensionality x_dim < N.
  10. Under the assumption of a linear relation (transform matrix C) between the
  11. latent variable X following a Gaussian process and the spike train data Y with
  12. a bias d and a noise term of zero mean and (co)variance R (i.e.,
  13. :math:`Y = C X + d + Gauss(0,R)`), the projection corresponds to the
  14. conditional probability E[X|Y].
  15. The parameters (C, d, R) as well as the time scales and variances of the
  16. Gaussian process are estimated from the data using an expectation-maximization
  17. (EM) algorithm.
  18. Internally, the analysis consists of the following steps:
  19. 0) bin the spike train data to get a sequence of N dimensional vectors of spike
  20. counts in respective time bins, and choose the reduced dimensionality x_dim
  21. 1) expectation-maximization for fitting of the parameters C, d, R and the
  22. time-scales and variances of the Gaussian process, using all the trials
  23. provided as input (c.f., `gpfa_core.em()`)
  24. 2) projection of single trials in the low dimensional space (c.f.,
  25. `gpfa_core.exact_inference_with_ll()`)
  26. 3) orthonormalization of the matrix C and the corresponding subspace, for
  27. visualization purposes: (c.f., `gpfa_core.orthonormalize()`)
  28. .. autosummary::
  29. :toctree: toctree/gpfa
  30. GPFA
  31. Visualization
  32. -------------
  33. Visualization of GPFA transforms is covered in Viziphant:
  34. https://viziphant.readthedocs.io/en/latest/modules.html
  35. Tutorial
  36. --------
  37. :doc:`View tutorial <../tutorials/gpfa>`
  38. Run tutorial interactively:
  39. .. image:: https://mybinder.org/badge.svg
  40. :target: https://mybinder.org/v2/gh/NeuralEnsemble/elephant/master
  41. ?filepath=doc/tutorials/gpfa.ipynb
  42. References
  43. ----------
  44. The code was ported from the MATLAB code based on Byron Yu's implementation.
  45. The original MATLAB code is available at Byron Yu's website:
  46. https://users.ece.cmu.edu/~byronyu/software.shtml
  47. .. [#f1] Yu MB, Cunningham JP, Santhanam G, Ryu SI, Shenoy K V, Sahani M (2009)
  48. Gaussian-process factor analysis for low-dimensional single-trial analysis
  49. of neural population activity. J Neurophysiol 102:614-635.
  50. :copyright: Copyright 2015-2019 by the Elephant team, see AUTHORS.txt.
  51. :license: Modified BSD, see LICENSE.txt for details.
  52. """
  53. from __future__ import division, print_function, unicode_literals
  54. import neo
  55. import numpy as np
  56. import quantities as pq
  57. import sklearn
  58. import warnings
  59. from elephant.gpfa import gpfa_core, gpfa_util
  60. from elephant.utils import deprecated_alias
  61. __all__ = [
  62. "GPFA"
  63. ]
  64. class GPFA(sklearn.base.BaseEstimator):
  65. """
  66. Apply Gaussian process factor analysis (GPFA) to spike train data
  67. There are two principle scenarios of using the GPFA analysis, both of which
  68. can be performed in an instance of the GPFA() class.
  69. In the first scenario, only one single dataset is used to fit the model and
  70. to extract the neural trajectories. The parameters that describe the
  71. transformation are first extracted from the data using the `fit()` method
  72. of the GPFA class. Then the same data is projected into the orthonormal
  73. basis using the method `transform()`. The `fit_transform()` method can be
  74. used to perform these two steps at once.
  75. In the second scenario, a single dataset is split into training and test
  76. datasets. Here, the parameters are estimated from the training data. Then
  77. the test data is projected into the low-dimensional space previously
  78. obtained from the training data. This analysis is performed by executing
  79. first the `fit()` method on the training data, followed by the
  80. `transform()` method on the test dataset.
  81. The GPFA class is compatible to the cross-validation functions of
  82. `sklearn.model_selection`, such that users can perform cross-validation to
  83. search for a set of parameters yielding best performance using these
  84. functions.
  85. Parameters
  86. ----------
  87. x_dim : int, optional
  88. state dimensionality
  89. Default: 3
  90. bin_size : float, optional
  91. spike bin width in msec
  92. Default: 20.0
  93. min_var_frac : float, optional
  94. fraction of overall data variance for each observed dimension to set as
  95. the private variance floor. This is used to combat Heywood cases,
  96. where ML parameter learning returns one or more zero private variances.
  97. Default: 0.01
  98. (See Martin & McDonald, Psychometrika, Dec 1975.)
  99. em_tol : float, optional
  100. stopping criterion for EM
  101. Default: 1e-8
  102. em_max_iters : int, optional
  103. number of EM iterations to run
  104. Default: 500
  105. tau_init : float, optional
  106. GP timescale initialization in msec
  107. Default: 100
  108. eps_init : float, optional
  109. GP noise variance initialization
  110. Default: 1e-3
  111. freq_ll : int, optional
  112. data likelihood is computed at every freq_ll EM iterations. freq_ll = 1
  113. means that data likelihood is computed at every iteration.
  114. Default: 5
  115. verbose : bool, optional
  116. specifies whether to display status messages
  117. Default: False
  118. Attributes
  119. ----------
  120. valid_data_names : tuple of str
  121. Names of the data contained in the resultant data structure, used to
  122. check the validity of users' request
  123. has_spikes_bool : np.ndarray of bool
  124. Indicates if a neuron has any spikes across trials of the training
  125. data.
  126. params_estimated : dict
  127. Estimated model parameters. Updated at each run of the fit() method.
  128. covType : str
  129. type of GP covariance, either 'rbf', 'tri', or 'logexp'.
  130. Currently, only 'rbf' is supported.
  131. gamma : (1, #latent_vars) np.ndarray
  132. related to GP timescales of latent variables before
  133. orthonormalization by :math:`bin_size / sqrt(gamma)`
  134. eps : (1, #latent_vars) np.ndarray
  135. GP noise variances
  136. d : (#units, 1) np.ndarray
  137. observation mean
  138. C : (#units, #latent_vars) np.ndarray
  139. loading matrix, representing the mapping between the neuronal data
  140. space and the latent variable space
  141. R : (#units, #latent_vars) np.ndarray
  142. observation noise covariance
  143. fit_info : dict
  144. Information of the fitting process. Updated at each run of the fit()
  145. method.
  146. iteration_time : list
  147. containing the runtime for each iteration step in the EM algorithm.
  148. log_likelihoods : list
  149. log likelihoods after each EM iteration.
  150. transform_info : dict
  151. Information of the transforming process. Updated at each run of the
  152. transform() method.
  153. log_likelihood : float
  154. maximized likelihood of the transformed data
  155. num_bins : nd.array
  156. number of bins in each trial
  157. Corth : (#units, #latent_vars) np.ndarray
  158. mapping between the neuronal data space and the orthonormal
  159. latent variable space
  160. Methods
  161. -------
  162. fit
  163. transform
  164. fit_transform
  165. score
  166. Raises
  167. ------
  168. ValueError
  169. If `bin_size` or `tau_init` is not a `pq.Quantity`.
  170. Examples
  171. --------
  172. In the following example, we calculate the neural trajectories of 20
  173. independent Poisson spike trains recorded in 50 trials with randomized
  174. rates up to 100 Hz.
  175. >>> import numpy as np
  176. >>> import quantities as pq
  177. >>> from elephant.gpfa import GPFA
  178. >>> from elephant.spike_train_generation import homogeneous_poisson_process
  179. >>> data = []
  180. >>> for trial in range(50):
  181. >>> n_channels = 20
  182. >>> firing_rates = np.random.randint(low=1, high=100,
  183. ... size=n_channels) * pq.Hz
  184. >>> spike_times = [homogeneous_poisson_process(rate=rate)
  185. ... for rate in firing_rates]
  186. >>> data.append((trial, spike_times))
  187. ...
  188. >>> gpfa = GPFA(bin_size=20*pq.ms, x_dim=8)
  189. >>> gpfa.fit(data)
  190. >>> results = gpfa.transform(data, returned_data=['latent_variable_orth',
  191. ... 'latent_variable'])
  192. >>> latent_variable_orth = results['latent_variable_orth']
  193. >>> latent_variable = results['latent_variable']
  194. or simply
  195. >>> results = GPFA(bin_size=20*pq.ms, x_dim=8).fit_transform(data,
  196. ... returned_data=['latent_variable_orth',
  197. ... 'latent_variable'])
  198. """
  199. @deprecated_alias(binsize='bin_size')
  200. def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
  201. tau_init=100.0 * pq.ms, eps_init=1.0E-3, em_tol=1.0E-8,
  202. em_max_iters=500, freq_ll=5, verbose=False):
  203. self.bin_size = bin_size
  204. self.x_dim = x_dim
  205. self.min_var_frac = min_var_frac
  206. self.tau_init = tau_init
  207. self.eps_init = eps_init
  208. self.em_tol = em_tol
  209. self.em_max_iters = em_max_iters
  210. self.freq_ll = freq_ll
  211. self.valid_data_names = (
  212. 'latent_variable_orth',
  213. 'latent_variable',
  214. 'Vsm',
  215. 'VsmGP',
  216. 'y')
  217. self.verbose = verbose
  218. if not isinstance(self.bin_size, pq.Quantity):
  219. raise ValueError("'bin_size' must be of type pq.Quantity")
  220. if not isinstance(self.tau_init, pq.Quantity):
  221. raise ValueError("'tau_init' must be of type pq.Quantity")
  222. # will be updated later
  223. self.params_estimated = dict()
  224. self.fit_info = dict()
  225. self.transform_info = dict()
  226. @property
  227. def binsize(self):
  228. warnings.warn("'binsize' is deprecated; use 'bin_size'")
  229. return self.bin_size
  230. def fit(self, spiketrains):
  231. """
  232. Fit the model with the given training data.
  233. Parameters
  234. ----------
  235. spiketrains : list of list of neo.SpikeTrain
  236. Spike train data to be fit to latent variables.
  237. The outer list corresponds to trials and the inner list corresponds
  238. to the neurons recorded in that trial, such that
  239. `spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
  240. Note that the number and order of `neo.SpikeTrain` objects per
  241. trial must be fixed such that `spiketrains[l][n]` and
  242. `spiketrains[k][n]` refer to spike trains of the same neuron
  243. for any choices of `l`, `k`, and `n`.
  244. Returns
  245. -------
  246. self : object
  247. Returns the instance itself.
  248. Raises
  249. ------
  250. ValueError
  251. If `spiketrains` is an empty list.
  252. If `spiketrains[0][0]` is not a `neo.SpikeTrain`.
  253. If covariance matrix of input spike data is rank deficient.
  254. """
  255. self._check_training_data(spiketrains)
  256. seqs_train = self._format_training_data(spiketrains)
  257. # Check if training data covariance is full rank
  258. y_all = np.hstack(seqs_train['y'])
  259. y_dim = y_all.shape[0]
  260. if np.linalg.matrix_rank(np.cov(y_all)) < y_dim:
  261. errmesg = 'Observation covariance matrix is rank deficient.\n' \
  262. 'Possible causes: ' \
  263. 'repeated units, not enough observations.'
  264. raise ValueError(errmesg)
  265. if self.verbose:
  266. print('Number of training trials: {}'.format(len(seqs_train)))
  267. print('Latent space dimensionality: {}'.format(self.x_dim))
  268. print('Observation dimensionality: {}'.format(
  269. self.has_spikes_bool.sum()))
  270. # The following does the heavy lifting.
  271. self.params_estimated, self.fit_info = gpfa_core.fit(
  272. seqs_train=seqs_train,
  273. x_dim=self.x_dim,
  274. bin_width=self.bin_size.rescale('ms').magnitude,
  275. min_var_frac=self.min_var_frac,
  276. em_max_iters=self.em_max_iters,
  277. em_tol=self.em_tol,
  278. tau_init=self.tau_init.rescale('ms').magnitude,
  279. eps_init=self.eps_init,
  280. freq_ll=self.freq_ll,
  281. verbose=self.verbose)
  282. return self
  283. @staticmethod
  284. def _check_training_data(spiketrains):
  285. if len(spiketrains) == 0:
  286. raise ValueError("Input spiketrains cannot be empty")
  287. if not isinstance(spiketrains[0][0], neo.SpikeTrain):
  288. raise ValueError("structure of the spiketrains is not correct: "
  289. "0-axis should be trials, 1-axis neo.SpikeTrain"
  290. "and 2-axis spike times")
  291. def _format_training_data(self, spiketrains):
  292. seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
  293. # Remove inactive units based on training set
  294. self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1)
  295. for seq in seqs:
  296. seq['y'] = seq['y'][self.has_spikes_bool, :]
  297. return seqs
  298. def transform(self, spiketrains, returned_data=['latent_variable_orth']):
  299. """
  300. Obtain trajectories of neural activity in a low-dimensional latent
  301. variable space by inferring the posterior mean of the obtained GPFA
  302. model and applying an orthonormalization on the latent variable space.
  303. Parameters
  304. ----------
  305. spiketrains : list of list of neo.SpikeTrain
  306. Spike train data to be transformed to latent variables.
  307. The outer list corresponds to trials and the inner list corresponds
  308. to the neurons recorded in that trial, such that
  309. `spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
  310. Note that the number and order of `neo.SpikeTrain` objects per
  311. trial must be fixed such that `spiketrains[l][n]` and
  312. `spiketrains[k][n]` refer to spike trains of the same neuron
  313. for any choices of `l`, `k`, and `n`.
  314. returned_data : list of str
  315. The dimensionality reduction transform generates the following
  316. resultant data:
  317. 'latent_variable_orth': orthonormalized posterior mean of latent
  318. variable
  319. 'latent_variable': posterior mean of latent variable before
  320. orthonormalization
  321. 'Vsm': posterior covariance between latent variables
  322. 'VsmGP': posterior covariance over time for each latent variable
  323. 'y': neural data used to estimate the GPFA model parameters
  324. `returned_data` specifies the keys by which the data dict is
  325. returned.
  326. Default is ['latent_variable_orth'].
  327. Returns
  328. -------
  329. np.ndarray or dict
  330. When the length of `returned_data` is one, a single np.ndarray,
  331. containing the requested data (the first entry in `returned_data`
  332. keys list), is returned. Otherwise, a dict of multiple np.ndarrays
  333. with the keys identical to the data names in `returned_data` is
  334. returned.
  335. N-th entry of each np.ndarray is a np.ndarray of the following
  336. shape, specific to each data type, containing the corresponding
  337. data for the n-th trial:
  338. `latent_variable_orth`: (#latent_vars, #bins) np.ndarray
  339. `latent_variable`: (#latent_vars, #bins) np.ndarray
  340. `y`: (#units, #bins) np.ndarray
  341. `Vsm`: (#latent_vars, #latent_vars, #bins) np.ndarray
  342. `VsmGP`: (#bins, #bins, #latent_vars) np.ndarray
  343. Note that the num. of bins (#bins) can vary across trials,
  344. reflecting the trial durations in the given `spiketrains` data.
  345. Raises
  346. ------
  347. ValueError
  348. If the number of neurons in `spiketrains` is different from that
  349. in the training spiketrain data.
  350. If `returned_data` contains keys different from the ones in
  351. `self.valid_data_names`.
  352. """
  353. if len(spiketrains[0]) != len(self.has_spikes_bool):
  354. raise ValueError("'spiketrains' must contain the same number of "
  355. "neurons as the training spiketrain data")
  356. invalid_keys = set(returned_data).difference(self.valid_data_names)
  357. if len(invalid_keys) > 0:
  358. raise ValueError("'returned_data' can only have the following "
  359. "entries: {}".format(self.valid_data_names))
  360. seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
  361. for seq in seqs:
  362. seq['y'] = seq['y'][self.has_spikes_bool, :]
  363. seqs, ll = gpfa_core.exact_inference_with_ll(seqs,
  364. self.params_estimated,
  365. get_ll=True)
  366. self.transform_info['log_likelihood'] = ll
  367. self.transform_info['num_bins'] = seqs['T']
  368. Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs)
  369. self.transform_info['Corth'] = Corth
  370. if len(returned_data) == 1:
  371. return seqs[returned_data[0]]
  372. return {x: seqs[x] for x in returned_data}
  373. def fit_transform(self, spiketrains, returned_data=[
  374. 'latent_variable_orth']):
  375. """
  376. Fit the model with `spiketrains` data and apply the dimensionality
  377. reduction on `spiketrains`.
  378. Parameters
  379. ----------
  380. spiketrains : list of list of neo.SpikeTrain
  381. Refer to the :func:`GPFA.fit` docstring.
  382. returned_data : list of str
  383. Refer to the :func:`GPFA.transform` docstring.
  384. Returns
  385. -------
  386. np.ndarray or dict
  387. Refer to the :func:`GPFA.transform` docstring.
  388. Raises
  389. ------
  390. ValueError
  391. Refer to :func:`GPFA.fit` and :func:`GPFA.transform`.
  392. See Also
  393. --------
  394. GPFA.fit : fit the model with `spiketrains`
  395. GPFA.transform : transform `spiketrains` into trajectories
  396. """
  397. self.fit(spiketrains)
  398. return self.transform(spiketrains, returned_data=returned_data)
  399. def score(self, spiketrains):
  400. """
  401. Returns the log-likelihood of the given data under the fitted model
  402. Parameters
  403. ----------
  404. spiketrains : list of list of neo.SpikeTrain
  405. Spike train data to be scored.
  406. The outer list corresponds to trials and the inner list corresponds
  407. to the neurons recorded in that trial, such that
  408. `spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
  409. Note that the number and order of `neo.SpikeTrain` objects per
  410. trial must be fixed such that `spiketrains[l][n]` and
  411. `spiketrains[k][n]` refer to spike trains of the same neuron
  412. for any choice of `l`, `k`, and `n`.
  413. Returns
  414. -------
  415. log_likelihood : float
  416. Log-likelihood of the given spiketrains under the fitted model.
  417. """
  418. self.transform(spiketrains)
  419. return self.transform_info['log_likelihood']