intro_figure2.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. ###############################################################################
  2. ## introductory plot ##
  3. import os
  4. import numpy as np
  5. import scipy.stats as stats
  6. import matplotlib.pyplot as plt
  7. from matplotlib import gridspec
  8. from matplotlib.patches import Ellipse
  9. from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size
  10. class Neuron:
  11. def __init__(self, x, y, w, h, **kwargs) -> None:
  12. if "text" in kwargs:
  13. del kwargs["text"]
  14. print("Text found!")
  15. self._ellipse = Ellipse([x,y], w, h, **kwargs)
  16. self._height = h
  17. @property
  18. def patch(self):
  19. return self._ellipse
  20. @property
  21. def x(self):
  22. return self._ellipse.center[0]
  23. @property
  24. def y(self):
  25. return self._ellipse.center[1]
  26. @property
  27. def height(self):
  28. return self._height
  29. def connect(self, other, axis, num=1, maxnum=1, **kwargs):
  30. dest = self.destination(other, num, max_connections=maxnum)
  31. axis.plot([self.x, dest[0]],[dest[1], dest[1]], **kwargs)
  32. axis.plot([self.x, self.x], [self.y, dest[1]], **kwargs)
  33. def destination(self, other, num=1, max_connections=1):
  34. y_positions = np.linspace(other.y + other.height/2*2/3, other.y - other.height/2*2/3, max_connections)
  35. dest = [other.x, y_positions[num-1]]
  36. return dest
  37. def multivariate_gaussian(pos, mu, sigma):
  38. """Return the multivariate Gaussian distribution on array pos.
  39. stolen from https://scipython.com/blog/visualizing-the-bivariate-gaussian-distribution/
  40. """
  41. n = mu.shape[0]
  42. sigma_det = np.linalg.det(sigma)
  43. sigma_inv = np.linalg.inv(sigma)
  44. N = np.sqrt((2*np.pi)**n * sigma_det)
  45. # This einsum call calculates (x-mu)T.sigma-1.(x-mu) in a vectorized
  46. # way across all the input variables.
  47. fac = np.einsum('...k,kl,...l->...', pos-mu, sigma_inv, pos-mu)
  48. return np.exp(-fac / 2) / N
  49. def plot_neuron(x, y, w, h, axis):
  50. neuron = Neuron(x, y, w, h, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5})
  51. axis.add_patch(neuron.patch)
  52. return neuron
  53. def plot_receptivefield(width, num_cells, axis=None, xpos=0.0):
  54. rf = Ellipse([xpos, 0.0], width=width, height=width/2,clip_on=False, **{"facecolor":"tab:blue", "alpha":0.5})
  55. axis.add_patch(rf)
  56. axis.set_xlim([-2, 2])
  57. axis.set_ylim([-2, 2])
  58. xpositions = np.linspace(rf.center[0]-rf.width/2, rf.center[0]+rf.width/2, num_cells)
  59. axis_ratio = axis.bbox.bounds[2] / axis.bbox.bounds[3]
  60. post_neuron = Neuron(-1.75, -1.1, 0.5, 0.5 * axis_ratio, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5})
  61. axis.add_patch(post_neuron.patch)
  62. axis.text(-1.75, -1.2, r"$\Sigma$", fontsize=9, ha="center", va="center",)
  63. neurons = []
  64. for i in range(num_cells):
  65. n = plot_neuron(xpositions[i], 0.0, 0.2, 0.2*axis_ratio, axis,)
  66. n.connect(post_neuron, axis, i+1, num_cells, **{"color": "gray", "linewidth":.5, "zorder":0})
  67. neurons.append(n)
  68. return rf, neurons
  69. def plot_pdfs(axis, width):
  70. #axis.spines["left"].set_visible(True)
  71. y = np.linspace(-5, 5, 100)
  72. axis.set_ylim([-6, 6])
  73. axis.plot(stats.norm.pdf(y, 0, width), y)
  74. axis.axhline(ls="--", lw=0.5, color="black")
  75. axis.set_xticklabels([])
  76. #y, x1, x2=0, where=None, step=None,
  77. axis.fill_betweenx(y[y <= 0.0], np.zeros_like(y[y <= 0.0]), stats.norm.pdf(y[y <= 0.0], 0, width),
  78. color='tab:green', alpha=.2)
  79. axis.fill_betweenx(y[y > 0.0], np.zeros_like(y[y > 0.0]), stats.norm.pdf(y[y > 0.0], 0, width),
  80. color='tab:red', alpha=.2)
  81. axis.text(0.3, 0.2, "lead", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:green")
  82. axis.text(0.3, 0.8, "lag", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:red")
  83. def plot_stims(axis, receptive_field, neurons):
  84. axis.text(receptive_field.center[0], receptive_field.center[1] + 1.5, "s(t)",
  85. color="tab:red", fontsize=9, ha="center", va="bottom")
  86. minx = 0
  87. maxx = 0
  88. for n in neurons:
  89. minx = n.x if n.x < minx else minx
  90. maxx = n.x if n.x > maxx else maxx
  91. axis.arrow(n.x, receptive_field.center[1] + 1.25, 0.0, -0.4, width=0.005, ec="tab:red", lw=0.5,
  92. head_width=10*0.005, head_length=25*0.005)
  93. axis.plot([minx, maxx], [receptive_field.center[1] + 1.25]*2, color="tab:red", lw=0.6)
  94. axis.plot([receptive_field.center[0]]*2, [receptive_field.center[1] + 1.45, receptive_field.center[1] + 1.25],
  95. color="tab:red", lw=0.6)
  96. def plot_delay(axis, num_cells, cellxmin=0.05, cellxmax=3.05):
  97. xmin = 0.0
  98. xmax = 3.5
  99. x = np.array([cellxmin, cellxmax])
  100. n = 0.0
  101. m = 0.25
  102. axis.set_clip_on(False)
  103. axis.spines["bottom"].set_visible(True)
  104. axis.spines["left"].set_visible(True)
  105. axis.set_yticks(np.arange(0.0, 1.6, 0.5))
  106. axis.set_yticks(np.arange(0.0, 1.6, 0.25), minor=True)
  107. axis.set_yticklabels([0, "","",""], fontsize=tick_label_size)
  108. axis.set_xticks(np.arange(xmin,xmax+0.1, 0.5))
  109. axis.set_xticks(np.arange(xmin, xmax+0.1, 0.25), minor=True)
  110. labels = [""] * len(axis.get_xticks(minor=False))
  111. labels[0] = "0"
  112. axis.set_xticklabels(labels, fontsize=tick_label_size)
  113. axis.plot(x, m*x+n, lw=1.0)
  114. axis.set_xlabel("distance to target", fontsize=label_size)
  115. axis.set_ylabel("delay", fontsize=label_size)
  116. axis.yaxis.set_label_coords(-0.15, 0.5)
  117. axis.set_ylim([0.0, 1.0])
  118. axis.set_xlim([xmin, xmax])
  119. axis.plot([0.6*xmax]*2, [0.2*xmax*m+n, 0.6*xmax*m+n], color="tab:red", lw=0.5)
  120. axis.plot([0.2*xmax, 0.6*xmax], [0.2*xmax*m+n]*2, color="tab:red", lw=0.5)
  121. axis.text(0.4*xmax, 0.18*xmax*m+n, r"$\Delta_{dist.}$", fontsize=7, va="top", ha="center")
  122. axis.text(0.625*xmax, 0.4*xmax*m+n, r"$\Delta_{delay}$", fontsize=7, va="center", ha="left")
  123. axis.text(0.05, 0.9, r"$v_{cond.}=\Delta_{dist.}/\Delta_{delay}$", fontsize=7, ha="left")
  124. positions = np.linspace(x[0], x[1], num_cells)
  125. axis.scatter(positions, positions*m+n, s=10, fc="white", ec="k")
  126. def plot_delayrf_center(axis, num_cells, cellxmin, cellxmax, cellcenter, xmin=-2, xmax=2, setlabel=False):
  127. x = np.array([cellxmin, cellxmax])
  128. positions = np.linspace(cellxmin, cellxmax, num_cells)
  129. axis.spines["left"].set_visible(True)
  130. axis.spines["bottom"].set_visible(True)
  131. axis.axhline(color='black', lw=0.5, ls="--")
  132. axis.axvline(x=positions[num_cells//2], color='black', lw=0.5, ls="--")
  133. m = 0.25
  134. n = m*positions[num_cells//2]
  135. ymin = m*x[0] - n
  136. ymax = m*x[-1] - n
  137. axis.plot(x+[positions[num_cells//2]], m*x - n, lw=1.0)
  138. axis.scatter(positions+[positions[num_cells//2]], positions*m - n, s=10, fc="white", ec="k")
  139. if setlabel:
  140. axis.set_xlabel("distance relative to receptive field center", fontsize=label_size)
  141. axis.set_ylim([-0.75, 0.75])
  142. axis.set_xlim([xmin, xmax])
  143. axis.set_yticks(np.arange(-0.5, 0.6, 0.5))
  144. axis.set_yticks(np.arange(-0.5, 0.6, 0.25), minor=True)
  145. axis.set_yticklabels(["", 0, ""], fontsize=tick_label_size)
  146. axis.fill_between([xmin, 0], [0.0]*2, [ymin]*2, color="tab:green", alpha=0.2)
  147. axis.fill_between([0, xmax], [0.0]*2, [ymax]*2, color="tab:red", alpha=0.2)
  148. axis.set_xticks(positions)
  149. labels = [""] * num_cells
  150. labels[num_cells//2] = "0"
  151. axis.set_xticklabels(labels, fontsize=tick_label_size)
  152. def layout_figure(num_rfs=3):
  153. fig = plt.figure(figsize=(5.1, 2.5))
  154. sketch_axes = []
  155. delay_axes = []
  156. pdf_axes = []
  157. sketch_labels = [r"B$_i$", r"C$_i$", r"D$_i$"]
  158. delay_labels = ["A", r"B$_{ii}$", r"C$_{ii}$", r"D$_{ii}$"]
  159. gs = gridspec.GridSpec(1, 4, left=0.05, right=0.975, top=0.95, bottom=0.15,
  160. wspace=0.15, figure=fig)
  161. for i in range(num_rfs + 1):
  162. sgs = gridspec.GridSpecFromSubplotSpec(2, 2, wspace=0.0, subplot_spec=gs[0, i], width_ratios=[4, 1],
  163. height_ratios=[1, 2])
  164. if i > 0:
  165. axis = fig.add_subplot(sgs[0, 0])
  166. despine(axis, ["top", "left", "right", "bottom"])
  167. axis.text(-0.3, 1., sketch_labels[i-1], transform=axis.transAxes, ha="left",
  168. va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight)
  169. sketch_axes.append(axis)
  170. axis = fig.add_subplot(sgs[1, 0])
  171. despine(axis, ["top", "right"])
  172. axis.text(-0.3, 1.15, delay_labels[i], transform=axis.transAxes, ha="left",
  173. va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight)
  174. delay_axes.append(axis)
  175. axis = fig.add_subplot(sgs[1, 1])
  176. despine(axis, ["left", "top", "right", "bottom"])
  177. pdf_axes.append(axis)
  178. return fig, sketch_axes, delay_axes, pdf_axes
  179. def introductory_figure(args):
  180. fig, sketch_axes, delay_axes, pdf_axes = layout_figure()
  181. rfs = [1, 2., 3]
  182. num_cells = [3, 5, 9]
  183. sigmas = [0.5, 0.9, 1.3]
  184. for i, (rf, nc) in enumerate(zip(rfs, num_cells)):
  185. receptivefield, neurons = plot_receptivefield(rf, nc, axis=sketch_axes[i])
  186. plot_stims(sketch_axes[i], receptivefield, neurons)
  187. plot_delayrf_center(delay_axes[i+1], nc, cellxmin=neurons[0].x, cellxmax=neurons[-1].x,
  188. cellcenter=neurons[nc//2].x, setlabel=i+1==2)
  189. plot_pdfs(pdf_axes[i+1], sigmas[i])
  190. plot_delay(delay_axes[0], num_cells[-1])
  191. if args.nosave:
  192. plt.show()
  193. else:
  194. fig.savefig(args.outfile, dpi=500)
  195. plt.close()
  196. def command_line_parser(subparsers):
  197. parser = subparsers.add_parser("intro_figure2", help="Introductory figure (figure 1).")
  198. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","intro_figure.pdf"), help="The filename of the figure")
  199. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  200. parser.set_defaults(func=introductory_figure)