supp_figure5.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. from .figure_style import subfig_labelsize, subfig_labelweight, despine
  6. from .property_correlations import plotLinregress
  7. from .supp_figure5_analysis import main as velocity_analysis
  8. class Colorspace():
  9. def __init__(self, valMin=0, valMax=1, schema='viridis'):
  10. # define colorspace
  11. self.cMap = plt.get_cmap(schema)
  12. self.colors = np.asarray(self.cMap.colors)
  13. self.values = np.linspace(valMin, valMax, self.cMap.N)
  14. def getColor(self, val):
  15. if isinstance(val, (float, int)):
  16. if val < self.values[0] or val > self.values[-1]:
  17. print('WARNING: exceeding valid range')
  18. return self.colors[np.searchsorted(self.values, val)]
  19. elif isinstance(val, (list, np.ndarray)):
  20. if len(val) > 1:
  21. return [self.getColor(val[0]), *self.getColor(val[1:])]
  22. else:
  23. return [self.getColor(val[0])]
  24. else:
  25. return None
  26. def getLinspacedColors(self, num=None):
  27. if num is None:
  28. num = self.colors.shape[0]
  29. return self.getColor(np.linspace(self.values[0], self.values[-1], num))
  30. def plotColorbar(self, ax, vertical=True):
  31. y = np.asarray([self.values, self.values])
  32. if vertical:
  33. y = y.T
  34. extent = [0, 1, self.values[-1], self.values[0]]
  35. else:
  36. extent = [self.values[0], self.values[-1], 0, 1]
  37. ax.imshow(
  38. y,
  39. cmap=self.cMap,
  40. aspect='auto',
  41. extent=extent
  42. )
  43. if vertical:
  44. ax.set_xticks([])
  45. else:
  46. ax.set_yticks([])
  47. def plot_velocities(args, axis):
  48. if not os.path.exists(args.velocity_data):
  49. raise ValueError(f"Velocity data file not found! {args.velocity_data}")
  50. data = np.load(args.velocity_data)
  51. position_differences = data["position_differences"]
  52. velocities = data["velocities"]
  53. average_velocities = data["average_velocities"]
  54. position_difference_centers = data["position_difference_centers"]
  55. cspace = Colorspace()
  56. velcolors = cspace.getLinspacedColors(10)
  57. for i in range(position_differences.shape[0]):
  58. axis.scatter(position_differences[i, np.isfinite(position_differences[i,:])],
  59. velocities[i,np.isfinite(position_differences[i,:])],
  60. s=5, color=velcolors[i]
  61. )
  62. axis.axhline(48.3, linestyle='--', color='black', linewidth=1, label=rf'$v_{{m}} =$ {48.3:.1f} $\frac{{m}}{{s}}$')
  63. axis.plot(position_difference_centers, average_velocities, color='black', linewidth=1)
  64. axis.set_yscale("log")
  65. axis.set_ylim([5, 3000])
  66. axis.set_xlim([0, 100])
  67. axis.set_xlabel('$\Delta$position [mm]')
  68. axis.set_ylabel('Velocity [m/s]')
  69. axis.legend(frameon=False, loc='upper right')
  70. axis.set_xticks(np.arange(0, 101, 5), minor=True)
  71. def plot_phase_data(positions, phases, axis, color, show_centroids=False):
  72. axis.scatter(positions, phases, color=color, s=10, marker=".")
  73. if show_centroids:
  74. centroid_x = np.mean(positions)
  75. centroid_y = np.mean(phases)
  76. axis.scatter(centroid_x, centroid_y, marker="+", s=15, color=color)
  77. def plot_phases_raw(df, axis):
  78. colors = ["tab:blue", "tab:orange"]
  79. cluster_labels = [0, 1]
  80. for label, color in zip(cluster_labels, colors):
  81. selection_phases = df.phase[df.kmeans_label == label]
  82. selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
  83. plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=True)
  84. # axis.set_xlabel("receptor position [rel.]")
  85. axis.set_ylabel("phase [rad]")
  86. axis.set_xlim([0, 1.0])
  87. axis.set_ylim([0, 2*np.pi])
  88. axis.set_yticks(np.arange(0, 2 * np.pi +0.1, np.pi))
  89. axis.set_yticks(np.arange(0, 2 * np.pi +0.1, np.pi/4), minor=True)
  90. axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$"])
  91. axis.set_xticks(np.arange(0, 1.01, 0.2))
  92. axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
  93. axis.set_xticklabels([])
  94. def plot_phases_shifted(df, axis):
  95. colors = ["tab:blue", "tab:orange"]
  96. cluster_labels = [0, 1]
  97. for label, color in zip(cluster_labels, colors):
  98. selection_phases = df.phase_shifted[df.kmeans_label == label]
  99. selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
  100. plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
  101. plotLinregress(axis, df.receptor_pos_relative, df.phase_shifted, 1, "black")
  102. axis.set_xlabel("receptor position [rel.]")
  103. axis.set_ylabel("phase shifted [rad]")
  104. def layout_figure():
  105. fig, axes = plt.subplots(ncols=1, nrows=3, figsize=(3.5, 5))
  106. axes[0].text(-.25, 1.125, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  107. transform=axes[0].transAxes)
  108. axes[0].text(-.25, 1.125, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  109. transform=axes[1].transAxes)
  110. axes[2].text(-.25, 1.125, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  111. transform=axes[2].transAxes)
  112. despine(axes[0], ["top", "right"], False)
  113. despine(axes[1], ["top", "right"], False)
  114. despine(axes[2], ["top", "right"], False)
  115. fig.subplots_adjust(left=0.2, top=0.97, right=0.9, hspace=0.35)
  116. return fig, axes
  117. def phase_analysis(args):
  118. if not os.path.exists(args.baseline_data_frame):
  119. raise ValueError(f"Baseline data could not be found! ({args.baseline_data_frame})")
  120. df = pd.read_csv(args.baseline_data_frame, sep=";", index_col=0)
  121. if args.redo:
  122. velocity_analysis()
  123. fig, axes = layout_figure()
  124. plot_velocities(args, axes[2])
  125. plot_phases_raw(df, axes[0])
  126. plot_phases_shifted(df, axes[1])
  127. if args.nosave:
  128. plt.show()
  129. else:
  130. fig.savefig(args.outfile, dpi=500)
  131. plt.close()
  132. def command_line_parser(subparsers):
  133. parser = subparsers.add_parser("supfig5", help="Supplementary figure 5: Plots clustering and conduction velocity estimation.")
  134. parser.add_argument("-bdf", "--baseline_data_frame", default=os.path.join("derived_data","figure2_baseline_properties.csv"))
  135. parser.add_argument("-r", "--redo", action="store_true", help="Redo the velocity analysis. Depends on figure2_baseline_properties.csv")
  136. parser.add_argument("-vel", "--velocity_data", default=os.path.join("derived_data", "suppfig5_velocities.npz"))
  137. parser.add_argument("-o", "--outfile", default=os.path.join("figures", "phase_analysis.pdf"))
  138. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  139. parser.set_defaults(func=phase_analysis)