1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import datetime
- import os
- import numpy as np
- import yaml
- from epileptor.points import Points, Point
- class StateRecorder:
- def __init__(self, nt: int, used_params: dict):
- self._nt = nt
- self._used_params = used_params
- self._points = self.get_record_points()
- self._plain_dict = {}
- self._points_dict = {}
- def get_record_points(self):
- y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
- points = Points(Point(x_nh // 2, y_nh // 2))
- if x_nh == y_nh == 81:
- points = Points(Point(33, 40), Point(40, 40), Point(41, 41))
- if x_nh == y_nh == 40:
- points = Points(Point(20, 20), Point(20, 24), Point(24, 20), Point(20, 35))
- if x_nh == y_nh == 80:
- points = Points(Point(40, 40), Point(40, 46), Point(40, 56), Point(40, 66), Point(40, 76))
- if x_nh == 4 and y_nh == 100:
- points = Points(Point(2, 10), Point(2, 90))
- return points
- def _init_points_state(self, state_name):
- self._points_dict[state_name] = np.empty((self._nt, self._points.len()), dtype=np.float32)
- def _init_plain_state(self, state_name):
- y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
- self._plain_dict[state_name] = np.empty((self._nt, y_nh, x_nh), dtype=np.float32)
- def record_points(self, dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi):
- for name, value in list(locals().items()):
- if not any((value is self, value is dt_i)):
- if name not in self._points_dict:
- self._init_points_state(name)
- self._points_dict[name][dt_i] = value[self._points.as_np_index()]
- def record_plain(self, dt_i, name, val):
- if name not in self._plain_dict:
- self._init_plain_state(name)
- self._plain_dict[name][dt_i] = val.astype(np.float32)
- def generate_filename(self):
- current = datetime.datetime.now()
- return os.path.join('results', current.strftime('%Y-%m-%d_%H.%M'))
- def save(self):
- filename = self.generate_filename()
- self._save_plain(filename)
- self._save_points(filename)
- self._save_params(filename)
- return filename
- def _save_plain(self, filename):
- np.savez_compressed(
- '{}_plain'.format(filename), **self._plain_dict
- )
- def _save_points(self, filename):
- self._points_dict['points'] = self._points
- np.savez_compressed(
- '{}_points'.format(filename), **self._points_dict
- )
- def _save_params(self, filename):
- with open('{}_params.yml'.format(filename), 'w') as outfile:
- yaml.dump(self._used_params, outfile, default_flow_style=False)
- @staticmethod
- def load(filepath):
- try:
- plain = np.load(filepath + '_plain.npz')
- points = np.load(filepath + '_points.npz', allow_pickle=True)
- with open(filepath + '_params.yml', 'r') as params_file:
- params = yaml.load(params_file)
- return params, plain, points
- except Exception as e:
- raise NameError('Invalid state recorder file. Please look `def save` for it.', e)
|