live_plot3.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. '''
  2. description: kiap bci main class, other kiap modules are imported
  3. author: Ioannis Vlachos
  4. date: 30.08.18
  5. Copyright (c) 2018 Ioannis Vlachos
  6. All rights reserved.'''
  7. # import argparse
  8. import ctypes
  9. # import logging
  10. import multiprocessing as mp
  11. import os
  12. import subprocess
  13. import sys
  14. import time
  15. from datetime import date
  16. from multiprocessing import Pipe, Value, current_process
  17. from subprocess import PIPE, Popen
  18. from timeit import default_timer
  19. import cere_conn as cc
  20. import matplotlib
  21. import matplotlib.pyplot as plt
  22. import numpy as np
  23. # import psutil
  24. from matplotlib import gridspec
  25. from pyfiglet import Figlet
  26. # from scipy import signal, stats
  27. from sklearn.decomposition.pca import PCA
  28. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
  29. import pyaudio
  30. import aux
  31. from aux import load_config
  32. # from modules import bci, daq, data
  33. # from mp_gui.kiapGUIToolBox import kiapGUI
  34. # matplotlib.use('Qt4Agg', warn=False, force=True)
  35. matplotlib.use('TkAgg', warn=False, force=True)
  36. def callback(in_data, frame_count, time_info, status):
  37. # data = wf.readframes(frame_count)
  38. global freq, length
  39. # freq = 300
  40. data = note(freq, length, amp=amp, rate=RATE)
  41. return (data, pyaudio.paContinue)
  42. def note(freq, len, amp=1, rate=44100):
  43. t = np.linspace(0, len, len * rate)
  44. data = np.sin(2 * np.pi * freq * t) * amp
  45. return data.astype(np.int16) # two byte integers
  46. def update_ch_map():
  47. ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
  48. ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
  49. ch_map = ck.get_spike_rate_estimator_ch_u_map()
  50. log.debug(ch_map)
  51. log.info(f"# of channels in ch_map: {len(ch_map['list'])}")
  52. log.warning('Only unit 0 will be returned. Check spike-sorting status in Central.')
  53. return None
  54. def audio_feedback(freq):
  55. RATE = 44100
  56. pa = pyaudio.PyAudio()
  57. s = pa.open(output=True,
  58. channels=2,
  59. rate=RATE,
  60. format=pyaudio.paInt16)
  61. # stream_callback=callback)
  62. while True:
  63. length = 200 / freq.value
  64. print(freq.value, length, RATE)
  65. tone = note(freq.value, length, amp=10000, rate=RATE)[1:]
  66. s.write(tone)
  67. # freq = np.random.randint(50,250)
  68. # freq += 10
  69. return None
  70. def plot_results(data_buffer, global_buffer_idx):
  71. fig = plt.figure(1, figsize=(8, 10))
  72. plt.clf()
  73. gs = gridspec.GridSpec(3, 1, height_ratios=[8, 1, 2])
  74. plt.subplot(gs[0])
  75. ax1 = plt.gca()
  76. plt.ylim(-20, 5000)
  77. # plt.ylim(2000, 3500)
  78. plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  79. ax1.set_xticks([])
  80. ax1.set_yticks([])
  81. plt.ylabel('Rates (sp/sec)')
  82. # ax1.set_title('Rates')
  83. plt.subplot(gs[1])
  84. ax2 = plt.gca()
  85. plt.ylim(-200, 200)
  86. plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  87. plt.ylabel('PC1, PC2')
  88. plt.xlabel('sec')
  89. plt.subplot(gs[2])
  90. ax3 = plt.gca()
  91. plt.xlim(-200, 200)
  92. plt.ylim(-200, 200)
  93. # plt.xlim(-2, params.buffer.length * params.daq.spike_rates.loop_interval / 1000.)
  94. plt.ylabel('PC1 vs PC2')
  95. fig.canvas.draw()
  96. col = ['b', 'r', 'g']
  97. # ch_range1 = list(range(ch_ids[0],ch_ids[1]))
  98. # ch_range2 = list(range(ch_ids[2],ch_ids[3]))
  99. ch_range1 = list(range(32)) + list(range(96, 128))
  100. # ch_range1 = [10,11,12,13]
  101. ch_range2 = list(range(32,96))
  102. # ch_range1 = list(range(0,128))
  103. # ch_range1 = list(range(0,128))
  104. ch_range = ch_range1+ch_range2
  105. n_channels_plot = len(ch_range)
  106. log.info(f'{ch_range}, {n_channels_plot}')
  107. lines1 = [ax1.plot(np.arange(params.buffer.shape[0]) * 0, 'C0', alpha=0.5)[0] for zz in range(len(ch_range1))] # rates
  108. lines1.extend([ax1.plot(np.arange(params.buffer.shape[0]) * 0, 'C1', alpha=0.5)[0] for zz in range(len(ch_range2))]) # rates
  109. lines2 = [ax2.plot(np.arange(params.buffer.shape[0]) * 0, 'C0',alpha=0.5)[0] for zz in range(2)]
  110. lines2.extend([ax2.plot(np.arange(params.buffer.shape[0]) * 0, 'C1',alpha=0.5)[0] for zz in range(2)])
  111. lines3 = [ax3.plot(np.arange(1) * 0,'C0.', alpha=0.2)[0] for zz in range(1)]
  112. lines3.append(ax3.plot(np.arange(1) * 0,'C1.', alpha=0.2)[0])
  113. lines3.append(ax3.plot(np.arange(1),'C0o',mec='k', alpha=0.8)[0])
  114. lines3.append(ax3.plot(np.arange(1),'C1o',mec='k', alpha=0.8)[0])
  115. background1 = fig.canvas.copy_from_bbox(ax1.bbox)
  116. background2 = fig.canvas.copy_from_bbox(ax2.bbox)
  117. background3 = fig.canvas.copy_from_bbox(ax3.bbox)
  118. for ii in range(n_channels_plot):
  119. ax1.draw_artist(lines1[ii])
  120. plt.pause(0.1)
  121. # plt.ion()
  122. print(n_channels_plot)
  123. offset = np.arange(0,params.daq.n_channels)*40
  124. min1,max1 = 0,0
  125. pca = PCA(n_components=2, svd_solver = 'arpack')
  126. # lda = LDA(n_components=2, solver = 'lsqr')
  127. while 1:
  128. cnt = 0
  129. tstart = time.time()
  130. while recording_status.value > 0:
  131. cnt += 1
  132. fig.canvas.restore_region(background1)
  133. fig.canvas.restore_region(background2)
  134. fig.canvas.restore_region(background3)
  135. # subplot 1
  136. b_idx = global_buffer_idx.value
  137. # log.error(f'b_idx: {b_idx}')
  138. if b_idx>20 and b_idx%20==0:
  139. min11 = min(min1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
  140. min12 = min(min1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
  141. max11 = max(max1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
  142. max12 = max(max1, int(np.min(data_buffer[b_idx-20:b_idx, ch_range]-offset[ch_range])))
  143. ax1.set_title(f'min/max rate: {min11}, {max11}, {min12}, {max12} sp/sec')
  144. plt.draw()
  145. # log.warning(offset[ch_range1])
  146. # log.warning(offset[ch_range2])
  147. xx = np.arange(b_idx) * params.daq.spike_rates.loop_interval / 1000.
  148. # for ii in range(params.plot.n_channels):
  149. for ii in range(n_channels_plot):
  150. lines1[ii].set_xdata(xx)
  151. # if ch_range[ii] >= ch_ids[2]:
  152. # offset2 = -offset[ch_ids[2]]+offset[ch_ids[1]]
  153. # else:
  154. offset2 = 0
  155. lines1[ii].set_ydata(data_buffer[:b_idx, ch_range[ii]]+offset2)
  156. ax1.draw_artist(lines1[ii])
  157. # log.warning(f'{data_buffer[b_idx, ch_range]}')
  158. # if b_idx >2 and np.var(data_buffer[:,0])>0:
  159. if b_idx >2:
  160. pca_res1 = pca.fit_transform(data_buffer[:b_idx, ch_range1])
  161. pca_res2 = pca.fit_transform(data_buffer[:b_idx, ch_range2])
  162. # subplot 2
  163. [lines2[zz].set_xdata(xx) for zz in range(2)]
  164. [lines2[zz].set_ydata(pca_res1[:,zz]) for zz in range(2)]
  165. [lines2[zz].set_xdata(xx) for zz in range(2,4)]
  166. [lines2[zz].set_ydata(pca_res2[:,zz-2]) for zz in range(2,4)]
  167. # subplot 3
  168. # log.warning(f'{ch_range1}, {ch_range2}, {pca_res1}')
  169. lines3[0].set_xdata(pca_res1[:,0])
  170. lines3[0].set_ydata(pca_res1[:,1])
  171. lines3[1].set_xdata(pca_res2[:,0])
  172. lines3[1].set_ydata(pca_res2[:,1])
  173. lines3[2].set_xdata(pca_res1[-1,0])
  174. lines3[2].set_ydata(pca_res1[-1,1])
  175. lines3[3].set_xdata(pca_res2[-1,0])
  176. lines3[3].set_ydata(pca_res2[-1,1])
  177. [ax2.draw_artist(lines2[zz]) for zz in range(4)]
  178. [ax3.draw_artist(lines3[zz]) for zz in range(4)]
  179. fig.canvas.blit(ax1.bbox)
  180. fig.canvas.blit(ax2.bbox)
  181. fig.canvas.blit(ax3.bbox)
  182. time.sleep(t_sleep) # to avoid high cpu usage
  183. time.sleep(t_sleep)
  184. return None
  185. log = aux.log # aux file contains logging configuration
  186. args = aux.args
  187. t_sleep = 0.1
  188. # RATE = 44100
  189. # FREQ = 262
  190. # freq = 100
  191. # length = 0.5*freq/1000.
  192. # amp = 10000
  193. # pa = pyaudio.PyAudio()
  194. # s = pa.open(output=True,
  195. # channels=2,
  196. # rate=RATE,
  197. # format=pyaudio.paInt16)
  198. # # stream_callback=callback)
  199. if __name__ == '__main__':
  200. # ch_ids = [int(item) for item in args.list.split(',')]
  201. # print(ch_ids)
  202. print('-' * 100)
  203. print(Figlet(font='standard', width=120).renderText('K I A P - B C I\nby\nW Y S S - C E N T E R\n\nlive-plot'))
  204. print('-' * 100)
  205. params = aux.load_config()
  206. recording_status = Value('i', 1)
  207. global_buffer_idx = Value('i', 0)
  208. # n_channels = params.daq.n_channels_max - len(params.daq.exclude_channels)
  209. n_channels = params.daq.n_channels
  210. # ch_range1 = list(range(ch_ids[0],ch_ids[1]))
  211. # ch_range2 = list(range(ch_ids[2],ch_ids[3]))
  212. # ch_range1 = list(range(32)) + list(range(96, 128))
  213. ch_range1 = [10,11]
  214. ch_range2 = list(range(32, 96))
  215. print(len(ch_range1), len(ch_range2))
  216. ch_range = ch_range1 + ch_range2
  217. n_channels_plot = len(ch_range)
  218. # n_channels_plot = 10
  219. shared_array_base = mp.Array(ctypes.c_float, params.buffer.length * n_channels)
  220. data_buffer = np.ctypeslib.as_array(shared_array_base.get_obj())
  221. data_buffer = data_buffer.reshape(params.buffer.length, n_channels)
  222. visual = mp.Process(name='visual', target=plot_results, args=(data_buffer, global_buffer_idx))
  223. visual.daemon = True # kill visualization if main app terminates
  224. visual.start()
  225. freq = Value('i', 100)
  226. feedback = mp.Process(name='feedback', target=audio_feedback, args=(freq,))
  227. feedback.daemon = True # kill visualization if main app terminates
  228. feedback.start()
  229. # time.sleep(10)
  230. processes = [visual]
  231. pids = [mp.current_process().pid, visual.pid]
  232. print(f'pids: {pids}')
  233. os.system(f'taskset -p -c 3 {mp.current_process().pid}') # set process affinity to core 3
  234. os.system(f'taskset -p -c 4 {visual.pid}') # set process affinity to core 4
  235. ck = cc.CereConn(withSRE=True)
  236. ck.send_open()
  237. t = time.time()
  238. while ck.get_state() != cc.ccS_Idle:
  239. time.sleep(0.005)
  240. # get only unit 0, caution: switch of spike sorting in central
  241. ck.set_spike_rate_estimator_loop_interval_ms(params.daq.spike_rates.loop_interval)
  242. ck.set_spike_band_power_estimator_loop_interval_ms(params.daq.spike_rates.loop_interval)
  243. ch_u_list = [(ii, 0) for ii in range(1, params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
  244. ch_list = [ii for ii in range(1,params.daq.n_channels_max + 1) if ii not in params.daq.exclude_channels]
  245. ck.set_spike_rate_estimator_ch_u_list(ch_u_list)
  246. ck.set_spike_band_power_estimator_ch_list(ch_list)
  247. ck.set_spike_rate_estimation_method_exponential(params.daq.spike_rates.decay_factor, params.daq.spike_rates.max_bins)
  248. # Set CAR, both for LFP (1kHz data) and for raw data (used for SBP)
  249. if params.daq.car_channels:
  250. ck.set_car_channels(2, params.daq.car_channels)
  251. ck.set_car_channels(6, params.daq.car_channels)
  252. update_ch_map()
  253. time.sleep(0.5)
  254. # gids = [5]
  255. # bin_width = 0.005
  256. # run_time = 1
  257. ck.send_record()
  258. offset = np.arange(0,params.daq.n_channels)*40
  259. while True:
  260. idx = global_buffer_idx.value
  261. rates = ck.get_spike_rate_data()
  262. # ts = data['ts']
  263. data = rates['rates']
  264. # print(data.shape)
  265. if data.shape[0]>0:
  266. data = data+offset
  267. if idx + data.shape[0] >= data_buffer.shape[0]:
  268. idx = 0
  269. data_buffer[idx:idx + data.shape[0], :data.shape[1]] = data
  270. global_buffer_idx.value = idx + data.shape[0]
  271. if data.shape[0] == 0:
  272. freq.value = 100
  273. else:
  274. freq.value = int(data[1,1])*2
  275. print(freq.value)
  276. # print(data.shape)
  277. # print(data[1,1])
  278. # time.sleep(0.01)
  279. # freq.value = np.random.randint(50,250)
  280. time.sleep(t_sleep)
  281. ck.send_close()