phase_analysis.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # -*- coding: utf-8 -*-
  2. """
  3. Methods for performing phase analysis.
  4. :copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt.
  5. :license: Modified BSD, see LICENSE.txt for details.
  6. """
  7. import numpy as np
  8. import quantities as pq
  9. def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
  10. """
  11. Calculate the set of spike-triggered phases of an AnalogSignal.
  12. Parameters
  13. ----------
  14. hilbert_transform : AnalogSignal or list of AnalogSignal
  15. AnalogSignal of the complex analytic signal (e.g., returned by the
  16. elephant.signal_processing.hilbert()). All spike trains are compared to
  17. this signal, if only one signal is given. Otherwise, length of
  18. hilbert_transform must match the length of spiketrains.
  19. spiketrains : Spiketrain or list of Spiketrain
  20. Spiketrains on which to trigger hilbert_transform extraction
  21. interpolate : bool
  22. If True, the phases and amplitudes of hilbert_transform for spikes
  23. falling between two samples of signal is interpolated. Otherwise, the
  24. closest sample of hilbert_transform is used.
  25. Returns
  26. -------
  27. phases : list of arrays
  28. Spike-triggered phases. Entries in the list correspond to the
  29. SpikeTrains in spiketrains. Each entry contains an array with the
  30. spike-triggered angles (in rad) of the signal.
  31. amp : list of arrays
  32. Corresponding spike-triggered amplitudes.
  33. times : list of arrays
  34. A list of times corresponding to the signal
  35. Corresponding times (corresponds to the spike times).
  36. Example
  37. -------
  38. Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson
  39. spike train:
  40. >>> f_osc = 20. * pq.Hz
  41. >>> f_sampling = 1 * pq.ms
  42. >>> tlen = 100 * pq.s
  43. >>> time_axis = np.arange(
  44. 0, tlen.magnitude,
  45. f_sampling.rescale(pq.s).magnitude) * pq.s
  46. >>> analogsignal = AnalogSignal(
  47. np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude),
  48. units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling)
  49. >>> spiketrain = elephant.spike_train_generation.
  50. homogeneous_poisson_process(
  51. 50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms))
  52. Calculate spike-triggered phases and amplitudes of the oscillation:
  53. >>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
  54. elephant.signal_processing.hilbert(analogsignal),
  55. spiketrain,
  56. interpolate=True)
  57. """
  58. # Convert inputs to lists
  59. if not isinstance(spiketrains, list):
  60. spiketrains = [spiketrains]
  61. if not isinstance(hilbert_transform, list):
  62. hilbert_transform = [hilbert_transform]
  63. # Number of signals
  64. num_spiketrains = len(spiketrains)
  65. num_phase = len(hilbert_transform)
  66. if num_spiketrains != 1 and num_phase != 1 and \
  67. num_spiketrains != num_phase:
  68. raise ValueError(
  69. "Number of spike trains and number of phase signals"
  70. "must match, or either of the two must be a single signal.")
  71. # For each trial, select the first input
  72. start = [elem.t_start for elem in hilbert_transform]
  73. stop = [elem.t_stop for elem in hilbert_transform]
  74. result_phases = []
  75. result_amps = []
  76. result_times = []
  77. # Step through each signal
  78. for spiketrain_i, spiketrain in enumerate(spiketrains):
  79. # Check which hilbert_transform AnalogSignal to look at - if there is
  80. # only one then all spike trains relate to this one, otherwise the two
  81. # lists of spike trains and phases are matched up
  82. if num_phase > 1:
  83. phase_i = spiketrain_i
  84. else:
  85. phase_i = 0
  86. # Take only spikes which lie directly within the signal segment -
  87. # ignore spikes sitting on the last sample
  88. sttimeind = np.where(np.logical_and(
  89. spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]
  90. # Find index into signal for each spike
  91. ind_at_spike = np.round(
  92. (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
  93. hilbert_transform[phase_i].sampling_period).magnitude.astype(int)
  94. # Extract times for speed reasons
  95. times = hilbert_transform[phase_i].times
  96. # Append new list to the results for this spiketrain
  97. result_phases.append([])
  98. result_amps.append([])
  99. result_times.append([])
  100. # Step through all spikes
  101. for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
  102. # Difference vector between actual spike time and sample point,
  103. # positive if spike time is later than sample point
  104. dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]
  105. # Make sure ind_at_spike is to the left of the spike time
  106. if dv < 0 and ind_at_spike_j > 0:
  107. ind_at_spike_j = ind_at_spike_j - 1
  108. if interpolate:
  109. # Get relative spike occurrence between the two closest signal
  110. # sample points
  111. # if z->0 spike is more to the left sample
  112. # if z->1 more to the right sample
  113. z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\
  114. hilbert_transform[phase_i].sampling_period
  115. # Save hilbert_transform (interpolate on circle)
  116. p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
  117. p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
  118. result_phases[spiketrain_i].append(
  119. np.angle(
  120. (1 - z) * np.exp(np.complex(0, p1)) +
  121. z * np.exp(np.complex(0, p2))))
  122. # Save amplitude
  123. result_amps[spiketrain_i].append(
  124. (1 - z) * np.abs(
  125. hilbert_transform[phase_i][ind_at_spike_j]) +
  126. z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]))
  127. else:
  128. p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
  129. result_phases[spiketrain_i].append(p1)
  130. # Save amplitude
  131. result_amps[spiketrain_i].append(
  132. np.abs(hilbert_transform[phase_i][ind_at_spike_j]))
  133. # Save time
  134. result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]])
  135. # Convert outputs to arrays
  136. for i, entry in enumerate(result_phases):
  137. result_phases[i] = np.array(entry).flatten()
  138. for i, entry in enumerate(result_amps):
  139. result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
  140. for i, entry in enumerate(result_times):
  141. result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()
  142. return result_phases, result_amps, result_times