population_coding.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. ###############################################################################
  2. ## plot the baseline and driven response properties as a function ##
  3. ## receptive field pos ##
  4. import os
  5. import numpy as np
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. from matplotlib.patches import Rectangle
  9. from scipy.stats import mannwhitneyu
  10. from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size
  11. fig3_help = "Compares the stimulus encoding performance in homogeneous and heterogeneous populations of P-units. Depends on the presence of the DataFrames with the respective analysis results."
  12. def plot_mi_at_similar_rate_modulation(hom_df, het_df, axis, pop_sizes=[5, 10, 20], range=0.2):
  13. """Compares the mutual information for heterogeneous and homogeneous populations while comparing only populations with similar
  14. rate modulations of the population response.
  15. Args:
  16. hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
  17. het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
  18. axis (pyplot.axis): the figure axis to draw to
  19. pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20].
  20. range (scalar, float): the plus-minus range around the rate modlation of the homogeneous population.
  21. """
  22. datasets = hom_df.dataset_id.unique()
  23. lower_limit = 1. - range
  24. upper_limit = 1. + range
  25. for pop_size in pop_sizes:
  26. hom_infos = []
  27. het_infos = []
  28. for d in datasets:
  29. hom_trials = hom_df[(hom_df.dataset_id == d) & (hom_df.kernel == 0.001) & (hom_df.pop_size == pop_size)]
  30. contrasts = hom_trials.contrast.unique()
  31. for c in contrasts:
  32. trials = hom_trials[hom_trials.contrast == c]
  33. rate_modulation = np.mean(trials.rate_modulation)
  34. hom_info = np.mean(trials.mi)
  35. het_trials = het_df[(het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0) & (het_df.pop_size == pop_size) &
  36. (het_df.rate_modulation >= lower_limit * rate_modulation) & (het_df.rate_modulation < upper_limit * rate_modulation)]
  37. if len(het_trials) > 0:
  38. het_info = np.mean(het_trials.mi)
  39. hom_infos.append(hom_info)
  40. het_infos.append(het_info)
  41. sc = axis.scatter(hom_infos, het_infos, edgecolor="white", s=10, lw=0.3, label="population size: %i" % pop_size)
  42. axis.scatter(np.mean(hom_infos), np.mean(het_infos), s=50, marker="d", facecolor=sc.get_facecolor(), edgecolor="k", lw=1., zorder=10)
  43. axis.set_xlim([0, 1000])
  44. axis.set_ylim([0, 1000])
  45. axis.legend(handletextpad=0.01, loc="lower right", frameon=False, labelspacing=0.2)
  46. axis.plot([0, 1000], [0,1000], lw=1, ls='--', color="black")
  47. axis.set_xticks(np.arange(0, 1001, 250))
  48. axis.set_xticks(np.arange(0, 1001, 50), minor=True)
  49. axis.set_yticks(np.arange(0, 1001, 250))
  50. axis.set_yticks(np.arange(0, 1001, 50), minor=True)
  51. axis.set_yticklabels([])
  52. axis.set_ylabel("m.i. heterogeneous [bit/s]", fontsize=label_size)
  53. axis.set_xlabel("m.i. homogeneous [bit/s]", fontsize=label_size)
  54. axis.xaxis.set_label_coords(0.5, -0.125)
  55. axis.yaxis.set_label_coords(-0.1, 0.5)
  56. despine(axis, ["top", "right"], False)
  57. def label_diff(i, j, text, X, Y, axis, offset=50):
  58. x = (X[i] + X[j]) / 2
  59. y = max(Y[i], Y[j]) + offset
  60. dx = abs(X[i] - X[j])
  61. props = {'connectionstyle':'bar','arrowstyle':'-', 'shrinkA':20,'shrinkB':20,'linewidth':1}
  62. axis.annotate(text, xy=(x, y + offset ), zorder=10, ha="center", va="bottom", fontsize=7)
  63. axis.annotate('', xy=(X[i], y), xytext=(X[j],y), arrowprops=props)
  64. def statistical_comparison(hom_df, het_df, pop_sizes=[5, 10, 20]):
  65. """Compares the mutual information carried by homogeneous and heterogeneous populations.
  66. Performs a Mann Whitney U test with Bonferroni correction.
  67. Args:
  68. hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
  69. het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
  70. pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20].
  71. """
  72. hom_infos = []
  73. het_infos = []
  74. significance_labels = []
  75. significances = []
  76. for i, ps in enumerate(pop_sizes, 1):
  77. hom_trials = hom_df[(hom_df.pop_size == ps) & (hom_df.kernel == 0.001)]
  78. infos = []
  79. for d in hom_trials.dataset_id.unique():
  80. infos.append(np.mean(hom_trials.mi[hom_trials.dataset_id == d].values))
  81. hom_infos.append(np.asarray(infos))
  82. het_trials = het_df[(het_df.pop_size == ps) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)]
  83. het_infos.append(het_trials.mi.values)
  84. s, p = mannwhitneyu(hom_infos[-1], het_infos[-1])
  85. p *= len(pop_sizes)
  86. significance_labels.append( "n.s." if p > 0.05 else f"{'*' if p < 0.05 and p > 0.001 else '**'}")
  87. significances.append(p)
  88. return significances, significance_labels
  89. def plot_all_population_performances(hom_df, het_df, axis, stats_pop_sizes, labels, absolute_mi=True):
  90. """Plot the mutual information as a function of the population size.
  91. For homogeneous populations an individual line is drawn for each dataset. A thicker line representing the average
  92. is added.
  93. For heterogeneous populations an errorbar is used that shows mean and standard deviation across all populations for.
  94. Additionally two lines depicting the best and worst populations are added.
  95. Args:
  96. hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
  97. het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
  98. axis (pyplot.axis): the figure axis to draw to
  99. """
  100. # heterogeneous population:
  101. het_pop_size = het_df.pop_size.unique()
  102. het_mi_avgs = np.zeros(len(het_pop_size))
  103. het_mi_errors = np.zeros(len(het_pop_size))
  104. het_max = np.zeros(len(het_pop_size))
  105. het_min = np.zeros(len(het_pop_size))
  106. for i, htp in enumerate(het_pop_size):
  107. trials = het_df[(het_df.pop_size == htp) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)]
  108. if absolute_mi:
  109. het_mi_avgs[i] = np.mean(trials.mi.values)
  110. het_mi_errors[i] = np.std(trials.mi.values)
  111. het_max[i] = np.max(trials.mi.values)
  112. het_min[i] = np.min(trials.mi.values)
  113. else:
  114. het_mi_avgs[i] = np.mean(trials.mi.values / trials.population_rate.values)
  115. het_mi_errors[i] = np.std(trials.mi.values / trials.population_rate.values)
  116. het_max[i] = np.max(trials.mi.values / trials.population_rate.values)
  117. het_min[i] = np.min(trials.mi.values / trials.population_rate.values)
  118. het_avg_line = axis.errorbar(het_pop_size, het_mi_avgs, yerr=het_mi_errors, label="heterogeneous average", color="tab:red", fmt="-o", markersize=2.5, linewidth=0.7)
  119. het_minmax_line, = axis.plot(het_pop_size, het_max, ls="-", lw=1, color='tab:red', label="heterogeneous best/worst")
  120. axis.plot(het_pop_size, het_min, ls="-", lw=1, color='tab:red')
  121. # homogenous population:
  122. dsets = hom_df.dataset_id.unique()
  123. all_hom_pop_sizes = hom_df.pop_size.unique()
  124. trial_counter = np.zeros(all_hom_pop_sizes.shape)
  125. hom_avg_mi = np.zeros(all_hom_pop_sizes.shape)
  126. for d in dsets:
  127. contrasts = hom_df.contrast[hom_df.dataset_id == d].unique()
  128. for c in contrasts:
  129. pop_sizes = hom_df.pop_size[(hom_df.dataset_id == d) & (hom_df.contrast == c)].unique()
  130. hom_avg = np.zeros(len(pop_sizes))
  131. for i, ps in enumerate(pop_sizes):
  132. mis = hom_df[(hom_df.dataset_id == d) & (hom_df.contrast == c) & (hom_df.pop_size == ps)]
  133. if absolute_mi:
  134. hom_avg[i] = np.mean(mis.mi.values)
  135. index = np.argwhere(all_hom_pop_sizes == ps)[0][0]
  136. hom_avg_mi[index] += np.mean(mis.mi.values)
  137. else:
  138. hom_avg[i] = np.mean(mis.mi.values / mis.population_rate.values)
  139. index = np.argwhere(all_hom_pop_sizes == ps)[0][0]
  140. hom_avg_mi[index] += np.mean(mis.mi.values / mis.population_rate.values)
  141. trial_counter[index] += 1
  142. axis.plot(pop_sizes, hom_avg, lw=0.2, ls="--", color="tab:blue")
  143. hom_avg_mi /= trial_counter
  144. axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=2.5, color="white")
  145. hom_avg_line, = axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=1.5, color="tab:blue",
  146. label="homogeneous average")
  147. axis.set_xlim([0, 30.5])
  148. axis.set_xticks(np.arange(0, 31, 10))
  149. axis.set_xticks(np.arange(0, 31, 2), minor=True)
  150. axis.set_xticklabels(np.arange(0, 31, 10))
  151. max_y = 1000 if absolute_mi else 5
  152. axis.set_ylim([0, max_y])
  153. axis.set_yticks(np.arange(0, max_y+1, 250))
  154. axis.set_yticks(np.arange(0, max_y+1, 50), minor=True)
  155. axis.set_yticklabels(np.arange(0, max_y+1, 250))
  156. axis.set_xlabel("population size", fontsize=label_size)
  157. axis.set_ylabel(f"mutual information {'[bit/s]' if absolute_mi else '[bit/spike]'}", fontsize=label_size)
  158. axis.yaxis.set_label_coords(-0.2, 0.5)
  159. axis.xaxis.set_label_coords(0.5, -0.125)
  160. axis.legend((hom_avg_line, het_avg_line, het_minmax_line), ("hom. average", "het. average", "het. best/worst"),
  161. frameon=True, loc="lower right", bbox_to_anchor=[1.0, 0.0], edgecolor="none", labelspacing=0.2,
  162. handlelength=1.0)
  163. despine(axis, ["top", "right"], False)
  164. print("Homogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (hom_avg_mi[1], hom_avg_mi[all_hom_pop_sizes == 30]))
  165. print("Heterogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (het_mi_avgs[0], het_mi_avgs[-1]))
  166. for ps, label in zip(stats_pop_sizes, labels):
  167. r = Rectangle([ps-0.4, 0.0], 0.8, 975, linewidth=0.5, edgecolor="silver", fill=True, facecolor="silver", alpha=0.5)
  168. axis.add_patch(r)
  169. axis.text(ps, max_y, label, ha="center", va="center")
  170. def layout_figure():
  171. fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(5.1, 2.5))
  172. axes[0].text(-0.28, 1.02, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  173. transform=axes[0].transAxes)
  174. axes[1].text(-0.225, 1.02, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  175. transform=axes[1].transAxes)
  176. fig.subplots_adjust(left=0.11, bottom=0.145, right=0.965, top=0.9, wspace=0.3)
  177. return fig, axes
  178. def compare_population_encoding(args):
  179. """Compare mutual information carried by homogeneous and heterogeneous populations.
  180. Args:
  181. args [ArgumentParser] : command line arguments.
  182. """
  183. fig, axes = layout_figure()
  184. hom = pd.read_csv(args.homogeneous_data, sep=";", index_col=0)
  185. het = pd.read_csv(args.heterogeneous_data, sep=";", index_col=0)
  186. pop_sizes = [5, 10, 20]
  187. significances, labels = statistical_comparison(hom, het, pop_sizes)
  188. plot_all_population_performances(hom, het, axes[0], pop_sizes, labels, args.absolute_info)
  189. #plot_statistic_comparison(hom, het, axes[1])
  190. plot_mi_at_similar_rate_modulation(hom, het, axes[1])
  191. if args.nosave:
  192. plt.show()
  193. else:
  194. fig.savefig(args.outfile)
  195. plt.close()
  196. plt.close()
  197. def command_line_parser(subparsers):
  198. default_homogeneous_data_frame = os.path.join("derived_data", "homogeneous_populationcoding.csv")
  199. default_heterogeneous_data_frame = os.path.join("derived_data", "heterogeneous_populationcoding.csv")
  200. comparison_hom_vs_het_parser = subparsers.add_parser("population_coding", help=fig3_help)
  201. comparison_hom_vs_het_parser.add_argument("-a", "--absolute_info", action="store_true", default=True, help="Whether absolute information values in bit/s or bit/spike are plotted")
  202. comparison_hom_vs_het_parser.add_argument("-homdf", "--homogeneous_data", help=f"",
  203. default=default_homogeneous_data_frame)
  204. comparison_hom_vs_het_parser.add_argument("-hetdf", "--heterogeneous_data",
  205. default=default_heterogeneous_data_frame)
  206. comparison_hom_vs_het_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "comparison_homogeneous_heterogeneous.pdf"))
  207. comparison_hom_vs_het_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  208. comparison_hom_vs_het_parser.set_defaults(func=compare_population_encoding)