data.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import datetime
  2. import glob
  3. import logging
  4. import os
  5. import pathlib
  6. import pickle
  7. import random
  8. import time
  9. from datetime import datetime as dt
  10. import matplotlib.pyplot as plt
  11. import munch
  12. import numpy as np
  13. from scipy import io
  14. import yaml
  15. from .kaux import log
  16. from .ringbuffer import RingBuffer
  17. class DataNormalizer:
  18. def __init__(self, params, initial_data=None):
  19. self.params = params
  20. self.norm_rate = {}
  21. self.norm_rate['ch_ids'] = [ch.id for ch in self.params.daq.normalization.channels]
  22. self.norm_rate['bottoms'] = np.asarray([ch.bottom for ch in self.params.daq.normalization.channels])
  23. self.norm_rate['tops'] = np.asarray([ch.top for ch in self.params.daq.normalization.channels])
  24. self.norm_rate['invs'] = [ch.invert for ch in self.params.daq.normalization.channels]
  25. n_norm_buffer = int(self.params.daq.normalization.len * (1000.0 / self.params.daq.spike_rates.loop_interval))
  26. self.norm_buffer = RingBuffer(n_norm_buffer, dtype=(float, self.params.daq.n_channels), allow_overwrite=True)
  27. self.last_update = time.time()
  28. if initial_data is not None:
  29. self.norm_buffer.extend(initial_data)
  30. def _update_norm_range(self):
  31. buf_vals = self.norm_buffer[:, self.norm_rate['ch_ids']]
  32. centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
  33. self.norm_rate['bottoms'] = centiles[0, :]
  34. self.norm_rate['tops'] = centiles[1, :]
  35. log.info(f"Updated normalization ranges for channels {self.norm_rate['ch_ids']} to bottoms: {self.norm_rate['bottoms']}, tops: {self.norm_rate['tops']}")
  36. def _update_norm_range_all(self):
  37. buf_vals = np.mean(self.norm_buffer, axis=1)
  38. centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
  39. # log.info(f"Centiles: {centiles}")
  40. self.params.daq.normalization.all_channels.bottom = centiles[0]
  41. self.params.daq.normalization.all_channels.top = centiles[1]
  42. log.info(f"Updated normalization range for all channels to [{self.params.daq.normalization.all_channels.bottom}, {self.params.daq.normalization.all_channels.top}]")
  43. def update_norm_range(self, data=None, force=False):
  44. if data is not None and data.size > 0:
  45. self.norm_buffer.extend(data)
  46. if (self.params.daq.normalization.do_update and (time.time() - self.last_update >= self.params.daq.normalization.update_interval)) or force:
  47. if self.params.daq.normalization.use_all_channels:
  48. self._update_norm_range_all()
  49. else:
  50. self._update_norm_range()
  51. self.last_update = time.time()
  52. log.info(f"New channel normalization setting: {yaml.dump(self._format_current_config(), sort_keys=False, default_flow_style=None)}")
  53. def _format_current_config(self):
  54. if self.params.daq.normalization.use_all_channels:
  55. out_dict = {'all_channels': {'bottom': float(self.params.daq.normalization.all_channels.bottom), 'top': float(self.params.daq.normalization.all_channels.top),
  56. 'invert': bool(self.params.daq.normalization.all_channels.invert)}}
  57. else:
  58. out_dict = {'channels': []}
  59. for ii in range(len(self.norm_rate['ch_ids'])):
  60. out_dict['channels'].append({'id': int(self.norm_rate['ch_ids'][ii]),
  61. 'bottom': float(self.norm_rate['bottoms'][ii]),
  62. 'top': float(self.norm_rate['tops'][ii]),
  63. 'invert': self.norm_rate['invs'][ii]}
  64. )
  65. return out_dict
  66. def _calculate_all_norm_rate(self, buf_item):
  67. avg_r = np.mean(buf_item, axis=1)
  68. if self.params.daq.normalization.clamp_firing_rates:
  69. avg_r = np.maximum(np.minimum(avg_r, self.params.daq.normalization.all_channels.top), self.params.daq.normalization.all_channels.bottom)
  70. norm_rate = (avg_r - self.params.daq.normalization.all_channels.bottom) / (self.params.daq.normalization.all_channels.top - self.params.daq.normalization.all_channels.bottom)
  71. if self.params.daq.normalization.all_channels.invert:
  72. norm_rate = 1 - norm_rate
  73. return norm_rate
  74. def _calculate_individual_norm_rate(self, buf_items):
  75. """Calculate normalized firing rate, determined by feedback settings"""
  76. if self.params.daq.normalization.clamp_firing_rates:
  77. clamped_rates = np.maximum(np.minimum(buf_items[:, self.norm_rate['ch_ids']], self.norm_rate['tops']), self.norm_rate['bottoms'])
  78. else:
  79. clamped_rates = buf_items[:, self.norm_rate['ch_ids']]
  80. denom = self.norm_rate['tops'] - self.norm_rate['bottoms']
  81. if np.all(denom==0):
  82. denom[:] = 1
  83. norm_rates = (clamped_rates - self.norm_rate['bottoms']) / denom
  84. norm_rates[:, self.norm_rate['invs']] = 1 - norm_rates[:, self.norm_rate['invs']]
  85. norm_rate = np.nanmean(norm_rates, axis=1)
  86. if not self.params.daq.normalization.clamp_firing_rates:
  87. norm_rate = np.maximum(norm_rate, 0.0)
  88. return norm_rate
  89. def calculate_norm_rate(self, buf_item):
  90. if buf_item.ndim == 1:
  91. buf_item.shape = (1, buf_item.shape[0])
  92. if self.params.daq.normalization.use_all_channels:
  93. return self._calculate_all_norm_rate(buf_item)
  94. else:
  95. return self._calculate_individual_norm_rate(buf_item)