|
@@ -13,6 +13,7 @@ from epileptor import display_0d
|
|
|
from epileptor import util
|
|
|
from epileptor.state import State
|
|
|
from epileptor.state_recorder import StateRecorder
|
|
|
+from epileptor.points import Points
|
|
|
|
|
|
|
|
|
def parse_cmd_args():
|
|
@@ -68,13 +69,13 @@ def make_animation_video(dt, state, data_filename):
|
|
|
fig_2d.clear()
|
|
|
|
|
|
|
|
|
-def make_animation_board(dt, state, data_filename, point_list):
|
|
|
+def make_animation_board(dt, state, data_filename, points):
|
|
|
def add_to_board(shot_i, ax):
|
|
|
ax.set_axis_off()
|
|
|
ax.set_title('{}'.format(util.dt_to_sec(shot_i, dt)))
|
|
|
state_img = ax.imshow(state.values[shot_i, :, :], cmap='jet')
|
|
|
- for point in point_list:
|
|
|
- ax.text(point[0] - 2, point[1] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
|
|
|
+ for point in points:
|
|
|
+ ax.text(point['x'] - 2, point['y'] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
|
|
|
state_img.set_clim(state.min, state.max)
|
|
|
return state_img
|
|
|
|
|
@@ -95,10 +96,10 @@ def make_animation_board(dt, state, data_filename, point_list):
|
|
|
im = add_to_board(t_shots_dt[i], ax)
|
|
|
if i == 0:
|
|
|
ax.set_title(ax.get_title() + 's')
|
|
|
- s1 = point_list[0]
|
|
|
- ax.text(s1[0] - 18, s1[1] - 2, 'S1', {'color': 'w', 'fontsize': 10})
|
|
|
- s2 = point_list[1]
|
|
|
- ax.text(s2[0] - 15, s2[1] + 2, 'S2', {'color': 'w', 'fontsize': 10})
|
|
|
+ s1 = points.get(0)
|
|
|
+ ax.text(s1['x'] - 18, s1['y'] - 2, 'S1', {'color': 'w', 'fontsize': 10})
|
|
|
+ s2 = points.get(1)
|
|
|
+ ax.text(s2['x'] - 15, s2['y'] + 2, 'S2', {'color': 'w', 'fontsize': 10})
|
|
|
fig.subplots_adjust(wspace=0.05)
|
|
|
fig.colorbar(
|
|
|
mappable=im, ax=axes.ravel().tolist(), orientation='horizontal',
|
|
@@ -112,14 +113,11 @@ def make_animation_board(dt, state, data_filename, point_list):
|
|
|
fig.clear()
|
|
|
|
|
|
|
|
|
-def display_h_points(dt, data, data_filename):
|
|
|
- # profile is two lists of x, y coords of recorded points
|
|
|
- points = data.pop('points', None)
|
|
|
- points_x, points_y = points[0], points[1]
|
|
|
+def display_h_points(dt, points_data, data_filename, points):
|
|
|
# for each recorded point
|
|
|
- for i in range(len(points_x)):
|
|
|
- title = 'x' + str(points_x[i]) + 'y' + str(points_y[i]) + '_' + data_filename + '.' + img_ext
|
|
|
- point_state_list = [State(k, v[:, i]) for k, v in data.items()]
|
|
|
+ for i, point in enumerate(points):
|
|
|
+ title = 'x' + str(point['x']) + 'y' + str(point['y']) + '_' + data_filename + '.' + img_ext
|
|
|
+ point_state_list = [State(k, v[:, i]) for k, v in points_data.items()]
|
|
|
display_0d.plot_data(dt, point_state_list, title)
|
|
|
|
|
|
|
|
@@ -146,18 +144,19 @@ def get_point_save_index(point, data):
|
|
|
points_save_index = data['points']
|
|
|
points_x, points_y = points_save_index[0].tolist(), points_save_index[1].tolist()
|
|
|
for i in range(len(points_x)):
|
|
|
- if points_x[i] == point[0] and points_y[i] == point[1]:
|
|
|
+ if points_x[i] == point['x'] and points_y[i] == point['y']:
|
|
|
return i
|
|
|
raise ValueError('No point index for {}'.format(point))
|
|
|
|
|
|
|
|
|
-def compare_h_points(dt, data, data_filename, point_list):
|
|
|
+def compare_h_points(dt, data, data_filename, points):
|
|
|
'''
|
|
|
Keep the same number of discharges/waves for points
|
|
|
'''
|
|
|
- point_title = '_'.join(['{},{}'.format(point[0], point[1]) for point in point_list])
|
|
|
+ points = [points.get(0), points.get(1)]
|
|
|
+ point_title = '_'.join(['{},{}'.format(point['x'], point['y']) for point in points])
|
|
|
title_list = ['S1', 'S2']
|
|
|
- point_index_list = [get_point_save_index(point, data) for point in point_list]
|
|
|
+ point_index_list = [0, 1]
|
|
|
data_filename = 'points{}_{}'.format(point_title, data_filename)
|
|
|
|
|
|
display_state_at_points(
|
|
@@ -267,14 +266,10 @@ data_filename = os.path.basename(filepath)
|
|
|
params, plain_data, points_data = StateRecorder.load(filepath)
|
|
|
dt = params['dt']
|
|
|
|
|
|
-point_list = None
|
|
|
-if 'points' in points_data and len(points_data['points'][0]) > 1:
|
|
|
- points_x = points_data['points'][0]
|
|
|
- points_y = points_data['points'][1]
|
|
|
- point_list = [
|
|
|
- (points_x[0], points_y[0]),
|
|
|
- (points_x[1], points_y[1]),
|
|
|
- ]
|
|
|
+points = None
|
|
|
+if 'x_coords' in points_data:
|
|
|
+ points_data = dict(points_data)
|
|
|
+ points = Points(points_data.pop('x_coords'), points_data.pop('y_coords'))
|
|
|
|
|
|
# plain videos
|
|
|
for name in plain_data.files:
|
|
@@ -284,11 +279,9 @@ for name in plain_data.files:
|
|
|
make_animation_video(dt, state, data_filename)
|
|
|
display_t_points(state, data_filename)
|
|
|
if state.name == 'K':
|
|
|
- # we expect here 80x80 grid
|
|
|
- point_list = [(40, 46), (40, 66)]
|
|
|
- make_animation_board(dt, state, data_filename, point_list)
|
|
|
+ make_animation_board(dt, state, data_filename, points)
|
|
|
|
|
|
# points plots
|
|
|
-display_h_points(dt, dict(points_data), data_filename)
|
|
|
-if point_list is not None:
|
|
|
- compare_h_points(dt, points_data, data_filename, point_list)
|
|
|
+display_h_points(dt, points_data, data_filename, points)
|
|
|
+if points and points.len() > 1:
|
|
|
+ compare_h_points(dt, points_data, data_filename, points)
|