voltage_trace_analyzer.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import numpy as np
  2. from scipy import signal
  3. class VoltageTraceAnalyzer:
  4. DEFAULT_SPIKE_WIDTH = 100
  5. DEFAULT_SPIKE_THRESHOLD = 0
  6. def set_spike_width(self, width):
  7. self.spike_width = width
  8. return self
  9. def set_spike_threshold(self, threshold):
  10. self.spike_threshold = threshold
  11. return self
  12. def get_interspike_interval(self, start_time, end_time):
  13. start_idx = get_index_of_time_point(self.times, start_time)
  14. end_idx = get_index_of_time_point(self.times, end_time)
  15. return get_interspike_intervals(self.times[start_idx:end_idx], self.voltage[start_idx:end_idx],
  16. self.spike_threshold)
  17. def get_firing_rate(self, shifting_window_length, shifting_window_overlap):
  18. t_max = self.times[-1]
  19. sample_times = get_sample_times(shifting_window_length, shifting_window_overlap, t_max)
  20. mean_frequency = get_spike_rates(self.spiking_times, sample_times, shifting_window_length)
  21. return sample_times, mean_frequency
  22. def get_local_firing_rate(self, t_start, t_end):
  23. return get_local_spiking_rate(self.spiking_times, t_start, t_end)
  24. def __init__(self, times, voltage, spike_threshold=DEFAULT_SPIKE_THRESHOLD, spike_width=DEFAULT_SPIKE_WIDTH):
  25. self.times = times
  26. self.voltage = voltage
  27. self.spike_threshold = spike_threshold
  28. self.spike_width = spike_width
  29. self.spiking_times = self.times[get_spike_indices(self.voltage, self.spike_threshold, self.spike_width)]
  30. def get_index_of_time_point(times, time_point):
  31. return np.searchsorted(times, time_point)
  32. def get_negative_threshold_crossings(voltage_trace, threshold):
  33. first_values = voltage_trace[:-1]
  34. second_values = voltage_trace[1:]
  35. crossing_detection = np.where(np.logical_and(first_values >= threshold, second_values < threshold))
  36. crossing_indices = crossing_detection[0]
  37. return crossing_indices
  38. def get_sample_times(shifting_window_length, shifting_window_overlap, t_max):
  39. sample_times = np.arange(shifting_window_length / 2,
  40. t_max - shifting_window_length / 2 + 1,
  41. shifting_window_length - shifting_window_overlap)
  42. return sample_times
  43. def get_spike_indices(voltage, threshold, width):
  44. local_maxima = signal.argrelextrema(voltage, np.greater_equal, order=width)[0]
  45. #filter consecutive maxima that is the rare case that during a spike there is a saddle
  46. indices = filter(lambda idx: voltage[idx] > threshold, local_maxima)
  47. return indices
  48. def get_local_spiking_rate(spiking_times, t_start, t_end):
  49. number_of_spikes = np.sum(np.logical_and(
  50. t_start <= spiking_times,
  51. t_end > spiking_times))
  52. return float(number_of_spikes) / (t_end - t_start)
  53. def get_spike_rates(spiking_times, sample_times, shifting_window_length):
  54. spiking_times_array = np.repeat(spiking_times.reshape(1, len(spiking_times)), len(sample_times), axis=0)
  55. sample_times_array = np.repeat(sample_times.reshape(len(sample_times), 1), len(spiking_times),
  56. axis=1)
  57. number_of_spikes_in_shifting_window = np.sum(
  58. np.logical_and(
  59. (sample_times_array - shifting_window_length / 2) <= spiking_times_array,
  60. spiking_times_array <= (sample_times_array + shifting_window_length / 2))
  61. , axis=1)
  62. mean_frequency = number_of_spikes_in_shifting_window / float(shifting_window_length)
  63. return mean_frequency
  64. def get_interspike_intervals(times, voltage_trace, threshold):
  65. threshold_crossing_indices = get_negative_threshold_crossings(voltage_trace, threshold)
  66. if threshold_crossing_indices.shape[0] == 0:
  67. interspike_intervals = np.array([])
  68. elif threshold_crossing_indices.shape[0] == 1:
  69. interspike_intervals = np.array([np.Infinity])
  70. else:
  71. time_points = times[threshold_crossing_indices]
  72. interspike_intervals = time_points[1:] - time_points[:-1]
  73. return interspike_intervals