spectral.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. '''
  2. spectral.py
  3. Provides tools for spectral analysis
  4. Version 1.3 - adapted for ANDA 2024
  5. Date 03.10.2024
  6. '''
  7. # %%
  8. import numpy as np
  9. import pywt
  10. import matplotlib.pyplot as plt
  11. import scipy
  12. def hilbert_bandpassed(signal, f_mid, f_width, dt):
  13. # we design a 3-pole lowpass filter at 0.1x Nyquist frequency
  14. nyqf = 0.5 / dt
  15. b, a = scipy.signal.butter(3, [(f_mid - f_width) / nyqf, (f_mid + f_width) / nyqf], 'band')
  16. filtered = scipy.signal.filtfilt(b, a, signal, method='gust', axis=-1)
  17. analytic = scipy.signal.hilbert(filtered, axis=-1)
  18. return analytic, filtered
  19. def phase_locking_value(signal_a, signal_b):
  20. phase_a = signal_a / np.abs(signal_a)
  21. mphase_b = np.conj(signal_b) / np.abs(signal_b)
  22. plv = np.mean(phase_a * mphase_b)
  23. return plv
  24. # Calculate the wavelet scales we requested
  25. def wavelet_transform_morlet(
  26. data: np.ndarray,
  27. n_freqs: int,
  28. freq_min: float,
  29. freq_max: float,
  30. dt: float,
  31. bandwidth: float = 1.5,
  32. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  33. # wavelet scales derived from parameters
  34. s_spacing: np.ndarray = (1.0 / (n_freqs - 1)) * np.log2(freq_max / freq_min)
  35. scale: np.ndarray = np.power(2, np.arange(0, n_freqs) * s_spacing)
  36. freq_axis: np.ndarray = freq_min * scale
  37. wave_scales: np.ndarray = 1.0 / (freq_axis * dt)
  38. # the wavelet we want to use
  39. mother = pywt.ContinuousWavelet(f"cmor{bandwidth}-1.0")
  40. # one or multiple time series? --> expand
  41. data_2d = data
  42. if data.ndim == 1:
  43. data_2d = data_2d[np.newaxis, :]
  44. n_trials = data_2d.shape[0]
  45. complex_spectrum = np.zeros([n_trials, n_freqs, data_2d.shape[1]], dtype=np.complex128)
  46. for i_trial in range(n_trials):
  47. complex_spectrum[i_trial, :, :], freq_axis_wavelets = pywt.cwt(
  48. data=data_2d[i_trial, :], scales=wave_scales, wavelet=mother, sampling_period=dt
  49. )
  50. # one or multiple time series? <-- flatten
  51. if data.ndim == 1:
  52. complex_spectrum = complex_spectrum[0, :, :]
  53. # generate time axis and cone-of-influence
  54. t_axis = dt * np.arange(data_2d.shape[1])
  55. t_coi = (bandwidth * 3) / 2 / np.pi * np.sqrt(2) / freq_axis_wavelets
  56. return complex_spectrum, t_axis, freq_axis_wavelets, t_coi
  57. def wavelet_dsignal_show(
  58. wavelet_dsignal: np.ndarray,
  59. t_axis: np.ndarray,
  60. f_axis: np.ndarray,
  61. t_coi: None | np.ndarray = None
  62. ):
  63. # average over first dimension, if signal_wavelet has three dims
  64. to_show = wavelet_dsignal # np.abs(wavelet_dsignal) ** 2
  65. if to_show.ndim == 3:
  66. to_show = to_show.mean(axis=0)
  67. # compute and plot power, but show just a few frequencies from all in vector
  68. f_pick = np.arange(0, f_axis.size, max(1, int(f_axis.size / 15)))
  69. plt.pcolor(t_axis, np.arange(f_axis.size), to_show)
  70. ax = plt.gca()
  71. ax.set_yticks(0.5 + f_pick)
  72. ax.set_yticklabels([str(int(f * 100) / 100) for f in f_axis[f_pick]])
  73. # cone-of-influence
  74. if t_coi is not None:
  75. # t_coi = (show_coi_bandwidth * 4) / 2 / np.pi * np.sqrt(2) / f_axis
  76. plt.plot(t_axis[0] + t_coi, np.arange(f_axis.size), 'w-')
  77. plt.plot(t_axis[-1] - t_coi, np.arange(f_axis.size), 'w-')
  78. # labeling
  79. plt.xlabel('time t')
  80. plt.ylabel('frequency f')
  81. plt.colorbar()
  82. return
  83. def coherence(a1, a2, tau_max, ntau, dt, ts=None, te=None):
  84. """a1 and a2 are (trials, freqs, time) arrays, ts and te time start/end indices
  85. tau_max is the maximal time delay (dimension time)
  86. returns C of shape (freqs, taus) where taus is of
  87. lenght 2*ntau+1 linearly from -tau_max to +tau_max
  88. """
  89. ntrials, nfreqs, ntime = a1.shape
  90. assert a2.shape == a1.shape
  91. if ts is None:
  92. ts = 0
  93. if te is None:
  94. te = ntime
  95. taus = np.linspace(-tau_max, tau_max, 2 * ntau + 1)
  96. c = np.zeros((nfreqs, len(taus)))
  97. a2_conj = np.conj(a2)
  98. a1_abs2 = np.abs(a1) ** 2
  99. a2_abs2 = np.abs(a2) ** 2
  100. # zero time delay at index ntau
  101. c[:, ntau] = np.abs(np.sum(a1[:, :, ts:te] * a2_conj[:, :, ts:te], axis=(0, 2))) ** 2 / \
  102. np.sum(a1_abs2[:, :, ts:te], axis=(0, 2)) / np.sum(a2_abs2[:, :, ts:te], axis=(0, 2))
  103. for itau in range(1, ntau + 1):
  104. tau = taus[ntau + itau] # absolute tau value
  105. taui = int(tau / dt) # index shift by tau
  106. # shift by +tau
  107. c[:, ntau + itau] = np.abs(np.sum(a1[:, :, (ts + taui):te] * a2_conj[:, :, ts:(te - taui)], axis=(0, 2))) ** 2 / \
  108. np.sum(a1_abs2[:, :, (ts + taui):te], axis=(0, 2)) / np.sum(a2_abs2[:, :, ts:(te - taui)], axis=(0, 2))
  109. # shift by -tau -> switch t boundaries
  110. c[:, ntau - itau] = np.abs(np.sum(a1[:, :, ts:(te - taui)] * a2_conj[:, :, (ts + taui):te], axis=(0, 2))) ** 2 / \
  111. np.sum(a1_abs2[:, :, ts:(te - taui)], axis=(0, 2)) / np.sum(a2_abs2[:, :, (ts + taui):te], axis=(0, 2))
  112. return c, taus
  113. def power(signal, dt_bin):
  114. # assert signal is 1D and real
  115. n_bins = signal.shape[-1]
  116. t_max = dt_bin * n_bins
  117. signal_fft = np.fft.rfft(signal)
  118. # frequency resolution and frequency axis
  119. n_fft = signal_fft.shape[-1]
  120. df = 1 / t_max
  121. f_axis = df * np.arange(n_fft)
  122. # power DENSITY, therefore divide by frequency resolution
  123. signal_power = 1 / df * (np.abs(signal_fft) / n_bins) ** 2
  124. signal_power[..., 1:-1] *= 2 # compensate lack of two-sided representation
  125. # case distinction necessary for odd/even number of bins
  126. if np.mod(n_bins, 2) == 1:
  127. signal_power[..., -1] *= 2
  128. return signal_power, f_axis, df
  129. def power_average(signal, dt_bin, n_average):
  130. n_bins_total = signal.shape[-1]
  131. n_bins = n_bins_total // n_average
  132. assert n_bins > 0, "Signal has too few bins for averaging!"
  133. shape_chunks = signal.shape[:-1] + (n_average, n_bins)
  134. signal_chunks = np.reshape(signal[..., :n_bins * n_average], shape_chunks)
  135. signal_power_chunks, f_axis, df = power(signal_chunks, dt_bin)
  136. signal_power = signal_power_chunks.mean(axis=-2)
  137. return signal_power, f_axis, df
  138. if __name__ == "__main__":
  139. import matplotlib.pyplot as plt
  140. print("Computing an example!")
  141. t_max = 3
  142. f_sin = 42
  143. dt_bin = 0.0025
  144. n_bins = np.ceil(t_max / dt_bin)
  145. a_sin = 3.2
  146. a_ofs = 2.1
  147. t = dt_bin * np.arange(n_bins)
  148. signal = a_ofs + a_sin * np.sin(2 * np.pi * f_sin * t)
  149. signal_power, f_axis, df = power(signal, dt_bin)
  150. plt.plot(f_axis, signal_power)
  151. plt.xlabel("frequency f [Hz]")
  152. plt.ylabel("spectral power [1/Hz]")
  153. plt.show()
  154. print("Checking the Tafelrunde:")
  155. print(f"var={np.var(signal):.3f}, int={np.sum(signal_power[1:])*df:.3f}")
  156. print(f"mean={np.mean(signal):.3f}, zeropow={np.sqrt(signal_power[0]*df):.3f}")
  157. complex_spectrum, freq_axis_wavelets = wavelet_transform_morlet(
  158. signal, n_freqs=100, freq_min=f_sin / 2, freq_max=f_sin * 2, dt=dt_bin,
  159. bandwidth=1.5)
  160. plt.imshow(abs(complex_spectrum) ** 2, cmap="hot", aspect="auto", interpolation="None")
  161. plt.colorbar()
  162. # %%