############################################################################### ## introductory plot ## import os import numpy as np import scipy.stats as stats import matplotlib.pyplot as plt from matplotlib import gridspec from matplotlib.patches import Ellipse from .figure_style import subfig_labelsize, subfig_labelweight, despine, label_size, tick_label_size class Neuron: def __init__(self, x, y, w, h, **kwargs) -> None: if "text" in kwargs: del kwargs["text"] print("Text found!") self._ellipse = Ellipse([x,y], w, h, **kwargs) self._height = h @property def patch(self): return self._ellipse @property def x(self): return self._ellipse.center[0] @property def y(self): return self._ellipse.center[1] @property def height(self): return self._height def connect(self, other, axis, num=1, maxnum=1, **kwargs): dest = self.destination(other, num, max_connections=maxnum) axis.plot([self.x, dest[0]],[dest[1], dest[1]], **kwargs) axis.plot([self.x, self.x], [self.y, dest[1]], **kwargs) def destination(self, other, num=1, max_connections=1): y_positions = np.linspace(other.y + other.height/2*2/3, other.y - other.height/2*2/3, max_connections) dest = [other.x, y_positions[num-1]] return dest def multivariate_gaussian(pos, mu, sigma): """Return the multivariate Gaussian distribution on array pos. stolen from https://scipython.com/blog/visualizing-the-bivariate-gaussian-distribution/ """ n = mu.shape[0] sigma_det = np.linalg.det(sigma) sigma_inv = np.linalg.inv(sigma) N = np.sqrt((2*np.pi)**n * sigma_det) # This einsum call calculates (x-mu)T.sigma-1.(x-mu) in a vectorized # way across all the input variables. fac = np.einsum('...k,kl,...l->...', pos-mu, sigma_inv, pos-mu) return np.exp(-fac / 2) / N def plot_neuron(x, y, w, h, axis): neuron = Neuron(x, y, w, h, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5}) axis.add_patch(neuron.patch) return neuron def plot_receptivefield(width, num_cells, axis=None, xpos=0.0): rf = Ellipse([xpos, 0.0], width=width, height=width/2,clip_on=False, **{"facecolor":"tab:blue", "alpha":0.5}) axis.add_patch(rf) axis.set_xlim([-2, 2]) axis.set_ylim([-2, 2]) xpositions = np.linspace(rf.center[0]-rf.width/2, rf.center[0]+rf.width/2, num_cells) axis_ratio = axis.bbox.bounds[2] / axis.bbox.bounds[3] post_neuron = Neuron(-1.75, -1.1, 0.5, 0.5 * axis_ratio, **{"facecolor": "white", "edgecolor":"black", "linewidth": 0.5}) axis.add_patch(post_neuron.patch) axis.text(-1.75, -1.2, r"$\Sigma$", fontsize=9, ha="center", va="center",) neurons = [] for i in range(num_cells): n = plot_neuron(xpositions[i], 0.0, 0.2, 0.2*axis_ratio, axis,) n.connect(post_neuron, axis, i+1, num_cells, **{"color": "gray", "linewidth":.5, "zorder":0}) neurons.append(n) return rf, neurons def plot_pdfs(axis, width): #axis.spines["left"].set_visible(True) y = np.linspace(-5, 5, 100) axis.set_ylim([-6, 6]) axis.plot(stats.norm.pdf(y, 0, width), y) axis.axhline(ls="--", lw=0.5, color="black") axis.set_xticklabels([]) #y, x1, x2=0, where=None, step=None, axis.fill_betweenx(y[y <= 0.0], np.zeros_like(y[y <= 0.0]), stats.norm.pdf(y[y <= 0.0], 0, width), color='tab:green', alpha=.2) axis.fill_betweenx(y[y > 0.0], np.zeros_like(y[y > 0.0]), stats.norm.pdf(y[y > 0.0], 0, width), color='tab:red', alpha=.2) axis.text(0.3, 0.2, "lead", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:green") axis.text(0.3, 0.8, "lag", transform=axis.transAxes, fontsize=7, va="center", rotation=-90, color="tab:red") def plot_stims(axis, receptive_field, neurons): axis.text(receptive_field.center[0], receptive_field.center[1] + 1.5, "s(t)", color="tab:red", fontsize=9, ha="center", va="bottom") minx = 0 maxx = 0 for n in neurons: minx = n.x if n.x < minx else minx maxx = n.x if n.x > maxx else maxx axis.arrow(n.x, receptive_field.center[1] + 1.25, 0.0, -0.4, width=0.005, ec="tab:red", lw=0.5, head_width=10*0.005, head_length=25*0.005) axis.plot([minx, maxx], [receptive_field.center[1] + 1.25]*2, color="tab:red", lw=0.6) axis.plot([receptive_field.center[0]]*2, [receptive_field.center[1] + 1.45, receptive_field.center[1] + 1.25], color="tab:red", lw=0.6) def plot_delay(axis, num_cells, cellxmin=0.05, cellxmax=3.05): xmin = 0.0 xmax = 3.5 x = np.array([cellxmin, cellxmax]) n = 0.0 m = 0.25 axis.set_clip_on(False) axis.spines["bottom"].set_visible(True) axis.spines["left"].set_visible(True) axis.set_yticks(np.arange(0.0, 1.6, 0.5)) axis.set_yticks(np.arange(0.0, 1.6, 0.25), minor=True) axis.set_yticklabels([0, "","",""], fontsize=tick_label_size) axis.set_xticks(np.arange(xmin,xmax+0.1, 0.5)) axis.set_xticks(np.arange(xmin, xmax+0.1, 0.25), minor=True) labels = [""] * len(axis.get_xticks(minor=False)) labels[0] = "0" axis.set_xticklabels(labels, fontsize=tick_label_size) axis.plot(x, m*x+n, lw=1.0) axis.set_xlabel("distance to target", fontsize=label_size) axis.set_ylabel("delay", fontsize=label_size) axis.yaxis.set_label_coords(-0.15, 0.5) axis.set_ylim([0.0, 1.0]) axis.set_xlim([xmin, xmax]) axis.plot([0.6*xmax]*2, [0.2*xmax*m+n, 0.6*xmax*m+n], color="tab:red", lw=0.5) axis.plot([0.2*xmax, 0.6*xmax], [0.2*xmax*m+n]*2, color="tab:red", lw=0.5) axis.text(0.4*xmax, 0.18*xmax*m+n, r"$\Delta_{dist.}$", fontsize=7, va="top", ha="center") axis.text(0.625*xmax, 0.4*xmax*m+n, r"$\Delta_{delay}$", fontsize=7, va="center", ha="left") axis.text(0.05, 0.9, r"$v_{cond.}=\Delta_{dist.}/\Delta_{delay}$", fontsize=7, ha="left") positions = np.linspace(x[0], x[1], num_cells) axis.scatter(positions, positions*m+n, s=10, fc="white", ec="k") def plot_delayrf_center(axis, num_cells, cellxmin, cellxmax, cellcenter, xmin=-2, xmax=2, setlabel=False): x = np.array([cellxmin, cellxmax]) positions = np.linspace(cellxmin, cellxmax, num_cells) axis.spines["left"].set_visible(True) axis.spines["bottom"].set_visible(True) axis.axhline(color='black', lw=0.5, ls="--") axis.axvline(x=positions[num_cells//2], color='black', lw=0.5, ls="--") m = 0.25 n = m*positions[num_cells//2] ymin = m*x[0] - n ymax = m*x[-1] - n axis.plot(x+[positions[num_cells//2]], m*x - n, lw=1.0) axis.scatter(positions+[positions[num_cells//2]], positions*m - n, s=10, fc="white", ec="k") if setlabel: axis.set_xlabel("distance relative to receptive field center", fontsize=label_size) axis.set_ylim([-0.75, 0.75]) axis.set_xlim([xmin, xmax]) axis.set_yticks(np.arange(-0.5, 0.6, 0.5)) axis.set_yticks(np.arange(-0.5, 0.6, 0.25), minor=True) axis.set_yticklabels(["", 0, ""], fontsize=tick_label_size) axis.fill_between([xmin, 0], [0.0]*2, [ymin]*2, color="tab:green", alpha=0.2) axis.fill_between([0, xmax], [0.0]*2, [ymax]*2, color="tab:red", alpha=0.2) axis.set_xticks(positions) labels = [""] * num_cells labels[num_cells//2] = "0" axis.set_xticklabels(labels, fontsize=tick_label_size) def layout_figure(num_rfs=3): fig = plt.figure(figsize=(5.1, 2.5)) sketch_axes = [] delay_axes = [] pdf_axes = [] sketch_labels = [r"B$_i$", r"C$_i$", r"D$_i$"] delay_labels = ["A", r"B$_{ii}$", r"C$_{ii}$", r"D$_{ii}$"] gs = gridspec.GridSpec(1, 4, left=0.05, right=0.975, top=0.95, bottom=0.15, wspace=0.15, figure=fig) for i in range(num_rfs + 1): sgs = gridspec.GridSpecFromSubplotSpec(2, 2, wspace=0.0, subplot_spec=gs[0, i], width_ratios=[4, 1], height_ratios=[1, 2]) if i > 0: axis = fig.add_subplot(sgs[0, 0]) despine(axis, ["top", "left", "right", "bottom"]) axis.text(-0.3, 1., sketch_labels[i-1], transform=axis.transAxes, ha="left", va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight) sketch_axes.append(axis) axis = fig.add_subplot(sgs[1, 0]) despine(axis, ["top", "right"]) axis.text(-0.3, 1.15, delay_labels[i], transform=axis.transAxes, ha="left", va="top", fontsize=subfig_labelsize, fontweight=subfig_labelweight) delay_axes.append(axis) axis = fig.add_subplot(sgs[1, 1]) despine(axis, ["left", "top", "right", "bottom"]) pdf_axes.append(axis) return fig, sketch_axes, delay_axes, pdf_axes def introductory_figure(args): fig, sketch_axes, delay_axes, pdf_axes = layout_figure() rfs = [1, 2., 3] num_cells = [3, 5, 9] sigmas = [0.5, 0.9, 1.3] for i, (rf, nc) in enumerate(zip(rfs, num_cells)): receptivefield, neurons = plot_receptivefield(rf, nc, axis=sketch_axes[i]) plot_stims(sketch_axes[i], receptivefield, neurons) plot_delayrf_center(delay_axes[i+1], nc, cellxmin=neurons[0].x, cellxmax=neurons[-1].x, cellcenter=neurons[nc//2].x, setlabel=i+1==2) plot_pdfs(pdf_axes[i+1], sigmas[i]) plot_delay(delay_axes[0], num_cells[-1]) if args.nosave: plt.show() else: fig.savefig(args.outfile, dpi=500) plt.close() def command_line_parser(subparsers): parser = subparsers.add_parser("intro_figure2", help="Introductory figure (figure 1).") parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures","intro_figure.pdf"), help="The filename of the figure") parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") parser.set_defaults(func=introductory_figure)