lif_results.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. ###############################################################################
  2. ## plot lif simulation results. ##
  3. import os
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import matplotlib.image as mpimg
  7. from .figure_style import subfig_labelsize, subfig_labelweight, despine
  8. plt.style.use("code/plots/pnas_onecolumn.mplstyle")
  9. def plot_info_errorbar(axis, pop_size, info, label="", ls="-", color="tab:blue"):
  10. info_mean = np.mean(info, axis=1)
  11. axis.plot(pop_size, info_mean, lw=0.75, label=label, ls=ls, color=color)
  12. axis.set_xlim([0, 210])
  13. def plot_info_per_band(ax1, ax2, ax3, population_size, info_no_delay, info1_delay, info2_delay, info3_delay,
  14. labels=[], colors=[], ls="-", titles=[]):
  15. if len(labels) < 3:
  16. labels = ["" for i in range(4)]
  17. if len(colors) < 4:
  18. colors = ["tab:blue" for i in range(4)]
  19. if len(titles) < 3:
  20. titles = ["" for i in range(3)]
  21. if info_no_delay is not None:
  22. plot_info_errorbar(ax1, population_size, info_no_delay, label=labels[0], color=colors[0])
  23. plot_info_errorbar(ax2, population_size, info_no_delay, label=labels[0], color=colors[0])
  24. plot_info_errorbar(ax3, population_size, info_no_delay, label=labels[0], color=colors[0])
  25. plot_info_errorbar(ax1, population_size, info1_delay, label=labels[1], ls=ls, color=colors[1])
  26. plot_info_errorbar(ax2, population_size, info2_delay, label=labels[2], ls=ls, color=colors[2])
  27. plot_info_errorbar(ax3, population_size, info3_delay, label=labels[3], ls=ls, color=colors[3])
  28. def pimp_lif_sim_axes(axis, set_ylabel=True, set_xlabel=True):
  29. despine(axis, ["top", "right"], False)
  30. axis.set_xlim([0, 210])
  31. axis.set_xticks(np.arange(0, 201, 100))
  32. axis.set_xticks(np.arange(0, 201, 50), minor=True)
  33. axis.set_xticklabels(np.arange(0, 201, 100))
  34. if set_xlabel:
  35. axis.set_xlabel("population size", labelpad=1.5)
  36. axis.set_ylim([0, 750])
  37. axis.set_yticks(np.arange(0, 801, 200))
  38. axis.set_yticks(np.arange(0, 801, 100), minor=True)
  39. if set_ylabel:
  40. axis.set_ylabel("mutual information [bit/s]")
  41. else:
  42. axis.set_yticklabels([])
  43. if set_xlabel:
  44. axis.legend(bbox_to_anchor=(0.3, -1.075, 2.1, 0.25), ncol=2, frameon=True)
  45. def add_additional_xaxes(fig, axis, population_size, density, conduction_velocity, delay_dx=1, add_label=True):
  46. pos = axis.get_position().bounds
  47. ax2 = fig.add_axes((pos[0], 0.325, pos[2], 0.0))
  48. ax2.yaxis.set_visible(False) # hide the yaxis
  49. max_rf = max(population_size) / density * 1000
  50. rf_ticklabels = np.arange(0, np.ceil(max_rf)+1, np.ceil(max_rf/2), dtype=int)
  51. rf_ticks = rf_ticklabels / 1000 * density
  52. rf_minorticklabels = np.arange(0, np.ceil(max_rf)+1, np.ceil(max_rf/4))
  53. rf_minorticks = rf_minorticklabels / 1000 * density
  54. ax2.set_xticks(rf_ticks)
  55. ax2.set_xticklabels(rf_ticklabels)
  56. ax2.set_xticks(rf_minorticks, minor=True)
  57. if add_label:
  58. ax2.set_xlabel("spatial extent [mm]", labelpad=1.52)
  59. ax2.set_xlim([0, 210])
  60. pos = axis.get_position().bounds
  61. ax3 = fig.add_axes((pos[0], 0.2, pos[2], 0.0))
  62. ax3.yaxis.set_visible(False) # hide the yaxis
  63. max_delay = np.ceil(max(population_size) / density / conduction_velocity * 1000)
  64. delay_xticklabels = np.arange(0, max_delay+delay_dx, delay_dx, dtype=int)
  65. delay_xticks = delay_xticklabels /1000 * conduction_velocity * density
  66. minor_delay_xticklabels = np.arange(0.0, max_delay+1, delay_dx/5)
  67. minor_delay_xticks = minor_delay_xticklabels / 1000 * conduction_velocity * density
  68. ax3.set_xticks(delay_xticks)
  69. ax3.set_xticklabels(delay_xticklabels)
  70. ax3.set_xticks(minor_delay_xticks, minor=True)
  71. if add_label:
  72. ax3.set_xlabel("maximum delay [ms]", labelpad=1.5)
  73. ax3.set_xlim([0, 210])
  74. def add_model_sketches(fig):
  75. pic_ax = fig.add_axes((0., 0.7, 0.325, 0.30))
  76. if os.path.exists(os.path.join("figures", "model_sketch_a.png")):
  77. img = mpimg.imread(os.path.join("figures", "model_sketch_a.png"))
  78. pic_ax.imshow(img)
  79. despine(pic_ax, ["top", "bottom", "left", "right"], True)
  80. pic_ax = fig.add_axes((0.35, 0.7, 0.325, 0.30))
  81. if os.path.exists(os.path.join("figures", "model_sketch_b.png")):
  82. img = mpimg.imread(os.path.join("figures", "model_sketch_b.png"))
  83. pic_ax.imshow(img)
  84. despine(pic_ax, ["top", "bottom", "left", "right"], True)
  85. pic_ax = fig.add_axes((0.7, 0.7, 0.325, 0.30))
  86. if os.path.exists(os.path.join("figures", "model_sketch_c.png")):
  87. img = mpimg.imread(os.path.join("figures", "model_sketch_c.png"))
  88. pic_ax.imshow(img)
  89. despine(pic_ax, ["top", "bottom", "left", "right"], True)
  90. def layout_figure(titles, colors):
  91. fig = plt.figure(figsize=(5.1, 3.5))#, constrained_layout=True)
  92. axes = []
  93. fig_grid = (8, 8)
  94. colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]
  95. # low velocity
  96. axes.append(plt.subplot2grid(fig_grid, (0, 0), 4, 2))
  97. # squid velocity
  98. axes.append(plt.subplot2grid(fig_grid, (0, 3), 4, 2))
  99. # efish velocity
  100. axes.append(plt.subplot2grid(fig_grid, (0, 6), 4, 2))
  101. axes[0].text(0.5, 0.95, titles[0], transform=axes[0].transAxes, ha="center",fontsize=8, color=colors[1])
  102. axes[0].text(-0.5, 1.1, "A", transform=axes[0].transAxes, fontsize=subfig_labelsize, weight=subfig_labelweight, color="k")
  103. axes[1].text(0.5, 0.95, titles[1], transform=axes[1].transAxes, ha="center",fontsize=8, color=colors[2])
  104. axes[1].text(-0.35, 1.1, "B", transform=axes[1].transAxes, fontsize=subfig_labelsize, weight=subfig_labelweight)
  105. axes[2].text(0.5, 0.95, titles[2], transform=axes[2].transAxes, ha="center", fontsize=8, color=colors[3])
  106. axes[2].text(-0.35, 1.1, "C", transform=axes[2].transAxes,fontsize=subfig_labelsize, weight=subfig_labelweight)
  107. fig.subplots_adjust(left=0.15, bottom=0.0, top=0.9, right=0.975, wspace=-0.2)
  108. return fig, axes
  109. def lif_simulations(args):
  110. dsets = [args.results_file_low, args.results_file_med, args.results_file_high]
  111. vel_low = 7 # m/s
  112. vel_med = 25 # m/s
  113. vel_high = 50 # m/s
  114. titles = [r"%.1f$m/s$" % 7.0,
  115. r"%.1f$m/s$" % 25.,
  116. r"%.1f$m/s$ " % 50.]
  117. labels = ["without delay; 0 - 100 Hz",
  118. "with delay; 0 - 100 Hz",
  119. "with delay; 100 - 200 Hz",
  120. "with delay; 200 - 300 Hz"]
  121. line_styles = ["--", "-.", ":"]
  122. colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]
  123. fig, axes = layout_figure(titles, colors)
  124. for i, dset in enumerate(dsets):
  125. print(i, dset)
  126. if not os.path.exists(dset):
  127. raise ValueError("Results file %s does not exist!" % dset)
  128. data = np.load(dset)
  129. pop_size = data["population_sizes"]
  130. density = data["density"]
  131. info_no_delay = data["info_no_delays"]
  132. info_efish = data["info_delay_efish"]
  133. info_squid = data["info_delay_squid"]
  134. info_slow = data["info_delay_cc"]
  135. lbls = [labels[0]]
  136. lbls.extend([labels[i+1] for j in range(3)])
  137. if i < 1:
  138. plot_info_per_band(axes[0], axes[1], axes[2], pop_size, info_no_delay, info_slow, info_squid, info_efish, colors=colors, ls=line_styles[i], labels=lbls)
  139. else:
  140. plot_info_per_band(axes[0], axes[1], axes[2], pop_size, None, info_slow, info_squid, info_efish, colors=colors, ls=line_styles[i], labels=lbls)
  141. if i == 2: # lowest condction delay
  142. from IPython import embed
  143. #embed()
  144. avg_info = np.mean(info_slow, axis=1)
  145. min_idx = np.argmin(avg_info[pop_size < 100])
  146. min_popsize = pop_size[min_idx]
  147. #n.x, receptive_field.center[1] + 1.25, 0.0, -0.4, width=0.005, ec="tab:red", lw=0.5,
  148. # head_width=10*0.005, head_length=25*0.005)
  149. axes[0].arrow(min_popsize, avg_info[min_idx] + 100, 0.0, -50, ec="tab:red", lw=0.75, head_width=5,
  150. head_length=8)
  151. for i, a in enumerate(axes):
  152. pimp_lif_sim_axes(a, i == 0, i==1)
  153. add_additional_xaxes(fig, axes[0], pop_size, density, vel_low, 5, add_label=False)
  154. add_additional_xaxes(fig, axes[1], pop_size, density, vel_med, 2.0)
  155. add_additional_xaxes(fig, axes[2], pop_size, density, vel_high, 1.0, add_label=False)
  156. if args.nosave:
  157. plt.show()
  158. else:
  159. fig.savefig(args.outfile, dpi=500)
  160. plt.close()
  161. def command_line_parser(subparsers):
  162. parser = subparsers.add_parser("lif_results", help="Plots the simulation results obtained with the LIF toy model.")
  163. parser.add_argument("-rl", "--results_file_low", default=os.path.join("derived_data","lif_simulation_0_100.npz"), help="the numpy npz file containing the simulation results.")
  164. parser.add_argument("-rm", "--results_file_med", default=os.path.join("derived_data","lif_simulation_100_200.npz"), help="the numpy npz file containing the simulation results.")
  165. parser.add_argument("-rh", "--results_file_high", default=os.path.join("derived_data","lif_simulation_200_300.npz"), help="the numpy npz file containing the simulation results.")
  166. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","lif_simulations.pdf"), help="The filename of the figure")
  167. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  168. parser.set_defaults(func=lif_simulations)