123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738 |
- import argparse
- import numpy as np
- import pandas as pd
- from tqdm import tqdm
- import emd
- import scipy.signal
- from scipy import fft
- from scipy.signal import windows, detrend, find_peaks
- from scipy.interpolate import BPoly
- from sklearn.decomposition import PCA
- from parameters import DATAPATH, NIMFCYCLES
- from util import (load_data, zero_runs, merge_ranges, circhist, circmean_angle,
- kl_divergence, switch_ranges, match_distributions)
- def pad_signal(x, method='peaks', npts=None, edge_tolerance=3, npeaks=5):
- """
- Extend a signal at both ends with various methods. Return the padded
- signal and a binary array for recovering the orignal signal i.e.
- orig_signal = padded_signal[inds == 1].
- method: 'even' - reflect signal at edges.
- 'odd' - "rotate" signal 180deg at edges.
- 'peaks' - extrapolate signal by reflecting peaks at edges and
- fitting a Bernstein polynomail constrained by 1st derivative.
- npts: number of points added to each end of signal when using the
- 'even' or 'odd' methods.
- edge_tolerance: number of points to extrapolate when checking if the
- signal edges represent a peak.
- npeaks: number of peaks to use at each end of signal when using the
- 'peaks' method.
- """
- if npts is None:
- npts = np.round(len(x) / 10).astype('int')
- if method == 'even':
- assert type(npts) in [int, np.int64], "npts must be an integer"
- from scipy.signal._arraytools import even_ext
- padded_signal = even_ext(x, npts) # extend signal using 'even' method
- # create array indicating original signal
- inds = np.full(padded_signal.shape, False)
- inds[npts:-npts] = True
- return padded_signal, inds
- elif method == 'odd':
- assert type(npts) in [int, np.int64], "npts must be an integer"
- from scipy.signal._arraytools import odd_ext
- padded_signal = odd_ext(x, npts) # extend signal using 'odd' method
- # create array indicating original signal
- inds = np.full(padded_signal.shape, False)
- inds[npts:-npts] = True
- return padded_signal, inds
- elif method == 'peaks':
- assert type(npeaks) == int, "npeaks must be an integer"
- pre = _pad_signal_mirror_peaks(x, edge_tolerance, npeaks)
- post = _pad_signal_mirror_peaks(np.flip(x), edge_tolerance, npeaks)
- padded_signal = np.concatenate((pre, x, np.flip(post)))
- # create array indicating original signal
- inds = np.full(padded_signal.shape, False)
- inds[len(pre):-len(post)] = True
- return padded_signal, inds
- def get_peaks_simple(x, edge_tolerance=2):
- """Return two arrays of indices indicating peaks & troughs of a signal occur."""
- assert x.ndim == 1, "Signal must be 1D"
- peaks, troughs = np.array([]), np.array([])
- search_range = np.arange(len(x))[1:-1] # search whole signal except endpoints
- for i in search_range: # go forwards through signal
- if (x[i-1] < x[i] > x[i+1]): # point is a peak
- peaks = np.append(peaks, i)
- elif (x[i-1] > x[i] < x[i+1]): # point is a trough
- troughs = np.append(troughs, i)
- # estimate gradient change beyond signal edges for requested tolerance
- # by assuming a constant 2nd derivative (i.e. peaks are parabolic)
- gradient = np.gradient(x) # first derivative
- gradient2 = np.gradient(gradient) # second derivative
- # before signal
- grad_pre = gradient[0] - edge_tolerance*gradient2[0]
- if np.sign(grad_pre) + -np.sign(gradient[0]) == 2: # first point is peak
- peaks = np.concatenate(([0], peaks))
- elif np.sign(grad_pre) + -np.sign(gradient[0]) == -2: # first point is trough
- troughs = np.concatenate(([0], troughs))
- # after signal
- grad_post = gradient[-1] + edge_tolerance*gradient2[-1]
- if np.sign(gradient[-1]) + -np.sign(grad_post) == 2: # last point is peak
- peaks = np.concatenate((peaks, [len(x) - 1]))
- elif np.sign(gradient[-1]) + -np.sign(grad_post) == -2: # last point is trough
- troughs = np.concatenate((troughs, [len(x) - 1]))
- # return as ints for indexing
- return peaks.astype('int'), troughs.astype('int')
- def _pad_signal_mirror_peaks(x, edge_tolerance, npeaks):
- """
- Extend a signal from the beginning by mirroring npeaks and interpolating
- over these peak values with splines restricted by the gradients at the
- signal end points.
- Notes
- -----
- - To extend a signal from the end, simply pass np.flip(x) instead of x,
- then flip the returned array before adding it to the original signal.
- """
- peaks = np.concatenate(get_peaks_simple(x, edge_tolerance)) # peaks & troughs
- peaks.sort()
- assert len(peaks) >= 2, "<2 peaks present in signal"
- if np.sign(x[peaks[0]]) == np.sign(x[0]):
- peaks = peaks[1:npeaks].astype('int') # convert to ints for indexing
- else:
- peaks = peaks[:npeaks].astype('int')
- #half_period = peaks[1] - peaks[0] # half period of first oscillation
- #offset = half_period - peaks[0] # "phase" of oscillation at beginning
- grad = np.gradient(x)[0] # derivative at beginning of signal
- inds = np.concatenate(([0], peaks))
- # y-values of points to interpolate
- y = np.concatenate(([x[0] - grad], x[peaks]))
- # gradients at points to interpolate (0 except for 1st point)
- grads = np.concatenate(([-grad], np.full(peaks.shape, 0)))
- # get Bernstein polynomial splines for given values and derivatives
- splines = BPoly.from_derivatives(inds, np.vstack((y, grads)).T, orders=3)
- return np.flip(splines(np.arange(inds[-1]))) # interpolate & flip
- def hilbert(x, fs=1, axis=1):
- """
- Perform Hilbert spectral analysis on a signal.
- Parameters
- ----------
- x : ndarray
- the signal
- fs : float (default = 1)
- sampling frequency
- axis : int
- axis of x along which to perform the analysis
- """
- y = scipy.signal.hilbert(x, axis=axis) # complex analytic signal
- phase = np.angle(y) # CCW angle from positive real axis
- # rate of change of phase
- freq = np.gradient(np.unwrap(phase), axis=axis) / (2*np.pi) * fs
- amp = np.abs(y) # length of signal vector at each timepoint
- return phase, freq, amp
- def fft_psd(signal, fs):
- """
- Compute the power spectral density of a signal using the Fourier transform.
- """
- signal = signal - signal.mean()
- fft_freq = np.fft.rfftfreq(len(signal), 1 / fs)[1:]
- signal_ft = np.fft.rfft(signal - signal.mean())[1:]
- #fft_power = np.abs(signal_ft) ** 2 * 2 / len(signal) ** 2
- fft_power = np.abs(signal_ft) / len(signal)
- return fft_freq, fft_power
- def hsa_psd(f, a, f_bins=None):
- """
- Compute the marginal power spectrum of a set of instantaneous frequency and
- amplitude traces.
- """
- if f_bins is None:
- f_bins = np.fft.rfftfreq(len(f), 1 / fs)[1:]
- psd = np.zeros(len(f_bins))
- f_inds = np.digitize(f, f_bins) # assign a frequency bin to each value
- for x, i in zip(a.ravel(), f_inds.ravel()): # loop over all values
- psd[i] += x # accumulate squared amplitudes
- psd /= len(f) # normalize by the number of timepoints
- return psd
- def check_binned_visits(data, bins, n_visits):
- """
- ## TODO: update docstring
- Check that, for each signal in the set, each of the given phase bins is
- visited a certain number of times.
- Parameters
- ----------
- signal_phases : ndarray
- A set of instantaneous phase traces, rows are signals, columns are
- time-points.
- phase_bins : ndarray
- The start and stop values of a set of phase bins, edge-inclusive.
- n_visits : int
- The number of times that each phase bin should be visited.
- Returns
- -------
- sufficient_visits : ndarray
- Boolean array indicating if each signal in the set visited each of the
- phase bins the desired number of times.
- """
- # allow single channel input
- if data.ndim == 1:
- data = data[:, np.newaxis].T
- # initialize boolean array (all true)
- sufficient_visits = np.ones(len(data)).astype('bool')
- # loop over instantaneous phase traces
- for ind, var in enumerate(data):
- # bin according to phase
- binned_var = np.digitize(var, bins)
- # time ranges during which IMF is in each phase bin
- bin_tranges = [zero_runs(~np.equal(binned_var, b)) for b in np.arange(1, len(bins))]
- # minimum number of visits across phase bins
- min_tranges = min([len(tranges) for tranges in bin_tranges])
- # change 1 to 0 if number of times for each phase bin is insufficient
- if min_tranges < n_visits:
- sufficient_visits[ind] = False
- return sufficient_visits
- def compute_fev(signal, components, add_mean=True):
- """
- Compute the fraction of variance explained in a target signal by a set of
- components.
- """
- recon = components.sum(axis=0) # reconstructed signal
- if add_mean:
- recon += signal.mean()
- mse = ((signal - recon) ** 2).mean() # mean squared error
- fev = 1 - (mse / signal.var()) # fraction explained variance
- return fev
- def mtcsd(x, fs=1, nperseg=None, nfft=None, noverlap=None, nw=3, ntapers=None,
- detrend_method='constant'):
- """
- Pair-wise cross-spectral density using Slepian tapers. Adapted from the
- mtcsd function in the labbox Matlab toolbox (authors: Partha Mitra,
- Ken Harris).
- Parameters
- ----------
- x : ndarray
- 2D array of signals across which to compute CSD, columns treated as
- channels
- fs : float (default = 1)
- sampling frequency
- nperseg : int, None (default = None)
- number of data points per segment, if None nperseg is set to 256
- nfft : int, None (default = None)
- number of points to include in scipy.fft.fft, if None nfft is set to
- 2 * nperseg, if nfft > nperseg data will be zero-padded
- noverlap : int, None (default = None)
- amout of overlap between consecutive segments, if None noverlap is set
- to nperseg / 2
- nw : int (default = 3)
- time-frequency bandwidth for Slepian tapers, passed on to
- scipy.signal.windows.dpss
- ntapers : int, None (default = None)
- number of tapers, passed on to scipy.signal.windows.dpss, if None
- ntapers is set to nw * 2 - 1 (as suggested by original authors)
- detrend_method : {'constant', 'linear'} (default = 'constant')
- method used by scipy.signal.detrend to detrend each segment
- Returns
- -------
- f : ndarray
- frequency bins
- csd : ndarray
- full cross-spectral density matrix
- """
- # allow single channel input
- if x.ndim == 1:
- x = x[:, np.newaxis]
- # ensure no more than 2D input
- assert x.ndim == 2
- # set some default for parameters values
- if nperseg is None:
- nperseg = 256
- if nfft is None:
- nfft = nperseg * 2
- if noverlap is None:
- noverlap = nperseg / 2
- if ntapers is None:
- ntapers = 2 * nw - 1
- # get step size and total number of segments
- stepsize = nperseg - noverlap
- nsegs = int(np.floor(len(x) / stepsize))
- # initialize csd matrix
- csd = np.zeros((x.shape[1], x.shape[1], nfft), dtype='complex128')
- # get FFT frequency bins
- f = fft.fftfreq(nfft, 1/fs)
- # get tapers
- tapers = windows.dpss(nperseg, nw, Kmax=ntapers)
- # loop over segments
- for seg_ind in range(nsegs):
- # prepare segment
- i0 = int(seg_ind * stepsize)
- i1 = int(seg_ind * stepsize + nperseg)
- if i1 > len(x): # stop if segment is out of range of data
- nsegs -= (nsegs - seg_ind) # reduce segment count
- break
- seg = x[i0:i1, :]
- seg = detrend(seg, type=detrend_method, axis=0)
- # apply tapers
- tapered_seg = np.full((len(tapers), seg.shape[0], seg.shape[1]), np.nan)
- for taper_ind, taper in enumerate(tapers):
- tapered_seg[taper_ind] = (seg.T * taper).T
- # compute FFT for each channel-taper combination
- fftnorm = np.sqrt(2) # value taken from original matlab function
- pxx = fft.fft(tapered_seg, n=nfft, axis=1) / fftnorm
- # fill upper triangle of csd matrix
- for ch1 in range(x.shape[1]): # loop over unique channel combinations
- for ch2 in range(ch1, x.shape[1]):
- # compute csd bewteen channels, summing over tapers and segments
- csd[ch1, ch2, :] += (pxx[:, :, ch1] * np.conjugate(pxx[:, :, ch2])).sum(axis=0)
- # normalize csd by number of taper-segment combinations
- # (equivalent to averaging over segments and tapers)
- csdnorm = ntapers * nsegs
- csd /= csdnorm
- # fill lower triangle of csd matrix with complex conjugate of upper triangle
- for ch1 in range(x.shape[1]):
- for ch2 in range(ch1 + 1, x.shape[1]):
- csd[ch2, ch1, :] = np.conjugate(csd[ch1, ch2, :])
- return f, csd
- def mtcoh(x, **kwargs):
- """
- Pair-wise multi-taper coherence for a set of signals.
- Parameters
- ----------
- See mtcsd documentation.
- Returns
- -------
- f : ndarray
- frequency bins
- coh : ndarray
- full spectral coherence matrix
- """
- # Compute cross-spectral density
- f, csd = mtcsd(x, **kwargs)
- # Compute power normalization matrix
- powernorm = np.zeros((x.shape[1], x.shape[1], len(f)))
- for ch1 in range(x.shape[1]):
- for ch2 in range(x.shape[1]):
- powernorm[ch1, ch2] = np.sqrt(np.abs(csd[ch1, ch1]) * np.abs(csd[ch2, ch2]))
- # Normalize CSD to get coherence
- coh = np.abs(csd) ** 2 / powernorm
- # Return frequency array, coherence, and phase differences
- return f, coh, np.angle(csd)
- class HHT():
- def __init__(self, signal, fs):
- self.signal = signal
- self.fs = fs
- self.n_samples = len(self.signal)
- def emd(self):
- signal = self.signal - self.signal.mean()
- self.imfs = emd.sift.sift(signal)
- self.n_imfs = self.imfs.shape[1]
- def hsa(self):
- phases, frequencies, amplitudes = np.full((3, self.n_samples, self.n_imfs), np.nan)
- for i, imf in enumerate(self.imfs.T):
- # pad signal to reduce edge-effects for Hilbert analysis
- try: # extrapolate by mirroring peaks
- imf_padded, orig_inds = pad_signal(imf, method='peaks')
- except AssertionError: # signal has fewer than two peaks
- # extrapolate by "reflecting signal 180deg"
- imf_padded, orig_inds = pad_signal(imf, method='odd')
- # get analytic signal
- phase, frequency, amplitude = hilbert(imf_padded, fs=self.fs, axis=0)
- # keep only values corresponding to original signal
- phases[:, i] = phase[orig_inds]
- frequencies[:, i] = frequency[orig_inds]
- amplitudes[:, i] = amplitude[orig_inds]
- self.phase, self.frequency, self.amplitude = phases, frequencies, amplitudes
- # Amplitude-weighted mean frequency for each IMF
- self.characteristic_frequency = (self.frequency * self.amplitude).sum(axis=0) / self.amplitude.sum(axis=0)
- # Power of each IMF
- self.power_density = (self.amplitude ** 2).sum(axis=0) / self.amplitude.shape[1]
- self.power_ratio = self.power_density / self.power_density.sum()
- def marginal_spectrum(self, f_bins=None, ranges=None):
- """
- Compute the marginal power spectrum of the IMF set.
- """
- if f_bins is None:
- f_bins = np.fft.rfftfreq(self.n_samples, 1 / self.fs)[1:]
- psd = np.zeros(len(f_bins))
- if ranges is None:
- frequency = self.frequency
- amplitude = self.amplitude
- else:
- frequency = np.concatenate([self.frequency[i0:i1] for i0, i1 in ranges])
- amplitude = np.concatenate([self.amplitude[i0:i1] for i0, i1 in ranges])
- binned_frequency = np.digitize(frequency, f_bins) # assign a frequency bin to each value
- for x, i in zip(amplitude.ravel(), binned_frequency.ravel()): # loop over all values
- psd[i] += x # accumulate squared amplitudes
- psd /= len(frequency) # normalize by the number of timepoints
- return psd
- def check_number_of_phasebin_visits(self, phasebins=None, ncycles=4, remove_invalid=False):
- if phasebins is None:
- phasebins = np.linspace(-np.pi, np.pi, 5)
- self.sufficient_phasebin_visits = check_binned_visits(self.phase.T, phasebins, ncycles)
- if remove_invalid:
- for attr in ['imfs', 'phase', 'frequency', 'amplitude']:
- setattr(self, attr, getattr(self, attr)[:, self.sufficient_phasebin_visits])
- for attr in ['characteristic_frequency', 'power_density', 'power_ratio']:
- setattr(self, attr, getattr(self, attr)[self.sufficient_phasebin_visits])
- self.n_imfs = self.sufficient_phasebin_visits.sum()
- def check_imf_significance(self):
- print("WARNING: IMF significance depricated.")
- assert hasattr(self, 'imfs')
- ln_f, ln_E, bounds = imf_statsig(self.imfs.T, return_period=False, use_hilbert=True)
- self.imf_significance = ln_E > bounds[1]
- def get_synchronous_events(self, dt=0.5, n_cycles=0.25, threshold_qt=0.95):
- """
- Perform a sliding window correlation between pairs of IMFs with similar frequencies.
- Notes
- -----
- This measure is similar to a time-resolved version of the pseudo mode splitting index
- from Wang et al. (2018) and Fabus et al. (2021).
- """
- imfs = self.imfs[:, np.where(self.characteristic_frequency > 0)[0]]
- freqs = self.characteristic_frequency[np.where(self.characteristic_frequency > 0)[0]]
- imfs1 = imfs[:-1]
- imfs2 = np.roll(imfs, -1, axis=1)[:-1]
- freqs2 = np.roll(freqs, -1, axis=1)[:-1]
- step_size = np.round(dt * self.fs).astype(int)
- samples = np.arange(0, len(imfs1), step_size)
- sync = np.full((len(samples), imfs1.shape[1]),np.nan)
- for i, (imf1, imf2, freq) in enumerate(zip(imfs1.T, imfs2.T, freqs2)):
- window_size = np.round(n_cycles * self.fs / freq).astype(int)
- starts = np.clip(samples - window_size, a_min=0, a_max=None)
- stops = np.clip(samples + window_size, a_min=None, a_max=(len(imf1) - 1))
- sync[:, i] = [np.dot(imf1[start:stop], imf2[start:stop]) / (stop - start) for start, stop in zip(starts, stops)]
- threshold = np.quantile(sync.mean(axis=0), threshold_qt)
- events = continuous_runs(sync.mean(axis=0) > threshold, min1len=5)
- self.synchronous_events = pts[events]
- def pairwise_coherence(self, ncycles=4):
- """
- Compute phase coherence between all pairs of IMFs.
- """
- coh_mat, pdiff_mat = np.full((2, self.n_imfs, self.n_imfs), np.nan)
- for imfi, imf in enumerate(self.imfs.T):
- # Get appropriate window size for this IMFs characteristic frequency
- period = 1 / self.characteristic_frequency[imfi] # get IMF period from characteristic frequency
- seglen = ncycles * period
- nperseg = int(2 ** np.floor((np.log2(seglen * self.fs)))) # number of samples
- # Skip if segment not long enough to estimate coherence
- if nperseg > self.n_samples:
- continue
- # Compute pair-wise cross-spectral density
- f, coh, pdiff = mtcoh(self.imfs, fs=self.fs, nperseg=nperseg)
- # Take only the row corresponding to the current IMF
- coh = coh[imfi]
- pdiff = pdiff[imfi]
- # Get index of the appropriate frequency bin (consider only +ve freqs)
- f_ind = f[f > 0].searchsorted(self.characteristic_frequency[imfi])
- # Take mean of two most apropriate frequency bins
- coh = coh[:, f_ind:(f_ind + 2)].mean(axis=1)
- pdiff = circmean_angle(pdiff[:, f_ind:(f_ind + 2)], axis=1)
- # Fill row of matrix
- coh_mat[imfi] = coh
- pdiff_mat[imfi] = pdiff
- # Normalize each row by it's maximum to get rid of contributions of power
- self.coherence = (coh_mat.T / coh_mat.max(axis=1)).T
- self.phasediff = pdiff_mat
- def phase_synchrony(self, n_bins=16, n_shf=1000):
- phases = self.phase.T
- n_phases = len(phases)
- freqs = self.characteristic_frequency
- bin_edges = np.linspace(-np.pi, np.pi, n_bins + 1)
- # Get bin areas array to normalize density function
- D_areas = np.outer(np.diff(bin_edges), np.diff(bin_edges))
- # Get a reference uniform distribution
- D_uniform = np.ones((n_bins, n_bins)) / n_bins**2
- # Initialize array to collect the joint distributions
- DD = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
- # Initialize array to colled KLDs
- DD_kld = np.full((len(phases), len(phases)), np.nan)
- # Initialize array to collect distribution p-values
- DD_p = np.full((len(phases), len(phases)), np.nan)
- # Initialize array to collect the significance masks
- DD_masks = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
- # Initialize array to collect synchronous time ranges
- DD_ranges = np.full((n_phases, n_phases), np.nan, dtype='object')
- # Loop over pairs of phase traces
- for i in range(len(phases)):
- for j in range(len(phases)):
- # Skip if not in upper triangle of pairwise matrix (redundant info)
- if i == j: continue
- # Make marginal distributions uniform by converting phases to ranks
- #ranks_i = phase2rank(phases[i]) - np.pi
- #ranks_j = phase2rank(phases[j]) - np.pi
- # Get the joint probability functions
- D = np.histogram2d(phases[i], phases[j], bins=bin_edges, density=True)[0] * D_areas
- # Normalize by marginal distributions
- #D = (D.T / np.histogram(phases[i], bins=phase_bins)[0]).T # normalize rows
- #D = D / np.histogram(phases[j], bins=phase_bins)[0] # normalize columns
- DD[i, j] = D
- # Compute Kullback-Leibler divergence from uniform
- ## TODO: add eps to D to ensure no zero values? --> no negative KLDs
- kld = kl_divergence(D, D_uniform)
- # Initialize array to collect shuffle distributions
- DD_shf = np.full((n_shf, D.shape[0], D.shape[1]), np.nan)
- # Initialize array to collect shuffle KLDs
- kld_shf = np.full(n_shf, np.nan)
- # Perform shuffles
- for shf in range(n_shf):
- # Randomly shuffle time points
- #shf_i = np.random.choice(np.arange(len(phases[i])), size=len(phases[i]), replace=False)
- #shf_j = np.random.choice(np.arange(len(phases[j])), size=len(phases[j]), replace=False)
- # Shuffle cycle order
- cycles_i = np.split(phases[i], np.where(np.diff(phases[i]) < -np.pi)[0])
- cycles_j = np.split(phases[j], np.where(np.diff(phases[j]) < -np.pi)[0])
- np.random.shuffle(cycles_i)
- np.random.shuffle(cycles_j)
- shf_i = np.concatenate(cycles_i)
- shf_j = np.concatenate(cycles_j)
- # Get PDF of shuffle
- D_shf = np.histogram2d(shf_i, shf_j, bins=bin_edges, density=True)[0] * D_areas
- DD_shf[shf] = D_shf
- # Compute KLD of shuffle
- kld_shf[shf] = kl_divergence(D_shf, D_uniform)
- # Get KLD diff
- DD_kld[i, j] = (kld - kld_shf.mean()) / kld_shf.std()
- # Get KLD p-values
- DD_p[i, j] = (kld_shf > kld).mean()
- # Get significance mask for joint distribution
- mask = D > np.percentile(D_shf, 95, axis=0)
- DD_masks[i, j] = mask
- # Find time ranges where phase traces pass though significant regions
- ranges = []
- for pi, pj in np.column_stack(np.where(mask)):
- mask_i = (phases[i] > bin_edges[pi]) & (phases[i] <= bin_edges[pi + 1])
- mask_j = (phases[j] > bin_edges[pj]) & (phases[j] <= bin_edges[pj + 1])
- ranges.append(zero_runs(~(mask_i & mask_j)))
- DD_ranges[i, j] = merge_ranges(np.concatenate(ranges))
- return DD, DD_kld, DD_p, DD_masks, DD_ranges
- def modemix(self, alpha=0.05):
- """
- Compute a metric for the amount of overlap between all pairs of signals in a
- set. Metric represents the average (over all signal pairs) proportion of
- time during which the signals crossed.
- Parameters
- ----------
- signals : ndarray
- signals array of shape (nchannels, ntimepoints)
- Returns
- -------
- out : float
- mean proportion of overlap between all pairs of signals
- Notes
- -----
- - Designed for use as a metric of frequency overlap (i.e. 'mode mixing')
- between a set of IMFs resulting from EMD, in this case the input
- should be a set of instantaneous frequency traces
- References
- ----------
- [1] Laszuk, D., Cadenas, O., & Nasuto, S. J. (2015, July). Objective
- empirical mode decomposition metric. In 2015 38th International Conference
- on Telecommunications and Signal Processing (TSP) (pp. 504-507). IEEE.
- """
- metric = np.full((self.n_imfs, self.n_imfs), np.nan) # pair-wise matrix
- for i in range(self.n_imfs): # loop over unique pairs
- for j in range(i + 1, self.n_imfs):
- assert self.characteristic_frequency[i] > self.characteristic_frequency[j]
- overlap_i = (self.frequency[:, i] < np.quantile(self.frequency[:, j], 1 - alpha)).sum()
- overlap_j = (self.frequency[:, j] > np.quantile(self.frequency[:, i], alpha)).sum()
- # proportion of time for which there is overlap
- metric[i, j] = (overlap_i + overlap_j) / self.n_samples
- return metric
- def get_imf_colors(self, cmap):
- color_vals = 1 - np.linspace(0.1, 1, self.n_imfs)
- return cmap(color_vals)
- def tranges_with_events(tranges, events):
- with_event = np.full(len(tranges), False)
- for i, (t0, t1) in enumerate(tranges):
- with_event[i] = any([(evt >= t0) & (evt <= t1) for evt in events])
- return with_event
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('e_name')
- args = parser.parse_args()
- df_pupil = load_data('pupil', [args.e_name])
- df_run = load_data('ball', [args.e_name])
- df = pd.merge(df_pupil, df_run, on=['m', 's', 'e'])
- seriess = []
- for _, row in tqdm(df.iterrows(), total=len(df)):
- pupil_area = row['pupil_area']
- pupil_tpts = row['pupil_tpts']
- # Get IMFs
- fs = 1 / np.diff(pupil_tpts).mean()
- hht = HHT(pupil_area, fs)
- hht.emd()
- hht.hsa()
- f_bins, psd = fft_psd(pupil_area, hht.fs)
- hht_psd = hht.marginal_spectrum(f_bins=f_bins)
- frequencies = hht.characteristic_frequency.copy()
- powers = hht.power_ratio.copy()
- run_ranges = row['pupil_tpts'].searchsorted(row['run_bouts'])
- run_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=run_ranges)
- sit_ranges = row['pupil_tpts'].searchsorted(row['sit_bouts'])
- sit_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=sit_ranges)
- #tranges = np.array([[0, int(len(pupil_tpts) / 2)]])
- #half1_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
- #tranges = np.array([[int(len(pupil_tpts) / 2), len(pupil_tpts)]])
- #half2_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
- hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
- pbi = np.full(hht.n_imfs, np.nan)
- cycle_tranges = []
- for i, phase in enumerate(hht.phase.T):
- # Get phase bias index
- counts, _ = circhist(phase)
- pbi[i] = (counts.max() - counts.min()) / counts.max()
- #phase_components = np.column_stack([np.cos(hht.phase), np.sin(hht.phase)])
- #pca = PCA()
- #pca.fit(phase_components)
- #hht.pairwise_coherence()
- jpd, sync_kld, sync_p, sync_masks, sync_ranges = hht.phase_synchrony()
- # Take ranges only for non-uniform distrbutions
- #ranges = merge_ranges(np.concatenate(sync_ranges[sync_p <= 0.05]))
- #synchronous_bouts = pupil_tpts[ranges]
- sync_boutss = []
- desync_boutss = []
- for i, (ranges, ps) in enumerate(zip(sync_ranges, sync_p)):
- bouts = ranges[ps <= 0.05] # take only if overall distribution is significant
- #bouts = np.delete(ranges, i) # take all (except self)
- if len(bouts) > 0:
- sync_bouts = merge_ranges(np.concatenate(bouts))
- desync_bouts = switch_ranges(sync_bouts, maxval=(hht.n_samples - 1))
- sync_bouts = pupil_tpts[sync_bouts]
- desync_bouts = pupil_tpts[desync_bouts]
- sync_boutss.append(sync_bouts)
- desync_boutss.append(desync_bouts)
- # Get pupil size matched time ranges for each IMF
- pupilarea_norm = pupil_area / pupil_area.max()
- phase_bins = np.linspace(-np.pi, np.pi, 5)
- phase_nbins = len(phase_bins) - 1
- size_bins = np.linspace(0, 1, 11)
- unmatched_means, matched_means = np.full((2, hht.n_imfs, phase_nbins, phase_nbins), np.nan)
- matched_ranges = []
- # loop over IMFs
- for imf_i, imf_phase in enumerate(hht.phase.T):
- phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
- phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
- unmatched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) for ranges in phasebin_ranges]
- matched_ranges_imf = match_distributions(imf_phase, pupilarea_norm, phase_bins, size_bins)
- matched_ranges.append(pupil_tpts[np.concatenate(matched_ranges_imf).clip(0, len(pupil_tpts) - 1).astype(int)])
- matched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges_imf]
- for i in np.arange(phase_nbins):
- for j in np.arange(phase_nbins):
- if j > i:
- unmatched_means[imf_i, i, j] = unmatched_dists[i].mean() - unmatched_dists[j].mean()
- matched_means[imf_i, i, j] = matched_dists[i].mean() - matched_dists[j].mean()
- nosaccade_ranges = []
- saccades = pupil_tpts.searchsorted(row['saccade_times'])
- for imf_i, imf_phase in enumerate(hht.phase.T):
- phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
- phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
- ranges2keep = [~tranges_with_events(ranges, saccades) for ranges in phasebin_ranges]
- phasebin_ranges = np.concatenate([pupil_tpts.searchsorted(ranges[keep]) for ranges, keep in zip(phasebin_ranges, ranges2keep)])
- nosaccade_ranges.append(phasebin_ranges)
- data = {
- 'm': row['m'],
- 's': row['s'],
- 'e': row['e'],
- 'condition': args.e_name,
- 't0': pupil_tpts[0],
- 't1': pupil_tpts[-1],
- 'segment_length': pupil_tpts[-1] - pupil_tpts[0],
- 'frequency': frequencies,
- 'power': powers,
- 'hht_psd': hht_psd,
- 'fft_psd': psd,
- 'psd_freq': f_bins,
- 'run_psd': run_psd,
- 'sit_psd': sit_psd,
- #'half1_psd': half1_psd,
- #'half2_psd': half2_psd,
- 'pbi': pbi,
- #'pc_variance':pca.explained_variance_ratio_,
- #'coherence': hht.coherence,
- #'phasediff': hht.phasediff,
- 'sync_p': sync_p,
- 'sync_kld': sync_kld,
- 'sync_bouts': sync_boutss,
- 'desync_bouts': desync_boutss,
- 'sizematched_bouts': matched_ranges,
- 'unmatched_meansize': unmatched_means,
- 'matched_meansize': matched_means,
- 'nosaccade_bouts': nosaccade_ranges
- }
- seriess.append(pd.Series(data=data))
- df_hht = pd.DataFrame(seriess)
- df_hht.to_pickle(DATAPATH + 'hht_{}.pkl'.format(args.e_name))
-
|