''' description: kiap bci main class, other kiap modules are imported author: Ioannis Vlachos date: 30.08.18 Copyright (c) 2018 Ioannis Vlachos All rights reserved.''' import argparse import ctypes import datetime import logging import multiprocessing as mp import os import subprocess import sys import time from datetime import date from multiprocessing import Pipe, Value, current_process, Array from subprocess import PIPE, Popen from timeit import default_timer import gc import signal as signal2 import cerebus.cbpy as cb import matplotlib import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation import numpy as np import psutil import pyaudio from scipy.io import wavfile import tkinter as tk from pyfiglet import Figlet from PyQt5.QtCore import pyqtSignal # remove later from PyQt5.QtWidgets import QApplication from scipy import signal, stats from sklearn.decomposition.pca import PCA import aux from aux import static_vars from modules import bci, daq, data from mp_gui.kiapGUIToolBox import kiapGUI from paradigms.feedback import note, load_wav, norm2freq from paradigms import sp_matrix matplotlib.use('Qt4Agg', warn=False, force=True) # matplotlib.use('TkAgg', warn=False, force=True) def run_watchdog(processes, pids, pipe_path, data_obj): '''monitor status of all related running processes''' cp = current_process() processes.append(cp) pids.append(cp.pid) log.debug('Running process pids: {}'.format(pids)) # log.debug('Running processes: {}'.format(processes)) while 1: with open(PIPE_PATH, "w") as pipe1: for ii, p in enumerate(processes): # if os.path.exists('/proc/{}'.format(ii)): if psutil.pid_exists(pids[ii]): pipe1.write('{} | {} (pid:{}) is alive. \n'.format(datetime.datetime.now(), p.name, pids[ii])) else: pipe1.write('{} | {} (pid:{}) is dead. \n'.format(datetime.datetime.now(), p.name, pids[ii])) data_buffer = data_obj.read_buffer() pipe1.write('{} | buffer shape: {}, {:.2f} (MB)'.format(datetime.datetime.now(), data_buffer.shape, data_buffer.nbytes * 1e-6)) pipe1.write('\n') time.sleep(1) def audio_feedback_process(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase): RATE = 44100 pa = pyaudio.PyAudio() log.info('audio feedback processes started') stream = pa.open(output=True, channels=2, rate=RATE,format=pyaudio.paFloat32) alpha = params.feedback.alpha beta = params.feedback.beta length = params.feedback.tone_length while True: target_counter = 0 target_maintained_counter = 0 which_target = aux.decision.unclassified.value while audio_feedback_run.value == 1: norm_rate = normalized_frate.value freq_float = norm2freq(norm_rate, alpha, beta) log.info(f'Normalized rate: {norm_rate}, freq: {freq_float}') tone = note(freq_float, length, amp=0.5, rate=RATE) if (params.feedback.target_n_tones > 0) and (target_counter % params.feedback.target_n_tones == params.feedback.target_n_tones - 1) and (params.speller.type == 'feedback'): target_freq_float = norm2freq(audio_fb_target[0], alpha, beta) log.info(f'freq: {freq_float} target: {target_freq_float}') target_tone = note(target_freq_float, length, amp=.25, rate=RATE, mode='saw') tone = np.concatenate((tone.reshape((-1, 1)), target_tone.reshape((-1, 1))), axis=1).flatten() else: tone = np.repeat(tone.reshape((-1, 1)), 2, axis=1).flatten() # need to copy, because we have 2 channels stream.write(tone.tostring()) # if 'reward_on_target' is True, we will check if normalized firing rate is in the target range. If so, # play reward tone and finish trial if (params.speller.type in ('question', 'feedback', 'color')) and block_phase.value == 2 and (not params.classifier.online): # this is a poor man's classifier – in case of question and not using online classifier (which would be set if classifier is trained) if params.paradigms.feedback.states.down[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.down[2]: if which_target == aux.decision.no.value: target_maintained_counter += 1 else: which_target = aux.decision.no.value target_maintained_counter = 0 elif params.paradigms.feedback.states.up[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.up[2] : if which_target == aux.decision.yes.value: target_maintained_counter += 1 else: which_target = aux.decision.yes.value target_maintained_counter = 0 else: which_target = aux.decision.unclassified.value target_maintained_counter = 0 if target_maintained_counter >= params.feedback.hold_iterations: target_maintained_counter = 0 # TODO: Fix the hack # if params.speller.type == 'feedback' and not (params.paradigms.feedback.play_end_feedback.success and params.paradigms.feedback.play_end_feedback.fail) and decoder_decision.value = which_target audio_feedback_run.value = 0 else: target_maintained_counter = 0 which_target = aux.decision.unclassified.value target_counter += 1 time.sleep(.1) return None def plot_results(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate): fig = plt.figure(1, figsize=(16, 4)) plt.clf() params.plot.channels = [ch.id for ch in params.daq.normalization.channels] log.info(f'plot channels: {params.plot.channels}') min1 = -1 max1 = max([ch.top for ch in params.daq.normalization.channels])*1.5 # set ymax according to channel dynamic range plt.subplot(311) ax1 = plt.gca() plt.ylim(min1, max1) plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.) ax1.set_xticks([]) plt.ylabel('Rates (sp/sec)') ax1.set_title(f'channels: {params.plot.channels}') plt.subplot(312) ax2 = plt.gca() plt.ylim(-0.05, 1.05) plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.) plt.ylabel('Normalized\nrate') plt.subplot(313) ax3 = plt.gca() plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.) plt.ylim(-100, 300) plt.ylim(-30, 30) plt.ylabel('PC1, PC2') # ax2.set_title('Probabilities') # plt.subplot(211) # fig.show() # plt.draw() # We need to draw the canvas before we start animating... fig.canvas.draw() col = ['b', 'r', 'g'] x_len = params.buffer.shape[0] lines1 = [ax1.plot(np.arange(x_len) * 0, alpha=0.5)[0] for zz in range(len(params.plot.channels))] # rates # plot normalized firing rate and thresholds lines2 = [ax2.plot(np.arange(x_len) * 0, 'k.', alpha=0.5)[0] for zz in range(1)] lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.up[1], 'C3--', alpha=0.5, lw=1)[0]) lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.down[2], 'C3--', alpha=0.5, lw=1)[0]) lines3 = [ax3.plot(np.arange(params.buffer.shape[0]) * 0, alpha=0.5)[0] for zz in range(2)] background1 = fig.canvas.copy_from_bbox(ax1.bbox) # background1 = ax1.get_figure().bbox) background2 = fig.canvas.copy_from_bbox(ax2.bbox) background3 = fig.canvas.copy_from_bbox(ax3.bbox) # lines1.extend(ax1.plot(200, 1, 'ro')) title_str = '' plt.xlabel('sec') plt.title(title_str) for line in lines1: ax1.draw_artist(line) plt.pause(0.1) norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan while 1: cnt = 0 tstart = time.time() while recording_status.value > 0: cnt += 1 fig.canvas.restore_region(background1) # fig.canvas.restore_region(fig.canvas.copy_from_bbox(ax1.get_figure().bbox)) fig.canvas.restore_region(background2) fig.canvas.restore_region(background3) b_idx = global_buffer_idx.value if b_idx >= params.buffer.shape[0]-5: norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan # reset if reached at the end # subplot 1 xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000. for ii,jj in enumerate(params.plot.channels): # log.warning(data_buffer.shape) lines1[ii].set_xdata(xx) # lines1[1].set_xdata(xx) lines1[ii].set_ydata(data_buffer[:b_idx, jj]) # lines1[1].set_ydata(data_buffer[:b_idx, 1]) # ax1.draw_artist(lines1[1]) ax1.draw_artist(lines1[ii]) # subplot 2 norm_rate_tot[b_idx] = normalized_frate.value # log.error(f'{norm_rate_tot[-20:]}, {normalized_frate.value}') idx = np.isfinite(norm_rate_tot[b_idx]) [lines2[zz].set_xdata(xx[idx]) for zz in range(1)] [lines2[zz].set_ydata(norm_rate_tot[:b_idx][idx]) for zz in range(1)] [ax2.draw_artist(lines2[zz]) for zz in range(3)] # fig.canvas.blit(ax1.bbox) # fig.canvas.blit(ax2.bbox) # fig.canvas.blit(ax3.bbox) fig.canvas.update() fig.canvas.flush_events() # time.sleep(0.00001) if child_conn2.poll(): decision_history = child_conn2.recv() time.sleep(1./params.plot.fps) # to avoid high cpu usage # print('FPS:', cnt / (time.time() - tstart)) time.sleep(0.1) def plot_feedback(audio_feedback_run, audio_fb_target, normalized_frate, block_phase): """Plot feedback in real time. Uses normalized frequency and audio_fb_target to plot target.""" log.info("Starting visual feedback") # config_file = 'paradigm.yaml' # config = self._read_config(config_file) # config = config.feedback fig = plt.figure(10, figsize=(4, 4), facecolor='black') ax = fig.add_subplot(1, 1, 1) ax.set_facecolor((0.02, 0.02, 0.02)) target_x = 0.1 target_w = 0.8 cursor_x = 0.2 cursor_y = 0.5 # variable cursor_w = 0.6 cursor_h = 0.05 target_ra = ax.add_patch(plt.Rectangle((target_x, 0.01), target_w, 0.01, fill=True, edgecolor=None, facecolor=(.7, .2, .4), linewidth=0, zorder=1)) cursor_ra = ax.add_patch(plt.Rectangle((cursor_x, 0.5), cursor_w, cursor_h, fill=True, facecolor=(.4, .5, 1.0), linewidth=0, zorder=10)) targ_lines = ax.hlines([params.paradigms.feedback.states.down[2],params.paradigms.feedback.states.up[1]] , 0, 1, colors='r', linestyles='solid') plt.show(False) plt.draw() def init_plot(): ax.set_clip_on(False) ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) ax.set_xlim(0, 1) ax.set_ylim(0, 1) for n,s in ax.spines.items(): s.set_color((.2,.2,.2)) return (cursor_ra, target_ra, targ_lines) @static_vars(is_plotting=False) def update_plot(i): if (not update_plot.is_plotting) and audio_feedback_run.value == 1: update_plot.is_plotting = True cursor_ra.set_visible(True) target_ra.set_visible(True) init_plot() elif update_plot.is_plotting and (audio_feedback_run.value == 0): update_plot.is_plotting = False cursor_ra.set_visible(False) target_ra.set_visible(False) if update_plot.is_plotting: cursor_y = normalized_frate.value - cursor_h/2.0 target_y = audio_fb_target[1] target_h = audio_fb_target[2] - audio_fb_target[1] if (audio_fb_target[1] <= normalized_frate.value) and (normalized_frate.value <= audio_fb_target[2]): target_ra.set_fc((.3,1.0,.6)) else: target_ra.set_fc((.7, .2, .4)) cursor_ra.set_y(cursor_y) target_ra.set_y(target_y) target_ra.set_height(target_h) log.debug(f'cursor_y: {cursor_y}, target_y: {target_y}, target_h: {target_h}') return (cursor_ra, target_ra, targ_lines) ani = FuncAnimation(fig, update_plot, frames=None, init_func=init_plot, blit=True, interval= 1000.0 / params.plot.fps) plt.show() return None def speller_matrix(child_conn): def _terminate_process(*args): # pid = os.getpid() # pname = psutil.Process(pid) # log.error(f'{pname} received SIGINT') log.error(f'speller matrix received SIGINT') root.destroy() return None root = tk.Tk() sp_matrix.sp_matrix(root, child_conn) # root.lift() root.attributes("-topmost", True) signal2.signal(signal2.SIGINT, _terminate_process) root.mainloop() return None log = aux.log # aux file contains logging configuration args = aux.args if __name__ == '__main__': print('-' * 100) print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R')) print('-' * 100) # TESTS # pa = pyaudio.PyAudio() # try: # pa.get_default_output_device_info() # except Exception as e: # log.error(e) # os.system("sed -i 's/audio: True/audio: False/g' config.yaml") # # sys.exit() # pa.terminate() start_gui = args.gui live_plot = args.plot # start_audio_feedback = args.audio_feedback recording_status = Value('i', 0) decoder_decision = Value('i', 0) global_buffer_idx = Value('i', 0) block_phase = Value('i',0) # 0: baseline, 1: stimulus, 2:response audio_feedback_run = Value('i',0) # 0: baseline, 1: stimulus, 2:response audio_fb_target = Array('d', 3) # Target for normalized audio feedback activity [0, 1]. Give upper and lower bound normalized_frate = Value('d', 0) # recording_type = Value('i',0) # 0: idle, 1: baseline, 2:recording # audio_feedback_process(audio_fb_freq) # time.sleep(10) # recording_type = Value(ctypes.c_char_p, b'MAIN()') # rec_signal = pyqtSignal() parent_conn, child_conn = Pipe() parent_conn2, child_conn2 = Pipe() try: params = aux.load_config() except Exception as e: log.error(e) sys.exit(1) # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels) n_channels = params.daq.n_channels shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels) data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj()) data_buffer = data_buffer.reshape(params.buffer.length, n_channels) shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * params.classifier.n_classes) class_prob = np.ctypeslib.as_array(shared_array_base.get_obj()) class_prob = class_prob.reshape(params.buffer.length, params.classifier.n_classes) class_prob[:] = 0*np.nan data_obj = data.Data(data_buffer, params, recording_status, decoder_decision, child_conn, global_buffer_idx, class_prob, block_phase, normalized_frate) data_obj.daemon = True data_obj.start() main_pid = os.getpid() pids = [mp.current_process().pid, data_obj.pid] # WATCHDOG # # ---------- # PIPE_PATH = "/tmp/my_pipe" # xterm = Popen('xterm -fa "Monospace" -fs 10 -e tail -f %s' % PIPE_PATH, shell=True) # setattr(xterm, 'name', 'xterm') # # processes = [xterm, data_obj] # pids = [xterm.pid, data_obj.pid] # watchdog = mp.Process(name='watchdog', target=run_watchdog, args=(processes, pids, PIPE_PATH, data_obj)) # # watchdog.start() if live_plot: visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate)) visual.daemon = True # kill visualization if main app terminates visual.start() pids += (visual.pid,) if params.feedback.feedback_tone: feedback_p = mp.Process(name='audio_feedback', target=audio_feedback_process, args=(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase)) feedback_p.daemon = True # kill visualization if main app terminates feedback_p.start() vfeedback_p = mp.Process(name='vfeedback', target=plot_feedback, args=(audio_feedback_run, audio_fb_target, normalized_frate, block_phase)) vfeedback_p.daemon = True # kill visualization if main app terminates vfeedback_p.start() pids += (feedback_p.pid, vfeedback_p.pid) log.info('audio feedback process started') if params.speller.type == 'color' and params.speller.speller_matrix: # start speller matrix as separate process log.warning('speller matrix started') parent_conn3, child_conn3 = Pipe() matrix = mp.Process(name='matrix', target=speller_matrix, args=(child_conn3,)) matrix.daemon = True matrix.start() pids += (matrix.pid,) parent_conn3.send(['','','']) else: parent_conn3 = [] # processes = [data_obj, visual, feedback_p] print(f'pids: {pids}') os.system(f'taskset -p -c 0 {mp.current_process().pid}') # set process affinity to core 0 os.system(f'taskset -p -c 1 {data_obj.pid}') # set process affinity to core 1 os.system(f'taskset -p -c 2 {visual.pid}') # set process affinity to core 2 # processes.append(watchdog) # pids.append(watchdog.pid) # if not os.path.exists(PIPE_PATH): # os.mkfifo(PIPE_PATH) freq = np.random.randint(50,250) if start_gui == 1: log.info("Starting GUI") app = QApplication(sys.argv) gui = kiapGUI(recording_status, log, params, decoder_decision, parent_conn, parent_conn2, parent_conn3, block_phase, audio_feedback_run, audio_fb_target) gui.StartGUI() else: log.warning("GUI start disabled.") pass # terminate correctly all processes cur_pids = psutil.pids() # input('\nDone. Press enter to exit app') # pids.remove(xterm.pid) # keep terminal for debugging after app exits for pid in pids: if pid in cur_pids: p = psutil.Process(pid) log.info('Terminating pid: {}'.format(pid)) p.terminate() else: log.debug('Process {} has been already terminated'.format(pid)) print('Exiting app.')