Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

data.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. import datetime
  2. import glob
  3. import logging
  4. import multiprocessing
  5. import os
  6. import pathlib
  7. import pickle
  8. import random
  9. import time
  10. from datetime import datetime as dt
  11. import cere_conn as cc
  12. import matplotlib.pyplot as plt
  13. import munch
  14. import numpy as np
  15. from scipy import io
  16. import yaml
  17. import aux
  18. from aux import log
  19. from modules import classifier as clf
  20. from helpers import data_management as dm
  21. from modules.ringbuffer import RingBuffer
  22. class DataNormalizer:
  23. def __init__(self, params, initial_data=None):
  24. self.params = params
  25. self.norm_rate = {}
  26. self.norm_rate['ch_ids'] = [ch.id for ch in self.params.daq.normalization.channels]
  27. self.norm_rate['bottoms'] = np.asarray([ch.bottom for ch in self.params.daq.normalization.channels])
  28. self.norm_rate['tops'] = np.asarray([ch.top for ch in self.params.daq.normalization.channels])
  29. self.norm_rate['invs'] = [ch.invert for ch in self.params.daq.normalization.channels]
  30. n_norm_buffer = int(self.params.daq.normalization.len * (1000.0 / self.params.daq.spike_rates.loop_interval))
  31. self.norm_buffer = RingBuffer(n_norm_buffer, dtype=(float, self.params.daq.n_channels), allow_overwrite=True)
  32. self.last_update = time.time()
  33. if initial_data is not None:
  34. self.norm_buffer.extend(initial_data)
  35. def _update_norm_range(self):
  36. buf_vals = self.norm_buffer[:, self.norm_rate['ch_ids']]
  37. centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
  38. self.norm_rate['bottoms'] = centiles[0, :]
  39. self.norm_rate['tops'] = centiles[1, :]
  40. log.info(f"Updated normalization ranges for channels {self.norm_rate['ch_ids']} to bottoms: {self.norm_rate['bottoms']}, tops: {self.norm_rate['tops']}")
  41. def _update_norm_range_all(self):
  42. buf_vals = np.mean(self.norm_buffer, axis=1)
  43. centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
  44. # log.info(f"Centiles: {centiles}")
  45. self.params.daq.normalization.all_channels.bottom = centiles[0]
  46. self.params.daq.normalization.all_channels.top = centiles[1]
  47. log.info(f"Updated normalization range for all channels to [{self.params.daq.normalization.all_channels.bottom}, {self.params.daq.normalization.all_channels.top}]")
  48. def update_norm_range(self, data=None, force=False):
  49. if data is not None and data.size > 0:
  50. self.norm_buffer.extend(data)
  51. if (self.params.daq.normalization.do_update and (time.time() - self.last_update >= self.params.daq.normalization.update_interval)) or force:
  52. if self.params.daq.normalization.use_all_channels:
  53. self._update_norm_range_all()
  54. else:
  55. self._update_norm_range()
  56. self.last_update = time.time()
  57. log.info(f"New channel normalization setting: {yaml.dump(self._format_current_config(), sort_keys=False, default_flow_style=None)}")
  58. def _format_current_config(self):
  59. if self.params.daq.normalization.use_all_channels:
  60. out_dict = {'all_channels': {'bottom': float(self.params.daq.normalization.all_channels.bottom), 'top': float(self.params.daq.normalization.all_channels.top),
  61. 'invert': bool(self.params.daq.normalization.all_channels.invert)}}
  62. else:
  63. out_dict = {'channels': []}
  64. for ii in range(len(self.norm_rate['ch_ids'])):
  65. out_dict['channels'].append({'id': int(self.norm_rate['ch_ids'][ii]),
  66. 'bottom': float(self.norm_rate['bottoms'][ii]),
  67. 'top': float(self.norm_rate['tops'][ii]),
  68. 'invert': self.norm_rate['invs'][ii]}
  69. )
  70. return out_dict
  71. def _calculate_all_norm_rate(self, buf_item):
  72. avg_r = np.mean(buf_item, axis=1)
  73. if self.params.daq.normalization.clamp_firing_rates:
  74. avg_r = np.maximum(np.minimum(avg_r, self.params.daq.normalization.all_channels.top), self.params.daq.normalization.all_channels.bottom)
  75. norm_rate = (avg_r - self.params.daq.normalization.all_channels.bottom) / (self.params.daq.normalization.all_channels.top - self.params.daq.normalization.all_channels.bottom)
  76. if self.params.daq.normalization.all_channels.invert:
  77. norm_rate = 1 - norm_rate
  78. return norm_rate
  79. def _calculate_individual_norm_rate(self, buf_items):
  80. """Calculate normalized firing rate, determined by feedback settings"""
  81. if self.params.daq.normalization.clamp_firing_rates:
  82. clamped_rates = np.maximum(np.minimum(buf_items[:, self.norm_rate['ch_ids']], self.norm_rate['tops']), self.norm_rate['bottoms'])
  83. else:
  84. clamped_rates = buf_items[:, self.norm_rate['ch_ids']]
  85. denom = self.norm_rate['tops'] - self.norm_rate['bottoms']
  86. if np.all(denom==0):
  87. denom[:] = 1
  88. norm_rates = (clamped_rates - self.norm_rate['bottoms']) / denom
  89. norm_rates[:, self.norm_rate['invs']] = 1 - norm_rates[:, self.norm_rate['invs']]
  90. norm_rate = np.nanmean(norm_rates, axis=1)
  91. if not self.params.daq.normalization.clamp_firing_rates:
  92. norm_rate = np.maximum(norm_rate, 0.0)
  93. return norm_rate
  94. def calculate_norm_rate(self, buf_item):
  95. if buf_item.ndim == 1:
  96. buf_item.shape = (1, buf_item.shape[0])
  97. if self.params.daq.normalization.use_all_channels:
  98. return self._calculate_all_norm_rate(buf_item)
  99. else:
  100. return self._calculate_individual_norm_rate(buf_item)
  101. class Data(multiprocessing.Process):
  102. def __init__(self, data_buffer, params, recording_status, decoder_decision, child_conn, global_buffer_idx, class_prob, block_phase, normalized_frate):
  103. super(Data, self).__init__()
  104. self.buffer = data_buffer
  105. self.global_buffer_idx = global_buffer_idx
  106. self.buffer_idx_total = 0
  107. self.bl = []
  108. self.class_prob = class_prob
  109. self.time_stamps = np.zeros(data_buffer.shape[0], dtype=np.uint64)
  110. # self.load_config()
  111. self.params = aux.load_config()
  112. self.buffer[:] = 0
  113. self.events = []
  114. self.recording_status = recording_status
  115. self.decoder_decision = decoder_decision
  116. self.block_phase = block_phase
  117. self.child_conn = child_conn
  118. self.clf = clf.Classifier(params, block_phase = block_phase)
  119. self.rate_fct = 10 # factor to use for simulate rates
  120. # initialize normalized rate calculation
  121. self.normalized_frate = normalized_frate
  122. try:
  123. history_data, _, _ = dm.get_session(self.params.file_handling.filename_history)
  124. log.info(f'history data shape: {history_data.shape}')
  125. except FileNotFoundError:
  126. log.warning('No history data file detected !')
  127. history_data = None
  128. self.normalizer = DataNormalizer(params, initial_data=history_data)
  129. def run(self):
  130. log.info('Started data as process')
  131. # self.load_config()
  132. # self.reset_buffer()
  133. session_id = 7
  134. # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
  135. # self.cur_data = data_tot[session_id:session_id + 1]
  136. plt.figure(2)
  137. self.buffer_idx = 0
  138. self.buffer_idx_prev = 0 # keep track of previous idx
  139. self.init_conn()
  140. self.read_raw() # put reading on standby-mode
  141. log.debug('standby mode')
  142. return
  143. # def set_recording_status(self, flag):
  144. # self.recording_status.value = 0
  145. # return None
  146. def init_conn(self):
  147. self.ck = cc.CereConn(withSRE=self.params.daq.data_source=='spike_rates', withSBPE=self.params.daq.data_source=='band_power')
  148. self.ck.send_open()
  149. # Wait until connection is established
  150. t = time.time()
  151. while self.ck.get_state() != cc.ccS_Idle:
  152. time.sleep(0.005)
  153. print("It took {:5.3f}s to open CereConn\n".format(time.time() - t))
  154. # get only unit 0, caution: switch of spike sorting in central
  155. self.ck.set_spike_rate_estimator_loop_interval_ms(self.params.daq.spike_rates.loop_interval)
  156. self.ck.set_spike_band_power_estimator_loop_interval_ms(self.params.daq.spike_band_power.loop_interval)
  157. self.ck.set_spike_band_power_estimator_integrated_samples(self.params.daq.spike_band_power.integrated_samples)
  158. self.ck.set_spike_band_power_estimator_avg_bins(self.params.daq.spike_band_power.average_n_bins)
  159. # set spike band power filter coefficients
  160. if self.params.daq.spike_band_power.filter:
  161. self.ck.set_spike_band_power_estimator_filter_coefficients(np.asarray(self.params.daq.spike_band_power.filter.b), np.asarray(self.params.daq.spike_band_power.filter.a))
  162. self.ck.set_spike_band_power_estimator_use_sample_group(self.params.daq.spike_band_power.sample_group)
  163. # self.ck.fill_spike_rate_estimator_ch_u_slist(self.params.daq.n_channels, self.params.daq.spike_rates.n_units)
  164. ch_u_list = [(ii, 0) for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
  165. ch_list = [ii for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
  166. self.ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
  167. self.ck.set_spike_band_power_estimator_ch_list(ch_list)
  168. if self.params.daq.spike_rates.method == 'exponential':
  169. self.ck.set_spike_rate_estimation_method_exponential(self.params.daq.spike_rates.decay_factor, self.params.daq.spike_rates.max_bins)
  170. else:
  171. self.ck.set_spike_rate_estimation_method_boxcar(self.params.daq.spike_rates.max_bins)
  172. self.ch_rec_list = list(zip(*ch_u_list))[0]
  173. # Set CAR, both for LFP (1kHz data) and for raw data (used for SBP)
  174. if self.params.daq.car_channels:
  175. self.ck.set_car_channels(2, self.params.daq.car_channels)
  176. self.ck.set_car_channels(6, self.params.daq.car_channels)
  177. self.update_ch_map()
  178. log.warning(ch_u_list)
  179. # self.ck.send_record()
  180. # time.sleep(0.1)
  181. # self.ck.send_idle()
  182. # self.ck.get_spike_rate_data()
  183. # self.ck.get_comment_data()
  184. time.sleep(0.5)
  185. self.gids = [5]
  186. self.bin_width = 0.005
  187. self.run_time = 1
  188. # self.ck.send_record()
  189. return None
  190. # TODO: This function duplicates code?!
  191. def update_ch_map(self):
  192. ch_u_list = [(ii, 0) for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
  193. self.ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
  194. ch_list = [ii for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
  195. self.ck.set_spike_band_power_estimator_ch_list(ch_list)
  196. ch_map = self.ck.get_spike_rate_estimator_ch_u_map()
  197. log.debug(ch_map)
  198. log.info(f"# of channels in ch_map: {len(ch_map['list'])}")
  199. log.warning('Only unit 0 will be returned. Check spike-sorting status in Central.')
  200. return None
  201. def get_raw(self):
  202. cd = self.ck.get_cont_data(sample_groups=self.gids)
  203. for gid in self.gids:
  204. data = cd['sample_groups'][gid]['data']
  205. time_stamp2 = cd['ts'] # int type
  206. # return data[:3000, :], time_stamp2
  207. if np.any(np.diff(time_stamps) == 0):
  208. log.critical(time_stamps)
  209. log.critical('Identical time stamps')
  210. return data, time_stamp2
  211. def get_rates(self):
  212. data = self.ck.get_spike_rate_data()
  213. ts = data['ts']
  214. rates = data['rates']
  215. # rates = rates + np.random.randn(rates.shape[0], rates.shape[1])*1 # noise only for EMG !!!
  216. # log.warning(f'ts: {ts}, rates: {rates.shape}')
  217. time_stamps = data['rate_ts']
  218. if np.any(np.diff(time_stamps) == 0):
  219. log.critical(time_stamps)
  220. log.critical('Identical time stamps')
  221. return rates, ts, time_stamps
  222. def get_sb_power(self):
  223. data = self.ck.get_spike_band_power_data()
  224. ts = data['ts']
  225. sbp = data['sbp']
  226. # rates = rates + np.random.randn(rates.shape[0], rates.shape[1])*1 # noise only for EMG !!!
  227. # log.warning(f'ts: {ts}, sbp: {sbp.shape}')
  228. time_stamps = data['rate_ts']
  229. if np.any(np.diff(time_stamps) == 0):
  230. log.critical(time_stamps)
  231. log.critical('Identical time stamps')
  232. return sbp, ts, time_stamps
  233. def get_rates_sim(self):
  234. data = self.ck.get_spike_rate_data()
  235. ts = data['ts']
  236. rates = data['rates']
  237. rates = np.random.randn(rates.shape[0], rates.shape[1])*5 + self.rate_fct # simulate data
  238. # rates = rates*0
  239. # log.info(f'sim data: started')
  240. # aa = datetime.datetime.now()
  241. # ts = int(str(aa.timestamp()).replace('.',''))
  242. # log.info(f'sim data: {ts}')
  243. # t0 = int(f'{aa.minute}{aa.second}{aa.microsecond}')
  244. # timestamps = [t0,t0-1]
  245. # log.info(f'sim data: {t0}, {ts}, {timestamps}')
  246. # rates = np.random.randn(100, self.params.daq.n_channels) + self.rate_fct # simulate data#
  247. log.debug(f'ts: {ts}, rates: {rates.shape}, rate_fct: {self.rate_fct}')
  248. time_stamps = data['rate_ts']
  249. if np.any(np.diff(time_stamps)==0):
  250. log.critical(time_stamps)
  251. log.critical('Identical time stamps')
  252. return rates, ts, time_stamps
  253. def correct_bl(self, raw):
  254. '''correct for baseline changes, baseline: first t_response_1 sec in each block'''
  255. bl_idx = int(self.params.recording.timing.t_baseline_1 / self.params.daq.spike_rates.loop_interval*1000.)
  256. # log.warning(f'buffer_idx: {bl_idx},{self.buffer_idx_total},{self.buffer_idx}')
  257. if self.bl == [] and self.buffer_idx_total >= bl_idx: # save baseline data only once, before trial 1
  258. bl = np.max(self.buffer[:bl_idx,:], axis=0) - np.min(self.buffer[:bl_idx,:], axis=0)
  259. log.debug(f'raw and baseline shape: {raw.shape}, {bl.shape}, idx:{self.buffer_idx}')
  260. self.bl = np.copy(bl)
  261. # self.bl = np.delete(self.bl, self.params.daq.exclude_channels)
  262. np.save(self.params.file_handling.filename_baseline, self.bl)
  263. if self.buffer_idx_total >= bl_idx:
  264. raw = (raw - self.bl)/(self.bl+self.params.daq.spike_rates.bl_offset)
  265. return raw
  266. def get_comments(self, store_comments=True):
  267. ''' get comments from NSP'''
  268. # log.info('Reading comments from NSP...')
  269. comments = self.ck.get_comment_data()
  270. if store_comments == False: # just read comments to empty buffer
  271. return None
  272. rates = [40,20]
  273. rates_exploration = [60,30,20]
  274. if len(comments['comments']) > 0:
  275. for ii in range(len(comments['comments'])):
  276. self.events.append((comments['comments'][ii]['ts'], comments['comments'][ii]['text']))
  277. log.warning(comments['comments'][ii])
  278. comment = str(comments['comments'][ii]['text'])
  279. # print(comments['comments'][ii])
  280. # print(type(comments['comments'][ii]))
  281. # if str(comments['comments'][ii]['text']) == 'question, yes, response, start':
  282. if 'question, Training, yes, response, start' in comment or \
  283. 'training_color, Training, yes, response, start' in comment:
  284. self.rate_fct = rates[0]
  285. elif 'question, Training, no, response, start' in comment or \
  286. 'training_color, Training, no, response, start' in comment:
  287. self.rate_fct = rates[1]
  288. elif ('Validation' in comment and 'response, start' in comment) or \
  289. ('color, Free' in comment and 'response, start' in comment):
  290. self.rate_fct = rates[random.randint(0, 1)]
  291. elif ('ruhe' in comment and 'response, start' in comment):
  292. self.rate_fct = rates_exploration[0]
  293. elif ('kopf' in comment and 'response, start' in comment):
  294. self.rate_fct = rates_exploration[1]
  295. elif ('fuss' in comment and 'response, start'in comment):
  296. self.rate_fct = rates_exploration[2]
  297. else:
  298. self.rate_fct = self.params.sim_data.rate_bl
  299. return None
  300. def send_triggers(self, send_triggers=True):
  301. '''forward triggers from bci process to NSP'''
  302. while self.child_conn.poll(): # bci process send a trigger
  303. # comment = '{:_<{}}'.format(self.child_conn.recv(), self.params.daq.trigger_len) # use padding
  304. comment = self.child_conn.recv()
  305. cb_time = self.ck.get_cb_time()['ts']
  306. self.ck.send_comment(comment)
  307. log.debug(f'Comment "{comment}" at {cb_time} forwarded to NSP')
  308. return None
  309. # def get_raw2(self, jj):
  310. # return self.clf.get_class(self.cur_data[0, 0][jj - 600:jj], jj)
  311. def get_msec(self, time_now):
  312. '''get current millisecond, used only for plotting'''
  313. tbin = int(time_now.microsecond / 1000. )
  314. # log.debug(tbin)
  315. return tbin
  316. def read_raw(self):
  317. jj = 600
  318. plt.clf()
  319. plt.ylim(-0.2, 1.2)
  320. plt.xlim(0, 9000)
  321. # win_past = len(self.params.classifier.template)*2 # number of samples to include from the past
  322. win_past = self.params.classifier.template.max() # number of samples to include from the past
  323. # log.warning(f'win_past: {win_past}')
  324. while 1:
  325. while self.recording_status.value:
  326. # log.error(self.buffer_idx)
  327. if self.ck.get_state() == cc.ccS_Idle:
  328. self.params = aux.load_config() # update parameters in data instance
  329. try:
  330. self.clf.set_params(self.params) # update parameters classifier instance
  331. except Exception as e:
  332. log.error(e)
  333. log.error('block stopped')
  334. self.recording_status.value = 0
  335. self.update_ch_map() # add or remove channels to record from
  336. self.ck.send_record()
  337. self.get_comments(store_comments=False) # get any comments that may still be in buffer
  338. time.sleep(.05)
  339. self.send_triggers()
  340. self.get_comments()
  341. if self.params.daq.data_source=='spike_rates':
  342. raw, time_stamp2, time_stamps = self.get_rates() # spike rates, received data
  343. else:
  344. raw, time_stamp2, time_stamps = self.get_sb_power() # spike band power, received data
  345. # raw, time_stamp2, time_stamps = self.get_rates_sim() # spike rates, simulated data
  346. if self.params.daq.spike_rates.correct_bl: # correct baseline
  347. raw = self.correct_bl(raw)
  348. self.write_buffer(raw, time_stamp2, time_stamps)
  349. # log.warning(f'AFTER: {time_stamp2}, {raw.shape}, {self.buffer_idx}, {time_stamp2-self.buffer_idx}')
  350. if raw.shape[0] > 0:
  351. log.debug(f'raw shape: {raw.shape}, buffer idx: {self.buffer_idx}, {raw[0,:10]}')
  352. else:
  353. log.debug(f'raw shape: {raw.shape}, buffer idx: {self.buffer_idx}')
  354. time_now = datetime.datetime.now()
  355. self.normalizer.update_norm_range(data=raw)
  356. self.normalized_frate.value = self.normalizer.calculate_norm_rate(self.buffer[self.global_buffer_idx.value - 1, :])
  357. if self.params.classifier.online:
  358. # if self.buffer_idx-win_past<0: # get the last elem
  359. # tmp_buffer = np.roll(self.buffer, -(self.buffer_idx-win_past),axis=0)[:win_past,:]
  360. if self.buffer_idx-win_past>=0:
  361. self.clf.init_buffer = 0
  362. tmp_buffer = np.take(self.buffer, range(self.buffer_idx-win_past, self.buffer_idx),axis=0)
  363. # log.error(f'indices: {range(self.buffer_idx-win_past, self.buffer_idx)}')
  364. # t1 = time.time()
  365. self.clf.get_class2(tmp_buffer, self.get_msec(time_now), self.decoder_decision) # CLASSIFY RESPONSE
  366. # log.error(time.time()-t1)
  367. # log.warning(f'data process: decision: {self.decoder_decision.value}')
  368. # log.warning(f'raw:{raw.shape}, Class prob: {self.class_prob}')
  369. # self.clf.get_class2(self.buffer[self.buffer_idx-win_past:self.buffer_idx, :], self.get_msec(time_now), self.decoder_decision) # data from NSP
  370. self.class_prob[self.buffer_idx,:] = self.clf.online_sig
  371. if self.buffer_idx - self.buffer_idx_prev >1:
  372. log.warning(f'data: buffer idx: {self.buffer_idx}, {self.buffer_idx_prev}')
  373. # log.warning(self.class_prob[:self.buffer_idx,:])
  374. self.buffer_idx_prev = self.buffer_idx
  375. # log.warning(self.class_prob[:self.buffer_idx,:])
  376. else:
  377. self.class_prob[:] = [0] * self.params.classifier.n_classes
  378. # self.decoder_decision.value = -1
  379. # log.debug('elapsed time: {}'.format(time.time()-tic))
  380. time.sleep(self.params.recording.timing.recording_loop_interval-0.01)
  381. else:
  382. time.sleep(1)
  383. if self.ck.get_state() == cc.ccS_Recording:
  384. log.warning(f'reading final comments from NSP')
  385. # time.sleep(1)
  386. self.send_triggers(send_triggers=True)
  387. time.sleep(0.2)
  388. self.get_comments() # get last comments to clear buffer
  389. # time.sleep(0.5)
  390. # self.send_triggers(send_triggers=False)
  391. if self.buffer_idx > 0:
  392. self.write_file()
  393. time.sleep(self.params.recording.timing.recording_loop_interval_data)
  394. if self.ck.get_state() == cc.ccS_Recording:
  395. self.ck.send_idle()
  396. self.bl = [] # use to save baseline data only once, before trial 1
  397. self.buffer_idx_total = 0
  398. return None
  399. def buffer_to_array(self, data_trial):
  400. '''Copy single trial continuous data to a numpy array
  401. Parameters
  402. -----------
  403. data_trial: list,
  404. continuous data as provided by cbpy, [[ch_id, ndarray]], f.i. raw = cp.trial_data(), data_trial = raw[2]
  405. Returns
  406. -------
  407. res: ndarray, shape (n_channels, samples)'''
  408. res = np.zeros((len(data_trial), data_trial[0][1].size), dtype=np.float32)
  409. for ii, val in enumerate(data_trial):
  410. res[ii] = val[1] / 4.
  411. return res
  412. def get_params(self):
  413. return self.params
  414. def read_buffer(self):
  415. """Returns the buffer object."""
  416. return self.buffer
  417. def write_buffer(self, data, time_stamp2, time_stamps):
  418. idx = self.buffer_idx # index to keep track of how full buffer is
  419. # diff = self.buffer.shape[0] - idx
  420. # data = data[:diff, :]
  421. if data.shape == (0, 0):
  422. pass
  423. # log.warning('Buffer is empty. Not writing to file')
  424. else:
  425. try:
  426. if idx + data.shape[0] >= self.buffer.shape[0]: # if new data causes overflow first write to file
  427. self.write_file()
  428. idx = self.buffer_idx
  429. log.warning(f'writing to buffer to avoid overflow: {idx} {data.shape}')
  430. # log.error(data.shape)
  431. self.buffer[idx:idx + data.shape[0], :data.shape[1]] = data
  432. self.time_stamps[idx:idx + data.shape[0]] = time_stamps
  433. self.buffer_idx = idx + data.shape[0]
  434. self.buffer_idx_total += data.shape[0]
  435. self.global_buffer_idx.value = np.copy(self.buffer_idx)
  436. self.time_stamp2 = time_stamp2
  437. except ValueError as e:
  438. log.error(e)
  439. log.error('If exclude_channels changed, restart kiap_bci !')
  440. # if np.mod(self.buffer_idx, data.shape[0]*5) == 0:
  441. # log.debug('write buffer, idx: {}'.format(self.buffer_idx))
  442. # if self.buffer_idx == self.buffer.shape[0] and self.params.file_handling.save_data: # when buffer is full write to disk
  443. # self.write_file()
  444. return None
  445. def reset_buffer(self): # Initialize buffer with relevant names
  446. """Resets the data buffer."""
  447. # self.buffer[:] = 0 # verify that this is correct
  448. self.buffer_idx = 0
  449. return None
  450. def get_buffer_idx(self):
  451. return self.buffer_idx
  452. def write_file(self):
  453. if not self.params.file_handling.save_data: # don't save anything
  454. self.reset_buffer()
  455. return None
  456. else:
  457. log.info(f'Files will be written in mode: {self.params.file_handling.mode}')
  458. with open(self.params.file_handling.filename_data, self.params.file_handling.mode) as fh: # data file
  459. # fh.write(self.buffer[:self.buffer_idx, :].tobytes())
  460. db_bytes = self.buffer[:self.buffer_idx, :].tobytes()
  461. n_bytes = np.int64(len(db_bytes))
  462. n_samples = np.int64(self.buffer[:self.buffer_idx, :].shape[0])
  463. n_ch = np.int64(self.buffer[:self.buffer_idx, :].shape[1])
  464. t_now1 = np.array(dt.now(), dtype='datetime64[us]') #
  465. # t_now1 = np.array(self.time_stamp1, dtype='datetime64[us]') #
  466. t_now2 = np.int64(self.time_stamp2).tobytes() # NSP time stamp
  467. fh.write(t_now1.tobytes()) # one from PC
  468. fh.write(t_now2) # one from NSP
  469. # fh.write(np.array(self.recording_type.value,dtype=np.int8).tobytes()) #1 byte
  470. fh.write(n_bytes)
  471. fh.write(n_samples)
  472. fh.write(n_ch)
  473. ch_rec_list = np.int16(self.ch_rec_list).tobytes()
  474. fh.write(np.int16(len(ch_rec_list)))
  475. fh.write(ch_rec_list)
  476. fh.write(db_bytes)
  477. fh.write(self.time_stamps[:self.buffer_idx].tobytes())
  478. with open(self.params.file_handling.filename_history, self.params.file_handling.mode) as fh: # history file
  479. db_bytes = self.buffer[:self.buffer_idx, :].tobytes()
  480. n_bytes = np.int64(len(db_bytes))
  481. n_samples = np.int64(self.buffer[:self.buffer_idx, :].shape[0])
  482. n_ch = np.int64(self.buffer[:self.buffer_idx, :].shape[1])
  483. t_now1 = np.array(dt.now(), dtype='datetime64[us]') #
  484. # t_now1 = np.array(self.time_stamp1, dtype='datetime64[us]') #
  485. t_now2 = np.int64(self.time_stamp2).tobytes() # NSP time stamp
  486. fh.write(t_now1.tobytes()) # one from PC
  487. fh.write(t_now2) # one from NSP
  488. # fh.write(np.array(self.recording_type.value,dtype=np.int8).tobytes()) #1 byte
  489. fh.write(n_bytes)
  490. fh.write(n_samples)
  491. fh.write(n_ch)
  492. ch_rec_list = np.int16(self.ch_rec_list).tobytes()
  493. fh.write(np.int16(len(ch_rec_list)))
  494. fh.write(ch_rec_list)
  495. fh.write(db_bytes)
  496. fh.write(self.time_stamps[:self.buffer_idx].tobytes())
  497. with open(self.params.file_handling.filename_events, self.params.file_handling.mode[0]) as fh: # event file
  498. # write events
  499. cnt = 0
  500. cnt = sum([(8 + len(ev[1])) for ev in self.events])
  501. # log.info(self.events)
  502. # fh.write(np.int8(cnt))
  503. for ev in self.events:
  504. log.info(f'EVENTS: {ev}, cnt: {cnt}')
  505. # log.info(f'8, {len(ev[1])}')
  506. # fh.write(np.int64(ev[0]).tobytes())
  507. fh.write(f'{ev[0]}, {ev[1]}\n')
  508. # fh.write(ev[1])
  509. self.events = []
  510. # self.recording_type.value = 'DATA'
  511. # log.info(self.recording_type.value)
  512. log.info('write buffer, idx: {} shape: {}, ts: {}'.format(self.buffer_idx, self.buffer.shape, t_now1))
  513. log.info(f'write buffer: ts: {self.time_stamp2} bytes: {n_bytes} samples:{n_samples} ch:{n_ch} {self.buffer[:10,0]}')
  514. log.info(f'write buffer, ts: {self.time_stamps[self.buffer_idx-3:self.buffer_idx]}')
  515. if self.buffer_idx < self.params.buffer.shape[0]:
  516. log.warning('Writing buffer to file, but buffer is not full')
  517. self.reset_buffer()
  518. # if self.child_conn.poll():
  519. # self.child_conn.recv()
  520. # self.child_conn.send(t_now1) # send signal to unlock flow in bci
  521. # log.info("Buffer written. Sent ack to bci")
  522. return None
  523. # def exit(self):
  524. # self.fh1.close()
  525. # fh2.close()