123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- ###############################################################################
- ## introductory plot ##
- import os
- import numpy as np
- import scipy.stats as stats
- import matplotlib.pyplot as plt
- from matplotlib import gridspec
- from matplotlib.patches import Ellipse
- from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size
- class Neuron:
- def __init__(self, x, y, w, h, **kwargs) -> None:
- if "text" in kwargs:
- del kwargs["text"]
- print("Text found!")
- self._ellipse = Ellipse([x,y], w, h, **kwargs)
- self._height = h
- @property
- def patch(self):
- return self._ellipse
- @property
- def x(self):
- return self._ellipse.center[0]
- @property
- def y(self):
- return self._ellipse.center[1]
- @property
- def height(self):
- return self._height
- def connect(self, other, axis, num=1, maxnum=1, **kwargs):
- dest = self.destination(other, num, max_connections=maxnum)
- axis.plot([self.x, dest[0]],[dest[1], dest[1]], **kwargs)
- axis.plot([self.x, self.x], [self.y, dest[1]], **kwargs)
- def destination(self, other, num=1, max_connections=1):
- y_positions = np.linspace(other.y + other.height/2*2/3, other.y - other.height/2*2/3, max_connections)
- dest = [other.x, y_positions[num-1]]
- return dest
- def multivariate_gaussian(pos, mu, sigma):
- """Return the multivariate Gaussian distribution on array pos.
- stolen from https://scipython.com/blog/visualizing-the-bivariate-gaussian-distribution/
- """
- n = mu.shape[0]
- sigma_det = np.linalg.det(sigma)
- sigma_inv = np.linalg.inv(sigma)
- N = np.sqrt((2*np.pi)**n * sigma_det)
- # This einsum call calculates (x-mu)T.sigma-1.(x-mu) in a vectorized
- # way across all the input variables.
- fac = np.einsum('...k,kl,...l->...', pos-mu, sigma_inv, pos-mu)
- return np.exp(-fac / 2) / N
- def plot_neuron(x, y, w, h, axis):
- neuron = Neuron(x, y, w, h, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5})
- axis.add_patch(neuron.patch)
- return neuron
- def plot_receptivefield(width, num_cells, axis=None, xpos=0.0):
- rf = Ellipse([xpos, 0.0], width=width, height=width/2,clip_on=False, **{"facecolor":"tab:blue", "alpha":0.5})
- axis.add_patch(rf)
- axis.set_xlim([-2, 2])
- axis.set_ylim([-2, 2])
- xpositions = np.linspace(rf.center[0]-rf.width/2, rf.center[0]+rf.width/2, num_cells)
- axis_ratio = axis.bbox.bounds[2] / axis.bbox.bounds[3]
- post_neuron = Neuron(-1.75, -1.1, 0.5, 0.5 * axis_ratio, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5})
- axis.add_patch(post_neuron.patch)
- axis.text(-1.75, -1.2, r"$\Sigma$", fontsize=9, ha="center", va="center",)
- neurons = []
- for i in range(num_cells):
- n = plot_neuron(xpositions[i], 0.0, 0.2, 0.2*axis_ratio, axis,)
- n.connect(post_neuron, axis, i+1, num_cells, **{"color": "gray", "linewidth":.5, "zorder":0})
- neurons.append(n)
- return rf, neurons
- def plot_pdfs(axis, width):
- #axis.spines["left"].set_visible(True)
- y = np.linspace(-5, 5, 100)
- axis.set_ylim([-6, 6])
- axis.plot(stats.norm.pdf(y, 0, width), y)
- axis.axhline(ls="--", lw=0.5, color="black")
- axis.set_xticklabels([])
- #y, x1, x2=0, where=None, step=None,
- axis.fill_betweenx(y[y <= 0.0], np.zeros_like(y[y <= 0.0]), stats.norm.pdf(y[y <= 0.0], 0, width),
- color='tab:green', alpha=.2)
- axis.fill_betweenx(y[y > 0.0], np.zeros_like(y[y > 0.0]), stats.norm.pdf(y[y > 0.0], 0, width),
- color='tab:red', alpha=.2)
- axis.text(0.3, 0.2, "lead", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:green")
- axis.text(0.3, 0.8, "lag", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:red")
- def plot_stims(axis, receptive_field, neurons):
- axis.text(receptive_field.center[0], receptive_field.center[1] + 1.5, "s(t)",
- color="tab:red", fontsize=9, ha="center", va="bottom")
- minx = 0
- maxx = 0
- for n in neurons:
- minx = n.x if n.x < minx else minx
- maxx = n.x if n.x > maxx else maxx
- axis.arrow(n.x, receptive_field.center[1] + 1.25, 0.0, -0.4, width=0.005, ec="tab:red", lw=0.5,
- head_width=10*0.005, head_length=25*0.005)
- axis.plot([minx, maxx], [receptive_field.center[1] + 1.25]*2, color="tab:red", lw=0.6)
- axis.plot([receptive_field.center[0]]*2, [receptive_field.center[1] + 1.45, receptive_field.center[1] + 1.25],
- color="tab:red", lw=0.6)
- def plot_delay(axis, num_cells, cellxmin=0.05, cellxmax=3.05):
- xmin = 0.0
- xmax = 3.5
- x = np.array([cellxmin, cellxmax])
- n = 0.0
- m = 0.25
- axis.set_clip_on(False)
- axis.spines["bottom"].set_visible(True)
- axis.spines["left"].set_visible(True)
- axis.set_yticks(np.arange(0.0, 1.6, 0.5))
- axis.set_yticks(np.arange(0.0, 1.6, 0.25), minor=True)
- axis.set_yticklabels([0, "","",""], fontsize=tick_label_size)
- axis.set_xticks(np.arange(xmin,xmax+0.1, 0.5))
- axis.set_xticks(np.arange(xmin, xmax+0.1, 0.25), minor=True)
- labels = [""] * len(axis.get_xticks(minor=False))
- labels[0] = "0"
- axis.set_xticklabels(labels, fontsize=tick_label_size)
- axis.plot(x, m*x+n, lw=1.0)
- axis.set_xlabel("distance to target", fontsize=label_size)
- axis.set_ylabel("delay", fontsize=label_size)
- axis.yaxis.set_label_coords(-0.15, 0.5)
- axis.set_ylim([0.0, 1.0])
- axis.set_xlim([xmin, xmax])
- axis.plot([0.6*xmax]*2, [0.2*xmax*m+n, 0.6*xmax*m+n], color="tab:red", lw=0.5)
- axis.plot([0.2*xmax, 0.6*xmax], [0.2*xmax*m+n]*2, color="tab:red", lw=0.5)
- axis.text(0.4*xmax, 0.18*xmax*m+n, r"$\Delta_{dist.}$", fontsize=7, va="top", ha="center")
- axis.text(0.625*xmax, 0.4*xmax*m+n, r"$\Delta_{delay}$", fontsize=7, va="center", ha="left")
- axis.text(0.05, 0.9, r"$v_{cond.}=\Delta_{dist.}/\Delta_{delay}$", fontsize=7, ha="left")
- positions = np.linspace(x[0], x[1], num_cells)
- axis.scatter(positions, positions*m+n, s=10, fc="white", ec="k")
- def plot_delayrf_center(axis, num_cells, cellxmin, cellxmax, cellcenter, xmin=-2, xmax=2, setlabel=False):
- x = np.array([cellxmin, cellxmax])
- positions = np.linspace(cellxmin, cellxmax, num_cells)
- axis.spines["left"].set_visible(True)
- axis.spines["bottom"].set_visible(True)
- axis.axhline(color='black', lw=0.5, ls="--")
- axis.axvline(x=positions[num_cells//2], color='black', lw=0.5, ls="--")
- m = 0.25
- n = m*positions[num_cells//2]
- ymin = m*x[0] - n
- ymax = m*x[-1] - n
- axis.plot(x+[positions[num_cells//2]], m*x - n, lw=1.0)
- axis.scatter(positions+[positions[num_cells//2]], positions*m - n, s=10, fc="white", ec="k")
- if setlabel:
- axis.set_xlabel("distance relative to receptive field center", fontsize=label_size)
- axis.set_ylim([-0.75, 0.75])
- axis.set_xlim([xmin, xmax])
- axis.set_yticks(np.arange(-0.5, 0.6, 0.5))
- axis.set_yticks(np.arange(-0.5, 0.6, 0.25), minor=True)
- axis.set_yticklabels(["", 0, ""], fontsize=tick_label_size)
- axis.fill_between([xmin, 0], [0.0]*2, [ymin]*2, color="tab:green", alpha=0.2)
- axis.fill_between([0, xmax], [0.0]*2, [ymax]*2, color="tab:red", alpha=0.2)
- axis.set_xticks(positions)
- labels = [""] * num_cells
- labels[num_cells//2] = "0"
- axis.set_xticklabels(labels, fontsize=tick_label_size)
- def layout_figure(num_rfs=3):
- fig = plt.figure(figsize=(5.1, 2.5))
- sketch_axes = []
- delay_axes = []
- pdf_axes = []
- sketch_labels = [r"B$_i$", r"C$_i$", r"D$_i$"]
- delay_labels = ["A", r"B$_{ii}$", r"C$_{ii}$", r"D$_{ii}$"]
- gs = gridspec.GridSpec(1, 4, left=0.05, right=0.975, top=0.95, bottom=0.15,
- wspace=0.15, figure=fig)
- for i in range(num_rfs + 1):
- sgs = gridspec.GridSpecFromSubplotSpec(2, 2, wspace=0.0, subplot_spec=gs[0, i], width_ratios=[4, 1],
- height_ratios=[1, 2])
- if i > 0:
- axis = fig.add_subplot(sgs[0, 0])
- despine(axis, ["top", "left", "right", "bottom"])
- axis.text(-0.3, 1., sketch_labels[i-1], transform=axis.transAxes, ha="left",
- va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight)
- sketch_axes.append(axis)
- axis = fig.add_subplot(sgs[1, 0])
- despine(axis, ["top", "right"])
- axis.text(-0.3, 1.15, delay_labels[i], transform=axis.transAxes, ha="left",
- va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight)
- delay_axes.append(axis)
- axis = fig.add_subplot(sgs[1, 1])
- despine(axis, ["left", "top", "right", "bottom"])
- pdf_axes.append(axis)
- return fig, sketch_axes, delay_axes, pdf_axes
- def introductory_figure(args):
- fig, sketch_axes, delay_axes, pdf_axes = layout_figure()
- rfs = [1, 2., 3]
- num_cells = [3, 5, 9]
- sigmas = [0.5, 0.9, 1.3]
- for i, (rf, nc) in enumerate(zip(rfs, num_cells)):
- receptivefield, neurons = plot_receptivefield(rf, nc, axis=sketch_axes[i])
- plot_stims(sketch_axes[i], receptivefield, neurons)
- plot_delayrf_center(delay_axes[i+1], nc, cellxmin=neurons[0].x, cellxmax=neurons[-1].x,
- cellcenter=neurons[nc//2].x, setlabel=i+1==2)
- plot_pdfs(pdf_axes[i+1], sigmas[i])
- plot_delay(delay_axes[0], num_cells[-1])
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile, dpi=500)
- plt.close()
- def command_line_parser(subparsers):
- parser = subparsers.add_parser("intro_figure2", help="Introductory figure (figure 1).")
- parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","intro_figure.pdf"), help="The filename of the figure")
- parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- parser.set_defaults(func=introductory_figure)
|