123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import numpy as np
- from scipy import signal
- class VoltageTraceAnalyzer:
- DEFAULT_SPIKE_WIDTH = 100
- DEFAULT_SPIKE_THRESHOLD = 0
- def set_spike_width(self, width):
- self.spike_width = width
- return self
- def set_spike_threshold(self, threshold):
- self.spike_threshold = threshold
- return self
- def get_interspike_interval(self, start_time, end_time):
- start_idx = get_index_of_time_point(self.times, start_time)
- end_idx = get_index_of_time_point(self.times, end_time)
- return get_interspike_intervals(self.times[start_idx:end_idx], self.voltage[start_idx:end_idx],
- self.spike_threshold)
- def get_firing_rate(self, shifting_window_length, shifting_window_overlap):
- t_max = self.times[-1]
- sample_times = get_sample_times(shifting_window_length, shifting_window_overlap, t_max)
- mean_frequency = get_spike_rates(self.spiking_times, sample_times, shifting_window_length)
- return sample_times, mean_frequency
- def get_local_firing_rate(self, t_start, t_end):
- return get_local_spiking_rate(self.spiking_times, t_start, t_end)
- def __init__(self, times, voltage, spike_threshold=DEFAULT_SPIKE_THRESHOLD, spike_width=DEFAULT_SPIKE_WIDTH):
- self.times = times
- self.voltage = voltage
- self.spike_threshold = spike_threshold
- self.spike_width = spike_width
- self.spiking_times = self.times[get_spike_indices(self.voltage, self.spike_threshold, self.spike_width)]
- def get_index_of_time_point(times, time_point):
- return np.searchsorted(times, time_point)
- def get_negative_threshold_crossings(voltage_trace, threshold):
- first_values = voltage_trace[:-1]
- second_values = voltage_trace[1:]
- crossing_detection = np.where(np.logical_and(first_values >= threshold, second_values < threshold))
- crossing_indices = crossing_detection[0]
- return crossing_indices
- def get_sample_times(shifting_window_length, shifting_window_overlap, t_max):
- sample_times = np.arange(shifting_window_length / 2,
- t_max - shifting_window_length / 2 + 1,
- shifting_window_length - shifting_window_overlap)
- return sample_times
- def get_spike_indices(voltage, threshold, width):
- local_maxima = signal.argrelextrema(voltage, np.greater_equal, order=width)[0]
- #filter consecutive maxima that is the rare case that during a spike there is a saddle
- indices = filter(lambda idx: voltage[idx] > threshold, local_maxima)
- return indices
- def get_local_spiking_rate(spiking_times, t_start, t_end):
- number_of_spikes = np.sum(np.logical_and(
- t_start <= spiking_times,
- t_end > spiking_times))
- return float(number_of_spikes) / (t_end - t_start)
- def get_spike_rates(spiking_times, sample_times, shifting_window_length):
- spiking_times_array = np.repeat(spiking_times.reshape(1, len(spiking_times)), len(sample_times), axis=0)
- sample_times_array = np.repeat(sample_times.reshape(len(sample_times), 1), len(spiking_times),
- axis=1)
- number_of_spikes_in_shifting_window = np.sum(
- np.logical_and(
- (sample_times_array - shifting_window_length / 2) <= spiking_times_array,
- spiking_times_array <= (sample_times_array + shifting_window_length / 2))
- , axis=1)
- mean_frequency = number_of_spikes_in_shifting_window / float(shifting_window_length)
- return mean_frequency
- def get_interspike_intervals(times, voltage_trace, threshold):
- threshold_crossing_indices = get_negative_threshold_crossings(voltage_trace, threshold)
- if threshold_crossing_indices.shape[0] == 0:
- interspike_intervals = np.array([])
- elif threshold_crossing_indices.shape[0] == 1:
- interspike_intervals = np.array([np.Infinity])
- else:
- time_points = times[threshold_crossing_indices]
- interspike_intervals = time_points[1:] - time_points[:-1]
- return interspike_intervals
|