plot_method_comparison.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import argparse
  2. import pandas as pd
  3. import numpy as np
  4. import seaborn as sns
  5. import matplotlib.pyplot as plt
  6. import matplotlib as mpl
  7. from types import SimpleNamespace
  8. import sys
  9. import copy
  10. from pathlib import Path
  11. from cobrawap.pipeline.utils.neo_utils import analogsignal_to_imagesequence
  12. from cobrawap.pipeline.utils.io_utils import load_neo
  13. sys.path.append(str(Path.cwd().parent))
  14. from plotting_utils import filter_velocity_local, load_df, colormap
  15. def plot_signal_trace(asig, event, label, ax, pixel=(40,30)):
  16. x_coord = pixel[0]
  17. y_coord = pixel[1]
  18. channel = np.where((asig.array_annotations['x_coords']==x_coord)
  19. & (asig.array_annotations['y_coords']==y_coord))[0]
  20. ax.plot(asig.times, asig.as_array()[:,channel], color='0.7', linewidth=3,
  21. label='signal')
  22. event_channels = np.where((event.array_annotations['x_coords']==x_coord)
  23. & (event.array_annotations['y_coords']==y_coord))
  24. for event_time, event_label in zip(event[event_channels].times,
  25. event[event_channels].labels):
  26. ls = ':' if event_label=='-1' else '-'
  27. ax.axvline(event_time, color='k', linestyle=ls,
  28. label='trigger')
  29. return ax
  30. def plot_planarity(waves_event, vector_field, times, wave_label,
  31. t_limits=None, ax=None):
  32. dim_t, dim_y, dim_x = vector_field.shape
  33. skip_step = int(min([dim_x, dim_y]) / 50) + 1
  34. idx = np.where(wave_label == waves_event.labels)[0]
  35. t_idx = np.array([np.argmax(times >= t) for t
  36. in waves_event.times[idx]])
  37. x = waves_event.array_annotations['x_coords'][idx]
  38. y = waves_event.array_annotations['y_coords'][idx]
  39. wave_directions = vector_field[t_idx, y.astype(int), x.astype(int)]
  40. norm = np.array([np.linalg.norm(w) for w in wave_directions])
  41. wave_directions /= norm
  42. planarity = np.linalg.norm(np.mean(wave_directions))
  43. if t_limits is None:
  44. nbr_colors = len(np.unique(t_idx))
  45. else:
  46. t_steps = np.arange(t_limits[0], t_limits[1]+0.04, 0.04)
  47. nbr_colors = len(t_steps)
  48. palette = sns.husl_palette(nbr_colors+1, h=0.3, l=0.4)[:-1]
  49. if ax is None:
  50. fig, ax = plt.subplots()
  51. area = copy.copy(np.real(vector_field[0]))
  52. area[np.where(np.isfinite(area))] = 0
  53. ax.imshow(area, interpolation='nearest', origin='lower',
  54. vmin=-1, vmax=1, cmap='RdBu')
  55. for i, frame_t in enumerate(np.unique(t_idx)):
  56. frame_i = np.where(frame_t == t_idx)[0].astype(int)
  57. xi = x[frame_i]
  58. yi = y[frame_i]
  59. ti = t_idx[frame_i]
  60. frame = vector_field[ti, yi.astype(int), xi.astype(int)].magnitude
  61. norm = np.array([np.linalg.norm(w) for w in frame])
  62. frame /= norm
  63. frame_time = times[frame_t].rescale('s').magnitude
  64. color_idx = np.argmin(np.abs(np.array(t_steps) - frame_time))
  65. ax.quiver(xi, yi, np.real(frame), np.imag(frame),
  66. scale=45, width=0.003,
  67. # units='width', scale=max(frame.shape)/(10*skip_step),
  68. # width=0.15/max(frame.shape),
  69. color=palette[color_idx], alpha=0.8,
  70. label=f'{frame_time:.2f} s')
  71. ax.axis('image')
  72. ax.set_xticks([])
  73. ax.set_yticks([])
  74. # start_t = np.min(waves_event.times[idx]).rescale('s').magnitude
  75. # stop_t = np.max(waves_event.times[idx]).rescale('s').magnitude
  76. ax.set_xlabel(f'planarity {planarity:.2f}')
  77. patches = [mpl.patches.Patch(color=c) for c in palette]
  78. legend = ax.legend(patches, [f'{t:.2f} s' for t in t_steps], frameon=False,
  79. bbox_to_anchor=(1,.5), loc='center left', fontsize=12)
  80. return ax
  81. def plot_figure(local_data, global_data, signal_path, alt_signal_path):
  82. # """
  83. # A1 A1 B1 B2 B2
  84. # A2 A2 B1 B2 B2
  85. # X X X X X
  86. # C1 C2 C3 C4 C5a
  87. # C1 C2 C3 C4 C5b
  88. # """
  89. sns.set(style='ticks', palette='deep', context='talk')
  90. fig = plt.figure(figsize=(25,14), constrained_layout=True)
  91. ax = SimpleNamespace()
  92. gs = mpl.gridspec.GridSpec(nrows=4, ncols=4, wspace=0.2, hspace=.17,
  93. height_ratios=(.5,.5,.1,1))
  94. ax.A1 = fig.add_subplot(gs[0, 0:2])
  95. ax.A2 = fig.add_subplot(gs[1, 0:2])
  96. ax.B1 = fig.add_subplot(gs[:2, 2])
  97. ax.B2 = fig.add_subplot(gs[:2, 3])
  98. ax.C1 = fig.add_subplot(gs[3, 0])
  99. ax.C2 = fig.add_subplot(gs[3, 1])
  100. ax.C3 = fig.add_subplot(gs[3, 2])
  101. ax.C4 = fig.add_subplot(gs[3, 3])
  102. axes = vars(ax)
  103. # SIGNAL TRACES + WAVE FIELDS
  104. # wave_labels = {'hilbert phase': '22', 'minima': '38'}
  105. wave_labels = {'hilbert phase': '30', 'minima': '51'}
  106. # t_limits = (19.12, 19.56)
  107. t_limits = (29.00, 29.48)
  108. pixel = (30,50)
  109. for ax.Ai, ax.Bi, path, method in zip([ax.A1, ax.A2],
  110. [ax.B1, ax.B2],
  111. [signal_path, alt_signal_path],
  112. ['hilbert phase', 'minima']):
  113. block = load_neo(str(path))
  114. asig = block.segments[0].analogsignals[0]
  115. flow_asig = block.filter(name='optical_flow', objects="AnalogSignal")[0]
  116. flow = analogsignal_to_imagesequence(flow_asig)
  117. event = block.filter(name='wavefronts', objects="Event")[0]
  118. plot_signal_trace(asig=asig, event=event, label=method,
  119. pixel=pixel, ax=ax.Ai)
  120. plot_planarity(waves_event=event, vector_field=flow, times=asig.times,
  121. wave_label=wave_labels[method],
  122. t_limits=t_limits, ax=ax.Bi)
  123. ax.Bi.scatter([pixel[0]], [pixel[1]], c='k', marker='s', s=15)
  124. del block, asig, flow, event
  125. # VIOLIN PLOTS
  126. violin_params = dict(orient='h', palette=colormap, inner='quartile', cut=0,
  127. alpha=.8, linewidth=1, scale='width', saturation=1,
  128. y='method', hue='anesthetic', split='True',
  129. hue_order=['isoflurane', 'ketamine'])
  130. bw = .2
  131. ## wave characteristics
  132. sns.violinplot(data=filter_velocity_local(local_data), x='velocity_local',
  133. **violin_params, ax=ax.C1)
  134. sns.violinplot(data=local_data, x='inter_wave_interval_local', bw=bw,
  135. **violin_params, ax=ax.C2)
  136. sns.violinplot(data=global_data, x='planarity',
  137. **violin_params, ax=ax.C3)
  138. # sns.violinplot(data=global_data, x='number_of_triggers',
  139. # **violin_params, ax=ax.C4)
  140. ## number of waves
  141. bar_width = .1
  142. for j, method in enumerate(['hilbert phase', 'minima']):
  143. for i, anesthetic in enumerate(['ketamine', 'isoflurane']):
  144. plot_data = global_data[(global_data.anesthetic == anesthetic) \
  145. & (global_data.method == method)]
  146. shift = bar_width/2 - i*bar_width
  147. ax.C4.barh(j+shift, len(plot_data), align='center',
  148. height=bar_width, color=colormap[anesthetic],
  149. edgecolor='k', linewidth=.5)
  150. ## legends
  151. for panel in ['B1', 'C2', 'C3', 'C4']:
  152. if axes[panel].get_legend() is not None:
  153. axes[panel].get_legend().remove()
  154. handles, labels = ax.A2.get_legend_handles_labels()
  155. ax.A2.legend(handles[:2], labels[:2], frameon=True, loc='upper center',
  156. bbox_to_anchor=(.1, 1, .2, .35), framealpha=1)
  157. handles, labels = ax.C1.get_legend_handles_labels()
  158. ax.C1.legend(handles, labels, frameon=True, loc='center right', title='')
  159. ## ticks & labels
  160. ax.A1.set_xticks([])
  161. ax.A1.set_xlabel('')
  162. for panel in ['C2', 'C3', 'C4']:
  163. axes[panel].set_yticks([])
  164. axes[panel].set_ylabel('')
  165. for panel, label in zip(['A1', 'A2'], ['hilbert phase', 'minima']):
  166. axes[panel].set_yticks([0])
  167. axes[panel].set_yticklabels([label], rotation=90, va='center')
  168. ax.A2.set_xlabel('time [s]')
  169. ax.C1.set_ylabel('')
  170. ax.C1.set_yticklabels(ax.C1.get_yticklabels(), rotation=90, va='center')
  171. ax.C1.set_xlabel('velocity [mm/s]')
  172. ax.C2.set_xlabel('inter-wave interval [s]')
  173. ax.C3.set_xlabel('planarity')
  174. ax.C4.set_xlabel('number of waves')
  175. ax.B1.set_title('hilbert phase')
  176. ax.B2.set_title('minima')
  177. ## limits
  178. ax.A1.set_xlim((10,30))
  179. ax.A2.set_xlim((10,30))
  180. ax.C2.set_xlim((0, 2))
  181. ax.C4.set_ylim((ax.C1.get_ylim()[0],ax.C1.get_ylim()[1]))
  182. ## spines
  183. for panel in ['A1', 'A2', 'C1', 'C2', 'C3', 'C4']:
  184. sns.despine(left=True, ax=axes[panel])
  185. sns.despine(bottom=True, left=True, ax=ax.A1)
  186. ## panel letters
  187. dx, dy = .05, .05
  188. text_params = dict(ha='right', va='center', fontsize=20, fontweight='bold')
  189. for panel in ['B', 'C']:
  190. axi = axes[f'{panel}1']
  191. axi.text(s=panel, x=-dx, y=1+dy, transform=axi.transAxes, **text_params)
  192. ax.A1.text(s='A', x=-dx/2, y=1+2*dy, transform=ax.A1.transAxes, **text_params)
  193. # fig.align_labels()
  194. return fig, ax
  195. if __name__ == '__main__':
  196. CLI = argparse.ArgumentParser()
  197. CLI.add_argument("--local_data", nargs='?', type=Path, required=True)
  198. CLI.add_argument("--global_data", nargs='?', type=Path, required=True)
  199. CLI.add_argument("--alt_local_data", nargs='?', type=Path, required=True)
  200. CLI.add_argument("--alt_global_data", nargs='?', type=Path, required=True)
  201. CLI.add_argument("--signal", nargs='?', type=Path, required=True)
  202. CLI.add_argument("--alt_signal", nargs='?', type=Path, required=True)
  203. CLI.add_argument("--output", nargs='?', type=Path, required=True)
  204. args, unknown = CLI.parse_known_args()
  205. data = {}
  206. for measure_type in ['local', 'global']:
  207. dfs = []
  208. for path, method in zip([args.__dict__[f'{measure_type}_data'],
  209. args.__dict__[f'alt_{measure_type}_data']],
  210. ['hilbert phase', 'minima']):
  211. df = load_df(path)
  212. df = df[(df['technique'] == 'calcium imaging')]
  213. df['method'] = method
  214. dfs.append(df)
  215. data[f'{measure_type}_data'] = pd.concat(dfs, ignore_index=True)
  216. # hotfix
  217. data['global_data']['number_of_triggers'] = \
  218. data['global_data']['number_of_triggers'].astype(float)
  219. plot_figure(**data,
  220. signal_path = args.signal,
  221. alt_signal_path = args.alt_signal)
  222. plt.savefig(args.output, bbox_inches='tight')