phase_delays.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. from scipy.stats import stats
  6. from scipy.stats import t as tstat
  7. from .figure_style import subfig_labelsize, subfig_labelweight, despine
  8. from .supp_figure5_analysis import main as velocity_analysis
  9. def get_delays(df, wn_df, pos_df):
  10. datasets = pos_df.dataset_id.unique()
  11. rel_positions = []
  12. positions = []
  13. wn_delays = []
  14. phase_delays = []
  15. for dataset in datasets:
  16. rel_pos = pos_df.receptor_pos_relative[pos_df.dataset_id == dataset].values[0]
  17. pos = pos_df.receptor_pos_absolute[pos_df.dataset_id == dataset].values[0]
  18. trials = wn_df[wn_df.dataset_id == dataset]
  19. if len(trials) < 1:
  20. continue
  21. # Get delay based on whitenoise STAs
  22. wn_delay = np.mean(trials.delay)
  23. wn_delays.append(wn_delay)
  24. # Get delay based on shifted EOD phase lock and period
  25. phase = df[df.dataset_id == dataset].phase_shifted.values[0]
  26. eod_period = df[df.dataset_id == dataset].eod_period.values[0]
  27. phase_delay = phase * eod_period / (2 * np.pi)
  28. phase_delays.append(phase_delay)
  29. rel_positions.append(rel_pos)
  30. positions.append(pos)
  31. wn_delays = np.array(wn_delays)
  32. phase_delays = np.array(phase_delays)
  33. rel_positions = np.array(rel_positions)
  34. positions = np.array(positions)
  35. return rel_positions, positions, phase_delays, wn_delays
  36. def plot_phase_data(positions, phases, axis, color, show_centroids=False):
  37. axis.scatter(positions, phases, color=color, s=10, marker=".")
  38. if show_centroids:
  39. centroid_x = np.mean(positions)
  40. centroid_y = np.mean(phases)
  41. axis.scatter(centroid_x, centroid_y, marker="+", s=15, color=color, linewidth=.8)
  42. def plot_phases_raw(df, axis):
  43. colors = ["tab:blue", "tab:orange"]
  44. cluster_labels = [0, 1]
  45. for label, color in zip(cluster_labels, colors):
  46. selection_phases = df.phase[df.kmeans_label == label]
  47. selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
  48. plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=True)
  49. axis.set_xlim([0, 1.0])
  50. axis.set_xlabel("receptor position [rel.]")
  51. axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
  52. axis.set_ylim([0, 2*np.pi])
  53. axis.set_ylabel("phase [rad]")
  54. axis.set_yticks(np.arange(0, 2 * np.pi + 0.1, np.pi))
  55. axis.set_yticks(np.arange(0, 2 * np.pi + 0.1, np.pi/4), minor=True)
  56. axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$"])
  57. def plot_phases_shifted(df, axis):
  58. colors = ["tab:blue", "tab:orange"]
  59. cluster_labels = [0, 1]
  60. for label, color in zip(cluster_labels, colors):
  61. selection_phases = df.phase_shifted[df.kmeans_label == label]
  62. selection_positions = df.receptor_pos_relative[df.kmeans_label == label]
  63. plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
  64. slope, intercept, *params = stats.linregress(df.receptor_pos_relative, df.phase_shifted)
  65. x_range = np.arange(df.receptor_pos_relative.min(), df.receptor_pos_relative.max(), 0.1)
  66. y_fit = intercept + slope * x_range
  67. axis.plot(x_range, y_fit, color="black", ls="-",
  68. label=f"r:{params[0]:.2f}, p:{params[1]:.3f}")
  69. axis.legend(loc='lower right')
  70. axis.set_xlim([0, 1.0])
  71. axis.set_xlabel("receptor position [rel.]")
  72. axis.set_xticks(np.arange(0, 1.01, 0.05), minor=True)
  73. axis.set_ylim([0, 4*np.pi])
  74. axis.set_ylabel("phase + " + r"$2\pi$" + " [rad]")
  75. axis.set_yticks(np.arange(0, 4 * np.pi + 0.1, np.pi))
  76. axis.set_yticks(np.arange(0, 4 * np.pi + 0.1, np.pi/4), minor=True)
  77. axis.set_yticklabels([r"$0$", r"$\pi$", r"$2\pi$", r"$3\pi$", r"$4\pi$"])
  78. def plot_baseline_delay(df, axis):
  79. colors = ["tab:blue", "tab:orange"]
  80. cluster_labels = [0, 1]
  81. for label, color in zip(cluster_labels, colors):
  82. selection_phases = df.phase_time[df.kmeans_label == label] * 1000
  83. selection_positions = df.receptor_pos_absolute[df.kmeans_label == label]
  84. plot_phase_data(selection_positions, selection_phases, axis, color, show_centroids=False)
  85. # Fit
  86. slope, intercept, *params = stats.linregress(df.receptor_pos_absolute, df.phase_time * 1000)
  87. x_range = np.arange(df.receptor_pos_absolute.min(), df.receptor_pos_absolute.max(), 0.1)
  88. y_fit = intercept + slope * x_range
  89. axis.plot(x_range, y_fit, color="black", ls="-",
  90. label=f"slope:{1./slope:.1f}m/s, r:{params[0]:.2f}, p:{params[1]:.3f}")
  91. axis.legend(loc='lower right')
  92. axis.set_xlim(20, 120)
  93. axis.set_xlabel("receptor position [mm]")
  94. axis.set_ylim([0, 3])
  95. axis.set_ylabel("phase delay [ms]")
  96. def plot_ci(ax, x, y, intercept, slope):
  97. """After Tomas Holderness' implementation:
  98. https://tomholderness.wordpress.com/2013/01/10/confidence_intervals/"""
  99. # linfit.py - example of confidence limit calculation for linear regression fitting.
  100. # References:
  101. # - Statistics in Geography by David Ebdon (ISBN: 978-0631136880)
  102. # - Reliability Engineering Resource Website:
  103. # - http://www.weibull.com/DOEWeb/confidence_intervals_in_simple_linear_regression.htm
  104. # - University of Glascow, Department of Statistics:
  105. # - http://www.stats.gla.ac.uk/steps/glossary/confidence_intervals.html#conflim
  106. # fit a curve to the data using a least squares 1st order polynomial fit
  107. fit = intercept + slope * x
  108. # predict y values of origional data using the fit
  109. p_y = slope * x + intercept
  110. # calculate the y-error (residuals)
  111. y_err = y - p_y
  112. # create series of new test x-values to predict for
  113. p_x = np.arange(np.min(x), np.max(x) + 1, 1)
  114. # now calculate confidence intervals for new test x-series
  115. mean_x = np.mean(x) # mean of x
  116. n = len(x) # number of samples in origional fit
  117. t = tstat.ppf(1. - 0.1 / 2, df=n - 1) # appropriate t value
  118. s_err = np.sum(np.power(y_err, 2)) # sum of the squares of the residuals
  119. confs = t * np.sqrt((s_err / (n - 2)) * (1.0 / n + (np.power((p_x - mean_x), 2) /
  120. ((np.sum(np.power(x, 2))) - n * (np.power(mean_x, 2))))))
  121. # now predict y based on test x-values
  122. p_y = slope * p_x + intercept
  123. # get lower and upper confidence limits based on predicted y and confidence intervals
  124. lower = p_y - abs(confs)
  125. upper = p_y + abs(confs)
  126. # plot confidence limits
  127. ax.plot(p_x, lower, linestyle='--', color='gray', linewidth=1.2)
  128. ax.plot(p_x, upper, linestyle='--', color='gray', linewidth=1.2)
  129. def plot_sta_delay(df, wn_df, pos_df, axis):
  130. rel_pos, pos, phase_delays, wn_delays = get_delays(df, wn_df, pos_df)
  131. axis.scatter(pos, wn_delays * 1000, s=10, marker=".")
  132. # Plot linear regression
  133. slope, intercept, *params = stats.linregress(pos, wn_delays * 1000)
  134. x_range = np.arange(pos.min(), pos.max(), 0.1)
  135. y_fit = intercept + slope * x_range
  136. axis.plot(x_range, y_fit, color="black", ls="-",
  137. label=f"slope:{1./slope:.1f}m/s, r:{params[0]:.2f}, p:{params[1]:.3f}")
  138. # Plot CI
  139. plot_ci(axis, pos, wn_delays, intercept, slope)
  140. axis.legend(loc='lower right')
  141. axis.set_xlim(20, 120)
  142. axis.set_xlabel("receptor position [mm]")
  143. axis.set_ylim(0, 8)
  144. axis.set_ylabel("STA delay [ms]")
  145. def plot_both_delays(df, wn_df, pos_df, axis):
  146. rel_pos, pos, phase_delays, wn_delays = get_delays(df, wn_df, pos_df)
  147. axis.scatter(phase_delays * 1000, wn_delays * 1000, s=10, marker=".")
  148. slope, intercept, *params = stats.linregress(phase_delays * 1000, wn_delays * 1000)
  149. x_range = np.arange(phase_delays.min() * 1000, phase_delays.max() * 1000, 0.1)
  150. y_fit = intercept + slope * x_range
  151. axis.plot(x_range, y_fit, color="black", ls="-",
  152. label=f"slope:{slope:.1f}, r:{params[0]:.2f}, p:{params[1]:.3f}")
  153. axis.legend(loc='lower right')
  154. axis.set_ylim(np.floor(wn_delays.min() * 1000), np.ceil(wn_delays.max() * 1000))
  155. axis.set_ylabel("STA delay [ms]")
  156. axis.set_xlabel('phase delay [ms]')
  157. axis.set_xlim(np.floor(phase_delays.min() * 1000), np.ceil(phase_delays.max() * 1000))
  158. def layout_figure():
  159. gs = plt.GridSpec(3, 2)
  160. fig = plt.figure(figsize=(6.5, 4.5))
  161. axes = []
  162. axes.append(fig.add_subplot(gs[0, 0]))
  163. axes.append(fig.add_subplot(gs[1, 0]))
  164. axes.append(fig.add_subplot(gs[2, 0]))
  165. axes.append(fig.add_subplot(gs[0, 1]))
  166. axes.append(fig.add_subplot(gs[1:, 1]))
  167. axes[0].text(-.15, 1.08, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  168. transform=axes[0].transAxes)
  169. axes[3].text(-.15, 1.08, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  170. transform=axes[3].transAxes)
  171. axes[4].text(-.15, 1.08, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  172. transform=axes[4].transAxes)
  173. despine(axes[0], ["top", "right"], False)
  174. despine(axes[1], ["top", "right"], False)
  175. despine(axes[2], ["top", "right"], False)
  176. despine(axes[3], ["top", "right"], False)
  177. despine(axes[4], ["top", "right"], False)
  178. fig.subplots_adjust(left=0.07, top=0.94, bottom=0.10, right=0.98, hspace=0.5, wspace=0.3)
  179. return fig, axes
  180. def set_aspect_quad(ax):
  181. ylim = ax.get_ylim()
  182. xlim = ax.get_xlim()
  183. ax.set_aspect((xlim[1] - xlim[0]) / (ylim[1] - ylim[0]))
  184. def phase_analysis(args):
  185. if not os.path.exists(args.baseline_data_frame):
  186. raise ValueError(f"Baseline data could not be found! ({args.baseline_data_frame})")
  187. df = pd.read_csv(args.baseline_data_frame, sep=";", index_col=0)
  188. wn_df = pd.read_csv(args.whitenoise_data_frame, sep=";", index_col=0)
  189. pos_df = pd.read_csv(args.position_data_frame, sep=";", index_col=0)
  190. if args.redo:
  191. velocity_analysis()
  192. fig, axes = layout_figure()
  193. plot_phases_raw(df, axes[0])
  194. plot_phases_shifted(df, axes[1])
  195. plot_baseline_delay(df, axes[2])
  196. plot_sta_delay(df, wn_df, pos_df, axes[3])
  197. plot_both_delays(df, wn_df, pos_df, axes[4])
  198. # set_aspect_quad(axes[4])
  199. # fig.tight_layout()
  200. if args.nosave:
  201. plt.show()
  202. else:
  203. fig.savefig(args.outfile, dpi=500)
  204. plt.close()
  205. def command_line_parser(subparsers):
  206. parser = subparsers.add_parser("phase_delays", help="Phase delays figure: Plots illustration conduction delays based on phase lock shift and whitenoise STA")
  207. parser.add_argument("-bdf", "--baseline_data_frame", default=os.path.join("derived_data","figure2_baseline_properties.csv"))
  208. parser.add_argument("-wndf", "--whitenoise_data_frame", default=os.path.join("derived_data","whitenoise_trials.csv"))
  209. parser.add_argument("-posdf", "--position_data_frame", default=os.path.join("derived_data","receptivefield_positions.csv"))
  210. parser.add_argument("-r", "--redo", action="store_true", help="Redo the velocity analysis. Depends on figure2_baseline_properties.csv")
  211. parser.add_argument("-vel", "--velocity_data", default=os.path.join("derived_data", "suppfig5_velocities.npz"))
  212. parser.add_argument("-o", "--outfile", default=os.path.join("figures", "phase_delays.pdf"))
  213. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  214. parser.set_defaults(func=phase_analysis)