figure4.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. ########################################################################################
  2. ## effect of delay on the coding performance ##
  3. import os
  4. import numpy as np
  5. import pandas as pd
  6. import matplotlib as mplt
  7. import matplotlib.pyplot as plt
  8. plt.style.use("./code/plots/pnas_onecolumn.mplstyle")
  9. from .figure_style import subfig_labelsize, subfig_labelweight, despine
  10. fig4_help = "plots the effect of conduction delays on stimulus encoding"
  11. def get_mutual_info(df, delay, column="mi", error="std", kernel=0.001):
  12. """Read the mutual information from the passed dataframe for all population sizes the specified delay and kernel.
  13. Function returns the average and yerr values as np.arrays.
  14. Parameters
  15. ----------
  16. df : pd.DataFrame
  17. DataFrame containing the analysis results.
  18. delay : float
  19. One of the delays used during analysis.
  20. column : str, optional
  21. The column name to read, by default "mi", i.e. the full mutual information
  22. error : str, optional
  23. The error measure that should be returned on of {"std", "se", "quartile"}, by default "std"
  24. kernel : float, optional
  25. The kernel width used in the analysis, by default 0.001
  26. Returns
  27. -------
  28. np.array
  29. the average mutual information for each population size
  30. np.array
  31. the respective error measures, For standard deviation and standard error, this will be a vector. For quartiles the 25 and 75 percent quartiles are returned in a 2 * number_of_population_sizes array
  32. """
  33. pop_sizes = df.pop_size.unique()
  34. delays = df.delay.unique()
  35. assert(delay in delays)
  36. if error == "quartile":
  37. yerr = np.zeros((2, len(pop_sizes)))
  38. if error == "std" or error == "se":
  39. yerr = np.zeros(len(pop_sizes))
  40. avg = np.zeros(len(pop_sizes))
  41. for i, p in enumerate(pop_sizes):
  42. selection = df[column][(df.pop_size == p) & (df.delay == delay) & (df.kernel_sigma == kernel)].values
  43. avg[i] = np.mean(selection)
  44. if error == "std":
  45. yerr[i] = np.std(selection)
  46. elif error == "se":
  47. yerr[i] = np.std(selection)/np.sqrt(len(selection))
  48. elif error == "quartile":
  49. yerr[0, i] = np.abs(avg[i] - np.percentile(selection, 25))
  50. yerr[1, i] = np.abs(avg[i] - np.percentile(selection, 75))
  51. return avg, yerr
  52. def get_mutual_info_delay(df, population_size, column="mi", error="std", kernel=0.001):
  53. """read mutual information values from the dataframe.
  54. Args:
  55. df (pandas.DataFrame): the data frame
  56. population_size (int): population size to be read out
  57. column (str, optional): column name. Defaults to "mi".
  58. error (str, optional): column name for the standard deviation. Defaults to "std".
  59. kernel (float, optional): kernel size. Defaults to 0.001.
  60. Returns:
  61. [type]: [description]
  62. """
  63. pop_sizes = df.pop_size.unique()
  64. delays = df.delay.unique()
  65. assert(population_size in pop_sizes)
  66. if error == "quartile":
  67. yerr = np.zeros((2, len(delays)))
  68. if error == "std" or error == "se":
  69. yerr = np.zeros(len(delays))
  70. avg = np.zeros(len(delays))
  71. for i, d in enumerate(delays):
  72. selection = df[column][(df.pop_size == population_size) & (df.delay == d) & (df.kernel_sigma == kernel)].values
  73. avg[i] = np.mean(selection)
  74. if error == "std":
  75. yerr[i] = np.std(selection)
  76. elif error == "se":
  77. yerr[i] = np.std(selection)/np.sqrt(len(selection))
  78. elif error == "quartile":
  79. yerr[0, i] = np.abs(avg[i] - np.percentile(selection, 25))
  80. yerr[1, i] = np.abs(avg[i] - np.percentile(selection, 75))
  81. return avg, yerr
  82. def compare_homogeneous_with_similar_heterogeneous(homogeneous, heterogeneous, axes=[]):
  83. """
  84. for each heterogeneous population size compare the mutual info with all
  85. homogenous population with similar rates/rate modulations
  86. Args:
  87. homogeneous ([pandas.DataFrame]): [description]
  88. heterogeneous (pandas.DataFrame): [description]
  89. axes (list(pyplot.axis)): optional, default None, the axis into which the results should be plotted
  90. """
  91. features = ["population_rate", "rate_modulation"]
  92. if len(axes) < len(features):
  93. fig = plt.figure()
  94. axes = []
  95. axes.append(fig.add_subplot(211))
  96. axes.append(fig.add_subplot(212))
  97. population_sizes = heterogeneous.pop_size[heterogeneous.pop_size > 1].unique()
  98. for i, f in enumerate(features):
  99. print(f)
  100. ax = axes[i]
  101. ax.set_title(f)
  102. for ps in population_sizes:
  103. hom_populations = homogeneous[(homogeneous.pop_size == ps)]
  104. het_populations = heterogeneous[(heterogeneous.pop_size == ps) & (heterogeneous.delay == 0) & (heterogeneous.kernel_sigma == 0.001)]
  105. het_mis = np.zeros(len(het_populations))
  106. for count, (_, row) in enumerate(het_populations.iterrows()):
  107. feat_value = row[f]
  108. het_mis[count] = row["mi_100"]
  109. trials = hom_populations[(hom_populations[f] >= 0.8 * feat_value) & (hom_populations[f] < 1.0 * feat_value)]
  110. hom_datasets = trials.dataset.unique()
  111. hom_mis = np.zeros(len(hom_datasets))
  112. for j, d in enumerate(hom_datasets):
  113. selection = trials[trials.dataset == d]
  114. hom_mis[j] = np.mean(selection.mi_100)
  115. ax.scatter(ps-0.25, np.mean(het_mis))
  116. ax.scatter(ps+0.25, np.mean(hom_mis))
  117. plt.show()
  118. def layout_figure():
  119. fig = plt.figure(figsize=(3.42, 4.2))
  120. shape = (20, 7)
  121. ax_total = plt.subplot2grid(shape, (0, 0), rowspan=8, colspan=3)
  122. ax_100 = plt.subplot2grid(shape, (0, 4), rowspan=6, colspan=3 )
  123. ax_200 = plt.subplot2grid(shape, (7, 4), rowspan=6, colspan=3 )
  124. ax_300 = plt.subplot2grid(shape, (14, 4), rowspan=6, colspan=3 )
  125. delay_axis = plt.subplot2grid(shape, (12, 0), rowspan=8, colspan=3 )
  126. ax_total.text(-.425, 1.125, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=ax_total.transAxes)
  127. ax_100.text(-.35, 1.125, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=ax_100.transAxes)
  128. delay_axis.text(-.425, 1.125, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=delay_axis.transAxes)
  129. axes = [ax_total, ax_100, ax_200, ax_300, delay_axis]
  130. fig.subplots_adjust(left=0.15, bottom=0.125, top=0.925, right=0.99, hspace=0.25)
  131. return fig, axes
  132. def plot_delay_effect(args):
  133. """Illustrate the effect of a delay, e.g. induced by neuronal conduction delays.
  134. Args:
  135. args ([type]): the command line arguments
  136. """
  137. fig, axes = layout_figure()
  138. df = pd.read_csv(args.heterogeneous_data, sep=";", index_col=0)
  139. pop_sizes = df.pop_size.unique()
  140. delays = df.delay.unique()
  141. mi_axes = axes[:4]
  142. columns = ["mi", "mi_100", "mi_200", "mi_300"]
  143. titles = [r"$0-300$Hz", r"$0-100$Hz",r"$100-200$Hz",r"$200-300$Hz"]
  144. cmaps = [mplt.cm.get_cmap('Greys'), mplt.cm.get_cmap('Reds'), mplt.cm.get_cmap('Greens'), mplt.cm.get_cmap('PuBu')]
  145. limits = [[0, 1000], [0, 400], [0, 400], [0, 400]]
  146. markers = ["o", "X", "d", "s"]
  147. # as a function of population size
  148. marked_delays = [0, 2, 4, 9, len(delays)-1]
  149. for ax, col, title, ylim, cmap, marker in zip(mi_axes, columns, titles, limits, cmaps, markers):
  150. for i, d in enumerate(delays):
  151. avg, yerr = get_mutual_info(df, d, col, error="std", kernel=args.kernel)
  152. ax.errorbar(pop_sizes, avg, yerr, lw=0.5,markeredgewidth=0.2, markeredgecolor="white",
  153. markersize=4, linestyle="--", marker=marker,
  154. color=cmap(i/len(delays)*0.8 + 0.2), label=r"%.2fms" % (d * 1000))
  155. if "$0-300$" in title:
  156. if i in marked_delays:
  157. ax.text(31, avg[-1], "%.1f" % (d *1000), fontsize=5, ha="left", va="center")
  158. if i == 0:
  159. ax.text(31, avg[-1] + 150, r"$\sigma_{delay}$", fontsize=5, ha="left", va="center")
  160. ax.text(31, avg[-1] + 75, "[ms]", fontsize=5, ha="left", va="center")
  161. plt.text(0.05, 0.90 if i == 0 else 0.85, title, transform=ax.transAxes, fontsize=7, ha="left")
  162. ax.set_ylim(ylim)
  163. ax.set_yticks(np.arange(0, ylim[1]+1, 250))
  164. ax.set_yticks(np.arange(0, ylim[1]+1, 50), minor=True)
  165. ax.set_xlim(0, np.max(pop_sizes)+5)
  166. ax.set_xticks(range(0, np.max(pop_sizes) + 5, 10))
  167. ax.set_xticks(range(0, np.max(pop_sizes) + 5, 2), minor=True)
  168. despine(ax, ["top", "right"], False)
  169. axes[0].set_ylabel("mutual information [bit/s]")
  170. axes[0].set_xlabel("population size")
  171. axes[0].yaxis.set_label_coords(-0.275, 0.5)
  172. axes[1].set_xticklabels([])
  173. axes[2].set_xticklabels([])
  174. axes[2].set_ylabel("mutual information [bit/s]")
  175. axes[2].yaxis.set_label_coords(-0.225, 0.5)
  176. axes[3].set_xlabel("population size", ha="center")
  177. axes[3].xaxis.set_label_coords(0.5, -0.35)
  178. delays[0] += 0.00025 # just for the log scale
  179. for col, title, cmap, marker in zip(columns[1:], titles[1:], cmaps[1:], markers[1:]):
  180. avg, yerr = get_mutual_info_delay(df, 16, column=col, error="std", kernel=args.kernel)
  181. axes[-1].errorbar(delays*1000, avg/avg[0], yerr= yerr/avg[0], c=cmap(0.75), lw=0.5,
  182. markeredgewidth=0.2, markersize=4, marker=marker, label=title,
  183. markeredgecolor="white", zorder=2)
  184. axes[-1].plot(delays*1000, avg/avg[0], c=cmap(0.75), lw=1.0, ls="--", zorder=1)
  185. axes[-1].legend(ncol=1, handletextpad=0.1, loc="lower left", columnspacing=.5, frameon=True,
  186. fancybox=False, shadow=False, markerscale=0.9, borderaxespad=0.05)
  187. axes[-1].set_xlabel(r"$\sigma_{delay}$ [ms]")
  188. axes[-1].set_ylabel(r"mutual information [rel.]")
  189. axes[-1].set_xlim(0.2, 30)
  190. axes[-1].set_ylim(0.0, 1.1)
  191. axes[-1].set_yticks(np.arange(0.0, 1.01, 0.25))
  192. axes[-1].set_xscale("log")
  193. axes[-1].set_xticks(delays*1000)
  194. delays[0] -= 0.00025
  195. yticklabels = list(map(str, np.round(delays*1000,1)))
  196. for i in range(1, len(yticklabels), 2):
  197. yticklabels[i] = ""
  198. yticklabels[yticklabels.index("7.0")] = ""
  199. axes[-1].set_xticklabels(yticklabels, rotation=45)
  200. axes[-1].yaxis.set_label_coords(-0.275, 0.5)
  201. despine(axes[-1], ["top", "right"], False)
  202. if args.nosave:
  203. plt.show()
  204. else:
  205. fig.savefig(args.outfile)
  206. plt.close()
  207. def command_line_parser(subparsers):
  208. parser = subparsers.add_parser("figure4", help="Plots the effect of conduction delays on stimulus encoding in heterogeneous populations.")
  209. parser.add_argument("-hetdf", "--heterogeneous_data", default=os.path.join("derived_data", "heterogeneous_populationcoding.csv"))
  210. parser.add_argument("-k", "--kernel", type=float, default=0.001, help="The kernel width used for the analysis")
  211. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures", "delay_effect.pdf"), help="The filename of the figure")
  212. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  213. parser.set_defaults(func=plot_delay_effect)