123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522 |
- '''
- description: kiap bci main class, other kiap modules are imported
- author: Ioannis Vlachos
- date: 30.08.18
- Copyright (c) 2018 Ioannis Vlachos
- All rights reserved.'''
- import argparse
- import ctypes
- import datetime
- import logging
- import multiprocessing as mp
- import os
- import subprocess
- import sys
- import time
- from datetime import date
- from multiprocessing import Pipe, Value, current_process, Array
- from subprocess import PIPE, Popen
- from timeit import default_timer
- import gc
- import signal as signal2
- import cerebus.cbpy as cb
- import matplotlib
- import matplotlib.pyplot as plt
- from matplotlib.animation import FuncAnimation
- import numpy as np
- import psutil
- import pyaudio
- from scipy.io import wavfile
- import tkinter as tk
- from pyfiglet import Figlet
- from PyQt5.QtCore import pyqtSignal # remove later
- from PyQt5.QtWidgets import QApplication
- from scipy import signal, stats
- from sklearn.decomposition.pca import PCA
- import aux
- from aux import static_vars
- from modules import bci, daq, data
- from mp_gui.kiapGUIToolBox import kiapGUI
- from paradigms.feedback import note, load_wav, norm2freq
- from paradigms import sp_matrix
- matplotlib.use('Qt4Agg', warn=False, force=True)
- # matplotlib.use('TkAgg', warn=False, force=True)
- def run_watchdog(processes, pids, pipe_path, data_obj):
- '''monitor status of all related running processes'''
- cp = current_process()
- processes.append(cp)
- pids.append(cp.pid)
- log.debug('Running process pids: {}'.format(pids))
- # log.debug('Running processes: {}'.format(processes))
- while 1:
- with open(PIPE_PATH, "w") as pipe1:
- for ii, p in enumerate(processes):
- # if os.path.exists('/proc/{}'.format(ii)):
- if psutil.pid_exists(pids[ii]):
- pipe1.write('{} | {} (pid:{}) is alive. \n'.format(datetime.datetime.now(), p.name, pids[ii]))
- else:
- pipe1.write('{} | {} (pid:{}) is dead. \n'.format(datetime.datetime.now(), p.name, pids[ii]))
- data_buffer = data_obj.read_buffer()
- pipe1.write('{} | buffer shape: {}, {:.2f} (MB)'.format(datetime.datetime.now(), data_buffer.shape, data_buffer.nbytes * 1e-6))
- pipe1.write('\n')
- time.sleep(1)
- def audio_feedback_process(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase):
- RATE = 44100
-
- pa = pyaudio.PyAudio()
- log.info('audio feedback processes started')
- stream = pa.open(output=True, channels=2, rate=RATE,format=pyaudio.paFloat32)
-
- alpha = params.feedback.alpha
- beta = params.feedback.beta
- length = params.feedback.tone_length
-
-
- while True:
- target_counter = 0
- target_maintained_counter = 0
- which_target = aux.decision.unclassified.value
- while audio_feedback_run.value == 1:
- norm_rate = normalized_frate.value
- freq_float = norm2freq(norm_rate, alpha, beta)
-
- log.info(f'Normalized rate: {norm_rate}, freq: {freq_float}')
- tone = note(freq_float, length, amp=0.5, rate=RATE)
- if (params.feedback.target_n_tones > 0) and (target_counter % params.feedback.target_n_tones == params.feedback.target_n_tones - 1) and (params.speller.type == 'feedback'):
- target_freq_float = norm2freq(audio_fb_target[0], alpha, beta)
- log.info(f'freq: {freq_float} target: {target_freq_float}')
-
- target_tone = note(target_freq_float, length, amp=.25, rate=RATE, mode='saw')
- tone = np.concatenate((tone.reshape((-1, 1)), target_tone.reshape((-1, 1))), axis=1).flatten()
- else:
- tone = np.repeat(tone.reshape((-1, 1)), 2, axis=1).flatten() # need to copy, because we have 2 channels
- stream.write(tone.tostring())
-
- # if 'reward_on_target' is True, we will check if normalized firing rate is in the target range. If so,
- # play reward tone and finish trial
- if (params.speller.type in ('question', 'feedback', 'color')) and block_phase.value == 2 and (not params.classifier.online):
- # this is a poor man's classifier – in case of question and not using online classifier (which would be set if classifier is trained)
- if params.paradigms.feedback.states.down[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.down[2]:
- if which_target == aux.decision.no.value:
- target_maintained_counter += 1
- else:
- which_target = aux.decision.no.value
- target_maintained_counter = 0
- elif params.paradigms.feedback.states.up[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.up[2] :
- if which_target == aux.decision.yes.value:
- target_maintained_counter += 1
- else:
- which_target = aux.decision.yes.value
- target_maintained_counter = 0
- else:
- which_target = aux.decision.unclassified.value
- target_maintained_counter = 0
- if target_maintained_counter >= params.feedback.hold_iterations:
- target_maintained_counter = 0
- # TODO: Fix the hack
- # if params.speller.type == 'feedback' and not (params.paradigms.feedback.play_end_feedback.success and params.paradigms.feedback.play_end_feedback.fail) and
- decoder_decision.value = which_target
- audio_feedback_run.value = 0
- else:
- target_maintained_counter = 0
- which_target = aux.decision.unclassified.value
-
- target_counter += 1
- time.sleep(.1)
- return None
- def plot_results(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate):
- fig = plt.figure(1, figsize=(16, 4))
- plt.clf()
-
- params.plot.channels = [ch.id for ch in params.daq.normalization.channels]
- log.info(f'plot channels: {params.plot.channels}')
- min1 = -1
- max1 = max([ch.top for ch in params.daq.normalization.channels])*1.5 # set ymax according to channel dynamic range
- plt.subplot(311)
- ax1 = plt.gca()
- plt.ylim(min1, max1)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- ax1.set_xticks([])
- plt.ylabel('Rates (sp/sec)')
- ax1.set_title(f'channels: {params.plot.channels}')
-
- plt.subplot(312)
- ax2 = plt.gca()
- plt.ylim(-0.05, 1.05)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylabel('Normalized\nrate')
- plt.subplot(313)
- ax3 = plt.gca()
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylim(-100, 300)
- plt.ylim(-30, 30)
- plt.ylabel('PC1, PC2')
- # ax2.set_title('Probabilities')
- # plt.subplot(211)
- # fig.show()
- # plt.draw()
- # We need to draw the canvas before we start animating...
- fig.canvas.draw()
- col = ['b', 'r', 'g']
- x_len = params.buffer.shape[0]
- lines1 = [ax1.plot(np.arange(x_len) * 0, alpha=0.5)[0] for zz in range(len(params.plot.channels))] # rates
- # plot normalized firing rate and thresholds
- lines2 = [ax2.plot(np.arange(x_len) * 0, 'k.', alpha=0.5)[0] for zz in range(1)]
- lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.up[1], 'C3--', alpha=0.5, lw=1)[0])
- lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.down[2], 'C3--', alpha=0.5, lw=1)[0])
- lines3 = [ax3.plot(np.arange(params.buffer.shape[0]) * 0, alpha=0.5)[0] for zz in range(2)]
- background1 = fig.canvas.copy_from_bbox(ax1.bbox)
- # background1 = ax1.get_figure().bbox)
- background2 = fig.canvas.copy_from_bbox(ax2.bbox)
- background3 = fig.canvas.copy_from_bbox(ax3.bbox)
- # lines1.extend(ax1.plot(200, 1, 'ro'))
- title_str = ''
- plt.xlabel('sec')
- plt.title(title_str)
- for line in lines1:
- ax1.draw_artist(line)
- plt.pause(0.1)
- norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan
- while 1:
- cnt = 0
- tstart = time.time()
- while recording_status.value > 0:
- cnt += 1
-
- fig.canvas.restore_region(background1)
- # fig.canvas.restore_region(fig.canvas.copy_from_bbox(ax1.get_figure().bbox))
- fig.canvas.restore_region(background2)
- fig.canvas.restore_region(background3)
- b_idx = global_buffer_idx.value
- if b_idx >= params.buffer.shape[0]-5:
- norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan # reset if reached at the end
- # subplot 1
- xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000.
- for ii,jj in enumerate(params.plot.channels):
- # log.warning(data_buffer.shape)
- lines1[ii].set_xdata(xx)
- # lines1[1].set_xdata(xx)
- lines1[ii].set_ydata(data_buffer[:b_idx, jj])
- # lines1[1].set_ydata(data_buffer[:b_idx, 1])
- # ax1.draw_artist(lines1[1])
- ax1.draw_artist(lines1[ii])
- # subplot 2
- norm_rate_tot[b_idx] = normalized_frate.value
- # log.error(f'{norm_rate_tot[-20:]}, {normalized_frate.value}')
- idx = np.isfinite(norm_rate_tot[b_idx])
- [lines2[zz].set_xdata(xx[idx]) for zz in range(1)]
- [lines2[zz].set_ydata(norm_rate_tot[:b_idx][idx]) for zz in range(1)]
- [ax2.draw_artist(lines2[zz]) for zz in range(3)]
-
- # fig.canvas.blit(ax1.bbox)
- # fig.canvas.blit(ax2.bbox)
- # fig.canvas.blit(ax3.bbox)
- fig.canvas.update()
- fig.canvas.flush_events()
- # time.sleep(0.00001)
- if child_conn2.poll():
- decision_history = child_conn2.recv()
- time.sleep(1./params.plot.fps) # to avoid high cpu usage
- # print('FPS:', cnt / (time.time() - tstart))
- time.sleep(0.1)
- def plot_feedback(audio_feedback_run, audio_fb_target, normalized_frate, block_phase):
- """Plot feedback in real time. Uses normalized frequency and audio_fb_target to plot target."""
- log.info("Starting visual feedback")
- # config_file = 'paradigm.yaml'
- # config = self._read_config(config_file)
- # config = config.feedback
-
- fig = plt.figure(10, figsize=(4, 4), facecolor='black')
- ax = fig.add_subplot(1, 1, 1)
- ax.set_facecolor((0.02, 0.02, 0.02))
- target_x = 0.1
- target_w = 0.8
-
-
- cursor_x = 0.2
- cursor_y = 0.5 # variable
- cursor_w = 0.6
- cursor_h = 0.05
-
- target_ra = ax.add_patch(plt.Rectangle((target_x, 0.01), target_w, 0.01, fill=True, edgecolor=None, facecolor=(.7, .2, .4), linewidth=0, zorder=1))
- cursor_ra = ax.add_patch(plt.Rectangle((cursor_x, 0.5), cursor_w, cursor_h, fill=True, facecolor=(.4, .5, 1.0), linewidth=0, zorder=10))
-
-
- targ_lines = ax.hlines([params.paradigms.feedback.states.down[2],params.paradigms.feedback.states.up[1]] , 0, 1, colors='r', linestyles='solid')
-
- plt.show(False)
- plt.draw()
-
- def init_plot():
- ax.set_clip_on(False)
- ax.get_xaxis().set_ticks([])
- ax.get_yaxis().set_ticks([])
- ax.set_xlim(0, 1)
- ax.set_ylim(0, 1)
-
-
-
- for n,s in ax.spines.items():
- s.set_color((.2,.2,.2))
-
- return (cursor_ra, target_ra, targ_lines)
-
-
- @static_vars(is_plotting=False)
- def update_plot(i):
- if (not update_plot.is_plotting) and audio_feedback_run.value == 1:
- update_plot.is_plotting = True
- cursor_ra.set_visible(True)
- target_ra.set_visible(True)
- init_plot()
- elif update_plot.is_plotting and (audio_feedback_run.value == 0):
- update_plot.is_plotting = False
- cursor_ra.set_visible(False)
- target_ra.set_visible(False)
-
- if update_plot.is_plotting:
-
- cursor_y = normalized_frate.value - cursor_h/2.0
- target_y = audio_fb_target[1]
- target_h = audio_fb_target[2] - audio_fb_target[1]
- if (audio_fb_target[1] <= normalized_frate.value) and (normalized_frate.value <= audio_fb_target[2]):
- target_ra.set_fc((.3,1.0,.6))
- else:
- target_ra.set_fc((.7, .2, .4))
-
- cursor_ra.set_y(cursor_y)
- target_ra.set_y(target_y)
- target_ra.set_height(target_h)
-
- log.debug(f'cursor_y: {cursor_y}, target_y: {target_y}, target_h: {target_h}')
- return (cursor_ra, target_ra, targ_lines)
-
- ani = FuncAnimation(fig, update_plot, frames=None,
- init_func=init_plot, blit=True, interval= 1000.0 / params.plot.fps)
- plt.show()
- return None
- def speller_matrix(child_conn):
- def _terminate_process(*args):
- # pid = os.getpid()
- # pname = psutil.Process(pid)
- # log.error(f'{pname} received SIGINT')
- log.error(f'speller matrix received SIGINT')
- root.destroy()
- return None
- root = tk.Tk()
- sp_matrix.sp_matrix(root, child_conn)
- # root.lift()
- root.attributes("-topmost", True)
- signal2.signal(signal2.SIGINT, _terminate_process)
- root.mainloop()
- return None
- log = aux.log # aux file contains logging configuration
- args = aux.args
- if __name__ == '__main__':
- print('-' * 100)
- print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R'))
- print('-' * 100)
- # TESTS
- # pa = pyaudio.PyAudio()
- # try:
- # pa.get_default_output_device_info()
- # except Exception as e:
- # log.error(e)
- # os.system("sed -i 's/audio: True/audio: False/g' config.yaml")
- # # sys.exit()
- # pa.terminate()
- start_gui = args.gui
- live_plot = args.plot
- # start_audio_feedback = args.audio_feedback
- recording_status = Value('i', 0)
- decoder_decision = Value('i', 0)
- global_buffer_idx = Value('i', 0)
- block_phase = Value('i',0) # 0: baseline, 1: stimulus, 2:response
- audio_feedback_run = Value('i',0) # 0: baseline, 1: stimulus, 2:response
- audio_fb_target = Array('d', 3) # Target for normalized audio feedback activity [0, 1]. Give upper and lower bound
- normalized_frate = Value('d', 0)
- # recording_type = Value('i',0) # 0: idle, 1: baseline, 2:recording
- # audio_feedback_process(audio_fb_freq)
- # time.sleep(10)
- # recording_type = Value(ctypes.c_char_p, b'MAIN()')
- # rec_signal = pyqtSignal()
- parent_conn, child_conn = Pipe()
- parent_conn2, child_conn2 = Pipe()
- try:
- params = aux.load_config()
- except Exception as e:
- log.error(e)
- sys.exit(1)
- # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels)
- n_channels = params.daq.n_channels
- shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels)
- data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj())
- data_buffer = data_buffer.reshape(params.buffer.length, n_channels)
- shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * params.classifier.n_classes)
- class_prob = np.ctypeslib.as_array(shared_array_base.get_obj())
- class_prob = class_prob.reshape(params.buffer.length, params.classifier.n_classes)
- class_prob[:] = 0*np.nan
- data_obj = data.Data(data_buffer, params, recording_status, decoder_decision, child_conn, global_buffer_idx, class_prob, block_phase, normalized_frate)
- data_obj.daemon = True
- data_obj.start()
- main_pid = os.getpid()
- pids = [mp.current_process().pid, data_obj.pid]
- # WATCHDOG
- # # ----------
- # PIPE_PATH = "/tmp/my_pipe"
- # xterm = Popen('xterm -fa "Monospace" -fs 10 -e tail -f %s' % PIPE_PATH, shell=True)
- # setattr(xterm, 'name', 'xterm')
- # # processes = [xterm, data_obj]
- # pids = [xterm.pid, data_obj.pid]
- # watchdog = mp.Process(name='watchdog', target=run_watchdog, args=(processes, pids, PIPE_PATH, data_obj))
- # # watchdog.start()
- if live_plot:
- visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate))
- visual.daemon = True # kill visualization if main app terminates
- visual.start()
- pids += (visual.pid,)
- if params.feedback.feedback_tone:
- feedback_p = mp.Process(name='audio_feedback', target=audio_feedback_process, args=(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase))
- feedback_p.daemon = True # kill visualization if main app terminates
- feedback_p.start()
-
- vfeedback_p = mp.Process(name='vfeedback', target=plot_feedback, args=(audio_feedback_run, audio_fb_target, normalized_frate, block_phase))
- vfeedback_p.daemon = True # kill visualization if main app terminates
- vfeedback_p.start()
- pids += (feedback_p.pid, vfeedback_p.pid)
- log.info('audio feedback process started')
- if params.speller.type == 'color' and params.speller.speller_matrix: # start speller matrix as separate process
- log.warning('speller matrix started')
- parent_conn3, child_conn3 = Pipe()
- matrix = mp.Process(name='matrix', target=speller_matrix, args=(child_conn3,))
- matrix.daemon = True
- matrix.start()
- pids += (matrix.pid,)
- parent_conn3.send(['','',''])
- else:
- parent_conn3 = []
- # processes = [data_obj, visual, feedback_p]
-
- print(f'pids: {pids}')
- os.system(f'taskset -p -c 0 {mp.current_process().pid}') # set process affinity to core 0
- os.system(f'taskset -p -c 1 {data_obj.pid}') # set process affinity to core 1
- os.system(f'taskset -p -c 2 {visual.pid}') # set process affinity to core 2
- # processes.append(watchdog)
- # pids.append(watchdog.pid)
- # if not os.path.exists(PIPE_PATH):
- # os.mkfifo(PIPE_PATH)
- freq = np.random.randint(50,250)
- if start_gui == 1:
- log.info("Starting GUI")
- app = QApplication(sys.argv)
- gui = kiapGUI(recording_status, log, params, decoder_decision, parent_conn, parent_conn2, parent_conn3, block_phase, audio_feedback_run, audio_fb_target)
- gui.StartGUI()
- else:
- log.warning("GUI start disabled.")
- pass
- # terminate correctly all processes
- cur_pids = psutil.pids()
- # input('\nDone. Press enter to exit app')
- # pids.remove(xterm.pid) # keep terminal for debugging after app exits
- for pid in pids:
- if pid in cur_pids:
- p = psutil.Process(pid)
- log.info('Terminating pid: {}'.format(pid))
- p.terminate()
- else:
- log.debug('Process {} has been already terminated'.format(pid))
- print('Exiting app.')
|