state_recorder.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import datetime
  2. import os
  3. import numpy as np
  4. import yaml
  5. from epileptor.points import Points, Point
  6. class StateRecorder:
  7. def __init__(self, nt: int, used_params: dict):
  8. self._nt = nt
  9. self._used_params = used_params
  10. self._points = self.get_record_points()
  11. self._plain_dict = {}
  12. self._points_dict = {}
  13. def get_record_points(self):
  14. y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
  15. points = Points(Point(x_nh // 2, y_nh // 2))
  16. if x_nh == y_nh == 81:
  17. points = Points(Point(33, 40), Point(40, 40), Point(41, 41))
  18. if x_nh == y_nh == 40:
  19. points = Points(Point(20, 20), Point(20, 24), Point(24, 20), Point(20, 35))
  20. if x_nh == y_nh == 80:
  21. points = Points(Point(40, 40), Point(40, 46), Point(40, 56), Point(40, 66), Point(40, 76))
  22. if x_nh == 4 and y_nh == 100:
  23. points = Points(Point(2, 10), Point(2, 90))
  24. return points
  25. def _init_points_state(self, state_name):
  26. self._points_dict[state_name] = np.empty((self._nt, self._points.len()), dtype=np.float32)
  27. def _init_plain_state(self, state_name):
  28. y_nh, x_nh = self._used_params['y_nh'], self._used_params['x_nh']
  29. self._plain_dict[state_name] = np.empty((self._nt, y_nh, x_nh), dtype=np.float32)
  30. def record_points(self, dt_i, K, Na, INaKpump, V, U, xD, uu, nu, phi):
  31. for name, value in list(locals().items()):
  32. if not any((value is self, value is dt_i)):
  33. if name not in self._points_dict:
  34. self._init_points_state(name)
  35. self._points_dict[name][dt_i] = value[self._points.as_np_index()]
  36. def record_plain(self, dt_i, name, val):
  37. if name not in self._plain_dict:
  38. self._init_plain_state(name)
  39. self._plain_dict[name][dt_i] = val.astype(np.float32)
  40. def generate_filename(self):
  41. current = datetime.datetime.now()
  42. return os.path.join('results', current.strftime('%Y-%m-%d_%H.%M'))
  43. def save(self):
  44. filename = self.generate_filename()
  45. self._save_plain(filename)
  46. self._save_points(filename)
  47. self._save_params(filename)
  48. return filename
  49. def _save_plain(self, filename):
  50. np.savez_compressed(
  51. '{}_plain'.format(filename), **self._plain_dict
  52. )
  53. def _save_points(self, filename):
  54. self._points_dict['points'] = self._points
  55. np.savez_compressed(
  56. '{}_points'.format(filename), **self._points_dict
  57. )
  58. def _save_params(self, filename):
  59. with open('{}_params.yml'.format(filename), 'w') as outfile:
  60. yaml.dump(self._used_params, outfile, default_flow_style=False)
  61. @staticmethod
  62. def load(filepath):
  63. try:
  64. plain = np.load(filepath + '_plain.npz')
  65. points = np.load(filepath + '_points.npz', allow_pickle=True)
  66. with open(filepath + '_params.yml', 'r') as params_file:
  67. params = yaml.load(params_file)
  68. return params, plain, points
  69. except Exception as e:
  70. raise NameError('Invalid state recorder file. Please look `def save` for it.', e)