123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- #include "sre/SpikeRateEstimator.h"
- namespace cc
- {
- namespace sre
- {
- SpikeRateEstimator::SpikeRateEstimator()
- : m_history_length(2'000)
- , m_t_spike_loop(10ms)
- , m_normalize_rates(true)
- , m_keep_going(true)
- , m_thread_loop_over_time_flag(false)
- , m_spkr_rb_needs_reset(false)
- {
- init_logger();
- // mp_logger->debug("SpikeRateEstimator()");
- fillChannelUnitList(cc::constants::MaxFEChannels, cc::constants::MaxUnits);
- updateRateNormalizationFactor();
- assignSpikeQueue();
- }
- SpikeRateEstimator::SpikeRateEstimator(size_type n_channels, size_type n_units, size_type history_length,
- rate_interval_type t_spike_loop)
- : m_history_length(history_length)
- , m_t_spike_loop(t_spike_loop)
- , m_normalize_rates(true)
- , m_keep_going(true)
- , m_thread_loop_over_time_flag(false)
- , m_spkr_rb_needs_reset(false)
- {
- init_logger();
- // mp_logger->debug("SpikeRateEstimator(ch u hl sl)");
-
- fillChannelUnitList(n_channels, n_units);
- updateRateNormalizationFactor();
- assignSpikeQueue();
- }
- SpikeRateEstimator::SpikeRateEstimator(ch_un_list_type ch_un_list, size_type history_length,
- rate_interval_type t_spike_loop)
- : m_ch_un_list(ch_un_list)
- , m_history_length(history_length)
- , m_t_spike_loop(t_spike_loop)
- , m_normalize_rates(true)
- , m_keep_going(true)
- , m_thread_loop_over_time_flag(false)
- , m_spkr_rb_needs_reset(false)
- {
- init_logger();
- // mp_logger->debug("SpikeRateEstimator(chunl hl sl)");
-
- applyChannelUnitList();
- updateRateNormalizationFactor();
- assignSpikeQueue();
- }
- SpikeRateEstimator::~SpikeRateEstimator ()
- {
- // mp_logger->debug("~SpikeRateEstimator()");
- stopProcessThread();
- }
- bool SpikeRateEstimator::addSpike(const Channel_t channel, const Unit_t unit)
- {
- return addSpike(spike{0, channel, unit});
- }
- void SpikeRateEstimator::processSpikes()
- {
- consumeQueue();
- applyCallback();
- }
- SpikeRateEstimator::size_type SpikeRateEstimator::consumeQueue()
- {
- size_type spks_consumed = 0;
- size_type unit_idx = 0;
-
- while (!mp_sq->isEmpty())
- {
- auto const& p_spk = mp_sq->frontPtr();
- if (ch_u_offset(p_spk->channel, p_spk->unit, unit_idx))
- {
- auto & uh = (*mp_unit_hist_container)[unit_idx];
- ++ uh;
- }
- mp_sq->popFront();
- ++ spks_consumed;
- }
- for (auto & uh: *mp_unit_hist_container)
- {
- uh.step();
- }
- return spks_consumed;
- }
- void SpikeRateEstimator::bindCallback(const rate_callback_type& rcb)
- {
- m_rate_callback = rcb;
- }
- void SpikeRateEstimator::applyCallback()
- {
- if (mp_spkr_rb)
- {
- if (m_spkr_rb_needs_reset)
- {
- mp_logger->info("Reset spike rate ring buffer because of channel-unit map change");
-
- mp_spkr_rb->reset();
- m_spkr_rb_needs_reset = false;
- }
-
- Timestamp_t lastNSPTime = 0;
- if (m_timestamp_callback)
- {
- lastNSPTime = m_timestamp_callback();
- // mp_logger->debug("applyCallback: got latest NSP timestamp: {}", lastNSPTime);
- }
- rate_ring_buffer_item_type time_and_rates{lastNSPTime, rate_item_type(n_ch_u(), 0)};
-
- if (m_rate_callback)
- {
- for (auto p = std::make_pair(mp_unit_hist_container->cbegin(), time_and_rates.second.begin());
- p.first != mp_unit_hist_container->cend() && p.second != time_and_rates.second.end();
- ++p.first, ++p.second)
- {
- if ((*p.first).has_spikes())
- {
- (*p.second) = m_rate_callback(*p.first);
-
- if (m_normalize_rates)
- {
- (*p.second) *= m_rate_norm_factor;
- }
- // mp_logger->debug("{} -> rate {}", *p.first, *p.second);
-
- }
- else
- {
- (*p.second) = 0;
- // mp_logger->debug("{} -> rate {} [no new spikes]", *p.first, *p.second);
- }
- }
- }
- else
- {
- mp_logger->warn("no rate callback");
- }
- mp_spkr_rb->push(time_and_rates);
- }
- else
- {
- mp_logger->warn("no rate ring buffer");
- }
- }
- void SpikeRateEstimator::attachSpikeRateRB(std::shared_ptr<rate_ring_buffer_type> sr_rb)
- {
- // mp_logger->debug("attachSpikeRateRB");
- mp_spkr_rb = sr_rb;
- }
- void SpikeRateEstimator::signalStart()
- {
- stopProcessThread();
- // mp_logger->debug("Launching spike processing thread");
- m_keep_going = true;
- m_spkr_rb_needs_reset = true;
- consumeQueue();
- m_process_thread = std::thread(&SpikeRateEstimator::processThreadFun, this);
- }
- void SpikeRateEstimator::signalStop()
- {
- m_keep_going = false;
- }
- void SpikeRateEstimator::stopProcessThread()
- {
- m_keep_going = false;
- if (m_process_thread.joinable())
- {
- // mp_logger->debug("Stopping spike processing thread");
- m_process_thread.join();
- }
- }
- void SpikeRateEstimator::setSpikeRateNormalization(bool enabled)
- {
- m_normalize_rates = enabled;
- }
- bool SpikeRateEstimator::getSpikeRateNormalization() const
- {
- return m_normalize_rates;
- }
- void SpikeRateEstimator::set_logger(std::shared_ptr<spdlog::logger> logger)
- {
- mp_logger = logger->clone("SpikeRateEstimator");
- }
- void SpikeRateEstimator::setChannelUnitList(const ch_un_list_type & ch_un_list)
- {
- m_ch_un_list.clear();
- std::copy(ch_un_list.cbegin(), ch_un_list.cend(), std::back_inserter(m_ch_un_list));
- applyChannelUnitList();
- }
- void SpikeRateEstimator::applyChannelUnitList()
- {
- m_ch_un_map.clear();
- size_type n_ch_un = m_ch_un_list.size();
-
- for (size_type i = 0; i < n_ch_un; ++ i)
- {
- m_ch_un_map[m_ch_un_list[i]] = i;
- }
- mp_unit_hist_container = std::make_unique<unit_hist_container_type>(n_ch_un, m_history_length);
- assignSpikeQueue();
- m_spkr_rb_needs_reset = true;
- mp_logger->debug("Set new (channel, unit) list, with {} items", m_ch_un_list.size());
- }
- void SpikeRateEstimator::fillChannelUnitList(const size_type n_channels, const size_type n_units)
- {
- ch_un_list_type ch_un_list;
- ch_un_list.reserve(n_channels * n_units);
- for (size_type ch = 1; ch <= n_channels; ++ ch)
- {
- for (size_type u = 0; u < n_units; ++ u)
- {
- ch_un_list.push_back(std::make_pair(ch, u));
- }
- }
- setChannelUnitList(ch_un_list);
- }
- std::pair<SpikeRateEstimator::ch_un_list_type, SpikeRateEstimator::ch_un_map_type> SpikeRateEstimator::getChannelUnitMapping() const
- {
- return std::make_pair(m_ch_un_list, m_ch_un_map);
- }
- SpikeRateEstimator::ch_un_list_type SpikeRateEstimator::getChannelUnitList() const
- {
- return m_ch_un_list;
- }
- void SpikeRateEstimator::setSpikeRateLoopInterval(const SpikeRateEstimator::rate_interval_type & new_interval)
- {
- if (m_t_spike_loop != new_interval)
- {
- m_t_spike_loop = new_interval;
- updateRateNormalizationFactor();
- assignSpikeQueue();
- m_spkr_rb_needs_reset = true;
- }
- }
- SpikeRateEstimator::rate_interval_type SpikeRateEstimator::getSpikeRateLoopInterval() const
- {
- return m_t_spike_loop;
- }
- void SpikeRateEstimator::bindTimestampCallback(const SpikeRateEstimator::timestamp_callback_type & tscb)
- {
- m_timestamp_callback = tscb;
- }
- bool SpikeRateEstimator::ringBufferInvalid() const
- {
- return !mp_spkr_rb || m_spkr_rb_needs_reset;
- }
- void SpikeRateEstimator::updateRateNormalizationFactor()
- {
- m_rate_norm_factor = ((rate_type)rate_interval_type::period::den / ((rate_type) m_t_spike_loop.count() * (rate_type) rate_interval_type::period::num));
- mp_logger->debug("Update rate normalization factor: {} ( {} / {})", m_rate_norm_factor, (rate_type)rate_interval_type::period::num, (rate_type) rate_interval_type::period::den);
- }
- void SpikeRateEstimator::assignSpikeQueue()
- {
- auto msbins = (rate_type)rate_interval_type::period::num * 1000 / (rate_type) rate_interval_type::period::den * m_t_spike_loop.count();
- mp_sq = std::make_unique<spike_queue_type>(msbins * m_ch_un_list.size() * 2 + 1);
- // mp_logger->debug("Assigned new spike queue");
- }
- void SpikeRateEstimator::processThreadFun()
- {
- m_thread_loop_over_time_flag.store(false);
- clock::time_point currentStartTime{clock::now()};
- clock::time_point nextStartTime{clock::now()};
- clock::duration taskDuration;
- // mp_logger->debug("SpikeRateEstimator in processThreadFun");
-
- while (m_keep_going.load())
- {
- currentStartTime = clock::now();
- processSpikes();
- taskDuration = clock::now() - currentStartTime;
- if (taskDuration > m_t_spike_loop)
- {
- m_thread_loop_over_time_flag = true;
- m_thread_loop_over_time = taskDuration;
- mp_logger->warn("Spike rate processing loop went over time! {}μs vs. {}ms", duration_cast<microseconds>(taskDuration).count(), m_t_spike_loop.count());
- }
- nextStartTime += m_t_spike_loop;
- std::this_thread::sleep_until(nextStartTime);
- }
- // mp_logger->debug("Ending processThreadFun");
- }
- void SpikeRateEstimator::init_logger()
- {
- auto l = spdlog::get("SpikeRateEstimator");
- if (l)
- {
- mp_logger = l;
- }
- else
- {
- mp_logger = spdlog::stdout_color_mt<spdlog::async_factory>("SpikeRateEstimator");
- }
- // mp_logger->debug("SpikeRateEstimator initiated");
- }
- } // namespace sre
- } // namespace cc
|