############################################################################### ## plot the baseline and driven response properties as a function ## ## receptive field pos ## import os import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from scipy.stats import mannwhitneyu from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size fig3_help = "Compares the stimulus encoding performance in homogeneous and heterogeneous populations of P-units. Depends on the presence of the DataFrames with the respective analysis results." def plot_mi_at_similar_rate_modulation(hom_df, het_df, axis, pop_sizes=[5, 10, 20], range=0.2): """Compares the mutual information for heterogeneous and homogeneous populations while comparing only populations with similar rate modulations of the population response. Args: hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations axis (pyplot.axis): the figure axis to draw to pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20]. range (scalar, float): the plus-minus range around the rate modlation of the homogeneous population. """ datasets = hom_df.dataset_id.unique() lower_limit = 1. - range upper_limit = 1. + range for pop_size in pop_sizes: hom_infos = [] het_infos = [] for d in datasets: hom_trials = hom_df[(hom_df.dataset_id == d) & (hom_df.kernel == 0.001) & (hom_df.pop_size == pop_size)] contrasts = hom_trials.contrast.unique() for c in contrasts: trials = hom_trials[hom_trials.contrast == c] rate_modulation = np.mean(trials.rate_modulation) hom_info = np.mean(trials.mi) het_trials = het_df[(het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0) & (het_df.pop_size == pop_size) & (het_df.rate_modulation >= lower_limit * rate_modulation) & (het_df.rate_modulation < upper_limit * rate_modulation)] if len(het_trials) > 0: het_info = np.mean(het_trials.mi) hom_infos.append(hom_info) het_infos.append(het_info) sc = axis.scatter(hom_infos, het_infos, edgecolor="white", s=10, lw=0.3, label="population size: %i" % pop_size) axis.scatter(np.mean(hom_infos), np.mean(het_infos), s=50, marker="d", facecolor=sc.get_facecolor(), edgecolor="k", lw=1., zorder=10) axis.set_xlim([0, 1000]) axis.set_ylim([0, 1000]) axis.legend(handletextpad=0.01, loc="lower right", frameon=False, labelspacing=0.2) axis.plot([0, 1000], [0,1000], lw=1, ls='--', color="black") axis.set_xticks(np.arange(0, 1001, 250)) axis.set_xticks(np.arange(0, 1001, 50), minor=True) axis.set_yticks(np.arange(0, 1001, 250)) axis.set_yticks(np.arange(0, 1001, 50), minor=True) axis.set_yticklabels([]) axis.set_ylabel("m.i. heterogeneous [bit/s]", fontsize=label_size) axis.set_xlabel("m.i. homogeneous [bit/s]", fontsize=label_size) axis.xaxis.set_label_coords(0.5, -0.125) axis.yaxis.set_label_coords(-0.1, 0.5) despine(axis, ["top", "right"], False) def label_diff(i, j, text, X, Y, axis, offset=50): x = (X[i] + X[j]) / 2 y = max(Y[i], Y[j]) + offset dx = abs(X[i] - X[j]) props = {'connectionstyle':'bar','arrowstyle':'-', 'shrinkA':20,'shrinkB':20,'linewidth':1} axis.annotate(text, xy=(x, y + offset ), zorder=10, ha="center", va="bottom", fontsize=7) axis.annotate('', xy=(X[i], y), xytext=(X[j],y), arrowprops=props) def statistical_comparison(hom_df, het_df, pop_sizes=[5, 10, 20]): """Compares the mutual information carried by homogeneous and heterogeneous populations. Performs a Mann Whitney U test with Bonferroni correction. Args: hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20]. """ hom_infos = [] het_infos = [] significance_labels = [] significances = [] for i, ps in enumerate(pop_sizes, 1): hom_trials = hom_df[(hom_df.pop_size == ps) & (hom_df.kernel == 0.001)] infos = [] for d in hom_trials.dataset_id.unique(): infos.append(np.mean(hom_trials.mi[hom_trials.dataset_id == d].values)) hom_infos.append(np.asarray(infos)) het_trials = het_df[(het_df.pop_size == ps) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)] het_infos.append(het_trials.mi.values) s, p = mannwhitneyu(hom_infos[-1], het_infos[-1]) p *= len(pop_sizes) significance_labels.append( "n.s." if p > 0.05 else f"{'*' if p < 0.05 and p > 0.001 else '**'}") significances.append(p) return significances, significance_labels def plot_all_population_performances(hom_df, het_df, axis, stats_pop_sizes, labels, absolute_mi=True): """Plot the mutual information as a function of the population size. For homogeneous populations an individual line is drawn for each dataset. A thicker line representing the average is added. For heterogeneous populations an errorbar is used that shows mean and standard deviation across all populations for. Additionally two lines depicting the best and worst populations are added. Args: hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations axis (pyplot.axis): the figure axis to draw to """ # heterogeneous population: het_pop_size = het_df.pop_size.unique() het_mi_avgs = np.zeros(len(het_pop_size)) het_mi_errors = np.zeros(len(het_pop_size)) het_max = np.zeros(len(het_pop_size)) het_min = np.zeros(len(het_pop_size)) for i, htp in enumerate(het_pop_size): trials = het_df[(het_df.pop_size == htp) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)] if absolute_mi: het_mi_avgs[i] = np.mean(trials.mi.values) het_mi_errors[i] = np.std(trials.mi.values) het_max[i] = np.max(trials.mi.values) het_min[i] = np.min(trials.mi.values) else: het_mi_avgs[i] = np.mean(trials.mi.values / trials.population_rate.values) het_mi_errors[i] = np.std(trials.mi.values / trials.population_rate.values) het_max[i] = np.max(trials.mi.values / trials.population_rate.values) het_min[i] = np.min(trials.mi.values / trials.population_rate.values) het_avg_line = axis.errorbar(het_pop_size, het_mi_avgs, yerr=het_mi_errors, label="heterogeneous average", color="tab:red", fmt="-o", markersize=2.5, linewidth=0.7) het_minmax_line, = axis.plot(het_pop_size, het_max, ls="-", lw=1, color='tab:red', label="heterogeneous best/worst") axis.plot(het_pop_size, het_min, ls="-", lw=1, color='tab:red') # homogenous population: dsets = hom_df.dataset_id.unique() all_hom_pop_sizes = hom_df.pop_size.unique() trial_counter = np.zeros(all_hom_pop_sizes.shape) hom_avg_mi = np.zeros(all_hom_pop_sizes.shape) for d in dsets: contrasts = hom_df.contrast[hom_df.dataset_id == d].unique() for c in contrasts: pop_sizes = hom_df.pop_size[(hom_df.dataset_id == d) & (hom_df.contrast == c)].unique() hom_avg = np.zeros(len(pop_sizes)) for i, ps in enumerate(pop_sizes): mis = hom_df[(hom_df.dataset_id == d) & (hom_df.contrast == c) & (hom_df.pop_size == ps)] if absolute_mi: hom_avg[i] = np.mean(mis.mi.values) index = np.argwhere(all_hom_pop_sizes == ps)[0][0] hom_avg_mi[index] += np.mean(mis.mi.values) else: hom_avg[i] = np.mean(mis.mi.values / mis.population_rate.values) index = np.argwhere(all_hom_pop_sizes == ps)[0][0] hom_avg_mi[index] += np.mean(mis.mi.values / mis.population_rate.values) trial_counter[index] += 1 axis.plot(pop_sizes, hom_avg, lw=0.2, ls="--", color="tab:blue") hom_avg_mi /= trial_counter axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=2.5, color="white") hom_avg_line, = axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=1.5, color="tab:blue", label="homogeneous average") axis.set_xlim([0, 30.5]) axis.set_xticks(np.arange(0, 31, 10)) axis.set_xticks(np.arange(0, 31, 2), minor=True) axis.set_xticklabels(np.arange(0, 31, 10)) max_y = 1000 if absolute_mi else 5 axis.set_ylim([0, max_y]) axis.set_yticks(np.arange(0, max_y+1, 250)) axis.set_yticks(np.arange(0, max_y+1, 50), minor=True) axis.set_yticklabels(np.arange(0, max_y+1, 250)) axis.set_xlabel("population size", fontsize=label_size) axis.set_ylabel(f"mutual information {'[bit/s]' if absolute_mi else '[bit/spike]'}", fontsize=label_size) axis.yaxis.set_label_coords(-0.2, 0.5) axis.xaxis.set_label_coords(0.5, -0.125) axis.legend((hom_avg_line, het_avg_line, het_minmax_line), ("hom. average", "het. average", "het. best/worst"), frameon=True, loc="lower right", bbox_to_anchor=[1.0, 0.0], edgecolor="none", labelspacing=0.2, handlelength=1.0) despine(axis, ["top", "right"], False) print("Homogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (hom_avg_mi[1], hom_avg_mi[all_hom_pop_sizes == 30])) print("Heterogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (het_mi_avgs[0], het_mi_avgs[-1])) for ps, label in zip(stats_pop_sizes, labels): r = Rectangle([ps-0.4, 0.0], 0.8, 975, linewidth=0.5, edgecolor="silver", fill=True, facecolor="silver", alpha=0.5) axis.add_patch(r) axis.text(ps, max_y, label, ha="center", va="center") def layout_figure(): fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(5.1, 2.5)) axes[0].text(-0.28, 1.02, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=axes[0].transAxes) axes[1].text(-0.225, 1.02, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=axes[1].transAxes) fig.subplots_adjust(left=0.11, bottom=0.145, right=0.965, top=0.9, wspace=0.3) return fig, axes def compare_population_encoding(args): """Compare mutual information carried by homogeneous and heterogeneous populations. Args: args [ArgumentParser] : command line arguments. """ fig, axes = layout_figure() hom = pd.read_csv(args.homogeneous_data, sep=";", index_col=0) het = pd.read_csv(args.heterogeneous_data, sep=";", index_col=0) pop_sizes = [5, 10, 20] significances, labels = statistical_comparison(hom, het, pop_sizes) plot_all_population_performances(hom, het, axes[0], pop_sizes, labels, args.absolute_info) #plot_statistic_comparison(hom, het, axes[1]) plot_mi_at_similar_rate_modulation(hom, het, axes[1]) if args.nosave: plt.show() else: fig.savefig(args.outfile) plt.close() plt.close() def command_line_parser(subparsers): default_homogeneous_data_frame = os.path.join("derived_data", "homogeneous_populationcoding.csv") default_heterogeneous_data_frame = os.path.join("derived_data", "heterogeneous_populationcoding.csv") comparison_hom_vs_het_parser = subparsers.add_parser("population_coding", help=fig3_help) comparison_hom_vs_het_parser.add_argument("-a", "--absolute_info", action="store_true", default=True, help="Whether absolute information values in bit/s or bit/spike are plotted") comparison_hom_vs_het_parser.add_argument("-homdf", "--homogeneous_data", help=f"", default=default_homogeneous_data_frame) comparison_hom_vs_het_parser.add_argument("-hetdf", "--heterogeneous_data", default=default_heterogeneous_data_frame) comparison_hom_vs_het_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "comparison_homogeneous_heterogeneous.pdf")) comparison_hom_vs_het_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") comparison_hom_vs_het_parser.set_defaults(func=compare_population_encoding)