display_2d.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from __future__ import division
  2. import argparse
  3. import os
  4. from mpl_toolkits.mplot3d import Axes3D
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from matplotlib import animation, ticker, patches
  9. from epileptor import display_0d
  10. from epileptor import util
  11. from epileptor.state import State
  12. from epileptor.state_recorder import StateRecorder
  13. from epileptor.points import Points
  14. def parse_cmd_args():
  15. parser = argparse.ArgumentParser(description='')
  16. parser.add_argument('-d', help='data filepath')
  17. parser.add_argument('-i', help='image extension', default='png')
  18. args = parser.parse_args()
  19. return args.d, args.i
  20. def make_animation_video(dt, state, data_filename):
  21. def create_state_ax(fig, state):
  22. ax = fig.gca(projection='3d')
  23. ax.set_xlabel('X')
  24. ax.set_ylabel('Y')
  25. ax.set_zlim(state.min, state.max)
  26. ax.set_title(state.name)
  27. return ax
  28. nt, y_nh, x_nh = state.values.shape
  29. speed = int(200 / dt) # speed 200 is for 1ms
  30. frame_num = int(nt / speed)
  31. X, Y = np.meshgrid(range(x_nh), range(y_nh))
  32. fig_3d = plt.figure(1, figsize=(8, 5))
  33. ax_3d = create_state_ax(fig_3d, state)
  34. state_surf = ax_3d.plot_surface(X, Y, state.values[0], cmap='plasma', vmin=state.min, vmax=state.max)
  35. plt.colorbar(mappable=state_surf, ax=ax_3d)
  36. # second
  37. fig_2d = plt.figure(2, figsize=(8, 5))
  38. ax_2d = fig_2d.gca()
  39. ax_2d.set_axis_off()
  40. ax_2d.set_title(state.name)
  41. state_img = ax_2d.imshow(state.values[0], cmap='plasma')
  42. state_img.set_clim(state.min, state.max)
  43. plt.colorbar(mappable=state_img, ax=ax_2d)
  44. def animate(i):
  45. nonlocal state_surf
  46. state_surf.remove()
  47. state_mean = np.mean(state.values[i * speed:(i + 1) * speed, :, :], axis=0)
  48. state_surf = ax_3d.plot_surface(X, Y, state_mean, cmap='plasma', vmin=state.min, vmax=state.max)
  49. state_img.set_data(state_mean)
  50. anim_3d = animation.FuncAnimation(fig_3d, animate, frames=frame_num)
  51. anim_3d_fname = '{}_3d_{}.avi'.format(state.name, data_filename)
  52. anim_3d.save(os.path.join(media_fpath, anim_3d_fname))
  53. anim_2d = animation.FuncAnimation(fig_2d, animate, frames=frame_num)
  54. anim_2d_fname = '{}_2d_{}.avi'.format(state.name, data_filename)
  55. anim_2d.save(os.path.join(media_fpath, anim_2d_fname))
  56. fig_3d.clear()
  57. fig_2d.clear()
  58. def make_animation_board(dt, state, data_filename, points):
  59. def add_to_board(shot_i, ax):
  60. ax.set_axis_off()
  61. ax.set_title('{}'.format(util.dt_to_sec(shot_i, dt)))
  62. state_img = ax.imshow(state.values[shot_i, :, :], cmap='jet')
  63. for point in [s1, s2]:
  64. ax.text(point['x'] - 2, point['y'] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
  65. state_img.set_clim(state.min, state.max)
  66. return state_img
  67. s1, s2 = points.get(0), points.get(1)
  68. peak_times = util.get_peak_times(state.values, util.sec_to_dt(20, dt))
  69. nrows = len(peak_times)
  70. ncols = 8
  71. time_range_dt = util.sec_to_dt(27, dt)
  72. t_shots_dt = []
  73. for peak_time in peak_times:
  74. tmp = np.linspace(peak_time - time_range_dt, min(peak_time + time_range_dt, len(state.values) - 2), ncols)
  75. t_shots_dt += list(map(int, tmp))
  76. # diffusion t_shots
  77. # t_shots_dt = list(map(int, np.linspace(util.sec_to_dt(35, dt), util.sec_to_dt(299, dt), 16)))
  78. fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14, 5))
  79. for i in range(nrows * ncols):
  80. ax = axes.flat[i]
  81. im = add_to_board(t_shots_dt[i], ax)
  82. if i == 0:
  83. ax.set_title(ax.get_title() + 's')
  84. ax.text(s1['x'] - 15, s1['y'] - 2, 'S1', {'color': 'w', 'fontsize': 10})
  85. ax.text(s2['x'] - 15, s2['y'] + 2, 'S2', {'color': 'w', 'fontsize': 10})
  86. fig.subplots_adjust(wspace=0.05)
  87. fig.colorbar(
  88. mappable=im, ax=axes.ravel().tolist(), orientation='horizontal',
  89. pad=0.075, shrink=0.4, anchor=(0.9, 1.0),
  90. ticks=[state.min, np.mean([state.min, state.max]), state.max],
  91. format='%.1f'
  92. )
  93. # plot_children(fig, fig._layoutbox, printit=False)
  94. fig_fname = '{}_board_{}.{}'.format(state.name, data_filename, img_ext)
  95. fig.savefig(os.path.join(media_fpath, fig_fname), bbox_inches='tight')
  96. fig.clear()
  97. def display_h_points(dt, points_data, data_filename, points):
  98. # for each recorded point
  99. for i, point in enumerate(points):
  100. title = 'x' + str(point['x']) + 'y' + str(point['y']) + '_' + data_filename + '.' + img_ext
  101. point_state_list = [State(k, v[:, i]) for k, v in points_data.items()]
  102. display_0d.plot_data(dt, point_state_list, title)
  103. def display_t_points(state, filename):
  104. _, y_nh, x_nh = state.values.shape
  105. peak_times = util.get_peak_times(state.values, util.sec_to_dt(20, dt))
  106. t1_dt = peak_times[0]
  107. profile_t1 = np.s_[t1_dt, :, int(x_nh / 2)]
  108. t1_s = util.dt_to_sec(t1_dt, dt)
  109. # profile_t2 = np.s_[int(100 * (1.0 / dt)), :, int(x_nh/2)]
  110. fig = plt.figure(figsize=(14, 10))
  111. ax = fig.gca()
  112. ax.plot(range(y_nh), state.values[profile_t1], label='{} at {}sec'.format(state.name, t1_s))
  113. ax.legend()
  114. fig_fname = 't_along_y[{}]_{}_{}.{}'.format(t1_s, state.name, filename, img_ext)
  115. plt.savefig(os.path.join(media_fpath, fig_fname))
  116. fig.clear()
  117. def get_point_save_index(point, data):
  118. points_save_index = data['points']
  119. points_x, points_y = points_save_index[0].tolist(), points_save_index[1].tolist()
  120. for i in range(len(points_x)):
  121. if points_x[i] == point['x'] and points_y[i] == point['y']:
  122. return i
  123. raise ValueError('No point index for {}'.format(point))
  124. def compare_h_points(dt, data, data_filename, points):
  125. '''
  126. Keep the same number of discharges/waves for points
  127. '''
  128. points = [points.get(0), points.get(1)]
  129. point_title = '_'.join(['{},{}'.format(point['x'], point['y']) for point in points])
  130. title_list = ['S1', 'S2']
  131. point_index_list = [0, 1]
  132. data_filename = 'points{}_{}'.format(point_title, data_filename)
  133. display_state_at_points(
  134. dt, State(r'$[K]_{O}$(mm)', data['K']), data_filename, point_index_list, ['b', '#a9a9a9'], title_list
  135. )
  136. display_state_at_points(
  137. dt, State(r'U(mV)', data['U']), data_filename, point_index_list, ['#009900', '#a9a9a9'], title_list
  138. )
  139. def display_state_at_points(dt, state, data_filename, point_index_list, line_list, title_list):
  140. dt_begin = util.sec_to_dt(1, dt)
  141. dt_end = util.sec_to_dt(140, dt)
  142. t = np.linspace(dt_begin, dt_end, dt_end - dt_begin)
  143. value_list = [state.values[dt_begin:dt_end, point_index] for point_index in point_index_list]
  144. fig = plt.figure(figsize=(14, 10))
  145. ax = fig.add_subplot(111)
  146. sec_formatter = matplotlib.ticker.FuncFormatter(lambda t_dt, x: util.dt_to_sec(t_dt, dt))
  147. ax.xaxis.set_major_formatter(sec_formatter)
  148. plt.xticks(fontsize=34)
  149. plt.yticks(fontsize=34)
  150. plt.locator_params(axis='y', tight=True, nbins=3)
  151. ax.set_xlabel('t(s)', fontsize=36)
  152. ax.set_ylabel(state.name, fontsize=36)
  153. for i in range(len(point_index_list)):
  154. ax.plot(t, value_list[i], line_list[i], label='{}'.format(title_list[i]), linewidth=2)
  155. ax.legend(fontsize=32)
  156. title = '{}_{}'.format(state.name, data_filename)
  157. fig_fname = '{}.{}'.format(title, img_ext)
  158. plt.savefig(os.path.join(media_fpath, fig_fname))
  159. def center_cut_screenshot(dt, t_dt, state, data_filename):
  160. fig = plt.figure(figsize=(5, 5))
  161. ax = plt.gca()
  162. ax.axis('off')
  163. ax.imshow(state.values[t_dt, :, :], cmap='jet')
  164. rect = patches.Rectangle((40, 0), 5, 79, linewidth=4, edgecolor='#FF0000', clip_on=False, fill=False)
  165. ax.add_patch(rect)
  166. rect = patches.Rectangle((-5, 63), 2, 14, clip_on=False, facecolor='000')
  167. ax.add_patch(rect)
  168. ax.text(-12, 67, '1mm', color='000', fontsize=18, rotation=90)
  169. rect = patches.Rectangle((0, 81), 14, 2, clip_on=False, facecolor='000')
  170. ax.add_patch(rect)
  171. ax.text(0, 89, '1mm', color='000', fontsize=18)
  172. # ax.set_title(r'$V$', fontdict=dict(fontsize=26, fontweight='bold'), y=-0.2, clip_on=False)
  173. ax.text(-10, -5, 'A', color='000', fontsize=32, fontweight='bold')
  174. title = 'cut_H_{}_at{}s_{}'.format(state.name, util.dt_to_sec(t_dt, dt), data_filename)
  175. fig_fname = '{}.{}'.format(title, img_ext)
  176. plt.savefig(os.path.join(media_fpath, fig_fname))
  177. ax.clear()
  178. fig.clear()
  179. def center_cut_t_evolution(dt, state, data_filename):
  180. values = state.values
  181. # Those values should be set manually. Now they are set for 2019-04-04_19.16
  182. # 32800 - 33100, 51380 - 52000, 163000 - 163400
  183. dt_begin = int(63300 / dt)
  184. dt_end = int(63700 / dt)
  185. _, y_nh, x_nh = values.shape
  186. x = x_nh // 2
  187. center_cut = values[dt_begin:dt_end, :, x:(x + 5)]
  188. center_cut = center_cut.transpose(0, 2, 1).reshape(-1, y_nh)
  189. # center_cut = center_cut.reshape(-1, y_nh)
  190. fig = plt.gcf()
  191. ax = fig.add_subplot(111, xticks=[], yticks=[])
  192. img = ax.imshow(center_cut.T, cmap='jet', aspect=4.0)
  193. rect = patches.Rectangle((0, 90), 500, 5, clip_on=False, facecolor='000')
  194. ax.add_patch(rect)
  195. ax.text(160, 115, '0.1s', color='000', fontsize=10)
  196. rect = patches.Rectangle((-50, 1), 35, 77, linewidth=2, edgecolor='#FF0000', clip_on=False, fill=False)
  197. # rect = patches.Rectangle((-35, 65), 15, 15, clip_on=False, facecolor='000')
  198. ax.add_patch(rect)
  199. # ax.text(-100, 65, '1mm', color='000', fontsize=10, rotation=90)
  200. ax.text(-80, -15, 'B', color='000', fontsize=16, fontweight='bold')
  201. cbar_ax = fig.add_axes([0.5, 0.25, .4, .1], xticks=[], yticks=[], frameon=False)
  202. cbar = plt.colorbar(
  203. mappable=img, ax=cbar_ax, orientation='horizontal',
  204. fraction=1.0, pad=0.01
  205. )
  206. ticks = [int(np.min(center_cut)), int(np.max(center_cut))]
  207. cbar.set_ticks(ticks)
  208. cbar_labels = [str(tick) for tick in ticks]
  209. cbar_labels[-1] += 'mV'
  210. cbar.set_ticklabels(cbar_labels)
  211. title = 'cut_{}_at{}s_{}'.format(state.name, int(dt_begin / 1000), data_filename)
  212. fig_fname = '{}.{}'.format(title, img_ext)
  213. plt.savefig(os.path.join(media_fpath, fig_fname))
  214. fig.clear()
  215. center_cut_screenshot(dt, dt_begin, state, data_filename)
  216. media_fpath = 'media'
  217. filepath, img_ext = parse_cmd_args()
  218. data_filename = os.path.basename(filepath)
  219. params, plain_data, points_data = StateRecorder.load(filepath)
  220. dt = params['dt']
  221. points = None
  222. if 'x_coords' in points_data:
  223. points_data = dict(points_data)
  224. points = Points(points_data.pop('x_coords'), points_data.pop('y_coords'))
  225. # plain videos
  226. for name in plain_data.files:
  227. state = State(name, plain_data[name])
  228. if state.name == 'V':
  229. center_cut_t_evolution(dt, state, data_filename)
  230. # make_animation_video(dt, state, data_filename)
  231. # display_t_points(state, data_filename)
  232. if state.name == 'K':
  233. make_animation_board(dt, state, data_filename, points)
  234. # points plots
  235. display_h_points(dt, points_data, data_filename, points)
  236. if points and points.len() > 1:
  237. compare_h_points(dt, points_data, data_filename, points)