123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- ###############################################################################
- ## plot the baseline and driven response properties as a function ##
- ## receptive field pos ##
- import os
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- from matplotlib.patches import Rectangle
- from scipy.stats import mannwhitneyu
- from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size
- fig3_help = "Compares the stimulus encoding performance in homogeneous and heterogeneous populations of P-units. Depends on the presence of the DataFrames with the respective analysis results."
- def plot_mi_at_similar_rate_modulation(hom_df, het_df, axis, pop_sizes=[5, 10, 20], range=0.2):
- """Compares the mutual information for heterogeneous and homogeneous populations while comparing only populations with similar
- rate modulations of the population response.
- Args:
- hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
- het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
- axis (pyplot.axis): the figure axis to draw to
- pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20].
- range (scalar, float): the plus-minus range around the rate modlation of the homogeneous population.
- """
- datasets = hom_df.dataset_id.unique()
- lower_limit = 1. - range
- upper_limit = 1. + range
- for pop_size in pop_sizes:
- hom_infos = []
- het_infos = []
- for d in datasets:
- hom_trials = hom_df[(hom_df.dataset_id == d) & (hom_df.kernel == 0.001) & (hom_df.pop_size == pop_size)]
- contrasts = hom_trials.contrast.unique()
- for c in contrasts:
- trials = hom_trials[hom_trials.contrast == c]
- rate_modulation = np.mean(trials.rate_modulation)
- hom_info = np.mean(trials.mi)
- het_trials = het_df[(het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0) & (het_df.pop_size == pop_size) &
- (het_df.rate_modulation >= lower_limit * rate_modulation) & (het_df.rate_modulation < upper_limit * rate_modulation)]
- if len(het_trials) > 0:
- het_info = np.mean(het_trials.mi)
- hom_infos.append(hom_info)
- het_infos.append(het_info)
- sc = axis.scatter(hom_infos, het_infos, edgecolor="white", s=10, lw=0.3, label="population size: %i" % pop_size)
- axis.scatter(np.mean(hom_infos), np.mean(het_infos), s=50, marker="d", facecolor=sc.get_facecolor(), edgecolor="k", lw=1., zorder=10)
- axis.set_xlim([0, 1000])
- axis.set_ylim([0, 1000])
- axis.legend(handletextpad=0.01, loc="lower right", frameon=False, labelspacing=0.2)
- axis.plot([0, 1000], [0,1000], lw=1, ls='--', color="black")
- axis.set_xticks(np.arange(0, 1001, 250))
- axis.set_xticks(np.arange(0, 1001, 50), minor=True)
- axis.set_yticks(np.arange(0, 1001, 250))
- axis.set_yticks(np.arange(0, 1001, 50), minor=True)
- axis.set_yticklabels([])
- axis.set_ylabel("m.i. heterogeneous [bit/s]", fontsize=label_size)
- axis.set_xlabel("m.i. homogeneous [bit/s]", fontsize=label_size)
- axis.xaxis.set_label_coords(0.5, -0.125)
- axis.yaxis.set_label_coords(-0.1, 0.5)
- despine(axis, ["top", "right"], False)
- def label_diff(i, j, text, X, Y, axis, offset=50):
- x = (X[i] + X[j]) / 2
- y = max(Y[i], Y[j]) + offset
- dx = abs(X[i] - X[j])
- props = {'connectionstyle':'bar','arrowstyle':'-', 'shrinkA':20,'shrinkB':20,'linewidth':1}
- axis.annotate(text, xy=(x, y + offset ), zorder=10, ha="center", va="bottom", fontsize=7)
- axis.annotate('', xy=(X[i], y), xytext=(X[j],y), arrowprops=props)
- def statistical_comparison(hom_df, het_df, pop_sizes=[5, 10, 20]):
- """Compares the mutual information carried by homogeneous and heterogeneous populations.
- Performs a Mann Whitney U test with Bonferroni correction.
- Args:
- hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
- het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
- pop_sizes (list, optional): The population sizes to test. Defaults to [5, 10, 20].
- """
- hom_infos = []
- het_infos = []
- significance_labels = []
- significances = []
- for i, ps in enumerate(pop_sizes, 1):
- hom_trials = hom_df[(hom_df.pop_size == ps) & (hom_df.kernel == 0.001)]
- infos = []
- for d in hom_trials.dataset_id.unique():
- infos.append(np.mean(hom_trials.mi[hom_trials.dataset_id == d].values))
- hom_infos.append(np.asarray(infos))
- het_trials = het_df[(het_df.pop_size == ps) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)]
- het_infos.append(het_trials.mi.values)
- s, p = mannwhitneyu(hom_infos[-1], het_infos[-1])
- p *= len(pop_sizes)
- significance_labels.append( "n.s." if p > 0.05 else f"{'*' if p < 0.05 and p > 0.001 else '**'}")
- significances.append(p)
- return significances, significance_labels
- def plot_all_population_performances(hom_df, het_df, axis, stats_pop_sizes, labels, absolute_mi=True):
- """Plot the mutual information as a function of the population size.
- For homogeneous populations an individual line is drawn for each dataset. A thicker line representing the average
- is added.
- For heterogeneous populations an errorbar is used that shows mean and standard deviation across all populations for.
- Additionally two lines depicting the best and worst populations are added.
- Args:
- hom_df (pandas DataFrame): the results table containing the results from the homogeneous populations
- het_df (pandas DataFrame): the results table containing the results from the heterogeneous populations
- axis (pyplot.axis): the figure axis to draw to
- """
- # heterogeneous population:
- het_pop_size = het_df.pop_size.unique()
- het_mi_avgs = np.zeros(len(het_pop_size))
- het_mi_errors = np.zeros(len(het_pop_size))
- het_max = np.zeros(len(het_pop_size))
- het_min = np.zeros(len(het_pop_size))
- for i, htp in enumerate(het_pop_size):
- trials = het_df[(het_df.pop_size == htp) & (het_df.kernel_sigma == 0.001) & (het_df.delay == 0.0)]
- if absolute_mi:
- het_mi_avgs[i] = np.mean(trials.mi.values)
- het_mi_errors[i] = np.std(trials.mi.values)
- het_max[i] = np.max(trials.mi.values)
- het_min[i] = np.min(trials.mi.values)
- else:
- het_mi_avgs[i] = np.mean(trials.mi.values / trials.population_rate.values)
- het_mi_errors[i] = np.std(trials.mi.values / trials.population_rate.values)
- het_max[i] = np.max(trials.mi.values / trials.population_rate.values)
- het_min[i] = np.min(trials.mi.values / trials.population_rate.values)
- het_avg_line = axis.errorbar(het_pop_size, het_mi_avgs, yerr=het_mi_errors, label="heterogeneous average", color="tab:red", fmt="-o", markersize=2.5, linewidth=0.7)
- het_minmax_line, = axis.plot(het_pop_size, het_max, ls="-", lw=1, color='tab:red', label="heterogeneous best/worst")
- axis.plot(het_pop_size, het_min, ls="-", lw=1, color='tab:red')
- # homogenous population:
- dsets = hom_df.dataset_id.unique()
- all_hom_pop_sizes = hom_df.pop_size.unique()
- trial_counter = np.zeros(all_hom_pop_sizes.shape)
- hom_avg_mi = np.zeros(all_hom_pop_sizes.shape)
- for d in dsets:
- contrasts = hom_df.contrast[hom_df.dataset_id == d].unique()
- for c in contrasts:
- pop_sizes = hom_df.pop_size[(hom_df.dataset_id == d) & (hom_df.contrast == c)].unique()
- hom_avg = np.zeros(len(pop_sizes))
- for i, ps in enumerate(pop_sizes):
- mis = hom_df[(hom_df.dataset_id == d) & (hom_df.contrast == c) & (hom_df.pop_size == ps)]
- if absolute_mi:
- hom_avg[i] = np.mean(mis.mi.values)
- index = np.argwhere(all_hom_pop_sizes == ps)[0][0]
- hom_avg_mi[index] += np.mean(mis.mi.values)
- else:
- hom_avg[i] = np.mean(mis.mi.values / mis.population_rate.values)
- index = np.argwhere(all_hom_pop_sizes == ps)[0][0]
- hom_avg_mi[index] += np.mean(mis.mi.values / mis.population_rate.values)
- trial_counter[index] += 1
- axis.plot(pop_sizes, hom_avg, lw=0.2, ls="--", color="tab:blue")
- hom_avg_mi /= trial_counter
- axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=2.5, color="white")
- hom_avg_line, = axis.plot(all_hom_pop_sizes, hom_avg_mi, ls='-', lw=1.5, color="tab:blue",
- label="homogeneous average")
- axis.set_xlim([0, 30.5])
- axis.set_xticks(np.arange(0, 31, 10))
- axis.set_xticks(np.arange(0, 31, 2), minor=True)
- axis.set_xticklabels(np.arange(0, 31, 10))
- max_y = 1000 if absolute_mi else 5
- axis.set_ylim([0, max_y])
- axis.set_yticks(np.arange(0, max_y+1, 250))
- axis.set_yticks(np.arange(0, max_y+1, 50), minor=True)
- axis.set_yticklabels(np.arange(0, max_y+1, 250))
- axis.set_xlabel("population size", fontsize=label_size)
- axis.set_ylabel(f"mutual information {'[bit/s]' if absolute_mi else '[bit/spike]'}", fontsize=label_size)
- axis.yaxis.set_label_coords(-0.2, 0.5)
- axis.xaxis.set_label_coords(0.5, -0.125)
- axis.legend((hom_avg_line, het_avg_line, het_minmax_line), ("hom. average", "het. average", "het. best/worst"),
- frameon=True, loc="lower right", bbox_to_anchor=[1.0, 0.0], edgecolor="none", labelspacing=0.2,
- handlelength=1.0)
- despine(axis, ["top", "right"], False)
- print("Homogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (hom_avg_mi[1], hom_avg_mi[all_hom_pop_sizes == 30]))
- print("Heterogeneous population @2: %.2f bit/s; @30 %.2f bit/s" % (het_mi_avgs[0], het_mi_avgs[-1]))
- for ps, label in zip(stats_pop_sizes, labels):
- r = Rectangle([ps-0.4, 0.0], 0.8, 975, linewidth=0.5, edgecolor="silver", fill=True, facecolor="silver", alpha=0.5)
- axis.add_patch(r)
- axis.text(ps, max_y, label, ha="center", va="center")
- def layout_figure():
- fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(5.1, 2.5))
- axes[0].text(-0.28, 1.02, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=axes[0].transAxes)
- axes[1].text(-0.225, 1.02, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=axes[1].transAxes)
- fig.subplots_adjust(left=0.11, bottom=0.145, right=0.965, top=0.9, wspace=0.3)
- return fig, axes
- def compare_population_encoding(args):
- """Compare mutual information carried by homogeneous and heterogeneous populations.
- Args:
- args [ArgumentParser] : command line arguments.
- """
- fig, axes = layout_figure()
- hom = pd.read_csv(args.homogeneous_data, sep=";", index_col=0)
- het = pd.read_csv(args.heterogeneous_data, sep=";", index_col=0)
- pop_sizes = [5, 10, 20]
- significances, labels = statistical_comparison(hom, het, pop_sizes)
- plot_all_population_performances(hom, het, axes[0], pop_sizes, labels, args.absolute_info)
- #plot_statistic_comparison(hom, het, axes[1])
- plot_mi_at_similar_rate_modulation(hom, het, axes[1])
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile)
- plt.close()
- plt.close()
- def command_line_parser(subparsers):
- default_homogeneous_data_frame = os.path.join("derived_data", "homogeneous_populationcoding.csv")
- default_heterogeneous_data_frame = os.path.join("derived_data", "heterogeneous_populationcoding.csv")
- comparison_hom_vs_het_parser = subparsers.add_parser("population_coding", help=fig3_help)
- comparison_hom_vs_het_parser.add_argument("-a", "--absolute_info", action="store_true", default=True, help="Whether absolute information values in bit/s or bit/spike are plotted")
- comparison_hom_vs_het_parser.add_argument("-homdf", "--homogeneous_data", help=f"",
- default=default_homogeneous_data_frame)
- comparison_hom_vs_het_parser.add_argument("-hetdf", "--heterogeneous_data",
- default=default_heterogeneous_data_frame)
- comparison_hom_vs_het_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "comparison_homogeneous_heterogeneous.pdf"))
- comparison_hom_vs_het_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- comparison_hom_vs_het_parser.set_defaults(func=compare_population_encoding)
|