conversion.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. # -*- coding: utf-8 -*-
  2. """
  3. This module allows to convert standard data representations
  4. (e.g., a spike train stored as Neo SpikeTrain object)
  5. into other representations useful to perform calculations on the data.
  6. An example is the representation of a spike train as a sequence of 0-1 values
  7. (binned spike train).
  8. .. autosummary::
  9. :toctree: toctree/conversion
  10. BinnedSpikeTrain
  11. binarize
  12. :copyright: Copyright 2014-2016 by the Elephant team, see `doc/authors.rst`.
  13. :license: BSD, see LICENSE.txt for details.
  14. """
  15. from __future__ import division, print_function, unicode_literals
  16. import warnings
  17. import neo
  18. import numpy as np
  19. import quantities as pq
  20. import scipy.sparse as sps
  21. from elephant.utils import is_binary, deprecated_alias, \
  22. check_neo_consistency, get_common_start_stop_times
  23. __all__ = [
  24. "binarize",
  25. "BinnedSpikeTrain"
  26. ]
  27. def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None,
  28. return_times=False):
  29. """
  30. Return an array indicating if spikes occurred at individual time points.
  31. The array contains boolean values identifying whether at least one spike
  32. occurred in the corresponding time bin. Time bins start at `t_start`
  33. and end at `t_stop`, spaced in `1/sampling_rate` intervals.
  34. Accepts either a `neo.SpikeTrain`, a `pq.Quantity` array, or a plain
  35. `np.ndarray`.
  36. Returns a boolean array with each element indicating the presence or
  37. absence of a spike in that time bin.
  38. Optionally also returns an array of time points corresponding to the
  39. elements of the boolean array. The units of this array will be the same as
  40. the units of the neo.SpikeTrain, if any.
  41. Parameters
  42. ----------
  43. spiketrain : neo.SpikeTrain or pq.Quantity or np.ndarray
  44. The spike times. Does not have to be sorted.
  45. sampling_rate : float or pq.Quantity, optional
  46. The sampling rate to use for the time points.
  47. If not specified, retrieved from the `sampling_rate` attribute of
  48. `spiketrain`.
  49. Default: None.
  50. t_start : float or pq.Quantity, optional
  51. The start time to use for the time points.
  52. If not specified, retrieved from the `t_start` attribute of
  53. `spiketrain`. If this is not present, defaults to `0`. Any element of
  54. `spiketrain` lower than `t_start` is ignored.
  55. Default: None.
  56. t_stop : float or pq.Quantity, optional
  57. The stop time to use for the time points.
  58. If not specified, retrieved from the `t_stop` attribute of
  59. `spiketrain`. If this is not present, defaults to the maximum value of
  60. `spiketrain`. Any element of `spiketrain` higher than `t_stop` is
  61. ignored.
  62. Default: None.
  63. return_times : bool, optional
  64. If True, also return the corresponding time points.
  65. Default: False.
  66. Returns
  67. -------
  68. values : np.ndarray of bool
  69. A True value at a particular index indicates the presence of one or
  70. more spikes at the corresponding time point.
  71. times : np.ndarray or pq.Quantity, optional
  72. The time points. This will have the same units as `spiketrain`.
  73. If `spiketrain` has no units, this will be an `np.ndarray` array.
  74. Raises
  75. ------
  76. TypeError
  77. If `spiketrain` is an `np.ndarray` and `t_start`, `t_stop`, or
  78. `sampling_rate` is a `pq.Quantity`.
  79. ValueError
  80. If `sampling_rate` is not explicitly defined and not present as an
  81. attribute of `spiketrain`.
  82. Notes
  83. -----
  84. Spike times are placed in the bin of the closest time point, going to the
  85. higher bin if exactly between two bins.
  86. So in the case where the bins are `5.5` and `6.5`, with the spike time
  87. being `6.0`, the spike will be placed in the `6.5` bin.
  88. The upper edge of the last bin, equal to `t_stop`, is inclusive. That is,
  89. a spike time exactly equal to `t_stop` will be included.
  90. If `spiketrain` is a `pq.Quantity` or `neo.SpikeTrain` and `t_start`,
  91. `t_stop` or `sampling_rate` is not, then the arguments that are not
  92. `pq.Quantity` will be assumed to have the same units as `spiketrain`.
  93. """
  94. # get the values from spiketrain if they are not specified.
  95. if sampling_rate is None:
  96. sampling_rate = getattr(spiketrain, 'sampling_rate', None)
  97. if sampling_rate is None:
  98. raise ValueError('sampling_rate must either be explicitly defined '
  99. 'or must be an attribute of spiketrain')
  100. if t_start is None:
  101. t_start = getattr(spiketrain, 't_start', 0)
  102. if t_stop is None:
  103. t_stop = getattr(spiketrain, 't_stop', np.max(spiketrain))
  104. # we don't actually want the sampling rate, we want the sampling period
  105. sampling_period = 1. / sampling_rate
  106. # figure out what units, if any, we are dealing with
  107. if hasattr(spiketrain, 'units'):
  108. units = spiketrain.units
  109. spiketrain = spiketrain.magnitude
  110. else:
  111. units = None
  112. # convert everything to the same units, then get the magnitude
  113. if hasattr(sampling_period, 'units'):
  114. if units is None:
  115. raise TypeError('sampling_period cannot be a Quantity if '
  116. 'spiketrain is not a quantity')
  117. sampling_period = sampling_period.rescale(units).magnitude
  118. if hasattr(t_start, 'units'):
  119. if units is None:
  120. raise TypeError('t_start cannot be a Quantity if '
  121. 'spiketrain is not a quantity')
  122. t_start = t_start.rescale(units).magnitude
  123. if hasattr(t_stop, 'units'):
  124. if units is None:
  125. raise TypeError('t_stop cannot be a Quantity if '
  126. 'spiketrain is not a quantity')
  127. t_stop = t_stop.rescale(units).magnitude
  128. # figure out the bin edges
  129. edges = np.arange(t_start - sampling_period / 2,
  130. t_stop + sampling_period * 3 / 2,
  131. sampling_period)
  132. # we don't want to count any spikes before t_start or after t_stop
  133. if edges[-2] > t_stop:
  134. edges = edges[:-1]
  135. if edges[1] < t_start:
  136. edges = edges[1:]
  137. edges[0] = t_start
  138. edges[-1] = t_stop
  139. # this is where we actually get the binarized spike train
  140. res = np.histogram(spiketrain, edges)[0].astype('bool')
  141. # figure out what to output
  142. if not return_times:
  143. return res
  144. elif units is None:
  145. return res, np.arange(t_start, t_stop + sampling_period,
  146. sampling_period)
  147. else:
  148. return res, pq.Quantity(np.arange(t_start, t_stop + sampling_period,
  149. sampling_period), units=units)
  150. ###########################################################################
  151. #
  152. # Methods to calculate parameters, t_start, t_stop, bin size,
  153. # number of bins
  154. #
  155. ###########################################################################
  156. def _detect_rounding_errors(values, tolerance):
  157. """
  158. Finds rounding errors in values that will be cast to int afterwards.
  159. Returns True for values that are within tolerance of the next integer.
  160. Works for both scalars and numpy arrays.
  161. """
  162. if tolerance is None or tolerance == 0:
  163. return np.zeros_like(values, dtype=bool)
  164. # same as '1 - (values % 1) <= tolerance' but faster
  165. return 1 - tolerance <= values % 1
  166. class BinnedSpikeTrain(object):
  167. """
  168. Class which calculates a binned spike train and provides methods to
  169. transform the binned spike train to a boolean matrix or a matrix with
  170. counted time points.
  171. A binned spike train represents the occurrence of spikes in a certain time
  172. frame.
  173. I.e., a time series like [0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] is
  174. represented as [0, 0, 1, 3, 4, 5, 6]. The outcome is dependent on given
  175. parameter such as size of bins, number of bins, start and stop points.
  176. A boolean matrix represents the binned spike train in a binary (True/False)
  177. manner. Its rows represent the number of spike trains and the columns
  178. represent the binned index position of a spike in a spike train.
  179. The calculated matrix entry containing `True` indicates a spike.
  180. A matrix with counted time points is calculated the same way, but its
  181. entries contain the number of spikes that occurred in the given bin of the
  182. given spike train.
  183. Note that with most common parameter combinations spike times can end up
  184. on bin edges. This makes the binning susceptible to rounding errors which
  185. is accounted for by moving spikes which are within tolerance of the next
  186. bin edge into the following bin. This can be adjusted using the tolerance
  187. parameter and turned off by setting `tolerance=None`.
  188. Parameters
  189. ----------
  190. spiketrains : neo.SpikeTrain or list of neo.SpikeTrain or np.ndarray
  191. Spike train(s) to be binned.
  192. bin_size : pq.Quantity, optional
  193. Width of a time bin.
  194. Default: None
  195. n_bins : int, optional
  196. Number of bins of the binned spike train.
  197. Default: None
  198. t_start : pq.Quantity, optional
  199. Time of the left edge of the first bin (left extreme; included).
  200. Default: None
  201. t_stop : pq.Quantity, optional
  202. Time of the right edge of the last bin (right extreme; excluded).
  203. Default: None
  204. tolerance : float, optional
  205. Tolerance for rounding errors in the binning process and in the input
  206. data
  207. Default: 1e-8
  208. Raises
  209. ------
  210. AttributeError
  211. If less than 3 optional parameters are `None`.
  212. TypeError
  213. If `spiketrains` is an np.ndarray with dimensionality different than
  214. NxM or
  215. if type of `n_bins` is not an `int` or `n_bins` < 0.
  216. ValueError
  217. When number of bins calculated from `t_start`, `t_stop` and `bin_size`
  218. differs from provided `n_bins` or
  219. if `t_stop` of any spike train is smaller than any `t_start` or
  220. if any spike train does not cover the full [`t_start`, t_stop`] range.
  221. Warns
  222. -----
  223. UserWarning
  224. If some spikes fall outside of [`t_start`, `t_stop`] range
  225. See also
  226. --------
  227. _convert_to_binned
  228. spike_indices
  229. to_bool_array
  230. to_array
  231. Notes
  232. -----
  233. There are four minimal configurations of the optional parameters which have
  234. to be provided, otherwise a `ValueError` will be raised:
  235. * `t_start`, `n_bins`, `bin_size`
  236. * `t_start`, `n_bins`, `t_stop`
  237. * `t_start`, `bin_size`, `t_stop`
  238. * `t_stop`, `n_bins`, `bin_size`
  239. If `spiketrains` is a `neo.SpikeTrain` or a list thereof, it is enough to
  240. explicitly provide only one parameter: `n_bins` or `bin_size`. The
  241. `t_start` and `t_stop` will be calculated from given `spiketrains` (max
  242. `t_start` and min `t_stop` of `neo.SpikeTrain`s).
  243. Missing parameter will be calculated automatically.
  244. All parameters will be checked for consistency. A corresponding error will
  245. be raised, if one of the four parameters does not match the consistency
  246. requirements.
  247. """
  248. @deprecated_alias(binsize='bin_size', num_bins='n_bins')
  249. def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
  250. t_stop=None, tolerance=1e-8):
  251. # Converting spiketrains to a list, if spiketrains is one
  252. # SpikeTrain object
  253. if isinstance(spiketrains, neo.SpikeTrain):
  254. spiketrains = [spiketrains]
  255. # Set given parameters
  256. self._t_start = t_start
  257. self._t_stop = t_stop
  258. self.n_bins = n_bins
  259. self._bin_size = bin_size
  260. self.units = None # will be set later
  261. # Check all parameter, set also missing values
  262. self._resolve_input_parameters(spiketrains, tolerance=tolerance)
  263. # Now create the sparse matrix
  264. self.sparse_matrix = self._create_sparse_matrix(spiketrains,
  265. tolerance=tolerance)
  266. @property
  267. def shape(self):
  268. return self.sparse_matrix.shape
  269. @property
  270. def bin_size(self):
  271. return pq.Quantity(self._bin_size, units=self.units, copy=False)
  272. @property
  273. def t_start(self):
  274. return pq.Quantity(self._t_start, units=self.units, copy=False)
  275. @property
  276. def t_stop(self):
  277. return pq.Quantity(self._t_stop, units=self.units, copy=False)
  278. @property
  279. def binsize(self):
  280. warnings.warn("'.binsize' is deprecated; use '.bin_size'",
  281. DeprecationWarning)
  282. return self._bin_size
  283. @property
  284. def num_bins(self):
  285. warnings.warn("'.num_bins' is deprecated; use '.n_bins'")
  286. return self.n_bins
  287. def __repr__(self):
  288. return "{klass}(t_start={t_start}, t_stop={t_stop}, " \
  289. "bin_size={bin_size}; shape={shape})".format(
  290. klass=type(self).__name__,
  291. t_start=self.t_start,
  292. t_stop=self.t_stop,
  293. bin_size=self.bin_size,
  294. shape=self.shape)
  295. def rescale(self, units):
  296. """
  297. Inplace rescaling to the new quantity units.
  298. Parameters
  299. ----------
  300. units : pq.Quantity or str
  301. New quantity units.
  302. Raises
  303. ------
  304. TypeError
  305. If the input units are not quantities.
  306. """
  307. if isinstance(units, str):
  308. units = pq.Quantity(1, units=units)
  309. if units == self.units:
  310. # do nothing
  311. return
  312. if not isinstance(units, pq.Quantity):
  313. raise TypeError("The input units must be quantities or string")
  314. scale = self.units.rescale(units).item()
  315. self._t_stop *= scale
  316. self._t_start *= scale
  317. self._bin_size *= scale
  318. self.units = units
  319. def __resolve_binned(self, spiketrains):
  320. spiketrains = np.asarray(spiketrains)
  321. if spiketrains.ndim != 2 or spiketrains.dtype == np.dtype('O'):
  322. raise ValueError("If the input is not a spiketrain(s), it "
  323. "must be an MxN numpy array, each cell of "
  324. "which represents the number of (binned) "
  325. "spikes that fall in an interval - not "
  326. "raw spike times.")
  327. if self.n_bins is not None:
  328. raise ValueError("When the input is a binned matrix, 'n_bins' "
  329. "must be set to None - it's extracted from the "
  330. "input shape.")
  331. self.n_bins = spiketrains.shape[1]
  332. if self._bin_size is None:
  333. if self._t_start is None or self._t_stop is None:
  334. raise ValueError("To determine the bin size, both 't_start' "
  335. "and 't_stop' must be set")
  336. self._bin_size = (self._t_stop - self._t_start) / self.n_bins
  337. if self._t_start is None and self._t_stop is None:
  338. raise ValueError("Either 't_start' or 't_stop' must be set")
  339. if self._t_start is None:
  340. self._t_start = self._t_stop - self._bin_size * self.n_bins
  341. if self._t_stop is None:
  342. self._t_stop = self._t_start + self._bin_size * self.n_bins
  343. def _resolve_input_parameters(self, spiketrains, tolerance):
  344. """
  345. Calculates `t_start`, `t_stop` from given spike trains.
  346. The start and stop points are calculated from given spike trains only
  347. if they are not calculable from given parameters or the number of
  348. parameters is less than three.
  349. Parameters
  350. ----------
  351. spiketrains : neo.SpikeTrain or list or np.ndarray of neo.SpikeTrain
  352. """
  353. def get_n_bins():
  354. n_bins = (self._t_stop - self._t_start) / self._bin_size
  355. if isinstance(n_bins, pq.Quantity):
  356. n_bins = n_bins.simplified.item()
  357. if _detect_rounding_errors(n_bins, tolerance=tolerance):
  358. warnings.warn('Correcting a rounding error in the calculation '
  359. 'of n_bins by increasing n_bins by 1. '
  360. 'You can set tolerance=None to disable this '
  361. 'behaviour.')
  362. return int(n_bins)
  363. def check_n_bins_consistency():
  364. if self.n_bins != get_n_bins():
  365. raise ValueError(
  366. "Inconsistent arguments: t_start ({t_start}), "
  367. "t_stop ({t_stop}), bin_size ({bin_size}), and "
  368. "n_bins ({n_bins})".format(
  369. t_start=self.t_start, t_stop=self.t_stop,
  370. bin_size=self.bin_size, n_bins=self.n_bins))
  371. def check_consistency():
  372. if self.t_start >= self.t_stop:
  373. raise ValueError("t_start must be smaller than t_stop")
  374. if not isinstance(self.n_bins, int) or self.n_bins <= 0:
  375. raise TypeError("The number of bins ({}) must be a positive "
  376. "integer".format(self.n_bins))
  377. if not _check_neo_spiketrain(spiketrains):
  378. # a binned numpy matrix
  379. self.__resolve_binned(spiketrains)
  380. self.units = self._bin_size.units
  381. check_n_bins_consistency()
  382. check_consistency()
  383. self._t_start = self._t_start.rescale(self.units).item()
  384. self._t_stop = self._t_stop.rescale(self.units).item()
  385. self._bin_size = self._bin_size.rescale(self.units).item()
  386. return
  387. if self._bin_size is None and self.n_bins is None:
  388. raise ValueError("Either 'bin_size' or 'n_bins' must be given")
  389. try:
  390. check_neo_consistency(spiketrains,
  391. object_type=neo.SpikeTrain,
  392. t_start=self._t_start,
  393. t_stop=self._t_stop,
  394. tolerance=tolerance)
  395. except ValueError as er:
  396. # different t_start/t_stop
  397. raise ValueError(er, "If you want to bin over the shared "
  398. "[t_start, t_stop] interval, provide "
  399. "shared t_start and t_stop explicitly, "
  400. "which can be obtained like so: "
  401. "t_start, t_stop = elephant.utils."
  402. "get_common_start_stop_times(spiketrains)"
  403. )
  404. if self._t_start is None:
  405. self._t_start = spiketrains[0].t_start
  406. if self._t_stop is None:
  407. self._t_stop = spiketrains[0].t_stop
  408. # At this point, all spiketrains share the same units.
  409. self.units = spiketrains[0].units
  410. try:
  411. self._t_start = self._t_start.rescale(self.units).item()
  412. self._t_stop = self._t_stop.rescale(self.units).item()
  413. except AttributeError:
  414. raise ValueError("'t_start' and 't_stop' must be quantities")
  415. start_shared, stop_shared = get_common_start_stop_times(spiketrains)
  416. start_shared = start_shared.rescale(self.units).item()
  417. stop_shared = stop_shared.rescale(self.units).item()
  418. if tolerance is None:
  419. tolerance = 0
  420. if self._t_start < start_shared - tolerance \
  421. or self._t_stop > stop_shared + tolerance:
  422. raise ValueError("'t_start' ({t_start}) or 't_stop' ({t_stop}) is "
  423. "outside of the shared [{start_shared}, "
  424. "{stop_shared}] interval".format(
  425. t_start=self.t_start, t_stop=self.t_stop,
  426. start_shared=start_shared,
  427. stop_shared=stop_shared))
  428. if self.n_bins is None:
  429. # bin_size is provided
  430. self._bin_size = self._bin_size.rescale(self.units).item()
  431. self.n_bins = get_n_bins()
  432. elif self._bin_size is None:
  433. # n_bins is provided
  434. self._bin_size = (self._t_stop - self._t_start) / self.n_bins
  435. else:
  436. # both n_bins are bin_size are given
  437. self._bin_size = self._bin_size.rescale(self.units).item()
  438. check_n_bins_consistency()
  439. check_consistency()
  440. @property
  441. def bin_edges(self):
  442. """
  443. Returns all time edges as a quantity array with :attr:`n_bins` bins.
  444. The borders of all time steps between :attr:`t_start` and
  445. :attr:`t_stop` with a step :attr:`bin_size`. It is crucial for many
  446. analyses that all bins have the same size, so if
  447. :attr:`t_stop` - :attr:`t_start` is not divisible by :attr:`bin_size`,
  448. there will be some leftover time at the end
  449. (see https://github.com/NeuralEnsemble/elephant/issues/255).
  450. The length of the returned array should match :attr:`n_bins`.
  451. Returns
  452. -------
  453. bin_edges : pq.Quantity
  454. All edges in interval [:attr:`t_start`, :attr:`t_stop`] with
  455. :attr:`n_bins` bins are returned as a quantity array.
  456. """
  457. bin_edges = np.linspace(self._t_start, self._t_start + self.n_bins *
  458. self._bin_size,
  459. num=self.n_bins + 1, endpoint=True)
  460. return pq.Quantity(bin_edges, units=self.units, copy=False)
  461. @property
  462. def bin_centers(self):
  463. """
  464. Returns each center time point of all bins between :attr:`t_start` and
  465. :attr:`t_stop` points.
  466. The center of each bin of all time steps between start and stop.
  467. Returns
  468. -------
  469. bin_edges : pq.Quantity
  470. All center edges in interval (:attr:`start`, :attr:`stop`).
  471. """
  472. start = self._t_start + self._bin_size / 2
  473. stop = start + (self.n_bins - 1) * self._bin_size
  474. bin_centers = np.linspace(start=start,
  475. stop=stop,
  476. num=self.n_bins, endpoint=True)
  477. bin_centers = pq.Quantity(bin_centers, units=self.units, copy=False)
  478. return bin_centers
  479. def to_sparse_array(self):
  480. """
  481. Getter for sparse matrix with time points.
  482. Returns
  483. -------
  484. scipy.sparse.csr_matrix
  485. Sparse matrix, version with spike counts.
  486. See also
  487. --------
  488. scipy.sparse.csr_matrix
  489. to_array
  490. """
  491. warnings.warn("'.to_sparse_array()' function is deprecated; "
  492. "use '.sparse_matrix' attribute directly",
  493. DeprecationWarning)
  494. return self.sparse_matrix
  495. def to_sparse_bool_array(self):
  496. """
  497. Getter for boolean version of the sparse matrix, calculated from
  498. sparse matrix with counted time points.
  499. Returns
  500. -------
  501. scipy.sparse.csr_matrix
  502. Sparse matrix, binary, boolean version.
  503. See also
  504. --------
  505. scipy.sparse.csr_matrix
  506. to_bool_array
  507. """
  508. # Return sparse Matrix as a copy
  509. spmat_copy = self.sparse_matrix.copy()
  510. spmat_copy.data = spmat_copy.data.astype(bool)
  511. return spmat_copy
  512. def get_num_of_spikes(self, axis=None):
  513. """
  514. Compute the number of binned spikes.
  515. Parameters
  516. ----------
  517. axis : int, optional
  518. If `None`, compute the total num. of spikes.
  519. Otherwise, compute num. of spikes along axis.
  520. If axis is `1`, compute num. of spikes per spike train (row).
  521. Default is `None`.
  522. Returns
  523. -------
  524. n_spikes_per_row : int or np.ndarray
  525. The number of binned spikes.
  526. """
  527. if axis is None:
  528. return self.sparse_matrix.sum(axis=axis)
  529. n_spikes_per_row = self.sparse_matrix.sum(axis=axis)
  530. n_spikes_per_row = np.ravel(n_spikes_per_row)
  531. return n_spikes_per_row
  532. @property
  533. def spike_indices(self):
  534. """
  535. A list of lists for each spike train (i.e., rows of the binned matrix),
  536. that in turn contains for each spike the index into the binned matrix
  537. where this spike enters.
  538. In contrast to `to_sparse_array().nonzero()`, this function will report
  539. two spikes falling in the same bin as two entries.
  540. Examples
  541. --------
  542. >>> import elephant.conversion as conv
  543. >>> import neo as n
  544. >>> import quantities as pq
  545. >>> st = n.SpikeTrain([0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s,
  546. ... t_stop=10.0 * pq.s)
  547. >>> x = conv.BinnedSpikeTrain(st, n_bins=10, bin_size=1 * pq.s,
  548. ... t_start=0 * pq.s)
  549. >>> print(x.spike_indices)
  550. [[0, 0, 1, 3, 4, 5, 6]]
  551. >>> print(x.sparse_matrix.nonzero()[1])
  552. [0 1 3 4 5 6]
  553. >>> print(x.to_array())
  554. [[2, 1, 0, 1, 1, 1, 1, 0, 0, 0]]
  555. """
  556. spike_idx = []
  557. for row in self.sparse_matrix:
  558. # Extract each non-zeros column index and how often it exists,
  559. # i.e., how many spikes fall in this column
  560. n_cols = np.repeat(row.indices, row.data)
  561. spike_idx.append(n_cols)
  562. return spike_idx
  563. @property
  564. def is_binary(self):
  565. """
  566. Checks and returns `True` if given input is a binary input.
  567. Beware, that the function does not know if the input is binary
  568. because e.g `to_bool_array()` was used before or if the input is just
  569. sparse (i.e. only one spike per bin at maximum).
  570. Returns
  571. -------
  572. bool
  573. True for binary input, False otherwise.
  574. """
  575. return is_binary(self.sparse_matrix.data)
  576. def to_bool_array(self):
  577. """
  578. Returns a matrix, in which the rows correspond to the spike trains and
  579. the columns correspond to the bins in the `BinnedSpikeTrain`.
  580. `True` indicates a spike in given bin of given spike train and
  581. `False` indicates lack of spikes.
  582. Returns
  583. -------
  584. numpy.ndarray
  585. Returns a dense matrix representation of the sparse matrix,
  586. with `True` indicating a spike and `False` indicating a no-spike.
  587. The columns represent the index position of the bins and rows
  588. represent the number of spike trains.
  589. See also
  590. --------
  591. scipy.sparse.csr_matrix
  592. scipy.sparse.csr_matrix.toarray
  593. Examples
  594. --------
  595. >>> import elephant.conversion as conv
  596. >>> import neo as n
  597. >>> import quantities as pq
  598. >>> a = n.SpikeTrain([0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s,
  599. ... t_stop=10.0 * pq.s)
  600. >>> x = conv.BinnedSpikeTrain(a, n_bins=10, bin_size=1 * pq.s,
  601. ... t_start=0 * pq.s)
  602. >>> print(x.to_bool_array())
  603. [[ True True False True True True True False False False]]
  604. """
  605. return self.to_array(dtype=bool)
  606. def to_array(self, dtype=None):
  607. """
  608. Returns a dense matrix, calculated from the sparse matrix, with counted
  609. time points of spikes. The rows correspond to spike trains and the
  610. columns correspond to bins in a `BinnedSpikeTrain`.
  611. Entries contain the count of spikes that occurred in the given bin of
  612. the given spike train.
  613. Returns
  614. -------
  615. matrix : np.ndarray
  616. Matrix with spike counts. Columns represent the index positions of
  617. the binned spikes and rows represent the spike trains.
  618. Examples
  619. --------
  620. >>> import elephant.conversion as conv
  621. >>> import neo as n
  622. >>> import quantities as pq
  623. >>> a = n.SpikeTrain([0.5, 0.7, 1.2, 3.1, 4.3, 5.5, 6.7] * pq.s,
  624. ... t_stop=10.0 * pq.s)
  625. >>> x = conv.BinnedSpikeTrain(a, n_bins=10, bin_size=1 * pq.s,
  626. ... t_start=0 * pq.s)
  627. >>> print(x.to_array())
  628. [[2 1 0 1 1 1 1 0 0 0]]
  629. See also
  630. --------
  631. scipy.sparse.csr_matrix
  632. scipy.sparse.csr_matrix.toarray
  633. """
  634. spmat = self.sparse_matrix
  635. if dtype is not None and dtype != spmat.data.dtype:
  636. # avoid a copy
  637. spmat = sps.csr_matrix(
  638. (spmat.data.astype(dtype), spmat.indices, spmat.indptr),
  639. shape=spmat.shape)
  640. return spmat.toarray()
  641. def binarize(self, copy=None):
  642. """
  643. Clip the internal array (no. of spikes in a bin) to `0` (no spikes) or
  644. `1` (at least one spike) values only.
  645. Parameters
  646. ----------
  647. copy : bool, optional
  648. Deprecated parameter. It has no effect.
  649. Returns
  650. -------
  651. bst : _BinnedSpikeTrainView
  652. A view of `BinnedSpikeTrain` with a sparse matrix containing
  653. data clipped to `0`s and `1`s.
  654. """
  655. if copy is not None:
  656. warnings.warn("'copy' parameter is deprecated - a view is always "
  657. "returned; set this parameter to None.",
  658. DeprecationWarning)
  659. spmat = self.sparse_matrix
  660. spmat = sps.csr_matrix(
  661. (spmat.data.clip(max=1), spmat.indices, spmat.indptr),
  662. shape=spmat.shape, copy=False)
  663. bst = _BinnedSpikeTrainView(t_start=self._t_start,
  664. t_stop=self._t_stop,
  665. bin_size=self._bin_size,
  666. units=self.units,
  667. sparse_matrix=spmat)
  668. return bst
  669. @property
  670. def sparsity(self):
  671. """
  672. Returns
  673. -------
  674. float
  675. Matrix sparsity defined as no. of nonzero elements divided by
  676. the matrix size
  677. """
  678. num_nonzero = self.sparse_matrix.data.shape[0]
  679. return num_nonzero / np.prod(self.sparse_matrix.shape)
  680. def _create_sparse_matrix(self, spiketrains, tolerance):
  681. """
  682. Converts `neo.SpikeTrain` objects to a sparse matrix
  683. (`scipy.sparse.csr_matrix`), which contains the binned spike times, and
  684. stores it in :attr:`_sparse_mat_u`.
  685. Parameters
  686. ----------
  687. spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
  688. Spike trains to bin.
  689. """
  690. if not _check_neo_spiketrain(spiketrains):
  691. # a binned numpy array
  692. sparse_matrix = sps.csr_matrix(spiketrains, dtype=np.int32)
  693. return sparse_matrix
  694. row_ids, column_ids = [], []
  695. # data
  696. counts = []
  697. n_discarded = 0
  698. # all spiketrains carry the same units
  699. scale_units = 1 / self._bin_size
  700. for idx, st in enumerate(spiketrains):
  701. times = st.magnitude
  702. times = times[(times >= self._t_start) & (
  703. times <= self._t_stop)] - self._t_start
  704. bins = times * scale_units
  705. # shift spikes that are very close
  706. # to the right edge into the next bin
  707. rounding_error_indices = _detect_rounding_errors(
  708. bins, tolerance=tolerance)
  709. num_rounding_corrections = rounding_error_indices.sum()
  710. if num_rounding_corrections > 0:
  711. warnings.warn('Correcting {} rounding errors by shifting '
  712. 'the affected spikes into the following bin. '
  713. 'You can set tolerance=None to disable this '
  714. 'behaviour.'.format(num_rounding_corrections))
  715. bins[rounding_error_indices] += .5
  716. bins = bins.astype(np.int32)
  717. valid_bins = bins[bins < self.n_bins]
  718. n_discarded += len(bins) - len(valid_bins)
  719. f, c = np.unique(valid_bins, return_counts=True)
  720. column_ids.append(f)
  721. counts.append(c)
  722. row_ids.append(np.repeat(idx, repeats=len(f)))
  723. if n_discarded > 0:
  724. warnings.warn("Binning discarded {} last spike(s) of the "
  725. "input spiketrain".format(n_discarded))
  726. counts = np.hstack(counts)
  727. row_ids = np.hstack(row_ids)
  728. column_ids = np.hstack(column_ids)
  729. sparse_matrix = sps.csr_matrix((counts, (row_ids, column_ids)),
  730. shape=(len(spiketrains), self.n_bins),
  731. dtype=np.int32, copy=False)
  732. return sparse_matrix
  733. class _BinnedSpikeTrainView(BinnedSpikeTrain):
  734. # Experimental feature and should not be public now.
  735. def __init__(self, t_start, t_stop, bin_size, units, sparse_matrix):
  736. self._t_start = t_start
  737. self._t_stop = t_stop
  738. self._bin_size = bin_size
  739. self.n_bins = sparse_matrix.shape[1]
  740. self.units = units
  741. self.sparse_matrix = sparse_matrix
  742. def _check_neo_spiketrain(matrix):
  743. """
  744. Checks if given input contains neo.SpikeTrain objects
  745. Parameters
  746. ----------
  747. matrix
  748. Object to test for `neo.SpikeTrain`s
  749. Returns
  750. -------
  751. bool
  752. True if `matrix` is a neo.SpikeTrain or a list or tuple thereof,
  753. otherwise False.
  754. """
  755. # Check for single spike train
  756. if isinstance(matrix, neo.SpikeTrain):
  757. return True
  758. # Check for list or tuple
  759. if isinstance(matrix, (list, tuple)):
  760. return all(map(_check_neo_spiketrain, matrix))
  761. return False