SpikeBandPowerEstimator.cpp 12 KB


  1. #include "sbpe/SpikeBandPowerEstimator.h"
  2. #include <algorithm>
  3. #include <iostream>
  4. namespace cc
  5. {
  6. namespace sbpe
  7. {
  8. SpikeBandPowerEstimator::SpikeBandPowerEstimator()
  9. : m_ch_list{}
  10. , mp_cont_data(std::make_unique<cont_sample_rb_type>(1))
  11. , m_iir_filt_coeff_b{0.190616600097498, 0, -0.381233200194995, 0, 0.190616600097498}
  12. , m_iir_filt_coeff_a{1, -2.33508240207651, 1.9509646897792, -0.819263685312402, 0.206671985166565}
  13. , mp_iir_filt_rms(std::make_unique<iir_filt_RMS_type>(1, 600, m_iir_filt_coeff_b, m_iir_filt_coeff_a))
  14. , m_t_sbp_loop(10ms)
  15. , m_keep_going(false)
  16. , m_thread_loop_over_time_flag(false)
  17. , m_sbp_rb_needs_reset(true)
  18. , m_iir_filt_need_reset(true)
  19. , m_accepted_sample_group(6)
  20. {
  21. initLogger();
  22. fillChannelList(constants::MaxFEChannels);
  23. }
  24. SpikeBandPowerEstimator::SpikeBandPowerEstimator(size_type n_channels,
  25. sbp_interval_type t_sbp_loop)
  26. : m_ch_list{}
  27. , mp_cont_data(std::make_unique<cont_sample_rb_type>(1))
  28. , m_iir_filt_coeff_b{0.190616600097498, 0, -0.381233200194995, 0, 0.190616600097498}
  29. , m_iir_filt_coeff_a{1, -2.33508240207651, 1.9509646897792, -0.819263685312402, 0.206671985166565}
  30. , mp_iir_filt_rms(std::make_unique<iir_filt_RMS_type>(n_channels, 600, m_iir_filt_coeff_b, m_iir_filt_coeff_a))
  31. , m_t_sbp_loop(t_sbp_loop)
  32. , m_keep_going(false)
  33. , m_thread_loop_over_time_flag(false)
  34. , m_sbp_rb_needs_reset(true)
  35. , m_iir_filt_need_reset(true)
  36. , m_accepted_sample_group(6)
  37. {
  38. initLogger();
  39. fillChannelList(n_channels);
  40. }
  41. SpikeBandPowerEstimator::SpikeBandPowerEstimator(ch_list_type ch_list,
  42. sbp_interval_type t_sbp_loop)
  43. : m_ch_list(ch_list)
  44. , mp_cont_data(std::make_unique<cont_sample_rb_type>(1))
  45. , m_iir_filt_coeff_b{0.190616600097498, 0, -0.381233200194995, 0, 0.190616600097498}
  46. , m_iir_filt_coeff_a{1, -2.33508240207651, 1.9509646897792, -0.819263685312402, 0.206671985166565}
  47. , mp_iir_filt_rms(std::make_unique<iir_filt_RMS_type>(ch_list.size(), 600, m_iir_filt_coeff_b, m_iir_filt_coeff_a))
  48. , m_t_sbp_loop(t_sbp_loop)
  49. , m_keep_going(false)
  50. , m_thread_loop_over_time_flag(false)
  51. , m_sbp_rb_needs_reset(true)
  52. , m_iir_filt_need_reset(true)
  53. , m_accepted_sample_group(6)
  54. {
  55. initLogger();
  56. applyChannelList();
  57. }
  58. SpikeBandPowerEstimator::~SpikeBandPowerEstimator ()
  59. {
  60. stopProcessThread();
  61. }
  62. void SpikeBandPowerEstimator::attachSpikeBandPowerRB(std::shared_ptr<sbp_ring_buffer_type> sbp_rb)
  63. {
  64. mp_logger->debug("attachSpikeBandPowerRB");
  65. mp_sbp_rb = sbp_rb;
  66. }
  67. void SpikeBandPowerEstimator::setChannelList(const ch_list_type & ch_list)
  68. {
  69. m_ch_list.clear();
  70. std::copy(ch_list.cbegin(), ch_list.cend(), std::back_inserter(m_ch_list));
  71. applyChannelList();
  72. }
  73. void SpikeBandPowerEstimator::applyChannelList()
  74. {
  75. m_ch_map.clear();
  76. size_type n_ch = m_ch_list.size();
  77. for (size_type i = 0; i < n_ch; ++ i)
  78. {
  79. m_ch_map[m_ch_list[i]] = i;
  80. }
  81. updateChannelPairList(mp_cont_data->ch_list_cbegin(), mp_cont_data->ch_list_cend());
  82. m_sbp_rb_needs_reset = true;
  83. m_iir_filt_need_reset = true;
  84. mp_logger->debug("Set new channel list, with {} items", m_ch_list.size());
  85. }
  86. void SpikeBandPowerEstimator::fillChannelList(const size_type n_channels)
  87. {
  88. ch_list_type ch_list(n_channels);
  89. std::iota (std::begin(ch_list), std::end(ch_list), 1);
  90. setChannelList(ch_list);
  91. }
  92. std::pair<SpikeBandPowerEstimator::ch_list_type, SpikeBandPowerEstimator::ch_map_type> SpikeBandPowerEstimator::getChannelMapping() const
  93. {
  94. return std::make_pair(m_ch_list, m_ch_map);
  95. }
  96. SpikeBandPowerEstimator::ch_list_type SpikeBandPowerEstimator::getChannelList() const
  97. {
  98. return m_ch_list;
  99. }
  100. void SpikeBandPowerEstimator::updateChannelPairList(const ChannelNumList_t & channel_ids)
  101. {
  102. updateChannelPairList(channel_ids.cbegin(), channel_ids.cend());
  103. }
  104. void SpikeBandPowerEstimator::setLoopInterval(const SpikeBandPowerEstimator::sbp_interval_type & new_interval)
  105. {
  106. if (m_t_sbp_loop != new_interval)
  107. {
  108. m_t_sbp_loop = new_interval;
  109. m_sbp_rb_needs_reset = true;
  110. }
  111. }
  112. SpikeBandPowerEstimator::sbp_interval_type SpikeBandPowerEstimator::getLoopInterval() const
  113. {
  114. return m_t_sbp_loop;
  115. }
  116. bool SpikeBandPowerEstimator::ringBufferInvalid() const
  117. {
  118. return !mp_sbp_rb || m_sbp_rb_needs_reset || m_iir_filt_need_reset;
  119. }
  120. void SpikeBandPowerEstimator::setAcceptedSampleGroup(const size_t group)
  121. {
  122. if (group < constants::NSampleGroups)
  123. {
  124. m_accepted_sample_group = group;
  125. }
  126. }
  127. size_t SpikeBandPowerEstimator::getAcceptedSampleGroup() const
  128. {
  129. return m_accepted_sample_group;
  130. }
  131. void SpikeBandPowerEstimator::setIIRFilterCoeffs(const iir_filt_RMS_type::iir_filter_coeff_container_type &b, const iir_filt_RMS_type::iir_filter_coeff_container_type &a)
  132. {
  133. if (m_iir_filt_coeff_b.n_rows != b.n_rows || m_iir_filt_coeff_b.n_cols != b.n_cols || any(m_iir_filt_coeff_b != b))
  134. {
  135. m_iir_filt_coeff_b = b;
  136. m_iir_filt_need_reset = true;
  137. }
  138. if (m_iir_filt_coeff_a.n_rows != a.n_rows || m_iir_filt_coeff_a.n_cols != a.n_cols ||any(m_iir_filt_coeff_a != a))
  139. {
  140. m_iir_filt_coeff_a = a;
  141. m_iir_filt_need_reset = true;
  142. }
  143. }
  144. std::pair<SpikeBandPowerEstimator::iir_filt_RMS_type::iir_filter_coeff_container_type, SpikeBandPowerEstimator::iir_filt_RMS_type::iir_filter_coeff_container_type> SpikeBandPowerEstimator::getIIRFilterCoeffs() const
  145. {
  146. return std::make_pair<>(m_iir_filt_coeff_b, m_iir_filt_coeff_a);
  147. }
  148. void SpikeBandPowerEstimator::setIIRFilterNSamples(const size_t n_samples)
  149. {
  150. mp_iir_filt_rms->setNSamples(n_samples);
  151. }
  152. size_t SpikeBandPowerEstimator::getIIRFilterNSamples() const
  153. {
  154. return mp_iir_filt_rms->getNSamples();
  155. }
  156. void SpikeBandPowerEstimator::setRMSNBins(const size_t n_bins)
  157. {
  158. mp_iir_filt_rms->setNAvgBins(n_bins);
  159. }
  160. size_t SpikeBandPowerEstimator::getRMSNBins() const
  161. {
  162. return mp_iir_filt_rms->getNAvgBins();
  163. }
  164. bool SpikeBandPowerEstimator::reset(const size_t group, const ChannelNumList_t &channel_list)
  165. {
  166. if (group != m_accepted_sample_group)
  167. {
  168. return false;
  169. }
  170. mp_cont_data->reset(channel_list);
  171. return true;
  172. }
  173. bool SpikeBandPowerEstimator::reset(const ChannelNumList_t &channel_list)
  174. {
  175. mp_cont_data->reset(channel_list);
  176. return true;
  177. }
  178. bool SpikeBandPowerEstimator::addSamplesFromPacket(const Timestamp_t timestamp, const SampleGroupRingBuffer::ContinuousDataSampleT * data, const unsigned int maxdatalen)
  179. {
  180. if (!m_keep_going)
  181. {
  182. return false;
  183. }
  184. return mp_cont_data->addSamplesFromPacket(timestamp, data, maxdatalen);
  185. }
  186. bool SpikeBandPowerEstimator::addSamplesFromPacket(const size_t group, const Timestamp_t timestamp, const SampleGroupRingBuffer::ContinuousDataSampleT * data, const unsigned int maxdatalen)
  187. {
  188. if (!m_keep_going || group != m_accepted_sample_group)
  189. {
  190. return false;
  191. }
  192. return mp_cont_data->addSamplesFromPacket(timestamp, data, maxdatalen);
  193. }
  194. bool SpikeBandPowerEstimator::addSamplesFromPacket(const cbPKT_GROUP *pkt_grp)
  195. {
  196. if (!m_keep_going || pkt_grp->type != m_accepted_sample_group)
  197. {
  198. return false;
  199. }
  200. return mp_cont_data->addSamplesFromPacket(pkt_grp);
  201. }
  202. bool SpikeBandPowerEstimator::setCARChannels(const size_t group, const ChannelNumList_t &carChannelList)
  203. {
  204. if (group != m_accepted_sample_group)
  205. {
  206. return false;
  207. }
  208. return mp_cont_data->setCARChannels(carChannelList);
  209. }
  210. bool SpikeBandPowerEstimator::setCARChannels(const size_t group, ChannelNumList_t &&carChannelList)
  211. {
  212. if (group != m_accepted_sample_group)
  213. {
  214. return false;
  215. }
  216. return mp_cont_data->setCARChannels(carChannelList);
  217. }
  218. void SpikeBandPowerEstimator::disableCAR(const size_t group)
  219. {
  220. if (group != m_accepted_sample_group)
  221. {
  222. return;
  223. }
  224. mp_cont_data->disableCAR();
  225. }
  226. bool SpikeBandPowerEstimator::applyFilters()
  227. {
  228. if (mp_sbp_rb)
  229. {
  230. if (m_iir_filt_need_reset)
  231. {
  232. mp_iir_filt_rms->resize(m_ch_list.size());
  233. mp_iir_filt_rms->setIIRFilterCoeffs(m_iir_filt_coeff_b, m_iir_filt_coeff_a);
  234. m_iir_filt_need_reset = false;
  235. }
  236. if (m_sbp_rb_needs_reset)
  237. {
  238. mp_sbp_rb->reset();
  239. m_sbp_rb_needs_reset = false;
  240. }
  241. SampleGroupChunk v;
  242. mp_cont_data->consume(v);
  243. if (v.channel_ids_changed)
  244. {
  245. mp_logger->debug("channel ids changed (in applyFilters())");
  246. updateChannelPairList(v.channel_ids);
  247. }
  248. const size_t num_channels = v.channel_ids.size();
  249. if (num_channels == 0)
  250. {
  251. mp_logger->debug("No channels");
  252. return false;
  253. }
  254. const size_t num_samples = v.buffer.size() / num_channels;
  255. std::vector<sbp_type> datav(m_ch_list.size(), 0);
  256. bool rms_complete;
  257. for (size_t i_sample = 0; i_sample < num_samples; ++ i_sample)
  258. {
  259. for (auto &p: m_ch_pair_v)
  260. {
  261. datav[p.second] = v.buffer[p.first + i_sample * num_channels];
  262. }
  263. rms_complete = (*mp_iir_filt_rms)(datav.cbegin(), datav.cend());
  264. if (rms_complete)
  265. {
  266. sbp_ring_buffer_item_type time_and_rates{v.last_timestamp - constants::SampleGroupPeriods[m_accepted_sample_group - 1] * (num_samples - i_sample - 1), sbp_item_type(mp_iir_filt_rms->cbegin(), mp_iir_filt_rms->cend())};
  267. mp_sbp_rb->push(time_and_rates);
  268. }
  269. }
  270. return true;
  271. }
  272. return false;
  273. }
  274. void SpikeBandPowerEstimator::signalStart()
  275. {
  276. stopProcessThread();
  277. // mp_logger->debug("Launching spike processing thread");
  278. m_keep_going = true;
  279. m_sbp_rb_needs_reset = true;
  280. m_iir_filt_need_reset = true;
  281. SampleGroupChunk v;
  282. mp_cont_data->consume(v);
  283. if (v.channel_ids_changed)
  284. {
  285. mp_logger->debug("channel ids changed (in signalStart())");
  286. updateChannelPairList(v.channel_ids);
  287. }
  288. m_process_thread = std::thread(&SpikeBandPowerEstimator::processThreadFun, this);
  289. }
  290. void SpikeBandPowerEstimator::signalStop()
  291. {
  292. m_keep_going = false;
  293. }
  294. void SpikeBandPowerEstimator::stopProcessThread()
  295. {
  296. m_keep_going = false;
  297. if (m_process_thread.joinable())
  298. {
  299. mp_logger->info("Stopping spike band power processing thread");
  300. m_process_thread.join();
  301. }
  302. }
  303. void SpikeBandPowerEstimator::processThreadFun()
  304. {
  305. m_thread_loop_over_time_flag = false;
  306. clock::time_point currentStartTime{clock::now()};
  307. clock::time_point nextStartTime{clock::now()};
  308. clock::duration taskDuration;
  309. size_t loop_counter = 0;
  310. while (m_keep_going)
  311. {
  312. ++ loop_counter;
  313. currentStartTime = clock::now();
  314. applyFilters();
  315. taskDuration = clock::now() - currentStartTime;
  316. if (taskDuration > m_t_sbp_loop)
  317. {
  318. m_thread_loop_over_time_flag = true;
  319. m_thread_loop_over_time = taskDuration;
  320. mp_logger->warn("Spike band power processing loop went over time! {}μs vs. {}ms. loop {}", duration_cast<microseconds>(taskDuration).count(), m_t_sbp_loop.count(), loop_counter);
  321. }
  322. nextStartTime += m_t_sbp_loop;
  323. std::this_thread::sleep_until(nextStartTime);
  324. }
  325. }
  326. void SpikeBandPowerEstimator::initLogger()
  327. {
  328. auto l = spdlog::get("SpikeBandPowerEstimator");
  329. if (l)
  330. {
  331. mp_logger = l;
  332. }
  333. else
  334. {
  335. mp_logger = spdlog::stdout_color_mt<spdlog::async_factory>("SpikeBandPowerEstimator");
  336. }
  337. }
  338. void SpikeBandPowerEstimator::setLogger(std::shared_ptr<spdlog::logger> logger)
  339. {
  340. mp_logger = logger->clone("SpikeBandPowerEstimator");
  341. }
  342. void SpikeBandPowerEstimator::bindTimestampCallback(const SpikeBandPowerEstimator::timestamp_callback_type & tscb)
  343. {
  344. m_timestamp_callback = tscb;
  345. }
  346. } // namespace sbpe
  347. } // namespace cc