123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- '''
- description: kiap bci main class, other kiap modules are imported
- author: Ioannis Vlachos
- date: 30.08.18
- Copyright (c) 2018 Ioannis Vlachos
- All rights reserved.'''
- # import argparse
- import ctypes
- # import logging
- import multiprocessing as mp
- import os
- import subprocess
- import sys
- import time
- from datetime import date
- from multiprocessing import Pipe, Value, current_process
- from subprocess import PIPE, Popen
- from timeit import default_timer
- import cere_conn as cc
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- # import psutil
- from matplotlib import gridspec
- from pyfiglet import Figlet
- # from scipy import signal, stats
- from sklearn.decomposition.pca import PCA
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
- import pyaudio
- import aux
- from aux import load_config
- # from modules import bci, daq, data
- # from mp_gui.kiapGUIToolBox import kiapGUI
- # matplotlib.use('Qt4Agg', warn=False, force=True)
- matplotlib.use('TkAgg', warn=False, force=True)
- def callback(in_data, frame_count, time_info, status):
- # data = wf.readframes(frame_count)
- global freq, length
- # freq = 300
- data = note(freq, length, amp=amp, rate=RATE)
- return (data, pyaudio.paContinue)
- def note(freq, len, amp=1, rate=44100):
- t = np.linspace(0, len, len * rate)
- data = np.sin(2 * np.pi * freq * t) * amp
- return data.astype(np.int16) # two byte integers
- def update_ch_map():
- ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
- ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- ch_map = ck.get_spike_rate_estimator_ch_u_map()
- log.debug(ch_map)
- log.info(f"# of channels in ch_map: {len(ch_map['list'])}")
- log.warning('Only unit 0 will be returned. Check spike-sorting status in Central.')
- return None
- def audio_feedback(freq):
- RATE = 44100
- pa = pyaudio.PyAudio()
- s = pa.open(output=True,
- channels=2,
- rate=RATE,
- format=pyaudio.paInt16)
- # stream_callback=callback)
- while True:
- length = 200 / freq.value
- print(freq.value, length, RATE)
- tone = note(freq.value, length, amp=10000, rate=RATE)[1:]
- s.write(tone)
- # freq = np.random.randint(50,250)
- # freq += 10
- return None
- def plot_results(data_buffer, global_buffer_idx):
- fig = plt.figure(1, figsize=(8, 10))
- plt.clf()
- gs = gridspec.GridSpec(3, 1, height_ratios=[8, 1, 2])
- plt.subplot(gs[0])
- ax1 = plt.gca()
- plt.ylim(-20, 5000)
- # plt.ylim(2000, 3500)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- ax1.set_xticks([])
- ax1.set_yticks([])
- plt.ylabel('Rates (sp/sec)')
- # ax1.set_title('Rates')
-
- plt.subplot(gs[1])
- ax2 = plt.gca()
- plt.ylim(-200, 200)
- plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylabel('PC1, PC2')
- plt.xlabel('sec')
- plt.subplot(gs[2])
- ax3 = plt.gca()
- plt.xlim(-200, 200)
- plt.ylim(-200, 200)
- # plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
- plt.ylabel('PC1 vs PC2')
- fig.canvas.draw()
- col = ['b', 'r', 'g']
- # ch_range1 = list(range(ch_ids[0],ch_ids[1]))
- # ch_range2 = list(range(ch_ids[2],ch_ids[3]))
- ch_range1 = list(range(32)) + list(range(96, 128))
- # ch_range1 = [10,11,12,13]
- ch_range2 = list(range(32,96))
- # ch_range1 = list(range(0,128))
- # ch_range1 = list(range(0,128))
- ch_range = ch_range1+ch_range2
- n_channels_plot = len(ch_range)
- log.info(f'{ch_range}, {n_channels_plot}')
- lines1 = [ax1.plot(np.arange(params.buffer.shape[0]) * 0, 'C0', alpha=0.5)[0] for zz in range(len(ch_range1))] # rates
- lines1.extend([ax1.plot(np.arange(params.buffer.shape[0]) * 0, 'C1', alpha=0.5)[0] for zz in range(len(ch_range2))]) # rates
-
- lines2 = [ax2.plot(np.arange(params.buffer.shape[0]) * 0, 'C0',alpha=0.5)[0] for zz in range(2)]
- lines2.extend([ax2.plot(np.arange(params.buffer.shape[0]) * 0, 'C1',alpha=0.5)[0] for zz in range(2)])
-
- lines3 = [ax3.plot(np.arange(1) * 0,'C0.', alpha=0.2)[0] for zz in range(1)]
- lines3.append(ax3.plot(np.arange(1) * 0,'C1.', alpha=0.2)[0])
- lines3.append(ax3.plot(np.arange(1),'C0o',mec='k', alpha=0.8)[0])
- lines3.append(ax3.plot(np.arange(1),'C1o',mec='k', alpha=0.8)[0])
-
- background1 = fig.canvas.copy_from_bbox(ax1.bbox)
- background2 = fig.canvas.copy_from_bbox(ax2.bbox)
- background3 = fig.canvas.copy_from_bbox(ax3.bbox)
-
- for ii in range(n_channels_plot):
- ax1.draw_artist(lines1[ii])
- plt.pause(0.1)
- # plt.ion()
- print(n_channels_plot)
- offset = np.arange(0,params.daq.n_channels)*40
- min1,max1 = 0,0
- pca = PCA(n_components=2, svd_solver = 'arpack')
- # lda = LDA(n_components=2, solver = 'lsqr')
- while 1:
- cnt = 0
- tstart = time.time()
- while recording_status.value > 0:
- cnt += 1
- fig.canvas.restore_region(background1)
- fig.canvas.restore_region(background2)
- fig.canvas.restore_region(background3)
- # subplot 1
- b_idx = global_buffer_idx.value
- # log.error(f'b_idx: {b_idx}')
- if b_idx>20 and b_idx%20==0:
- min11 = min(min1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
- min12 = min(min1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
- max11 = max(max1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
- max12 = max(max1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
-
- ax1.set_title(f'min/max rate: {min11}, {max11}, {min12}, {max12} sp/sec')
- plt.draw()
- # log.warning(offset[ch_range1])
- # log.warning(offset[ch_range2])
- xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000.
- # for ii in range(params.plot.n_channels):
- for ii in range(n_channels_plot):
- lines1[ii].set_xdata(xx)
- # if ch_range[ii] >= ch_ids[2]:
- # offset2 = -offset[ch_ids[2]]+offset[ch_ids[1]]
- # else:
- offset2 = 0
- lines1[ii].set_ydata(data_buffer[:b_idx, ch_range[ii]]+offset2)
- ax1.draw_artist(lines1[ii])
- # log.warning(f'{data_buffer[b_idx, ch_range]}')
- # if b_idx >2 and np.var(data_buffer[:,0])>0:
- if b_idx >2:
- pca_res1 = pca.fit_transform(data_buffer[:b_idx, ch_range1])
- pca_res2 = pca.fit_transform(data_buffer[:b_idx, ch_range2])
- # subplot 2
- [lines2[zz].set_xdata(xx) for zz in range(2)]
- [lines2[zz].set_ydata(pca_res1[:,zz]) for zz in range(2)]
- [lines2[zz].set_xdata(xx) for zz in range(2,4)]
- [lines2[zz].set_ydata(pca_res2[:,zz-2]) for zz in range(2,4)]
- # subplot 3
- # log.warning(f'{ch_range1}, {ch_range2}, {pca_res1}')
- lines3[0].set_xdata(pca_res1[:,0])
- lines3[0].set_ydata(pca_res1[:,1])
- lines3[1].set_xdata(pca_res2[:,0])
- lines3[1].set_ydata(pca_res2[:,1])
-
- lines3[2].set_xdata(pca_res1[-1,0])
- lines3[2].set_ydata(pca_res1[-1,1])
-
- lines3[3].set_xdata(pca_res2[-1,0])
- lines3[3].set_ydata(pca_res2[-1,1])
- [ax2.draw_artist(lines2[zz]) for zz in range(4)]
- [ax3.draw_artist(lines3[zz]) for zz in range(4)]
-
- fig.canvas.blit(ax1.bbox)
- fig.canvas.blit(ax2.bbox)
- fig.canvas.blit(ax3.bbox)
- time.sleep(t_sleep) # to avoid high cpu usage
- time.sleep(t_sleep)
- return None
- log = aux.log # aux file contains logging configuration
- args = aux.args
- t_sleep = 0.1
- # RATE = 44100
- # FREQ = 262
- # freq = 100
- # length = 0.5*freq/1000.
- # amp = 10000
- # pa = pyaudio.PyAudio()
- # s = pa.open(output=True,
- # channels=2,
- # rate=RATE,
- # format=pyaudio.paInt16)
- # # stream_callback=callback)
- if __name__ == '__main__':
- # ch_ids = [int(item) for item in args.list.split(',')]
- # print(ch_ids)
- print('-' * 100)
- print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R\n\nlive-plot'))
- print('-' * 100)
- params = aux.load_config()
- recording_status = Value('i', 1)
- global_buffer_idx = Value('i', 0)
- # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels)
- n_channels = params.daq.n_channels
- # ch_range1 = list(range(ch_ids[0],ch_ids[1]))
- # ch_range2 = list(range(ch_ids[2],ch_ids[3]))
- # ch_range1 = list(range(32)) + list(range(96, 128))
- ch_range1 = [10,11]
- ch_range2 = list(range(32, 96))
- print(len(ch_range1), len(ch_range2))
- ch_range = ch_range1 + ch_range2
- n_channels_plot = len(ch_range)
- # n_channels_plot = 10
- shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels)
- data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj())
- data_buffer = data_buffer.reshape(params.buffer.length, n_channels)
- visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, global_buffer_idx))
- visual.daemon = True # kill visualization if main app terminates
- visual.start()
- freq = Value('i', 100)
- feedback = mp.Process(name='feedback', target=audio_feedback, args=(freq,))
- feedback.daemon = True # kill visualization if main app terminates
- feedback.start()
- # time.sleep(10)
- processes = [visual]
- pids = [mp.current_process().pid, visual.pid]
- print(f'pids: {pids}')
- os.system(f'taskset -p -c 3 {mp.current_process().pid}') # set process affinity to core 3
- os.system(f'taskset -p -c 4 {visual.pid}') # set process affinity to core 4
- ck = cc.CereConn(withSRE=True)
- ck.send_open()
- t = time.time()
- while ck.get_state() != cc.ccS_Idle:
- time.sleep(0.005)
- # get only unit 0, caution: switch of spike sorting in central
- ck.set_spike_rate_estimator_loop_interval_ms(params.daq.spike_rates.loop_interval)
- ck.set_spike_band_power_estimator_loop_interval_ms(params.daq.spike_rates.loop_interval)
- ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
- ch_list = [ii for ii in range(1,params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
- ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
- ck.set_spike_band_power_estimator_ch_list(ch_list)
- ck.set_spike_rate_estimation_method_exponential(params.daq.spike_rates.decay_factor, params.daq.spike_rates.max_bins)
-
- # Set CAR, both for LFP (1kHz data) and for raw data (used for SBP)
- if params.daq.car_channels:
- ck.set_car_channels(2, params.daq.car_channels)
- ck.set_car_channels(6, params.daq.car_channels)
-
- update_ch_map()
- time.sleep(0.5)
- # gids = [5]
- # bin_width = 0.005
- # run_time = 1
- ck.send_record()
- offset = np.arange(0,params.daq.n_channels)*40
- while True:
- idx = global_buffer_idx.value
- rates = ck.get_spike_rate_data()
- # ts = data['ts']
- data = rates['rates']
- # print(data.shape)
- if data.shape[0]>0:
- data = data+offset
-
- if idx + data.shape[0] >= data_buffer.shape[0]:
- idx = 0
- data_buffer[idx:idx + data.shape[0], :data.shape[1]] = data
- global_buffer_idx.value = idx + data.shape[0]
- if data.shape[0] == 0:
- freq.value = 100
- else:
- freq.value = int(data[1,1])*2
- print(freq.value)
- # print(data.shape)
- # print(data[1,1])
- # time.sleep(0.01)
- # freq.value = np.random.randint(50,250)
- time.sleep(t_sleep)
- ck.send_close()
|