Browse Source

get firing rates

Paul Pfeiffer 1 year ago
parent
commit
42dc042301
3 changed files with 179 additions and 10 deletions
  1. 10 10
      makefile
  2. 73 0
      scripts/analyse_firing_rates.py
  3. 96 0
      scripts/tools/voltage_trace_analyzer.py

+ 10 - 10
makefile

@@ -1,30 +1,30 @@
-all: prepare_folder_structure extract_stimulations assign_protocols filter analyse_firing_rates analyse_persistent_activity plot_firing_rates plot_traces report
+all: prepare_folder_structure extract_stimulations assign_protocols filter analyse_firing_rates
 
 prepare_folder_structure:
 	python scripts/check_folder_structure.py
 
-extract_stimulations: prepare_folder_structure
+extract_stimulations:
 	python scripts/extract_stimulations.py
 
-assign_protocols: prepare_folder_structure
+assign_protocols:
 	python scripts/extract_protocols.py
 
-filter: prepare_folder_structure
+filter:
 	python scripts/filter_stimulations.py
 
-analyse_firing_rates: prepare_folder_structure
-	python analyse_firing_rates.py
+analyse_firing_rates:
+	python scripts/analyse_firing_rates.py
 
-analyse_persistent_activity: prepare_folder_structure
+analyse_persistent_activity:
 	python analyse_persistent_activity.py
 
-plot_firing_rates: prepare_folder_structure
+plot_firing_rates:
 	python plot_traces.py
 
-plot_traces: prepare_folder_structure
+plot_traces:
 	python plot_traces.py
 
-report: prepare_folder_structure
+report: 
 	python generate_report.py
 
 clean:

+ 73 - 0
scripts/analyse_firing_rates.py

@@ -0,0 +1,73 @@
+import warnings
+
+import pandas as pd
+
+from tools.helper import get_trace, get_times, get_pulse_times, get_filters
+from tools.definitions import OUTPUT_FOLDER, HELPER_TABLE_FOLDER, STIMULATION_METADATA_FILTERED
+from tools.voltage_trace_analyzer import VoltageTraceAnalyzer
+
+path_to_filtered_stimulus_file = OUTPUT_FOLDER+HELPER_TABLE_FOLDER+STIMULATION_METADATA_FILTERED
+
+print "# Analyse firing during and after pulses"
+print ""
+
+spiking_threshold = -10
+minimum_width_of_a_spike_in_time_points = 100
+
+pulse_mean_string = "p_{:d}_ISI_mean"
+after_pulse_mean_string = "ap_{:d}_ISI_mean"
+
+stimulations = pd.read_csv(path_to_filtered_stimulus_file, index_col="stimulation_id")
+stimulations = stimulations[get_filters(stimulations)]
+# get protocols
+type_number_combinations = stimulations[["protocol_type", "pulse_number"]].drop_duplicates().values
+
+
+def get_firing_rates(stimulation):
+    times = get_times(stimulation)
+    voltage_trace = get_trace(stimulation, "V-1")
+    v_analyser = VoltageTraceAnalyzer(times, voltage_trace, spiking_threshold, minimum_width_of_a_spike_in_time_points)
+
+    transient_period = 2
+
+    pulses = get_pulse_times(stimulation)
+
+    # Suppress RuntimeWarnings generated by empty ISI arrays
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore", category=RuntimeWarning)
+
+        for idx, pulse in enumerate(pulses):
+            during_pulse_column = pulse_mean_string.format(idx + 1)
+            after_pulse_column = after_pulse_mean_string.format(idx + 1)
+            stimulation[during_pulse_column] = v_analyser.get_interspike_interval(pulse["start"]+transient_period, pulse["pulse_end"]).mean()
+
+            stimulation[after_pulse_column] = v_analyser.get_interspike_interval(pulse["pulse_end"]+transient_period, pulse["end"]).mean()
+
+
+        stimulation["isi_before"] = v_analyser.get_interspike_interval(-stimulation["before_protocol"]+transient_period, 0).mean()
+        stimulation["isi_after"] = v_analyser.get_interspike_interval(stimulation["stimulus_length"]+transient_period, stimulation[
+            "stimulus_length"] + stimulation["after_protocol"]).mean()
+    return stimulation
+
+
+for type_number in type_number_combinations:
+    protocol_type = type_number[0]
+    number = type_number[1]
+
+    # Determine the protocol
+    protocol_id = "{}-{:d}".format(protocol_type, number)
+    stimulation_of_current_protocol = stimulations[(stimulations["protocol_type"] == protocol_type) & (stimulations[
+                                                                                                           "pulse_number"] == number)]
+    print "{}: {} stimulations".format(protocol_id, len(stimulation_of_current_protocol))
+
+    # Analyse firing rates
+    stimulation_with_firing_info = stimulation_of_current_protocol.apply(get_firing_rates, axis=1)
+
+    # Save firing rates
+    firing_info_columns = ["protocol_type", "pulse_number", "isi_before", "isi_after"] + \
+                          [pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)] + \
+                          [after_pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)]
+    path_to_firing_rate_file = OUTPUT_FOLDER+HELPER_TABLE_FOLDER+"{:s}.csv".format(protocol_id)
+    stimulation_with_firing_info[firing_info_columns].to_csv(path_to_firing_rate_file)
+    print "Firing rates saved in {:s}".format(path_to_firing_rate_file)
+    print ""

+ 96 - 0
scripts/tools/voltage_trace_analyzer.py

@@ -0,0 +1,96 @@
+import numpy as np
+from scipy import signal
+
+
+class VoltageTraceAnalyzer:
+    DEFAULT_SPIKE_WIDTH = 100
+    DEFAULT_SPIKE_THRESHOLD = 0
+
+    def set_spike_width(self, width):
+        self.spike_width = width
+        return self
+
+    def set_spike_threshold(self, threshold):
+        self.spike_threshold = threshold
+        return self
+
+    def get_interspike_interval(self, start_time, end_time):
+        start_idx = get_index_of_time_point(self.times, start_time)
+        end_idx = get_index_of_time_point(self.times, end_time)
+        return get_interspike_intervals(self.times[start_idx:end_idx], self.voltage[start_idx:end_idx],
+                                        self.spike_threshold)
+
+    def get_firing_rate(self, shifting_window_length, shifting_window_overlap):
+        t_max = self.times[-1]
+        sample_times = get_sample_times(shifting_window_length, shifting_window_overlap, t_max)
+
+        mean_frequency = get_spike_rates(self.spiking_times, sample_times, shifting_window_length)
+        return sample_times, mean_frequency
+
+    def get_local_firing_rate(self, t_start, t_end):
+        return get_local_spiking_rate(self.spiking_times, t_start, t_end)
+
+    def __init__(self, times, voltage, spike_threshold=DEFAULT_SPIKE_THRESHOLD, spike_width=DEFAULT_SPIKE_WIDTH):
+        self.times = times
+        self.voltage = voltage
+        self.spike_threshold = spike_threshold
+        self.spike_width = spike_width
+        self.spiking_times = self.times[get_spike_indices(self.voltage, self.spike_threshold, self.spike_width)]
+
+
+def get_index_of_time_point(times, time_point):
+    return np.searchsorted(times, time_point)
+
+
+def get_negative_threshold_crossings(voltage_trace, threshold):
+    first_values = voltage_trace[:-1]
+    second_values = voltage_trace[1:]
+    crossing_detection = np.where(np.logical_and(first_values >= threshold, second_values < threshold))
+    crossing_indices = crossing_detection[0]
+    return crossing_indices
+
+
+def get_sample_times(shifting_window_length, shifting_window_overlap, t_max):
+    sample_times = np.arange(shifting_window_length / 2,
+                             t_max - shifting_window_length / 2 + 1,
+                             shifting_window_length - shifting_window_overlap)
+    return sample_times
+
+
+def get_spike_indices(voltage, threshold, width):
+    local_maxima = signal.argrelextrema(voltage, np.greater_equal, order=width)[0]
+    #filter consecutive maxima that is the rare case that during a spike there is a saddle
+    indices = filter(lambda idx: voltage[idx] > threshold, local_maxima)
+    return indices
+
+
+def get_local_spiking_rate(spiking_times, t_start, t_end):
+    number_of_spikes = np.sum(np.logical_and(
+        t_start <= spiking_times,
+        t_end > spiking_times))
+    return float(number_of_spikes) / (t_end - t_start)
+
+
+def get_spike_rates(spiking_times, sample_times, shifting_window_length):
+    spiking_times_array = np.repeat(spiking_times.reshape(1, len(spiking_times)), len(sample_times), axis=0)
+    sample_times_array = np.repeat(sample_times.reshape(len(sample_times), 1), len(spiking_times),
+                                   axis=1)
+    number_of_spikes_in_shifting_window = np.sum(
+        np.logical_and(
+            (sample_times_array - shifting_window_length / 2) <= spiking_times_array,
+            spiking_times_array <= (sample_times_array + shifting_window_length / 2))
+        , axis=1)
+    mean_frequency = number_of_spikes_in_shifting_window / float(shifting_window_length)
+    return mean_frequency
+
+
+def get_interspike_intervals(times, voltage_trace, threshold):
+    threshold_crossing_indices = get_negative_threshold_crossings(voltage_trace, threshold)
+    if threshold_crossing_indices.shape[0] == 0:
+        interspike_intervals = np.array([])
+    elif threshold_crossing_indices.shape[0] == 1:
+        interspike_intervals = np.array([np.Infinity])
+    else:
+        time_points = times[threshold_crossing_indices]
+        interspike_intervals = time_points[1:] - time_points[:-1]
+    return interspike_intervals