import numpy as np from scipy import signal class VoltageTraceAnalyzer: DEFAULT_SPIKE_WIDTH = 100 DEFAULT_SPIKE_THRESHOLD = 0 def set_spike_width(self, width): self.spike_width = width return self def set_spike_threshold(self, threshold): self.spike_threshold = threshold return self def get_interspike_interval(self, start_time, end_time): start_idx = get_index_of_time_point(self.times, start_time) end_idx = get_index_of_time_point(self.times, end_time) return get_interspike_intervals(self.times[start_idx:end_idx], self.voltage[start_idx:end_idx], self.spike_threshold) def get_firing_rate(self, shifting_window_length, shifting_window_overlap): t_max = self.times[-1] sample_times = get_sample_times(shifting_window_length, shifting_window_overlap, t_max) mean_frequency = get_spike_rates(self.spiking_times, sample_times, shifting_window_length) return sample_times, mean_frequency def get_local_firing_rate(self, t_start, t_end): return get_local_spiking_rate(self.spiking_times, t_start, t_end) def __init__(self, times, voltage, spike_threshold=DEFAULT_SPIKE_THRESHOLD, spike_width=DEFAULT_SPIKE_WIDTH): self.times = times self.voltage = voltage self.spike_threshold = spike_threshold self.spike_width = spike_width self.spiking_times = self.times[get_spike_indices(self.voltage, self.spike_threshold, self.spike_width)] def get_index_of_time_point(times, time_point): return np.searchsorted(times, time_point) def get_negative_threshold_crossings(voltage_trace, threshold): first_values = voltage_trace[:-1] second_values = voltage_trace[1:] crossing_detection = np.where(np.logical_and(first_values >= threshold, second_values < threshold)) crossing_indices = crossing_detection[0] return crossing_indices def get_sample_times(shifting_window_length, shifting_window_overlap, t_max): sample_times = np.arange(shifting_window_length / 2, t_max - shifting_window_length / 2 + 1, shifting_window_length - shifting_window_overlap) return sample_times def get_spike_indices(voltage, threshold, width): local_maxima = signal.argrelextrema(voltage, np.greater_equal, order=width)[0] #filter consecutive maxima that is the rare case that during a spike there is a saddle indices = filter(lambda idx: voltage[idx] > threshold, local_maxima) return indices def get_local_spiking_rate(spiking_times, t_start, t_end): number_of_spikes = np.sum(np.logical_and( t_start <= spiking_times, t_end > spiking_times)) return float(number_of_spikes) / (t_end - t_start) def get_spike_rates(spiking_times, sample_times, shifting_window_length): spiking_times_array = np.repeat(spiking_times.reshape(1, len(spiking_times)), len(sample_times), axis=0) sample_times_array = np.repeat(sample_times.reshape(len(sample_times), 1), len(spiking_times), axis=1) number_of_spikes_in_shifting_window = np.sum( np.logical_and( (sample_times_array - shifting_window_length / 2) <= spiking_times_array, spiking_times_array <= (sample_times_array + shifting_window_length / 2)) , axis=1) mean_frequency = number_of_spikes_in_shifting_window / float(shifting_window_length) return mean_frequency def get_interspike_intervals(times, voltage_trace, threshold): threshold_crossing_indices = get_negative_threshold_crossings(voltage_trace, threshold) if threshold_crossing_indices.shape[0] == 0: interspike_intervals = np.array([]) elif threshold_crossing_indices.shape[0] == 1: interspike_intervals = np.array([np.Infinity]) else: time_points = times[threshold_crossing_indices] interspike_intervals = time_points[1:] - time_points[:-1] return interspike_intervals