Explorar el Código

remove false optimisation from StateRecorder

vogdb hace 5 años
padre
commit
f53ebae2fa
Se han modificado 2 ficheros con 39 adiciones y 27 borrados
  1. 7 8
      epileptor/model_2d_full.py
  2. 32 19
      epileptor/state_recorder.py

+ 7 - 8
epileptor/model_2d_full.py

@@ -84,10 +84,11 @@ def get_record_points():
 
 
 def solve():
+    t = np.linspace(0, t_end, int(t_end / dt) + 1)
+    nt = len(t)
+
     points = get_record_points()
-    recorder = StateRecorder(
-        extract_params_to_dict(), ['K'], points, ['K', 'Na', 'INaKpump', 'V', 'U', 'xD', 'uu', 'nu', 'phi'],
-    )
+    recorder = StateRecorder(nt, extract_params_to_dict(), points)
     # init state
     K = np.ones((y_nh, x_nh)) * K0
     U = np.ones((y_nh, x_nh)) * U0
@@ -96,9 +97,7 @@ def solve():
     xD = np.ones((y_nh, x_nh))
     V = np.zeros((y_nh, x_nh))
 
-    t = np.linspace(0, t_end, int(t_end / dt) + 1)
-    nt = len(t)
-    for ti in range(nt - 1):
+    for dt_i in range(nt - 1):
         dKdt, dNadt, dVdt, dUdt, dwdt, dxDdt, uu, nu, INaKpump, phi = ode_step(K, Na, V, U, w, xD)
 
         K = K_step(K, dKdt)
@@ -112,8 +111,8 @@ def solve():
         U[Uspike] = Ureset
         w[Uspike] = w[Uspike] + delta_w
 
-        recorder.record_plain('K', K)
-        recorder.record_points(K, Na, INaKpump, V, U, xD, uu, nu, phi)
+        recorder.record_plain(dt_i, 'K', K)
+        recorder.record_points(dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi)
     return recorder
 
 

+ 32 - 19
epileptor/state_recorder.py

@@ -7,25 +7,20 @@ import yaml
 
 class StateRecorder:
 
-    def __init__(self, used_params, plain_states, points, points_states):
+    def __init__(self, nt, used_params, points):
+        self._nt = nt
         self._used_params = used_params
-        self._init_points(points_states, points)
-        self._init_plain(plain_states)
-
-    def _init_points(self, points_states, points):
-        self._points_dict = {}
         self._points = points
+        self._plain_dict = {}
+        self._points_dict = {}
 
-        for name in points_states:
-            self._points_dict[name] = []
-        self._points_dict['points'] = self._points
-
-    def _init_plain(self, plain_states):
-        self._plain_dict = {name: [] for name in plain_states}
-        self._plain_dict['names'] = plain_states
+    def _init_points(self, name):
+        points_num = len(self._points[0])
+        self._points_dict[name] = np.zeros((self._nt, points_num), dtype=np.float32)
 
-    def record_plain(self, name, val):
-        self._plain_dict[name].append(val.astype(np.float32))
+    def _init_plain(self, name):
+        y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
+        self._plain_dict[name] = np.zeros((self._nt, y_nh, x_nh), dtype=np.float32)
 
     # deprecated
     def _init_plain_memmap(self, plain_states):
@@ -43,10 +38,18 @@ class StateRecorder:
     def record_plain_memmap(self, name, val, ti):
         self._plain_dict[name][ti] = val.astype(np.float32)
 
-    def record_points(self, K, Na, INaKpump, V, U, xD, uu, nu, phi):
-        for k, v in list(locals().items()):
-            if v is not self:
-                self._points_dict[k].append(v[self._points])
+    def record_points(self, dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi):
+        for name, value in list(locals().items()):
+            if not any((value is self, value is dt_i)):
+                if name not in self._points_dict:
+                    self._init_points(name)
+                self._points_dict[name][dt_i] = value[self._points]
+
+    def record_plain(self, dt_i, name, val):
+        if name not in self._plain_dict:
+            self._init_plain(name)
+        # self._plain_dict[name][dt_i] = val.astype(np.float32)
+        self._plain_dict[name][dt_i] = val
 
     def generate_filename(self):
         current = datetime.datetime.now()
@@ -54,12 +57,22 @@ class StateRecorder:
 
     def save(self):
         filename = self.generate_filename()
+        self._save_plain(filename)
+        self._save_points(filename)
+        self._save_params(filename)
+
+    def _save_plain(self, filename):
         np.savez_compressed(
             '{}_plain'.format(filename), **self._plain_dict
         )
+
+    def _save_points(self, filename):
+        self._points_dict['points'] = self._points
         np.savez_compressed(
             '{}_points'.format(filename), **self._points_dict
         )
+
+    def _save_params(self, filename):
         with open('{}_params.yml'.format(filename), 'w') as outfile:
             yaml.dump(self._used_params, outfile, default_flow_style=False)