sbp_analysis1.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. '''
  2. description: offline analysis of kiap data
  3. author: Ioannis Vlachos
  4. date: 02.04.2019
  5. Copyright (c) 2019 Ioannis Vlachos.
  6. All rights reserved.'''
  7. import re
  8. import os
  9. import glob
  10. import importlib
  11. import urllib
  12. import time
  13. import numpy as np
  14. import matplotlib.pyplot as plt
  15. import matplotlib as mpl
  16. from scipy import signal
  17. from scipy import stats
  18. from sklearn.decomposition.pca import PCA
  19. from cachetools import cached, LRUCache, TTLCache
  20. from neo.io.blackrockio import BlackrockIO
  21. import aux
  22. from aux import log
  23. from analytics import analytics1
  24. importlib.reload(analytics1)
  25. # @cached(cache={})
  26. def read_data():
  27. '''returns first array1 and then array2 channels'''
  28. data = reader.get_analogsignal_chunk(i_stop=eval(str(params.lfp.i_stop)),
  29. channel_indexes=list(set(params.lfp.array1) | set(params.lfp.array2)))
  30. return data
  31. def get_triggers(n_triggers=10, verbose=True):
  32. try:
  33. triggers_ts = [tt[0] for tt in reader.nev_data['Comments'][0]]
  34. triggers_txt = [tt[5].decode('utf-8') for tt in reader.nev_data['Comments'][0]]
  35. except:
  36. log.error('Triggers could not be loaded. Does Nev file exist?')
  37. return None, None
  38. if verbose:
  39. print(reader.nev_data['Comments'])
  40. triggers_yes = []
  41. triggers_no = []
  42. for ii in range(len(triggers_txt)):
  43. if 'yes, response, start' in triggers_txt[ii]:
  44. triggers_yes.append(triggers_ts[ii])
  45. elif 'no, response, start' in triggers_txt[ii]:
  46. triggers_no.append(triggers_ts[ii])
  47. triggers_yes = np.array(triggers_yes)
  48. triggers_no = np.array(triggers_no)
  49. # get timestamps for ns2
  50. triggers_yes_ns2 = triggers_yes // params.lfp.sampling_ratio
  51. triggers_no_ns2 = triggers_no // params.lfp.sampling_ratio
  52. return triggers_no_ns2[:n_triggers], triggers_yes_ns2[:n_triggers]
  53. def get_valid_trials(triggers_no_ns2, triggers_yes_ns2):
  54. trials_msk = np.ones((10, 2), dtype=bool)
  55. for ii, tt in enumerate(triggers_no_ns2):
  56. tmp = data1[tt + psth_win[0]:tt + psth_win[1], :]
  57. if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
  58. trials_msk[ii, 0] = False
  59. for ii, tt in enumerate(triggers_yes_ns2):
  60. tmp = data1[tt + psth_win[0]:tt + psth_win[1], :]
  61. if (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
  62. trials_msk[ii, 1] = False
  63. print(ii)
  64. print('\nTrials excluded', np.where(trials_msk[:, 0] == False)[0], np.where(trials_msk[:, 1] == False)[0])
  65. return trials_msk
  66. def calc_cc():
  67. '''calculate time resolved correlation coefficients betweech channels of each array'''
  68. cc_step = 5000
  69. rng = range(0, data1.shape[0], cc_step)
  70. cc1 = np.zeros((len(rng), 3))
  71. cc2 = np.zeros((len(rng), 3))
  72. hist1 = np.zeros((len(rng), 100))
  73. hist2 = np.zeros((len(rng), 100))
  74. for ii, idx in enumerate(rng):
  75. # cc1[ii, 0] = np.min(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
  76. # cc1[ii, 1] = np.max(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
  77. # cc1[ii, 2] = np.mean(np.abs(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1)))
  78. cc1[ii, 0] = np.min(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))
  79. cc1[ii, 1] = np.max(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))
  80. cc1[ii, 2] = np.mean(np.triu(np.corrcoef(data01[idx:idx + cc_step, :].T), k=1))
  81. # hist1[ii, :] = np.histogram(np.triu(np.corrcoef(data0[idx:idx + cc_step, array1].T)).flatten(), 100)[1][:-1]
  82. # cc2[ii, 0] = np.min(np.abs(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1)))
  83. # cc2[ii, 1] = np.max(np.abs(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1)))
  84. # cc2[ii, 2] = np.mean(np.abs(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1)))
  85. cc2[ii, 0] = np.min(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1))
  86. cc2[ii, 1] = np.max(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1))
  87. cc2[ii, 2] = np.mean(np.triu(np.corrcoef(data02[idx:idx + cc_step, :].T), k=1))
  88. # hist2[ii, :] = np.histogram(np.triu(np.corrcoef(data0[idx:idx + cc_step, array2].T)), 100)[1][:-1]
  89. # print(ii)
  90. if params.lfp.plot.general:
  91. plt.figure()
  92. plt.plot(cc1)
  93. plt.gca().set_prop_cycle(None)
  94. plt.plot(cc2, '-.')
  95. plt.show()
  96. return None
  97. def compute_psth(triggers_yes_ns2, triggers_no_ns2):
  98. psth_len = np.diff(psth_win)[0]
  99. psth_no = np.zeros((len(triggers_no_ns2), psth_len, ch_nr1))
  100. psth_yes = np.zeros((len(triggers_yes_ns2), psth_len, ch_nr1))
  101. for ii in range(len(triggers_no_ns2)):
  102. if (triggers_no_ns2[ii] + psth_win[0] > 0) & (triggers_no_ns2[ii] + psth_win[1] < len(data2)):
  103. tmp = data01[(triggers_no_ns2[ii] + psth_win[0]):(triggers_no_ns2[ii] + psth_win[1]), :]
  104. # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
  105. psth_no[ii, :, :] = tmp
  106. # else:
  107. # print(f'Excluded no-trial: {ii}')
  108. for ii in range(len(triggers_yes_ns2)):
  109. if (triggers_yes_ns2[ii] + psth_win[0] > 0) & (triggers_yes_ns2[ii] + psth_win[1] < len(data2)):
  110. tmp = data01[(triggers_yes_ns2[ii] + psth_win[0]):(triggers_yes_ns2[ii] + psth_win[1]), :]
  111. # if not (np.any(tmp > params.lfp.artifact_thr) or np.any(tmp < -params.lfp.artifact_thr)):
  112. psth_yes[ii, :, :] = tmp
  113. # else:
  114. # print(f'Excluded yes-trial: {ii}')
  115. psth_yes, psth_no = remove_trials(psth_yes, psth_no)
  116. return psth_no, psth_yes
  117. def remove_trials(psth_yes, psth_no):
  118. var1 = np.var(psth_yes, axis=1)
  119. var2 = np.var(psth_no, axis=1)
  120. L1 = np.percentile(var1, 5, axis=0, keepdims=True)
  121. L2 = np.percentile(var2, 5, axis=0, keepdims=True)
  122. U1 = np.percentile(var1, 95, axis=0, keepdims=True)
  123. U2 = np.percentile(var2, 95, axis=0, keepdims=True)
  124. w = 10
  125. valid1 = [L1 - w * (U1 - L1), U1 + w * (U1 - L1)]
  126. valid2 = [L2 - w * (U2 - L2), U2 + w * (U2 - L2)]
  127. # array1_msk = np.unique((np.where(var1 < valid1[0])[0] or np.where(var1 > valid1[1])[0]))
  128. # array2_msk = np.unique((np.where(var2 < valid2[0])[0] or np.where(var2 > valid2[1])[0]))
  129. array1_msk = ((var1 < valid1[0]) | (var1 > valid1[1])).nonzero()[0]
  130. array2_msk = ((var2 < valid2[0]) | (var2 > valid2[1])).nonzero()[0]
  131. print(f'excluded yes-trials: {array1_msk}')
  132. print(f'excluded no-trials: {array2_msk}')
  133. psth_yes = np.delete(psth_yes, array1_msk, axis=0)
  134. psth_no = np.delete(psth_no, array2_msk, axis=0)
  135. return psth_yes, psth_no
  136. def plot_channels(data, ch_ids, fs, i_start=0, i_stop=10000, step=5, color='C0'):
  137. if ch_ids == []:
  138. ch_ids = range(data.shape[1])
  139. if i_stop == -1:
  140. i_stop = data.shape[0]
  141. # data = data
  142. offset = np.cumsum(4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis]).T
  143. plt.figure()
  144. plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1)
  145. plt.xlabel('Time (sec)')
  146. plt.yticks(offset[::step], range(1, len(ch_ids), step))
  147. plt.ylim(0, offset[-1] + 4)
  148. plt.tight_layout()
  149. return None
  150. def plot_psth(ff=[], ch_id=0, trial_id=0, step=10):
  151. yy = data01.std()
  152. # plot psth for visual inspection
  153. xx1 = np.arange(data1.shape[0]) / params.lfp.fs
  154. xx2 = np.arange(psth_win[0], psth_win[1]) / params.lfp.fs
  155. f_idx = ff < 150
  156. plt.figure(2, figsize=(10, 10))
  157. plt.clf()
  158. plt.subplot(311)
  159. plt.plot(xx1[::step], data01[::step, 0])
  160. plt.plot(xx1[::step], data1[::step, :, 2], 'C1', alpha=0.2) # low band
  161. plt.stem(triggers_yes_ns2 / params.lfp.fs, triggers_yes_ns2 * 0 + yy * 0.8, markerfmt='C2o') # green
  162. plt.stem(triggers_no_ns2 / params.lfp.fs, triggers_no_ns2 * 0 + yy * 0.8, markerfmt='C3o') # red
  163. plt.ylim(-yy, yy)
  164. plt.subplot(323)
  165. plt.plot(xx2, psth_no.mean(0)[:, params.lfp.plot.ch_ids]) # average accross trials, 1 channel
  166. plt.plot(xx2, psth_no.mean(0).mean(1), 'C3', alpha=0.5)
  167. plt.subplot(324)
  168. plt.plot(xx2, psth_yes.mean(0)[:, params.lfp.plot.ch_ids])
  169. plt.plot(xx2, psth_yes.mean(0).mean(1), 'C3', alpha=0.5)
  170. plt.subplot(325)
  171. # plt.pcolormesh(tt, ff[f_idx], Sxx1[f_idx, ch_id, :])
  172. if params.lfp.normalize:
  173. plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :] / Sxx1_norm[f_idx, :, ch_id])
  174. else:
  175. plt.pcolormesh(tt, ff[f_idx], Sxx1[trial_id, f_idx, ch_id, :])
  176. vmax = max(Sxx1[trial_id, f_idx, ch_id, :].max(), Sxx2[trial_id, f_idx, ch_id, :].max())
  177. plt.clim(vmax=vmax)
  178. plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
  179. plt.subplot(326)
  180. if params.lfp.normalize:
  181. plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :] / Sxx2_norm[f_idx, :, ch_id])
  182. else:
  183. plt.pcolormesh(tt, ff[f_idx], Sxx2[trial_id, f_idx, ch_id, :])
  184. plt.clim(vmax=vmax)
  185. plt.colorbar(orientation='horizontal', fraction=0.05, aspect=50)
  186. plt.tight_layout()
  187. plt.show()
  188. return None
  189. params = aux.load_config()
  190. params.lfp.array1 = np.delete(params.lfp.array1, params.lfp.array1_exclude)
  191. params.lfp.array2 = np.delete(params.lfp.array2, params.lfp.array2_exclude)
  192. ch_nr1 = len(params.lfp.array1)
  193. ch_nr2 = len(params.lfp.array2)
  194. psth_win = params.lfp.psth_win
  195. # file_name = '/home/vlachos/data_kiap/20190321-135311/20190321-135311-001'
  196. # path = '/home/kiap/Documents/kiap/k01/20190414-123429/'
  197. path = '/media/vlachos/kiap_backup/Recordings/K01/Recordings/20190415-145905/'
  198. # file_name = '/media/vlachos/kiap_backup/Recordings/K01/bci_sessions/20190321-140130/20190321-140130-003'
  199. # file_name = path + '20190414-123429-001.ns2'
  200. file_name = path + '20190415-145905-004'
  201. if not 'data0' in locals():
  202. reader = BlackrockIO(filename=file_name, nsx_to_load=6)
  203. data0 = read_data() # (samples, channels)
  204. data0 = data0.copy().astype(np.float64)
  205. triggers_no_ns2, triggers_yes_ns2 = get_triggers(n_triggers=20)
  206. # trials_msk = get_valid_trials(triggers_no_ns2, triggers_yes_ns2)
  207. array1 = params.lfp.array1
  208. array2 = params.lfp.array2
  209. # data1[data1 > params.lfp.artifact_thr] = 0
  210. # data1[data1 < -params.lfp.artifact_thr] = 0
  211. # data0 = data0 - np.mean(data0, 1)[:, np.newaxis]
  212. # reader.nev_data['Comments'][0]
  213. data0 = stats.zscore(data0)
  214. data01 = data0[:, :len(array1)]
  215. data02 = data0[:, len(array1):len(array1) + len(array2)]
  216. data01 = signal.decimate(data01, 2, axis=0) # downsample to 15 kHz
  217. data01, _, _ = analytics1.bw_filter(data0.T, [250, 5000], fs=15000) # bandpass
  218. nperseg = 10000
  219. ff, tt, Sxx1 = signal.spectrogram(data01, 15000, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.5))
  220. f_idx = (ff > 0) & (ff < 5000)
  221. width = 10000
  222. rms = np.empty((0, data01.shape[1]))
  223. for ii in range(0, data01.shape[0], width):
  224. # rms = np.vstack((rms, np.var(data01[ii: ii + width], axis=0)))
  225. rms = np.vstack((rms, np.sqrt(np.mean(data01[ii: ii + width], axis=0)**2)))
  226. print(ii, ii + width)
  227. res = PCA(n_components=3).fit_transform(rms)
  228. plt.clf()
  229. # plt.pcolormesh(tt, ff[f_idx], np.log10(Sxx1[f_idx, 0]))
  230. # plt.plot(tt, Sxx1[:, 3, :].sum(axis=0))
  231. # plt.plot(rms)
  232. plt.plot(res[:, 0], res[:, 1], '.')
  233. # data1 = signal.resample_poly(data0, 1, 2, axis=0)
  234. # y1, ff1 = analytics1.calc_fft(data1, fs=30000, axis=0)
  235. # y11, ff11 = analytics1.calc_fft(data11, fs=15000, axis=0)
  236. # f_idx1 = (ff1 > 0) & (ff1 < 10000)
  237. # f_idx11 = (ff11 > 0) & (ff11 < 10000)
  238. # plt.clf()
  239. # # plt.plot(ff1[f_idx1], y1[f_idx1, 0])
  240. # plt.plot(ff11[f_idx11], y11[f_idx11, 0])
  241. # COMPUTE PSTH
  242. # ----------------------------------
  243. # offset = 600 * params.lfp.fs
  244. # psth_no, psth_yes = compute_psth(triggers_yes_ns2, triggers_no_ns2)
  245. # psth_norm1, psth_norm2 = compute_psth(triggers_yes_ns2 + offset, triggers_no_ns2 + offset)
  246. # # PSTH SPECTROGRAMS
  247. # # ----------------------------------
  248. # nperseg = params.lfp.spectra.spgr_len
  249. # nperseg = min(nperseg, psth_no.shape[1])
  250. # # ff, tt, Sxx1 = signal.spectrogram(psth_no.mean(0), params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  251. # # ff, tt, Sxx2 = signal.spectrogram(psth_yes.mean(0), params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  252. # ff, tt, Sxx1 = signal.spectrogram(psth_no, params.lfp.fs, axis=1, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  253. # ff, tt, Sxx2 = signal.spectrogram(psth_yes, params.lfp.fs, axis=1, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  254. # ff, tt, Sxx1_norm = signal.spectrogram(psth_norm1.mean(0), params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  255. # ff, tt, Sxx2_norm = signal.spectrogram(psth_norm2.mean(0), params.lfp.fs, axis=0, nperseg=nperseg, noverlap=int(nperseg * 0.9))
  256. # # Sxx1 = Sxx1.mean(1)
  257. # # Sxx2 = Sxx2.mean(1)
  258. # Sxx1_norm = Sxx1_norm.mean(2)
  259. # Sxx2_norm = Sxx2_norm.mean(2)
  260. # tt = tt + psth_win[0] / params.lfp.fs
  261. # t_idx = tt < 1
  262. # Sxx1_norm = Sxx1_norm[:, np.newaxis].repeat(len(tt), 1)
  263. # Sxx2_norm = Sxx2_norm[:, np.newaxis].repeat(len(tt), 1)
  264. # # if params.lfp.normalize:
  265. # # Sxx1 = Sxx1 / Sxx1_norm
  266. # # Sxx2 = Sxx2 / Sxx2_norm
  267. # # PLOT RESULTS
  268. # # ----------------------------------
  269. # if params.lfp.plot.general:
  270. # # plot_spectra()
  271. # plot_psth(ff=ff, ch_id=0, trial_id=0, step=100)
  272. # pass
  273. # # xxx
  274. # plt.ioff()
  275. # for trial_id in range(4, 5):
  276. # for ch_id in range(ch_nr1):
  277. # plot_psth(ff=ff, ch_id=ch_id, trial_id=0, step=100)
  278. # fig2 = plt.gcf()
  279. # fname = path + f'results/trial_{trial_id}_ch_{ch_id}.png'
  280. # fig2.savefig(fname)
  281. # print(ch_id)