123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import datetime
- import glob
- import logging
- import os
- import pathlib
- import pickle
- import random
- import time
- from datetime import datetime as dt
- import matplotlib.pyplot as plt
- import munch
- import numpy as np
- from scipy import io
- import yaml
- from .kaux import log
- from .ringbuffer import RingBuffer
- class DataNormalizer:
- def __init__(self, params, initial_data=None):
- self.params = params
- self.norm_rate = {}
- self.norm_rate['ch_ids'] = [ch.id for ch in self.params.daq.normalization.channels]
- self.norm_rate['bottoms'] = np.asarray([ch.bottom for ch in self.params.daq.normalization.channels])
- self.norm_rate['tops'] = np.asarray([ch.top for ch in self.params.daq.normalization.channels])
- self.norm_rate['invs'] = [ch.invert for ch in self.params.daq.normalization.channels]
-
- n_norm_buffer = int(self.params.daq.normalization.len * (1000.0 / self.params.daq.spike_rates.loop_interval))
- self.norm_buffer = RingBuffer(n_norm_buffer, dtype=(float, self.params.daq.n_channels), allow_overwrite=True)
- self.last_update = time.time()
- if initial_data is not None:
- self.norm_buffer.extend(initial_data)
-
- def _update_norm_range(self):
- buf_vals = self.norm_buffer[:, self.norm_rate['ch_ids']]
- centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
- self.norm_rate['bottoms'] = centiles[0, :]
- self.norm_rate['tops'] = centiles[1, :]
- log.info(f"Updated normalization ranges for channels {self.norm_rate['ch_ids']} to bottoms: {self.norm_rate['bottoms']}, tops: {self.norm_rate['tops']}")
- def _update_norm_range_all(self):
- buf_vals = np.mean(self.norm_buffer, axis=1)
- centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
- # log.info(f"Centiles: {centiles}")
-
- self.params.daq.normalization.all_channels.bottom = centiles[0]
- self.params.daq.normalization.all_channels.top = centiles[1]
- log.info(f"Updated normalization range for all channels to [{self.params.daq.normalization.all_channels.bottom}, {self.params.daq.normalization.all_channels.top}]")
-
- def update_norm_range(self, data=None, force=False):
- if data is not None and data.size > 0:
- self.norm_buffer.extend(data)
- if (self.params.daq.normalization.do_update and (time.time() - self.last_update >= self.params.daq.normalization.update_interval)) or force:
- if self.params.daq.normalization.use_all_channels:
- self._update_norm_range_all()
- else:
- self._update_norm_range()
- self.last_update = time.time()
- log.info(f"New channel normalization setting: {yaml.dump(self._format_current_config(), sort_keys=False, default_flow_style=None)}")
- def _format_current_config(self):
- if self.params.daq.normalization.use_all_channels:
- out_dict = {'all_channels': {'bottom': float(self.params.daq.normalization.all_channels.bottom), 'top': float(self.params.daq.normalization.all_channels.top),
- 'invert': bool(self.params.daq.normalization.all_channels.invert)}}
- else:
- out_dict = {'channels': []}
- for ii in range(len(self.norm_rate['ch_ids'])):
- out_dict['channels'].append({'id': int(self.norm_rate['ch_ids'][ii]),
- 'bottom': float(self.norm_rate['bottoms'][ii]),
- 'top': float(self.norm_rate['tops'][ii]),
- 'invert': self.norm_rate['invs'][ii]}
- )
- return out_dict
-
-
- def _calculate_all_norm_rate(self, buf_item):
- avg_r = np.mean(buf_item, axis=1)
- if self.params.daq.normalization.clamp_firing_rates:
- avg_r = np.maximum(np.minimum(avg_r, self.params.daq.normalization.all_channels.top), self.params.daq.normalization.all_channels.bottom)
- norm_rate = (avg_r - self.params.daq.normalization.all_channels.bottom) / (self.params.daq.normalization.all_channels.top - self.params.daq.normalization.all_channels.bottom)
- if self.params.daq.normalization.all_channels.invert:
- norm_rate = 1 - norm_rate
- return norm_rate
-
- def _calculate_individual_norm_rate(self, buf_items):
- """Calculate normalized firing rate, determined by feedback settings"""
- if self.params.daq.normalization.clamp_firing_rates:
- clamped_rates = np.maximum(np.minimum(buf_items[:, self.norm_rate['ch_ids']], self.norm_rate['tops']), self.norm_rate['bottoms'])
- else:
- clamped_rates = buf_items[:, self.norm_rate['ch_ids']]
- denom = self.norm_rate['tops'] - self.norm_rate['bottoms']
- if np.all(denom==0):
- denom[:] = 1
- norm_rates = (clamped_rates - self.norm_rate['bottoms']) / denom
- norm_rates[:, self.norm_rate['invs']] = 1 - norm_rates[:, self.norm_rate['invs']]
- norm_rate = np.nanmean(norm_rates, axis=1)
- if not self.params.daq.normalization.clamp_firing_rates:
- norm_rate = np.maximum(norm_rate, 0.0)
- return norm_rate
-
-
- def calculate_norm_rate(self, buf_item):
- if buf_item.ndim == 1:
- buf_item.shape = (1, buf_item.shape[0])
- if self.params.daq.normalization.use_all_channels:
- return self._calculate_all_norm_rate(buf_item)
- else:
- return self._calculate_individual_norm_rate(buf_item)
-
|