##################################################################################################### ## methods plot with baseline and driven activity of an example cell ## import os import numpy as np import matplotlib.pyplot as plt from matplotlib import gridspec from matplotlib.patches import ArrowStyle from .figure_style import subfig_labelsize, subfig_labelweight, legend_fontsize, despine, label_size from .figure1_analysis import analyze_baseline_activity, analyze_driven_activity, get_firing_rate fig1_help="Plots a methods figure with baseline activity and stimulus driven activity. Depends on figure1_baseline_data.npz and figure1_whitenoise_data.npz files. These are expected in the ./derived_data folder" def plot_isi(spikes, cv, burstiness, axis): axis.hist(np.diff(spikes), bins=np.arange(0.0, 0.015, 0.0002), density=True, alpha=0.75) axis.text(0.95 * axis.get_xlim()[1], 0.85 * axis.get_ylim()[1], r"$CV_{ISI}$: %.2f" % (cv), fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom") axis.text(0.95 * axis.get_xlim()[1], 1.0 * axis.get_ylim()[1], "Burstiness: %.2f" % (burstiness), fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom") despine(axis) axis.spines["bottom"].set_visible(True) axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.001), minor=True) axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.005)) axis.xaxis.set_ticklabels(np.arange(0, 16, 5)) axis.set_xlabel("ISI [ms]", fontsize=label_size) axis.xaxis.set_label_coords(0.5, -0.225) def plot_phase_locking(eod_waveform, eod_error, spike_phases, waveform_axis, vs, preferred_phase, eodf, ylim): delay = 2* np.pi/preferred_phase / eodf x = np.linspace(0.0, 2 * np.pi, len(eod_waveform)) waveform_axis.plot(x, eod_waveform, color="tab:orange", lw=0.75, zorder=2) waveform_axis.fill_between(x, eod_waveform + eod_error, eod_waveform - eod_error, color="tab:orange", alpha=0.25, zorder=1, lw=0.0) spike_phase_ax = waveform_axis.twinx() spike_phase_ax.hist(spike_phases, bins=np.arange(0.0, 2 * np.pi + 0.01, np.pi / 12), density=True, alpha=0.5, color="tab:blue", edgecolor="steelblue") waveform_axis.text(0.95 * waveform_axis.get_xlim()[1], 1.25 * waveform_axis.get_ylim()[1], "VS: %.2f" % np.abs(vs), va="bottom", ha="right", color="tab:blue", fontsize=legend_fontsize) waveform_axis.set_ylim([ylim[0], 1.25 * ylim[-1]]) despine(waveform_axis, ["left", "top", "right", "bottom"], True) arrow = ArrowStyle.BarAB(widthA=0.15, widthB=.15) waveform_axis.annotate(text='', xy=(0, 1.05 * np.max(eod_waveform)), xytext=(preferred_phase, 1.05 * np.max(eod_waveform)), arrowprops=dict(arrowstyle=arrow, color="steelblue")) waveform_axis.text(preferred_phase/2, 1.1 * np.max(eod_waveform), f"Delay: {delay*1000:.2f}ms", va="bottom", ha="center", color="tab:blue", fontsize=legend_fontsize) despine(spike_phase_ax, ["left", "top", "right"], True) spike_phase_ax.spines["bottom"].set_visible(True) spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/2)) spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/4), minor=True) spike_phase_ax.set_xticklabels([r"0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$", r"$2\pi$"]) waveform_axis.set_xlabel("phase [rad.]", fontsize=label_size) waveform_axis.xaxis.set_label_coords(0.5, -0.225) ylim = spike_phase_ax.get_ylim() spike_phase_ax.set_ylim([ylim[0], 1.25 * ylim[-1]]) def plot_eod(time, eod, eodf, axis, start_idx, end_idx): axis.plot(time[start_idx:end_idx] - time[start_idx], eod[start_idx:end_idx], color="tab:orange", lw=0.5, label="EOD") ylim = axis.get_ylim() axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 2.], color="k", lw=0.5) axis.text(-.0025, ylim[0] + 1, r"2 mV", rotation=90, va="center", fontsize=legend_fontsize) axis.text(0.95 * axis.get_xlim()[1], 1.25 * ylim[1], "EOD frequency: %i Hz" % (eodf), color="tab:orange", fontsize=legend_fontsize, ha="right", va="top") axis.set_xlim([-0.001, time[end_idx] - time[start_idx]]) despine(axis) def plot_voltage_spikes(time, voltage, baseline_rate, v_axis, start_idx, end_idx): #s_axis = v_axis.twinx() v_axis.plot(time[start_idx:end_idx] - time[start_idx], voltage[start_idx:end_idx], color="tab:blue", lw=0.9, label="voltage") ylim = v_axis.get_ylim() v_axis.plot([0.0, 0.01], [ylim[0], ylim[0]], color="k") v_axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 5], color="k", lw=0.5) v_axis.text(-.0025, ylim[0] + 2.5, r"5 mV", rotation=90, va="center", fontsize=legend_fontsize) v_axis.text(0.005, ylim[0] - 1.25, "10 ms", ha="center", va="top", fontsize=legend_fontsize) v_axis.text(0.95 * v_axis.get_xlim()[1], 1. * ylim[1], "Baseline firing rate: %i Hz" % (baseline_rate), color="tab:blue", fontsize=legend_fontsize, ha="right", va="bottom") v_axis.set_xlim([-0.001, time[end_idx] - time[start_idx]]) v_axis.set_ylim(ylim) despine(v_axis) def plot_baseline_activity(eod_axis, response_axis, vs_axis, isi_axis): data = np.load(os.path.join("derived_data", "figure1_baseline_data.npz")) plot_eod(data["time"], data["eod"], data["eodf"], eod_axis, 950, 1575) plot_voltage_spikes(data["time"], data["voltage"], data["baseline_rate"], response_axis, 950, 1575) plot_isi(data["spike_times"], data["cv"], data["burstiness"], isi_axis) plot_phase_locking(data["eod_template"], data["template_error"], data["spike_phases"], vs_axis, data["vector_strength"], data["preferred_phase"], data["eodf"], eod_axis.get_ylim()) def plot_whitenoise_activity(stimulus_axis, response_axis, gain_axis, coherence_axis): data = np.load(os.path.join("derived_data", "figure1_whitenoise_data.npz")) plot_white_noise_stim(data["stimulus"], stimulus_axis, 1.0, 0.25) spikes = [] temp = data["spikes"] for i in np.unique(data["trial_data"]): spikes.append(temp[data["trial_data"] == i]) plot_white_noise_response(spikes, response_axis, 0.005, 1.0, 0.25) f = data["frequency"] gain = data["gain_spectra"] coh = data["coherence_spectra"] plot_spectrum(f, gain, gain_axis, r"gain [Hz mV$^{-1}$]", False, True) gain_axis.set_xlabel("") plot_spectrum(f, coh, coherence_axis, r"coherence", False, False) coherence_axis.xaxis.set_label_coords(-0.25, -0.23) def plot_white_noise_response(spikes, resp_axis, sigma, start_time, extent): spikes_axis = resp_axis.twinx() time, responses = get_firing_rate(spikes, 10., 1./20000., sigma) avg_response = np.mean(responses, axis=0) std_responses = np.std(responses, axis=0) error_plus = avg_response + std_responses error_minus = avg_response - std_responses resp_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time, avg_response[(time >= start_time) & (time < start_time + extent)], color='tab:blue', lw=0.75) resp_axis.fill_between(time[(time >= start_time) & (time < start_time + extent)] - start_time, error_plus[(time >= start_time) & (time < start_time + extent)], error_minus[(time >= start_time) & (time < start_time + extent)], color="tab:blue", alpha=0.5, lw=0) ylim = [0, np.ceil(np.max(error_plus[(time >= start_time) & (time < start_time + extent)])/50) * 50] resp_axis.set_ylim(ylim) resp_axis.plot([0.0, 0.1], [resp_axis.get_ylim()[0] -0.5]*2, color="k", lw=0.5) resp_axis.text(0.05, -125, r"100 ms", va="bottom", ha="center", fontsize=legend_fontsize) resp_axis.plot([-0.005, -0.005], [0, 200], color="k", lw=1, clip_on=False) resp_axis.text(-.02, -0.3, "200 Hz", fontsize=legend_fontsize, rotation=90) resp_axis.set_xlim([-0.0, extent]) despine(resp_axis) for i in range(len(spikes)): spike_times = spikes[i] spike_times = spike_times[(spike_times >= start_time) & (spike_times < start_time + extent)] spike_times -= start_time spikes_axis.scatter(spike_times, np.ones(len(spike_times)) * i, marker="|", s=30, color="tab:blue", lw=0.2, alpha=0.5) spikes_axis.set_ylim([-0.5, 5.25]) despine(spikes_axis) def plot_white_noise_stim(stimulus, trace_axis, start_time, extent): time = stimulus[0] stim = stimulus[1] trace_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time, -1*stim[(time >= start_time) & (time < start_time + extent)], color='tab:red', lw=0.5) trace_axis.set_ylim([-0.5, 0.5]) trace_axis.set_xlim([0.0, extent]) #trace_axis.plot([0.0, 0.1], [trace_axis.get_ylim()[1], trace_axis.get_ylim()[1]], color="k") #trace_axis.text(0.05, trace_axis.get_ylim()[1], r"100 ms", va="bottom", ha="center", fontsize=7) despine(trace_axis) def plot_spectrum(freq, spectrum, axis, ylabel, yticklabels=False, logplot=True): k = np.ones(15) / 15 smoothed_spectrum = np.zeros(spectrum.shape) for i in range(spectrum.shape[0]): smoothed_spectrum[i, :] = np.convolve(np.abs(spectrum[i, :]), k, mode="same") # smoothed_gain[i, :] = np.abs(gain[i, :]) freq_limit = 293 avg_spectrum = np.mean(smoothed_spectrum, axis=0) std_spectrum = np.std(smoothed_spectrum, axis=0) avg_error_minus = avg_spectrum - std_spectrum avg_error_plus = avg_spectrum + std_spectrum max_spectum = np.max(avg_spectrum[freq < freq_limit]) peak_f = freq[np.where(avg_spectrum == max_spectum)] upper_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq > peak_f)][0] lower_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq < peak_f)][-1] if logplot: axis.semilogy(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue") else: axis.plot(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue") axis.set_ylim([0, 1]) axis.set_yticks(np.arange(0, 1.01, 0.25)) axis.fill_between(freq[freq < freq_limit], avg_error_minus[freq < freq_limit], avg_error_plus[freq < freq_limit], lw=0.0, color="tab:blue", alpha=0.5) despine(axis, ["right", "top"], False) axis.set_yticklabels([]) axis.set_xticks(np.arange(0, 301, 150)) axis.set_xticks(np.arange(0., 301, 50), minor=True) axis.set_xticklabels(np.arange(0, 301, 150)) axis.set_xlabel("frequency [Hz]", fontsize=label_size) axis.set_xlim([0, 300]) axis.set_ylim(axis.get_ylim()) axis.set_ylabel(ylabel) axis.yaxis.set_label_coords(-0.075, 0.5) axis.plot([0, peak_f[0]], [max_spectum, max_spectum], color="r", ls="--", lw=0.75) axis.plot([0, upper_cutoff], [avg_spectrum[freq == upper_cutoff], avg_spectrum[freq == upper_cutoff]], color='r', ls="--", lw=0.75) axis.plot([upper_cutoff, upper_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == upper_cutoff][0]], color='r', ls="--", lw=0.75) axis.plot([lower_cutoff, lower_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == lower_cutoff][0]], color='r', ls="--", lw=0.75) max_text = "max(H(f))" if logplot else "max(coh)" axis.annotate(max_text, (peak_f/2, max_spectum), (75, max_spectum*1.25), fontsize=6, color="r", arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75}, annotation_clip=False) cutoff_text = r"$\frac{max(H(f))}{\sqrt{2}}$" if logplot else r"$\frac{max(coh)}{\sqrt{2}}$" axis.annotate(cutoff_text, (peak_f, max_spectum/np.sqrt(2)), (150, max_spectum * 0.75), fontsize=8, color="r", annotation_clip=False, arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75}) text_ypos = max_spectum/2 if logplot else 0.5 axis.text(lower_cutoff + 5, text_ypos, "lower cutoff", rotation=-90, fontsize=6, color='r', ha="left", va="top") axis.text(upper_cutoff + 5, text_ypos, "upper cutoff", rotation=-90, fontsize=6, color='r', ha="left", va="top") def layout_figure(): fig = plt.figure(figsize=(5.1, 3.75)) gs = gridspec.GridSpec(2, 1, left=0.05, right=0.975, top=0.95, bottom=0.1, wspace=0.15, height_ratios=[3, 2],hspace=0.4, figure=fig) sgs_top = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0], height_ratios=[1, 1], hspace=0.1) sgs_bottom = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs[1], wspace=0.4) eod_axis = fig.add_subplot(sgs_top[0, 0]) eod_axis.text(-0.1, 1.05, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=eod_axis.transAxes) baseline_resp_axis = fig.add_subplot(sgs_top[1, 0]) white_noise_stimulus_axis = fig.add_subplot(sgs_top[0, 1]) white_noise_stimulus_axis.text(-0.1, 1.05, "D", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=white_noise_stimulus_axis.transAxes) white_noise_response_axis = fig.add_subplot(sgs_top[1, 1]) vector_strength_axis = fig.add_subplot(sgs_bottom[0]) vector_strength_axis.text(-0.25, 1.05, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=vector_strength_axis.transAxes) isi_axis = fig.add_subplot(sgs_bottom[1]) isi_axis.text(-0.25, 1.05, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=isi_axis.transAxes) gain_axis = fig.add_subplot(sgs_bottom[2]) gain_axis.text(-0.3, 1.05, "E", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=gain_axis.transAxes) coherence_axis = fig.add_subplot(sgs_bottom[3]) coherence_axis.text(-0.3, 1.05, "F", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=coherence_axis.transAxes) axes = {"eod": eod_axis, "baseline_response": baseline_resp_axis, "vector_strength": vector_strength_axis, "isi": isi_axis, "whitenoise_stimulus": white_noise_stimulus_axis, "whitenoise_response": white_noise_response_axis, "gain": gain_axis, "coherence": coherence_axis} return fig, axes def methods_plot(args): if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_baseline_data.npz")): analyze_baseline_activity(args) if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_whitenoise_data.npz")): analyze_driven_activity(args) fig, axes = layout_figure() plot_baseline_activity(axes["eod"], axes["baseline_response"], axes["vector_strength"], axes["isi"]) plot_whitenoise_activity(axes["whitenoise_stimulus"], axes["whitenoise_response"], axes["gain"], axes["coherence"]) fig.subplots_adjust(bottom=0.1125, left=0.05, top=0.95, right=0.97) if args.nosave: plt.show() else: plt.savefig(args.outfile) plt.close() def command_line_parser(subparsers): methods_parser = subparsers.add_parser("response_features", help=fig1_help) methods_parser.add_argument("-d", "--dataset", default=os.path.join("raw_data", "2018-09-06-ai-invivo-1.nix"), help="the dataset that contains baseline and driven data, only needed when redoing stuff from scratch") methods_parser.add_argument("-o", "--outfile", help="name of the output figure", default="figures/methods.pdf") methods_parser.add_argument("-n", "--nosave", action="store_true", help="just plot and show, no saving of the figure") methods_parser.add_argument("-r", "--redo", action="store_true", help="redo analysis of cell data") methods_parser.set_defaults(func=methods_plot) if __name__ == "__main__": methods_plot()