plot_firing_rates.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import matplotlib.pyplot as plt
  2. import numpy
  3. import pandas as pd
  4. from matplotlib.lines import Line2D
  5. from tools.helper import after_pulse_mean_string, pulse_mean_string
  6. from tools.definitions import OUTPUT_FOLDER, HELPER_TABLE_FOLDER, PLOTS_FOLDER, STIMULATION_METADATA_FILTERED
  7. path_to_plot_folder = OUTPUT_FOLDER+PLOTS_FOLDER
  8. path_to_filtered_stimulus_file = OUTPUT_FOLDER+HELPER_TABLE_FOLDER+STIMULATION_METADATA_FILTERED
  9. stimulations = pd.read_csv(path_to_filtered_stimulus_file, index_col="stimulation_id")
  10. type_number_combinations = stimulations[["protocol_type", "pulse_number"]].drop_duplicates().values
  11. print "# Plotting firing rates"
  12. for type_number in type_number_combinations:
  13. protocol_type = type_number[0]
  14. number = type_number[1]
  15. # Determine the protocol
  16. protocol_id = "{}-{:d}".format(protocol_type, number)
  17. path_to_firing_rates_file = OUTPUT_FOLDER + HELPER_TABLE_FOLDER + "{}.csv".format(protocol_id)
  18. isis = pd.read_csv(path_to_firing_rates_file, index_col="stimulation_id")
  19. fig = plt.figure(figsize=(9, 9))
  20. ax = fig.add_subplot(111)
  21. markers = iter(Line2D.filled_markers)
  22. if protocol_type == "UP":
  23. firing_info_columns = ["isi_before"] + \
  24. [after_pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)]
  25. elif protocol_type == "DOWN":
  26. firing_info_columns = [pulse_mean_string.format(pulse_idx) for pulse_idx in range(1, number + 1)] + [
  27. "isi_after"]
  28. else:
  29. RuntimeWarning("Unknown protocol {}".format(protocol_type))
  30. for stim_id, isi in isis.iterrows():
  31. legend = stim_id
  32. # noinspection PyTypeChecker
  33. firing_rates = numpy.nan_to_num(1.0 / isi[firing_info_columns].values.astype(numpy.float))
  34. color = 'k' if stimulations.loc[stim_id]["was_successful"] else 'r'
  35. ax.plot(firing_rates, ls='-', marker=markers.next(), markersize=10, color=color, label=stim_id)
  36. ax.set_xticks(ticks=numpy.arange(0, number + 1))
  37. ax.legend()
  38. ax.set_xlabel("Pulse")
  39. ax.set_ylabel("Firing [Hz]")
  40. fig.savefig("{}/{}.png".format(path_to_plot_folder, protocol_id))