SpikeRateEstimator.cpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. #include "sre/SpikeRateEstimator.h"
  2. namespace cc
  3. {
  4. namespace sre
  5. {
  6. SpikeRateEstimator::SpikeRateEstimator()
  7. : m_history_length(2'000)
  8. , m_t_spike_loop(10ms)
  9. , m_normalize_rates(true)
  10. , m_keep_going(true)
  11. , m_thread_loop_over_time_flag(false)
  12. , m_spkr_rb_needs_reset(false)
  13. {
  14. init_logger();
  15. // mp_logger->debug("SpikeRateEstimator()");
  16. fillChannelUnitList(cc::constants::MaxFEChannels, cc::constants::MaxUnits);
  17. updateRateNormalizationFactor();
  18. assignSpikeQueue();
  19. }
  20. SpikeRateEstimator::SpikeRateEstimator(size_type n_channels, size_type n_units, size_type history_length,
  21. rate_interval_type t_spike_loop)
  22. : m_history_length(history_length)
  23. , m_t_spike_loop(t_spike_loop)
  24. , m_normalize_rates(true)
  25. , m_keep_going(true)
  26. , m_thread_loop_over_time_flag(false)
  27. , m_spkr_rb_needs_reset(false)
  28. {
  29. init_logger();
  30. // mp_logger->debug("SpikeRateEstimator(ch u hl sl)");
  31. fillChannelUnitList(n_channels, n_units);
  32. updateRateNormalizationFactor();
  33. assignSpikeQueue();
  34. }
  35. SpikeRateEstimator::SpikeRateEstimator(ch_un_list_type ch_un_list, size_type history_length,
  36. rate_interval_type t_spike_loop)
  37. : m_ch_un_list(ch_un_list)
  38. , m_history_length(history_length)
  39. , m_t_spike_loop(t_spike_loop)
  40. , m_normalize_rates(true)
  41. , m_keep_going(true)
  42. , m_thread_loop_over_time_flag(false)
  43. , m_spkr_rb_needs_reset(false)
  44. {
  45. init_logger();
  46. // mp_logger->debug("SpikeRateEstimator(chunl hl sl)");
  47. applyChannelUnitList();
  48. updateRateNormalizationFactor();
  49. assignSpikeQueue();
  50. }
  51. SpikeRateEstimator::~SpikeRateEstimator ()
  52. {
  53. // mp_logger->debug("~SpikeRateEstimator()");
  54. stopProcessThread();
  55. }
  56. bool SpikeRateEstimator::addSpike(const Channel_t channel, const Unit_t unit)
  57. {
  58. return addSpike(spike{0, channel, unit});
  59. }
  60. void SpikeRateEstimator::processSpikes()
  61. {
  62. consumeQueue();
  63. applyCallback();
  64. }
  65. SpikeRateEstimator::size_type SpikeRateEstimator::consumeQueue()
  66. {
  67. size_type spks_consumed = 0;
  68. size_type unit_idx = 0;
  69. while (!mp_sq->isEmpty())
  70. {
  71. auto const& p_spk = mp_sq->frontPtr();
  72. if (ch_u_offset(p_spk->channel, p_spk->unit, unit_idx))
  73. {
  74. auto & uh = (*mp_unit_hist_container)[unit_idx];
  75. ++ uh;
  76. }
  77. mp_sq->popFront();
  78. ++ spks_consumed;
  79. }
  80. for (auto & uh: *mp_unit_hist_container)
  81. {
  82. uh.step();
  83. }
  84. return spks_consumed;
  85. }
  86. void SpikeRateEstimator::bindCallback(const rate_callback_type& rcb)
  87. {
  88. m_rate_callback = rcb;
  89. }
  90. void SpikeRateEstimator::applyCallback()
  91. {
  92. if (mp_spkr_rb)
  93. {
  94. if (m_spkr_rb_needs_reset)
  95. {
  96. mp_logger->info("Reset spike rate ring buffer because of channel-unit map change");
  97. mp_spkr_rb->reset();
  98. m_spkr_rb_needs_reset = false;
  99. }
  100. Timestamp_t lastNSPTime = 0;
  101. if (m_timestamp_callback)
  102. {
  103. lastNSPTime = m_timestamp_callback();
  104. // mp_logger->debug("applyCallback: got latest NSP timestamp: {}", lastNSPTime);
  105. }
  106. rate_ring_buffer_item_type time_and_rates{lastNSPTime, rate_item_type(n_ch_u(), 0)};
  107. if (m_rate_callback)
  108. {
  109. for (auto p = std::make_pair(mp_unit_hist_container->cbegin(), time_and_rates.second.begin());
  110. p.first != mp_unit_hist_container->cend() && p.second != time_and_rates.second.end();
  111. ++p.first, ++p.second)
  112. {
  113. if ((*p.first).has_spikes())
  114. {
  115. (*p.second) = m_rate_callback(*p.first);
  116. if (m_normalize_rates)
  117. {
  118. (*p.second) *= m_rate_norm_factor;
  119. }
  120. // mp_logger->debug("{} -> rate {}", *p.first, *p.second);
  121. }
  122. else
  123. {
  124. (*p.second) = 0;
  125. // mp_logger->debug("{} -> rate {} [no new spikes]", *p.first, *p.second);
  126. }
  127. }
  128. }
  129. else
  130. {
  131. mp_logger->warn("no rate callback");
  132. }
  133. mp_spkr_rb->push(time_and_rates);
  134. }
  135. else
  136. {
  137. mp_logger->warn("no rate ring buffer");
  138. }
  139. }
  140. void SpikeRateEstimator::attachSpikeRateRB(std::shared_ptr<rate_ring_buffer_type> sr_rb)
  141. {
  142. // mp_logger->debug("attachSpikeRateRB");
  143. mp_spkr_rb = sr_rb;
  144. }
  145. void SpikeRateEstimator::signalStart()
  146. {
  147. stopProcessThread();
  148. // mp_logger->debug("Launching spike processing thread");
  149. m_keep_going = true;
  150. m_spkr_rb_needs_reset = true;
  151. consumeQueue();
  152. m_process_thread = std::thread(&SpikeRateEstimator::processThreadFun, this);
  153. }
  154. void SpikeRateEstimator::signalStop()
  155. {
  156. m_keep_going = false;
  157. }
  158. void SpikeRateEstimator::stopProcessThread()
  159. {
  160. m_keep_going = false;
  161. if (m_process_thread.joinable())
  162. {
  163. // mp_logger->debug("Stopping spike processing thread");
  164. m_process_thread.join();
  165. }
  166. }
  167. void SpikeRateEstimator::setSpikeRateNormalization(bool enabled)
  168. {
  169. m_normalize_rates = enabled;
  170. }
  171. bool SpikeRateEstimator::getSpikeRateNormalization() const
  172. {
  173. return m_normalize_rates;
  174. }
  175. void SpikeRateEstimator::set_logger(std::shared_ptr<spdlog::logger> logger)
  176. {
  177. mp_logger = logger->clone("SpikeRateEstimator");
  178. }
  179. void SpikeRateEstimator::setChannelUnitList(const ch_un_list_type & ch_un_list)
  180. {
  181. m_ch_un_list.clear();
  182. std::copy(ch_un_list.cbegin(), ch_un_list.cend(), std::back_inserter(m_ch_un_list));
  183. applyChannelUnitList();
  184. }
  185. void SpikeRateEstimator::applyChannelUnitList()
  186. {
  187. m_ch_un_map.clear();
  188. size_type n_ch_un = m_ch_un_list.size();
  189. for (size_type i = 0; i < n_ch_un; ++ i)
  190. {
  191. m_ch_un_map[m_ch_un_list[i]] = i;
  192. }
  193. mp_unit_hist_container = std::make_unique<unit_hist_container_type>(n_ch_un, m_history_length);
  194. assignSpikeQueue();
  195. m_spkr_rb_needs_reset = true;
  196. mp_logger->debug("Set new (channel, unit) list, with {} items", m_ch_un_list.size());
  197. }
  198. void SpikeRateEstimator::fillChannelUnitList(const size_type n_channels, const size_type n_units)
  199. {
  200. ch_un_list_type ch_un_list;
  201. ch_un_list.reserve(n_channels * n_units);
  202. for (size_type ch = 1; ch <= n_channels; ++ ch)
  203. {
  204. for (size_type u = 0; u < n_units; ++ u)
  205. {
  206. ch_un_list.push_back(std::make_pair(ch, u));
  207. }
  208. }
  209. setChannelUnitList(ch_un_list);
  210. }
  211. std::pair<SpikeRateEstimator::ch_un_list_type, SpikeRateEstimator::ch_un_map_type> SpikeRateEstimator::getChannelUnitMapping() const
  212. {
  213. return std::make_pair(m_ch_un_list, m_ch_un_map);
  214. }
  215. SpikeRateEstimator::ch_un_list_type SpikeRateEstimator::getChannelUnitList() const
  216. {
  217. return m_ch_un_list;
  218. }
  219. void SpikeRateEstimator::setSpikeRateLoopInterval(const SpikeRateEstimator::rate_interval_type & new_interval)
  220. {
  221. if (m_t_spike_loop != new_interval)
  222. {
  223. m_t_spike_loop = new_interval;
  224. updateRateNormalizationFactor();
  225. assignSpikeQueue();
  226. m_spkr_rb_needs_reset = true;
  227. }
  228. }
  229. SpikeRateEstimator::rate_interval_type SpikeRateEstimator::getSpikeRateLoopInterval() const
  230. {
  231. return m_t_spike_loop;
  232. }
  233. void SpikeRateEstimator::bindTimestampCallback(const SpikeRateEstimator::timestamp_callback_type & tscb)
  234. {
  235. m_timestamp_callback = tscb;
  236. }
  237. bool SpikeRateEstimator::ringBufferInvalid() const
  238. {
  239. return !mp_spkr_rb || m_spkr_rb_needs_reset;
  240. }
  241. void SpikeRateEstimator::updateRateNormalizationFactor()
  242. {
  243. m_rate_norm_factor = ((rate_type)rate_interval_type::period::den / ((rate_type) m_t_spike_loop.count() * (rate_type) rate_interval_type::period::num));
  244. mp_logger->debug("Update rate normalization factor: {} ( {} / {})", m_rate_norm_factor, (rate_type)rate_interval_type::period::num, (rate_type) rate_interval_type::period::den);
  245. }
  246. void SpikeRateEstimator::assignSpikeQueue()
  247. {
  248. auto msbins = (rate_type)rate_interval_type::period::num * 1000 / (rate_type) rate_interval_type::period::den * m_t_spike_loop.count();
  249. mp_sq = std::make_unique<spike_queue_type>(msbins * m_ch_un_list.size() * 2 + 1);
  250. // mp_logger->debug("Assigned new spike queue");
  251. }
  252. void SpikeRateEstimator::processThreadFun()
  253. {
  254. m_thread_loop_over_time_flag.store(false);
  255. clock::time_point currentStartTime{clock::now()};
  256. clock::time_point nextStartTime{clock::now()};
  257. clock::duration taskDuration;
  258. // mp_logger->debug("SpikeRateEstimator in processThreadFun");
  259. while (m_keep_going.load())
  260. {
  261. currentStartTime = clock::now();
  262. processSpikes();
  263. taskDuration = clock::now() - currentStartTime;
  264. if (taskDuration > m_t_spike_loop)
  265. {
  266. m_thread_loop_over_time_flag = true;
  267. m_thread_loop_over_time = taskDuration;
  268. mp_logger->warn("Spike rate processing loop went over time! {}μs vs. {}ms", duration_cast<microseconds>(taskDuration).count(), m_t_spike_loop.count());
  269. }
  270. nextStartTime += m_t_spike_loop;
  271. std::this_thread::sleep_until(nextStartTime);
  272. }
  273. // mp_logger->debug("Ending processThreadFun");
  274. }
  275. void SpikeRateEstimator::init_logger()
  276. {
  277. auto l = spdlog::get("SpikeRateEstimator");
  278. if (l)
  279. {
  280. mp_logger = l;
  281. }
  282. else
  283. {
  284. mp_logger = spdlog::stdout_color_mt<spdlog::async_factory>("SpikeRateEstimator");
  285. }
  286. // mp_logger->debug("SpikeRateEstimator initiated");
  287. }
  288. } // namespace sre
  289. } // namespace cc