############################################################################### ## plot the baseline and driven response properties as a function ## ## receptive field pos ## import os import numpy as np import pandas as pd import scipy.stats as stats import matplotlib.pyplot as plt import matplotlib.image as mpimg import matplotlib.gridspec as gridspec from matplotlib.ticker import FormatStrFormatter from mpl_toolkits.axes_grid1.inset_locator import inset_axes from .figure_style import * label_size = 8 fig2_help="Plots the properties of the baseline response and the stimulus driven response as a function of the receptive field position along the fish's rostro-caudal axis. Depends on fishsketch.png (expected in the folder ../figures/) and the data files " def get_feature_values(df, features): """Get the values of different Parameters ---------- df : [type] [description] features : [type] [description] Returns ------- [type] [description] """ datasets = df.dataset_id.unique() df = df.dropna() feature_dict = {} for f in features: feature_dict[f] = [] for d in datasets: contrasts = df.contrast[(df.dataset_id == d) & (df.pop_size == 1)].unique() for c in contrasts: trials = df[(df.dataset_id == d) & (df.pop_size == 1) & (df.contrast == c)] for f in feature_dict.keys(): feature_dict[f].append(np.mean(trials[f].values)) for f in feature_dict.keys(): feature_dict[f] = np.array(feature_dict[f]) return feature_dict, datasets def plotLinregress(ax, x, y, bonferroni_factor, color="tab:red", label="", feature=""): xRange = np.asarray([x.min(), x.max()]) slope, intercept, *params = stats.linregress(x, y) p_star = params[1] * bonferroni_factor if p_star > 1: p_star = 1 l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p$^*$", params[1] if bonferroni_factor == 1 else p_star) ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-") ax.legend(loc='upper left', frameon=False, ncol=1, fontsize=6, bbox_to_anchor=(0.0, 1.15)) print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y))) return slope, intercept def plotMeanProperty(ax, bins, x, y): my = np.zeros(len(bins)-1) sy = np.zeros(len(bins)-1) for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])): yy = y[(x >= start) & (x < end)] if len(yy) > 0: my[i] = np.mean(yy) sy[i] = np.std(yy) else: my[i] = np.nan sy[i] = np.nan centers = bins[:-1]+(bins[1]-bins[0])/2 # Filter out nans centers = centers[np.isfinite(my)] sy = sy[np.isfinite(my)] my = my[np.isfinite(my)] ax.plot(centers, my, '--', lw=0.5, color='black') ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1) return my def remove_bounds(x, y, label="firing rate"): mask = np.ones(len(x), dtype=bool) ordered_x = np.sort(x) min_index = np.where(x == ordered_x[0])[0] max_index = np.where(x == ordered_x[-1])[0] mask[min_index] = False mask[max_index] = False reduced_x = x[mask] reduced_y = y[mask] _, _, *params = stats.linregress(reduced_x, reduced_y) print(f"{label}: removed min and max position, r: {params[0]:.2f}, p: {params[1]:.2f}") def plot_baseline_properties(axes, df, markersize, correctionFactor): xPositions = df.receptor_pos_relative.values if os.path.exists("figures/fishsketch.png"): img = mpimg.imread("figures/fishsketch.png") else: raise FileNotFoundError("Missing fishsketch.png file.") img_aspect = img.shape[0]/img.shape[1] # Create receptor location inset size_inch = 0.40 inset_ax = inset_axes(axes[0], height=size_inch, width=size_inch/img_aspect) # Plot fish background inset_ax.imshow(img, alpha=1.0, extent=[0.0, 1.0, 0.0, 1.0], aspect=img_aspect) # Plot histogram counts, posBins = np.histogram(xPositions, bins=np.linspace(0, 1, 20)) binWidth = posBins[1]-posBins[0] posCenters = posBins[:-1]+binWidth/2 inset_ax.bar(posCenters, counts / counts.max(), color="tab:blue", alpha=0.5, width=binWidth) # Plot cumsum twinAx = inset_ax.twinx() cumCounts = np.cumsum(counts) cumCounts = cumCounts / np.max(cumCounts) twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.6) twinAx.set_yticks([]) twinAx.patch.set_visible(False) # Format inset_ax.set_title('receptor distribution', fontsize=7) inset_ax.set_xticks(np.arange(0, 1.1, 0.2)) inset_ax.set_xlim([0, 1.]) inset_ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) inset_ax.set_xticklabels([]) # inset_ax.set_ylabel("# cells", fontsize=label_size) inset_ax.set_yticks([]) # firing rate y_values = df.firing_rate.values remove_bounds(xPositions.copy(), y_values.copy()) ax = axes[0] plotMeanProperty(ax, posBins, xPositions, y_values) plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="firing rate") ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) ax.set_ylabel('firing rate [Hz]', fontsize=label_size) maxFR = np.ceil(df.firing_rate.max()/200)*200 + 200 ax.set_ylim(0, maxFR) ax.set_yticks(np.arange(0, maxFR+1, 200)) ax.set_yticks(np.arange(0, maxFR+1, 100), minor=True) #axes[1].spines['left'].set_bounds(0, maxFR) ax.set_xlim([0, 1.0]) ax.set_xticks(np.arange(0.0, 1.01, 0.2)) ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True) ax.set_xticklabels([]) ax.yaxis.set_label_coords(-0.15, 0.5) # CV of the interspike interval y_values = df.cv.values ax = axes[1] plotMeanProperty(ax, posBins, xPositions, y_values) plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="CV") ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) ax.set_ylabel(r'CV$_{ISI}$', fontsize=label_size) ax.set_ylim(0, 1.5) ax.set_yticks(np.arange(0, 1.51, 0.5)) ax.set_yticks(np.arange(0, 1.51, 0.25), minor=True) ax.set_xlim([0.0, 1.0]) ax.set_xticks(np.arange(0.0, 1.01, 0.2)) ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True) ax.set_xticklabels([]) ax.yaxis.set_label_coords(-0.15, 0.5) # burst fraction y_values = df.burst_fraction.values ax = axes[2] remove_bounds(xPositions.copy(), y_values.copy(), label="burst fraction") plotMeanProperty(ax, posBins, xPositions, y_values) plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="burst fraction") # Move values at 1 and 0 a bit inward for visualization, so they don't get cropped by axes y_values_vis = np.copy(y_values) y_values_vis[y_values > 0.99] = 0.99 y_values_vis[y_values < 0.02] = 0.02 ax.scatter(xPositions, y_values_vis, color="tab:blue", s=markersize) ax.set_ylabel('burst fraction', fontsize=label_size) ax.set_ylim(0, 1.0) ax.set_yticks(np.arange(0, 1.01, 0.5)) ax.set_yticks(np.arange(0, 1.01, 0.25), minor=True) ax.set_xlim([0, 1.0]) ax.set_xticks(np.arange(0.0, 1.01, 0.2)) ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True) ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1) # Vector strength y_values = df.vector_strength.values ax = axes[3] plotMeanProperty(ax, posBins, xPositions, y_values) plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="vector strength") ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) ax.set_ylabel('vector strength', fontsize=label_size) ax.yaxis.set_label_coords(-0.15, 0.5) ax.set_ylim(0.7, 1.0) ax.set_yticks(np.arange(0.7, 1.01, 0.1)) # ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True) ax.set_xlim([0, 1.0]) ax.set_xticks(np.arange(0.0, 1.01, 0.2)) ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True) ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1) def plot_driven_properties(axes, df, markersize, correctionFactor): xPositions = df.receptor_pos_relative.values posBins = np.linspace(0, 1, 10) # response modulation y_values = df.response_modulation.values ax = axes[0] ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="modulation") plotMeanProperty(ax, posBins, xPositions, y_values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) ax.set_xticklabels([]) ax.set_ylim([0, 300]) ax.set_yticks(np.arange(0, 301, 100)) ax.set_yticks(np.arange(0, 300, 50), minor=True) ax.set_ylabel("modulation [Hz]", fontsize=label_size) ax.yaxis.set_label_coords(-0.15, 0.5) # response variability y_values = df.response_variability.values ax = axes[1] ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="variability") plotMeanProperty(ax, posBins, xPositions, y_values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) ax.set_xticklabels([]) ax.set_ylim([0, 200]) ax.set_yticks(np.arange(0, 201, 50)) ax.set_yticks(np.arange(0, 200, 25), minor=True) ax.set_ylabel("variability [Hz]", fontsize=label_size) ax.yaxis.set_label_coords(-0.15, 0.5) # lower cutoff ax = axes[2] ax.scatter(xPositions, df.gain_cutoff_lower.values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:red", feature="lower cutoff") plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) ax.set_xticklabels([]) ax.set_ylim([0, 100]) ax.set_yticks(np.arange(0, 101, 20)) ax.set_yticks(np.arange(0, 101, 10), minor=True) ax.set_ylabel("cutoff [Hz]", fontsize=label_size) ax.yaxis.set_label_coords(-0.15, 0.5) # upper cutoff ax = axes[3] ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", feature="upper cutoff") plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_upper.values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) ax.set_xticklabels([]) ax.set_ylim([0, 300]) ax.set_yticks(np.arange(0, 301, 100)) ax.set_yticks(np.arange(0, 300, 50), minor=True) ax.set_ylabel("cutoff [Hz]", fontsize=label_size) ax.yaxis.set_label_coords(-0.15, 0.5) # maximum gain y_values = df.gain_maximum.values / 1000 # convert ot kHz/mV ax = axes[4] ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="gain") plotMeanProperty(ax, posBins, xPositions, y_values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) # ax.set_xticklabels([]) ax.set_ylim([0, 4]) ax.set_yticks(np.arange(0, 4.1, 2)) ax.set_yticks(np.arange(0, 4.1, 1), minor=True) ax.set_ylabel("gain [kHz/mV]", fontsize=label_size) ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1) ax.yaxis.set_label_coords(-0.15, 0.5) ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) ax.xaxis.set_label_coords(0.5, -0.20) # mututal information y_values = df.mi.values ax = axes[5] ax.scatter(xPositions, y_values, color="tab:blue", s=markersize) plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="mutual info.") plotMeanProperty(ax, posBins, xPositions, y_values) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True) ax.set_ylim([0, 600]) ax.set_yticks(np.arange(0, 601, 200)) ax.set_yticks(np.arange(0, 601, 100), minor=True) ax.set_ylabel("mutual info. [bit/s]", fontsize=label_size) ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1) ax.yaxis.set_label_coords(-0.15, 0.5) ax.yaxis.set_major_formatter(FormatStrFormatter('%i')) ax.xaxis.set_label_coords(0.5, -0.20) def layout_figure_baseline(): subfig_labels = ["A", "B", "C", "D"] fig, axes = plt.subplots(2, 2, figsize=(6.1, 4.25)) axes = axes.flatten() for i, ax in enumerate(axes): ax.text(-0.25, 1.05, subfig_labels[i], fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=ax.transAxes, ha="center", va="bottom") despine(ax, ["top", "right"], False) return fig, axes def layout_figure_driven(): subfig_labels = ["A", "B", "C", "D", "E", "F"] fig, axes = plt.subplots(3, 2, figsize=(6.1, 6.25)) axes = axes.flatten() for i, ax in enumerate(axes): ax.text(-0.25, 1.05, subfig_labels[i], fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=ax.transAxes, ha="center", va="bottom") despine(ax, ["top", "right"], False) return fig, axes def plot_position_correlations(args): if not os.path.exists(args.driven_frame) or not os.path.exists(args.baseline_frame): raise ValueError(f"Results data frame(s) not found! ({args.driven_frame} or {args.baseline_frame})") driven_df = pd.read_csv(args.driven_frame, sep=";", index_col=0) baseline_df = pd.read_csv(args.baseline_frame, sep=";", index_col=0) markersize = 0.5 fig_baseline, baseline_axes = layout_figure_baseline() fig_driven, driven_axes = layout_figure_driven() print("Baseline response:") plot_baseline_properties(baseline_axes, baseline_df, markersize, correctionFactor=4) print("Driven response:") plot_driven_properties(driven_axes, driven_df, markersize, correctionFactor=5) fig_baseline.subplots_adjust(left=0.12, right=0.975, top=0.92, bottom=0.095, hspace=0.30, wspace=0.3) fig_driven.subplots_adjust(left=0.12, right=0.975, top=0.95, bottom=0.075, hspace=0.35, wspace=0.3) if args.nosave: plt.show() else: fig_baseline.savefig(args.outfile_baseline, dpi=500) fig_driven.savefig(args.outfile_driven, dpi=500) plt.close() plt.close() def command_line_parser(subparsers): default_bf = os.path.join("derived_data", "figure2_baseline_properties.csv") default_df = os.path.join("derived_data", "figure2_driven_properties.csv") parser = subparsers.add_parser("property_correlations", help=fig2_help) parser.add_argument("-df", "--driven_frame", default=default_df, help=f"Full file name of a CSV table readable with pandas that contains the position data and coding properties ({default_df})") parser.add_argument("-bf", "--baseline_frame", default=default_bf, help=f"Full file name of a CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).") parser.add_argument("-ob", "--outfile_baseline", type=str, default=os.path.join("figures", "property_position_relation_baseline.pdf"), help="The filename of the figure") parser.add_argument("-od", "--outfile_driven", type=str, default=os.path.join("figures", "property_position_relation_driven.pdf"), help="The filename of the figure") parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") parser.set_defaults(func=plot_position_correlations) return parser