Просмотр исходного кода

Update work with points

Create class Points
Move points selection to StateRecorder
vogdb 5 лет назад
Родитель
Сommit
f6bcb1220d
4 измененных файлов с 83 добавлено и 64 удалено
  1. 25 32
      epileptor/display_2d.py
  2. 2 20
      epileptor/model_2d_full.py
  3. 25 0
      epileptor/points.py
  4. 31 12
      epileptor/state_recorder.py

+ 25 - 32
epileptor/display_2d.py

@@ -13,6 +13,7 @@ from epileptor import display_0d
 from epileptor import util
 from epileptor.state import State
 from epileptor.state_recorder import StateRecorder
+from epileptor.points import Points
 
 
 def parse_cmd_args():
@@ -68,13 +69,13 @@ def make_animation_video(dt, state, data_filename):
     fig_2d.clear()
 
 
-def make_animation_board(dt, state, data_filename, point_list):
+def make_animation_board(dt, state, data_filename, points):
     def add_to_board(shot_i, ax):
         ax.set_axis_off()
         ax.set_title('{}'.format(util.dt_to_sec(shot_i, dt)))
         state_img = ax.imshow(state.values[shot_i, :, :], cmap='jet')
-        for point in point_list:
-            ax.text(point[0] - 2, point[1] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
+        for point in points:
+            ax.text(point['x'] - 2, point['y'] + 1, 'x', {'color': 'w', 'fontsize': 5, 'weight': 'bold'})
         state_img.set_clim(state.min, state.max)
         return state_img
 
@@ -95,10 +96,10 @@ def make_animation_board(dt, state, data_filename, point_list):
         im = add_to_board(t_shots_dt[i], ax)
         if i == 0:
             ax.set_title(ax.get_title() + 's')
-            s1 = point_list[0]
-            ax.text(s1[0] - 18, s1[1] - 2, 'S1', {'color': 'w', 'fontsize': 10})
-            s2 = point_list[1]
-            ax.text(s2[0] - 15, s2[1] + 2, 'S2', {'color': 'w', 'fontsize': 10})
+            s1 = points.get(0)
+            ax.text(s1['x'] - 18, s1['y'] - 2, 'S1', {'color': 'w', 'fontsize': 10})
+            s2 = points.get(1)
+            ax.text(s2['x'] - 15, s2['y'] + 2, 'S2', {'color': 'w', 'fontsize': 10})
     fig.subplots_adjust(wspace=0.05)
     fig.colorbar(
         mappable=im, ax=axes.ravel().tolist(), orientation='horizontal',
@@ -112,14 +113,11 @@ def make_animation_board(dt, state, data_filename, point_list):
     fig.clear()
 
 
-def display_h_points(dt, data, data_filename):
-    # profile is two lists of x, y coords of recorded points
-    points = data.pop('points', None)
-    points_x, points_y = points[0], points[1]
+def display_h_points(dt, points_data, data_filename, points):
     # for each recorded point
-    for i in range(len(points_x)):
-        title = 'x' + str(points_x[i]) + 'y' + str(points_y[i]) + '_' + data_filename + '.' + img_ext
-        point_state_list = [State(k, v[:, i]) for k, v in data.items()]
+    for i, point in enumerate(points):
+        title = 'x' + str(point['x']) + 'y' + str(point['y']) + '_' + data_filename + '.' + img_ext
+        point_state_list = [State(k, v[:, i]) for k, v in points_data.items()]
         display_0d.plot_data(dt, point_state_list, title)
 
 
@@ -146,18 +144,19 @@ def get_point_save_index(point, data):
     points_save_index = data['points']
     points_x, points_y = points_save_index[0].tolist(), points_save_index[1].tolist()
     for i in range(len(points_x)):
-        if points_x[i] == point[0] and points_y[i] == point[1]:
+        if points_x[i] == point['x'] and points_y[i] == point['y']:
             return i
     raise ValueError('No point index for {}'.format(point))
 
 
-def compare_h_points(dt, data, data_filename, point_list):
+def compare_h_points(dt, data, data_filename, points):
     '''
     Keep the same number of discharges/waves for points
     '''
-    point_title = '_'.join(['{},{}'.format(point[0], point[1]) for point in point_list])
+    points = [points.get(0), points.get(1)]
+    point_title = '_'.join(['{},{}'.format(point['x'], point['y']) for point in points])
     title_list = ['S1', 'S2']
-    point_index_list = [get_point_save_index(point, data) for point in point_list]
+    point_index_list = [0, 1]
     data_filename = 'points{}_{}'.format(point_title, data_filename)
 
     display_state_at_points(
@@ -267,14 +266,10 @@ data_filename = os.path.basename(filepath)
 params, plain_data, points_data = StateRecorder.load(filepath)
 dt = params['dt']
 
-point_list = None
-if 'points' in points_data and len(points_data['points'][0]) > 1:
-    points_x = points_data['points'][0]
-    points_y = points_data['points'][1]
-    point_list = [
-        (points_x[0], points_y[0]),
-        (points_x[1], points_y[1]),
-    ]
+points = None
+if 'x_coords' in points_data:
+    points_data = dict(points_data)
+    points = Points(points_data.pop('x_coords'), points_data.pop('y_coords'))
 
 # plain videos
 for name in plain_data.files:
@@ -284,11 +279,9 @@ for name in plain_data.files:
     make_animation_video(dt, state, data_filename)
     display_t_points(state, data_filename)
     if state.name == 'K':
-        # we expect here 80x80 grid
-        point_list = [(40, 46), (40, 66)]
-        make_animation_board(dt, state, data_filename, point_list)
+        make_animation_board(dt, state, data_filename, points)
 
 # points plots
-display_h_points(dt, dict(points_data), data_filename)
-if point_list is not None:
-    compare_h_points(dt, points_data, data_filename, point_list)
+display_h_points(dt, points_data, data_filename, points)
+if points and points.len() > 1:
+    compare_h_points(dt, points_data, data_filename, points)

+ 2 - 20
epileptor/model_2d_full.py

@@ -65,30 +65,11 @@ def boundary_conditions(f):
     f[-1, -1] = f[-2, -2]
 
 
-def get_record_points():
-    x_coords = [x_nh // 2]
-    y_coords = [y_nh // 2]
-    if x_nh == y_nh == 81:
-        x_coords = [33, 40, 41, 48, 40, 41, 37, 37, 47, 46, 47]
-        y_coords = [40, 40, 41, 41, 33, 48, 34, 47, 36, 45, 45]
-    if x_nh == y_nh == 40:
-        x_coords = [20, 20, 24, 20, 35]
-        y_coords = [20, 24, 20, 35, 20]
-    if x_nh == y_nh == 80:
-        x_coords = [40, 40, 65, 40, 44]
-        y_coords = [40, 65, 40, 44, 40]
-    if x_nh == 4 and y_nh == 100:
-        x_coords = [2, 2]
-        y_coords = [10, 90]
-    return np.s_[y_coords, x_coords]
-
-
 def solve():
     t = np.linspace(0, t_end, int(t_end / dt) + 1)
     nt = len(t)
 
-    points = get_record_points()
-    recorder = StateRecorder(nt, extract_params_to_dict(), points)
+    recorder = StateRecorder(nt, extract_params_to_dict())
     # init state
     K = np.ones((y_nh, x_nh)) * K0
     U = np.ones((y_nh, x_nh)) * U0
@@ -112,6 +93,7 @@ def solve():
         w[Uspike] = w[Uspike] + delta_w
 
         recorder.record_plain(dt_i, 'K', K)
+        recorder.record_plain(dt_i, 'V', V)
         recorder.record_points(dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi)
     return recorder
 

+ 25 - 0
epileptor/points.py

@@ -0,0 +1,25 @@
+import numpy as np
+
+
+class Points:
+
+    def __init__(self, x_coords, y_coords):
+        self._np_index = np.s_[y_coords, x_coords]
+        assert len(x_coords) == len(y_coords)
+        self.x_coords = x_coords
+        self.y_coords = y_coords
+
+    def as_np_index(self):
+        return self._np_index
+
+    def get(self, idx):
+        return {'x': self.x_coords[idx], 'y': self.y_coords[idx]}
+
+    def len(self):
+        return len(self.x_coords)
+
+    def __getitem__(self, index):
+        return self.get(index)
+
+    def __len__(self):
+        return self.len()

+ 31 - 12
epileptor/state_recorder.py

@@ -3,37 +3,55 @@ import os
 
 import numpy as np
 import yaml
+from epileptor.points import Points
 
 
 class StateRecorder:
 
-    def __init__(self, nt, used_params, points):
+    def __init__(self, nt: int, used_params: dict):
         self._nt = nt
         self._used_params = used_params
-        self._points = points
+        self._points = self.get_record_points()
         self._plain_dict = {}
         self._points_dict = {}
 
-    def _init_points(self, name):
-        points_num = len(self._points[0])
-        self._points_dict[name] = np.zeros((self._nt, points_num), dtype=np.float32)
+    def get_record_points(self):
+        y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
+
+        x_coords = [x_nh // 2]
+        y_coords = [y_nh // 2]
+        if x_nh == y_nh == 81:
+            x_coords = [33, 40, 41, 48, 40, 41, 37, 37, 47, 46, 47]
+            y_coords = [40, 40, 41, 41, 33, 48, 34, 47, 36, 45, 45]
+        if x_nh == y_nh == 40:
+            x_coords = [20, 20, 24, 20, 35]
+            y_coords = [20, 24, 20, 35, 20]
+        if x_nh == y_nh == 80:
+            x_coords = [40, 40, 65, 40, 44]
+            y_coords = [46, 66, 40, 44, 40]
+        if x_nh == 4 and y_nh == 100:
+            x_coords = [2, 2]
+            y_coords = [10, 90]
+        return Points(x_coords, y_coords)
+
+    def _init_points_state(self, state_name):
+        self._points_dict[state_name] = np.empty((self._nt, self._points.len()), dtype=np.float32)
 
-    def _init_plain(self, name):
+    def _init_plain_state(self, state_name):
         y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
-        self._plain_dict[name] = np.zeros((self._nt, y_nh, x_nh), dtype=np.float32)
+        self._plain_dict[state_name] = np.empty((self._nt, y_nh, x_nh), dtype=np.float32)
 
     def record_points(self, dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi):
         for name, value in list(locals().items()):
             if not any((value is self, value is dt_i)):
                 if name not in self._points_dict:
-                    self._init_points(name)
-                self._points_dict[name][dt_i] = value[self._points]
+                    self._init_points_state(name)
+                self._points_dict[name][dt_i] = value[self._points.as_np_index()]
 
     def record_plain(self, dt_i, name, val):
         if name not in self._plain_dict:
-            self._init_plain(name)
+            self._init_plain_state(name)
         self._plain_dict[name][dt_i] = val.astype(np.float32)
-        # self._plain_dict[name][dt_i] = val
 
     def generate_filename(self):
         current = datetime.datetime.now()
@@ -51,7 +69,8 @@ class StateRecorder:
         )
 
     def _save_points(self, filename):
-        self._points_dict['points'] = self._points
+        self._points_dict['x_coords'] = self._points.x_coords
+        self._points_dict['y_coords'] = self._points.y_coords
         np.savez_compressed(
             '{}_points'.format(filename), **self._points_dict
         )