123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652 |
- import datetime
- import glob
- import logging
- import multiprocessing
- import os
- import pathlib
- import pickle
- import random
- import time
- from datetime import datetime as dt
- import cere_conn as cc
- import matplotlib.pyplot as plt
- import munch
- import numpy as np
- from scipy import io
- import yaml
- import aux
- from aux import log
- from modules import classifier as clf
- from helpers import data_management as dm
- from modules.ringbuffer import RingBuffer
- class DataNormalizer:
- def __init__(self, params, initial_data=None):
- self.params = params
- self.norm_rate = {}
- self.norm_rate['ch_ids'] = [ch.id for ch in self.params.daq.normalization.channels]
- self.norm_rate['bottoms'] = np.asarray([ch.bottom for ch in self.params.daq.normalization.channels])
- self.norm_rate['tops'] = np.asarray([ch.top for ch in self.params.daq.normalization.channels])
- self.norm_rate['invs'] = [ch.invert for ch in self.params.daq.normalization.channels]
-
- n_norm_buffer = int(self.params.daq.normalization.len * (1000.0 / self.params.daq.spike_rates.loop_interval))
- self.norm_buffer = RingBuffer(n_norm_buffer, dtype=(float, self.params.daq.n_channels), allow_overwrite=True)
- self.last_update = time.time()
- if initial_data is not None:
- self.norm_buffer.extend(initial_data)
-
- def _update_norm_range(self):
- buf_vals = self.norm_buffer[:, self.norm_rate['ch_ids']]
- centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
- self.norm_rate['bottoms'] = centiles[0, :]
- self.norm_rate['tops'] = centiles[1, :]
- log.info(f"Updated normalization ranges for channels {self.norm_rate['ch_ids']} to bottoms: {self.norm_rate['bottoms']}, tops: {self.norm_rate['tops']}")
- def _update_norm_range_all(self):
- buf_vals = np.mean(self.norm_buffer, axis=1)
- centiles = np.percentile(buf_vals, self.params.daq.normalization.range, axis=0)
- # log.info(f"Centiles: {centiles}")
-
- self.params.daq.normalization.all_channels.bottom = centiles[0]
- self.params.daq.normalization.all_channels.top = centiles[1]
- log.info(f"Updated normalization range for all channels to [{self.params.daq.normalization.all_channels.bottom}, {self.params.daq.normalization.all_channels.top}]")
-
- def update_norm_range(self, data=None, force=False):
- if data is not None and data.size > 0:
- self.norm_buffer.extend(data)
- if (self.params.daq.normalization.do_update and (time.time() - self.last_update >= self.params.daq.normalization.update_interval)) or force:
- if self.params.daq.normalization.use_all_channels:
- self._update_norm_range_all()
- else:
- self._update_norm_range()
- self.last_update = time.time()
- log.info(f"New channel normalization setting: {yaml.dump(self._format_current_config(), sort_keys=False, default_flow_style=None)}")
- def _format_current_config(self):
- if self.params.daq.normalization.use_all_channels:
- out_dict = {'all_channels': {'bottom': float(self.params.daq.normalization.all_channels.bottom), 'top': float(self.params.daq.normalization.all_channels.top),
- 'invert': bool(self.params.daq.normalization.all_channels.invert)}}
- else:
- out_dict = {'channels': []}
- for ii in range(len(self.norm_rate['ch_ids'])):
- out_dict['channels'].append({'id': int(self.norm_rate['ch_ids'][ii]),
- 'bottom': float(self.norm_rate['bottoms'][ii]),
- 'top': float(self.norm_rate['tops'][ii]),
- 'invert': self.norm_rate['invs'][ii]}
- )
- return out_dict
-
-
- def _calculate_all_norm_rate(self, buf_item):
- avg_r = np.mean(buf_item, axis=1)
- if self.params.daq.normalization.clamp_firing_rates:
- avg_r = np.maximum(np.minimum(avg_r, self.params.daq.normalization.all_channels.top), self.params.daq.normalization.all_channels.bottom)
- norm_rate = (avg_r - self.params.daq.normalization.all_channels.bottom) / (self.params.daq.normalization.all_channels.top - self.params.daq.normalization.all_channels.bottom)
- if self.params.daq.normalization.all_channels.invert:
- norm_rate = 1 - norm_rate
- return norm_rate
-
- def _calculate_individual_norm_rate(self, buf_items):
- """Calculate normalized firing rate, determined by feedback settings"""
- if self.params.daq.normalization.clamp_firing_rates:
- clamped_rates = np.maximum(np.minimum(buf_items[:, self.norm_rate['ch_ids']], self.norm_rate['tops']), self.norm_rate['bottoms'])
- else:
- clamped_rates = buf_items[:, self.norm_rate['ch_ids']]
- denom = self.norm_rate['tops'] - self.norm_rate['bottoms']
- if np.all(denom==0):
- denom[:] = 1
- norm_rates = (clamped_rates - self.norm_rate['bottoms']) / denom
- norm_rates[:, self.norm_rate['invs']] = 1 - norm_rates[:, self.norm_rate['invs']]
- norm_rate = np.nanmean(norm_rates, axis=1)
- if not self.params.daq.normalization.clamp_firing_rates:
- norm_rate = np.maximum(norm_rate, 0.0)
- return norm_rate
-
-
- def calculate_norm_rate(self, buf_item):
- if buf_item.ndim == 1:
- buf_item.shape = (1, buf_item.shape[0])
- if self.params.daq.normalization.use_all_channels:
- return self._calculate_all_norm_rate(buf_item)
- else:
- return self._calculate_individual_norm_rate(buf_item)
-
- class Data(multiprocessing.Process):
- def __init__(self, data_buffer, params, recording_status, decoder_decision, child_conn, global_buffer_idx, class_prob, block_phase, normalized_frate):
- super(Data, self).__init__()
- self.buffer = data_buffer
- self.global_buffer_idx = global_buffer_idx
- self.buffer_idx_total = 0
- self.bl = []
- self.class_prob = class_prob
- self.time_stamps = np.zeros(data_buffer.shape[0], dtype=np.uint64)
- # self.load_config()
- self.params = aux.load_config()
- self.buffer[:] = 0
- self.events = []
- self.recording_status = recording_status
- self.decoder_decision = decoder_decision
- self.block_phase = block_phase
- self.child_conn = child_conn
- self.clf = clf.Classifier(params, block_phase = block_phase)
- self.rate_fct = 10 # factor to use for simulate rates
- # initialize normalized rate calculation
- self.normalized_frate = normalized_frate
- try:
- history_data, _, _ = dm.get_session(self.params.file_handling.filename_history)
- log.info(f'history data shape: {history_data.shape}')
- except FileNotFoundError:
- log.warning('No history data file detected !')
- history_data = None
- self.normalizer = DataNormalizer(params, initial_data=history_data)
- def run(self):
- log.info('Started data as process')
- # self.load_config()
- # self.reset_buffer()
-
- session_id = 7
- # data_tot = io.loadmat('/kiap/data/tom/model/trainData_rates2.mat')['trainData']
- # self.cur_data = data_tot[session_id:session_id + 1]
- plt.figure(2)
- self.buffer_idx = 0
- self.buffer_idx_prev = 0 # keep track of previous idx
- self.init_conn()
- self.read_raw() # put reading on standby-mode
- log.debug('standby mode')
- return
- # def set_recording_status(self, flag):
- # self.recording_status.value = 0
- # return None
- def init_conn(self):
- self.ck = cc.CereConn(withSRE=self.params.daq.data_source=='spike_rates', withSBPE=self.params.daq.data_source=='band_power')
- self.ck.send_open()
- # Wait until connection is established
- t = time.time()
- while self.ck.get_state() != cc.ccS_Idle:
- time.sleep(0.005)
- print("It took {:5.3f}s to open CereConn\n".format(time.time() - t))
- # get only unit 0, caution: switch of spike sorting in central
- self.ck.set_spike_rate_estimator_loop_interval_ms(self.params.daq.spike_rates.loop_interval)
- self.ck.set_spike_band_power_estimator_loop_interval_ms(self.params.daq.spike_band_power.loop_interval)
- self.ck.set_spike_band_power_estimator_integrated_samples(self.params.daq.spike_band_power.integrated_samples)
- self.ck.set_spike_band_power_estimator_avg_bins(self.params.daq.spike_band_power.average_n_bins)
- # set spike band power filter coefficients
- if self.params.daq.spike_band_power.filter:
- self.ck.set_spike_band_power_estimator_filter_coefficients(np.asarray(self.params.daq.spike_band_power.filter.b), np.asarray(self.params.daq.spike_band_power.filter.a))
- self.ck.set_spike_band_power_estimator_use_sample_group(self.params.daq.spike_band_power.sample_group)
-
- # self.ck.fill_spike_rate_estimator_ch_u_slist(self.params.daq.n_channels, self.params.daq.spike_rates.n_units)
- ch_u_list = [(ii, 0) for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
- ch_list = [ii for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
- self.ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- self.ck.set_spike_band_power_estimator_ch_list(ch_list)
- if self.params.daq.spike_rates.method == 'exponential':
- self.ck.set_spike_rate_estimation_method_exponential(self.params.daq.spike_rates.decay_factor, self.params.daq.spike_rates.max_bins)
- else:
- self.ck.set_spike_rate_estimation_method_boxcar(self.params.daq.spike_rates.max_bins)
- self.ch_rec_list = list(zip(*ch_u_list))[0]
- # Set CAR, both for LFP (1kHz data) and for raw data (used for SBP)
- if self.params.daq.car_channels:
- self.ck.set_car_channels(2, self.params.daq.car_channels)
- self.ck.set_car_channels(6, self.params.daq.car_channels)
- self.update_ch_map()
- log.warning(ch_u_list)
- # self.ck.send_record()
- # time.sleep(0.1)
- # self.ck.send_idle()
- # self.ck.get_spike_rate_data()
- # self.ck.get_comment_data()
- time.sleep(0.5)
- self.gids = [5]
- self.bin_width = 0.005
- self.run_time = 1
- # self.ck.send_record()
- return None
- # TODO: This function duplicates code?!
- def update_ch_map(self):
- ch_u_list = [(ii, 0) for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
- self.ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- ch_list = [ii for ii in range(1, self.params.daq.n_channels_max + 1) if ii not in self.params.daq.exclude_channels]
- self.ck.set_spike_band_power_estimator_ch_list(ch_list)
- ch_map = self.ck.get_spike_rate_estimator_ch_u_map()
- log.debug(ch_map)
- log.info(f"# of channels in ch_map: {len(ch_map['list'])}")
- log.warning('Only unit 0 will be returned. Check spike-sorting status in Central.')
- return None
- def get_raw(self):
- cd = self.ck.get_cont_data(sample_groups=self.gids)
- for gid in self.gids:
- data = cd['sample_groups'][gid]['data']
- time_stamp2 = cd['ts'] # int type
- # return data[:3000, :], time_stamp2
- if np.any(np.diff(time_stamps) == 0):
- log.critical(time_stamps)
- log.critical('Identical time stamps')
- return data, time_stamp2
- def get_rates(self):
- data = self.ck.get_spike_rate_data()
- ts = data['ts']
- rates = data['rates']
- # rates = rates + np.random.randn(rates.shape[0], rates.shape[1])*1 # noise only for EMG !!!
- # log.warning(f'ts: {ts}, rates: {rates.shape}')
- time_stamps = data['rate_ts']
- if np.any(np.diff(time_stamps) == 0):
- log.critical(time_stamps)
- log.critical('Identical time stamps')
- return rates, ts, time_stamps
- def get_sb_power(self):
- data = self.ck.get_spike_band_power_data()
- ts = data['ts']
- sbp = data['sbp']
- # rates = rates + np.random.randn(rates.shape[0], rates.shape[1])*1 # noise only for EMG !!!
- # log.warning(f'ts: {ts}, sbp: {sbp.shape}')
- time_stamps = data['rate_ts']
- if np.any(np.diff(time_stamps) == 0):
- log.critical(time_stamps)
- log.critical('Identical time stamps')
- return sbp, ts, time_stamps
- def get_rates_sim(self):
- data = self.ck.get_spike_rate_data()
- ts = data['ts']
- rates = data['rates']
- rates = np.random.randn(rates.shape[0], rates.shape[1])*5 + self.rate_fct # simulate data
- # rates = rates*0
- # log.info(f'sim data: started')
- # aa = datetime.datetime.now()
- # ts = int(str(aa.timestamp()).replace('.',''))
- # log.info(f'sim data: {ts}')
- # t0 = int(f'{aa.minute}{aa.second}{aa.microsecond}')
- # timestamps = [t0,t0-1]
- # log.info(f'sim data: {t0}, {ts}, {timestamps}')
- # rates = np.random.randn(100, self.params.daq.n_channels) + self.rate_fct # simulate data#
-
- log.debug(f'ts: {ts}, rates: {rates.shape}, rate_fct: {self.rate_fct}')
- time_stamps = data['rate_ts']
- if np.any(np.diff(time_stamps)==0):
- log.critical(time_stamps)
- log.critical('Identical time stamps')
- return rates, ts, time_stamps
- def correct_bl(self, raw):
- '''correct for baseline changes, baseline: first t_response_1 sec in each block'''
- bl_idx = int(self.params.recording.timing.t_baseline_1 / self.params.daq.spike_rates.loop_interval*1000.)
- # log.warning(f'buffer_idx: {bl_idx},{self.buffer_idx_total},{self.buffer_idx}')
- if self.bl == [] and self.buffer_idx_total >= bl_idx: # save baseline data only once, before trial 1
- bl = np.max(self.buffer[:bl_idx,:], axis=0) - np.min(self.buffer[:bl_idx,:], axis=0)
- log.debug(f'raw and baseline shape: {raw.shape}, {bl.shape}, idx:{self.buffer_idx}')
- self.bl = np.copy(bl)
- # self.bl = np.delete(self.bl, self.params.daq.exclude_channels)
- np.save(self.params.file_handling.filename_baseline, self.bl)
- if self.buffer_idx_total >= bl_idx:
- raw = (raw - self.bl)/(self.bl+self.params.daq.spike_rates.bl_offset)
-
- return raw
- def get_comments(self, store_comments=True):
- ''' get comments from NSP'''
- # log.info('Reading comments from NSP...')
- comments = self.ck.get_comment_data()
- if store_comments == False: # just read comments to empty buffer
- return None
- rates = [40,20]
- rates_exploration = [60,30,20]
- if len(comments['comments']) > 0:
- for ii in range(len(comments['comments'])):
- self.events.append((comments['comments'][ii]['ts'], comments['comments'][ii]['text']))
- log.warning(comments['comments'][ii])
- comment = str(comments['comments'][ii]['text'])
- # print(comments['comments'][ii])
- # print(type(comments['comments'][ii]))
- # if str(comments['comments'][ii]['text']) == 'question, yes, response, start':
- if 'question, Training, yes, response, start' in comment or \
- 'training_color, Training, yes, response, start' in comment:
- self.rate_fct = rates[0]
- elif 'question, Training, no, response, start' in comment or \
- 'training_color, Training, no, response, start' in comment:
- self.rate_fct = rates[1]
-
- elif ('Validation' in comment and 'response, start' in comment) or \
- ('color, Free' in comment and 'response, start' in comment):
- self.rate_fct = rates[random.randint(0, 1)]
-
- elif ('ruhe' in comment and 'response, start' in comment):
- self.rate_fct = rates_exploration[0]
-
- elif ('kopf' in comment and 'response, start' in comment):
- self.rate_fct = rates_exploration[1]
-
- elif ('fuss' in comment and 'response, start'in comment):
- self.rate_fct = rates_exploration[2]
-
- else:
- self.rate_fct = self.params.sim_data.rate_bl
- return None
- def send_triggers(self, send_triggers=True):
- '''forward triggers from bci process to NSP'''
- while self.child_conn.poll(): # bci process send a trigger
- # comment = '{:_<{}}'.format(self.child_conn.recv(), self.params.daq.trigger_len) # use padding
- comment = self.child_conn.recv()
- cb_time = self.ck.get_cb_time()['ts']
- self.ck.send_comment(comment)
- log.debug(f'Comment "{comment}" at {cb_time} forwarded to NSP')
- return None
- # def get_raw2(self, jj):
- # return self.clf.get_class(self.cur_data[0, 0][jj - 600:jj], jj)
- def get_msec(self, time_now):
- '''get current millisecond, used only for plotting'''
- tbin = int(time_now.microsecond / 1000. )
- # log.debug(tbin)
- return tbin
- def read_raw(self):
- jj = 600
- plt.clf()
- plt.ylim(-0.2, 1.2)
- plt.xlim(0, 9000)
-
- # win_past = len(self.params.classifier.template)*2 # number of samples to include from the past
- win_past = self.params.classifier.template.max() # number of samples to include from the past
- # log.warning(f'win_past: {win_past}')
- while 1:
- while self.recording_status.value:
- # log.error(self.buffer_idx)
- if self.ck.get_state() == cc.ccS_Idle:
- self.params = aux.load_config() # update parameters in data instance
- try:
- self.clf.set_params(self.params) # update parameters classifier instance
- except Exception as e:
- log.error(e)
- log.error('block stopped')
- self.recording_status.value = 0
- self.update_ch_map() # add or remove channels to record from
- self.ck.send_record()
- self.get_comments(store_comments=False) # get any comments that may still be in buffer
- time.sleep(.05)
- self.send_triggers()
- self.get_comments()
- if self.params.daq.data_source=='spike_rates':
- raw, time_stamp2, time_stamps = self.get_rates() # spike rates, received data
- else:
- raw, time_stamp2, time_stamps = self.get_sb_power() # spike band power, received data
- # raw, time_stamp2, time_stamps = self.get_rates_sim() # spike rates, simulated data
- if self.params.daq.spike_rates.correct_bl: # correct baseline
- raw = self.correct_bl(raw)
- self.write_buffer(raw, time_stamp2, time_stamps)
- # log.warning(f'AFTER: {time_stamp2}, {raw.shape}, {self.buffer_idx}, {time_stamp2-self.buffer_idx}')
- if raw.shape[0] > 0:
- log.debug(f'raw shape: {raw.shape}, buffer idx: {self.buffer_idx}, {raw[0,:10]}')
- else:
- log.debug(f'raw shape: {raw.shape}, buffer idx: {self.buffer_idx}')
- time_now = datetime.datetime.now()
- self.normalizer.update_norm_range(data=raw)
- self.normalized_frate.value = self.normalizer.calculate_norm_rate(self.buffer[self.global_buffer_idx.value - 1, :])
- if self.params.classifier.online:
- # if self.buffer_idx-win_past<0: # get the last elem
- # tmp_buffer = np.roll(self.buffer, -(self.buffer_idx-win_past),axis=0)[:win_past,:]
-
- if self.buffer_idx-win_past>=0:
- self.clf.init_buffer = 0
- tmp_buffer = np.take(self.buffer, range(self.buffer_idx-win_past, self.buffer_idx),axis=0)
- # log.error(f'indices: {range(self.buffer_idx-win_past, self.buffer_idx)}')
- # t1 = time.time()
- self.clf.get_class2(tmp_buffer, self.get_msec(time_now), self.decoder_decision) # CLASSIFY RESPONSE
- # log.error(time.time()-t1)
- # log.warning(f'data process: decision: {self.decoder_decision.value}')
- # log.warning(f'raw:{raw.shape}, Class prob: {self.class_prob}')
- # self.clf.get_class2(self.buffer[self.buffer_idx-win_past:self.buffer_idx, :], self.get_msec(time_now), self.decoder_decision) # data from NSP
- self.class_prob[self.buffer_idx,:] = self.clf.online_sig
-
- if self.buffer_idx - self.buffer_idx_prev >1:
- log.warning(f'data: buffer idx: {self.buffer_idx}, {self.buffer_idx_prev}')
- # log.warning(self.class_prob[:self.buffer_idx,:])
- self.buffer_idx_prev = self.buffer_idx
- # log.warning(self.class_prob[:self.buffer_idx,:])
- else:
- self.class_prob[:] = [0] * self.params.classifier.n_classes
- # self.decoder_decision.value = -1
- # log.debug('elapsed time: {}'.format(time.time()-tic))
- time.sleep(self.params.recording.timing.recording_loop_interval-0.01)
- else:
- time.sleep(1)
- if self.ck.get_state() == cc.ccS_Recording:
- log.warning(f'reading final comments from NSP')
- # time.sleep(1)
- self.send_triggers(send_triggers=True)
- time.sleep(0.2)
- self.get_comments() # get last comments to clear buffer
- # time.sleep(0.5)
- # self.send_triggers(send_triggers=False)
- if self.buffer_idx > 0:
- self.write_file()
-
- time.sleep(self.params.recording.timing.recording_loop_interval_data)
- if self.ck.get_state() == cc.ccS_Recording:
- self.ck.send_idle()
- self.bl = [] # use to save baseline data only once, before trial 1
- self.buffer_idx_total = 0
- return None
- def buffer_to_array(self, data_trial):
- '''Copy single trial continuous data to a numpy array
- Parameters
- -----------
- data_trial: list,
- continuous data as provided by cbpy, [[ch_id, ndarray]], f.i. raw = cp.trial_data(), data_trial = raw[2]
- Returns
- -------
- res: ndarray, shape (n_channels, samples)'''
- res = np.zeros((len(data_trial), data_trial[0][1].size), dtype=np.float32)
- for ii, val in enumerate(data_trial):
- res[ii] = val[1] / 4.
- return res
- def get_params(self):
- return self.params
- def read_buffer(self):
- """Returns the buffer object."""
- return self.buffer
- def write_buffer(self, data, time_stamp2, time_stamps):
- idx = self.buffer_idx # index to keep track of how full buffer is
- # diff = self.buffer.shape[0] - idx
- # data = data[:diff, :]
- if data.shape == (0, 0):
- pass
- # log.warning('Buffer is empty. Not writing to file')
- else:
- try:
- if idx + data.shape[0] >= self.buffer.shape[0]: # if new data causes overflow first write to file
- self.write_file()
- idx = self.buffer_idx
- log.warning(f'writing to buffer to avoid overflow: {idx} {data.shape}')
- # log.error(data.shape)
- self.buffer[idx:idx + data.shape[0], :data.shape[1]] = data
- self.time_stamps[idx:idx + data.shape[0]] = time_stamps
- self.buffer_idx = idx + data.shape[0]
- self.buffer_idx_total += data.shape[0]
- self.global_buffer_idx.value = np.copy(self.buffer_idx)
- self.time_stamp2 = time_stamp2
- except ValueError as e:
- log.error(e)
- log.error('If exclude_channels changed, restart kiap_bci !')
- # if np.mod(self.buffer_idx, data.shape[0]*5) == 0:
- # log.debug('write buffer, idx: {}'.format(self.buffer_idx))
- # if self.buffer_idx == self.buffer.shape[0] and self.params.file_handling.save_data: # when buffer is full write to disk
- # self.write_file()
- return None
- def reset_buffer(self): # Initialize buffer with relevant names
- """Resets the data buffer."""
- # self.buffer[:] = 0 # verify that this is correct
- self.buffer_idx = 0
- return None
- def get_buffer_idx(self):
- return self.buffer_idx
- def write_file(self):
-
- if not self.params.file_handling.save_data: # don't save anything
- self.reset_buffer()
- return None
- else:
- log.info(f'Files will be written in mode: {self.params.file_handling.mode}')
- with open(self.params.file_handling.filename_data, self.params.file_handling.mode) as fh: # data file
- # fh.write(self.buffer[:self.buffer_idx, :].tobytes())
- db_bytes = self.buffer[:self.buffer_idx, :].tobytes()
- n_bytes = np.int64(len(db_bytes))
- n_samples = np.int64(self.buffer[:self.buffer_idx, :].shape[0])
- n_ch = np.int64(self.buffer[:self.buffer_idx, :].shape[1])
- t_now1 = np.array(dt.now(), dtype='datetime64[us]') #
- # t_now1 = np.array(self.time_stamp1, dtype='datetime64[us]') #
- t_now2 = np.int64(self.time_stamp2).tobytes() # NSP time stamp
- fh.write(t_now1.tobytes()) # one from PC
- fh.write(t_now2) # one from NSP
- # fh.write(np.array(self.recording_type.value,dtype=np.int8).tobytes()) #1 byte
- fh.write(n_bytes)
- fh.write(n_samples)
- fh.write(n_ch)
- ch_rec_list = np.int16(self.ch_rec_list).tobytes()
- fh.write(np.int16(len(ch_rec_list)))
- fh.write(ch_rec_list)
- fh.write(db_bytes)
- fh.write(self.time_stamps[:self.buffer_idx].tobytes())
- with open(self.params.file_handling.filename_history, self.params.file_handling.mode) as fh: # history file
- db_bytes = self.buffer[:self.buffer_idx, :].tobytes()
- n_bytes = np.int64(len(db_bytes))
- n_samples = np.int64(self.buffer[:self.buffer_idx, :].shape[0])
- n_ch = np.int64(self.buffer[:self.buffer_idx, :].shape[1])
- t_now1 = np.array(dt.now(), dtype='datetime64[us]') #
- # t_now1 = np.array(self.time_stamp1, dtype='datetime64[us]') #
- t_now2 = np.int64(self.time_stamp2).tobytes() # NSP time stamp
- fh.write(t_now1.tobytes()) # one from PC
- fh.write(t_now2) # one from NSP
- # fh.write(np.array(self.recording_type.value,dtype=np.int8).tobytes()) #1 byte
- fh.write(n_bytes)
- fh.write(n_samples)
- fh.write(n_ch)
- ch_rec_list = np.int16(self.ch_rec_list).tobytes()
- fh.write(np.int16(len(ch_rec_list)))
- fh.write(ch_rec_list)
- fh.write(db_bytes)
- fh.write(self.time_stamps[:self.buffer_idx].tobytes())
- with open(self.params.file_handling.filename_events, self.params.file_handling.mode[0]) as fh: # event file
- # write events
- cnt = 0
- cnt = sum([(8 + len(ev[1])) for ev in self.events])
-
- # log.info(self.events)
- # fh.write(np.int8(cnt))
- for ev in self.events:
- log.info(f'EVENTS: {ev}, cnt: {cnt}')
- # log.info(f'8, {len(ev[1])}')
- # fh.write(np.int64(ev[0]).tobytes())
- fh.write(f'{ev[0]}, {ev[1]}\n')
- # fh.write(ev[1])
- self.events = []
- # self.recording_type.value = 'DATA'
- # log.info(self.recording_type.value)
- log.info('write buffer, idx: {} shape: {}, ts: {}'.format(self.buffer_idx, self.buffer.shape, t_now1))
- log.info(f'write buffer: ts: {self.time_stamp2} bytes: {n_bytes} samples:{n_samples} ch:{n_ch} {self.buffer[:10,0]}')
- log.info(f'write buffer, ts: {self.time_stamps[self.buffer_idx-3:self.buffer_idx]}')
- if self.buffer_idx < self.params.buffer.shape[0]:
- log.warning('Writing buffer to file, but buffer is not full')
- self.reset_buffer()
- # if self.child_conn.poll():
- # self.child_conn.recv()
- # self.child_conn.send(t_now1) # send signal to unlock flow in bci
- # log.info("Buffer written. Sent ack to bci")
- return None
- # def exit(self):
- # self.fh1.close()
- # fh2.close()
|