123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- ###############################################################################
- ## plot lif simulation results. ##
- import os
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib.image as mpimg
- from .figure_style import subfig_labelsize, subfig_labelweight, despine
- plt.style.use("code/plots/pnas_onecolumn.mplstyle")
- def plot_info_errorbar(axis, pop_size, info, label="", ls="-", color="tab:blue"):
- info_mean = np.mean(info, axis=1)
- axis.plot(pop_size, info_mean, lw=0.75, label=label, ls=ls, color=color)
- axis.set_xlim([0, 210])
- def plot_info_per_band(ax1, ax2, ax3, population_size, info_no_delay, info1_delay, info2_delay, info3_delay,
- labels=[], colors=[], ls="-", titles=[]):
- if len(labels) < 3:
- labels = ["" for i in range(4)]
- if len(colors) < 4:
- colors = ["tab:blue" for i in range(4)]
- if len(titles) < 3:
- titles = ["" for i in range(3)]
- if info_no_delay is not None:
- plot_info_errorbar(ax1, population_size, info_no_delay, label=labels[0], color=colors[0])
- plot_info_errorbar(ax2, population_size, info_no_delay, label=labels[0], color=colors[0])
- plot_info_errorbar(ax3, population_size, info_no_delay, label=labels[0], color=colors[0])
- plot_info_errorbar(ax1, population_size, info1_delay, label=labels[1], ls=ls, color=colors[1])
- plot_info_errorbar(ax2, population_size, info2_delay, label=labels[2], ls=ls, color=colors[2])
- plot_info_errorbar(ax3, population_size, info3_delay, label=labels[3], ls=ls, color=colors[3])
- def pimp_lif_sim_axes(axis, set_ylabel=True, set_xlabel=True):
- despine(axis, ["top", "right"], False)
- axis.set_xlim([0, 210])
- axis.set_xticks(np.arange(0, 201, 100))
- axis.set_xticks(np.arange(0, 201, 50), minor=True)
- axis.set_xticklabels(np.arange(0, 201, 100))
- if set_xlabel:
- axis.set_xlabel("population size", labelpad=1.5)
- axis.set_ylim([0, 750])
- axis.set_yticks(np.arange(0, 801, 200))
- axis.set_yticks(np.arange(0, 801, 100), minor=True)
- if set_ylabel:
- axis.set_ylabel("mutual information [bit/s]")
- else:
- axis.set_yticklabels([])
- if set_xlabel:
- axis.legend(bbox_to_anchor=(0.3, -1.075, 2.1, 0.25), ncol=2, frameon=True)
- def add_additional_xaxes(fig, axis, population_size, density, conduction_velocity, delay_dx=1, add_label=True):
- pos = axis.get_position().bounds
- ax2 = fig.add_axes((pos[0], 0.325, pos[2], 0.0))
- ax2.yaxis.set_visible(False) # hide the yaxis
- max_rf = max(population_size) / density * 1000
- rf_ticklabels = np.arange(0, np.ceil(max_rf)+1, np.ceil(max_rf/2), dtype=int)
- rf_ticks = rf_ticklabels / 1000 * density
- rf_minorticklabels = np.arange(0, np.ceil(max_rf)+1, np.ceil(max_rf/4))
- rf_minorticks = rf_minorticklabels / 1000 * density
- ax2.set_xticks(rf_ticks)
- ax2.set_xticklabels(rf_ticklabels)
- ax2.set_xticks(rf_minorticks, minor=True)
- if add_label:
- ax2.set_xlabel("spatial extent [mm]", labelpad=1.52)
- ax2.set_xlim([0, 210])
- pos = axis.get_position().bounds
- ax3 = fig.add_axes((pos[0], 0.2, pos[2], 0.0))
- ax3.yaxis.set_visible(False) # hide the yaxis
- max_delay = np.ceil(max(population_size) / density / conduction_velocity * 1000)
- delay_xticklabels = np.arange(0, max_delay+delay_dx, delay_dx, dtype=int)
- delay_xticks = delay_xticklabels /1000 * conduction_velocity * density
- minor_delay_xticklabels = np.arange(0.0, max_delay+1, delay_dx/5)
- minor_delay_xticks = minor_delay_xticklabels / 1000 * conduction_velocity * density
- ax3.set_xticks(delay_xticks)
- ax3.set_xticklabels(delay_xticklabels)
- ax3.set_xticks(minor_delay_xticks, minor=True)
- if add_label:
- ax3.set_xlabel("maximum delay [ms]", labelpad=1.5)
- ax3.set_xlim([0, 210])
- def add_model_sketches(fig):
- pic_ax = fig.add_axes((0., 0.7, 0.325, 0.30))
- if os.path.exists(os.path.join("figures", "model_sketch_a.png")):
- img = mpimg.imread(os.path.join("figures", "model_sketch_a.png"))
- pic_ax.imshow(img)
- despine(pic_ax, ["top", "bottom", "left", "right"], True)
- pic_ax = fig.add_axes((0.35, 0.7, 0.325, 0.30))
- if os.path.exists(os.path.join("figures", "model_sketch_b.png")):
- img = mpimg.imread(os.path.join("figures", "model_sketch_b.png"))
- pic_ax.imshow(img)
- despine(pic_ax, ["top", "bottom", "left", "right"], True)
- pic_ax = fig.add_axes((0.7, 0.7, 0.325, 0.30))
- if os.path.exists(os.path.join("figures", "model_sketch_c.png")):
- img = mpimg.imread(os.path.join("figures", "model_sketch_c.png"))
- pic_ax.imshow(img)
- despine(pic_ax, ["top", "bottom", "left", "right"], True)
- def layout_figure(titles, colors):
- fig = plt.figure(figsize=(5.1, 3.5))#, constrained_layout=True)
- axes = []
- fig_grid = (8, 8)
- colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]
- # low velocity
- axes.append(plt.subplot2grid(fig_grid, (0, 0), 4, 2))
- # squid velocity
- axes.append(plt.subplot2grid(fig_grid, (0, 3), 4, 2))
- # efish velocity
- axes.append(plt.subplot2grid(fig_grid, (0, 6), 4, 2))
- axes[0].text(0.5, 0.95, titles[0], transform=axes[0].transAxes, ha="center",fontsize=8, color=colors[1])
- axes[0].text(-0.5, 1.1, "A", transform=axes[0].transAxes, fontsize=subfig_labelsize, weight=subfig_labelweight, color="k")
- axes[1].text(0.5, 0.95, titles[1], transform=axes[1].transAxes, ha="center",fontsize=8, color=colors[2])
- axes[1].text(-0.35, 1.1, "B", transform=axes[1].transAxes, fontsize=subfig_labelsize, weight=subfig_labelweight)
- axes[2].text(0.5, 0.95, titles[2], transform=axes[2].transAxes, ha="center", fontsize=8, color=colors[3])
- axes[2].text(-0.35, 1.1, "C", transform=axes[2].transAxes,fontsize=subfig_labelsize, weight=subfig_labelweight)
- fig.subplots_adjust(left=0.15, bottom=0.0, top=0.9, right=0.975, wspace=-0.2)
- return fig, axes
- def lif_simulations(args):
- dsets = [args.results_file_low, args.results_file_med, args.results_file_high]
- vel_low = 7 # m/s
- vel_med = 25 # m/s
- vel_high = 50 # m/s
- titles = [r"%.1f$m/s$" % 7.0,
- r"%.1f$m/s$" % 25.,
- r"%.1f$m/s$ " % 50.]
- labels = ["without delay; 0 - 100 Hz",
- "with delay; 0 - 100 Hz",
- "with delay; 100 - 200 Hz",
- "with delay; 200 - 300 Hz"]
- line_styles = ["--", "-.", ":"]
- colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]
- fig, axes = layout_figure(titles, colors)
- for i, dset in enumerate(dsets):
- print(i, dset)
- if not os.path.exists(dset):
- raise ValueError("Results file %s does not exist!" % dset)
- data = np.load(dset)
- pop_size = data["population_sizes"]
- density = data["density"]
- info_no_delay = data["info_no_delays"]
- info_efish = data["info_delay_efish"]
- info_squid = data["info_delay_squid"]
- info_slow = data["info_delay_cc"]
- lbls = [labels[0]]
- lbls.extend([labels[i+1] for j in range(3)])
- if i < 1:
- plot_info_per_band(axes[0], axes[1], axes[2], pop_size, info_no_delay, info_slow, info_squid, info_efish, colors=colors, ls=line_styles[i], labels=lbls)
- else:
- plot_info_per_band(axes[0], axes[1], axes[2], pop_size, None, info_slow, info_squid, info_efish, colors=colors, ls=line_styles[i], labels=lbls)
- if i == 2: # lowest condction delay
- from IPython import embed
- #embed()
- avg_info = np.mean(info_slow, axis=1)
- min_idx = np.argmin(avg_info[pop_size < 100])
- min_popsize = pop_size[min_idx]
- #n.x, receptive_field.center[1] + 1.25, 0.0, -0.4, width=0.005, ec="tab:red", lw=0.5,
- # head_width=10*0.005, head_length=25*0.005)
- axes[0].arrow(min_popsize, avg_info[min_idx] + 100, 0.0, -50, ec="tab:red", lw=0.75, head_width=5,
- head_length=8)
- for i, a in enumerate(axes):
- pimp_lif_sim_axes(a, i == 0, i==1)
- add_additional_xaxes(fig, axes[0], pop_size, density, vel_low, 5, add_label=False)
- add_additional_xaxes(fig, axes[1], pop_size, density, vel_med, 2.0)
- add_additional_xaxes(fig, axes[2], pop_size, density, vel_high, 1.0, add_label=False)
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile, dpi=500)
- plt.close()
- def command_line_parser(subparsers):
- parser = subparsers.add_parser("lif_results", help="Plots the simulation results obtained with the LIF toy model.")
- parser.add_argument("-rl", "--results_file_low", default=os.path.join("derived_data","lif_simulation_0_100.npz"), help="the numpy npz file containing the simulation results.")
- parser.add_argument("-rm", "--results_file_med", default=os.path.join("derived_data","lif_simulation_100_200.npz"), help="the numpy npz file containing the simulation results.")
- parser.add_argument("-rh", "--results_file_high", default=os.path.join("derived_data","lif_simulation_200_300.npz"), help="the numpy npz file containing the simulation results.")
- parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","lif_simulations.pdf"), help="The filename of the figure")
- parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- parser.set_defaults(func=lif_simulations)
|