hht.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. import argparse
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. import emd
  6. import scipy.signal
  7. from scipy import fft
  8. from scipy.signal import windows, detrend, find_peaks
  9. from scipy.interpolate import BPoly
  10. from sklearn.decomposition import PCA
  11. from parameters import DATAPATH, NIMFCYCLES
  12. from util import (load_data, zero_runs, merge_ranges, circhist, circmean_angle,
  13. kl_divergence, switch_ranges, match_distributions)
  14. def pad_signal(x, method='peaks', npts=None, edge_tolerance=3, npeaks=5):
  15. """
  16. Extend a signal at both ends with various methods. Return the padded
  17. signal and a binary array for recovering the orignal signal i.e.
  18. orig_signal = padded_signal[inds == 1].
  19. method: 'even' - reflect signal at edges.
  20. 'odd' - "rotate" signal 180deg at edges.
  21. 'peaks' - extrapolate signal by reflecting peaks at edges and
  22. fitting a Bernstein polynomail constrained by 1st derivative.
  23. npts: number of points added to each end of signal when using the
  24. 'even' or 'odd' methods.
  25. edge_tolerance: number of points to extrapolate when checking if the
  26. signal edges represent a peak.
  27. npeaks: number of peaks to use at each end of signal when using the
  28. 'peaks' method.
  29. """
  30. if npts is None:
  31. npts = np.round(len(x) / 10).astype('int')
  32. if method == 'even':
  33. assert type(npts) in [int, np.int64], "npts must be an integer"
  34. from scipy.signal._arraytools import even_ext
  35. padded_signal = even_ext(x, npts) # extend signal using 'even' method
  36. # create array indicating original signal
  37. inds = np.full(padded_signal.shape, False)
  38. inds[npts:-npts] = True
  39. return padded_signal, inds
  40. elif method == 'odd':
  41. assert type(npts) in [int, np.int64], "npts must be an integer"
  42. from scipy.signal._arraytools import odd_ext
  43. padded_signal = odd_ext(x, npts) # extend signal using 'odd' method
  44. # create array indicating original signal
  45. inds = np.full(padded_signal.shape, False)
  46. inds[npts:-npts] = True
  47. return padded_signal, inds
  48. elif method == 'peaks':
  49. assert type(npeaks) == int, "npeaks must be an integer"
  50. pre = _pad_signal_mirror_peaks(x, edge_tolerance, npeaks)
  51. post = _pad_signal_mirror_peaks(np.flip(x), edge_tolerance, npeaks)
  52. padded_signal = np.concatenate((pre, x, np.flip(post)))
  53. # create array indicating original signal
  54. inds = np.full(padded_signal.shape, False)
  55. inds[len(pre):-len(post)] = True
  56. return padded_signal, inds
  57. def get_peaks_simple(x, edge_tolerance=2):
  58. """Return two arrays of indices indicating peaks & troughs of a signal occur."""
  59. assert x.ndim == 1, "Signal must be 1D"
  60. peaks, troughs = np.array([]), np.array([])
  61. search_range = np.arange(len(x))[1:-1] # search whole signal except endpoints
  62. for i in search_range: # go forwards through signal
  63. if (x[i-1] < x[i] > x[i+1]): # point is a peak
  64. peaks = np.append(peaks, i)
  65. elif (x[i-1] > x[i] < x[i+1]): # point is a trough
  66. troughs = np.append(troughs, i)
  67. # estimate gradient change beyond signal edges for requested tolerance
  68. # by assuming a constant 2nd derivative (i.e. peaks are parabolic)
  69. gradient = np.gradient(x) # first derivative
  70. gradient2 = np.gradient(gradient) # second derivative
  71. # before signal
  72. grad_pre = gradient[0] - edge_tolerance*gradient2[0]
  73. if np.sign(grad_pre) + -np.sign(gradient[0]) == 2: # first point is peak
  74. peaks = np.concatenate(([0], peaks))
  75. elif np.sign(grad_pre) + -np.sign(gradient[0]) == -2: # first point is trough
  76. troughs = np.concatenate(([0], troughs))
  77. # after signal
  78. grad_post = gradient[-1] + edge_tolerance*gradient2[-1]
  79. if np.sign(gradient[-1]) + -np.sign(grad_post) == 2: # last point is peak
  80. peaks = np.concatenate((peaks, [len(x) - 1]))
  81. elif np.sign(gradient[-1]) + -np.sign(grad_post) == -2: # last point is trough
  82. troughs = np.concatenate((troughs, [len(x) - 1]))
  83. # return as ints for indexing
  84. return peaks.astype('int'), troughs.astype('int')
  85. def _pad_signal_mirror_peaks(x, edge_tolerance, npeaks):
  86. """
  87. Extend a signal from the beginning by mirroring npeaks and interpolating
  88. over these peak values with splines restricted by the gradients at the
  89. signal end points.
  90. Notes
  91. -----
  92. - To extend a signal from the end, simply pass np.flip(x) instead of x,
  93. then flip the returned array before adding it to the original signal.
  94. """
  95. peaks = np.concatenate(get_peaks_simple(x, edge_tolerance)) # peaks & troughs
  96. peaks.sort()
  97. assert len(peaks) >= 2, "<2 peaks present in signal"
  98. if np.sign(x[peaks[0]]) == np.sign(x[0]):
  99. peaks = peaks[1:npeaks].astype('int') # convert to ints for indexing
  100. else:
  101. peaks = peaks[:npeaks].astype('int')
  102. #half_period = peaks[1] - peaks[0] # half period of first oscillation
  103. #offset = half_period - peaks[0] # "phase" of oscillation at beginning
  104. grad = np.gradient(x)[0] # derivative at beginning of signal
  105. inds = np.concatenate(([0], peaks))
  106. # y-values of points to interpolate
  107. y = np.concatenate(([x[0] - grad], x[peaks]))
  108. # gradients at points to interpolate (0 except for 1st point)
  109. grads = np.concatenate(([-grad], np.full(peaks.shape, 0)))
  110. # get Bernstein polynomial splines for given values and derivatives
  111. splines = BPoly.from_derivatives(inds, np.vstack((y, grads)).T, orders=3)
  112. return np.flip(splines(np.arange(inds[-1]))) # interpolate & flip
  113. def hilbert(x, fs=1, axis=1):
  114. """
  115. Perform Hilbert spectral analysis on a signal.
  116. Parameters
  117. ----------
  118. x : ndarray
  119. the signal
  120. fs : float (default = 1)
  121. sampling frequency
  122. axis : int
  123. axis of x along which to perform the analysis
  124. """
  125. y = scipy.signal.hilbert(x, axis=axis) # complex analytic signal
  126. phase = np.angle(y) # CCW angle from positive real axis
  127. # rate of change of phase
  128. freq = np.gradient(np.unwrap(phase), axis=axis) / (2*np.pi) * fs
  129. amp = np.abs(y) # length of signal vector at each timepoint
  130. return phase, freq, amp
  131. def fft_psd(signal, fs):
  132. """
  133. Compute the power spectral density of a signal using the Fourier transform.
  134. """
  135. signal = signal - signal.mean()
  136. fft_freq = np.fft.rfftfreq(len(signal), 1 / fs)[1:]
  137. signal_ft = np.fft.rfft(signal - signal.mean())[1:]
  138. #fft_power = np.abs(signal_ft) ** 2 * 2 / len(signal) ** 2
  139. fft_power = np.abs(signal_ft) / len(signal)
  140. return fft_freq, fft_power
  141. def hsa_psd(f, a, f_bins=None):
  142. """
  143. Compute the marginal power spectrum of a set of instantaneous frequency and
  144. amplitude traces.
  145. """
  146. if f_bins is None:
  147. f_bins = np.fft.rfftfreq(len(f), 1 / fs)[1:]
  148. psd = np.zeros(len(f_bins))
  149. f_inds = np.digitize(f, f_bins) # assign a frequency bin to each value
  150. for x, i in zip(a.ravel(), f_inds.ravel()): # loop over all values
  151. psd[i] += x # accumulate squared amplitudes
  152. psd /= len(f) # normalize by the number of timepoints
  153. return psd
  154. def check_binned_visits(data, bins, n_visits):
  155. """
  156. ## TODO: update docstring
  157. Check that, for each signal in the set, each of the given phase bins is
  158. visited a certain number of times.
  159. Parameters
  160. ----------
  161. signal_phases : ndarray
  162. A set of instantaneous phase traces, rows are signals, columns are
  163. time-points.
  164. phase_bins : ndarray
  165. The start and stop values of a set of phase bins, edge-inclusive.
  166. n_visits : int
  167. The number of times that each phase bin should be visited.
  168. Returns
  169. -------
  170. sufficient_visits : ndarray
  171. Boolean array indicating if each signal in the set visited each of the
  172. phase bins the desired number of times.
  173. """
  174. # allow single channel input
  175. if data.ndim == 1:
  176. data = data[:, np.newaxis].T
  177. # initialize boolean array (all true)
  178. sufficient_visits = np.ones(len(data)).astype('bool')
  179. # loop over instantaneous phase traces
  180. for ind, var in enumerate(data):
  181. # bin according to phase
  182. binned_var = np.digitize(var, bins)
  183. # time ranges during which IMF is in each phase bin
  184. bin_tranges = [zero_runs(~np.equal(binned_var, b)) for b in np.arange(1, len(bins))]
  185. # minimum number of visits across phase bins
  186. min_tranges = min([len(tranges) for tranges in bin_tranges])
  187. # change 1 to 0 if number of times for each phase bin is insufficient
  188. if min_tranges < n_visits:
  189. sufficient_visits[ind] = False
  190. return sufficient_visits
  191. def compute_fev(signal, components, add_mean=True):
  192. """
  193. Compute the fraction of variance explained in a target signal by a set of
  194. components.
  195. """
  196. recon = components.sum(axis=0) # reconstructed signal
  197. if add_mean:
  198. recon += signal.mean()
  199. mse = ((signal - recon) ** 2).mean() # mean squared error
  200. fev = 1 - (mse / signal.var()) # fraction explained variance
  201. return fev
  202. def mtcsd(x, fs=1, nperseg=None, nfft=None, noverlap=None, nw=3, ntapers=None,
  203. detrend_method='constant'):
  204. """
  205. Pair-wise cross-spectral density using Slepian tapers. Adapted from the
  206. mtcsd function in the labbox Matlab toolbox (authors: Partha Mitra,
  207. Ken Harris).
  208. Parameters
  209. ----------
  210. x : ndarray
  211. 2D array of signals across which to compute CSD, columns treated as
  212. channels
  213. fs : float (default = 1)
  214. sampling frequency
  215. nperseg : int, None (default = None)
  216. number of data points per segment, if None nperseg is set to 256
  217. nfft : int, None (default = None)
  218. number of points to include in scipy.fft.fft, if None nfft is set to
  219. 2 * nperseg, if nfft > nperseg data will be zero-padded
  220. noverlap : int, None (default = None)
  221. amout of overlap between consecutive segments, if None noverlap is set
  222. to nperseg / 2
  223. nw : int (default = 3)
  224. time-frequency bandwidth for Slepian tapers, passed on to
  225. scipy.signal.windows.dpss
  226. ntapers : int, None (default = None)
  227. number of tapers, passed on to scipy.signal.windows.dpss, if None
  228. ntapers is set to nw * 2 - 1 (as suggested by original authors)
  229. detrend_method : {'constant', 'linear'} (default = 'constant')
  230. method used by scipy.signal.detrend to detrend each segment
  231. Returns
  232. -------
  233. f : ndarray
  234. frequency bins
  235. csd : ndarray
  236. full cross-spectral density matrix
  237. """
  238. # allow single channel input
  239. if x.ndim == 1:
  240. x = x[:, np.newaxis]
  241. # ensure no more than 2D input
  242. assert x.ndim == 2
  243. # set some default for parameters values
  244. if nperseg is None:
  245. nperseg = 256
  246. if nfft is None:
  247. nfft = nperseg * 2
  248. if noverlap is None:
  249. noverlap = nperseg / 2
  250. if ntapers is None:
  251. ntapers = 2 * nw - 1
  252. # get step size and total number of segments
  253. stepsize = nperseg - noverlap
  254. nsegs = int(np.floor(len(x) / stepsize))
  255. # initialize csd matrix
  256. csd = np.zeros((x.shape[1], x.shape[1], nfft), dtype='complex128')
  257. # get FFT frequency bins
  258. f = fft.fftfreq(nfft, 1/fs)
  259. # get tapers
  260. tapers = windows.dpss(nperseg, nw, Kmax=ntapers)
  261. # loop over segments
  262. for seg_ind in range(nsegs):
  263. # prepare segment
  264. i0 = int(seg_ind * stepsize)
  265. i1 = int(seg_ind * stepsize + nperseg)
  266. if i1 > len(x): # stop if segment is out of range of data
  267. nsegs -= (nsegs - seg_ind) # reduce segment count
  268. break
  269. seg = x[i0:i1, :]
  270. seg = detrend(seg, type=detrend_method, axis=0)
  271. # apply tapers
  272. tapered_seg = np.full((len(tapers), seg.shape[0], seg.shape[1]), np.nan)
  273. for taper_ind, taper in enumerate(tapers):
  274. tapered_seg[taper_ind] = (seg.T * taper).T
  275. # compute FFT for each channel-taper combination
  276. fftnorm = np.sqrt(2) # value taken from original matlab function
  277. pxx = fft.fft(tapered_seg, n=nfft, axis=1) / fftnorm
  278. # fill upper triangle of csd matrix
  279. for ch1 in range(x.shape[1]): # loop over unique channel combinations
  280. for ch2 in range(ch1, x.shape[1]):
  281. # compute csd bewteen channels, summing over tapers and segments
  282. csd[ch1, ch2, :] += (pxx[:, :, ch1] * np.conjugate(pxx[:, :, ch2])).sum(axis=0)
  283. # normalize csd by number of taper-segment combinations
  284. # (equivalent to averaging over segments and tapers)
  285. csdnorm = ntapers * nsegs
  286. csd /= csdnorm
  287. # fill lower triangle of csd matrix with complex conjugate of upper triangle
  288. for ch1 in range(x.shape[1]):
  289. for ch2 in range(ch1 + 1, x.shape[1]):
  290. csd[ch2, ch1, :] = np.conjugate(csd[ch1, ch2, :])
  291. return f, csd
  292. def mtcoh(x, **kwargs):
  293. """
  294. Pair-wise multi-taper coherence for a set of signals.
  295. Parameters
  296. ----------
  297. See mtcsd documentation.
  298. Returns
  299. -------
  300. f : ndarray
  301. frequency bins
  302. coh : ndarray
  303. full spectral coherence matrix
  304. """
  305. # Compute cross-spectral density
  306. f, csd = mtcsd(x, **kwargs)
  307. # Compute power normalization matrix
  308. powernorm = np.zeros((x.shape[1], x.shape[1], len(f)))
  309. for ch1 in range(x.shape[1]):
  310. for ch2 in range(x.shape[1]):
  311. powernorm[ch1, ch2] = np.sqrt(np.abs(csd[ch1, ch1]) * np.abs(csd[ch2, ch2]))
  312. # Normalize CSD to get coherence
  313. coh = np.abs(csd) ** 2 / powernorm
  314. # Return frequency array, coherence, and phase differences
  315. return f, coh, np.angle(csd)
  316. class HHT():
  317. def __init__(self, signal, fs):
  318. self.signal = signal
  319. self.fs = fs
  320. self.n_samples = len(self.signal)
  321. def emd(self):
  322. signal = self.signal - self.signal.mean()
  323. self.imfs = emd.sift.sift(signal)
  324. self.n_imfs = self.imfs.shape[1]
  325. def hsa(self):
  326. phases, frequencies, amplitudes = np.full((3, self.n_samples, self.n_imfs), np.nan)
  327. for i, imf in enumerate(self.imfs.T):
  328. # pad signal to reduce edge-effects for Hilbert analysis
  329. try: # extrapolate by mirroring peaks
  330. imf_padded, orig_inds = pad_signal(imf, method='peaks')
  331. except AssertionError: # signal has fewer than two peaks
  332. # extrapolate by "reflecting signal 180deg"
  333. imf_padded, orig_inds = pad_signal(imf, method='odd')
  334. # get analytic signal
  335. phase, frequency, amplitude = hilbert(imf_padded, fs=self.fs, axis=0)
  336. # keep only values corresponding to original signal
  337. phases[:, i] = phase[orig_inds]
  338. frequencies[:, i] = frequency[orig_inds]
  339. amplitudes[:, i] = amplitude[orig_inds]
  340. self.phase, self.frequency, self.amplitude = phases, frequencies, amplitudes
  341. # Amplitude-weighted mean frequency for each IMF
  342. self.characteristic_frequency = (self.frequency * self.amplitude).sum(axis=0) / self.amplitude.sum(axis=0)
  343. # Power of each IMF
  344. self.power_density = (self.amplitude ** 2).sum(axis=0) / self.amplitude.shape[1]
  345. self.power_ratio = self.power_density / self.power_density.sum()
  346. def marginal_spectrum(self, f_bins=None, ranges=None):
  347. """
  348. Compute the marginal power spectrum of the IMF set.
  349. """
  350. if f_bins is None:
  351. f_bins = np.fft.rfftfreq(self.n_samples, 1 / self.fs)[1:]
  352. psd = np.zeros(len(f_bins))
  353. if ranges is None:
  354. frequency = self.frequency
  355. amplitude = self.amplitude
  356. else:
  357. frequency = np.concatenate([self.frequency[i0:i1] for i0, i1 in ranges])
  358. amplitude = np.concatenate([self.amplitude[i0:i1] for i0, i1 in ranges])
  359. binned_frequency = np.digitize(frequency, f_bins) # assign a frequency bin to each value
  360. for x, i in zip(amplitude.ravel(), binned_frequency.ravel()): # loop over all values
  361. psd[i] += x # accumulate squared amplitudes
  362. psd /= len(frequency) # normalize by the number of timepoints
  363. return psd
  364. def check_number_of_phasebin_visits(self, phasebins=None, ncycles=4, remove_invalid=False):
  365. if phasebins is None:
  366. phasebins = np.linspace(-np.pi, np.pi, 5)
  367. self.sufficient_phasebin_visits = check_binned_visits(self.phase.T, phasebins, ncycles)
  368. if remove_invalid:
  369. for attr in ['imfs', 'phase', 'frequency', 'amplitude']:
  370. setattr(self, attr, getattr(self, attr)[:, self.sufficient_phasebin_visits])
  371. for attr in ['characteristic_frequency', 'power_density', 'power_ratio']:
  372. setattr(self, attr, getattr(self, attr)[self.sufficient_phasebin_visits])
  373. self.n_imfs = self.sufficient_phasebin_visits.sum()
  374. def check_imf_significance(self):
  375. print("WARNING: IMF significance depricated.")
  376. assert hasattr(self, 'imfs')
  377. ln_f, ln_E, bounds = imf_statsig(self.imfs.T, return_period=False, use_hilbert=True)
  378. self.imf_significance = ln_E > bounds[1]
  379. def get_synchronous_events(self, dt=0.5, n_cycles=0.25, threshold_qt=0.95):
  380. """
  381. Perform a sliding window correlation between pairs of IMFs with similar frequencies.
  382. Notes
  383. -----
  384. This measure is similar to a time-resolved version of the pseudo mode splitting index
  385. from Wang et al. (2018) and Fabus et al. (2021).
  386. """
  387. imfs = self.imfs[:, np.where(self.characteristic_frequency > 0)[0]]
  388. freqs = self.characteristic_frequency[np.where(self.characteristic_frequency > 0)[0]]
  389. imfs1 = imfs[:-1]
  390. imfs2 = np.roll(imfs, -1, axis=1)[:-1]
  391. freqs2 = np.roll(freqs, -1, axis=1)[:-1]
  392. step_size = np.round(dt * self.fs).astype(int)
  393. samples = np.arange(0, len(imfs1), step_size)
  394. sync = np.full((len(samples), imfs1.shape[1]),np.nan)
  395. for i, (imf1, imf2, freq) in enumerate(zip(imfs1.T, imfs2.T, freqs2)):
  396. window_size = np.round(n_cycles * self.fs / freq).astype(int)
  397. starts = np.clip(samples - window_size, a_min=0, a_max=None)
  398. stops = np.clip(samples + window_size, a_min=None, a_max=(len(imf1) - 1))
  399. sync[:, i] = [np.dot(imf1[start:stop], imf2[start:stop]) / (stop - start) for start, stop in zip(starts, stops)]
  400. threshold = np.quantile(sync.mean(axis=0), threshold_qt)
  401. events = continuous_runs(sync.mean(axis=0) > threshold, min1len=5)
  402. self.synchronous_events = pts[events]
  403. def pairwise_coherence(self, ncycles=4):
  404. """
  405. Compute phase coherence between all pairs of IMFs.
  406. """
  407. coh_mat, pdiff_mat = np.full((2, self.n_imfs, self.n_imfs), np.nan)
  408. for imfi, imf in enumerate(self.imfs.T):
  409. # Get appropriate window size for this IMFs characteristic frequency
  410. period = 1 / self.characteristic_frequency[imfi] # get IMF period from characteristic frequency
  411. seglen = ncycles * period
  412. nperseg = int(2 ** np.floor((np.log2(seglen * self.fs)))) # number of samples
  413. # Skip if segment not long enough to estimate coherence
  414. if nperseg > self.n_samples:
  415. continue
  416. # Compute pair-wise cross-spectral density
  417. f, coh, pdiff = mtcoh(self.imfs, fs=self.fs, nperseg=nperseg)
  418. # Take only the row corresponding to the current IMF
  419. coh = coh[imfi]
  420. pdiff = pdiff[imfi]
  421. # Get index of the appropriate frequency bin (consider only +ve freqs)
  422. f_ind = f[f > 0].searchsorted(self.characteristic_frequency[imfi])
  423. # Take mean of two most apropriate frequency bins
  424. coh = coh[:, f_ind:(f_ind + 2)].mean(axis=1)
  425. pdiff = circmean_angle(pdiff[:, f_ind:(f_ind + 2)], axis=1)
  426. # Fill row of matrix
  427. coh_mat[imfi] = coh
  428. pdiff_mat[imfi] = pdiff
  429. # Normalize each row by it's maximum to get rid of contributions of power
  430. self.coherence = (coh_mat.T / coh_mat.max(axis=1)).T
  431. self.phasediff = pdiff_mat
  432. def phase_synchrony(self, n_bins=16, n_shf=1000):
  433. phases = self.phase.T
  434. n_phases = len(phases)
  435. freqs = self.characteristic_frequency
  436. bin_edges = np.linspace(-np.pi, np.pi, n_bins + 1)
  437. # Get bin areas array to normalize density function
  438. D_areas = np.outer(np.diff(bin_edges), np.diff(bin_edges))
  439. # Get a reference uniform distribution
  440. D_uniform = np.ones((n_bins, n_bins)) / n_bins**2
  441. # Initialize array to collect the joint distributions
  442. DD = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
  443. # Initialize array to colled KLDs
  444. DD_kld = np.full((len(phases), len(phases)), np.nan)
  445. # Initialize array to collect distribution p-values
  446. DD_p = np.full((len(phases), len(phases)), np.nan)
  447. # Initialize array to collect the significance masks
  448. DD_masks = np.full((n_phases, n_phases, n_bins, n_bins), np.nan)
  449. # Initialize array to collect synchronous time ranges
  450. DD_ranges = np.full((n_phases, n_phases), np.nan, dtype='object')
  451. # Loop over pairs of phase traces
  452. for i in range(len(phases)):
  453. for j in range(len(phases)):
  454. # Skip if not in upper triangle of pairwise matrix (redundant info)
  455. if i == j: continue
  456. # Make marginal distributions uniform by converting phases to ranks
  457. #ranks_i = phase2rank(phases[i]) - np.pi
  458. #ranks_j = phase2rank(phases[j]) - np.pi
  459. # Get the joint probability functions
  460. D = np.histogram2d(phases[i], phases[j], bins=bin_edges, density=True)[0] * D_areas
  461. # Normalize by marginal distributions
  462. #D = (D.T / np.histogram(phases[i], bins=phase_bins)[0]).T # normalize rows
  463. #D = D / np.histogram(phases[j], bins=phase_bins)[0] # normalize columns
  464. DD[i, j] = D
  465. # Compute Kullback-Leibler divergence from uniform
  466. ## TODO: add eps to D to ensure no zero values? --> no negative KLDs
  467. kld = kl_divergence(D, D_uniform)
  468. # Initialize array to collect shuffle distributions
  469. DD_shf = np.full((n_shf, D.shape[0], D.shape[1]), np.nan)
  470. # Initialize array to collect shuffle KLDs
  471. kld_shf = np.full(n_shf, np.nan)
  472. # Perform shuffles
  473. for shf in range(n_shf):
  474. # Randomly shuffle time points
  475. #shf_i = np.random.choice(np.arange(len(phases[i])), size=len(phases[i]), replace=False)
  476. #shf_j = np.random.choice(np.arange(len(phases[j])), size=len(phases[j]), replace=False)
  477. # Shuffle cycle order
  478. cycles_i = np.split(phases[i], np.where(np.diff(phases[i]) < -np.pi)[0])
  479. cycles_j = np.split(phases[j], np.where(np.diff(phases[j]) < -np.pi)[0])
  480. np.random.shuffle(cycles_i)
  481. np.random.shuffle(cycles_j)
  482. shf_i = np.concatenate(cycles_i)
  483. shf_j = np.concatenate(cycles_j)
  484. # Get PDF of shuffle
  485. D_shf = np.histogram2d(shf_i, shf_j, bins=bin_edges, density=True)[0] * D_areas
  486. DD_shf[shf] = D_shf
  487. # Compute KLD of shuffle
  488. kld_shf[shf] = kl_divergence(D_shf, D_uniform)
  489. # Get KLD diff
  490. DD_kld[i, j] = (kld - kld_shf.mean()) / kld_shf.std()
  491. # Get KLD p-values
  492. DD_p[i, j] = (kld_shf > kld).mean()
  493. # Get significance mask for joint distribution
  494. mask = D > np.percentile(D_shf, 95, axis=0)
  495. DD_masks[i, j] = mask
  496. # Find time ranges where phase traces pass though significant regions
  497. ranges = []
  498. for pi, pj in np.column_stack(np.where(mask)):
  499. mask_i = (phases[i] > bin_edges[pi]) & (phases[i] <= bin_edges[pi + 1])
  500. mask_j = (phases[j] > bin_edges[pj]) & (phases[j] <= bin_edges[pj + 1])
  501. ranges.append(zero_runs(~(mask_i & mask_j)))
  502. DD_ranges[i, j] = merge_ranges(np.concatenate(ranges))
  503. return DD, DD_kld, DD_p, DD_masks, DD_ranges
  504. def modemix(self, alpha=0.05):
  505. """
  506. Compute a metric for the amount of overlap between all pairs of signals in a
  507. set. Metric represents the average (over all signal pairs) proportion of
  508. time during which the signals crossed.
  509. Parameters
  510. ----------
  511. signals : ndarray
  512. signals array of shape (nchannels, ntimepoints)
  513. Returns
  514. -------
  515. out : float
  516. mean proportion of overlap between all pairs of signals
  517. Notes
  518. -----
  519. - Designed for use as a metric of frequency overlap (i.e. 'mode mixing')
  520. between a set of IMFs resulting from EMD, in this case the input
  521. should be a set of instantaneous frequency traces
  522. References
  523. ----------
  524. [1] Laszuk, D., Cadenas, O., & Nasuto, S. J. (2015, July). Objective
  525. empirical mode decomposition metric. In 2015 38th International Conference
  526. on Telecommunications and Signal Processing (TSP) (pp. 504-507). IEEE.
  527. """
  528. metric = np.full((self.n_imfs, self.n_imfs), np.nan) # pair-wise matrix
  529. for i in range(self.n_imfs): # loop over unique pairs
  530. for j in range(i + 1, self.n_imfs):
  531. assert self.characteristic_frequency[i] > self.characteristic_frequency[j]
  532. overlap_i = (self.frequency[:, i] < np.quantile(self.frequency[:, j], 1 - alpha)).sum()
  533. overlap_j = (self.frequency[:, j] > np.quantile(self.frequency[:, i], alpha)).sum()
  534. # proportion of time for which there is overlap
  535. metric[i, j] = (overlap_i + overlap_j) / self.n_samples
  536. return metric
  537. def get_imf_colors(self, cmap):
  538. color_vals = 1 - np.linspace(0.1, 1, self.n_imfs)
  539. return cmap(color_vals)
  540. def tranges_with_events(tranges, events):
  541. with_event = np.full(len(tranges), False)
  542. for i, (t0, t1) in enumerate(tranges):
  543. with_event[i] = any([(evt >= t0) & (evt <= t1) for evt in events])
  544. return with_event
  545. if __name__ == "__main__":
  546. parser = argparse.ArgumentParser()
  547. parser.add_argument('e_name')
  548. args = parser.parse_args()
  549. df_pupil = load_data('pupil', [args.e_name])
  550. df_run = load_data('ball', [args.e_name])
  551. df = pd.merge(df_pupil, df_run, on=['m', 's', 'e'])
  552. seriess = []
  553. for _, row in tqdm(df.iterrows(), total=len(df)):
  554. pupil_area = row['pupil_area']
  555. pupil_tpts = row['pupil_tpts']
  556. # Get IMFs
  557. fs = 1 / np.diff(pupil_tpts).mean()
  558. hht = HHT(pupil_area, fs)
  559. hht.emd()
  560. hht.hsa()
  561. f_bins, psd = fft_psd(pupil_area, hht.fs)
  562. hht_psd = hht.marginal_spectrum(f_bins=f_bins)
  563. frequencies = hht.characteristic_frequency.copy()
  564. powers = hht.power_ratio.copy()
  565. run_ranges = row['pupil_tpts'].searchsorted(row['run_bouts'])
  566. run_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=run_ranges)
  567. sit_ranges = row['pupil_tpts'].searchsorted(row['sit_bouts'])
  568. sit_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=sit_ranges)
  569. #tranges = np.array([[0, int(len(pupil_tpts) / 2)]])
  570. #half1_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
  571. #tranges = np.array([[int(len(pupil_tpts) / 2), len(pupil_tpts)]])
  572. #half2_psd = hht.marginal_spectrum(f_bins=f_bins, ranges=tranges)
  573. hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
  574. pbi = np.full(hht.n_imfs, np.nan)
  575. cycle_tranges = []
  576. for i, phase in enumerate(hht.phase.T):
  577. # Get phase bias index
  578. counts, _ = circhist(phase)
  579. pbi[i] = (counts.max() - counts.min()) / counts.max()
  580. #phase_components = np.column_stack([np.cos(hht.phase), np.sin(hht.phase)])
  581. #pca = PCA()
  582. #pca.fit(phase_components)
  583. #hht.pairwise_coherence()
  584. jpd, sync_kld, sync_p, sync_masks, sync_ranges = hht.phase_synchrony()
  585. # Take ranges only for non-uniform distrbutions
  586. #ranges = merge_ranges(np.concatenate(sync_ranges[sync_p <= 0.05]))
  587. #synchronous_bouts = pupil_tpts[ranges]
  588. sync_boutss = []
  589. desync_boutss = []
  590. for i, (ranges, ps) in enumerate(zip(sync_ranges, sync_p)):
  591. bouts = ranges[ps <= 0.05] # take only if overall distribution is significant
  592. #bouts = np.delete(ranges, i) # take all (except self)
  593. if len(bouts) > 0:
  594. sync_bouts = merge_ranges(np.concatenate(bouts))
  595. desync_bouts = switch_ranges(sync_bouts, maxval=(hht.n_samples - 1))
  596. sync_bouts = pupil_tpts[sync_bouts]
  597. desync_bouts = pupil_tpts[desync_bouts]
  598. sync_boutss.append(sync_bouts)
  599. desync_boutss.append(desync_bouts)
  600. # Get pupil size matched time ranges for each IMF
  601. pupilarea_norm = pupil_area / pupil_area.max()
  602. phase_bins = np.linspace(-np.pi, np.pi, 5)
  603. phase_nbins = len(phase_bins) - 1
  604. size_bins = np.linspace(0, 1, 11)
  605. unmatched_means, matched_means = np.full((2, hht.n_imfs, phase_nbins, phase_nbins), np.nan)
  606. matched_ranges = []
  607. # loop over IMFs
  608. for imf_i, imf_phase in enumerate(hht.phase.T):
  609. phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
  610. phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
  611. unmatched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) for ranges in phasebin_ranges]
  612. matched_ranges_imf = match_distributions(imf_phase, pupilarea_norm, phase_bins, size_bins)
  613. matched_ranges.append(pupil_tpts[np.concatenate(matched_ranges_imf).clip(0, len(pupil_tpts) - 1).astype(int)])
  614. matched_dists = [np.concatenate([pupilarea_norm[i0:i1] for i0, i1 in ranges]) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges_imf]
  615. for i in np.arange(phase_nbins):
  616. for j in np.arange(phase_nbins):
  617. if j > i:
  618. unmatched_means[imf_i, i, j] = unmatched_dists[i].mean() - unmatched_dists[j].mean()
  619. matched_means[imf_i, i, j] = matched_dists[i].mean() - matched_dists[j].mean()
  620. nosaccade_ranges = []
  621. saccades = pupil_tpts.searchsorted(row['saccade_times'])
  622. for imf_i, imf_phase in enumerate(hht.phase.T):
  623. phase_binned = np.digitize(imf_phase, phase_bins).clip(1, phase_nbins) - 1
  624. phasebin_ranges = [zero_runs(~np.equal(phase_binned, phase_bin)) for phase_bin in np.arange(phase_nbins)]
  625. ranges2keep = [~tranges_with_events(ranges, saccades) for ranges in phasebin_ranges]
  626. phasebin_ranges = np.concatenate([pupil_tpts.searchsorted(ranges[keep]) for ranges, keep in zip(phasebin_ranges, ranges2keep)])
  627. nosaccade_ranges.append(phasebin_ranges)
  628. data = {
  629. 'm': row['m'],
  630. 's': row['s'],
  631. 'e': row['e'],
  632. 'condition': args.e_name,
  633. 't0': pupil_tpts[0],
  634. 't1': pupil_tpts[-1],
  635. 'segment_length': pupil_tpts[-1] - pupil_tpts[0],
  636. 'frequency': frequencies,
  637. 'power': powers,
  638. 'hht_psd': hht_psd,
  639. 'fft_psd': psd,
  640. 'psd_freq': f_bins,
  641. 'run_psd': run_psd,
  642. 'sit_psd': sit_psd,
  643. #'half1_psd': half1_psd,
  644. #'half2_psd': half2_psd,
  645. 'pbi': pbi,
  646. #'pc_variance':pca.explained_variance_ratio_,
  647. #'coherence': hht.coherence,
  648. #'phasediff': hht.phasediff,
  649. 'sync_p': sync_p,
  650. 'sync_kld': sync_kld,
  651. 'sync_bouts': sync_boutss,
  652. 'desync_bouts': desync_boutss,
  653. 'sizematched_bouts': matched_ranges,
  654. 'unmatched_meansize': unmatched_means,
  655. 'matched_meansize': matched_means,
  656. 'nosaccade_bouts': nosaccade_ranges
  657. }
  658. seriess.append(pd.Series(data=data))
  659. df_hht = pd.DataFrame(seriess)
  660. df_hht.to_pickle(DATAPATH + 'hht_{}.pkl'.format(args.e_name))