Browse Source

Split property_correlations into two separate figures for baseline and driven correlations;

Tim 9 months ago
parent
commit
366f77b638
1 changed files with 109 additions and 76 deletions
  1. 109 76
      code/plots/property_correlations.py

+ 109 - 76
code/plots/property_correlations.py

@@ -10,6 +10,7 @@ import matplotlib.image as mpimg
 import matplotlib.gridspec as gridspec
 
 from matplotlib.ticker import FormatStrFormatter
+from mpl_toolkits.axes_grid1.inset_locator import inset_axes
 
 from .figure_style import *
 label_size = 8
@@ -57,7 +58,7 @@ def plotLinregress(ax, x, y, bonferroni_factor, color="tab:red", label="", featu
         p_star = 1
     l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p$^*$", params[1] if bonferroni_factor == 1 else p_star)
     ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-")
-    ax.legend(loc='upper left', frameon=False, ncol=1,fontsize=6, bbox_to_anchor=(0., 1.3))
+    ax.legend(loc='upper left', frameon=False, ncol=1, fontsize=6, bbox_to_anchor=(0.0, 1.15))
     print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y)))
     return slope, intercept
 
@@ -75,6 +76,11 @@ def plotMeanProperty(ax, bins, x, y):
             sy[i] = np.nan
     centers = bins[:-1]+(bins[1]-bins[0])/2
 
+    # Filter out nans
+    centers = centers[np.isfinite(my)]
+    sy = sy[np.isfinite(my)]
+    my = my[np.isfinite(my)]
+
     ax.plot(centers, my, '--', lw=0.5, color='black')
     ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1)
 
@@ -97,38 +103,49 @@ def remove_bounds(x, y, label="firing rate"):
 def plot_baseline_properties(axes, df, markersize, correctionFactor):
     xPositions = df.receptor_pos_relative.values
 
-    posBins = np.linspace(0, 1, 10)
-    counts, posBins, _ = axes[0].hist(xPositions, np.linspace(0, 1, 10), color="tab:blue", histtype='stepfilled', alpha=0.75)
-    posCenters = posBins[:-1]+(posBins[1]-posBins[0])/2
-    twinAx = axes[0].twinx()
+    if os.path.exists("figures/fishsketch.png"):
+        img = mpimg.imread("figures/fishsketch.png")
+    else:
+        raise FileNotFoundError("Missing fishsketch.png file.")
+
+    img_aspect = img.shape[0]/img.shape[1]
+
+    # Create receptor location inset
+    size_inch = 0.40
+    inset_ax = inset_axes(axes[0], height=size_inch, width=size_inch/img_aspect)
+    # Plot fish background
+    inset_ax.imshow(img, alpha=1.0, extent=[0.0, 1.0, 0.0, 1.0], aspect=img_aspect)
+    # Plot histogram
+    counts, posBins = np.histogram(xPositions, bins=np.linspace(0, 1, 20))
+    binWidth = posBins[1]-posBins[0]
+    posCenters = posBins[:-1]+binWidth/2
+    inset_ax.bar(posCenters, counts / counts.max(), color="tab:blue", alpha=0.5, width=binWidth)
+    # Plot cumsum
+    twinAx = inset_ax.twinx()
     cumCounts = np.cumsum(counts)
-    cumCounts /= cumCounts.max()
-    twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.75)
+    cumCounts = cumCounts / np.max(cumCounts)
+    twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.6)
     twinAx.set_yticks([])
-    twinAx.set_ylim([0, 1.5])
     twinAx.patch.set_visible(False)
-    despine(twinAx, ["top","right", "bottom", "left"], True)
 
-    ax = axes[0]
-    ax.set_xticks(np.arange(0, 1.1, 0.2))
-    ax.set_xlim([0, 1.])
-    ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
-    ax.set_xticklabels([])
-    ax.set_ylabel("# of cells", fontsize=label_size)
-    ax.set_ylim([0, 60])
-    ax.set_yticks([0, 20, 40])
-    ax.set_yticks([0, 10, 20, 30, 40], minor=True)
-    ax.yaxis.set_label_coords(-0.15, 0.5)
+    # Format
+    inset_ax.set_title('receptor distribution', fontsize=7)
+    inset_ax.set_xticks(np.arange(0, 1.1, 0.2))
+    inset_ax.set_xlim([0, 1.])
+    inset_ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
+    inset_ax.set_xticklabels([])
+    # inset_ax.set_ylabel("# cells", fontsize=label_size)
+    inset_ax.set_yticks([])
 
     # firing rate
     y_values = df.firing_rate.values
     remove_bounds(xPositions.copy(), y_values.copy())
-    ax = axes[1]
+    ax = axes[0]
     plotMeanProperty(ax, posBins, xPositions, y_values)
     plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="firing rate")
     ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
     ax.set_ylabel('firing rate [Hz]', fontsize=label_size)
-    maxFR = np.ceil(df.firing_rate.max()/200)*200
+    maxFR = np.ceil(df.firing_rate.max()/200)*200 + 200
     ax.set_ylim(0, maxFR)
     ax.set_yticks(np.arange(0, maxFR+1, 200))
     ax.set_yticks(np.arange(0, maxFR+1, 100), minor=True)
@@ -141,7 +158,7 @@ def plot_baseline_properties(axes, df, markersize, correctionFactor):
 
     # CV of the interspike interval
     y_values = df.cv.values
-    ax = axes[2]
+    ax = axes[1]
     plotMeanProperty(ax, posBins, xPositions, y_values)
     plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="CV")
     ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
@@ -157,11 +174,15 @@ def plot_baseline_properties(axes, df, markersize, correctionFactor):
 
     # burst fraction
     y_values = df.burst_fraction.values
-    ax = axes[3]
+    ax = axes[2]
     remove_bounds(xPositions.copy(), y_values.copy(), label="burst fraction")
     plotMeanProperty(ax, posBins, xPositions, y_values)
     plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="burst fraction")
-    ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
+    # Move values at 1 and 0 a bit inward for visualization, so they don't get cropped by axes
+    y_values_vis = np.copy(y_values)
+    y_values_vis[y_values > 0.99] = 0.99
+    y_values_vis[y_values < 0.02] = 0.02
+    ax.scatter(xPositions, y_values_vis, color="tab:blue", s=markersize)
     ax.set_ylabel('burst fraction', fontsize=label_size)
     ax.set_ylim(0, 1.0)
     ax.set_yticks(np.arange(0, 1.01, 0.5))
@@ -169,25 +190,23 @@ def plot_baseline_properties(axes, df, markersize, correctionFactor):
     ax.set_xlim([0, 1.0])
     ax.set_xticks(np.arange(0.0, 1.01, 0.2))
     ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
-    ax.set_xticklabels([])
-    ax.yaxis.set_label_coords(-0.15, 0.5)
+    ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
 
     # Vector strength
     y_values = df.vector_strength.values
-    ax = axes[4]
+    ax = axes[3]
     plotMeanProperty(ax, posBins, xPositions, y_values)
     plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="vector strength")
     ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
     ax.set_ylabel('vector strength', fontsize=label_size)
     ax.yaxis.set_label_coords(-0.15, 0.5)
-    ax.set_ylim(0.6, 1.0)
-    ax.set_yticks(np.arange(0.6, 1.01, 0.2))
-    ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True)
+    ax.set_ylim(0.7, 1.0)
+    ax.set_yticks(np.arange(0.7, 1.01, 0.1))
+    # ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True)
     ax.set_xlim([0, 1.0])
     ax.set_xticks(np.arange(0.0, 1.01, 0.2))
     ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
     ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
-    ax.xaxis.set_label_coords(0.5, -0.35)
 
 
 def plot_driven_properties(axes, df, markersize, correctionFactor):
@@ -218,24 +237,35 @@ def plot_driven_properties(axes, df, markersize, correctionFactor):
     ax.set_xticks(np.arange(0, 1.1, 0.2))
     ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
     ax.set_xticklabels([])
-    ax.set_ylim([0, 300])
-    ax.set_yticks(np.arange(0, 301, 100))
-    ax.set_yticks(np.arange(0, 300, 50), minor=True)
+    ax.set_ylim([0, 200])
+    ax.set_yticks(np.arange(0, 201, 50))
+    ax.set_yticks(np.arange(0, 200, 25), minor=True)
     ax.set_ylabel("variability [Hz]", fontsize=label_size)
     ax.yaxis.set_label_coords(-0.15, 0.5)
 
-    # lower and upper cutoff
+    # lower cutoff
     ax = axes[2]
     ax.scatter(xPositions, df.gain_cutoff_lower.values, color="tab:blue", s=markersize)
-    plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:blue", label="lower", feature="lower cutoff")
-    plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values)   
-    ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:red", s=markersize)
-    plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", label="upper", feature="upper cutoff")
+    plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:red", feature="lower cutoff")
+    plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values)
+    ax.set_xticks(np.arange(0, 1.1, 0.2))
+    ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
+    ax.set_xticklabels([])
+    ax.set_ylim([0, 100])
+    ax.set_yticks(np.arange(0, 101, 20))
+    ax.set_yticks(np.arange(0, 101, 10), minor=True)
+    ax.set_ylabel("cutoff [Hz]", fontsize=label_size)
+    ax.yaxis.set_label_coords(-0.15, 0.5)
+
+    # upper cutoff
+    ax = axes[3]
+    ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:blue", s=markersize)
+    plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", feature="upper cutoff")
     plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_upper.values)
     ax.set_xticks(np.arange(0, 1.1, 0.2))
     ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
     ax.set_xticklabels([])
-    ax.set_ylim([0, 325])
+    ax.set_ylim([0, 300])
     ax.set_yticks(np.arange(0, 301, 100))
     ax.set_yticks(np.arange(0, 300, 50), minor=True)
     ax.set_ylabel("cutoff [Hz]", fontsize=label_size)
@@ -243,23 +273,25 @@ def plot_driven_properties(axes, df, markersize, correctionFactor):
 
     # maximum gain
     y_values = df.gain_maximum.values / 1000   # convert ot kHz/mV
-    ax = axes[3]
+    ax = axes[4]
     ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
     plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="gain")
     plotMeanProperty(ax, posBins, xPositions, y_values)
     ax.set_xticks(np.arange(0, 1.1, 0.2))
     ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
-    ax.set_xticklabels([])
+    # ax.set_xticklabels([])
     ax.set_ylim([0, 4])
     ax.set_yticks(np.arange(0, 4.1, 2))
     ax.set_yticks(np.arange(0, 4.1, 1), minor=True)
     ax.set_ylabel("gain [kHz/mV]", fontsize=label_size)
+    ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
     ax.yaxis.set_label_coords(-0.15, 0.5)
     ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
+    ax.xaxis.set_label_coords(0.5, -0.20)
 
     # mututal information
     y_values = df.mi.values
-    ax = axes[4]
+    ax = axes[5]
     ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
     plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="mutual info.")
     plotMeanProperty(ax, posBins, xPositions, y_values)
@@ -272,43 +304,38 @@ def plot_driven_properties(axes, df, markersize, correctionFactor):
     ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
     ax.yaxis.set_label_coords(-0.15, 0.5)
     ax.yaxis.set_major_formatter(FormatStrFormatter('%i'))
-    ax.xaxis.set_label_coords(0.5, -0.35)
+    ax.xaxis.set_label_coords(0.5, -0.20)
 
 
-def layout_figure():
-    baseline_subfig_labels = ["A", "B", "C", "D", "E"]
-    driven_subfig_labels = ["F", "G", "H", "I", "J"]
-    subplots_per_col = max([len(baseline_subfig_labels), len(driven_subfig_labels)])
+def layout_figure_baseline():
+    subfig_labels = ["A", "B", "C", "D"]
 
-    driven_axes = list(range(len(baseline_subfig_labels)))
-    baseline_axes = list(range(len(driven_subfig_labels)))
+    fig, axes = plt.subplots(2, 2, figsize=(6.1, 4.25))
+    axes = axes.flatten()
 
-    fig = plt.figure(figsize=(5.1, 5.75))
-    subfigs = fig.subfigures(1, 2, wspace=0.05)
-    gr = gridspec.GridSpec(subplots_per_col, 1, hspace=0.5, left=0.225)
+    for i, ax in enumerate(axes):
 
-    for i in range(subplots_per_col):
-        baseline_label = baseline_subfig_labels[i]
-        driven_label = driven_subfig_labels[i]
-        baseline_axes[i] = subfigs[0].add_subplot(gr[i, 0])
-        driven_axes[i] = subfigs[1].add_subplot(gr[i, 0])
+        ax.text(-0.25, 1.05, subfig_labels[i], fontsize=subfig_labelsize, fontweight=subfig_labelweight,
+                              transform=ax.transAxes, ha="center", va="bottom")
+        despine(ax, ["top", "right"], False)
 
-        baseline_axes[i].text(-0.25, 1.05, baseline_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight, 
-                              transform=baseline_axes[i].transAxes, ha="center", va="bottom")
-        driven_axes[i].text(-0.25, 1.05, driven_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight, 
-                              transform=driven_axes[i].transAxes, ha="center", va="bottom")
-        despine(baseline_axes[i], ["top", "right"], False)
-        despine(driven_axes[i], ["top", "right"], False)
+    return fig, axes
 
-    pic_ax = subfigs[0].add_axes((0.23, 0.88, 0.75, 0.15))
-    if os.path.exists("figures/fishsketch.png"):
-        img = mpimg.imread("figures/fishsketch.png")
-        pic_ax.imshow(img)
-        despine(pic_ax, ["top", "bottom", "left", "right"], True)
-    else:
-        raise FileNotFoundError("Missing fishsketch.png file.")
 
-    return fig, baseline_axes, driven_axes
+def layout_figure_driven():
+
+    subfig_labels = ["A", "B", "C", "D", "E", "F"]
+
+    fig, axes = plt.subplots(3, 2, figsize=(6.1, 6.25))
+    axes = axes.flatten()
+
+    for i, ax in enumerate(axes):
+
+        ax.text(-0.25, 1.05, subfig_labels[i], fontsize=subfig_labelsize, fontweight=subfig_labelweight,
+                              transform=ax.transAxes, ha="center", va="bottom")
+        despine(ax, ["top", "right"], False)
+
+    return fig, axes
 
 
 def plot_position_correlations(args):
@@ -318,18 +345,21 @@ def plot_position_correlations(args):
     driven_df = pd.read_csv(args.driven_frame, sep=";", index_col=0)
     baseline_df = pd.read_csv(args.baseline_frame, sep=";", index_col=0)
     markersize = 0.5
-    fig, baseline_axes, driven_axes = layout_figure()
+    fig_baseline, baseline_axes = layout_figure_baseline()
+    fig_driven, driven_axes = layout_figure_driven()
 
     print("Baseline response:")
     plot_baseline_properties(baseline_axes, baseline_df, markersize, correctionFactor=4)
     print("Driven response:")
     plot_driven_properties(driven_axes, driven_df, markersize, correctionFactor=5)
 
-    fig.subplots_adjust(left=0.2, right=0.975, top=0.95, bottom=0.075)
+    fig_baseline.subplots_adjust(left=0.12, right=0.975, top=0.92, bottom=0.095, hspace=0.30, wspace=0.3)
+    fig_driven.subplots_adjust(left=0.12, right=0.975, top=0.95, bottom=0.075, hspace=0.35, wspace=0.3)
     if args.nosave:
         plt.show()
     else:
-        fig.savefig(args.outfile, dpi=500)
+        fig_baseline.savefig(args.outfile_baseline, dpi=500)
+        fig_driven.savefig(args.outfile_driven, dpi=500)
         plt.close()
     plt.close()
 
@@ -343,7 +373,10 @@ def command_line_parser(subparsers):
                         help=f"Full file name of a CSV table readable with pandas that contains the position data and coding properties ({default_df})")
     parser.add_argument("-bf", "--baseline_frame", default=default_bf,
                         help=f"Full file name of a  CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).")
-    parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures", "property_position_relation.pdf"),
+    parser.add_argument("-ob", "--outfile_baseline", type=str, default=os.path.join("figures", "property_position_relation_baseline.pdf"),
+                        help="The filename of the figure")
+    parser.add_argument("-od", "--outfile_driven", type=str,
+                        default=os.path.join("figures", "property_position_relation_driven.pdf"),
                         help="The filename of the figure")
     parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") 
     parser.set_defaults(func=plot_position_correlations)