Browse Source

Add new phase delay figure

New phase delay figure includes both, baseline- and STA-based, illustrations of conduction delays in recorded P-units.
thladnik 1 year ago
parent
commit
17f1418815
2 changed files with 280 additions and 0 deletions
  1. 278 0
      code/plots/phase_delays.py
  2. 2 0
      plot_figures.py

+ 278 - 0
code/plots/phase_delays.py

@@ -0,0 +1,278 @@
+import os
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+from scipy.stats import stats
+from scipy.stats import t as tstat
+
+from .figure_style import subfig_labelsize, subfig_labelweight, despine
+from .supp_figure5_analysis import main as velocity_analysis
+
+
+def get_delays(df, wn_df, pos_df):
+
+    datasets = pos_df.dataset_id.unique()
+    rel_positions = []
+    positions = []
+    wn_delays = []
+    phase_delays = []
+    for dataset in datasets:
+        rel_pos = pos_df.receptor_pos_relative[pos_df.dataset_id == dataset].values[0]
+        pos = pos_df.receptor_pos_absolute[pos_df.dataset_id == dataset].values[0]
+        trials = wn_df[wn_df.dataset_id == dataset]
+        if len(trials) < 1:
+            continue
+
+        # Get delay based on whitenoise STAs
+        wn_delay = np.mean(trials.delay)
+        wn_delays.append(wn_delay)
+
+        # Get delay based on shifted EOD phase lock and period
+        phase = df[df.dataset_id == dataset].phase_shifted.values[0]
+        eod_period = df[df.dataset_id == dataset].eod_period.values[0]
+        phase_delay = phase * eod_period  / (2 * np.pi)
+        phase_delays.append(phase_delay)
+        rel_positions.append(rel_pos)
+        positions.append(pos)
+
+    wn_delays = np.array(wn_delays)
+    phase_delays = np.array(phase_delays)
+    rel_positions = np.array(rel_positions)
+    positions = np.array(positions)
+
+    return rel_positions, positions, phase_delays, wn_delays
+
+
+def plot_phase_data(positions, phases, axis, color, show_centroids=False):
+    axis.scatter(positions, phases, color=color, s=10, marker=".")
+    if show_centroids:
+        centroid_x = np.mean(positions)
+        centroid_y = np.mean(phases)
+        axis.scatter(centroid_x, centroid_y, marker="+", s=15, color=color, linewidth=.8)
+
+
+def plot_phases_raw(df, axis):
+    colors = ["tab:blue", "tab:orange"]
+    cluster_labels = [0, 1]
+
+    for label, color in zip(cluster_labels, colors):
+        selection_phases = df.phase[df.kmeans_label == label]
+        selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
+        plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=True)
+
+    axis.set_xlim([0, 1.0])
+    axis.set_xlabel("receptor position [rel.]")
+    axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
+
+    axis.set_ylim([0, 2*np.pi])
+    axis.set_ylabel("phase [rad]")
+    axis.set_yticks(np.arange(0, 2 * np.pi + 0.1, np.pi))
+    axis.set_yticks(np.arange(0, 2 * np.pi + 0.1, np.pi/4), minor=True)
+    axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$"])
+
+
+def plot_phases_shifted(df, axis):
+    colors = ["tab:blue", "tab:orange"]
+    cluster_labels = [0, 1]
+    for label, color in zip(cluster_labels, colors):
+        selection_phases = df.phase_shifted[df.kmeans_label == label]
+        selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
+        plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
+
+    slope, intercept, *params = stats.linregress(df.receptor_pos_relative, df.phase_shifted)
+    x_range = np.arange(df.receptor_pos_relative.min(), df.receptor_pos_relative.max(), 0.1)
+    y_fit = intercept + slope * x_range
+    axis.plot(x_range, y_fit, color="black", ls="-",
+              label=f"r:{params[0]:.2f}, p:{params[1]:.3f}")
+
+    axis.legend(loc='lower right')
+    axis.set_xlim([0, 1.0])
+    axis.set_xlabel("receptor position [rel.]")
+    axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
+
+    axis.set_ylim([0, 4*np.pi])
+    axis.set_ylabel("phase + " + r"$2\pi$" + " [rad]")
+    axis.set_yticks(np.arange(0, 4 * np.pi + 0.1, np.pi))
+    axis.set_yticks(np.arange(0, 4 * np.pi + 0.1, np.pi/4), minor=True)
+    axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$", r"$3\pi$", r"$4\pi$"])
+
+
+def plot_baseline_delay(df, axis):
+    colors = ["tab:blue", "tab:orange"]
+    cluster_labels = [0, 1]
+    for label, color in zip(cluster_labels, colors):
+        selection_phases = df.phase_time[df.kmeans_label == label] * 1000
+        selection_positions = df.receptor_pos_absolute[df.kmeans_label == label]
+        plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
+
+    # Fit
+    slope, intercept, *params = stats.linregress(df.receptor_pos_absolute, df.phase_time * 1000)
+    x_range = np.arange(df.receptor_pos_absolute.min(), df.receptor_pos_absolute.max(), 0.1)
+    y_fit = intercept + slope * x_range
+    axis.plot(x_range, y_fit, color="black", ls="-",
+              label=f"slope:{1./slope:.1f}m/s, r:{params[0]:.2f}, p:{params[1]:.3f}")
+
+    axis.legend(loc='lower right')
+    axis.set_xlim(20, 120)
+    axis.set_xlabel("receptor position [mm]")
+
+    axis.set_ylim([0, 3])
+    axis.set_ylabel("phase delay [ms]")
+
+
+def plot_ci(ax, x, y, intercept, slope):
+    """After Tomas Holderness' implementation:
+        https://tomholderness.wordpress.com/2013/01/10/confidence_intervals/"""
+    # linfit.py - example of confidence limit calculation for linear regression fitting.
+
+    # References:
+    # - Statistics in Geography by David Ebdon (ISBN: 978-0631136880)
+    # - Reliability Engineering Resource Website:
+    # - http://www.weibull.com/DOEWeb/confidence_intervals_in_simple_linear_regression.htm
+    # - University of Glascow, Department of Statistics:
+    # - http://www.stats.gla.ac.uk/steps/glossary/confidence_intervals.html#conflim
+
+    # fit a curve to the data using a least squares 1st order polynomial fit
+    fit = intercept + slope * x
+
+    # predict y values of origional data using the fit
+    p_y = slope * x + intercept
+
+    # calculate the y-error (residuals)
+    y_err = y - p_y
+
+    # create series of new test x-values to predict for
+    p_x = np.arange(np.min(x), np.max(x) + 1, 1)
+
+    # now calculate confidence intervals for new test x-series
+    mean_x = np.mean(x)  # mean of x
+    n = len(x)  # number of samples in origional fit
+    t = tstat.ppf(1. - 0.1 / 2, df=n - 1)  # appropriate t value
+    s_err = np.sum(np.power(y_err, 2))  # sum of the squares of the residuals
+
+    confs = t * np.sqrt((s_err / (n - 2)) * (1.0 / n + (np.power((p_x - mean_x), 2) /
+                                                        ((np.sum(np.power(x, 2))) - n * (np.power(mean_x, 2))))))
+
+    # now predict y based on test x-values
+    p_y = slope * p_x + intercept
+
+    # get lower and upper confidence limits based on predicted y and confidence intervals
+    lower = p_y - abs(confs)
+    upper = p_y + abs(confs)
+
+    # plot confidence limits
+    ax.plot(p_x, lower, linestyle='--', color='gray', linewidth=1.2)
+    ax.plot(p_x, upper, linestyle='--', color='gray', linewidth=1.2)
+
+
+def plot_sta_delay(df, wn_df, pos_df, axis):
+
+    rel_pos, pos, phase_delays, wn_delays = get_delays(df, wn_df, pos_df)
+
+    axis.scatter(pos, wn_delays * 1000, s=10, marker=".")
+
+    # Plot linear regression
+    slope, intercept, *params = stats.linregress(pos, wn_delays * 1000)
+    x_range = np.arange(pos.min(), pos.max(), 0.1)
+    y_fit = intercept + slope * x_range
+    axis.plot(x_range, y_fit, color="black", ls="-",
+              label=f"slope:{1./slope:.1f}m/s, r:{params[0]:.2f}, p:{params[1]:.3f}")
+
+    # Plot CI
+    plot_ci(axis, pos, wn_delays, intercept, slope)
+
+    axis.legend(loc='lower right')
+    axis.set_xlim(20, 120)
+    axis.set_xlabel("receptor position [mm]")
+
+    axis.set_ylim(0, 8)
+    axis.set_ylabel("STA delay [ms]")
+
+
+def plot_both_delays(df, wn_df, pos_df, axis):
+
+    rel_pos, pos, phase_delays, wn_delays = get_delays(df, wn_df, pos_df)
+
+    axis.scatter(phase_delays * 1000, wn_delays * 1000, s=10, marker=".")
+
+    slope, intercept, *params = stats.linregress(phase_delays * 1000, wn_delays * 1000)
+    x_range = np.arange(phase_delays.min() * 1000, phase_delays.max() * 1000, 0.1)
+    y_fit = intercept + slope * x_range
+    axis.plot(x_range, y_fit, color="black", ls="-",
+              label=f"slope:{slope:.1f}, r:{params[0]:.2f}, p:{params[1]:.3f}")
+
+    axis.legend(loc='lower right')
+    axis.set_ylim(np.floor(wn_delays.min() * 1000), np.ceil(wn_delays.max() * 1000))
+    axis.set_ylabel("STA delay [ms]")
+    axis.set_xlabel('phase delay [ms]')
+    axis.set_xlim(np.floor(phase_delays.min() * 1000), np.ceil(phase_delays.max() * 1000))
+
+
+def layout_figure():
+    gs = plt.GridSpec(3, 2)
+    fig = plt.figure(figsize=(6.5, 4.5))
+
+    axes = []
+    axes.append(fig.add_subplot(gs[0, 0]))
+    axes.append(fig.add_subplot(gs[1, 0]))
+    axes.append(fig.add_subplot(gs[2, 0]))
+    axes.append(fig.add_subplot(gs[0, 1]))
+    axes.append(fig.add_subplot(gs[1:, 1]))
+
+    axes[0].text(-.15, 1.08, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
+                    transform=axes[0].transAxes)
+    axes[3].text(-.15, 1.08, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
+                    transform=axes[3].transAxes)
+    axes[4].text(-.15, 1.08, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
+                    transform=axes[4].transAxes)
+    despine(axes[0], ["top", "right"], False)
+    despine(axes[1], ["top", "right"], False)
+    despine(axes[2], ["top", "right"], False)
+    despine(axes[3], ["top", "right"], False)
+    despine(axes[4], ["top", "right"], False)
+
+    fig.subplots_adjust(left=0.07, top=0.94, bottom=0.10, right=0.98, hspace=0.5, wspace=0.3)
+    return fig, axes
+
+
+def set_aspect_quad(ax):
+    ylim = ax.get_ylim()
+    xlim = ax.get_xlim()
+    ax.set_aspect((xlim[1] - xlim[0]) / (ylim[1] - ylim[0]))
+
+
+def phase_analysis(args):
+    if not os.path.exists(args.baseline_data_frame):
+        raise ValueError(f"Baseline data could not be found! ({args.baseline_data_frame})")
+    df = pd.read_csv(args.baseline_data_frame, sep=";", index_col=0)
+    wn_df = pd.read_csv(args.whitenoise_data_frame, sep=";", index_col=0)
+    pos_df = pd.read_csv(args.position_data_frame, sep=";", index_col=0)
+    if args.redo:
+        velocity_analysis()
+
+    fig, axes = layout_figure()
+
+    plot_phases_raw(df, axes[0])
+    plot_phases_shifted(df, axes[1])
+    plot_baseline_delay(df, axes[2])
+    plot_sta_delay(df, wn_df, pos_df, axes[3])
+    plot_both_delays(df, wn_df, pos_df, axes[4])
+    # set_aspect_quad(axes[4])
+    # fig.tight_layout()
+    if args.nosave:
+        plt.show()
+    else:
+        fig.savefig(args.outfile, dpi=500)
+        plt.close()
+
+
+def command_line_parser(subparsers):
+    parser = subparsers.add_parser("phase_delays", help="Phase delays figure: Plots illustration conduction delays based on phase lock shift and whitenoise STA")
+    parser.add_argument("-bdf", "--baseline_data_frame", default=os.path.join("derived_data","figure2_baseline_properties.csv"))
+    parser.add_argument("-wndf", "--whitenoise_data_frame", default=os.path.join("derived_data","whitenoise_trials.csv"))
+    parser.add_argument("-posdf", "--position_data_frame", default=os.path.join("derived_data","receptivefield_positions.csv"))
+    parser.add_argument("-r", "--redo", action="store_true", help="Redo the velocity analysis. Depends on figure2_baseline_properties.csv")
+    parser.add_argument("-vel", "--velocity_data", default=os.path.join("derived_data", "suppfig5_velocities.npz"))
+    parser.add_argument("-o", "--outfile", default=os.path.join("figures", "phase_delays.pdf"))
+    parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") 
+    parser.set_defaults(func=phase_analysis)

+ 2 - 0
plot_figures.py

@@ -11,6 +11,7 @@ from code.plots.supp_figure5 import command_line_parser as supfig5_parser
 from code.plots.intro_figure import command_line_parser as introfig_parser
 from code.plots.populations_method import command_line_parser as popmethods_parser
 from code.plots.intro_figure2 import command_line_parser as introfig2_parser
+from code.plots.phase_delays import command_line_parser as phase_delays_parser
 
 def create_parser():
     parser = argparse.ArgumentParser(description="Tool for plotting figures of the Hladnik & Grewe population coding project.")
@@ -20,6 +21,7 @@ def create_parser():
 
     introfig_parser(subparsers)
     introfig2_parser(subparsers)
+    phase_delays_parser(subparsers)
     analysis_parser(subparsers)
     correlations_parser(subparsers)
     popcoding_parser(subparsers)