|
@@ -7,25 +7,20 @@ import yaml
|
|
|
|
|
|
class StateRecorder:
|
|
|
|
|
|
- def __init__(self, used_params, plain_states, points, points_states):
|
|
|
+ def __init__(self, nt, used_params, points):
|
|
|
+ self._nt = nt
|
|
|
self._used_params = used_params
|
|
|
- self._init_points(points_states, points)
|
|
|
- self._init_plain(plain_states)
|
|
|
-
|
|
|
- def _init_points(self, points_states, points):
|
|
|
- self._points_dict = {}
|
|
|
self._points = points
|
|
|
+ self._plain_dict = {}
|
|
|
+ self._points_dict = {}
|
|
|
|
|
|
- for name in points_states:
|
|
|
- self._points_dict[name] = []
|
|
|
- self._points_dict['points'] = self._points
|
|
|
-
|
|
|
- def _init_plain(self, plain_states):
|
|
|
- self._plain_dict = {name: [] for name in plain_states}
|
|
|
- self._plain_dict['names'] = plain_states
|
|
|
+ def _init_points(self, name):
|
|
|
+ points_num = len(self._points[0])
|
|
|
+ self._points_dict[name] = np.zeros((self._nt, points_num), dtype=np.float32)
|
|
|
|
|
|
- def record_plain(self, name, val):
|
|
|
- self._plain_dict[name].append(val.astype(np.float32))
|
|
|
+ def _init_plain(self, name):
|
|
|
+ y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
|
|
|
+ self._plain_dict[name] = np.zeros((self._nt, y_nh, x_nh), dtype=np.float32)
|
|
|
|
|
|
# deprecated
|
|
|
def _init_plain_memmap(self, plain_states):
|
|
@@ -43,10 +38,18 @@ class StateRecorder:
|
|
|
def record_plain_memmap(self, name, val, ti):
|
|
|
self._plain_dict[name][ti] = val.astype(np.float32)
|
|
|
|
|
|
- def record_points(self, K, Na, INaKpump, V, U, xD, uu, nu, phi):
|
|
|
- for k, v in list(locals().items()):
|
|
|
- if v is not self:
|
|
|
- self._points_dict[k].append(v[self._points])
|
|
|
+ def record_points(self, dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi):
|
|
|
+ for name, value in list(locals().items()):
|
|
|
+ if not any((value is self, value is dt_i)):
|
|
|
+ if name not in self._points_dict:
|
|
|
+ self._init_points(name)
|
|
|
+ self._points_dict[name][dt_i] = value[self._points]
|
|
|
+
|
|
|
+ def record_plain(self, dt_i, name, val):
|
|
|
+ if name not in self._plain_dict:
|
|
|
+ self._init_plain(name)
|
|
|
+ # self._plain_dict[name][dt_i] = val.astype(np.float32)
|
|
|
+ self._plain_dict[name][dt_i] = val
|
|
|
|
|
|
def generate_filename(self):
|
|
|
current = datetime.datetime.now()
|
|
@@ -54,12 +57,22 @@ class StateRecorder:
|
|
|
|
|
|
def save(self):
|
|
|
filename = self.generate_filename()
|
|
|
+ self._save_plain(filename)
|
|
|
+ self._save_points(filename)
|
|
|
+ self._save_params(filename)
|
|
|
+
|
|
|
+ def _save_plain(self, filename):
|
|
|
np.savez_compressed(
|
|
|
'{}_plain'.format(filename), **self._plain_dict
|
|
|
)
|
|
|
+
|
|
|
+ def _save_points(self, filename):
|
|
|
+ self._points_dict['points'] = self._points
|
|
|
np.savez_compressed(
|
|
|
'{}_points'.format(filename), **self._points_dict
|
|
|
)
|
|
|
+
|
|
|
+ def _save_params(self, filename):
|
|
|
with open('{}_params.yml'.format(filename), 'w') as outfile:
|
|
|
yaml.dump(self._used_params, outfile, default_flow_style=False)
|
|
|
|