123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- #####################################################################################################
- ## methods plot with baseline and driven activity of an example cell ##
- import os
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib import gridspec
- from matplotlib.patches import ArrowStyle
- from .figure_style import subfig_labelsize, subfig_labelweight, legend_fontsize, despine, label_size
- from .figure1_analysis import analyze_baseline_activity, analyze_driven_activity, get_firing_rate
- fig1_help="Plots a methods figure with baseline activity and stimulus driven activity. Depends on figure1_baseline_data.npz and figure1_whitenoise_data.npz files. These are expected in the ./derived_data folder"
- def plot_isi(spikes, cv, burstiness, axis):
- axis.hist(np.diff(spikes), bins=np.arange(0.0, 0.015, 0.0002), density=True,
- alpha=0.75)
- axis.text(0.95 * axis.get_xlim()[1], 0.85 * axis.get_ylim()[1], r"$CV_{ISI}$: %.2f" % (cv),
- fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom")
- axis.text(0.95 * axis.get_xlim()[1], 1.0 * axis.get_ylim()[1],
- "Burstiness: %.2f" % (burstiness),
- fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom")
- despine(axis)
- axis.spines["bottom"].set_visible(True)
- axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.001), minor=True)
- axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.005))
- axis.xaxis.set_ticklabels(np.arange(0, 16, 5))
- axis.set_xlabel("ISI [ms]", fontsize=label_size)
- axis.xaxis.set_label_coords(0.5, -0.225)
- def plot_phase_locking(eod_waveform, eod_error, spike_phases, waveform_axis, vs, preferred_phase, eodf, ylim):
- delay = 2* np.pi/preferred_phase / eodf
- x = np.linspace(0.0, 2 * np.pi, len(eod_waveform))
- waveform_axis.plot(x, eod_waveform, color="tab:orange", lw=0.75, zorder=2)
- waveform_axis.fill_between(x, eod_waveform + eod_error, eod_waveform - eod_error,
- color="tab:orange", alpha=0.25, zorder=1, lw=0.0)
- spike_phase_ax = waveform_axis.twinx()
- spike_phase_ax.hist(spike_phases, bins=np.arange(0.0, 2 * np.pi + 0.01, np.pi / 12),
- density=True, alpha=0.5, color="tab:blue", edgecolor="steelblue")
-
- waveform_axis.text(0.95 * waveform_axis.get_xlim()[1], 1.25 * waveform_axis.get_ylim()[1],
- "VS: %.2f" % np.abs(vs),
- va="bottom", ha="right", color="tab:blue", fontsize=legend_fontsize)
- waveform_axis.set_ylim([ylim[0], 1.25 * ylim[-1]])
- despine(waveform_axis, ["left", "top", "right", "bottom"], True)
-
- arrow = ArrowStyle.BarAB(widthA=0.15, widthB=.15)
- waveform_axis.annotate(text='', xy=(0, 1.05 * np.max(eod_waveform)), xytext=(preferred_phase, 1.05 * np.max(eod_waveform)), arrowprops=dict(arrowstyle=arrow, color="steelblue"))
- waveform_axis.text(preferred_phase/2, 1.1 * np.max(eod_waveform), f"Delay: {delay*1000:.2f}ms",
- va="bottom", ha="center", color="tab:blue", fontsize=legend_fontsize)
- despine(spike_phase_ax, ["left", "top", "right"], True)
- spike_phase_ax.spines["bottom"].set_visible(True)
- spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/2))
- spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/4), minor=True)
- spike_phase_ax.set_xticklabels([r"0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$",
- r"$2\pi$"])
- waveform_axis.set_xlabel("phase [rad.]", fontsize=label_size)
- waveform_axis.xaxis.set_label_coords(0.5, -0.225)
- ylim = spike_phase_ax.get_ylim()
- spike_phase_ax.set_ylim([ylim[0], 1.25 * ylim[-1]])
- def plot_eod(time, eod, eodf, axis, start_idx, end_idx):
- axis.plot(time[start_idx:end_idx] - time[start_idx],
- eod[start_idx:end_idx],
- color="tab:orange", lw=0.5, label="EOD")
- ylim = axis.get_ylim()
- axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 2.], color="k", lw=0.5)
- axis.text(-.0025, ylim[0] + 1, r"2 mV", rotation=90, va="center", fontsize=legend_fontsize)
- axis.text(0.95 * axis.get_xlim()[1], 1.25 * ylim[1],
- "EOD frequency: %i Hz" % (eodf),
- color="tab:orange", fontsize=legend_fontsize, ha="right", va="top")
- axis.set_xlim([-0.001, time[end_idx] - time[start_idx]])
- despine(axis)
- def plot_voltage_spikes(time, voltage, baseline_rate, v_axis,
- start_idx, end_idx):
- #s_axis = v_axis.twinx()
- v_axis.plot(time[start_idx:end_idx] - time[start_idx],
- voltage[start_idx:end_idx],
- color="tab:blue", lw=0.9, label="voltage")
- ylim = v_axis.get_ylim()
- v_axis.plot([0.0, 0.01], [ylim[0], ylim[0]], color="k")
- v_axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 5], color="k", lw=0.5)
- v_axis.text(-.0025, ylim[0] + 2.5, r"5 mV", rotation=90, va="center", fontsize=legend_fontsize)
- v_axis.text(0.005, ylim[0] - 1.25, "10 ms", ha="center", va="top", fontsize=legend_fontsize)
- v_axis.text(0.95 * v_axis.get_xlim()[1], 1. * ylim[1],
- "Baseline firing rate: %i Hz" % (baseline_rate),
- color="tab:blue", fontsize=legend_fontsize, ha="right", va="bottom")
- v_axis.set_xlim([-0.001, time[end_idx] - time[start_idx]])
- v_axis.set_ylim(ylim)
- despine(v_axis)
- def plot_baseline_activity(eod_axis, response_axis, vs_axis, isi_axis):
- data = np.load(os.path.join("derived_data", "figure1_baseline_data.npz"))
- plot_eod(data["time"], data["eod"], data["eodf"], eod_axis, 950, 1575)
- plot_voltage_spikes(data["time"], data["voltage"], data["baseline_rate"], response_axis, 950, 1575)
- plot_isi(data["spike_times"], data["cv"], data["burstiness"], isi_axis)
- plot_phase_locking(data["eod_template"], data["template_error"], data["spike_phases"], vs_axis, data["vector_strength"], data["preferred_phase"], data["eodf"], eod_axis.get_ylim())
- def plot_whitenoise_activity(stimulus_axis, response_axis, gain_axis, coherence_axis):
- data = np.load(os.path.join("derived_data", "figure1_whitenoise_data.npz"))
- plot_white_noise_stim(data["stimulus"], stimulus_axis, 1.0, 0.25)
- spikes = []
- temp = data["spikes"]
- for i in np.unique(data["trial_data"]):
- spikes.append(temp[data["trial_data"] == i])
- plot_white_noise_response(spikes, response_axis, 0.005, 1.0, 0.25)
- f = data["frequency"]
- gain = data["gain_spectra"]
- coh = data["coherence_spectra"]
- plot_spectrum(f, gain, gain_axis, r"gain [Hz mV$^{-1}$]", False, True)
- gain_axis.set_xlabel("")
- plot_spectrum(f, coh, coherence_axis, r"coherence", False, False)
- coherence_axis.xaxis.set_label_coords(-0.25, -0.23)
- def plot_white_noise_response(spikes, resp_axis, sigma, start_time, extent):
- spikes_axis = resp_axis.twinx()
- time, responses = get_firing_rate(spikes, 10., 1./20000., sigma)
- avg_response = np.mean(responses, axis=0)
- std_responses = np.std(responses, axis=0)
- error_plus = avg_response + std_responses
- error_minus = avg_response - std_responses
- resp_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time,
- avg_response[(time >= start_time) & (time < start_time + extent)],
- color='tab:blue', lw=0.75)
- resp_axis.fill_between(time[(time >= start_time) & (time < start_time + extent)] - start_time,
- error_plus[(time >= start_time) & (time < start_time + extent)],
- error_minus[(time >= start_time) & (time < start_time + extent)],
- color="tab:blue", alpha=0.5, lw=0)
- ylim = [0, np.ceil(np.max(error_plus[(time >= start_time) & (time < start_time + extent)])/50) * 50]
- resp_axis.set_ylim(ylim)
- resp_axis.plot([0.0, 0.1], [resp_axis.get_ylim()[0] -0.5]*2, color="k", lw=0.5)
- resp_axis.text(0.05, -125, r"100 ms", va="bottom", ha="center", fontsize=legend_fontsize)
- resp_axis.plot([-0.005, -0.005], [0, 200], color="k", lw=1, clip_on=False)
- resp_axis.text(-.02, -0.3, "200 Hz", fontsize=legend_fontsize, rotation=90)
- resp_axis.set_xlim([-0.0, extent])
- despine(resp_axis)
- for i in range(len(spikes)):
- spike_times = spikes[i]
- spike_times = spike_times[(spike_times >= start_time) & (spike_times < start_time + extent)]
- spike_times -= start_time
- spikes_axis.scatter(spike_times, np.ones(len(spike_times)) * i, marker="|", s=30,
- color="tab:blue", lw=0.2, alpha=0.5)
- spikes_axis.set_ylim([-0.5, 5.25])
- despine(spikes_axis)
- def plot_white_noise_stim(stimulus, trace_axis, start_time, extent):
- time = stimulus[0]
- stim = stimulus[1]
- trace_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time,
- -1*stim[(time >= start_time) & (time < start_time + extent)],
- color='tab:red', lw=0.5)
- trace_axis.set_ylim([-0.5, 0.5])
- trace_axis.set_xlim([0.0, extent])
- #trace_axis.plot([0.0, 0.1], [trace_axis.get_ylim()[1], trace_axis.get_ylim()[1]], color="k")
- #trace_axis.text(0.05, trace_axis.get_ylim()[1], r"100 ms", va="bottom", ha="center", fontsize=7)
- despine(trace_axis)
- def plot_spectrum(freq, spectrum, axis, ylabel, yticklabels=False, logplot=True):
- k = np.ones(15) / 15
- smoothed_spectrum = np.zeros(spectrum.shape)
- for i in range(spectrum.shape[0]):
- smoothed_spectrum[i, :] = np.convolve(np.abs(spectrum[i, :]), k, mode="same")
- # smoothed_gain[i, :] = np.abs(gain[i, :])
- freq_limit = 293
- avg_spectrum = np.mean(smoothed_spectrum, axis=0)
- std_spectrum = np.std(smoothed_spectrum, axis=0)
- avg_error_minus = avg_spectrum - std_spectrum
- avg_error_plus = avg_spectrum + std_spectrum
- max_spectum = np.max(avg_spectrum[freq < freq_limit])
- peak_f = freq[np.where(avg_spectrum == max_spectum)]
- upper_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq > peak_f)][0]
- lower_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq < peak_f)][-1]
-
- if logplot:
- axis.semilogy(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue")
- else:
- axis.plot(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue")
- axis.set_ylim([0, 1])
- axis.set_yticks(np.arange(0, 1.01, 0.25))
- axis.fill_between(freq[freq < freq_limit], avg_error_minus[freq < freq_limit], avg_error_plus[freq < freq_limit],
- lw=0.0, color="tab:blue", alpha=0.5)
- despine(axis, ["right", "top"], False)
- axis.set_yticklabels([])
- axis.set_xticks(np.arange(0, 301, 150))
- axis.set_xticks(np.arange(0., 301, 50), minor=True)
- axis.set_xticklabels(np.arange(0, 301, 150))
- axis.set_xlabel("frequency [Hz]", fontsize=label_size)
- axis.set_xlim([0, 300])
- axis.set_ylim(axis.get_ylim())
- axis.set_ylabel(ylabel)
- axis.yaxis.set_label_coords(-0.075, 0.5)
- axis.plot([0, peak_f[0]], [max_spectum, max_spectum], color="r", ls="--", lw=0.75)
- axis.plot([0, upper_cutoff], [avg_spectrum[freq == upper_cutoff], avg_spectrum[freq == upper_cutoff]],
- color='r', ls="--", lw=0.75)
- axis.plot([upper_cutoff, upper_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == upper_cutoff][0]],
- color='r', ls="--", lw=0.75)
- axis.plot([lower_cutoff, lower_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == lower_cutoff][0]],
- color='r', ls="--", lw=0.75)
- max_text = "max(H(f))" if logplot else "max(coh)"
- axis.annotate(max_text, (peak_f/2, max_spectum), (75, max_spectum*1.25), fontsize=6, color="r",
- arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75}, annotation_clip=False)
- cutoff_text = r"$\frac{max(H(f))}{\sqrt{2}}$" if logplot else r"$\frac{max(coh)}{\sqrt{2}}$"
- axis.annotate(cutoff_text, (peak_f, max_spectum/np.sqrt(2)), (150, max_spectum * 0.75),
- fontsize=8, color="r", annotation_clip=False,
- arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75})
- text_ypos = max_spectum/2 if logplot else 0.5
- axis.text(lower_cutoff + 5, text_ypos, "lower cutoff", rotation=-90, fontsize=6, color='r',
- ha="left", va="top")
- axis.text(upper_cutoff + 5, text_ypos, "upper cutoff", rotation=-90, fontsize=6, color='r',
- ha="left", va="top")
- def layout_figure():
- fig = plt.figure(figsize=(5.1, 3.75))
- gs = gridspec.GridSpec(2, 1, left=0.05, right=0.975, top=0.95, bottom=0.1,
- wspace=0.15, height_ratios=[3, 2],hspace=0.4, figure=fig)
- sgs_top = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0], height_ratios=[1, 1], hspace=0.1)
- sgs_bottom = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs[1], wspace=0.4)
- eod_axis = fig.add_subplot(sgs_top[0, 0])
- eod_axis.text(-0.1, 1.05, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=eod_axis.transAxes)
- baseline_resp_axis = fig.add_subplot(sgs_top[1, 0])
- white_noise_stimulus_axis = fig.add_subplot(sgs_top[0, 1])
- white_noise_stimulus_axis.text(-0.1, 1.05, "D", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=white_noise_stimulus_axis.transAxes)
- white_noise_response_axis = fig.add_subplot(sgs_top[1, 1])
- vector_strength_axis = fig.add_subplot(sgs_bottom[0])
- vector_strength_axis.text(-0.25, 1.05, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=vector_strength_axis.transAxes)
- isi_axis = fig.add_subplot(sgs_bottom[1])
- isi_axis.text(-0.25, 1.05, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=isi_axis.transAxes)
- gain_axis = fig.add_subplot(sgs_bottom[2])
- gain_axis.text(-0.3, 1.05, "E", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=gain_axis.transAxes)
- coherence_axis = fig.add_subplot(sgs_bottom[3])
- coherence_axis.text(-0.3, 1.05, "F", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=coherence_axis.transAxes)
- axes = {"eod": eod_axis, "baseline_response": baseline_resp_axis, "vector_strength": vector_strength_axis,
- "isi": isi_axis, "whitenoise_stimulus": white_noise_stimulus_axis,
- "whitenoise_response": white_noise_response_axis, "gain": gain_axis, "coherence": coherence_axis}
- return fig, axes
- def methods_plot(args):
- if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_baseline_data.npz")):
- analyze_baseline_activity(args)
- if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_whitenoise_data.npz")):
- analyze_driven_activity(args)
- fig, axes = layout_figure()
- plot_baseline_activity(axes["eod"], axes["baseline_response"],
- axes["vector_strength"], axes["isi"])
- plot_whitenoise_activity(axes["whitenoise_stimulus"], axes["whitenoise_response"],
- axes["gain"], axes["coherence"])
- fig.subplots_adjust(bottom=0.1125, left=0.05, top=0.95, right=0.97)
- if args.nosave:
- plt.show()
- else:
- plt.savefig(args.outfile)
- plt.close()
- def command_line_parser(subparsers):
- methods_parser = subparsers.add_parser("response_features", help=fig1_help)
- methods_parser.add_argument("-d", "--dataset", default=os.path.join("raw_data", "2018-09-06-ai-invivo-1.nix"), help="the dataset that contains baseline and driven data, only needed when redoing stuff from scratch")
- methods_parser.add_argument("-o", "--outfile", help="name of the output figure", default="figures/methods.pdf")
- methods_parser.add_argument("-n", "--nosave", action="store_true", help="just plot and show, no saving of the figure")
- methods_parser.add_argument("-r", "--redo", action="store_true", help="redo analysis of cell data")
- methods_parser.set_defaults(func=methods_plot)
- if __name__ == "__main__":
- methods_plot()
|