kiap_bci.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. '''
  2. description: kiap bci main class, other kiap modules are imported
  3. author: Ioannis Vlachos
  4. date: 30.08.18
  5. Copyright (c) 2018 Ioannis Vlachos
  6. All rights reserved.'''
  7. import argparse
  8. import ctypes
  9. import datetime
  10. import logging
  11. import multiprocessing as mp
  12. import os
  13. import subprocess
  14. import sys
  15. import time
  16. from datetime import date
  17. from multiprocessing import Pipe, Value, current_process, Array
  18. from subprocess import PIPE, Popen
  19. from timeit import default_timer
  20. import gc
  21. import signal as signal2
  22. import cerebus.cbpy as cb
  23. import matplotlib
  24. import matplotlib.pyplot as plt
  25. from matplotlib.animation import FuncAnimation
  26. import numpy as np
  27. import psutil
  28. import pyaudio
  29. from scipy.io import wavfile
  30. import tkinter as tk
  31. from pyfiglet import Figlet
  32. from PyQt5.QtCore import pyqtSignal # remove later
  33. from PyQt5.QtWidgets import QApplication
  34. from scipy import signal, stats
  35. from sklearn.decomposition.pca import PCA
  36. import aux
  37. from aux import static_vars
  38. from modules import bci, daq, data
  39. from mp_gui.kiapGUIToolBox import kiapGUI
  40. from paradigms.feedback import note, load_wav, norm2freq
  41. from paradigms import sp_matrix
  42. matplotlib.use('Qt4Agg', warn=False, force=True)
  43. # matplotlib.use('TkAgg', warn=False, force=True)
  44. def run_watchdog(processes, pids, pipe_path, data_obj):
  45. '''monitor status of all related running processes'''
  46. cp = current_process()
  47. processes.append(cp)
  48. pids.append(cp.pid)
  49. log.debug('Running process pids: {}'.format(pids))
  50. # log.debug('Running processes: {}'.format(processes))
  51. while 1:
  52. with open(PIPE_PATH, "w") as pipe1:
  53. for ii, p in enumerate(processes):
  54. # if os.path.exists('/proc/{}'.format(ii)):
  55. if psutil.pid_exists(pids[ii]):
  56. pipe1.write('{} | {} (pid:{}) is alive. \n'.format(datetime.datetime.now(), p.name, pids[ii]))
  57. else:
  58. pipe1.write('{} | {} (pid:{}) is dead. \n'.format(datetime.datetime.now(), p.name, pids[ii]))
  59. data_buffer = data_obj.read_buffer()
  60. pipe1.write('{} | buffer shape: {}, {:.2f} (MB)'.format(datetime.datetime.now(), data_buffer.shape, data_buffer.nbytes * 1e-6))
  61. pipe1.write('\n')
  62. time.sleep(1)
  63. def audio_feedback_process(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase):
  64. RATE = 44100
  65. pa = pyaudio.PyAudio()
  66. log.info('audio feedback processes started')
  67. stream = pa.open(output=True, channels=2, rate=RATE,format=pyaudio.paFloat32)
  68. alpha = params.feedback.alpha
  69. beta = params.feedback.beta
  70. length = params.feedback.tone_length
  71. while True:
  72. target_counter = 0
  73. target_maintained_counter = 0
  74. which_target = aux.decision.unclassified.value
  75. while audio_feedback_run.value == 1:
  76. norm_rate = normalized_frate.value
  77. freq_float = norm2freq(norm_rate, alpha, beta)
  78. log.info(f'Normalized rate: {norm_rate}, freq: {freq_float}')
  79. tone = note(freq_float, length, amp=0.5, rate=RATE)
  80. if (params.feedback.target_n_tones > 0) and (target_counter % params.feedback.target_n_tones == params.feedback.target_n_tones - 1) and (params.speller.type == 'feedback'):
  81. target_freq_float = norm2freq(audio_fb_target[0], alpha, beta)
  82. log.info(f'freq: {freq_float} target: {target_freq_float}')
  83. target_tone = note(target_freq_float, length, amp=.25, rate=RATE, mode='saw')
  84. tone = np.concatenate((tone.reshape((-1, 1)), target_tone.reshape((-1, 1))), axis=1).flatten()
  85. else:
  86. tone = np.repeat(tone.reshape((-1, 1)), 2, axis=1).flatten() # need to copy, because we have 2 channels
  87. stream.write(tone.tostring())
  88. # if 'reward_on_target' is True, we will check if normalized firing rate is in the target range. If so,
  89. # play reward tone and finish trial
  90. if (params.speller.type in ('question', 'feedback', 'color')) and block_phase.value == 2 and (not params.classifier.online):
  91. # this is a poor man's classifier – in case of question and not using online classifier (which would be set if classifier is trained)
  92. if params.paradigms.feedback.states.down[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.down[2]:
  93. if which_target == aux.decision.no.value:
  94. target_maintained_counter += 1
  95. else:
  96. which_target = aux.decision.no.value
  97. target_maintained_counter = 0
  98. elif params.paradigms.feedback.states.up[1] <= norm_rate and norm_rate <= params.paradigms.feedback.states.up[2] :
  99. if which_target == aux.decision.yes.value:
  100. target_maintained_counter += 1
  101. else:
  102. which_target = aux.decision.yes.value
  103. target_maintained_counter = 0
  104. else:
  105. which_target = aux.decision.unclassified.value
  106. target_maintained_counter = 0
  107. if target_maintained_counter >= params.feedback.hold_iterations:
  108. target_maintained_counter = 0
  109. # TODO: Fix the hack
  110. # if params.speller.type == 'feedback' and not (params.paradigms.feedback.play_end_feedback.success and params.paradigms.feedback.play_end_feedback.fail) and
  111. decoder_decision.value = which_target
  112. audio_feedback_run.value = 0
  113. else:
  114. target_maintained_counter = 0
  115. which_target = aux.decision.unclassified.value
  116. target_counter += 1
  117. time.sleep(.1)
  118. return None
  119. def plot_results(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate):
  120. fig = plt.figure(1, figsize=(16, 4))
  121. plt.clf()
  122. params.plot.channels = [ch.id for ch in params.daq.normalization.channels]
  123. log.info(f'plot channels: {params.plot.channels}')
  124. min1 = -1
  125. max1 = max([ch.top for ch in params.daq.normalization.channels])*1.5 # set ymax according to channel dynamic range
  126. plt.subplot(311)
  127. ax1 = plt.gca()
  128. plt.ylim(min1, max1)
  129. plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  130. ax1.set_xticks([])
  131. plt.ylabel('Rates (sp/sec)')
  132. ax1.set_title(f'channels: {params.plot.channels}')
  133. plt.subplot(312)
  134. ax2 = plt.gca()
  135. plt.ylim(-0.05, 1.05)
  136. plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  137. plt.ylabel('Normalized\nrate')
  138. plt.subplot(313)
  139. ax3 = plt.gca()
  140. plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  141. plt.ylim(-100, 300)
  142. plt.ylim(-30, 30)
  143. plt.ylabel('PC1, PC2')
  144. # ax2.set_title('Probabilities')
  145. # plt.subplot(211)
  146. # fig.show()
  147. # plt.draw()
  148. # We need to draw the canvas before we start animating...
  149. fig.canvas.draw()
  150. col = ['b', 'r', 'g']
  151. x_len = params.buffer.shape[0]
  152. lines1 = [ax1.plot(np.arange(x_len) * 0, alpha=0.5)[0] for zz in range(len(params.plot.channels))] # rates
  153. # plot normalized firing rate and thresholds
  154. lines2 = [ax2.plot(np.arange(x_len) * 0, 'k.', alpha=0.5)[0] for zz in range(1)]
  155. lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.up[1], 'C3--', alpha=0.5, lw=1)[0])
  156. lines2.append(ax2.plot(np.arange(0,x_len,x_len-1), np.arange(0,x_len,x_len-1)*0+params.paradigms.feedback.states.down[2], 'C3--', alpha=0.5, lw=1)[0])
  157. lines3 = [ax3.plot(np.arange(params.buffer.shape[0]) * 0, alpha=0.5)[0] for zz in range(2)]
  158. background1 = fig.canvas.copy_from_bbox(ax1.bbox)
  159. # background1 = ax1.get_figure().bbox)
  160. background2 = fig.canvas.copy_from_bbox(ax2.bbox)
  161. background3 = fig.canvas.copy_from_bbox(ax3.bbox)
  162. # lines1.extend(ax1.plot(200, 1, 'ro'))
  163. title_str = ''
  164. plt.xlabel('sec')
  165. plt.title(title_str)
  166. for line in lines1:
  167. ax1.draw_artist(line)
  168. plt.pause(0.1)
  169. norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan
  170. while 1:
  171. cnt = 0
  172. tstart = time.time()
  173. while recording_status.value > 0:
  174. cnt += 1
  175. fig.canvas.restore_region(background1)
  176. # fig.canvas.restore_region(fig.canvas.copy_from_bbox(ax1.get_figure().bbox))
  177. fig.canvas.restore_region(background2)
  178. fig.canvas.restore_region(background3)
  179. b_idx = global_buffer_idx.value
  180. if b_idx >= params.buffer.shape[0]-5:
  181. norm_rate_tot = np.zeros((params.buffer.shape[0], ))*np.nan # reset if reached at the end
  182. # subplot 1
  183. xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000.
  184. for ii,jj in enumerate(params.plot.channels):
  185. # log.warning(data_buffer.shape)
  186. lines1[ii].set_xdata(xx)
  187. # lines1[1].set_xdata(xx)
  188. lines1[ii].set_ydata(data_buffer[:b_idx, jj])
  189. # lines1[1].set_ydata(data_buffer[:b_idx, 1])
  190. # ax1.draw_artist(lines1[1])
  191. ax1.draw_artist(lines1[ii])
  192. # subplot 2
  193. norm_rate_tot[b_idx] = normalized_frate.value
  194. # log.error(f'{norm_rate_tot[-20:]}, {normalized_frate.value}')
  195. idx = np.isfinite(norm_rate_tot[b_idx])
  196. [lines2[zz].set_xdata(xx[idx]) for zz in range(1)]
  197. [lines2[zz].set_ydata(norm_rate_tot[:b_idx][idx]) for zz in range(1)]
  198. [ax2.draw_artist(lines2[zz]) for zz in range(3)]
  199. # fig.canvas.blit(ax1.bbox)
  200. # fig.canvas.blit(ax2.bbox)
  201. # fig.canvas.blit(ax3.bbox)
  202. fig.canvas.update()
  203. fig.canvas.flush_events()
  204. # time.sleep(0.00001)
  205. if child_conn2.poll():
  206. decision_history = child_conn2.recv()
  207. time.sleep(1./params.plot.fps) # to avoid high cpu usage
  208. # print('FPS:', cnt / (time.time() - tstart))
  209. time.sleep(0.1)
  210. def plot_feedback(audio_feedback_run, audio_fb_target, normalized_frate, block_phase):
  211. """Plot feedback in real time. Uses normalized frequency and audio_fb_target to plot target."""
  212. log.info("Starting visual feedback")
  213. # config_file = 'paradigm.yaml'
  214. # config = self._read_config(config_file)
  215. # config = config.feedback
  216. fig = plt.figure(10, figsize=(4, 4), facecolor='black')
  217. ax = fig.add_subplot(1, 1, 1)
  218. ax.set_facecolor((0.02, 0.02, 0.02))
  219. target_x = 0.1
  220. target_w = 0.8
  221. cursor_x = 0.2
  222. cursor_y = 0.5 # variable
  223. cursor_w = 0.6
  224. cursor_h = 0.05
  225. target_ra = ax.add_patch(plt.Rectangle((target_x, 0.01), target_w, 0.01, fill=True, edgecolor=None, facecolor=(.7, .2, .4), linewidth=0, zorder=1))
  226. cursor_ra = ax.add_patch(plt.Rectangle((cursor_x, 0.5), cursor_w, cursor_h, fill=True, facecolor=(.4, .5, 1.0), linewidth=0, zorder=10))
  227. targ_lines = ax.hlines([params.paradigms.feedback.states.down[2],params.paradigms.feedback.states.up[1]] , 0, 1, colors='r', linestyles='solid')
  228. plt.show(False)
  229. plt.draw()
  230. def init_plot():
  231. ax.set_clip_on(False)
  232. ax.get_xaxis().set_ticks([])
  233. ax.get_yaxis().set_ticks([])
  234. ax.set_xlim(0, 1)
  235. ax.set_ylim(0, 1)
  236. for n,s in ax.spines.items():
  237. s.set_color((.2,.2,.2))
  238. return (cursor_ra, target_ra, targ_lines)
  239. @static_vars(is_plotting=False)
  240. def update_plot(i):
  241. if (not update_plot.is_plotting) and audio_feedback_run.value == 1:
  242. update_plot.is_plotting = True
  243. cursor_ra.set_visible(True)
  244. target_ra.set_visible(True)
  245. init_plot()
  246. elif update_plot.is_plotting and (audio_feedback_run.value == 0):
  247. update_plot.is_plotting = False
  248. cursor_ra.set_visible(False)
  249. target_ra.set_visible(False)
  250. if update_plot.is_plotting:
  251. cursor_y = normalized_frate.value - cursor_h/2.0
  252. target_y = audio_fb_target[1]
  253. target_h = audio_fb_target[2] - audio_fb_target[1]
  254. if (audio_fb_target[1] <= normalized_frate.value) and (normalized_frate.value <= audio_fb_target[2]):
  255. target_ra.set_fc((.3,1.0,.6))
  256. else:
  257. target_ra.set_fc((.7, .2, .4))
  258. cursor_ra.set_y(cursor_y)
  259. target_ra.set_y(target_y)
  260. target_ra.set_height(target_h)
  261. log.debug(f'cursor_y: {cursor_y}, target_y: {target_y}, target_h: {target_h}')
  262. return (cursor_ra, target_ra, targ_lines)
  263. ani = FuncAnimation(fig, update_plot, frames=None,
  264. init_func=init_plot, blit=True, interval= 1000.0 / params.plot.fps)
  265. plt.show()
  266. return None
  267. def speller_matrix(child_conn):
  268. def _terminate_process(*args):
  269. # pid = os.getpid()
  270. # pname = psutil.Process(pid)
  271. # log.error(f'{pname} received SIGINT')
  272. log.error(f'speller matrix received SIGINT')
  273. root.destroy()
  274. return None
  275. root = tk.Tk()
  276. sp_matrix.sp_matrix(root, child_conn)
  277. # root.lift()
  278. root.attributes("-topmost", True)
  279. signal2.signal(signal2.SIGINT, _terminate_process)
  280. root.mainloop()
  281. return None
  282. log = aux.log # aux file contains logging configuration
  283. args = aux.args
  284. if __name__ == '__main__':
  285. print('-' * 100)
  286. print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R'))
  287. print('-' * 100)
  288. # TESTS
  289. # pa = pyaudio.PyAudio()
  290. # try:
  291. # pa.get_default_output_device_info()
  292. # except Exception as e:
  293. # log.error(e)
  294. # os.system("sed -i 's/audio: True/audio: False/g' config.yaml")
  295. # # sys.exit()
  296. # pa.terminate()
  297. start_gui = args.gui
  298. live_plot = args.plot
  299. # start_audio_feedback = args.audio_feedback
  300. recording_status = Value('i', 0)
  301. decoder_decision = Value('i', 0)
  302. global_buffer_idx = Value('i', 0)
  303. block_phase = Value('i',0) # 0: baseline, 1: stimulus, 2:response
  304. audio_feedback_run = Value('i',0) # 0: baseline, 1: stimulus, 2:response
  305. audio_fb_target = Array('d', 3) # Target for normalized audio feedback activity [0, 1]. Give upper and lower bound
  306. normalized_frate = Value('d', 0)
  307. # recording_type = Value('i',0) # 0: idle, 1: baseline, 2:recording
  308. # audio_feedback_process(audio_fb_freq)
  309. # time.sleep(10)
  310. # recording_type = Value(ctypes.c_char_p, b'MAIN()')
  311. # rec_signal = pyqtSignal()
  312. parent_conn, child_conn = Pipe()
  313. parent_conn2, child_conn2 = Pipe()
  314. try:
  315. params = aux.load_config()
  316. except Exception as e:
  317. log.error(e)
  318. sys.exit(1)
  319. # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels)
  320. n_channels = params.daq.n_channels
  321. shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels)
  322. data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj())
  323. data_buffer = data_buffer.reshape(params.buffer.length, n_channels)
  324. shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * params.classifier.n_classes)
  325. class_prob = np.ctypeslib.as_array(shared_array_base.get_obj())
  326. class_prob = class_prob.reshape(params.buffer.length, params.classifier.n_classes)
  327. class_prob[:] = 0*np.nan
  328. data_obj = data.Data(data_buffer, params, recording_status, decoder_decision, child_conn, global_buffer_idx, class_prob, block_phase, normalized_frate)
  329. data_obj.daemon = True
  330. data_obj.start()
  331. main_pid = os.getpid()
  332. pids = [mp.current_process().pid, data_obj.pid]
  333. # WATCHDOG
  334. # # ----------
  335. # PIPE_PATH = "/tmp/my_pipe"
  336. # xterm = Popen('xterm -fa "Monospace" -fs 10 -e tail -f %s' % PIPE_PATH, shell=True)
  337. # setattr(xterm, 'name', 'xterm')
  338. # # processes = [xterm, data_obj]
  339. # pids = [xterm.pid, data_obj.pid]
  340. # watchdog = mp.Process(name='watchdog', target=run_watchdog, args=(processes, pids, PIPE_PATH, data_obj))
  341. # # watchdog.start()
  342. if live_plot:
  343. visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, child_conn2, global_buffer_idx, class_prob, normalized_frate))
  344. visual.daemon = True # kill visualization if main app terminates
  345. visual.start()
  346. pids += (visual.pid,)
  347. if params.feedback.feedback_tone:
  348. feedback_p = mp.Process(name='audio_feedback', target=audio_feedback_process, args=(audio_feedback_run, decoder_decision, audio_fb_target, normalized_frate, block_phase))
  349. feedback_p.daemon = True # kill visualization if main app terminates
  350. feedback_p.start()
  351. vfeedback_p = mp.Process(name='vfeedback', target=plot_feedback, args=(audio_feedback_run, audio_fb_target, normalized_frate, block_phase))
  352. vfeedback_p.daemon = True # kill visualization if main app terminates
  353. vfeedback_p.start()
  354. pids += (feedback_p.pid, vfeedback_p.pid)
  355. log.info('audio feedback process started')
  356. if params.speller.type == 'color' and params.speller.speller_matrix: # start speller matrix as separate process
  357. log.warning('speller matrix started')
  358. parent_conn3, child_conn3 = Pipe()
  359. matrix = mp.Process(name='matrix', target=speller_matrix, args=(child_conn3,))
  360. matrix.daemon = True
  361. matrix.start()
  362. pids += (matrix.pid,)
  363. parent_conn3.send(['','',''])
  364. else:
  365. parent_conn3 = []
  366. # processes = [data_obj, visual, feedback_p]
  367. print(f'pids: {pids}')
  368. os.system(f'taskset -p -c 0 {mp.current_process().pid}') # set process affinity to core 0
  369. os.system(f'taskset -p -c 1 {data_obj.pid}') # set process affinity to core 1
  370. os.system(f'taskset -p -c 2 {visual.pid}') # set process affinity to core 2
  371. # processes.append(watchdog)
  372. # pids.append(watchdog.pid)
  373. # if not os.path.exists(PIPE_PATH):
  374. # os.mkfifo(PIPE_PATH)
  375. freq = np.random.randint(50,250)
  376. if start_gui == 1:
  377. log.info("Starting GUI")
  378. app = QApplication(sys.argv)
  379. gui = kiapGUI(recording_status, log, params, decoder_decision, parent_conn, parent_conn2, parent_conn3, block_phase, audio_feedback_run, audio_fb_target)
  380. gui.StartGUI()
  381. else:
  382. log.warning("GUI start disabled.")
  383. pass
  384. # terminate correctly all processes
  385. cur_pids = psutil.pids()
  386. # input('\nDone. Press enter to exit app')
  387. # pids.remove(xterm.pid) # keep terminal for debugging after app exits
  388. for pid in pids:
  389. if pid in cur_pids:
  390. p = psutil.Process(pid)
  391. log.info('Terminating pid: {}'.format(pid))
  392. p.terminate()
  393. else:
  394. log.debug('Process {} has been already terminated'.format(pid))
  395. print('Exiting app.')