import matplotlib.pyplot as plt import numpy import pandas as pd from matplotlib.lines import Line2D from tools.helper import after_pulse_mean_string, pulse_mean_string from tools.definitions import OUTPUT_FOLDER, HELPER_TABLE_FOLDER, PLOTS_FOLDER, STIMULATION_METADATA_FILTERED path_to_plot_folder = OUTPUT_FOLDER+PLOTS_FOLDER path_to_filtered_stimulus_file = OUTPUT_FOLDER+HELPER_TABLE_FOLDER+STIMULATION_METADATA_FILTERED stimulations = pd.read_csv(path_to_filtered_stimulus_file, index_col="stimulation_id") type_number_combinations = stimulations[["protocol_type", "pulse_number"]].drop_duplicates().values print "# Plotting firing rates" for type_number in type_number_combinations: protocol_type = type_number[0] number = type_number[1] # Determine the protocol protocol_id = "{}-{:d}".format(protocol_type, number) path_to_firing_rates_file = OUTPUT_FOLDER + HELPER_TABLE_FOLDER + "{}.csv".format(protocol_id) isis = pd.read_csv(path_to_firing_rates_file, index_col="stimulation_id") fig = plt.figure(figsize=(9, 9)) ax = fig.add_subplot(111) markers = iter(Line2D.filled_markers) if protocol_type == "UP": firing_info_columns = ["isi_before"] + \ [after_pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)] elif protocol_type == "DOWN": firing_info_columns = [pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)] + [ "isi_after"] else: RuntimeWarning("Unknown protocol {}".format(protocol_type)) for stim_id, isi in isis.iterrows(): legend = stim_id # noinspection PyTypeChecker firing_rates = numpy.nan_to_num(1.0 / isi[firing_info_columns].values.astype(numpy.float)) color = 'k' if stimulations.loc[stim_id]["was_successful"] else 'r' ax.plot(firing_rates, ls='-', marker=markers.next(), markersize=10, color=color, label=stim_id) ax.set_xticks(ticks=numpy.arange(0, number + 1)) ax.legend() ax.set_xlabel("Pulse") ax.set_ylabel("Firing [Hz]") fig.savefig("{}/{}.png".format(path_to_plot_folder, protocol_id))