spike_train_generation.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970
  1. # -*- coding: utf-8 -*-
  2. """
  3. Functions to generate spike trains from analog signals,
  4. or to generate random spike trains.
  5. Some functions are based on the NeuroTools stgen module, which was mostly
  6. written by Eilif Muller, or from the NeuroTools signals.analogs module.
  7. :copyright: Copyright 2015 by the Elephant team, see AUTHORS.txt.
  8. :license: Modified BSD, see LICENSE.txt for details.
  9. """
  10. from __future__ import division
  11. import numpy as np
  12. from quantities import ms, mV, Hz, Quantity, dimensionless
  13. from neo import SpikeTrain
  14. import random
  15. from elephant.spike_train_surrogates import dither_spike_train
  16. import warnings
  17. def spike_extraction(signal, threshold=0.0 * mV, sign='above',
  18. time_stamps=None, extr_interval=(-2 * ms, 4 * ms)):
  19. """
  20. Return the peak times for all events that cross threshold and the
  21. waveforms. Usually used for extracting spikes from a membrane
  22. potential to calculate waveform properties.
  23. Similar to spike_train_generation.peak_detection.
  24. Parameters
  25. ----------
  26. signal : neo AnalogSignal object
  27. 'signal' is an analog signal.
  28. threshold : A quantity, e.g. in mV
  29. 'threshold' contains a value that must be reached for an event
  30. to be detected. Default: 0.0 * mV.
  31. sign : 'above' or 'below'
  32. 'sign' determines whether to count thresholding crossings
  33. that cross above or below the threshold. Default: 'above'.
  34. time_stamps: None, quantity array or Object with .times interface
  35. if 'spike_train' is a quantity array or exposes a quantity array
  36. exposes the .times interface, it provides the time_stamps
  37. around which the waveform is extracted. If it is None, the
  38. function peak_detection is used to calculate the time_stamps
  39. from signal. Default: None.
  40. extr_interval: unpackable time quantities, len == 2
  41. 'extr_interval' specifies the time interval around the
  42. time_stamps where the waveform is extracted. The default is an
  43. interval of '6 ms'. Default: (-2 * ms, 4 * ms).
  44. Returns
  45. -------
  46. result_st : neo SpikeTrain object
  47. 'result_st' contains the time_stamps of each of the spikes and
  48. the waveforms in result_st.waveforms.
  49. """
  50. # Get spike time_stamps
  51. if time_stamps is None:
  52. time_stamps = peak_detection(signal, threshold, sign=sign)
  53. elif hasattr(time_stamps, 'times'):
  54. time_stamps = time_stamps.times
  55. elif type(time_stamps) is Quantity:
  56. raise TypeError("time_stamps must be None, a quantity array or" +
  57. " expose the.times interface")
  58. if len(time_stamps) == 0:
  59. return SpikeTrain(time_stamps, units=signal.times.units,
  60. t_start=signal.t_start, t_stop=signal.t_stop,
  61. waveforms=np.array([]),
  62. sampling_rate=signal.sampling_rate)
  63. # Unpack the extraction interval from tuple or array
  64. extr_left, extr_right = extr_interval
  65. if extr_left > extr_right:
  66. raise ValueError("extr_interval[0] must be < extr_interval[1]")
  67. if any(np.diff(time_stamps) < extr_interval[1]):
  68. warnings.warn("Waveforms overlap.", UserWarning)
  69. data_left = ((extr_left * signal.sampling_rate).simplified).magnitude
  70. data_right = ((extr_right * signal.sampling_rate).simplified).magnitude
  71. data_stamps = (((time_stamps - signal.t_start) *
  72. signal.sampling_rate).simplified).magnitude
  73. data_stamps = data_stamps.astype(int)
  74. borders_left = data_stamps + data_left
  75. borders_right = data_stamps + data_right
  76. borders = np.dstack((borders_left, borders_right)).flatten()
  77. waveforms = np.array(
  78. np.split(np.array(signal), borders.astype(int))[1::2]) * signal.units
  79. # len(np.shape(waveforms)) == 1 if waveforms do not have the same width.
  80. # this can occur when extr_interval indexes beyond the signal.
  81. # Workaround: delete spikes shorter than the maximum length with
  82. if len(np.shape(waveforms)) == 1:
  83. max_len = (np.array([len(x) for x in waveforms])).max()
  84. to_delete = np.array([idx for idx, x in enumerate(waveforms)
  85. if len(x) < max_len])
  86. waveforms = np.delete(waveforms, to_delete, axis=0)
  87. waveforms = np.array([x for x in waveforms])
  88. warnings.warn("Waveforms " +
  89. ("{:d}, " * len(to_delete)).format(*to_delete) +
  90. "exceeded signal and had to be deleted. " +
  91. "Change extr_interval to keep.")
  92. waveforms = waveforms[:, np.newaxis, :]
  93. return SpikeTrain(time_stamps, units=signal.times.units,
  94. t_start=signal.t_start, t_stop=signal.t_stop,
  95. sampling_rate=signal.sampling_rate, waveforms=waveforms,
  96. left_sweep=extr_left)
  97. def threshold_detection(signal, threshold=0.0 * mV, sign='above'):
  98. """
  99. Returns the times when the analog signal crosses a threshold.
  100. Usually used for extracting spike times from a membrane potential.
  101. Adapted from version in NeuroTools.
  102. Parameters
  103. ----------
  104. signal : neo AnalogSignal object
  105. 'signal' is an analog signal.
  106. threshold : A quantity, e.g. in mV
  107. 'threshold' contains a value that must be reached
  108. for an event to be detected. Default: 0.0 * mV.
  109. sign : 'above' or 'below'
  110. 'sign' determines whether to count thresholding crossings
  111. that cross above or below the threshold.
  112. format : None or 'raw'
  113. Whether to return as SpikeTrain (None)
  114. or as a plain array of times ('raw').
  115. Returns
  116. -------
  117. result_st : neo SpikeTrain object
  118. 'result_st' contains the spike times of each of the events (spikes)
  119. extracted from the signal.
  120. """
  121. assert threshold is not None, "A threshold must be provided"
  122. if sign is 'above':
  123. cutout = np.where(signal > threshold)[0]
  124. elif sign in 'below':
  125. cutout = np.where(signal < threshold)[0]
  126. if len(cutout) <= 0:
  127. events = np.zeros(0)
  128. else:
  129. take = np.where(np.diff(cutout) > 1)[0] + 1
  130. take = np.append(0, take)
  131. time = signal.times
  132. events = time[cutout][take]
  133. events_base = events.base
  134. if events_base is None:
  135. # This occurs in some Python 3 builds due to some
  136. # bug in quantities.
  137. events_base = np.array([event.base for event in events]) # Workaround
  138. result_st = SpikeTrain(events_base, units=signal.times.units,
  139. t_start=signal.t_start, t_stop=signal.t_stop)
  140. return result_st
  141. def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
  142. """
  143. Return the peak times for all events that cross threshold.
  144. Usually used for extracting spike times from a membrane potential.
  145. Similar to spike_train_generation.threshold_detection.
  146. Parameters
  147. ----------
  148. signal : neo AnalogSignal object
  149. 'signal' is an analog signal.
  150. threshold : A quantity, e.g. in mV
  151. 'threshold' contains a value that must be reached
  152. for an event to be detected.
  153. sign : 'above' or 'below'
  154. 'sign' determines whether to count thresholding crossings that
  155. cross above or below the threshold. Default: 'above'.
  156. format : None or 'raw'
  157. Whether to return as SpikeTrain (None) or as a plain array
  158. of times ('raw'). Default: None.
  159. Returns
  160. -------
  161. result_st : neo SpikeTrain object
  162. 'result_st' contains the spike times of each of the events
  163. (spikes) extracted from the signal.
  164. """
  165. assert threshold is not None, "A threshold must be provided"
  166. if sign is 'above':
  167. cutout = np.where(signal > threshold)[0]
  168. peak_func = np.argmax
  169. elif sign in 'below':
  170. cutout = np.where(signal < threshold)[0]
  171. peak_func = np.argmin
  172. else:
  173. raise ValueError("sign must be 'above' or 'below'")
  174. if len(cutout) <= 0:
  175. events_base = np.zeros(0)
  176. else:
  177. # Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
  178. # This avoids empty slices
  179. border_start = np.where(np.diff(cutout) > 1)[0]
  180. border_end = border_start + 1
  181. borders = np.concatenate((border_start, border_end))
  182. borders = np.append(0, borders)
  183. borders = np.append(borders, len(cutout)-1)
  184. borders = np.sort(borders)
  185. true_borders = cutout[borders]
  186. right_borders = true_borders[1::2] + 1
  187. true_borders = np.sort(np.append(true_borders[0::2], right_borders))
  188. # Workaround for bug that occurs when signal goes below thr for 1 dtp,
  189. # Workaround eliminates empy slices from np. split
  190. backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
  191. forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
  192. to_begin=1)[::-1]) > 0
  193. true_borders = true_borders[backward_mask * forward_mask]
  194. split_signal = np.split(np.array(signal), true_borders)[1::2]
  195. maxima_idc_split = np.array([peak_func(x) for x in split_signal])
  196. max_idc = maxima_idc_split + true_borders[0::2]
  197. events = signal.times[max_idc]
  198. events_base = events.base
  199. if events_base is None:
  200. # This occurs in some Python 3 builds due to some
  201. # bug in quantities.
  202. events_base = np.array([event.base for event in events]) # Workaround
  203. if format is None:
  204. result_st = SpikeTrain(events_base, units=signal.times.units,
  205. t_start=signal.t_start, t_stop=signal.t_stop)
  206. elif 'raw':
  207. result_st = events_base
  208. else:
  209. raise ValueError("Format argument must be None or 'raw'")
  210. return result_st
  211. def _homogeneous_process(interval_generator, args, mean_rate, t_start, t_stop,
  212. as_array):
  213. """
  214. Returns a spike train whose spikes are a realization of a random process
  215. generated by the function `interval_generator` with the given rate,
  216. starting at time `t_start` and stopping `time t_stop`.
  217. """
  218. def rescale(x):
  219. return (x / mean_rate.units).rescale(t_stop.units)
  220. n = int(((t_stop - t_start) * mean_rate).simplified)
  221. number = np.ceil(n + 3 * np.sqrt(n))
  222. if number < 100:
  223. number = min(5 + np.ceil(2 * n), 100)
  224. assert number > 4 # if positive, number cannot be less than 5
  225. isi = rescale(interval_generator(*args, size=int(number)))
  226. spikes = np.cumsum(isi)
  227. spikes += t_start
  228. i = spikes.searchsorted(t_stop)
  229. if i == len(spikes):
  230. # ISI buffer overrun
  231. extra_spikes = []
  232. t_last = spikes[-1] + rescale(interval_generator(*args, size=1))[0]
  233. while t_last < t_stop:
  234. extra_spikes.append(t_last)
  235. t_last = t_last + rescale(interval_generator(*args, size=1))[0]
  236. # np.concatenate does not conserve units
  237. spikes = Quantity(
  238. np.concatenate(
  239. (spikes, extra_spikes)).magnitude, units=spikes.units)
  240. else:
  241. spikes = spikes[:i]
  242. if as_array:
  243. spikes = spikes.magnitude
  244. else:
  245. spikes = SpikeTrain(
  246. spikes, t_start=t_start, t_stop=t_stop, units=spikes.units)
  247. return spikes
  248. def homogeneous_poisson_process(rate, t_start=0.0 * ms, t_stop=1000.0 * ms,
  249. as_array=False):
  250. """
  251. Returns a spike train whose spikes are a realization of a Poisson process
  252. with the given rate, starting at time `t_start` and stopping time `t_stop`.
  253. All numerical values should be given as Quantities, e.g. 100*Hz.
  254. Parameters
  255. ----------
  256. rate : Quantity scalar with dimension 1/time
  257. The rate of the discharge.
  258. t_start : Quantity scalar with dimension time
  259. The beginning of the spike train.
  260. t_stop : Quantity scalar with dimension time
  261. The end of the spike train.
  262. as_array : bool
  263. If True, a NumPy array of sorted spikes is returned,
  264. rather than a SpikeTrain object.
  265. Raises
  266. ------
  267. ValueError : If `t_start` and `t_stop` are not of type `pq.Quantity`.
  268. Examples
  269. --------
  270. >>> from quantities import Hz, ms
  271. >>> spikes = homogeneous_poisson_process(50*Hz, 0*ms, 1000*ms)
  272. >>> spikes = homogeneous_poisson_process(
  273. 20*Hz, 5000*ms, 10000*ms, as_array=True)
  274. """
  275. if not isinstance(t_start, Quantity) or not isinstance(t_stop, Quantity):
  276. raise ValueError("t_start and t_stop must be of type pq.Quantity")
  277. rate = rate.rescale((1 / t_start).units)
  278. mean_interval = 1 / rate.magnitude
  279. return _homogeneous_process(
  280. np.random.exponential, (mean_interval,), rate, t_start, t_stop,
  281. as_array)
  282. def homogeneous_gamma_process(a, b, t_start=0.0 * ms, t_stop=1000.0 * ms,
  283. as_array=False):
  284. """
  285. Returns a spike train whose spikes are a realization of a gamma process
  286. with the given parameters, starting at time `t_start` and stopping time
  287. `t_stop` (average rate will be b/a).
  288. All numerical values should be given as Quantities, e.g. 100*Hz.
  289. Parameters
  290. ----------
  291. a : int or float
  292. The shape parameter of the gamma distribution.
  293. b : Quantity scalar with dimension 1/time
  294. The rate parameter of the gamma distribution.
  295. t_start : Quantity scalar with dimension time
  296. The beginning of the spike train.
  297. t_stop : Quantity scalar with dimension time
  298. The end of the spike train.
  299. as_array : bool
  300. If True, a NumPy array of sorted spikes is returned,
  301. rather than a SpikeTrain object.
  302. Raises
  303. ------
  304. ValueError : If `t_start` and `t_stop` are not of type `pq.Quantity`.
  305. Examples
  306. --------
  307. >>> from quantities import Hz, ms
  308. >>> spikes = homogeneous_gamma_process(2.0, 50*Hz, 0*ms, 1000*ms)
  309. >>> spikes = homogeneous_gamma_process(
  310. 5.0, 20*Hz, 5000*ms, 10000*ms, as_array=True)
  311. """
  312. if not isinstance(t_start, Quantity) or not isinstance(t_stop, Quantity):
  313. raise ValueError("t_start and t_stop must be of type pq.Quantity")
  314. b = b.rescale((1 / t_start).units).simplified
  315. rate = b / a
  316. k, theta = a, (1 / b.magnitude)
  317. return _homogeneous_process(np.random.gamma, (k, theta), rate, t_start, t_stop, as_array)
  318. def _n_poisson(rate, t_stop, t_start=0.0 * ms, n=1):
  319. """
  320. Generates one or more independent Poisson spike trains.
  321. Parameters
  322. ----------
  323. rate : Quantity or Quantity array
  324. Expected firing rate (frequency) of each output SpikeTrain.
  325. Can be one of:
  326. * a single Quantity value: expected firing rate of each output
  327. SpikeTrain
  328. * a Quantity array: rate[i] is the expected firing rate of the i-th
  329. output SpikeTrain
  330. t_stop : Quantity
  331. Single common stop time of each output SpikeTrain. Must be > t_start.
  332. t_start : Quantity (optional)
  333. Single common start time of each output SpikeTrain. Must be < t_stop.
  334. Default: 0 s.
  335. n: int (optional)
  336. If rate is a single Quantity value, n specifies the number of
  337. SpikeTrains to be generated. If rate is an array, n is ignored and the
  338. number of SpikeTrains is equal to len(rate).
  339. Default: 1
  340. Returns
  341. -------
  342. list of neo.SpikeTrain
  343. Each SpikeTrain contains one of the independent Poisson spike trains,
  344. either n SpikeTrains of the same rate, or len(rate) SpikeTrains with
  345. varying rates according to the rate parameter. The time unit of the
  346. SpikeTrains is given by t_stop.
  347. """
  348. # Check that the provided input is Hertz of return error
  349. try:
  350. for r in rate.reshape(-1, 1):
  351. r.rescale('Hz')
  352. except AttributeError:
  353. raise ValueError('rate argument must have rate unit (1/time)')
  354. # Check t_start < t_stop and create their strip dimensions
  355. if not t_start < t_stop:
  356. raise ValueError(
  357. 't_start (=%s) must be < t_stop (=%s)' % (t_start, t_stop))
  358. # Set number n of output spike trains (specified or set to len(rate))
  359. if not (type(n) == int and n > 0):
  360. raise ValueError('n (=%s) must be a positive integer' % str(n))
  361. rate_dl = rate.simplified.magnitude.flatten()
  362. # Check rate input parameter
  363. if len(rate_dl) == 1:
  364. if rate_dl < 0:
  365. raise ValueError('rate (=%s) must be non-negative.' % rate)
  366. rates = np.array([rate_dl] * n)
  367. else:
  368. rates = rate_dl.flatten()
  369. if any(rates < 0):
  370. raise ValueError('rate must have non-negative elements.')
  371. sts = []
  372. for r in rates:
  373. sts.append(homogeneous_poisson_process(r * Hz, t_start, t_stop))
  374. return sts
  375. def single_interaction_process(
  376. rate, rate_c, t_stop, n=2, jitter=0 * ms, coincidences='deterministic',
  377. t_start=0 * ms, min_delay=0 * ms, return_coinc=False):
  378. """
  379. Generates a multidimensional Poisson SIP (single interaction process)
  380. plus independent Poisson processes
  381. A Poisson SIP consists of Poisson time series which are independent
  382. except for simultaneous events in all of them. This routine generates
  383. a SIP plus additional parallel independent Poisson processes.
  384. See [1].
  385. Parameters
  386. -----------
  387. t_stop: quantities.Quantity
  388. Total time of the simulated processes. The events are drawn between
  389. 0 and `t_stop`.
  390. rate: quantities.Quantity
  391. Overall mean rate of the time series to be generated (coincidence
  392. rate `rate_c` is subtracted to determine the background rate). Can be:
  393. * a float, representing the overall mean rate of each process. If
  394. so, it must be higher than `rate_c`.
  395. * an iterable of floats (one float per process), each float
  396. representing the overall mean rate of a process. If so, all the
  397. entries must be larger than `rate_c`.
  398. rate_c: quantities.Quantity
  399. Coincidence rate (rate of coincidences for the n-dimensional SIP).
  400. The SIP spike trains will have coincident events with rate `rate_c`
  401. plus independent 'background' events with rate `rate-rate_c`.
  402. n: int, optional
  403. If `rate` is a single Quantity value, `n` specifies the number of
  404. SpikeTrains to be generated. If rate is an array, `n` is ignored and
  405. the number of SpikeTrains is equal to `len(rate)`.
  406. Default: 1
  407. jitter: quantities.Quantity, optional
  408. Jitter for the coincident events. If `jitter == 0`, the events of all
  409. n correlated processes are exactly coincident. Otherwise, they are
  410. jittered around a common time randomly, up to +/- `jitter`.
  411. coincidences: string, optional
  412. Whether the total number of injected coincidences must be determin-
  413. istic (i.e. rate_c is the actual rate with which coincidences are
  414. generated) or stochastic (i.e. rate_c is the mean rate of coincid-
  415. ences):
  416. * 'deterministic': deterministic rate
  417. * 'stochastic': stochastic rate
  418. Default: 'deterministic'
  419. t_start: quantities.Quantity, optional
  420. Starting time of the series. If specified, it must be lower than
  421. t_stop
  422. Default: 0 * ms
  423. min_delay: quantities.Quantity, optional
  424. Minimum delay between consecutive coincidence times.
  425. Default: 0 * ms
  426. return_coinc: bool, optional
  427. Whether to return the coincidence times for the SIP process
  428. Default: False
  429. Returns
  430. --------
  431. output: list
  432. Realization of a SIP consisting of n Poisson processes characterized
  433. by synchronous events (with the given jitter)
  434. If `return_coinc` is `True`, the coincidence times are returned as a
  435. second output argument. They also have an associated time unit (same
  436. as `t_stop`).
  437. References
  438. ----------
  439. [1] Kuhn, Aertsen, Rotter (2003) Neural Comput 15(1):67-101
  440. EXAMPLE:
  441. >>> import quantities as qt
  442. >>> import jelephant.core.stocmod as sm
  443. >>> sip, coinc = sm.sip_poisson(n=10, n=0, t_stop=1*qt.sec, \
  444. rate=20*qt.Hz, rate_c=4, return_coinc = True)
  445. *************************************************************************
  446. """
  447. # Check if n is a positive integer
  448. if not (isinstance(n, int) and n > 0):
  449. raise ValueError('n (=%s) must be a positive integer' % str(n))
  450. # Assign time unit to jitter, or check that its existing unit is a time
  451. # unit
  452. jitter = abs(jitter)
  453. # Define the array of rates from input argument rate. Check that its length
  454. # matches with n
  455. if rate.ndim == 0:
  456. if rate < 0 * Hz:
  457. raise ValueError(
  458. 'rate (=%s) must be non-negative.' % str(rate))
  459. rates_b = np.array(
  460. [rate.magnitude for _ in range(n)]) * rate.units
  461. else:
  462. rates_b = np.array(rate).flatten() * rate.units
  463. if not all(rates_b >= 0. * Hz):
  464. raise ValueError('*rate* must have non-negative elements')
  465. # Check: rate>=rate_c
  466. if np.any(rates_b < rate_c):
  467. raise ValueError('all elements of *rate* must be >= *rate_c*')
  468. # Check min_delay < 1./rate_c
  469. if not (rate_c == 0 * Hz or min_delay < 1. / rate_c):
  470. raise ValueError(
  471. "'*min_delay* (%s) must be lower than 1/*rate_c* (%s)." %
  472. (str(min_delay), str((1. / rate_c).rescale(min_delay.units))))
  473. # Generate the n Poisson processes there are the basis for the SIP
  474. # (coincidences still lacking)
  475. embedded_poisson_trains = _n_poisson(
  476. rate=rates_b - rate_c, t_stop=t_stop, t_start=t_start)
  477. # Convert the trains from neo SpikeTrain objects to simpler Quantity
  478. # objects
  479. embedded_poisson_trains = [
  480. emb.view(Quantity) for emb in embedded_poisson_trains]
  481. # Generate the array of times for coincident events in SIP, not closer than
  482. # min_delay. The array is generated as a quantity from the Quantity class
  483. # in the quantities module
  484. if coincidences == 'deterministic':
  485. Nr_coinc = int(((t_stop - t_start) * rate_c).rescale(dimensionless))
  486. while True:
  487. coinc_times = t_start + \
  488. np.sort(np.random.random(Nr_coinc)) * (t_stop - t_start)
  489. if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay:
  490. break
  491. elif coincidences == 'stochastic':
  492. while True:
  493. coinc_times = homogeneous_poisson_process(
  494. rate=rate_c, t_stop=t_stop, t_start=t_start)
  495. if len(coinc_times) < 2 or min(np.diff(coinc_times)) >= min_delay:
  496. break
  497. # Convert coinc_times from a neo SpikeTrain object to a Quantity object
  498. # pq.Quantity(coinc_times.base)*coinc_times.units
  499. coinc_times = coinc_times.view(Quantity)
  500. # Set the coincidence times to T-jitter if larger. This ensures that
  501. # the last jittered spike time is <T
  502. for i in range(len(coinc_times)):
  503. if coinc_times[i] > t_stop - jitter:
  504. coinc_times[i] = t_stop - jitter
  505. # Replicate coinc_times n times, and jitter each event in each array by
  506. # +/- jitter (within (t_start, t_stop))
  507. embedded_coinc = coinc_times + \
  508. np.random.random(
  509. (len(rates_b), len(coinc_times))) * 2 * jitter - jitter
  510. embedded_coinc = embedded_coinc + \
  511. (t_start - embedded_coinc) * (embedded_coinc < t_start) - \
  512. (t_stop - embedded_coinc) * (embedded_coinc > t_stop)
  513. # Inject coincident events into the n SIP processes generated above, and
  514. # merge with the n independent processes
  515. sip_process = [
  516. np.sort(np.concatenate((
  517. embedded_poisson_trains[m].rescale(t_stop.units),
  518. embedded_coinc[m].rescale(t_stop.units))) * t_stop.units)
  519. for m in range(len(rates_b))]
  520. # Convert back sip_process and coinc_times from Quantity objects to
  521. # neo.SpikeTrain objects
  522. sip_process = [
  523. SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units)
  524. for t in sip_process]
  525. coinc_times = [
  526. SpikeTrain(t, t_start=t_start, t_stop=t_stop).rescale(t_stop.units)
  527. for t in embedded_coinc]
  528. # Return the processes in the specified output_format
  529. if not return_coinc:
  530. output = sip_process
  531. else:
  532. output = sip_process, coinc_times
  533. return output
  534. def _pool_two_spiketrains(a, b, extremes='inner'):
  535. """
  536. Pool the spikes of two spike trains a and b into a unique spike train.
  537. Parameters
  538. ----------
  539. a, b : neo.SpikeTrains
  540. Spike trains to be pooled
  541. extremes: str, optional
  542. Only spikes of a and b in the specified extremes are considered.
  543. * 'inner': pool all spikes from max(a.tstart_ b.t_start) to
  544. min(a.t_stop, b.t_stop)
  545. * 'outer': pool all spikes from min(a.tstart_ b.t_start) to
  546. max(a.t_stop, b.t_stop)
  547. Default: 'inner'
  548. Output
  549. ------
  550. neo.SpikeTrain containing all spikes in a and b falling in the
  551. specified extremes
  552. """
  553. unit = a.units
  554. times_a_dimless = list(a.view(Quantity).magnitude)
  555. times_b_dimless = list(b.rescale(unit).view(Quantity).magnitude)
  556. times = (times_a_dimless + times_b_dimless) * unit
  557. if extremes == 'outer':
  558. t_start = min(a.t_start, b.t_start)
  559. t_stop = max(a.t_stop, b.t_stop)
  560. elif extremes == 'inner':
  561. t_start = max(a.t_start, b.t_start)
  562. t_stop = min(a.t_stop, b.t_stop)
  563. times = times[times > t_start]
  564. times = times[times < t_stop]
  565. else:
  566. raise ValueError(
  567. 'extremes (%s) can only be "inner" or "outer"' % extremes)
  568. pooled_train = SpikeTrain(
  569. times=sorted(times.magnitude), units=unit, t_start=t_start,
  570. t_stop=t_stop)
  571. return pooled_train
  572. def _pool_spiketrains(trains, extremes='inner'):
  573. """
  574. Pool spikes from any number of spike trains into a unique spike train.
  575. Parameters
  576. ----------
  577. trains: list
  578. list of spike trains to merge
  579. extremes: str, optional
  580. Only spikes of a and b in the specified extremes are considered.
  581. * 'inner': pool all spikes from min(a.t_start b.t_start) to
  582. max(a.t_stop, b.t_stop)
  583. * 'outer': pool all spikes from max(a.tstart_ b.t_start) to
  584. min(a.t_stop, b.t_stop)
  585. Default: 'inner'
  586. Output
  587. ------
  588. neo.SpikeTrain containing all spikes in trains falling in the
  589. specified extremes
  590. """
  591. merge_trains = trains[0]
  592. for t in trains[1:]:
  593. merge_trains = _pool_two_spiketrains(
  594. merge_trains, t, extremes=extremes)
  595. t_start, t_stop = merge_trains.t_start, merge_trains.t_stop
  596. merge_trains = sorted(merge_trains)
  597. merge_trains = np.squeeze(merge_trains)
  598. merge_trains = SpikeTrain(
  599. merge_trains, t_stop=t_stop, t_start=t_start, units=trains[0].units)
  600. return merge_trains
  601. def _sample_int_from_pdf(a, n):
  602. """
  603. Draw n independent samples from the set {0,1,...,L}, where L=len(a)-1,
  604. according to the probability distribution a.
  605. a[j] is the probability to sample j, for each j from 0 to L.
  606. Parameters
  607. -----
  608. a: numpy.array
  609. Probability vector (i..e array of sum 1) that at each entry j carries
  610. the probability to sample j (j=0,1,...,len(a)-1).
  611. n: int
  612. Number of samples generated with the function
  613. Output
  614. -------
  615. array of n samples taking values between 0 and n=len(a)-1.
  616. """
  617. A = np.cumsum(a) # cumulative distribution of a
  618. u = np.random.uniform(0, 1, size=n)
  619. U = np.array([u for i in a]).T # copy u (as column vector) len(a) times
  620. return (A < U).sum(axis=1)
  621. def _mother_proc_cpp_stat(A, t_stop, rate, t_start=0 * ms):
  622. """
  623. Generate the hidden ("mother") Poisson process for a Compound Poisson
  624. Process (CPP).
  625. Parameters
  626. ----------
  627. A : numpy.array
  628. Amplitude distribution. A[j] represents the probability of a
  629. synchronous event of size j.
  630. The sum over all entries of a must be equal to one.
  631. t_stop : quantities.Quantity
  632. The stopping time of the mother process
  633. rate : quantities.Quantity
  634. Homogeneous rate of the n spike trains that will be genereted by the
  635. CPP function
  636. t_start : quantities.Quantity, optional
  637. The starting time of the mother process
  638. Default: 0 ms
  639. Output
  640. ------
  641. Poisson spike train representing the mother process generating the CPP
  642. """
  643. N = len(A) - 1
  644. exp_A = np.dot(A, range(N + 1)) # expected value of a
  645. exp_mother = (N * rate) / float(exp_A) # rate of the mother process
  646. return homogeneous_poisson_process(
  647. rate=exp_mother, t_stop=t_stop, t_start=t_start)
  648. def _cpp_hom_stat(A, t_stop, rate, t_start=0 * ms):
  649. """
  650. Generate a Compound Poisson Process (CPP) with amplitude distribution
  651. A and heterogeneous firing rates r=r[0], r[1], ..., r[-1].
  652. Parameters
  653. ----------
  654. A : numpy.ndarray
  655. Amplitude distribution. A[j] represents the probability of a
  656. synchronous event of size j.
  657. The sum over all entries of A must be equal to one.
  658. t_stop : quantities.Quantity
  659. The end time of the output spike trains
  660. rate : quantities.Quantity
  661. Average rate of each spike train generated
  662. t_start : quantities.Quantity, optional
  663. The start time of the output spike trains
  664. Default: 0 ms
  665. Output
  666. ------
  667. List of n neo.SpikeTrains, having average firing rate r and correlated
  668. such to form a CPP with amplitude distribution a
  669. """
  670. # Generate mother process and associated spike labels
  671. mother = _mother_proc_cpp_stat(
  672. A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  673. labels = _sample_int_from_pdf(A, len(mother))
  674. N = len(A) - 1 # Number of trains in output
  675. try: # Faster but more memory-consuming approach
  676. M = len(mother) # number of spikes in the mother process
  677. spike_matrix = np.zeros((N, M), dtype=bool)
  678. # for each spike, take its label l
  679. for spike_id, l in enumerate(labels):
  680. # choose l random trains
  681. train_ids = random.sample(range(N), l)
  682. # and set the spike matrix for that train
  683. for train_id in train_ids:
  684. spike_matrix[train_id, spike_id] = True # and spike to True
  685. times = [[] for i in range(N)]
  686. for train_id, row in enumerate(spike_matrix):
  687. times[train_id] = mother[row].view(Quantity)
  688. except MemoryError: # Slower (~2x) but less memory-consuming approach
  689. print('memory case')
  690. times = [[] for i in range(N)]
  691. for t, l in zip(mother, labels):
  692. train_ids = random.sample(range(N), l)
  693. for train_id in train_ids:
  694. times[train_id].append(t)
  695. trains = [SpikeTrain(
  696. times=t, t_start=t_start, t_stop=t_stop) for t in times]
  697. return trains
  698. def _cpp_het_stat(A, t_stop, rate, t_start=0. * ms):
  699. """
  700. Generate a Compound Poisson Process (CPP) with amplitude distribution
  701. A and heterogeneous firing rates r=r[0], r[1], ..., r[-1].
  702. Parameters
  703. ----------
  704. A : array
  705. CPP's amplitude distribution. A[j] represents the probability of
  706. a synchronous event of size j among the generated spike trains.
  707. The sum over all entries of A must be equal to one.
  708. t_stop : Quantity (time)
  709. The end time of the output spike trains
  710. rate : Quantity (1/time)
  711. Average rate of each spike train generated
  712. t_start : quantities.Quantity, optional
  713. The start time of the output spike trains
  714. Default: 0 ms
  715. Output
  716. ------
  717. List of neo.SpikeTrains with different firing rates, forming
  718. a CPP with amplitude distribution A
  719. """
  720. # Computation of Parameters of the two CPPs that will be merged
  721. # (uncorrelated with heterog. rates + correlated with homog. rates)
  722. N = len(rate) # number of output spike trains
  723. A_exp = np.dot(A, range(N + 1)) # expectation of A
  724. r_sum = np.sum(rate) # sum of all output firing rates
  725. r_min = np.min(rate) # minimum of the firing rates
  726. r1 = r_sum - N * r_min # rate of the uncorrelated CPP
  727. r2 = r_sum / float(A_exp) - r1 # rate of the correlated CPP
  728. r_mother = r1 + r2 # rate of the hidden mother process
  729. # Check the analytical constraint for the amplitude distribution
  730. if A[1] < (r1 / r_mother).rescale(dimensionless).magnitude:
  731. raise ValueError('A[1] too small / A[i], i>1 too high')
  732. # Compute the amplitude distrib of the correlated CPP, and generate it
  733. a = [(r_mother * i) / float(r2) for i in A]
  734. a[1] = a[1] - r1 / float(r2)
  735. CPP = _cpp_hom_stat(a, t_stop, r_min, t_start)
  736. # Generate the independent heterogeneous Poisson processes
  737. POISS = [
  738. homogeneous_poisson_process(i - r_min, t_start, t_stop) for i in rate]
  739. # Pool the correlated CPP and the corresponding Poisson processes
  740. out = [_pool_two_spiketrains(CPP[i], POISS[i]) for i in range(N)]
  741. return out
  742. def compound_poisson_process(rate, A, t_stop, shift=None, t_start=0 * ms):
  743. """
  744. Generate a Compound Poisson Process (CPP; see [1]) with a given amplitude
  745. distribution A and stationary marginal rates r.
  746. The CPP process is a model for parallel, correlated processes with Poisson
  747. spiking statistics at pre-defined firing rates. It is composed of len(A)-1
  748. spike trains with a correlation structure determined by the amplitude
  749. distribution A: A[j] is the probability that a spike occurs synchronously
  750. in any j spike trains.
  751. The CPP is generated by creating a hidden mother Poisson process, and then
  752. copying spikes of the mother process to j of the output spike trains with
  753. probability A[j].
  754. Note that this function decorrelates the firing rate of each SpikeTrain
  755. from the probability for that SpikeTrain to participate in a synchronous
  756. event (which is uniform across SpikeTrains).
  757. Parameters
  758. ----------
  759. rate : quantities.Quantity
  760. Average rate of each spike train generated. Can be:
  761. - a single value, all spike trains will have same rate rate
  762. - an array of values (of length len(A)-1), each indicating the
  763. firing rate of one process in output
  764. A : array
  765. CPP's amplitude distribution. A[j] represents the probability of
  766. a synchronous event of size j among the generated spike trains.
  767. The sum over all entries of A must be equal to one.
  768. t_stop : quantities.Quantity
  769. The end time of the output spike trains.
  770. shift : None or quantities.Quantity, optional
  771. If None, the injected synchrony is exact. If shift is a Quantity, all
  772. the spike trains are shifted independently by a random amount in
  773. the interval [-shift, +shift].
  774. Default: None
  775. t_start : quantities.Quantity, optional
  776. The t_start time of the output spike trains.
  777. Default: 0 s
  778. Returns
  779. -------
  780. List of neo.SpikeTrains
  781. SpikeTrains with specified firing rates forming the CPP with amplitude
  782. distribution A.
  783. References
  784. ----------
  785. [1] Staude, Rotter, Gruen (2010) J Comput Neurosci 29:327-350.
  786. """
  787. # Check A is a probability distribution (it sums to 1 and is positive)
  788. if abs(sum(A) - 1) > np.finfo('float').eps:
  789. raise ValueError(
  790. 'A must be a probability vector, sum(A)= %f !=1' % (sum(A)))
  791. if any([a < 0 for a in A]):
  792. raise ValueError(
  793. 'A must be a probability vector, all the elements of must be >0')
  794. # Check that the rate is not an empty Quantity
  795. if rate.ndim == 1 and len(rate.magnitude) == 0:
  796. raise ValueError('Rate is an empty Quantity array')
  797. # Return empty spike trains for specific parameters
  798. elif A[0] == 1 or np.sum(np.abs(rate.magnitude)) == 0:
  799. return [
  800. SpikeTrain([] * t_stop.units, t_stop=t_stop,
  801. t_start=t_start) for i in range(len(A) - 1)]
  802. else:
  803. # Homogeneous rates
  804. if rate.ndim == 0:
  805. cpp = _cpp_hom_stat(A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  806. # Heterogeneous rates
  807. else:
  808. cpp = _cpp_het_stat(A=A, t_stop=t_stop, rate=rate, t_start=t_start)
  809. if shift is None:
  810. return cpp
  811. # Dither the output spiketrains
  812. else:
  813. cpp = [
  814. dither_spike_train(cp, shift=shift, edges=True)[0]
  815. for cp in cpp]
  816. return cpp
  817. # Alias for the compound poisson process
  818. cpp = compound_poisson_process