123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import os
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- from .figure_style import subfig_labelsize, subfig_labelweight, despine
- from .property_correlations import plotLinregress
- from .supp_figure5_analysis import main as velocity_analysis
- class Colorspace():
- def __init__(self, valMin=0, valMax=1, schema='viridis'):
- # define colorspace
- self.cMap = plt.get_cmap(schema)
- self.colors = np.asarray(self.cMap.colors)
- self.values = np.linspace(valMin, valMax, self.cMap.N)
- def getColor(self, val):
- if isinstance(val, (float, int)):
- if val < self.values[0] or val > self.values[-1]:
- print('WARNING: exceeding valid range')
- return self.colors[np.searchsorted(self.values, val)]
- elif isinstance(val, (list, np.ndarray)):
- if len(val) > 1:
- return [self.getColor(val[0]), *self.getColor(val[1:])]
- else:
- return [self.getColor(val[0])]
- else:
- return None
-
- def getLinspacedColors(self, num=None):
- if num is None:
- num = self.colors.shape[0]
- return self.getColor(np.linspace(self.values[0], self.values[-1], num))
- def plotColorbar(self, ax, vertical=True):
- y = np.asarray([self.values, self.values])
- if vertical:
- y = y.T
- extent = [0, 1, self.values[-1], self.values[0]]
- else:
- extent = [self.values[0], self.values[-1], 0, 1]
-
- ax.imshow(
- y,
- cmap=self.cMap,
- aspect='auto',
- extent=extent
- )
- if vertical:
- ax.set_xticks([])
- else:
- ax.set_yticks([])
- def plot_velocities(args, axis):
- if not os.path.exists(args.velocity_data):
- raise ValueError(f"Velocity data file not found! {args.velocity_data}")
- data = np.load(args.velocity_data)
- position_differences = data["position_differences"]
- velocities = data["velocities"]
- average_velocities = data["average_velocities"]
- position_difference_centers = data["position_difference_centers"]
- cspace = Colorspace()
- velcolors = cspace.getLinspacedColors(10)
- for i in range(position_differences.shape[0]):
- axis.scatter(position_differences[i, np.isfinite(position_differences[i,:])],
- velocities[i,np.isfinite(position_differences[i,:])],
- s=5, color=velcolors[i]
- )
- axis.axhline(48.3, linestyle='--', color='black', linewidth=1, label=rf'$v_{{m}} =$ {48.3:.1f} $\frac{{m}}{{s}}$')
- axis.plot(position_difference_centers, average_velocities, color='black', linewidth=1)
- axis.set_yscale("log")
- axis.set_ylim([5, 3000])
- axis.set_xlim([0, 100])
- axis.set_xlabel('$\Delta$position [mm]')
- axis.set_ylabel('Velocity [m/s]')
- axis.legend(frameon=False, loc='upper right')
- axis.set_xticks(np.arange(0, 101, 5), minor=True)
- def plot_phase_data(positions, phases, axis, color, show_centroids=False):
- axis.scatter(positions, phases, color=color, s=10, marker=".")
- if show_centroids:
- centroid_x = np.mean(positions)
- centroid_y = np.mean(phases)
- axis.scatter(centroid_x, centroid_y, marker="+", s=15, color=color)
- def plot_phases_raw(df, axis):
- colors = ["tab:blue", "tab:orange"]
- cluster_labels = [0, 1]
- for label, color in zip(cluster_labels, colors):
- selection_phases = df.phase[df.kmeans_label == label]
- selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
- plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=True)
- # axis.set_xlabel("receptor position [rel.]")
- axis.set_ylabel("phase [rad]")
- axis.set_xlim([0, 1.0])
- axis.set_ylim([0, 2*np.pi])
- axis.set_yticks(np.arange(0, 2 * np.pi +0.1, np.pi))
- axis.set_yticks(np.arange(0, 2 * np.pi +0.1, np.pi/4), minor=True)
- axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$"])
- axis.set_xticks(np.arange(0, 1.01, 0.2))
- axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
- axis.set_xticklabels([])
- def plot_phases_shifted(df, axis):
- colors = ["tab:blue", "tab:orange"]
- cluster_labels = [0, 1]
- for label, color in zip(cluster_labels, colors):
- selection_phases = df.phase_shifted[df.kmeans_label == label]
- selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
- plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
- plotLinregress(axis, df.receptor_pos_relative, df.phase_shifted, 1, "black")
- axis.set_xlabel("receptor position [rel.]")
- axis.set_ylabel("phase shifted [rad]")
- def layout_figure():
- fig, axes = plt.subplots(ncols=1, nrows=3, figsize=(3.5, 5))
- axes[0].text(-.25, 1.125, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=axes[0].transAxes)
- axes[0].text(-.25, 1.125, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=axes[1].transAxes)
- axes[2].text(-.25, 1.125, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=axes[2].transAxes)
- despine(axes[0], ["top", "right"], False)
- despine(axes[1], ["top", "right"], False)
- despine(axes[2], ["top", "right"], False)
- fig.subplots_adjust(left=0.2, top=0.97, right=0.9, hspace=0.35)
- return fig, axes
- def phase_analysis(args):
- if not os.path.exists(args.baseline_data_frame):
- raise ValueError(f"Baseline data could not be found! ({args.baseline_data_frame})")
- df = pd.read_csv(args.baseline_data_frame, sep=";", index_col=0)
- if args.redo:
- velocity_analysis()
- fig, axes = layout_figure()
- plot_velocities(args, axes[2])
- plot_phases_raw(df, axes[0])
- plot_phases_shifted(df, axes[1])
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile, dpi=500)
- plt.close()
- def command_line_parser(subparsers):
- parser = subparsers.add_parser("supfig5", help="Supplementary figure 5: Plots clustering and conduction velocity estimation.")
- parser.add_argument("-bdf", "--baseline_data_frame", default=os.path.join("derived_data","figure2_baseline_properties.csv"))
- parser.add_argument("-r", "--redo", action="store_true", help="Redo the velocity analysis. Depends on figure2_baseline_properties.csv")
- parser.add_argument("-vel", "--velocity_data", default=os.path.join("derived_data", "suppfig5_velocities.npz"))
- parser.add_argument("-o", "--outfile", default=os.path.join("figures", "phase_analysis.pdf"))
- parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- parser.set_defaults(func=phase_analysis)
|