data_analyses.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. #####################################################################################################
  2. ## methods plot with baseline and driven activity of an example cell ##
  3. import os
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from matplotlib import gridspec
  7. from matplotlib.patches import ArrowStyle
  8. from .figure_style import subfig_labelsize, subfig_labelweight, legend_fontsize, despine, label_size
  9. from .figure1_analysis import analyze_baseline_activity, analyze_driven_activity, get_firing_rate
  10. fig1_help="Plots a methods figure with baseline activity and stimulus driven activity. Depends on figure1_baseline_data.npz and figure1_whitenoise_data.npz files. These are expected in the ./derived_data folder"
  11. def plot_isi(spikes, cv, burstiness, axis):
  12. axis.hist(np.diff(spikes), bins=np.arange(0.0, 0.015, 0.0002), density=True,
  13. alpha=0.75)
  14. axis.text(0.95 * axis.get_xlim()[1], 0.85 * axis.get_ylim()[1], r"$CV_{ISI}$: %.2f" % (cv),
  15. fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom")
  16. axis.text(0.95 * axis.get_xlim()[1], 1.0 * axis.get_ylim()[1],
  17. "Burstiness: %.2f" % (burstiness),
  18. fontsize=legend_fontsize, color="tab:blue", ha="right", va="bottom")
  19. despine(axis)
  20. axis.spines["bottom"].set_visible(True)
  21. axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.001), minor=True)
  22. axis.xaxis.set_ticks(np.arange(0.0, 0.016, 0.005))
  23. axis.xaxis.set_ticklabels(np.arange(0, 16, 5))
  24. axis.set_xlabel("ISI [ms]", fontsize=label_size)
  25. axis.xaxis.set_label_coords(0.5, -0.225)
  26. def plot_phase_locking(eod_waveform, eod_error, spike_phases, waveform_axis, vs, preferred_phase, eodf, ylim):
  27. delay = 2* np.pi/preferred_phase / eodf
  28. x = np.linspace(0.0, 2 * np.pi, len(eod_waveform))
  29. waveform_axis.plot(x, eod_waveform, color="tab:orange", lw=0.75, zorder=2)
  30. waveform_axis.fill_between(x, eod_waveform + eod_error, eod_waveform - eod_error,
  31. color="tab:orange", alpha=0.25, zorder=1, lw=0.0)
  32. spike_phase_ax = waveform_axis.twinx()
  33. spike_phase_ax.hist(spike_phases, bins=np.arange(0.0, 2 * np.pi + 0.01, np.pi / 12),
  34. density=True, alpha=0.5, color="tab:blue", edgecolor="steelblue")
  35. waveform_axis.text(0.95 * waveform_axis.get_xlim()[1], 1.25 * waveform_axis.get_ylim()[1],
  36. "VS: %.2f" % np.abs(vs),
  37. va="bottom", ha="right", color="tab:blue", fontsize=legend_fontsize)
  38. waveform_axis.set_ylim([ylim[0], 1.25 * ylim[-1]])
  39. despine(waveform_axis, ["left", "top", "right", "bottom"], True)
  40. arrow = ArrowStyle.BarAB(widthA=0.15, widthB=.15)
  41. waveform_axis.annotate(text='', xy=(0, 1.05 * np.max(eod_waveform)), xytext=(preferred_phase, 1.05 * np.max(eod_waveform)), arrowprops=dict(arrowstyle=arrow, color="steelblue"))
  42. waveform_axis.text(preferred_phase/2, 1.1 * np.max(eod_waveform), f"Delay: {delay*1000:.2f}ms",
  43. va="bottom", ha="center", color="tab:blue", fontsize=legend_fontsize)
  44. despine(spike_phase_ax, ["left", "top", "right"], True)
  45. spike_phase_ax.spines["bottom"].set_visible(True)
  46. spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/2))
  47. spike_phase_ax.set_xticks(np.arange(0., 2*np.pi+0.1, np.pi/4), minor=True)
  48. spike_phase_ax.set_xticklabels([r"0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$",
  49. r"$2\pi$"])
  50. waveform_axis.set_xlabel("phase [rad.]", fontsize=label_size)
  51. waveform_axis.xaxis.set_label_coords(0.5, -0.225)
  52. ylim = spike_phase_ax.get_ylim()
  53. spike_phase_ax.set_ylim([ylim[0], 1.25 * ylim[-1]])
  54. def plot_eod(time, eod, eodf, axis, start_idx, end_idx):
  55. axis.plot(time[start_idx:end_idx] - time[start_idx],
  56. eod[start_idx:end_idx],
  57. color="tab:orange", lw=0.5, label="EOD")
  58. ylim = axis.get_ylim()
  59. axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 2.], color="k", lw=0.5)
  60. axis.text(-.0025, ylim[0] + 1, r"2 mV", rotation=90, va="center", fontsize=legend_fontsize)
  61. axis.text(0.95 * axis.get_xlim()[1], 1.25 * ylim[1],
  62. "EOD frequency: %i Hz" % (eodf),
  63. color="tab:orange", fontsize=legend_fontsize, ha="right", va="top")
  64. axis.set_xlim([-0.001, time[end_idx] - time[start_idx]])
  65. despine(axis)
  66. def plot_voltage_spikes(time, voltage, baseline_rate, v_axis,
  67. start_idx, end_idx):
  68. #s_axis = v_axis.twinx()
  69. v_axis.plot(time[start_idx:end_idx] - time[start_idx],
  70. voltage[start_idx:end_idx],
  71. color="tab:blue", lw=0.9, label="voltage")
  72. ylim = v_axis.get_ylim()
  73. v_axis.plot([0.0, 0.01], [ylim[0], ylim[0]], color="k")
  74. v_axis.plot([-0.0005, -0.0005], [ylim[0], ylim[0] + 5], color="k", lw=0.5)
  75. v_axis.text(-.0025, ylim[0] + 2.5, r"5 mV", rotation=90, va="center", fontsize=legend_fontsize)
  76. v_axis.text(0.005, ylim[0] - 1.25, "10 ms", ha="center", va="top", fontsize=legend_fontsize)
  77. v_axis.text(0.95 * v_axis.get_xlim()[1], 1. * ylim[1],
  78. "Baseline firing rate: %i Hz" % (baseline_rate),
  79. color="tab:blue", fontsize=legend_fontsize, ha="right", va="bottom")
  80. v_axis.set_xlim([-0.001, time[end_idx] - time[start_idx]])
  81. v_axis.set_ylim(ylim)
  82. despine(v_axis)
  83. def plot_baseline_activity(eod_axis, response_axis, vs_axis, isi_axis):
  84. data = np.load(os.path.join("derived_data", "figure1_baseline_data.npz"))
  85. plot_eod(data["time"], data["eod"], data["eodf"], eod_axis, 950, 1575)
  86. plot_voltage_spikes(data["time"], data["voltage"], data["baseline_rate"], response_axis, 950, 1575)
  87. plot_isi(data["spike_times"], data["cv"], data["burstiness"], isi_axis)
  88. plot_phase_locking(data["eod_template"], data["template_error"], data["spike_phases"], vs_axis, data["vector_strength"], data["preferred_phase"], data["eodf"], eod_axis.get_ylim())
  89. def plot_whitenoise_activity(stimulus_axis, response_axis, gain_axis, coherence_axis):
  90. data = np.load(os.path.join("derived_data", "figure1_whitenoise_data.npz"))
  91. plot_white_noise_stim(data["stimulus"], stimulus_axis, 1.0, 0.25)
  92. spikes = []
  93. temp = data["spikes"]
  94. for i in np.unique(data["trial_data"]):
  95. spikes.append(temp[data["trial_data"] == i])
  96. plot_white_noise_response(spikes, response_axis, 0.005, 1.0, 0.25)
  97. f = data["frequency"]
  98. gain = data["gain_spectra"]
  99. coh = data["coherence_spectra"]
  100. plot_spectrum(f, gain, gain_axis, r"gain [Hz mV$^{-1}$]", False, True)
  101. gain_axis.set_xlabel("")
  102. plot_spectrum(f, coh, coherence_axis, r"coherence", False, False)
  103. coherence_axis.xaxis.set_label_coords(-0.25, -0.23)
  104. def plot_white_noise_response(spikes, resp_axis, sigma, start_time, extent):
  105. spikes_axis = resp_axis.twinx()
  106. time, responses = get_firing_rate(spikes, 10., 1./20000., sigma)
  107. avg_response = np.mean(responses, axis=0)
  108. std_responses = np.std(responses, axis=0)
  109. error_plus = avg_response + std_responses
  110. error_minus = avg_response - std_responses
  111. resp_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time,
  112. avg_response[(time >= start_time) & (time < start_time + extent)],
  113. color='tab:blue', lw=0.75)
  114. resp_axis.fill_between(time[(time >= start_time) & (time < start_time + extent)] - start_time,
  115. error_plus[(time >= start_time) & (time < start_time + extent)],
  116. error_minus[(time >= start_time) & (time < start_time + extent)],
  117. color="tab:blue", alpha=0.5, lw=0)
  118. ylim = [0, np.ceil(np.max(error_plus[(time >= start_time) & (time < start_time + extent)])/50) * 50]
  119. resp_axis.set_ylim(ylim)
  120. resp_axis.plot([0.0, 0.1], [resp_axis.get_ylim()[0] -0.5]*2, color="k", lw=0.5)
  121. resp_axis.text(0.05, -125, r"100 ms", va="bottom", ha="center", fontsize=legend_fontsize)
  122. resp_axis.plot([-0.005, -0.005], [0, 200], color="k", lw=1, clip_on=False)
  123. resp_axis.text(-.02, -0.3, "200 Hz", fontsize=legend_fontsize, rotation=90)
  124. resp_axis.set_xlim([-0.0, extent])
  125. despine(resp_axis)
  126. for i in range(len(spikes)):
  127. spike_times = spikes[i]
  128. spike_times = spike_times[(spike_times >= start_time) & (spike_times < start_time + extent)]
  129. spike_times -= start_time
  130. spikes_axis.scatter(spike_times, np.ones(len(spike_times)) * i, marker="|", s=30,
  131. color="tab:blue", lw=0.2, alpha=0.5)
  132. spikes_axis.set_ylim([-0.5, 5.25])
  133. despine(spikes_axis)
  134. def plot_white_noise_stim(stimulus, trace_axis, start_time, extent):
  135. time = stimulus[0]
  136. stim = stimulus[1]
  137. trace_axis.plot(time[(time >= start_time) & (time < start_time + extent)] - start_time,
  138. -1*stim[(time >= start_time) & (time < start_time + extent)],
  139. color='tab:red', lw=0.5)
  140. trace_axis.set_ylim([-0.5, 0.5])
  141. trace_axis.set_xlim([0.0, extent])
  142. #trace_axis.plot([0.0, 0.1], [trace_axis.get_ylim()[1], trace_axis.get_ylim()[1]], color="k")
  143. #trace_axis.text(0.05, trace_axis.get_ylim()[1], r"100 ms", va="bottom", ha="center", fontsize=7)
  144. despine(trace_axis)
  145. def plot_spectrum(freq, spectrum, axis, ylabel, yticklabels=False, logplot=True):
  146. k = np.ones(15) / 15
  147. smoothed_spectrum = np.zeros(spectrum.shape)
  148. for i in range(spectrum.shape[0]):
  149. smoothed_spectrum[i, :] = np.convolve(np.abs(spectrum[i, :]), k, mode="same")
  150. # smoothed_gain[i, :] = np.abs(gain[i, :])
  151. freq_limit = 293
  152. avg_spectrum = np.mean(smoothed_spectrum, axis=0)
  153. std_spectrum = np.std(smoothed_spectrum, axis=0)
  154. avg_error_minus = avg_spectrum - std_spectrum
  155. avg_error_plus = avg_spectrum + std_spectrum
  156. max_spectum = np.max(avg_spectrum[freq < freq_limit])
  157. peak_f = freq[np.where(avg_spectrum == max_spectum)]
  158. upper_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq > peak_f)][0]
  159. lower_cutoff = freq[(avg_spectrum <= max_spectum/np.sqrt(2)) & (freq < peak_f)][-1]
  160. if logplot:
  161. axis.semilogy(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue")
  162. else:
  163. axis.plot(freq[freq < freq_limit], avg_spectrum[freq < freq_limit], lw=0.5, color="tab:blue")
  164. axis.set_ylim([0, 1])
  165. axis.set_yticks(np.arange(0, 1.01, 0.25))
  166. axis.fill_between(freq[freq < freq_limit], avg_error_minus[freq < freq_limit], avg_error_plus[freq < freq_limit],
  167. lw=0.0, color="tab:blue", alpha=0.5)
  168. despine(axis, ["right", "top"], False)
  169. axis.set_yticklabels([])
  170. axis.set_xticks(np.arange(0, 301, 150))
  171. axis.set_xticks(np.arange(0., 301, 50), minor=True)
  172. axis.set_xticklabels(np.arange(0, 301, 150))
  173. axis.set_xlabel("frequency [Hz]", fontsize=label_size)
  174. axis.set_xlim([0, 300])
  175. axis.set_ylim(axis.get_ylim())
  176. axis.set_ylabel(ylabel)
  177. axis.yaxis.set_label_coords(-0.075, 0.5)
  178. axis.plot([0, peak_f[0]], [max_spectum, max_spectum], color="r", ls="--", lw=0.75)
  179. axis.plot([0, upper_cutoff], [avg_spectrum[freq == upper_cutoff], avg_spectrum[freq == upper_cutoff]],
  180. color='r', ls="--", lw=0.75)
  181. axis.plot([upper_cutoff, upper_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == upper_cutoff][0]],
  182. color='r', ls="--", lw=0.75)
  183. axis.plot([lower_cutoff, lower_cutoff], [axis.get_ylim()[0], avg_spectrum[freq == lower_cutoff][0]],
  184. color='r', ls="--", lw=0.75)
  185. max_text = "max(H(f))" if logplot else "max(coh)"
  186. axis.annotate(max_text, (peak_f/2, max_spectum), (75, max_spectum*1.25), fontsize=6, color="r",
  187. arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75}, annotation_clip=False)
  188. cutoff_text = r"$\frac{max(H(f))}{\sqrt{2}}$" if logplot else r"$\frac{max(coh)}{\sqrt{2}}$"
  189. axis.annotate(cutoff_text, (peak_f, max_spectum/np.sqrt(2)), (150, max_spectum * 0.75),
  190. fontsize=8, color="r", annotation_clip=False,
  191. arrowprops={"arrowstyle": "-", "color": "r", "linewidth": 0.75})
  192. text_ypos = max_spectum/2 if logplot else 0.5
  193. axis.text(lower_cutoff + 5, text_ypos, "lower cutoff", rotation=-90, fontsize=6, color='r',
  194. ha="left", va="top")
  195. axis.text(upper_cutoff + 5, text_ypos, "upper cutoff", rotation=-90, fontsize=6, color='r',
  196. ha="left", va="top")
  197. def layout_figure():
  198. fig = plt.figure(figsize=(5.1, 3.75))
  199. gs = gridspec.GridSpec(2, 1, left=0.05, right=0.975, top=0.95, bottom=0.1,
  200. wspace=0.15, height_ratios=[3, 2],hspace=0.4, figure=fig)
  201. sgs_top = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0], height_ratios=[1, 1], hspace=0.1)
  202. sgs_bottom = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs[1], wspace=0.4)
  203. eod_axis = fig.add_subplot(sgs_top[0, 0])
  204. eod_axis.text(-0.1, 1.05, "A", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  205. transform=eod_axis.transAxes)
  206. baseline_resp_axis = fig.add_subplot(sgs_top[1, 0])
  207. white_noise_stimulus_axis = fig.add_subplot(sgs_top[0, 1])
  208. white_noise_stimulus_axis.text(-0.1, 1.05, "D", fontsize=subfig_labelsize, fontweight=subfig_labelweight, transform=white_noise_stimulus_axis.transAxes)
  209. white_noise_response_axis = fig.add_subplot(sgs_top[1, 1])
  210. vector_strength_axis = fig.add_subplot(sgs_bottom[0])
  211. vector_strength_axis.text(-0.25, 1.05, "B", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  212. transform=vector_strength_axis.transAxes)
  213. isi_axis = fig.add_subplot(sgs_bottom[1])
  214. isi_axis.text(-0.25, 1.05, "C", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  215. transform=isi_axis.transAxes)
  216. gain_axis = fig.add_subplot(sgs_bottom[2])
  217. gain_axis.text(-0.3, 1.05, "E", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  218. transform=gain_axis.transAxes)
  219. coherence_axis = fig.add_subplot(sgs_bottom[3])
  220. coherence_axis.text(-0.3, 1.05, "F", fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  221. transform=coherence_axis.transAxes)
  222. axes = {"eod": eod_axis, "baseline_response": baseline_resp_axis, "vector_strength": vector_strength_axis,
  223. "isi": isi_axis, "whitenoise_stimulus": white_noise_stimulus_axis,
  224. "whitenoise_response": white_noise_response_axis, "gain": gain_axis, "coherence": coherence_axis}
  225. return fig, axes
  226. def methods_plot(args):
  227. if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_baseline_data.npz")):
  228. analyze_baseline_activity(args)
  229. if args.redo or not os.path.exists(os.path.join("derived_data", "figure1_whitenoise_data.npz")):
  230. analyze_driven_activity(args)
  231. fig, axes = layout_figure()
  232. plot_baseline_activity(axes["eod"], axes["baseline_response"],
  233. axes["vector_strength"], axes["isi"])
  234. plot_whitenoise_activity(axes["whitenoise_stimulus"], axes["whitenoise_response"],
  235. axes["gain"], axes["coherence"])
  236. fig.subplots_adjust(bottom=0.1125, left=0.05, top=0.95, right=0.97)
  237. if args.nosave:
  238. plt.show()
  239. else:
  240. plt.savefig(args.outfile)
  241. plt.close()
  242. def command_line_parser(subparsers):
  243. methods_parser = subparsers.add_parser("response_features", help=fig1_help)
  244. methods_parser.add_argument("-d", "--dataset", default=os.path.join("raw_data", "2018-09-06-ai-invivo-1.nix"), help="the dataset that contains baseline and driven data, only needed when redoing stuff from scratch")
  245. methods_parser.add_argument("-o", "--outfile", help="name of the output figure", default="figures/methods.pdf")
  246. methods_parser.add_argument("-n", "--nosave", action="store_true", help="just plot and show, no saving of the figure")
  247. methods_parser.add_argument("-r", "--redo", action="store_true", help="redo analysis of cell data")
  248. methods_parser.set_defaults(func=methods_plot)
  249. if __name__ == "__main__":
  250. methods_plot()