filters.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from scipy.ndimage.filters import median_filter, uniform_filter
  2. import numpy as np
  3. from scipy.signal import kaiserord, firwin, freqz, lfilter
  4. def apply_filter(matrix_in: np.ndarray, view_flags, filter_type: str):
  5. if filter_type == "median":
  6. func = median_filter
  7. interpretation_func = view_flags.interpret_median_filter_params
  8. elif filter_type == "mean":
  9. func = uniform_filter
  10. interpretation_func = view_flags.interpret_mean_filter_params
  11. else:
  12. raise NotImplementedError(f"Filter type can be either 'median' or 'mean', got {filter_type}")
  13. size_in_space, size_in_time = interpretation_func()
  14. if size_in_time is None and size_in_space is None:
  15. return matrix_in
  16. if len(matrix_in.shape) == 3: # assume data format is XYT
  17. sizes_along_dimension_of_input = (size_in_space, size_in_space, size_in_time)
  18. elif len(matrix_in.shape) == 2: # assume data format is XY
  19. sizes_along_dimension_of_input = (size_in_space, size_in_space)
  20. elif len(matrix_in.shape) == 1: # assume data is a time trace
  21. sizes_along_dimension_of_input = (size_in_time,)
  22. else:
  23. raise NotImplementedError
  24. return func(matrix_in, size=sizes_along_dimension_of_input, mode="nearest")
  25. def filter_kaisord_highpass(signal, sampling_rate, cutoff=100, transitionWidth=40, rippleDB=20):
  26. """
  27. Applies a digital high pass filter to <signal>.
  28. :param Sequence signal: sequence of floats representing the signal to be filtered
  29. :param float sampling_rate: sampling rate of <signal> in Hz
  30. :param float cutoff: in Hz, frequencies above this will pass
  31. :param float transitionWidth: in Hz, over which filter gain transits from pass to stop
  32. :param float rippleDB: in DB, amplitude of ripple of frequency band stopped
  33. :return: Sequence of float, representing the filtered signal
  34. """
  35. nyqFreq = sampling_rate / 2
  36. transitionWidth = min(transitionWidth, cutoff)
  37. N, beta = kaiserord(rippleDB, transitionWidth / nyqFreq)
  38. tapsLP = firwin(N, cutoff / nyqFreq, window=('kaiser', beta))
  39. delay_samples = int((N - 1) * 0.5)
  40. temp = np.zeros((N,))
  41. temp[delay_samples] = 1
  42. tapsHP = temp - tapsLP
  43. temp = np.empty((len(signal) + 2 * delay_samples))
  44. temp[:delay_samples] = signal[0]
  45. temp[delay_samples: delay_samples + len(signal)] = signal
  46. temp[-delay_samples:] = signal[-1]
  47. temp1 = lfilter(tapsHP, 1.0, temp)
  48. signal_filtered = temp1[2 * delay_samples: 2 * delay_samples + len(signal)]
  49. # ----- code for debugging ----
  50. # from matplotlib import pyplot as plt
  51. # plt.ion()
  52. # fig, ax = plt.subplots(figsize=(7, 5.6))
  53. # ax.plot(signal, "-b", label="unfiltered")
  54. # ax.plot(signal_filtered, "-r", label='filtered')
  55. # ax.legend(loc="best")
  56. # plt.draw()
  57. # input("Press any key to continue")
  58. # plt.close()
  59. # ------------------------------
  60. return signal_filtered