#include "sre/SpikeRateEstimator.h" namespace cc { namespace sre { SpikeRateEstimator::SpikeRateEstimator() : m_history_length(2'000) , m_t_spike_loop(10ms) , m_normalize_rates(true) , m_keep_going(true) , m_thread_loop_over_time_flag(false) , m_spkr_rb_needs_reset(false) { init_logger(); // mp_logger->debug("SpikeRateEstimator()"); fillChannelUnitList(cc::constants::MaxFEChannels, cc::constants::MaxUnits); updateRateNormalizationFactor(); assignSpikeQueue(); } SpikeRateEstimator::SpikeRateEstimator(size_type n_channels, size_type n_units, size_type history_length, rate_interval_type t_spike_loop) : m_history_length(history_length) , m_t_spike_loop(t_spike_loop) , m_normalize_rates(true) , m_keep_going(true) , m_thread_loop_over_time_flag(false) , m_spkr_rb_needs_reset(false) { init_logger(); // mp_logger->debug("SpikeRateEstimator(ch u hl sl)"); fillChannelUnitList(n_channels, n_units); updateRateNormalizationFactor(); assignSpikeQueue(); } SpikeRateEstimator::SpikeRateEstimator(ch_un_list_type ch_un_list, size_type history_length, rate_interval_type t_spike_loop) : m_ch_un_list(ch_un_list) , m_history_length(history_length) , m_t_spike_loop(t_spike_loop) , m_normalize_rates(true) , m_keep_going(true) , m_thread_loop_over_time_flag(false) , m_spkr_rb_needs_reset(false) { init_logger(); // mp_logger->debug("SpikeRateEstimator(chunl hl sl)"); applyChannelUnitList(); updateRateNormalizationFactor(); assignSpikeQueue(); } SpikeRateEstimator::~SpikeRateEstimator () { // mp_logger->debug("~SpikeRateEstimator()"); stopProcessThread(); } bool SpikeRateEstimator::addSpike(const Channel_t channel, const Unit_t unit) { return addSpike(spike{0, channel, unit}); } void SpikeRateEstimator::processSpikes() { consumeQueue(); applyCallback(); } SpikeRateEstimator::size_type SpikeRateEstimator::consumeQueue() { size_type spks_consumed = 0; size_type unit_idx = 0; while (!mp_sq->isEmpty()) { auto const& p_spk = mp_sq->frontPtr(); if (ch_u_offset(p_spk->channel, p_spk->unit, unit_idx)) { auto & uh = (*mp_unit_hist_container)[unit_idx]; ++ uh; } mp_sq->popFront(); ++ spks_consumed; } for (auto & uh: *mp_unit_hist_container) { uh.step(); } return spks_consumed; } void SpikeRateEstimator::bindCallback(const rate_callback_type& rcb) { m_rate_callback = rcb; } void SpikeRateEstimator::applyCallback() { if (mp_spkr_rb) { if (m_spkr_rb_needs_reset) { mp_logger->info("Reset spike rate ring buffer because of channel-unit map change"); mp_spkr_rb->reset(); m_spkr_rb_needs_reset = false; } Timestamp_t lastNSPTime = 0; if (m_timestamp_callback) { lastNSPTime = m_timestamp_callback(); // mp_logger->debug("applyCallback: got latest NSP timestamp: {}", lastNSPTime); } rate_ring_buffer_item_type time_and_rates{lastNSPTime, rate_item_type(n_ch_u(), 0)}; if (m_rate_callback) { for (auto p = std::make_pair(mp_unit_hist_container->cbegin(), time_and_rates.second.begin()); p.first != mp_unit_hist_container->cend() && p.second != time_and_rates.second.end(); ++p.first, ++p.second) { if ((*p.first).has_spikes()) { (*p.second) = m_rate_callback(*p.first); if (m_normalize_rates) { (*p.second) *= m_rate_norm_factor; } // mp_logger->debug("{} -> rate {}", *p.first, *p.second); } else { (*p.second) = 0; // mp_logger->debug("{} -> rate {} [no new spikes]", *p.first, *p.second); } } } else { mp_logger->warn("no rate callback"); } mp_spkr_rb->push(time_and_rates); } else { mp_logger->warn("no rate ring buffer"); } } void SpikeRateEstimator::attachSpikeRateRB(std::shared_ptr sr_rb) { // mp_logger->debug("attachSpikeRateRB"); mp_spkr_rb = sr_rb; } void SpikeRateEstimator::signalStart() { stopProcessThread(); // mp_logger->debug("Launching spike processing thread"); m_keep_going = true; m_spkr_rb_needs_reset = true; consumeQueue(); m_process_thread = std::thread(&SpikeRateEstimator::processThreadFun, this); } void SpikeRateEstimator::signalStop() { m_keep_going = false; } void SpikeRateEstimator::stopProcessThread() { m_keep_going = false; if (m_process_thread.joinable()) { // mp_logger->debug("Stopping spike processing thread"); m_process_thread.join(); } } void SpikeRateEstimator::setSpikeRateNormalization(bool enabled) { m_normalize_rates = enabled; } bool SpikeRateEstimator::getSpikeRateNormalization() const { return m_normalize_rates; } void SpikeRateEstimator::set_logger(std::shared_ptr logger) { mp_logger = logger->clone("SpikeRateEstimator"); } void SpikeRateEstimator::setChannelUnitList(const ch_un_list_type & ch_un_list) { m_ch_un_list.clear(); std::copy(ch_un_list.cbegin(), ch_un_list.cend(), std::back_inserter(m_ch_un_list)); applyChannelUnitList(); } void SpikeRateEstimator::applyChannelUnitList() { m_ch_un_map.clear(); size_type n_ch_un = m_ch_un_list.size(); for (size_type i = 0; i < n_ch_un; ++ i) { m_ch_un_map[m_ch_un_list[i]] = i; } mp_unit_hist_container = std::make_unique(n_ch_un, m_history_length); assignSpikeQueue(); m_spkr_rb_needs_reset = true; mp_logger->debug("Set new (channel, unit) list, with {} items", m_ch_un_list.size()); } void SpikeRateEstimator::fillChannelUnitList(const size_type n_channels, const size_type n_units) { ch_un_list_type ch_un_list; ch_un_list.reserve(n_channels * n_units); for (size_type ch = 1; ch <= n_channels; ++ ch) { for (size_type u = 0; u < n_units; ++ u) { ch_un_list.push_back(std::make_pair(ch, u)); } } setChannelUnitList(ch_un_list); } std::pair SpikeRateEstimator::getChannelUnitMapping() const { return std::make_pair(m_ch_un_list, m_ch_un_map); } SpikeRateEstimator::ch_un_list_type SpikeRateEstimator::getChannelUnitList() const { return m_ch_un_list; } void SpikeRateEstimator::setSpikeRateLoopInterval(const SpikeRateEstimator::rate_interval_type & new_interval) { if (m_t_spike_loop != new_interval) { m_t_spike_loop = new_interval; updateRateNormalizationFactor(); assignSpikeQueue(); m_spkr_rb_needs_reset = true; } } SpikeRateEstimator::rate_interval_type SpikeRateEstimator::getSpikeRateLoopInterval() const { return m_t_spike_loop; } void SpikeRateEstimator::bindTimestampCallback(const SpikeRateEstimator::timestamp_callback_type & tscb) { m_timestamp_callback = tscb; } bool SpikeRateEstimator::ringBufferInvalid() const { return !mp_spkr_rb || m_spkr_rb_needs_reset; } void SpikeRateEstimator::updateRateNormalizationFactor() { m_rate_norm_factor = ((rate_type)rate_interval_type::period::den / ((rate_type) m_t_spike_loop.count() * (rate_type) rate_interval_type::period::num)); mp_logger->debug("Update rate normalization factor: {} ( {} / {})", m_rate_norm_factor, (rate_type)rate_interval_type::period::num, (rate_type) rate_interval_type::period::den); } void SpikeRateEstimator::assignSpikeQueue() { auto msbins = (rate_type)rate_interval_type::period::num * 1000 / (rate_type) rate_interval_type::period::den * m_t_spike_loop.count(); mp_sq = std::make_unique(msbins * m_ch_un_list.size() * 2 + 1); // mp_logger->debug("Assigned new spike queue"); } void SpikeRateEstimator::processThreadFun() { m_thread_loop_over_time_flag.store(false); clock::time_point currentStartTime{clock::now()}; clock::time_point nextStartTime{clock::now()}; clock::duration taskDuration; // mp_logger->debug("SpikeRateEstimator in processThreadFun"); while (m_keep_going.load()) { currentStartTime = clock::now(); processSpikes(); taskDuration = clock::now() - currentStartTime; if (taskDuration > m_t_spike_loop) { m_thread_loop_over_time_flag = true; m_thread_loop_over_time = taskDuration; mp_logger->warn("Spike rate processing loop went over time! {}μs vs. {}ms", duration_cast(taskDuration).count(), m_t_spike_loop.count()); } nextStartTime += m_t_spike_loop; std::this_thread::sleep_until(nextStartTime); } // mp_logger->debug("Ending processThreadFun"); } void SpikeRateEstimator::init_logger() { auto l = spdlog::get("SpikeRateEstimator"); if (l) { mp_logger = l; } else { mp_logger = spdlog::stdout_color_mt("SpikeRateEstimator"); } // mp_logger->debug("SpikeRateEstimator initiated"); } } // namespace sre } // namespace cc