123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- '''
- description: kiap bci main class, other kiap modules are imported
- author: Ioannis Vlachos
- date: 30.08.18
- Copyright (c) 2018 Ioannis Vlachos
- All rights reserved.'''
- # import argparse
- import ctypes
- # import logging
- import multiprocessing as mp
- import os
- import subprocess
- import sys
- import time
- from datetime import date
- from multiprocessing import Pipe, Value, current_process
- from subprocess import PIPE, Popen
- from timeit import default_timer
- import cere_conn as cc
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- # import psutil
- from matplotlib import gridspec
- from pyfiglet import Figlet
- # from scipy import signal, stats
- from sklearn.decomposition.pca import PCA
- import aux
- # from modules import bci, daq, data
- # from mp_gui.kiapGUIToolBox import kiapGUI
- # matplotlib.use('Qt4Agg', warn=False, force=True)
- matplotlib.use('TkAgg', warn=False, force=True)
- def update_ch_map():
- ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
- ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- ch_map = ck.get_spike_rate_estimator_ch_u_map()
- log.debug(ch_map)
- log.info(f"# of channels in ch_map: {len(ch_map['list'])}")
- log.warning('Only unit 0 will be returned. Check spike-sorting status in Central.')
- return None
- def plot_results(data_buffer, global_buffer_idx, ch_ids):
- fig = plt.figure(1, figsize=(8, 8))
- plt.clf()
- gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 2])
- plt.subplot(gs[0])
- ax1 = plt.gca()
- plt.ylim(-2, 300)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- ax1.set_xticks([])
- ax1.set_yticks([])
- plt.ylabel('Rates (sp/sec)')
- # ax1.set_title('Rates')
- plt.subplot(gs[1])
- ax2 = plt.gca()
- plt.ylim(-100, 200)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylabel('PC1, PC2')
- plt.xlabel('sec')
- plt.subplot(gs[2])
- ax3 = plt.gca()
- plt.xlim(-200, 200)
- plt.ylim(-200, 200)
- # plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylabel('PC1 vs PC2')
- fig.canvas.draw()
- col = ['b', 'r', 'g']
- n_channels_plot = len(range(ch_ids[0],ch_ids[1]))
- lines1 = [ax1.plot(np.arange(params.buffer.shape[0]) * 0, 'C0', alpha=0.5)[0] for zz in range(n_channels_plot)] # rates
- lines2 = [ax2.plot(np.arange(params.buffer.shape[0]) * 0, alpha=0.5)[0] for zz in range(2)]
- lines3 = [ax3.plot(np.arange(1) * 0,'.', alpha=0.5)[0] for zz in range(1)]
- lines3.append(ax3.plot(np.arange(1),'C3o',mec='k')[0])
- background1 = fig.canvas.copy_from_bbox(ax1.bbox)
- background2 = fig.canvas.copy_from_bbox(ax2.bbox)
- background3 = fig.canvas.copy_from_bbox(ax3.bbox)
- for ii in range(params.plot.n_channels):
- ax1.draw_artist(lines1[ii])
- plt.pause(0.1)
- # plt.ion()
- print(ch_ids,n_channels_plot)
- offset = np.arange(0,n_channels_plot)*40 + ch_ids[0]*40
- min1,max1 = 0,0
- pca = PCA(n_components=2, svd_solver = 'arpack')
- while 1:
- cnt = 0
- tstart = time.time()
- while recording_status.value > 0:
- cnt += 1
- fig.canvas.restore_region(background1)
- fig.canvas.restore_region(background2)
- fig.canvas.restore_region(background3)
- # subplot 1
- b_idx = global_buffer_idx.value
- # log.error(f'b_idx: {b_idx}')
- if b_idx>20 and b_idx%20==0:
- print(b_idx)
- min1 = min(min1, int(np.min(data_buffer[b_idx-20:b_idx,ch_ids[0]:ch_ids[1]]-offset)))
- max1 = max(max1, int(np.max(data_buffer[b_idx-20:b_idx,ch_ids[0]:ch_ids[1]]-offset)))
- ax1.set_title(f'min/max rate: {min1}, {max1} sp/sec')
- plt.draw()
- xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000.
- # for ii in range(params.plot.n_channels):
- for ii in range(n_channels_plot):
- lines1[ii].set_xdata(xx)
- lines1[ii].set_ydata(data_buffer[:b_idx, ii])
- ax1.draw_artist(lines1[ii])
- if b_idx >2 and np.var(data_buffer[:,0])>0:
- pca_res = pca.fit_transform(data_buffer[:b_idx, ch_ids[0]:ch_ids[1]])
- # subplot 2
- [lines2[zz].set_xdata(xx) for zz in range(2)]
- [lines2[zz].set_ydata(pca_res[:,zz]) for zz in range(2)]
- # subplot 3
- lines3[0].set_xdata(pca_res[:,0])
- lines3[0].set_ydata(pca_res[:,1])
- lines3[1].set_xdata(pca_res[-1,0])
- lines3[1].set_ydata(pca_res[-1,1])
- [ax2.draw_artist(lines2[zz]) for zz in range(2)]
- [ax3.draw_artist(lines3[zz]) for zz in range(2)]
- fig.canvas.blit(ax1.bbox)
- fig.canvas.blit(ax2.bbox)
- fig.canvas.blit(ax3.bbox)
- time.sleep(t_sleep) # to avoid high cpu usage
- time.sleep(t_sleep)
- return None
- log = aux.log # aux file contains logging configuration
- # args = aux.args
- t_sleep = 0.05
- if __name__ == '__main__':
- ch1 = int(sys.argv[1])
- ch2 = int(sys.argv[2])
- ch_ids = [ch1,ch2]
- # xx
- print('-' * 100)
- print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R\n\nlive-plot'))
- print('-' * 100)
- params = aux.load_config()
- recording_status = Value('i', 1)
- global_buffer_idx = Value('i', 0)
- # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels)
- n_channels = params.daq.n_channels
- n_channels_plot = len(range(ch_ids[0],ch_ids[1]))
- shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels)
- data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj())
- data_buffer = data_buffer.reshape(params.buffer.length, n_channels)
- visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, global_buffer_idx, ch_ids))
- visual.daemon = True # kill visualization if main app terminates
- visual.start()
- processes = [visual]
- pids = [mp.current_process().pid, visual.pid]
- print(f'pids: {pids}')
- os.system(f'taskset -p -c 3 {mp.current_process().pid}') # set process affinity to core 3
- os.system(f'taskset -p -c 4 {visual.pid}') # set process affinity to core 4
- ck = cc.CereConn()
- ck.send_open()
- t = time.time()
- while ck.get_state() != cc.ccS_Idle:
- time.sleep(0.005)
- # get only unit 0, caution: switch of spike sorting in central
- ck.set_spike_rate_estimator_loop_interval_ms(params.daq.spike_rates.loop_interval)
- ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
- ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- ck.set_spike_rate_estimation_method_exponential(params.daq.spike_rates.decay_factor, params.daq.spike_rates.max_bins)
- # Set CAR, both for LFP (1kHz data) and for raw data (used for SBP)
- if params.daq.car_channels:
- ck.set_car_channels(2, params.daq.car_channels)
- ck.set_car_channels(6, params.daq.car_channels)
- update_ch_map()
- time.sleep(0.5)
- # gids = [5]
- # bin_width = 0.005
- # run_time = 1
- ck.send_record()
- offset = np.arange(0,params.daq.n_channels)*40
- while True:
- idx = global_buffer_idx.value
- rates = ck.get_spike_rate_data()
- # ts = data['ts']
- data = rates['rates']
- if data.shape[0]>0:
- data = data+offset
- if idx + data.shape[0] >= data_buffer.shape[0]:
- idx = 0
- data_buffer[idx:idx + data.shape[0], :data.shape[1]] = data
- global_buffer_idx.value = idx + data.shape[0]
- # print(data)
- time.sleep(t_sleep)
- ck.send_close()