intro_figure.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. ###############################################################################
  2. ## introductory plot ##
  3. import os
  4. import numpy as np
  5. import pandas as pd
  6. import scipy.stats as stats
  7. import matplotlib.pyplot as plt
  8. import matplotlib.image as mpimg
  9. from .figure_style import subfig_labelsize, subfig_labelweight, despine
  10. def plot_linregress(ax, x, y, bonferroni_factor, color="tab:red", label="", feature=""):
  11. xRange = np.asarray([x.min(), x.max()])
  12. slope, intercept, *params = stats.linregress(x, y)
  13. p_star = params[1] * bonferroni_factor
  14. if p_star > 1:
  15. p_star = 1
  16. l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p", params[1] if bonferroni_factor == 1 else p_star)
  17. ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-")
  18. ax.legend(loc='upper left', frameon=False, ncol=1,fontsize=6, bbox_to_anchor=(0., 1.25))
  19. print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y)))
  20. return slope, intercept
  21. def plotMeanProperty(ax, bins, x, y):
  22. my = np.zeros(len(bins)-1)
  23. sy = np.zeros(len(bins)-1)
  24. for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
  25. yy = y[(x >= start) & (x < end)]
  26. if len(yy) > 0:
  27. my[i] = np.mean(yy)
  28. sy[i] = np.std(yy)
  29. else:
  30. my[i] = np.nan
  31. sy[i] = np.nan
  32. centers = bins[:-1]+(bins[1]-bins[0])/2
  33. ax.plot(centers, my, '--', lw=0.5, color='black')
  34. ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1)
  35. return my
  36. def add_model_sketch(axis):
  37. if os.path.exists(os.path.join("figures", "delay_problem_sketch.png")):
  38. img = mpimg.imread(os.path.join("figures", "delay_problem_sketch.png"))
  39. axis.imshow(img)
  40. axis.set_xticklabels([])
  41. axis.set_yticklabels([])
  42. despine(axis, ["top", "bottom", "left", "right"], True)
  43. def add_fish_sketch(axis):
  44. if os.path.exists(os.path.join("figures", "fishsketch.png")):
  45. img = mpimg.imread(os.path.join("figures", "fishsketch_tight.png"))
  46. axis.imshow(img)
  47. despine(axis, ["top", "bottom", "left", "right"], True)
  48. else:
  49. raise FileNotFoundError("Missing fishsketch.png file.")
  50. def plot_delay(axis, args):
  51. if not os.path.exists(args.baseline_frame):
  52. raise ValueError(f"Results data frame not found! {args.baseline_frame})")
  53. df = pd.read_csv(args.baseline_frame, sep=";", index_col=0)
  54. absolute_positions = df.receptor_pos_absolute.values
  55. delay_times = df.phase_time.values * 1000
  56. axis.scatter(absolute_positions, delay_times, color="tab:blue", s=1)
  57. m, n = plot_linregress(axis, absolute_positions, delay_times, bonferroni_factor=4, feature="phase")
  58. # plotMeanProperty(axis, np.linspace(0, 120, 12), absolute_positions, delay_times)
  59. axis.set_yticks([0.0, 1.0, 2.0])
  60. axis.set_xlim(0.0, 120.0)
  61. despine(axis, ["top", "right"], False)
  62. axis.set_xlabel("receptor position [mm]", fontsize=8)
  63. axis.set_ylabel("delay [ms]", fontsize=8)
  64. def layout_figure():
  65. fig = plt.figure(figsize=(3.42, 1.4)) # , constrained_layout=True)
  66. axes = []
  67. fig_grid = (8, 8)
  68. axes.append(plt.subplot2grid(fig_grid, (0, 0), 8, 4))
  69. axes.append(plt.subplot2grid(fig_grid, (0, 5), 2, 3))
  70. axes.append(plt.subplot2grid(fig_grid, (2, 5), 3, 3))
  71. axes[0].text(0.05, 1.00, "A", transform=axes[0].transAxes, ha="center",fontsize=subfig_labelsize, fontweight=subfig_labelweight)
  72. axes[1].text(-0.35, 0.9, "B", transform=axes[1].transAxes, ha="center",fontsize=subfig_labelsize, fontweight=subfig_labelweight)
  73. fig.subplots_adjust(left=-0.0, top=0.95, bottom=0.00, right=0.95)
  74. return fig, axes
  75. def introductory_figure(args):
  76. fig, axes = layout_figure()
  77. add_model_sketch(axes[0])
  78. #add_distributions(axes[1], axes[2], axes[3])
  79. add_fish_sketch(axes[1])
  80. plot_delay(axes[-1], args)
  81. if args.nosave:
  82. plt.show()
  83. else:
  84. fig.savefig(args.outfile, dpi=500)
  85. plt.close()
  86. def command_line_parser(subparsers):
  87. default_bf = os.path.join("derived_data", "figure2_baseline_properties.csv")
  88. parser = subparsers.add_parser("intro_figure", help="Introductory plot (figure 1).")
  89. parser.add_argument("-bf", "--baseline_frame", default=default_bf,
  90. help=f"Full file name of a CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).")
  91. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","delay_problem.pdf"), help="The filename of the figure")
  92. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  93. parser.set_defaults(func=introductory_figure)