123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- import argparse
- import pandas as pd
- import numpy as np
- import seaborn as sns
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from types import SimpleNamespace
- import sys
- import copy
- from pathlib import Path
- from cobrawap.pipeline.utils.neo_utils import analogsignal_to_imagesequence
- from cobrawap.pipeline.utils.io_utils import load_neo
- sys.path.append(str(Path.cwd().parent))
- from plotting_utils import filter_velocity_local, load_df, colormap
- def plot_signal_trace(asig, event, label, ax, pixel=(40,30)):
- x_coord = pixel[0]
- y_coord = pixel[1]
- channel = np.where((asig.array_annotations['x_coords']==x_coord)
- & (asig.array_annotations['y_coords']==y_coord))[0]
- ax.plot(asig.times, asig.as_array()[:,channel], color='0.7', linewidth=3,
- label='signal')
- event_channels = np.where((event.array_annotations['x_coords']==x_coord)
- & (event.array_annotations['y_coords']==y_coord))
- for event_time, event_label in zip(event[event_channels].times,
- event[event_channels].labels):
- ls = ':' if event_label=='-1' else '-'
- ax.axvline(event_time, color='k', linestyle=ls,
- label='trigger')
- return ax
- def plot_planarity(waves_event, vector_field, times, wave_label,
- t_limits=None, ax=None):
- dim_t, dim_y, dim_x = vector_field.shape
- skip_step = int(min([dim_x, dim_y]) / 50) + 1
- idx = np.where(wave_label == waves_event.labels)[0]
- t_idx = np.array([np.argmax(times >= t) for t
- in waves_event.times[idx]])
- x = waves_event.array_annotations['x_coords'][idx]
- y = waves_event.array_annotations['y_coords'][idx]
- wave_directions = vector_field[t_idx, y.astype(int), x.astype(int)]
- norm = np.array([np.linalg.norm(w) for w in wave_directions])
- wave_directions /= norm
- planarity = np.linalg.norm(np.mean(wave_directions))
- if t_limits is None:
- nbr_colors = len(np.unique(t_idx))
- else:
- t_steps = np.arange(t_limits[0], t_limits[1]+0.04, 0.04)
- nbr_colors = len(t_steps)
- palette = sns.husl_palette(nbr_colors+1, h=0.3, l=0.4)[:-1]
- if ax is None:
- fig, ax = plt.subplots()
- area = copy.copy(np.real(vector_field[0]))
- area[np.where(np.isfinite(area))] = 0
- ax.imshow(area, interpolation='nearest', origin='lower',
- vmin=-1, vmax=1, cmap='RdBu')
- for i, frame_t in enumerate(np.unique(t_idx)):
- frame_i = np.where(frame_t == t_idx)[0].astype(int)
- xi = x[frame_i]
- yi = y[frame_i]
- ti = t_idx[frame_i]
- frame = vector_field[ti, yi.astype(int), xi.astype(int)].magnitude
- norm = np.array([np.linalg.norm(w) for w in frame])
- frame /= norm
- frame_time = times[frame_t].rescale('s').magnitude
- color_idx = np.argmin(np.abs(np.array(t_steps) - frame_time))
- ax.quiver(xi, yi, np.real(frame), np.imag(frame),
- scale=45, width=0.003,
- # units='width', scale=max(frame.shape)/(10*skip_step),
- # width=0.15/max(frame.shape),
- color=palette[color_idx], alpha=0.8,
- label=f'{frame_time:.2f} s')
- ax.axis('image')
- ax.set_xticks([])
- ax.set_yticks([])
- # start_t = np.min(waves_event.times[idx]).rescale('s').magnitude
- # stop_t = np.max(waves_event.times[idx]).rescale('s').magnitude
- ax.set_xlabel(f'planarity {planarity:.2f}')
- patches = [mpl.patches.Patch(color=c) for c in palette]
- legend = ax.legend(patches, [f'{t:.2f} s' for t in t_steps], frameon=False,
- bbox_to_anchor=(1,.5), loc='center left', fontsize=12)
- return ax
- def plot_figure(local_data, global_data, signal_path, alt_signal_path):
- # """
- # A1 A1 B1 B2 B2
- # A2 A2 B1 B2 B2
- # X X X X X
- # C1 C2 C3 C4 C5a
- # C1 C2 C3 C4 C5b
- # """
- sns.set(style='ticks', palette='deep', context='talk')
- fig = plt.figure(figsize=(25,14), constrained_layout=True)
- ax = SimpleNamespace()
- gs = mpl.gridspec.GridSpec(nrows=4, ncols=4, wspace=0.2, hspace=.17,
- height_ratios=(.5,.5,.1,1))
- ax.A1 = fig.add_subplot(gs[0, 0:2])
- ax.A2 = fig.add_subplot(gs[1, 0:2])
- ax.B1 = fig.add_subplot(gs[:2, 2])
- ax.B2 = fig.add_subplot(gs[:2, 3])
- ax.C1 = fig.add_subplot(gs[3, 0])
- ax.C2 = fig.add_subplot(gs[3, 1])
- ax.C3 = fig.add_subplot(gs[3, 2])
- ax.C4 = fig.add_subplot(gs[3, 3])
- axes = vars(ax)
- # SIGNAL TRACES + WAVE FIELDS
- # wave_labels = {'hilbert phase': '22', 'minima': '38'}
- wave_labels = {'hilbert phase': '30', 'minima': '51'}
- # t_limits = (19.12, 19.56)
- t_limits = (29.00, 29.48)
- pixel = (30,50)
- for ax.Ai, ax.Bi, path, method in zip([ax.A1, ax.A2],
- [ax.B1, ax.B2],
- [signal_path, alt_signal_path],
- ['hilbert phase', 'minima']):
- block = load_neo(str(path))
- asig = block.segments[0].analogsignals[0]
- flow_asig = block.filter(name='optical_flow', objects="AnalogSignal")[0]
- flow = analogsignal_to_imagesequence(flow_asig)
- event = block.filter(name='wavefronts', objects="Event")[0]
- plot_signal_trace(asig=asig, event=event, label=method,
- pixel=pixel, ax=ax.Ai)
- plot_planarity(waves_event=event, vector_field=flow, times=asig.times,
- wave_label=wave_labels[method],
- t_limits=t_limits, ax=ax.Bi)
- ax.Bi.scatter([pixel[0]], [pixel[1]], c='k', marker='s', s=15)
- del block, asig, flow, event
- # VIOLIN PLOTS
- violin_params = dict(orient='h', palette=colormap, inner='quartile', cut=0,
- alpha=.8, linewidth=1, scale='width', saturation=1,
- y='method', hue='anesthetic', split='True',
- hue_order=['isoflurane', 'ketamine'])
- bw = .2
-
- ## wave characteristics
- sns.violinplot(data=filter_velocity_local(local_data), x='velocity_local',
- **violin_params, ax=ax.C1)
- sns.violinplot(data=local_data, x='inter_wave_interval_local', bw=bw,
- **violin_params, ax=ax.C2)
- sns.violinplot(data=global_data, x='planarity',
- **violin_params, ax=ax.C3)
- # sns.violinplot(data=global_data, x='number_of_triggers',
- # **violin_params, ax=ax.C4)
- ## number of waves
- bar_width = .1
- for j, method in enumerate(['hilbert phase', 'minima']):
- for i, anesthetic in enumerate(['ketamine', 'isoflurane']):
- plot_data = global_data[(global_data.anesthetic == anesthetic) \
- & (global_data.method == method)]
- shift = bar_width/2 - i*bar_width
- ax.C4.barh(j+shift, len(plot_data), align='center',
- height=bar_width, color=colormap[anesthetic],
- edgecolor='k', linewidth=.5)
- ## legends
- for panel in ['B1', 'C2', 'C3', 'C4']:
- if axes[panel].get_legend() is not None:
- axes[panel].get_legend().remove()
- handles, labels = ax.A2.get_legend_handles_labels()
- ax.A2.legend(handles[:2], labels[:2], frameon=True, loc='upper center',
- bbox_to_anchor=(.1, 1, .2, .35), framealpha=1)
- handles, labels = ax.C1.get_legend_handles_labels()
- ax.C1.legend(handles, labels, frameon=True, loc='center right', title='')
- ## ticks & labels
- ax.A1.set_xticks([])
- ax.A1.set_xlabel('')
- for panel in ['C2', 'C3', 'C4']:
- axes[panel].set_yticks([])
- axes[panel].set_ylabel('')
-
- for panel, label in zip(['A1', 'A2'], ['hilbert phase', 'minima']):
- axes[panel].set_yticks([0])
- axes[panel].set_yticklabels([label], rotation=90, va='center')
- ax.A2.set_xlabel('time [s]')
- ax.C1.set_ylabel('')
- ax.C1.set_yticklabels(ax.C1.get_yticklabels(), rotation=90, va='center')
- ax.C1.set_xlabel('velocity [mm/s]')
- ax.C2.set_xlabel('inter-wave interval [s]')
- ax.C3.set_xlabel('planarity')
- ax.C4.set_xlabel('number of waves')
- ax.B1.set_title('hilbert phase')
- ax.B2.set_title('minima')
-
- ## limits
- ax.A1.set_xlim((10,30))
- ax.A2.set_xlim((10,30))
- ax.C2.set_xlim((0, 2))
- ax.C4.set_ylim((ax.C1.get_ylim()[0],ax.C1.get_ylim()[1]))
- ## spines
- for panel in ['A1', 'A2', 'C1', 'C2', 'C3', 'C4']:
- sns.despine(left=True, ax=axes[panel])
- sns.despine(bottom=True, left=True, ax=ax.A1)
-
- ## panel letters
- dx, dy = .05, .05
- text_params = dict(ha='right', va='center', fontsize=20, fontweight='bold')
- for panel in ['B', 'C']:
- axi = axes[f'{panel}1']
- axi.text(s=panel, x=-dx, y=1+dy, transform=axi.transAxes, **text_params)
- ax.A1.text(s='A', x=-dx/2, y=1+2*dy, transform=ax.A1.transAxes, **text_params)
- # fig.align_labels()
- return fig, ax
- if __name__ == '__main__':
- CLI = argparse.ArgumentParser()
- CLI.add_argument("--local_data", nargs='?', type=Path, required=True)
- CLI.add_argument("--global_data", nargs='?', type=Path, required=True)
- CLI.add_argument("--alt_local_data", nargs='?', type=Path, required=True)
- CLI.add_argument("--alt_global_data", nargs='?', type=Path, required=True)
- CLI.add_argument("--signal", nargs='?', type=Path, required=True)
- CLI.add_argument("--alt_signal", nargs='?', type=Path, required=True)
- CLI.add_argument("--output", nargs='?', type=Path, required=True)
- args, unknown = CLI.parse_known_args()
- data = {}
- for measure_type in ['local', 'global']:
- dfs = []
- for path, method in zip([args.__dict__[f'{measure_type}_data'],
- args.__dict__[f'alt_{measure_type}_data']],
- ['hilbert phase', 'minima']):
- df = load_df(path)
- df = df[(df['technique'] == 'calcium imaging')]
- df['method'] = method
- dfs.append(df)
- data[f'{measure_type}_data'] = pd.concat(dfs, ignore_index=True)
-
- # hotfix
- data['global_data']['number_of_triggers'] = \
- data['global_data']['number_of_triggers'].astype(float)
- plot_figure(**data,
- signal_path = args.signal,
- alt_signal_path = args.alt_signal)
- plt.savefig(args.output, bbox_inches='tight')
|