generate_test_data.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. import nitime.algorithms as tsa
  3. import scipy.io
  4. def generate_data(length, coeffs, variance, fs, transients=10):
  5. """
  6. Generate model data to test PSD prediction
  7. Parameters
  8. ----------
  9. length : int
  10. length of generated time series
  11. coeffs : np.array
  12. coefficients of auto-regressive model used to generate time series.
  13. coefficients need to be 1D
  14. variance : float
  15. variance of noise used to generate time series from auto-regressive
  16. model
  17. fs : float
  18. sampling frequency of time series data
  19. transients : int
  20. number of data points dropped from beginning of time series
  21. Returns
  22. -------
  23. times : np.array
  24. time axis of generated data
  25. time_series : np.array
  26. generated time series data
  27. freqs : np.array
  28. frequencies of PSD of generated data
  29. psd_time_series : np.array
  30. PSD of generated time series data
  31. """
  32. # generate time axis for data to be generated
  33. times = np.linspace(0, 1 / fs, length)
  34. # determine order of auto-regressive model to generate data
  35. order = np.size(coeffs)
  36. # array to store generated data
  37. time_series_full = np.zeros(length + transients)
  38. # generate noise
  39. noise = np.random.normal(0, variance, length + transients)
  40. # generate time series data from autoregressive model
  41. for i in range(length + transients):
  42. for j in range(order):
  43. time_series_full[i] += time_series_full[i - j - 1] * coeffs[j]
  44. time_series_full[i] += noise[i]
  45. # get rid off transients to finalize generated time series
  46. time_series = time_series_full[transients:]
  47. # generate frequencies for PSD
  48. freqs = np.fft.rfftfreq(length, d=1 / fs)
  49. # generate PSD of generated time series
  50. arguments = np.arange(1, order + 1, 1)
  51. prod_f_arg = np.outer(freqs, arguments)
  52. exps = np.exp(-2 * np.pi * 1j * prod_f_arg / fs)
  53. sum_exps = np.matmul(exps, coeffs)
  54. psd_time_series = variance / (fs * np.abs(1 - sum_exps)**2)
  55. return times, time_series, freqs, psd_time_series
  56. # Choose parameters, coeffs as in nitime utils.ar_generator default selection
  57. # See: http://nipy.org/nitime/api/generated/nitime.utils.html?highlight=utils%20ar_generator#nitime.utils.ar_generator # noqa
  58. length = 2**10
  59. coeffs = np.array([2.7607, -3.8106, 2.6535, -0.9238])
  60. variance = 1.
  61. fs = 0.1
  62. np.random.seed(1234)
  63. times, time_series, freqs, psd_time_series = generate_data(
  64. length=length,
  65. coeffs=coeffs,
  66. variance=variance,
  67. fs=fs)
  68. f, psd_mt, nu = tsa.multi_taper_psd(time_series, Fs=fs, NW=4,
  69. jackknife=False, low_bias=False)
  70. assert np.all(freqs == f), "Mismatch of frequencies between ground truth " \
  71. "and nitime"
  72. data_folder = '../data/'
  73. np.save(data_folder + 'time_series.npy', time_series)
  74. scipy.io.savemat(data_folder + 'time_series.mat', {"time_series": time_series})
  75. np.save(data_folder + 'psd_nitime.npy', psd_mt)
  76. np.save(data_folder + 'psd_ground_truth.npy', psd_time_series)