spike_train_correlation.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # -*- coding: utf-8 -*-
  2. """
  3. This modules provides functions to calculate correlations between spike trains.
  4. :copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. from __future__ import division
  8. import numpy as np
  9. import neo
  10. import quantities as pq
  11. def covariance(binned_sts, binary=False):
  12. '''
  13. Calculate the NxN matrix of pairwise covariances between all combinations
  14. of N binned spike trains.
  15. For each pair of spike trains :math:`(i,j)`, the covariance :math:`C[i,j]`
  16. is obtained by binning :math:`i` and :math:`j` at the desired bin size. Let
  17. :math:`b_i` and :math:`b_j` denote the binary vectors and :math:`m_i` and
  18. :math:`m_j` their respective averages. Then
  19. .. math::
  20. C[i,j] = <b_i-m_i, b_j-m_j> / (l-1)
  21. where <..,.> is the scalar product of two vectors.
  22. For an input of n spike trains, a n x n matrix is returned containing the
  23. covariances for each combination of input spike trains.
  24. If binary is True, the binned spike trains are clipped to 0 or 1 before
  25. computing the covariance, so that the binned vectors :math:`b_i` and
  26. :math:`b_j` are binary.
  27. Parameters
  28. ----------
  29. binned_sts : elephant.conversion.BinnedSpikeTrain
  30. A binned spike train containing the spike trains to be evaluated.
  31. binary : bool, optional
  32. If True, two spikes of a particular spike train falling in the same bin
  33. are counted as 1, resulting in binary binned vectors :math:`b_i`. If
  34. False, the binned vectors :math:`b_i` contain the spike counts per bin.
  35. Default: False
  36. Returns
  37. -------
  38. C : ndarrray
  39. The square matrix of covariances. The element :math:`C[i,j]=C[j,i]` is
  40. the covariance between binned_sts[i] and binned_sts[j].
  41. Examples
  42. --------
  43. Generate two Poisson spike trains
  44. >>> from elephant.spike_train_generation import homogeneous_poisson_process
  45. >>> st1 = homogeneous_poisson_process(
  46. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  47. >>> st2 = homogeneous_poisson_process(
  48. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  49. Calculate the covariance matrix.
  50. >>> from elephant.conversion import BinnedSpikeTrain
  51. >>> cov_matrix = covariance(BinnedSpikeTrain([st1, st2], binsize=5*ms))
  52. The covariance between the spike trains is stored in cc_matrix[0,1] (or
  53. cov_matrix[1,0]).
  54. Notes
  55. -----
  56. * The spike trains in the binned structure are assumed to all cover the
  57. complete time span of binned_sts [t_start,t_stop).
  58. '''
  59. return __calculate_correlation_or_covariance(
  60. binned_sts, binary, corrcoef_norm=False)
  61. def corrcoef(binned_sts, binary=False):
  62. '''
  63. Calculate the NxN matrix of pairwise Pearson's correlation coefficients
  64. between all combinations of N binned spike trains.
  65. For each pair of spike trains :math:`(i,j)`, the correlation coefficient
  66. :math:`C[i,j]` is obtained by binning :math:`i` and :math:`j` at the
  67. desired bin size. Let :math:`b_i` and :math:`b_j` denote the binary vectors
  68. and :math:`m_i` and :math:`m_j` their respective averages. Then
  69. .. math::
  70. C[i,j] = <b_i-m_i, b_j-m_j> /
  71. \sqrt{<b_i-m_i, b_i-m_i>*<b_j-m_j,b_j-m_j>}
  72. where <..,.> is the scalar product of two vectors.
  73. For an input of n spike trains, a n x n matrix is returned.
  74. Each entry in the matrix is a real number ranging between -1 (perfectly
  75. anti-correlated spike trains) and +1 (perfectly correlated spike trains).
  76. If binary is True, the binned spike trains are clipped to 0 or 1 before
  77. computing the correlation coefficients, so that the binned vectors
  78. :math:`b_i` and :math:`b_j` are binary.
  79. Parameters
  80. ----------
  81. binned_sts : elephant.conversion.BinnedSpikeTrain
  82. A binned spike train containing the spike trains to be evaluated.
  83. binary : bool, optional
  84. If True, two spikes of a particular spike train falling in the same bin
  85. are counted as 1, resulting in binary binned vectors :math:`b_i`. If
  86. False, the binned vectors :math:`b_i` contain the spike counts per bin.
  87. Default: False
  88. Returns
  89. -------
  90. C : ndarrray
  91. The square matrix of correlation coefficients. The element
  92. :math:`C[i,j]=C[j,i]` is the Pearson's correlation coefficient between
  93. binned_sts[i] and binned_sts[j]. If binned_sts contains only one
  94. SpikeTrain, C=1.0.
  95. Examples
  96. --------
  97. Generate two Poisson spike trains
  98. >>> from elephant.spike_train_generation import homogeneous_poisson_process
  99. >>> st1 = homogeneous_poisson_process(
  100. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  101. >>> st2 = homogeneous_poisson_process(
  102. rate=10.0*Hz, t_start=0.0*s, t_stop=10.0*s)
  103. Calculate the correlation matrix.
  104. >>> from elephant.conversion import BinnedSpikeTrain
  105. >>> cc_matrix = corrcoef(BinnedSpikeTrain([st1, st2], binsize=5*ms))
  106. The correlation coefficient between the spike trains is stored in
  107. cc_matrix[0,1] (or cc_matrix[1,0]).
  108. Notes
  109. -----
  110. * The spike trains in the binned structure are assumed to all cover the
  111. complete time span of binned_sts [t_start,t_stop).
  112. '''
  113. return __calculate_correlation_or_covariance(
  114. binned_sts, binary, corrcoef_norm=True)
  115. def __calculate_correlation_or_covariance(binned_sts, binary, corrcoef_norm):
  116. '''
  117. Helper function for covariance() and corrcoef() that performs the complete
  118. calculation for either the covariance (corrcoef_norm=False) or correlation
  119. coefficient (corrcoef_norm=True). Both calculations differ only by the
  120. denominator.
  121. Parameters
  122. ----------
  123. binned_sts : elephant.conversion.BinnedSpikeTrain
  124. See covariance() or corrcoef(), respectively.
  125. binary : bool
  126. See covariance() or corrcoef(), respectively.
  127. corrcoef_norm : bool
  128. Use normalization factor for the correlation coefficient rather than
  129. for the covariance.
  130. '''
  131. num_neurons = binned_sts.matrix_rows
  132. # Pre-allocate correlation matrix
  133. C = np.zeros((num_neurons, num_neurons))
  134. # Retrieve unclipped matrix
  135. spmat = binned_sts.to_sparse_array()
  136. # For each row, extract the nonzero column indices and the corresponding
  137. # data in the matrix (for performance reasons)
  138. bin_idx_unique = []
  139. bin_counts_unique = []
  140. if binary:
  141. for s in spmat:
  142. bin_idx_unique.append(s.nonzero()[1])
  143. else:
  144. for s in spmat:
  145. bin_counts_unique.append(s.data)
  146. # All combinations of spike trains
  147. for i in range(num_neurons):
  148. for j in range(i, num_neurons):
  149. # Enumerator:
  150. # $$ <b_i-m_i, b_j-m_j>
  151. # = <b_i, b_j> + l*m_i*m_j - <b_i, M_j> - <b_j, M_i>
  152. # =: ij + l*m_i*m_j - n_i * m_j - n_j * m_i
  153. # = ij - n_i*n_j/l $$
  154. # where $n_i$ is the spike count of spike train $i$,
  155. # $l$ is the number of bins used (i.e., length of $b_i$ or $b_j$),
  156. # and $M_i$ is a vector [m_i, m_i,..., m_i].
  157. if binary:
  158. # Intersect indices to identify number of coincident spikes in
  159. # i and j (more efficient than directly using the dot product)
  160. ij = len(np.intersect1d(
  161. bin_idx_unique[i], bin_idx_unique[j], assume_unique=True))
  162. # Number of spikes in i and j
  163. n_i = len(bin_idx_unique[i])
  164. n_j = len(bin_idx_unique[j])
  165. else:
  166. # Calculate dot product b_i*b_j between unclipped matrices
  167. ij = spmat[i].dot(spmat[j].transpose()).toarray()[0][0]
  168. # Number of spikes in i and j
  169. n_i = np.sum(bin_counts_unique[i])
  170. n_j = np.sum(bin_counts_unique[j])
  171. enumerator = ij - n_i * n_j / binned_sts.num_bins
  172. # Denominator:
  173. if corrcoef_norm:
  174. # Correlation coefficient
  175. # Note:
  176. # $$ <b_i-m_i, b_i-m_i>
  177. # = <b_i, b_i> + m_i^2 - 2 <b_i, M_i>
  178. # =: ii + m_i^2 - 2 n_i * m_i
  179. # = ii - n_i^2 / $$
  180. if binary:
  181. # Here, b_i*b_i is just the number of filled bins (since
  182. # each filled bin of a clipped spike train has value equal
  183. # to 1)
  184. ii = len(bin_idx_unique[i])
  185. jj = len(bin_idx_unique[j])
  186. else:
  187. # directly calculate the dot product based on the counts of
  188. # all filled entries (more efficient than using the dot
  189. # product of the rows of the sparse matrix)
  190. ii = np.dot(bin_counts_unique[i], bin_counts_unique[i])
  191. jj = np.dot(bin_counts_unique[j], bin_counts_unique[j])
  192. denominator = np.sqrt(
  193. (ii - (n_i ** 2) / binned_sts.num_bins) *
  194. (jj - (n_j ** 2) / binned_sts.num_bins))
  195. else:
  196. # Covariance
  197. # $$ l-1 $$
  198. denominator = (binned_sts.num_bins - 1)
  199. # Fill entry of correlation matrix
  200. C[i, j] = C[j, i] = enumerator / denominator
  201. return np.squeeze(C)
  202. def cross_correlation_histogram(
  203. binned_st1, binned_st2, window='full', border_correction=False, binary=False,
  204. kernel=None, method='speed'):
  205. """
  206. Computes the cross-correlation histogram (CCH) between two binned spike
  207. trains binned_st1 and binned_st2.
  208. Parameters
  209. ----------
  210. binned_st1, binned_st2 : BinnedSpikeTrain
  211. Binned spike trains to cross-correlate. The two spike trains must have
  212. same t_start and t_stop
  213. window : string or list (optional)
  214. ‘full’: This returns the crosscorrelation at each point of overlap,
  215. with an output shape of (N+M-1,). At the end-points of the
  216. cross-correlogram, the signals do not overlap completely, and
  217. boundary effects may be seen.
  218. ‘valid’: Mode valid returns output of length max(M, N) - min(M, N) + 1.
  219. The cross-correlation product is only given for points where the
  220. signals overlap completely.
  221. Values outside the signal boundary have no effect.
  222. Default: 'full'
  223. list of integer of of quantities (window[0]=minimum, window[1]=maximum
  224. lag): The entries of window can be integer (number of bins) or
  225. quantities (time units of the lag), in the second case they have to be
  226. a multiple of the binsize
  227. Default: 'Full'
  228. border_correction : bool (optional)
  229. whether to correct for the border effect. If True, the value of the
  230. CCH at bin b (for b=-H,-H+1, ...,H, where H is the CCH half-length)
  231. is multiplied by the correction factor:
  232. (H+1)/(H+1-|b|),
  233. which linearly corrects for loss of bins at the edges.
  234. Default: False
  235. binary : bool (optional)
  236. whether to binary spikes from the same spike train falling in the
  237. same bin. If True, such spikes are considered as a single spike;
  238. otherwise they are considered as different spikes.
  239. Default: False.
  240. kernel : array or None (optional)
  241. A one dimensional array containing an optional smoothing kernel applied
  242. to the resulting CCH. The length N of the kernel indicates the
  243. smoothing window. The smoothing window cannot be larger than the
  244. maximum lag of the CCH. The kernel is normalized to unit area before
  245. being applied to the resulting CCH. Popular choices for the kernel are
  246. * normalized boxcar kernel: numpy.ones(N)
  247. * hamming: numpy.hamming(N)
  248. * hanning: numpy.hanning(N)
  249. * bartlett: numpy.bartlett(N)
  250. If None is specified, the CCH is not smoothed.
  251. Default: None
  252. method : string (optional)
  253. Defines the algorithm to use. "speed" uses numpy.correlate to calculate
  254. the correlation between two binned spike trains using a non-sparse data
  255. representation. Due to various optimizations, it is the fastest
  256. realization. In contrast, the option "memory" uses an own
  257. implementation to calculate the correlation based on sparse matrices,
  258. which is more memory efficient but slower than the "speed" option.
  259. Default: "speed"
  260. Returns
  261. -------
  262. cch : AnalogSignal
  263. Containing the cross-correlation histogram between binned_st1 and binned_st2.
  264. The central bin of the histogram represents correlation at zero
  265. delay. Offset bins correspond to correlations at a delay equivalent
  266. to the difference between the spike times of binned_st1 and those of binned_st2: an
  267. entry at positive lags corresponds to a spike in binned_st2 following a
  268. spike in binned_st1 bins to the right, and an entry at negative lags
  269. corresponds to a spike in binned_st1 following a spike in binned_st2.
  270. To illustrate this definition, consider the two spike trains:
  271. binned_st1: 0 0 0 0 1 0 0 0 0 0 0
  272. binned_st2: 0 0 0 0 0 0 0 1 0 0 0
  273. Here, the CCH will have an entry of 1 at lag h=+3.
  274. Consistent with the definition of AnalogSignals, the time axis
  275. represents the left bin borders of each histogram bin. For example,
  276. the time axis might be:
  277. np.array([-2.5 -1.5 -0.5 0.5 1.5]) * ms
  278. bin_ids : ndarray of int
  279. Contains the IDs of the individual histogram bins, where the central
  280. bin has ID 0, bins the left have negative IDs and bins to the right
  281. have positive IDs, e.g.,:
  282. np.array([-3, -2, -1, 0, 1, 2, 3])
  283. Example
  284. -------
  285. Plot the cross-correlation histogram between two Poisson spike trains
  286. >>> import elephant
  287. >>> import matplotlib.pyplot as plt
  288. >>> import quantities as pq
  289. >>> binned_st1 = elephant.conversion.BinnedSpikeTrain(
  290. elephant.spike_train_generation.homogeneous_poisson_process(
  291. 10. * pq.Hz, t_start=0 * pq.ms, t_stop=5000 * pq.ms),
  292. binsize=5. * pq.ms)
  293. >>> binned_st2 = elephant.conversion.BinnedSpikeTrain(
  294. elephant.spike_train_generation.homogeneous_poisson_process(
  295. 10. * pq.Hz, t_start=0 * pq.ms, t_stop=5000 * pq.ms),
  296. binsize=5. * pq.ms)
  297. >>> cc_hist = elephant.spike_train_correlation.cross_correlation_histogram(
  298. binned_st1, binned_st2, window=[-30,30],
  299. border_correction=False,
  300. binary=False, kernel=None, method='memory')
  301. >>> plt.bar(
  302. left=cc_hist[0].times.magnitude,
  303. height=cc_hist[0][:, 0].magnitude,
  304. width=cc_hist[0].sampling_period.magnitude)
  305. >>> plt.xlabel('time (' + str(cc_hist[0].times.units) + ')')
  306. >>> plt.ylabel('cross-correlation histogram')
  307. >>> plt.axis('tight')
  308. >>> plt.show()
  309. Alias
  310. -----
  311. cch
  312. """
  313. def _border_correction(counts, max_num_bins, l, r):
  314. # Correct the values taking into account lacking contributes
  315. # at the edges
  316. correction = float(max_num_bins + 1) / np.array(
  317. max_num_bins + 1 - abs(
  318. np.arange(l, r + 1)), float)
  319. return counts * correction
  320. def _kernel_smoothing(counts, kern, l, r):
  321. # Define the kern for smoothing as an ndarray
  322. if hasattr(kern, '__iter__'):
  323. if len(kern) > np.abs(l) + np.abs(r) + 1:
  324. raise ValueError(
  325. 'The length of the kernel cannot be larger than the '
  326. 'length %d of the resulting CCH.' % (
  327. np.abs(l) + np.abs(r) + 1))
  328. kern = np.array(kern, dtype=float)
  329. kern = 1. * kern / sum(kern)
  330. # Check kern parameter
  331. else:
  332. raise ValueError('Invalid smoothing kernel.')
  333. # Smooth the cross-correlation histogram with the kern
  334. return np.convolve(counts, kern, mode='same')
  335. def _cch_memory(binned_st1, binned_st2, win, border_corr, binary, kern):
  336. # Retrieve unclipped matrix
  337. st1_spmat = binned_st1.to_sparse_array()
  338. st2_spmat = binned_st2.to_sparse_array()
  339. binsize = binned_st1.binsize
  340. max_num_bins = max(binned_st1.num_bins, binned_st2.num_bins)
  341. # Set the time window in which is computed the cch
  342. if not isinstance(win, str):
  343. # Window parameter given in number of bins (integer)
  344. if isinstance(win[0], int) and isinstance(win[1], int):
  345. # Check the window parameter values
  346. if win[0] >= win[1] or win[0] <= -max_num_bins \
  347. or win[1] >= max_num_bins:
  348. raise ValueError(
  349. "The window exceeds the length of the spike trains")
  350. # Assign left and right edges of the cch
  351. l, r = win[0], win[1]
  352. # Window parameter given in time units
  353. else:
  354. # Check the window parameter values
  355. if win[0].rescale(binsize.units).magnitude % \
  356. binsize.magnitude != 0 or win[1].rescale(
  357. binsize.units).magnitude % binsize.magnitude != 0:
  358. raise ValueError(
  359. "The window has to be a multiple of the binsize")
  360. if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
  361. or win[1] >= max_num_bins * binsize:
  362. raise ValueError("The window exceeds the length of the"
  363. " spike trains")
  364. # Assign left and right edges of the cch
  365. l, r = int(win[0].rescale(binsize.units) / binsize), int(
  366. win[1].rescale(binsize.units) / binsize)
  367. # Case without explicit window parameter
  368. elif window == 'full':
  369. # cch computed for all the possible entries
  370. # Assign left and right edges of the cch
  371. r = binned_st2.num_bins - 1
  372. l = - binned_st1.num_bins + 1
  373. # cch compute only for the entries that completely overlap
  374. elif window == 'valid':
  375. # cch computed only for valid entries
  376. # Assign left and right edges of the cch
  377. r = max(binned_st2.num_bins - binned_st1.num_bins, 0)
  378. l = min(binned_st2.num_bins - binned_st1.num_bins, 0)
  379. # Check the mode parameter
  380. else:
  381. raise KeyError("Invalid window parameter")
  382. # For each row, extract the nonzero column indices
  383. # and the corresponding # data in the matrix (for performance reasons)
  384. st1_bin_idx_unique = st1_spmat.nonzero()[1]
  385. st2_bin_idx_unique = st2_spmat.nonzero()[1]
  386. # Case with binary entries
  387. if binary:
  388. st1_bin_counts_unique = np.array(st1_spmat.data > 0, dtype=int)
  389. st2_bin_counts_unique = np.array(st2_spmat.data > 0, dtype=int)
  390. # Case with all values
  391. else:
  392. st1_bin_counts_unique = st1_spmat.data
  393. st2_bin_counts_unique = st2_spmat.data
  394. # Initialize the counts to an array of zeroes,
  395. # and the bin IDs to integers
  396. # spanning the time axis
  397. counts = np.zeros(np.abs(l) + np.abs(r) + 1)
  398. bin_ids = np.arange(l, r + 1)
  399. # Compute the CCH at lags in l,...,r only
  400. for idx, i in enumerate(st1_bin_idx_unique):
  401. il = np.searchsorted(st2_bin_idx_unique, l + i)
  402. ir = np.searchsorted(st2_bin_idx_unique, r + i, side='right')
  403. timediff = st2_bin_idx_unique[il:ir] - i
  404. assert ((timediff >= l) & (timediff <= r)).all(), 'Not all the '
  405. 'entries of cch lie in the window'
  406. counts[timediff + np.abs(l)] += (st1_bin_counts_unique[idx] *
  407. st2_bin_counts_unique[il:ir])
  408. st2_bin_idx_unique = st2_bin_idx_unique[il:]
  409. st2_bin_counts_unique = st2_bin_counts_unique[il:]
  410. # Border correction
  411. if border_corr is True:
  412. counts = _border_correction(counts, max_num_bins, l, r)
  413. if kern is not None:
  414. # Smoothing
  415. counts = _kernel_smoothing(counts, kern, l, r)
  416. # Transform the array count into an AnalogSignal
  417. cch_result = neo.AnalogSignal(
  418. signal=counts.reshape(counts.size, 1),
  419. units=pq.dimensionless,
  420. t_start=(bin_ids[0] - 0.5) * binned_st1.binsize,
  421. sampling_period=binned_st1.binsize)
  422. # Return only the hist_bins bins and counts before and after the
  423. # central one
  424. return cch_result, bin_ids
  425. def _cch_speed(binned_st1, binned_st2, win, border_corr, binary, kern):
  426. # Retrieve the array of the binne spik train
  427. st1_arr = binned_st1.to_array()[0, :]
  428. st2_arr = binned_st2.to_array()[0, :]
  429. binsize = binned_st1.binsize
  430. # Convert the to binary version
  431. if binary:
  432. st1_arr = np.array(st1_arr > 0, dtype=int)
  433. st2_arr = np.array(st2_arr > 0, dtype=int)
  434. max_num_bins = max(len(st1_arr), len(st2_arr))
  435. # Cross correlate the spiketrains
  436. # Case explicit temporal window
  437. if not isinstance(win, str):
  438. # Window parameter given in number of bins (integer)
  439. if isinstance(win[0], int) and isinstance(win[1], int):
  440. # Check the window parameter values
  441. if win[0] >= win[1] or win[0] <= -max_num_bins \
  442. or win[1] >= max_num_bins:
  443. raise ValueError(
  444. "The window exceed the length of the spike trains")
  445. # Assign left and right edges of the cch
  446. l, r = win
  447. # Window parameter given in time units
  448. else:
  449. # Check the window parameter values
  450. if win[0].rescale(binsize.units).magnitude % \
  451. binsize.magnitude != 0 or win[1].rescale(
  452. binsize.units).magnitude % binsize.magnitude != 0:
  453. raise ValueError(
  454. "The window has to be a multiple of the binsize")
  455. if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
  456. or win[1] >= max_num_bins * binsize:
  457. raise ValueError("The window exceed the length of the"
  458. " spike trains")
  459. # Assign left and right edges of the cch
  460. l, r = int(win[0].rescale(binsize.units) / binsize), int(
  461. win[1].rescale(binsize.units) / binsize)
  462. # Zero padding
  463. st1_arr = np.pad(
  464. st1_arr, (int(np.abs(np.min([l, 0]))), np.max([r, 0])),
  465. mode='constant')
  466. cch_mode = 'valid'
  467. else:
  468. # Assign the edges of the cch for the different mode parameters
  469. if win == 'full':
  470. # Assign left and right edges of the cch
  471. r = binned_st2.num_bins - 1
  472. l = - binned_st1.num_bins + 1
  473. # cch compute only for the entries that completely overlap
  474. elif win == 'valid':
  475. # Assign left and right edges of the cch
  476. r = max(binned_st2.num_bins - binned_st1.num_bins, 0)
  477. l = min(binned_st2.num_bins - binned_st1.num_bins, 0)
  478. cch_mode = win
  479. # Cross correlate the spike trains
  480. counts = np.correlate(st2_arr, st1_arr, mode=cch_mode)
  481. bin_ids = np.r_[l:r + 1]
  482. # Border correction
  483. if border_corr is True:
  484. counts = _border_correction(counts, max_num_bins, l, r)
  485. if kern is not None:
  486. # Smoothing
  487. counts = _kernel_smoothing(counts, kern, l, r)
  488. # Transform the array count into an AnalogSignal
  489. cch_result = neo.AnalogSignal(
  490. signal=counts.reshape(counts.size, 1),
  491. units=pq.dimensionless,
  492. t_start=(bin_ids[0] - 0.5) * binned_st1.binsize,
  493. sampling_period=binned_st1.binsize)
  494. # Return only the hist_bins bins and counts before and after the
  495. # central one
  496. return cch_result, bin_ids
  497. # Check that the spike trains are binned with the same temporal
  498. # resolution
  499. if not binned_st1.matrix_rows == 1:
  500. raise AssertionError("Spike train must be one dimensional")
  501. if not binned_st2.matrix_rows == 1:
  502. raise AssertionError("Spike train must be one dimensional")
  503. if not binned_st1.binsize == binned_st2.binsize:
  504. raise AssertionError("Bin sizes must be equal")
  505. # Check t_start and t_stop identical (to drop once that the
  506. # pad functionality wil be available in the BinnedSpikeTrain classe)
  507. if not binned_st1.t_start == binned_st2.t_start:
  508. raise AssertionError("Spike train must have same t start")
  509. if not binned_st1.t_stop == binned_st2.t_stop:
  510. raise AssertionError("Spike train must have same t stop")
  511. if method == "memory":
  512. cch_result, bin_ids = _cch_memory(
  513. binned_st1, binned_st2, window, border_correction, binary,
  514. kernel)
  515. elif method == "speed":
  516. cch_result, bin_ids = _cch_speed(
  517. binned_st1, binned_st2, window, border_correction, binary,
  518. kernel)
  519. return cch_result, bin_ids
  520. # Alias for common abbreviation
  521. cch = cross_correlation_histogram