1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import matplotlib.pyplot as plt
- import numpy
- import pandas as pd
- from matplotlib.lines import Line2D
- from tools.helper import after_pulse_mean_string, pulse_mean_string
- from tools.definitions import OUTPUT_FOLDER, HELPER_TABLE_FOLDER, PLOTS_FOLDER, STIMULATION_METADATA_FILTERED
- path_to_plot_folder = OUTPUT_FOLDER+PLOTS_FOLDER
- path_to_filtered_stimulus_file = OUTPUT_FOLDER+HELPER_TABLE_FOLDER+STIMULATION_METADATA_FILTERED
- stimulations = pd.read_csv(path_to_filtered_stimulus_file, index_col="stimulation_id")
- type_number_combinations = stimulations[["protocol_type", "pulse_number"]].drop_duplicates().values
- print "# Plotting firing rates"
- for type_number in type_number_combinations:
- protocol_type = type_number[0]
- number = type_number[1]
- # Determine the protocol
- protocol_id = "{}-{:d}".format(protocol_type, number)
- path_to_firing_rates_file = OUTPUT_FOLDER + HELPER_TABLE_FOLDER + "{}.csv".format(protocol_id)
- isis = pd.read_csv(path_to_firing_rates_file, index_col="stimulation_id")
- fig = plt.figure(figsize=(9, 9))
- ax = fig.add_subplot(111)
- markers = iter(Line2D.filled_markers)
- if protocol_type == "UP":
- firing_info_columns = ["isi_before"] + \
- [after_pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)]
- elif protocol_type == "DOWN":
- firing_info_columns = [pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)] + [
- "isi_after"]
- else:
- RuntimeWarning("Unknown protocol {}".format(protocol_type))
- for stim_id, isi in isis.iterrows():
- legend = stim_id
- # noinspection PyTypeChecker
- firing_rates = numpy.nan_to_num(1.0 / isi[firing_info_columns].values.astype(numpy.float))
- color = 'k' if stimulations.loc[stim_id]["was_successful"] else 'r'
- ax.plot(firing_rates, ls='-', marker=markers.next(), markersize=10, color=color, label=stim_id)
- ax.set_xticks(ticks=numpy.arange(0, number + 1))
- ax.legend()
- ax.set_xlabel("Pulse")
- ax.set_ylabel("Firing [Hz]")
- fig.savefig("{}/{}.png".format(path_to_plot_folder, protocol_id))
|