123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- # -*- coding: utf-8 -*-
- """
- Methods for performing phase analysis.
- :copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt.
- :license: Modified BSD, see LICENSE.txt for details.
- """
- import numpy as np
- import quantities as pq
- def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
- """
- Calculate the set of spike-triggered phases of an AnalogSignal.
- Parameters
- ----------
- hilbert_transform : AnalogSignal or list of AnalogSignal
- AnalogSignal of the complex analytic signal (e.g., returned by the
- elephant.signal_processing.hilbert()). All spike trains are compared to
- this signal, if only one signal is given. Otherwise, length of
- hilbert_transform must match the length of spiketrains.
- spiketrains : Spiketrain or list of Spiketrain
- Spiketrains on which to trigger hilbert_transform extraction
- interpolate : bool
- If True, the phases and amplitudes of hilbert_transform for spikes
- falling between two samples of signal is interpolated. Otherwise, the
- closest sample of hilbert_transform is used.
- Returns
- -------
- phases : list of arrays
- Spike-triggered phases. Entries in the list correspond to the
- SpikeTrains in spiketrains. Each entry contains an array with the
- spike-triggered angles (in rad) of the signal.
- amp : list of arrays
- Corresponding spike-triggered amplitudes.
- times : list of arrays
- A list of times corresponding to the signal
- Corresponding times (corresponds to the spike times).
- Example
- -------
- Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson
- spike train:
- >>> f_osc = 20. * pq.Hz
- >>> f_sampling = 1 * pq.ms
- >>> tlen = 100 * pq.s
- >>> time_axis = np.arange(
- 0, tlen.magnitude,
- f_sampling.rescale(pq.s).magnitude) * pq.s
- >>> analogsignal = AnalogSignal(
- np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude),
- units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling)
- >>> spiketrain = elephant.spike_train_generation.
- homogeneous_poisson_process(
- 50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms))
- Calculate spike-triggered phases and amplitudes of the oscillation:
- >>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
- elephant.signal_processing.hilbert(analogsignal),
- spiketrain,
- interpolate=True)
- """
- # Convert inputs to lists
- if not isinstance(spiketrains, list):
- spiketrains = [spiketrains]
- if not isinstance(hilbert_transform, list):
- hilbert_transform = [hilbert_transform]
- # Number of signals
- num_spiketrains = len(spiketrains)
- num_phase = len(hilbert_transform)
- if num_spiketrains != 1 and num_phase != 1 and \
- num_spiketrains != num_phase:
- raise ValueError(
- "Number of spike trains and number of phase signals"
- "must match, or either of the two must be a single signal.")
- # For each trial, select the first input
- start = [elem.t_start for elem in hilbert_transform]
- stop = [elem.t_stop for elem in hilbert_transform]
- result_phases = []
- result_amps = []
- result_times = []
- # Step through each signal
- for spiketrain_i, spiketrain in enumerate(spiketrains):
- # Check which hilbert_transform AnalogSignal to look at - if there is
- # only one then all spike trains relate to this one, otherwise the two
- # lists of spike trains and phases are matched up
- if num_phase > 1:
- phase_i = spiketrain_i
- else:
- phase_i = 0
- # Take only spikes which lie directly within the signal segment -
- # ignore spikes sitting on the last sample
- sttimeind = np.where(np.logical_and(
- spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]
- # Find index into signal for each spike
- ind_at_spike = np.round(
- (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
- hilbert_transform[phase_i].sampling_period).magnitude.astype(int)
- # Extract times for speed reasons
- times = hilbert_transform[phase_i].times
- # Append new list to the results for this spiketrain
- result_phases.append([])
- result_amps.append([])
- result_times.append([])
- # Step through all spikes
- for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
- # Difference vector between actual spike time and sample point,
- # positive if spike time is later than sample point
- dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]
- # Make sure ind_at_spike is to the left of the spike time
- if dv < 0 and ind_at_spike_j > 0:
- ind_at_spike_j = ind_at_spike_j - 1
- if interpolate:
- # Get relative spike occurrence between the two closest signal
- # sample points
- # if z->0 spike is more to the left sample
- # if z->1 more to the right sample
- z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\
- hilbert_transform[phase_i].sampling_period
- # Save hilbert_transform (interpolate on circle)
- p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
- p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
- result_phases[spiketrain_i].append(
- np.angle(
- (1 - z) * np.exp(np.complex(0, p1)) +
- z * np.exp(np.complex(0, p2))))
- # Save amplitude
- result_amps[spiketrain_i].append(
- (1 - z) * np.abs(
- hilbert_transform[phase_i][ind_at_spike_j]) +
- z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]))
- else:
- p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
- result_phases[spiketrain_i].append(p1)
- # Save amplitude
- result_amps[spiketrain_i].append(
- np.abs(hilbert_transform[phase_i][ind_at_spike_j]))
- # Save time
- result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]])
- # Convert outputs to arrays
- for i, entry in enumerate(result_phases):
- result_phases[i] = np.array(entry).flatten()
- for i, entry in enumerate(result_amps):
- result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
- for i, entry in enumerate(result_times):
- result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()
- return result_phases, result_amps, result_times
|