import argparse import pandas as pd import numpy as np import seaborn as sns import matplotlib.pyplot as plt import matplotlib as mpl from types import SimpleNamespace import sys import copy from pathlib import Path from cobrawap.pipeline.utils.neo_utils import analogsignal_to_imagesequence from cobrawap.pipeline.utils.io_utils import load_neo sys.path.append(str(Path.cwd().parent)) from plotting_utils import filter_velocity_local, load_df, colormap def plot_signal_trace(asig, event, label, ax, pixel=(40,30)): x_coord = pixel[0] y_coord = pixel[1] channel = np.where((asig.array_annotations['x_coords']==x_coord) & (asig.array_annotations['y_coords']==y_coord))[0] ax.plot(asig.times, asig.as_array()[:,channel], color='0.7', linewidth=3, label='signal') event_channels = np.where((event.array_annotations['x_coords']==x_coord) & (event.array_annotations['y_coords']==y_coord)) for event_time, event_label in zip(event[event_channels].times, event[event_channels].labels): ls = ':' if event_label=='-1' else '-' ax.axvline(event_time, color='k', linestyle=ls, label='trigger') return ax def plot_planarity(waves_event, vector_field, times, wave_label, t_limits=None, ax=None): dim_t, dim_y, dim_x = vector_field.shape skip_step = int(min([dim_x, dim_y]) / 50) + 1 idx = np.where(wave_label == waves_event.labels)[0] t_idx = np.array([np.argmax(times >= t) for t in waves_event.times[idx]]) x = waves_event.array_annotations['x_coords'][idx] y = waves_event.array_annotations['y_coords'][idx] wave_directions = vector_field[t_idx, y.astype(int), x.astype(int)] norm = np.array([np.linalg.norm(w) for w in wave_directions]) wave_directions /= norm planarity = np.linalg.norm(np.mean(wave_directions)) if t_limits is None: nbr_colors = len(np.unique(t_idx)) else: t_steps = np.arange(t_limits[0], t_limits[1]+0.04, 0.04) nbr_colors = len(t_steps) palette = sns.husl_palette(nbr_colors+1, h=0.3, l=0.4)[:-1] if ax is None: fig, ax = plt.subplots() area = copy.copy(np.real(vector_field[0])) area[np.where(np.isfinite(area))] = 0 ax.imshow(area, interpolation='nearest', origin='lower', vmin=-1, vmax=1, cmap='RdBu') for i, frame_t in enumerate(np.unique(t_idx)): frame_i = np.where(frame_t == t_idx)[0].astype(int) xi = x[frame_i] yi = y[frame_i] ti = t_idx[frame_i] frame = vector_field[ti, yi.astype(int), xi.astype(int)].magnitude norm = np.array([np.linalg.norm(w) for w in frame]) frame /= norm frame_time = times[frame_t].rescale('s').magnitude color_idx = np.argmin(np.abs(np.array(t_steps) - frame_time)) ax.quiver(xi, yi, np.real(frame), np.imag(frame), scale=45, width=0.003, # units='width', scale=max(frame.shape)/(10*skip_step), # width=0.15/max(frame.shape), color=palette[color_idx], alpha=0.8, label=f'{frame_time:.2f} s') ax.axis('image') ax.set_xticks([]) ax.set_yticks([]) # start_t = np.min(waves_event.times[idx]).rescale('s').magnitude # stop_t = np.max(waves_event.times[idx]).rescale('s').magnitude ax.set_xlabel(f'planarity {planarity:.2f}') patches = [mpl.patches.Patch(color=c) for c in palette] legend = ax.legend(patches, [f'{t:.2f} s' for t in t_steps], frameon=False, bbox_to_anchor=(1,.5), loc='center left', fontsize=12) return ax def plot_figure(local_data, global_data, signal_path, alt_signal_path): # """ # A1 A1 B1 B2 B2 # A2 A2 B1 B2 B2 # X X X X X # C1 C2 C3 C4 C5a # C1 C2 C3 C4 C5b # """ sns.set(style='ticks', palette='deep', context='talk') fig = plt.figure(figsize=(25,14), constrained_layout=True) ax = SimpleNamespace() gs = mpl.gridspec.GridSpec(nrows=4, ncols=4, wspace=0.2, hspace=.17, height_ratios=(.5,.5,.1,1)) ax.A1 = fig.add_subplot(gs[0, 0:2]) ax.A2 = fig.add_subplot(gs[1, 0:2]) ax.B1 = fig.add_subplot(gs[:2, 2]) ax.B2 = fig.add_subplot(gs[:2, 3]) ax.C1 = fig.add_subplot(gs[3, 0]) ax.C2 = fig.add_subplot(gs[3, 1]) ax.C3 = fig.add_subplot(gs[3, 2]) ax.C4 = fig.add_subplot(gs[3, 3]) axes = vars(ax) # SIGNAL TRACES + WAVE FIELDS # wave_labels = {'hilbert phase': '22', 'minima': '38'} wave_labels = {'hilbert phase': '30', 'minima': '51'} # t_limits = (19.12, 19.56) t_limits = (29.00, 29.48) pixel = (30,50) for ax.Ai, ax.Bi, path, method in zip([ax.A1, ax.A2], [ax.B1, ax.B2], [signal_path, alt_signal_path], ['hilbert phase', 'minima']): block = load_neo(str(path)) asig = block.segments[0].analogsignals[0] flow_asig = block.filter(name='optical_flow', objects="AnalogSignal")[0] flow = analogsignal_to_imagesequence(flow_asig) event = block.filter(name='wavefronts', objects="Event")[0] plot_signal_trace(asig=asig, event=event, label=method, pixel=pixel, ax=ax.Ai) plot_planarity(waves_event=event, vector_field=flow, times=asig.times, wave_label=wave_labels[method], t_limits=t_limits, ax=ax.Bi) ax.Bi.scatter([pixel[0]], [pixel[1]], c='k', marker='s', s=15) del block, asig, flow, event # VIOLIN PLOTS violin_params = dict(orient='h', palette=colormap, inner='quartile', cut=0, alpha=.8, linewidth=1, scale='width', saturation=1, y='method', hue='anesthetic', split='True', hue_order=['isoflurane', 'ketamine']) bw = .2 ## wave characteristics sns.violinplot(data=filter_velocity_local(local_data), x='velocity_local', **violin_params, ax=ax.C1) sns.violinplot(data=local_data, x='inter_wave_interval_local', bw=bw, **violin_params, ax=ax.C2) sns.violinplot(data=global_data, x='planarity', **violin_params, ax=ax.C3) # sns.violinplot(data=global_data, x='number_of_triggers', # **violin_params, ax=ax.C4) ## number of waves bar_width = .1 for j, method in enumerate(['hilbert phase', 'minima']): for i, anesthetic in enumerate(['ketamine', 'isoflurane']): plot_data = global_data[(global_data.anesthetic == anesthetic) \ & (global_data.method == method)] shift = bar_width/2 - i*bar_width ax.C4.barh(j+shift, len(plot_data), align='center', height=bar_width, color=colormap[anesthetic], edgecolor='k', linewidth=.5) ## legends for panel in ['B1', 'C2', 'C3', 'C4']: if axes[panel].get_legend() is not None: axes[panel].get_legend().remove() handles, labels = ax.A2.get_legend_handles_labels() ax.A2.legend(handles[:2], labels[:2], frameon=True, loc='upper center', bbox_to_anchor=(.1, 1, .2, .35), framealpha=1) handles, labels = ax.C1.get_legend_handles_labels() ax.C1.legend(handles, labels, frameon=True, loc='center right', title='') ## ticks & labels ax.A1.set_xticks([]) ax.A1.set_xlabel('') for panel in ['C2', 'C3', 'C4']: axes[panel].set_yticks([]) axes[panel].set_ylabel('') for panel, label in zip(['A1', 'A2'], ['hilbert phase', 'minima']): axes[panel].set_yticks([0]) axes[panel].set_yticklabels([label], rotation=90, va='center') ax.A2.set_xlabel('time [s]') ax.C1.set_ylabel('') ax.C1.set_yticklabels(ax.C1.get_yticklabels(), rotation=90, va='center') ax.C1.set_xlabel('velocity [mm/s]') ax.C2.set_xlabel('inter-wave interval [s]') ax.C3.set_xlabel('planarity') ax.C4.set_xlabel('number of waves') ax.B1.set_title('hilbert phase') ax.B2.set_title('minima') ## limits ax.A1.set_xlim((10,30)) ax.A2.set_xlim((10,30)) ax.C2.set_xlim((0, 2)) ax.C4.set_ylim((ax.C1.get_ylim()[0],ax.C1.get_ylim()[1])) ## spines for panel in ['A1', 'A2', 'C1', 'C2', 'C3', 'C4']: sns.despine(left=True, ax=axes[panel]) sns.despine(bottom=True, left=True, ax=ax.A1) ## panel letters dx, dy = .05, .05 text_params = dict(ha='right', va='center', fontsize=20, fontweight='bold') for panel in ['B', 'C']: axi = axes[f'{panel}1'] axi.text(s=panel, x=-dx, y=1+dy, transform=axi.transAxes, **text_params) ax.A1.text(s='A', x=-dx/2, y=1+2*dy, transform=ax.A1.transAxes, **text_params) # fig.align_labels() return fig, ax if __name__ == '__main__': CLI = argparse.ArgumentParser() CLI.add_argument("--local_data", nargs='?', type=Path, required=True) CLI.add_argument("--global_data", nargs='?', type=Path, required=True) CLI.add_argument("--alt_local_data", nargs='?', type=Path, required=True) CLI.add_argument("--alt_global_data", nargs='?', type=Path, required=True) CLI.add_argument("--signal", nargs='?', type=Path, required=True) CLI.add_argument("--alt_signal", nargs='?', type=Path, required=True) CLI.add_argument("--output", nargs='?', type=Path, required=True) args, unknown = CLI.parse_known_args() data = {} for measure_type in ['local', 'global']: dfs = [] for path, method in zip([args.__dict__[f'{measure_type}_data'], args.__dict__[f'alt_{measure_type}_data']], ['hilbert phase', 'minima']): df = load_df(path) df = df[(df['technique'] == 'calcium imaging')] df['method'] = method dfs.append(df) data[f'{measure_type}_data'] = pd.concat(dfs, ignore_index=True) # hotfix data['global_data']['number_of_triggers'] = \ data['global_data']['number_of_triggers'].astype(float) plot_figure(**data, signal_path = args.signal, alt_signal_path = args.alt_signal) plt.savefig(args.output, bbox_inches='tight')