sta.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Functions to calculate spike-triggered average and spike-field coherence of
  4. analog signals.
  5. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  6. :license: Modified BSD, see LICENSE.txt for details.
  7. '''
  8. from __future__ import division
  9. import numpy as np
  10. import scipy.signal
  11. import quantities as pq
  12. from neo.core import AnalogSignal, SpikeTrain
  13. import warnings
  14. from .conversion import BinnedSpikeTrain
  15. def spike_triggered_average(signal, spiketrains, window):
  16. """
  17. Calculates the spike-triggered averages of analog signals in a time window
  18. relative to the spike times of a corresponding spiketrain for multiple
  19. signals each. The function receives n analog signals and either one or
  20. n spiketrains. In case it is one spiketrain this one is muliplied n-fold
  21. and used for each of the n analog signals.
  22. Parameters
  23. ----------
  24. signal : neo AnalogSignal object
  25. 'signal' contains n analog signals.
  26. spiketrains : one SpikeTrain or one numpy ndarray or a list of n of either of these.
  27. 'spiketrains' contains the times of the spikes in the spiketrains.
  28. window : tuple of 2 Quantity objects with dimensions of time.
  29. 'window' is the start time and the stop time, relative to a spike, of
  30. the time interval for signal averaging.
  31. If the window size is not a multiple of the sampling interval of the
  32. signal the window will be extended to the next multiple.
  33. Returns
  34. -------
  35. result_sta : neo AnalogSignal object
  36. 'result_sta' contains the spike-triggered averages of each of the
  37. analog signals with respect to the spikes in the corresponding
  38. spiketrains. The length of 'result_sta' is calculated as the number
  39. of bins from the given start and stop time of the averaging interval
  40. and the sampling rate of the analog signal. If for an analog signal
  41. no spike was either given or all given spikes had to be ignored
  42. because of a too large averaging interval, the corresponding returned
  43. analog signal has all entries as nan. The number of used spikes and
  44. unused spikes for each analog signal are returned as annotations to
  45. the returned AnalogSignal object.
  46. Examples
  47. --------
  48. >>> signal = neo.AnalogSignal(np.array([signal1, signal2]).T, units='mV',
  49. ... sampling_rate=10/ms)
  50. >>> stavg = spike_triggered_average(signal, [spiketrain1, spiketrain2],
  51. ... (-5 * ms, 10 * ms))
  52. """
  53. # checking compatibility of data and data types
  54. # window_starttime: time to specify the start time of the averaging
  55. # interval relative to a spike
  56. # window_stoptime: time to specify the stop time of the averaging
  57. # interval relative to a spike
  58. window_starttime, window_stoptime = window
  59. if not (isinstance(window_starttime, pq.quantity.Quantity) and
  60. window_starttime.dimensionality.simplified ==
  61. pq.Quantity(1, "s").dimensionality):
  62. raise TypeError("The start time of the window (window[0]) "
  63. "must be a time quantity.")
  64. if not (isinstance(window_stoptime, pq.quantity.Quantity) and
  65. window_stoptime.dimensionality.simplified ==
  66. pq.Quantity(1, "s").dimensionality):
  67. raise TypeError("The stop time of the window (window[1]) "
  68. "must be a time quantity.")
  69. if window_stoptime <= window_starttime:
  70. raise ValueError("The start time of the window (window[0]) must be "
  71. "earlier than the stop time of the window (window[1]).")
  72. # checks on signal
  73. if not isinstance(signal, AnalogSignal):
  74. raise TypeError(
  75. "Signal must be an AnalogSignal, not %s." % type(signal))
  76. if len(signal.shape) > 1:
  77. # num_signals: number of analog signals
  78. num_signals = signal.shape[1]
  79. else:
  80. raise ValueError("Empty analog signal, hence no averaging possible.")
  81. if window_stoptime - window_starttime > signal.t_stop - signal.t_start:
  82. raise ValueError("The chosen time window is larger than the "
  83. "time duration of the signal.")
  84. # spiketrains type check
  85. if isinstance(spiketrains, (np.ndarray, SpikeTrain)):
  86. spiketrains = [spiketrains]
  87. elif isinstance(spiketrains, list):
  88. for st in spiketrains:
  89. if not isinstance(st, (np.ndarray, SpikeTrain)):
  90. raise TypeError(
  91. "spiketrains must be a SpikeTrain, a numpy ndarray, or a "
  92. "list of one of those, not %s." % type(spiketrains))
  93. else:
  94. raise TypeError(
  95. "spiketrains must be a SpikeTrain, a numpy ndarray, or a list of "
  96. "one of those, not %s." % type(spiketrains))
  97. # multiplying spiketrain in case only a single spiketrain is given
  98. if len(spiketrains) == 1 and num_signals != 1:
  99. template = spiketrains[0]
  100. spiketrains = []
  101. for i in range(num_signals):
  102. spiketrains.append(template)
  103. # checking for matching numbers of signals and spiketrains
  104. if num_signals != len(spiketrains):
  105. raise ValueError(
  106. "The number of signals and spiketrains has to be the same.")
  107. # checking the times of signal and spiketrains
  108. for i in range(num_signals):
  109. if spiketrains[i].t_start < signal.t_start:
  110. raise ValueError(
  111. "The spiketrain indexed by %i starts earlier than "
  112. "the analog signal." % i)
  113. if spiketrains[i].t_stop > signal.t_stop:
  114. raise ValueError(
  115. "The spiketrain indexed by %i stops later than "
  116. "the analog signal." % i)
  117. # *** Main algorithm: ***
  118. # window_bins: number of bins of the chosen averaging interval
  119. window_bins = int(np.ceil(((window_stoptime - window_starttime) *
  120. signal.sampling_rate).simplified))
  121. # result_sta: array containing finally the spike-triggered averaged signal
  122. result_sta = AnalogSignal(np.zeros((window_bins, num_signals)),
  123. sampling_rate=signal.sampling_rate, units=signal.units)
  124. # setting of correct times of the spike-triggered average
  125. # relative to the spike
  126. result_sta.t_start = window_starttime
  127. used_spikes = np.zeros(num_signals, dtype=int)
  128. unused_spikes = np.zeros(num_signals, dtype=int)
  129. total_used_spikes = 0
  130. for i in range(num_signals):
  131. # summing over all respective signal intervals around spiketimes
  132. for spiketime in spiketrains[i]:
  133. # checks for sufficient signal data around spiketime
  134. if (spiketime + window_starttime >= signal.t_start and
  135. spiketime + window_stoptime <= signal.t_stop):
  136. # calculating the startbin in the analog signal of the
  137. # averaging window for spike
  138. startbin = int(np.floor(((spiketime + window_starttime -
  139. signal.t_start) * signal.sampling_rate).simplified))
  140. # adds the signal in selected interval relative to the spike
  141. result_sta[:, i] += signal[
  142. startbin: startbin + window_bins, i]
  143. # counting of the used spikes
  144. used_spikes[i] += 1
  145. else:
  146. # counting of the unused spikes
  147. unused_spikes[i] += 1
  148. # normalization
  149. result_sta[:, i] = result_sta[:, i] / used_spikes[i]
  150. total_used_spikes += used_spikes[i]
  151. if total_used_spikes == 0:
  152. warnings.warn(
  153. "No spike at all was either found or used for averaging")
  154. result_sta.annotate(used_spikes=used_spikes, unused_spikes=unused_spikes)
  155. return result_sta
  156. def spike_field_coherence(signal, spiketrain, **kwargs):
  157. """
  158. Calculates the spike-field coherence between a analog signal(s) and a
  159. (binned) spike train.
  160. The current implementation makes use of scipy.signal.coherence(). Additional
  161. kwargs will will be directly forwarded to scipy.signal.coherence(),
  162. except for the axis parameter and the sampling frequency, which will be
  163. extracted from the input signals.
  164. The spike_field_coherence function receives an analog signal array and
  165. either a binned spike train or a spike train containing the original spike
  166. times. In case of original spike times the spike train is binned according
  167. to the sampling rate of the analog signal array.
  168. The AnalogSignal object can contain one or multiple signal traces. In case
  169. of multiple signal traces, the spike field coherence is calculated
  170. individually for each signal trace and the spike train.
  171. Parameters
  172. ----------
  173. signal : neo AnalogSignal object
  174. 'signal' contains n analog signals.
  175. spiketrain : SpikeTrain or BinnedSpikeTrain
  176. Single spike train to perform the analysis on. The binsize of the
  177. binned spike train must match the sampling_rate of signal.
  178. KWArgs
  179. ------
  180. All KWArgs are passed to scipy.signal.coherence().
  181. Returns
  182. -------
  183. coherence : complex Quantity array
  184. contains the coherence values calculated for each analog signal trace
  185. in combination with the spike train. The first dimension corresponds to
  186. the frequency, the second to the number of the signal trace.
  187. frequencies : Quantity array
  188. contains the frequency values corresponding to the first dimension of
  189. the 'coherence' array
  190. Example
  191. -------
  192. Plot the SFC between a regular spike train at 20 Hz, and two sinusoidal
  193. time series at 20 Hz and 23 Hz, respectively.
  194. >>> import numpy as np
  195. >>> import matplotlib.pyplot as plt
  196. >>> from quantities import s, ms, mV, Hz, kHz
  197. >>> import neo, elephant
  198. >>> t = pq.Quantity(range(10000),units='ms')
  199. >>> f1, f2 = 20. * Hz, 23. * Hz
  200. >>> signal = neo.AnalogSignal(np.array([
  201. np.sin(f1 * 2. * np.pi * t.rescale(s)),
  202. np.sin(f2 * 2. * np.pi * t.rescale(s))]).T,
  203. units=pq.mV, sampling_rate=1. * kHz)
  204. >>> spiketrain = neo.SpikeTrain(
  205. range(t[0], t[-1], 50), units='ms',
  206. t_start=t[0], t_stop=t[-1])
  207. >>> sfc, freqs = elephant.sta.spike_field_coherence(
  208. signal, spiketrain, window='boxcar')
  209. >>> plt.plot(freqs, sfc[:,0])
  210. >>> plt.plot(freqs, sfc[:,1])
  211. >>> plt.xlabel('Frequency [Hz]')
  212. >>> plt.ylabel('SFC')
  213. >>> plt.xlim((0, 60))
  214. >>> plt.show()
  215. """
  216. if not hasattr(scipy.signal, 'coherence'):
  217. raise AttributeError('scipy.signal.coherence is not available. The sfc '
  218. 'function uses scipy.signal.coherence for '
  219. 'the coherence calculation. This function is '
  220. 'available for scipy version 0.16 or newer. '
  221. 'Please update you scipy version.')
  222. # spiketrains type check
  223. if not isinstance(spiketrain, (SpikeTrain, BinnedSpikeTrain)):
  224. raise TypeError(
  225. "spiketrain must be of type SpikeTrain or BinnedSpikeTrain, "
  226. "not %s." % type(spiketrain))
  227. # checks on analogsignal
  228. if not isinstance(signal, AnalogSignal):
  229. raise TypeError(
  230. "Signal must be an AnalogSignal, not %s." % type(signal))
  231. if len(signal.shape) > 1:
  232. # num_signals: number of individual traces in the analog signal
  233. num_signals = signal.shape[1]
  234. elif len(signal.shape) == 1:
  235. num_signals = 1
  236. else:
  237. raise ValueError("Empty analog signal.")
  238. len_signals = signal.shape[0]
  239. # bin spiketrain if necessary
  240. if isinstance(spiketrain, SpikeTrain):
  241. spiketrain = BinnedSpikeTrain(
  242. spiketrain, binsize=signal.sampling_period)
  243. # check the start and stop times of signal and spike trains
  244. if spiketrain.t_start < signal.t_start:
  245. raise ValueError(
  246. "The spiketrain starts earlier than the analog signal.")
  247. if spiketrain.t_stop > signal.t_stop:
  248. raise ValueError(
  249. "The spiketrain stops later than the analog signal.")
  250. # check equal time resolution for both signals
  251. if spiketrain.binsize != signal.sampling_period:
  252. raise ValueError(
  253. "The spiketrain and signal must have a "
  254. "common sampling frequency / binsize")
  255. # calculate how many bins to add on the left of the binned spike train
  256. delta_t = spiketrain.t_start - signal.t_start
  257. if delta_t % spiketrain.binsize == 0:
  258. left_edge = int((delta_t / spiketrain.binsize).magnitude)
  259. else:
  260. raise ValueError("Incompatible binning of spike train and LFP")
  261. right_edge = int(left_edge + spiketrain.num_bins)
  262. # duplicate spike trains
  263. spiketrain_array = np.zeros((1, len_signals))
  264. spiketrain_array[0, left_edge:right_edge] = spiketrain.to_array()
  265. spiketrains_array = np.repeat(spiketrain_array, repeats=num_signals, axis=0).transpose()
  266. # calculate coherence
  267. frequencies, sfc = scipy.signal.coherence(
  268. spiketrains_array, signal.magnitude,
  269. fs=signal.sampling_rate.rescale('Hz').magnitude,
  270. axis=0, **kwargs)
  271. return (pq.Quantity(sfc, units=pq.dimensionless),
  272. pq.Quantity(frequencies, units=pq.Hz))