123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- from __future__ import division
- import argparse
- import os
- from mpl_toolkits.mplot3d import Axes3D
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- from matplotlib import animation, ticker, patches
- from epileptor import display_0d
- from epileptor import util
- from epileptor.state import State
- from epileptor.state_recorder import StateRecorder
- from epileptor.points import Points
- def parse_cmd_args():
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('-d', help='data filepath')
- parser.add_argument('-i', help='image extension', default='png')
- args = parser.parse_args()
- return args.d, args.i
- def make_animation_video(dt, state, data_filename):
- def create_state_ax(fig, state):
- ax = fig.gca(projection='3d')
- ax.set_xlabel('X')
- ax.set_ylabel('Y')
- ax.set_zlim(state.min, state.max)
- ax.set_title(state.name)
- return ax
- nt, y_nh, x_nh = state.values.shape
- speed = int(200 / dt) # speed 200 is for 1ms
- frame_num = int(nt / speed)
- X, Y = np.meshgrid(range(x_nh), range(y_nh))
- fig_3d = plt.figure(1, figsize=(8, 5))
- ax_3d = create_state_ax(fig_3d, state)
- state_surf = ax_3d.plot_surface(X, Y, state.values[0], cmap='plasma', vmin=state.min, vmax=state.max)
- plt.colorbar(mappable=state_surf, ax=ax_3d)
- # second
- fig_2d = plt.figure(2, figsize=(8, 5))
- ax_2d = fig_2d.gca()
- ax_2d.set_axis_off()
- ax_2d.set_title(state.name)
- state_img = ax_2d.imshow(state.values[0], cmap='plasma')
- state_img.set_clim(state.min, state.max)
- plt.colorbar(mappable=state_img, ax=ax_2d)
- def animate(i):
- nonlocal state_surf
- state_surf.remove()
- state_mean = np.mean(state.values[i * speed:(i + 1) * speed, :, :], axis=0)
- state_surf = ax_3d.plot_surface(X, Y, state_mean, cmap='plasma', vmin=state.min, vmax=state.max)
- state_img.set_data(state_mean)
- anim_3d = animation.FuncAnimation(fig_3d, animate, frames=frame_num)
- anim_3d_fname = '{}_3d_{}.avi'.format(state.name, data_filename)
- anim_3d.save(os.path.join(media_fpath, anim_3d_fname))
- anim_2d = animation.FuncAnimation(fig_2d, animate, frames=frame_num)
- anim_2d_fname = '{}_2d_{}.avi'.format(state.name, data_filename)
- anim_2d.save(os.path.join(media_fpath, anim_2d_fname))
- fig_3d.clear()
- fig_2d.clear()
- def make_animation_board(dt, state, data_filename, points):
- def add_to_board(shot_i, ax):
- ax.set_axis_off()
- ax.set_title('{}'.format(util.dt_to_sec(shot_i, dt)))
- state_img = ax.imshow(state.values[shot_i, :, :], cmap='jet')
- for point in [s1, s2]:
- ax.text(point['x'] - 2, point['y'] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
- state_img.set_clim(state.min, state.max)
- return state_img
- s1, s2 = points.get(0), points.get(1)
- peak_times = util.get_peak_times(state.values, util.sec_to_dt(20, dt))
- nrows = len(peak_times)
- ncols = 8
- time_range_dt = util.sec_to_dt(27, dt)
- t_shots_dt = []
- for peak_time in peak_times:
- tmp = np.linspace(peak_time - time_range_dt, min(peak_time + time_range_dt, len(state.values) - 2), ncols)
- t_shots_dt += list(map(int, tmp))
- # diffusion t_shots
- # t_shots_dt = list(map(int, np.linspace(util.sec_to_dt(35, dt), util.sec_to_dt(299, dt), 16)))
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14, 5))
- for i in range(nrows * ncols):
- ax = axes.flat[i]
- im = add_to_board(t_shots_dt[i], ax)
- if i == 0:
- ax.set_title(ax.get_title() + 's')
- ax.text(s1['x'] - 15, s1['y'] - 2, 'S1', {'color': 'w', 'fontsize': 10})
- ax.text(s2['x'] - 15, s2['y'] + 2, 'S2', {'color': 'w', 'fontsize': 10})
- fig.subplots_adjust(wspace=0.05)
- fig.colorbar(
- mappable=im, ax=axes.ravel().tolist(), orientation='horizontal',
- pad=0.075, shrink=0.4, anchor=(0.9, 1.0),
- ticks=[state.min, np.mean([state.min, state.max]), state.max],
- format='%.1f'
- )
- # plot_children(fig, fig._layoutbox, printit=False)
- fig_fname = '{}_board_{}.{}'.format(state.name, data_filename, img_ext)
- fig.savefig(os.path.join(media_fpath, fig_fname), bbox_inches='tight')
- fig.clear()
- def display_h_points(dt, points_data, data_filename, points):
- # for each recorded point
- for i, point in enumerate(points):
- title = 'x' + str(point['x']) + 'y' + str(point['y']) + '_' + data_filename + '.' + img_ext
- point_state_list = [State(k, v[:, i]) for k, v in points_data.items()]
- display_0d.plot_data(dt, point_state_list, title)
- def display_t_points(state, filename):
- _, y_nh, x_nh = state.values.shape
- peak_times = util.get_peak_times(state.values, util.sec_to_dt(20, dt))
- t1_dt = peak_times[0]
- profile_t1 = np.s_[t1_dt, :, int(x_nh / 2)]
- t1_s = util.dt_to_sec(t1_dt, dt)
- # profile_t2 = np.s_[int(100 * (1.0 / dt)), :, int(x_nh/2)]
- fig = plt.figure(figsize=(14, 10))
- ax = fig.gca()
- ax.plot(range(y_nh), state.values[profile_t1], label='{} at {}sec'.format(state.name, t1_s))
- ax.legend()
- fig_fname = 't_along_y[{}]_{}_{}.{}'.format(t1_s, state.name, filename, img_ext)
- plt.savefig(os.path.join(media_fpath, fig_fname))
- fig.clear()
- def get_point_save_index(point, data):
- points_save_index = data['points']
- points_x, points_y = points_save_index[0].tolist(), points_save_index[1].tolist()
- for i in range(len(points_x)):
- if points_x[i] == point['x'] and points_y[i] == point['y']:
- return i
- raise ValueError('No point index for {}'.format(point))
- def compare_h_points(dt, data, data_filename, points):
- '''
- Keep the same number of discharges/waves for points
- '''
- points = [points.get(0), points.get(1)]
- point_title = '_'.join(['{},{}'.format(point['x'], point['y']) for point in points])
- title_list = ['S1', 'S2']
- point_index_list = [0, 1]
- data_filename = 'points{}_{}'.format(point_title, data_filename)
- display_state_at_points(
- dt, State(r'$[K]_{O}$(mm)', data['K']), data_filename, point_index_list, ['b', '#a9a9a9'], title_list
- )
- display_state_at_points(
- dt, State(r'U(mV)', data['U']), data_filename, point_index_list, ['#009900', '#a9a9a9'], title_list
- )
- def display_state_at_points(dt, state, data_filename, point_index_list, line_list, title_list):
- dt_begin = util.sec_to_dt(1, dt)
- dt_end = util.sec_to_dt(140, dt)
- t = np.linspace(dt_begin, dt_end, dt_end - dt_begin)
- value_list = [state.values[dt_begin:dt_end, point_index] for point_index in point_index_list]
- fig = plt.figure(figsize=(14, 10))
- ax = fig.add_subplot(111)
- sec_formatter = matplotlib.ticker.FuncFormatter(lambda t_dt, x: util.dt_to_sec(t_dt, dt))
- ax.xaxis.set_major_formatter(sec_formatter)
- plt.xticks(fontsize=34)
- plt.yticks(fontsize=34)
- plt.locator_params(axis='y', tight=True, nbins=3)
- ax.set_xlabel('t(s)', fontsize=36)
- ax.set_ylabel(state.name, fontsize=36)
- for i in range(len(point_index_list)):
- ax.plot(t, value_list[i], line_list[i], label='{}'.format(title_list[i]), linewidth=2)
- ax.legend(fontsize=32)
- title = '{}_{}'.format(state.name, data_filename)
- fig_fname = '{}.{}'.format(title, img_ext)
- plt.savefig(os.path.join(media_fpath, fig_fname))
- def center_cut_screenshot(dt, t_dt, state, data_filename):
- fig = plt.figure(figsize=(5, 5))
- ax = plt.gca()
- ax.axis('off')
- ax.imshow(state.values[t_dt, :, :], cmap='jet')
- rect = patches.Rectangle((40, 0), 5, 79, linewidth=4, edgecolor='#FF0000', clip_on=False, fill=False)
- ax.add_patch(rect)
- rect = patches.Rectangle((-5, 63), 2, 14, clip_on=False, facecolor='000')
- ax.add_patch(rect)
- ax.text(-12, 67, '1mm', color='000', fontsize=18, rotation=90)
- rect = patches.Rectangle((0, 81), 14, 2, clip_on=False, facecolor='000')
- ax.add_patch(rect)
- ax.text(0, 89, '1mm', color='000', fontsize=18)
- # ax.set_title(r'$V$', fontdict=dict(fontsize=26, fontweight='bold'), y=-0.2, clip_on=False)
- ax.text(-10, -5, 'A', color='000', fontsize=32, fontweight='bold')
- title = 'cut_H_{}_at{}s_{}'.format(state.name, util.dt_to_sec(t_dt, dt), data_filename)
- fig_fname = '{}.{}'.format(title, img_ext)
- plt.savefig(os.path.join(media_fpath, fig_fname))
- ax.clear()
- fig.clear()
- def center_cut_t_evolution(dt, state, data_filename):
- values = state.values
- # Those values should be set manually. Now they are set for 2019-04-04_19.16
- # 32800 - 33100, 51380 - 52000, 163000 - 163400
- dt_begin = int(63300 / dt)
- dt_end = int(63700 / dt)
- _, y_nh, x_nh = values.shape
- x = x_nh // 2
- center_cut = values[dt_begin:dt_end, :, x:(x + 5)]
- center_cut = center_cut.transpose(0, 2, 1).reshape(-1, y_nh)
- # center_cut = center_cut.reshape(-1, y_nh)
- fig = plt.gcf()
- ax = fig.add_subplot(111, xticks=[], yticks=[])
- img = ax.imshow(center_cut.T, cmap='jet', aspect=4.0)
- rect = patches.Rectangle((0, 90), 500, 5, clip_on=False, facecolor='000')
- ax.add_patch(rect)
- ax.text(160, 115, '0.1s', color='000', fontsize=10)
- rect = patches.Rectangle((-50, 1), 35, 77, linewidth=2, edgecolor='#FF0000', clip_on=False, fill=False)
- # rect = patches.Rectangle((-35, 65), 15, 15, clip_on=False, facecolor='000')
- ax.add_patch(rect)
- # ax.text(-100, 65, '1mm', color='000', fontsize=10, rotation=90)
- ax.text(-80, -15, 'B', color='000', fontsize=16, fontweight='bold')
- cbar_ax = fig.add_axes([0.5, 0.25, .4, .1], xticks=[], yticks=[], frameon=False)
- cbar = plt.colorbar(
- mappable=img, ax=cbar_ax, orientation='horizontal',
- fraction=1.0, pad=0.01
- )
- ticks = [int(np.min(center_cut)), int(np.max(center_cut))]
- cbar.set_ticks(ticks)
- cbar_labels = [str(tick) for tick in ticks]
- cbar_labels[-1] += 'mV'
- cbar.set_ticklabels(cbar_labels)
- title = 'cut_{}_at{}s_{}'.format(state.name, int(dt_begin / 1000), data_filename)
- fig_fname = '{}.{}'.format(title, img_ext)
- plt.savefig(os.path.join(media_fpath, fig_fname))
- fig.clear()
- center_cut_screenshot(dt, dt_begin, state, data_filename)
- media_fpath = 'media'
- filepath, img_ext = parse_cmd_args()
- data_filename = os.path.basename(filepath)
- params, plain_data, points_data = StateRecorder.load(filepath)
- dt = params['dt']
- points = None
- if 'x_coords' in points_data:
- points_data = dict(points_data)
- points = Points(points_data.pop('x_coords'), points_data.pop('y_coords'))
- # plain videos
- for name in plain_data.files:
- state = State(name, plain_data[name])
- if state.name == 'V':
- center_cut_t_evolution(dt, state, data_filename)
- # make_animation_video(dt, state, data_filename)
- # display_t_points(state, data_filename)
- if state.name == 'K':
- make_animation_board(dt, state, data_filename, points)
- # points plots
- display_h_points(dt, points_data, data_filename, points)
- if points and points.len() > 1:
- compare_h_points(dt, points_data, data_filename, points)
|