''' description: offline analysis of kiap data author: Ioannis Vlachos date: 12.12.2018 Copyright (c) 2018 Ioannis Vlachos. All rights reserved.''' from __future__ import division, print_function import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D from sklearn.decomposition.pca import PCA import scipy.signal as signal from scipy import fftpack from aux import log import os from scipy import stats def exclude_channels(data, exclude_channels): '''Note: ids in exclude channels correspond to recorded data channels, which may be less then max channels''' channel_mask = [ii for ii in range(data[0,0].shape[1]) if ii not in exclude_channels] for ii in range(data.size): # log.info(f'channel_mask: {channel_mask}') if len(channel_mask) <= data[ii,0].shape[1]: data[ii, 0] = data[ii, 0][:,channel_mask] else: log.warning(f'channel mask bigger than number of channels in data. channel_mask: {len(channel_mask)}, data channels: {data[0,0].shape[1]}') return data def correct_init_baseline(data, idx1, offset): for ii in range(data.size): tmp = np.max(data[ii,0][:idx1,:],axis=0)-np.min(data[ii,0][:idx1,:],axis=0) # data[ii, 0][idx1:,:] = data[ii,0][idx1:,:]*(tmp+offset)+tmp data[ii, 0][idx1:,:] = (data[ii,0][idx1:,:]-tmp)/(tmp+offset) data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+offset) # data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+20) return data def plot_channels(data, ch_ids, fs, array_msk = [], i_start=0, i_stop=-1, step=1, color='C0'): '''plot all data from array1 or array2 in the time domain''' data = stats.zscore(data) if ch_ids == []: ch_ids = range(data.shape[1]) if i_stop == -1: i_stop = data.shape[0] var1 = 4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis] var1[np.isnan(var1)] = 0 offset = np.cumsum(var1).T plt.figure() plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1) if array_msk != []: # highlight excluded channels plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, array_msk] + offset[array_msk], color='C3', lw=1) plt.xlabel('Time (sec)') plt.yticks(offset[::step], range(0, len(ch_ids), step)) plt.ylim(0, offset[-1] + 4) # plt.title(f'raw data from array {arr_id}') plt.tight_layout() return None def plot_channel_pdf(data, xlim=50, ylim=5000, arr_id=1): d0,d1 = data.shape for ii in range(64): ax = plt.subplot(8,8,ii+1) if arr_id == 1: label_id = ii + 32 else: if ii < 32: label_id = ii else: label_id = ii + 64 plt.hist(data[:,ii], label=f'{label_id}') plt.xlim(0, xlim) plt.ylim(0, ylim) plt.legend() if ii<56: # if (ii // 8) < (d1 // 8): ax.set_xticks([]) if np.mod(ii, 8) != 0: ax.set_yticks([]) plt.draw() plt.show() return None def plot_pca1(clf): plt.figure() plt.clf() for ii in range(clf.psth[0].shape[1]): # steps in time aa = clf.psth[0][:,ii,:] bb = clf.psth[1][:,ii,:] # cc = clf.psth[2][:,ii,:2] data = np.vstack((aa,bb)) if len(clf.psth) == 3: cc = clf.psth[2][:,ii,:2] data = np.vstack((data,cc)) res = PCA(n_components=2).fit_transform(data) plt.cla() plt.plot(res[:6,0], res[:6,1],'C0o') plt.plot(res[6:17,0], res[6:17,1],'C1o') plt.plot(res[17:,0], res[17:,1],'C2o') plt.title(f'sample idx: {ii}') # plt.xlim(-100,100) # plt.ylim(-100,100) plt.draw() input() plt.show() return None def plot_pca_2D(clf, ch_list, stim_id=0): fig = plt.figure(2) plt.clf() # ax = fig.add_subplot(111, projection='3d') # for ii in range(clf.psth[0].shape[1]): aa = clf.psth[0][stim_id,:,ch_list] bb = clf.psth[1][stim_id,:,ch_list] data = np.vstack((aa,bb)) if len(clf.psth) == 3: cc = clf.psth[2][:,ii,:2] data = np.vstack((data,cc)) res = PCA(n_components=2).fit_transform(data.T) plt.plot(clf.psth_xx, res[:,0]) plt.plot(clf.psth_xx, res[:,1]) plt.title(f'PC1, PC2 vs time, stim_id: {stim_id}') # plt.xlim(0,ii) plt.ylim(-100,100) # plt.set_zlim(-100,100) plt.draw() # input() plt.show() return None def plot_pca_3D(clf, ch_list): fig = plt.figure(2) plt.clf() ax = fig.add_subplot(111, projection='3d') for ii in range(clf.psth[0].shape[1]): aa = clf.psth[0][:,ii,ch_list] bb = clf.psth[1][:,ii,ch_list] data = np.vstack((aa,bb)) if len(clf.psth) == 3: cc = clf.psth[2][:,ii,:2] data = np.vstack((data,cc)) res = PCA(n_components=2).fit_transform(data) ax.scatter(res[:6,0]*0+ii, res[:6,0], res[:6,1],color = 'C0') ax.scatter(res[6:17,0]*0+ii, res[6:17,0], res[6:17,1],color = 'C1') ax.scatter(res[17:,0]*0+ii, res[17:,0], res[17:,1],color = 'C2') plt.title(f'sample idx: {ii}') plt.xlim(0,ii) plt.ylim(-100,100) ax.set_zlim(-100,100) plt.draw() # input() plt.show() return None def plot_psth(clf, ylim=[0,10], ch_ids=[0,1,2,3]): plt.figure(2) # plot psth of first four channels plt.clf() try: ax1 = plt.subplot(221) tmp = clf.psth[0][:,:,0] # tmp = signal.savgol_filter(tmp,51,3) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].T,'C1', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].mean(0),'C1', lw=2) # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].T,'C0', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].mean(0),'C0', lw=2) ax1.set_title('Channel 0') # plt.ylim(ylim) ax2 = plt.subplot(222) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].T,'C1', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].mean(0),'C1', lw=2) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].T,'C0', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].mean(0),'C0', lw=2) # plt.ylim(ylim) ax2 = plt.subplot(223) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].T,'C1', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].mean(0),'C1', lw=2) # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].T,'C0', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].mean(0),'C0', lw=2) ax1.set_title('Channel 0') # plt.ylim(ylim) ax2 = plt.subplot(224) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].T,'C1', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].mean(0),'C1', lw=2) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].T,'C0', lw=1, alpha=0.5) plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].mean(0),'C0', lw=2) # plt.ylim(ylim) except Exception as e: log.warning(e) log.warning('Not all channels plotted') plt.draw() plt.show() return None def plot_psth2(clf, ch_ids, fig_name=''): psth1 = clf.psth[0] # trials x tt xx n_ch psth2 = clf.psth[1] psth_tot = np.vstack((psth1,psth2)) mu1 = np.mean(psth1[:,:,ch_ids],axis=2) # average accross channels yes mu2 = np.mean(psth2[:,:,ch_ids],axis=2) # average accross channels no mu3 = np.mean(psth1[:,:,:],axis=0) # average accross trials yes mu4 = np.mean(psth2[:,:,:],axis=0) # average accross trials no mu5 = np.mean(psth_tot, axis=0) # stack yes and no trials res1 = PCA(n_components=2).fit_transform(mu3) # pca space of channels yes res2 = PCA(n_components=2).fit_transform(mu4) # pca space of channels no res3 = PCA(n_components=2).fit_transform(mu5) # pca space of channels yes and no plt.clf() plt.subplot(321) plt.title(f'PSTH, yes, ch={ch_ids[0]}, n={psth1.shape[0]}') plt.plot(clf.psth_xx, mu1.T, alpha=0.5) # plot all trials averaged accross channels - yes plt.plot(clf.psth_xx, mu1.mean(axis=0),'k') # plot mean of above plt.plot(clf.psth_xx, np.median(mu1, axis=0),'k--') # plot mean of above plt.ylabel('all trials,\naverage accross channels') plt.subplot(322) plt.title(f'PSTH, no, {ch_ids[0]}, n={psth2.shape[0]}') plt.plot(clf.psth_xx, mu2.T, alpha=0.5) # plot all trials average accross channels - no plt.plot(clf.psth_xx, mu2.mean(axis=0), 'k') # plot mean of above plt.plot(clf.psth_xx, np.median(mu2, axis=0),'k--') # plot mean of above plt.subplot(323) plt.plot(clf.psth_xx, np.mean(psth1[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean') plt.plot(clf.psth_xx, np.median(psth1[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median') plt.plot(clf.psth_xx, np.min(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='min') plt.plot(clf.psth_xx, np.max(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='max') # plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1),'k', label='mean channels') # plt.plot(clf.psth_xx, np.median(np.median(psth1[:,:,ch_ids],axis=0), axis=1),'k--', label='median') if len(ch_ids)>2: plt.plot(clf.psth_xx,res1[:,0],'C3', label='PC1') plt.plot(clf.psth_xx,res1[:,1],'C4', label='PC2') plt.ylabel('Selected channels,\naverage accross trials') plt.legend(loc=1, prop={'size': 6}) plt.subplot(324) plt.plot(clf.psth_xx, np.mean(psth2[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean') plt.plot(clf.psth_xx, np.median(psth2[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median') plt.plot(clf.psth_xx, np.min(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='min') plt.plot(clf.psth_xx, np.max(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='max') # plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1),'k', label='average') # plt.plot(clf.psth_xx, np.median(np.median(psth2[:,:,ch_ids],axis=0), axis=1),'k--', label='median') if len(ch_ids)>2: plt.plot(clf.psth_xx,res2[:,0],'C3', label='PC1') plt.plot(clf.psth_xx,res2[:,1],'C4', label='PC2') plt.legend(loc=1, prop={'size': 6}) plt.subplot(325) plt.plot(clf.psth_xx, mu1.mean(axis=0), 'C2', label='yes mean') plt.plot(clf.psth_xx, mu2.mean(axis=0), 'C1', label='no mean') plt.plot(clf.psth_xx, np.median(mu1, axis=0), 'C2--', label='yes median') plt.plot(clf.psth_xx, np.median(mu2, axis=0), 'C1--', label='no median') plt.legend(loc=1, prop={'size': 6}) plt.subplot(326) plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1), 'C2', label='yes mean') plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1), 'C1', label='no mean') plt.legend(loc=1, prop={'size': 6}) if fig_name !='': fig_name = f'{fig_name}_ch_{ch_ids[0]}' print(f'saving figure as {fig_name}') plt.savefig(fig_name) return None # CODE writte by Marcos __author__ = "Marcos Duarte, https://github.com/demotu/BMC" __version__ = "1.0.5" __license__ = "MIT" def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False, show=False, ax=None, width=1): """Detect peaks in data based on their amplitude and other features. Parameters ---------- x : 1D array_like data. mph : {None, number}, optional (default = None) detect peaks that are greater than minimum peak height (if parameter `valley` is False) or peaks that are smaller than maximum peak height (if parameter `valley` is True). mpd : positive integer, optional (default = 1) detect peaks that are at least separated by minimum peak distance (in number of data). threshold : positive number, optional (default = 0) detect peaks (valleys) that are greater (smaller) than `threshold` in relation to their immediate neighbors. edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising') for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None). kpsh : bool, optional (default = False) keep peaks with same height even if they are closer than `mpd`. valley : bool, optional (default = False) if True (1), detect valleys (local minima) instead of peaks. show : bool, optional (default = False) if True (1), plot data in matplotlib figure. ax : a matplotlib.axes.Axes instance, optional (default = None). width : positive integer, optional (default = 1) Required width of peaks in samples above or equal to mph. Returns ------- ind : 1D array_like indeces of the peaks in `x`. Notes ----- The detection of valleys instead of peaks is performed internally by simply negating the data: `ind_valleys = detect_peaks(-x)` The function can handle NaN's See this IPython Notebook [1]_. References ---------- .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb Examples -------- >>> from detect_peaks import detect_peaks >>> x = np.random.randn(100) >>> x[60:81] = np.nan >>> # detect all peaks and plot data >>> ind = detect_peaks(x, show=True) >>> print(ind) >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5 >>> # set minimum peak height = 0 and minimum peak distance = 20 >>> detect_peaks(x, mph=0, mpd=20, show=True) >>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0] >>> # set minimum peak distance = 2 >>> detect_peaks(x, mpd=2, show=True) >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5 >>> # detection of valleys instead of peaks >>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True) >>> x = [0, 1, 1, 0, 1, 1, 0] >>> # detect both edges >>> detect_peaks(x, edge='both', show=True) >>> x = [-2, 1, -2, 2, 1, 1, 3, 0] >>> # set threshold = 2 >>> detect_peaks(x, threshold = 2, show=True) Version history --------------- '1.0.5': The sign of `mph` is inverted if parameter `valley` is True '1.0.6': @Espinosa: Peak width added """ x = np.atleast_1d(x).astype('float64') if x.size < 3: return np.array([], dtype=int) if valley: x = -x if mph is not None: mph = -mph # find indices of all peaks dx = x[1:] - x[:-1] # handle NaN's indnan = np.where(np.isnan(x))[0] if indnan.size: x[indnan] = np.inf dx[np.where(np.isnan(dx))[0]] = np.inf ine, ire, ife = np.array([[], [], []], dtype=int) if not edge: ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0] else: if edge.lower() in ['rising', 'both']: ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0] if edge.lower() in ['falling', 'both']: ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0] ind = np.unique(np.hstack((ine, ire, ife))) # handle NaN's if ind.size and indnan.size: # NaN's and values close to NaN's cannot be peaks ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)] # first and last values of x cannot be peaks if ind.size and ind[0] == 0: ind = ind[1:] if ind.size and ind[-1] == x.size-1: ind = ind[:-1] # remove peaks < minimum peak height if ind.size and mph is not None: ind = ind[x[ind] >= mph] # remove peaks - neighbors < threshold if ind.size and threshold > 0: dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0) ind = np.delete(ind, np.where(dx < threshold)[0]) # enforce peaks above mph have minimum width (added by Espinosa) if ind.size and width > 1: # ind_in = np.ones((ind.size, 1), dtype=bool) ind_in = ind < x.size - width for index in range(np.sum(ind_in)): ind_in[index] = np.all(x[ ind[index]:ind[index] + width ] >= mph) ind = ind[ind_in] ind += width # detect small peaks closer than minimum peak distance if ind.size and mpd > 1: # ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height idel = np.zeros(ind.size, dtype=bool) for i in range(ind.size): if not idel[i]: # keep peaks with the same height if kpsh is True idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \ & (x[ind[i]] > x[ind] if kpsh else True) idel[i] = 0 # Keep current peak # remove the small peaks and sort back the indices by their occurrence ind = np.sort(ind[~idel]) if show: if indnan.size: x[indnan] = np.nan if valley: x = -x if mph is not None: mph = -mph _plot(x, mph, mpd, threshold, edge, valley, ax, ind) return ind def _plot(x, mph, mpd, threshold, edge, valley, ax, ind): """Plot results of the detect_peaks function, see its help.""" try: import matplotlib.pyplot as plt except ImportError: print('matplotlib is not available.') else: if ax is None: _, ax = plt.subplots(1, 1, figsize=(8, 4)) ax.plot(x, 'b', lw=1) if ind.size: label = 'valley' if valley else 'peak' label = label + 's' if ind.size > 1 else label ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8, label='%d %s' % (ind.size, label)) ax.legend(loc='best', framealpha=.5, numpoints=1) ax.set_xlim(-.02*x.size, x.size*1.02-1) ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max() yrange = ymax - ymin if ymax > ymin else 1 ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange) ax.set_xlabel('Data #', fontsize=14) ax.set_ylabel('Amplitude', fontsize=14) mode = 'Valley detection' if valley else 'Peak detection' ax.set_title("%s (mph=%s, mpd=%d, threshold=%s, edge='%s')" % (mode, str(mph), mpd, str(threshold), edge)) # plt.grid() plt.show() # SPECTRAL METHODS def bw_filter_coeff(fc, fs, order=5, btype='low'): nyq = 0.5 * fs fc_norm = np.array(fc) / nyq b, a = signal.butter(order, fc_norm, btype=btype, analog=False) return b, a def bw_filter(data, fc, fs, order=5, plot=False): if fc[1] == 0: b, a = bw_filter_coeff(fc[0], fs, order=order, btype='lowpass') elif fc[0] == 0: b, a = bw_filter_coeff(fc[1], fs, order=order, btype='highpass') else: b, a = bw_filter_coeff(fc, fs, order=order, btype='bandpass') y = signal.filtfilt(b, a, data) if plot: plot_freq_resp(fc, fs, b, a) # print(f'filter coeffs: {b}\n{a}') return y.T, b, a def calc_fft(x, fs, i_stop=1e5, axis=-1): '''x: ndarray, (samples x channels) fs: sampling frequency i_stop: calculate fft up to this sample index ''' i_stop = int(min(i_stop, len(x))) x = x[:i_stop, :] # hann = np.hanning(x.shape[0])[:,np.newaxis] # hann = np.repeat(hann,x.shape[1],axis=1) # x = x * hann x_fft = np.abs(fftpack.fft((x - np.mean(x, axis=axis)), axis=axis) / x.shape[axis])**2 n = x_fft.shape[axis] x_fft_freq = fftpack.fftfreq(n, d=1. / fs) x_fft = fftpack.fftshift(x_fft, axes=axis) x_fft_freq = fftpack.fftshift(x_fft_freq) return x_fft, x_fft_freq def plot_freq_resp(fc, fs, b, a): w, h = signal.freqz(b, a, worN=8000) plt.figure() plt.clf() # plt.subplot(311) plt.plot(0.5*fs*w/np.pi, np.abs(h), 'b') if fc[0]>0: plt.plot(fc[0], 0.5*np.sqrt(2), 'ko') plt.axvline(fc[0], color='k',ls='--', alpha=0.5) if fc[1]>0: plt.plot(fc[1], 0.5*np.sqrt(2), 'ko') plt.axvline(fc[1], color='k', ls='--', alpha=0.5) plt.xlim(0, 0.5*fs) plt.title(f'Filter Frequency Response, [{fc[0]}-{fc[1]}] Hz') plt.xlabel('Frequency [Hz]') plt.ylabel('H(e^jw)') plt.grid() return None