123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 |
- '''
- description: offline analysis of kiap data
- author: Ioannis Vlachos
- date: 12.12.2018
- Copyright (c) 2018 Ioannis Vlachos.
- All rights reserved.'''
- from __future__ import division, print_function
- import matplotlib.pyplot as plt
- import numpy as np
- from mpl_toolkits.mplot3d import Axes3D
- from sklearn.decomposition.pca import PCA
- import scipy.signal as signal
- from scipy import fftpack
- from aux import log
- import os
- from scipy import stats
- def exclude_channels(data, exclude_channels):
- '''Note: ids in exclude channels correspond to recorded data channels, which may be less then max channels'''
- channel_mask = [ii for ii in range(data[0,0].shape[1]) if ii not in exclude_channels]
- for ii in range(data.size):
- # log.info(f'channel_mask: {channel_mask}')
- if len(channel_mask) <= data[ii,0].shape[1]:
- data[ii, 0] = data[ii, 0][:,channel_mask]
- else:
- log.warning(f'channel mask bigger than number of channels in data. channel_mask: {len(channel_mask)}, data channels: {data[0,0].shape[1]}')
- return data
- def correct_init_baseline(data, idx1, offset):
- for ii in range(data.size):
- tmp = np.max(data[ii,0][:idx1,:],axis=0)-np.min(data[ii,0][:idx1,:],axis=0)
- # data[ii, 0][idx1:,:] = data[ii,0][idx1:,:]*(tmp+offset)+tmp
- data[ii, 0][idx1:,:] = (data[ii,0][idx1:,:]-tmp)/(tmp+offset)
- data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+offset)
- # data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+20)
- return data
- def plot_channels(data, ch_ids, fs, array_msk = [], i_start=0, i_stop=-1, step=1, color='C0'):
- '''plot all data from array1 or array2 in the time domain'''
- data = stats.zscore(data)
- if ch_ids == []:
- ch_ids = range(data.shape[1])
- if i_stop == -1:
- i_stop = data.shape[0]
- var1 = 4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis]
- var1[np.isnan(var1)] = 0
- offset = np.cumsum(var1).T
- plt.figure()
- plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1)
- if array_msk != []: # highlight excluded channels
- plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, array_msk] + offset[array_msk], color='C3', lw=1)
- plt.xlabel('Time (sec)')
- plt.yticks(offset[::step], range(0, len(ch_ids), step))
- plt.ylim(0, offset[-1] + 4)
- # plt.title(f'raw data from array {arr_id}')
- plt.tight_layout()
- return None
- def plot_channel_pdf(data, xlim=50, ylim=5000, arr_id=1):
- d0,d1 = data.shape
- for ii in range(64):
- ax = plt.subplot(8,8,ii+1)
- if arr_id == 1:
- label_id = ii + 32
- else:
- if ii < 32:
- label_id = ii
- else:
- label_id = ii + 64
- plt.hist(data[:,ii], label=f'{label_id}')
- plt.xlim(0, xlim)
- plt.ylim(0, ylim)
- plt.legend()
- if ii<56:
- # if (ii // 8) < (d1 // 8):
- ax.set_xticks([])
- if np.mod(ii, 8) != 0:
- ax.set_yticks([])
- plt.draw()
- plt.show()
- return None
- def plot_pca1(clf):
- plt.figure()
- plt.clf()
-
- for ii in range(clf.psth[0].shape[1]): # steps in time
- aa = clf.psth[0][:,ii,:]
- bb = clf.psth[1][:,ii,:]
- # cc = clf.psth[2][:,ii,:2]
- data = np.vstack((aa,bb))
- if len(clf.psth) == 3:
- cc = clf.psth[2][:,ii,:2]
- data = np.vstack((data,cc))
- res = PCA(n_components=2).fit_transform(data)
- plt.cla()
- plt.plot(res[:6,0], res[:6,1],'C0o')
- plt.plot(res[6:17,0], res[6:17,1],'C1o')
- plt.plot(res[17:,0], res[17:,1],'C2o')
- plt.title(f'sample idx: {ii}')
- # plt.xlim(-100,100)
- # plt.ylim(-100,100)
- plt.draw()
- input()
- plt.show()
- return None
- def plot_pca_2D(clf, ch_list, stim_id=0):
- fig = plt.figure(2)
- plt.clf()
- # ax = fig.add_subplot(111, projection='3d')
- # for ii in range(clf.psth[0].shape[1]):
- aa = clf.psth[0][stim_id,:,ch_list]
- bb = clf.psth[1][stim_id,:,ch_list]
- data = np.vstack((aa,bb))
-
- if len(clf.psth) == 3:
- cc = clf.psth[2][:,ii,:2]
- data = np.vstack((data,cc))
- res = PCA(n_components=2).fit_transform(data.T)
- plt.plot(clf.psth_xx, res[:,0])
- plt.plot(clf.psth_xx, res[:,1])
- plt.title(f'PC1, PC2 vs time, stim_id: {stim_id}')
- # plt.xlim(0,ii)
- plt.ylim(-100,100)
- # plt.set_zlim(-100,100)
- plt.draw()
- # input()
- plt.show()
- return None
- def plot_pca_3D(clf, ch_list):
- fig = plt.figure(2)
- plt.clf()
- ax = fig.add_subplot(111, projection='3d')
-
- for ii in range(clf.psth[0].shape[1]):
- aa = clf.psth[0][:,ii,ch_list]
- bb = clf.psth[1][:,ii,ch_list]
- data = np.vstack((aa,bb))
-
- if len(clf.psth) == 3:
- cc = clf.psth[2][:,ii,:2]
- data = np.vstack((data,cc))
- res = PCA(n_components=2).fit_transform(data)
- ax.scatter(res[:6,0]*0+ii, res[:6,0], res[:6,1],color = 'C0')
- ax.scatter(res[6:17,0]*0+ii, res[6:17,0], res[6:17,1],color = 'C1')
- ax.scatter(res[17:,0]*0+ii, res[17:,0], res[17:,1],color = 'C2')
- plt.title(f'sample idx: {ii}')
- plt.xlim(0,ii)
- plt.ylim(-100,100)
- ax.set_zlim(-100,100)
-
- plt.draw()
- # input()
- plt.show()
- return None
- def plot_psth(clf, ylim=[0,10], ch_ids=[0,1,2,3]):
- plt.figure(2) # plot psth of first four channels
- plt.clf()
- try:
- ax1 = plt.subplot(221)
- tmp = clf.psth[0][:,:,0]
- # tmp = signal.savgol_filter(tmp,51,3)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].T,'C1', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].mean(0),'C1', lw=2)
- # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].T,'C0', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].mean(0),'C0', lw=2)
- ax1.set_title('Channel 0')
- # plt.ylim(ylim)
- ax2 = plt.subplot(222)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].T,'C1', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].mean(0),'C1', lw=2)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].T,'C0', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].mean(0),'C0', lw=2)
- # plt.ylim(ylim)
- ax2 = plt.subplot(223)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].T,'C1', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].mean(0),'C1', lw=2)
- # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].T,'C0', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].mean(0),'C0', lw=2)
- ax1.set_title('Channel 0')
- # plt.ylim(ylim)
- ax2 = plt.subplot(224)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].T,'C1', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].mean(0),'C1', lw=2)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].T,'C0', lw=1, alpha=0.5)
- plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].mean(0),'C0', lw=2)
- # plt.ylim(ylim)
- except Exception as e:
- log.warning(e)
- log.warning('Not all channels plotted')
- plt.draw()
- plt.show()
- return None
- def plot_psth2(clf, ch_ids, fig_name=''):
- psth1 = clf.psth[0] # trials x tt xx n_ch
- psth2 = clf.psth[1]
- psth_tot = np.vstack((psth1,psth2))
- mu1 = np.mean(psth1[:,:,ch_ids],axis=2) # average accross channels yes
- mu2 = np.mean(psth2[:,:,ch_ids],axis=2) # average accross channels no
- mu3 = np.mean(psth1[:,:,:],axis=0) # average accross trials yes
- mu4 = np.mean(psth2[:,:,:],axis=0) # average accross trials no
- mu5 = np.mean(psth_tot, axis=0) # stack yes and no trials
- res1 = PCA(n_components=2).fit_transform(mu3) # pca space of channels yes
- res2 = PCA(n_components=2).fit_transform(mu4) # pca space of channels no
- res3 = PCA(n_components=2).fit_transform(mu5) # pca space of channels yes and no
- plt.clf()
- plt.subplot(321)
- plt.title(f'PSTH, yes, ch={ch_ids[0]}, n={psth1.shape[0]}')
- plt.plot(clf.psth_xx, mu1.T, alpha=0.5) # plot all trials averaged accross channels - yes
- plt.plot(clf.psth_xx, mu1.mean(axis=0),'k') # plot mean of above
- plt.plot(clf.psth_xx, np.median(mu1, axis=0),'k--') # plot mean of above
- plt.ylabel('all trials,\naverage accross channels')
- plt.subplot(322)
- plt.title(f'PSTH, no, {ch_ids[0]}, n={psth2.shape[0]}')
- plt.plot(clf.psth_xx, mu2.T, alpha=0.5) # plot all trials average accross channels - no
- plt.plot(clf.psth_xx, mu2.mean(axis=0), 'k') # plot mean of above
- plt.plot(clf.psth_xx, np.median(mu2, axis=0),'k--') # plot mean of above
- plt.subplot(323)
- plt.plot(clf.psth_xx, np.mean(psth1[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean')
- plt.plot(clf.psth_xx, np.median(psth1[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median')
- plt.plot(clf.psth_xx, np.min(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='min')
- plt.plot(clf.psth_xx, np.max(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='max')
- # plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1),'k', label='mean channels')
- # plt.plot(clf.psth_xx, np.median(np.median(psth1[:,:,ch_ids],axis=0), axis=1),'k--', label='median')
- if len(ch_ids)>2:
- plt.plot(clf.psth_xx,res1[:,0],'C3', label='PC1')
- plt.plot(clf.psth_xx,res1[:,1],'C4', label='PC2')
- plt.ylabel('Selected channels,\naverage accross trials')
- plt.legend(loc=1, prop={'size': 6})
- plt.subplot(324)
- plt.plot(clf.psth_xx, np.mean(psth2[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean')
- plt.plot(clf.psth_xx, np.median(psth2[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median')
- plt.plot(clf.psth_xx, np.min(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='min')
- plt.plot(clf.psth_xx, np.max(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='max')
- # plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1),'k', label='average')
- # plt.plot(clf.psth_xx, np.median(np.median(psth2[:,:,ch_ids],axis=0), axis=1),'k--', label='median')
-
- if len(ch_ids)>2:
- plt.plot(clf.psth_xx,res2[:,0],'C3', label='PC1')
- plt.plot(clf.psth_xx,res2[:,1],'C4', label='PC2')
- plt.legend(loc=1, prop={'size': 6})
- plt.subplot(325)
- plt.plot(clf.psth_xx, mu1.mean(axis=0), 'C2', label='yes mean')
- plt.plot(clf.psth_xx, mu2.mean(axis=0), 'C1', label='no mean')
- plt.plot(clf.psth_xx, np.median(mu1, axis=0), 'C2--', label='yes median')
- plt.plot(clf.psth_xx, np.median(mu2, axis=0), 'C1--', label='no median')
- plt.legend(loc=1, prop={'size': 6})
- plt.subplot(326)
- plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1), 'C2', label='yes mean')
- plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1), 'C1', label='no mean')
- plt.legend(loc=1, prop={'size': 6})
- if fig_name !='':
- fig_name = f'{fig_name}_ch_{ch_ids[0]}'
- print(f'saving figure as {fig_name}')
- plt.savefig(fig_name)
- return None
- # CODE writte by Marcos
- __author__ = "Marcos Duarte, https://github.com/demotu/BMC"
- __version__ = "1.0.5"
- __license__ = "MIT"
- def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
- kpsh=False, valley=False, show=False, ax=None, width=1):
- """Detect peaks in data based on their amplitude and other features.
- Parameters
- ----------
- x : 1D array_like
- data.
- mph : {None, number}, optional (default = None)
- detect peaks that are greater than minimum peak height (if parameter
- `valley` is False) or peaks that are smaller than maximum peak height
- (if parameter `valley` is True).
- mpd : positive integer, optional (default = 1)
- detect peaks that are at least separated by minimum peak distance (in
- number of data).
- threshold : positive number, optional (default = 0)
- detect peaks (valleys) that are greater (smaller) than `threshold`
- in relation to their immediate neighbors.
- edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
- for a flat peak, keep only the rising edge ('rising'), only the
- falling edge ('falling'), both edges ('both'), or don't detect a
- flat peak (None).
- kpsh : bool, optional (default = False)
- keep peaks with same height even if they are closer than `mpd`.
- valley : bool, optional (default = False)
- if True (1), detect valleys (local minima) instead of peaks.
- show : bool, optional (default = False)
- if True (1), plot data in matplotlib figure.
- ax : a matplotlib.axes.Axes instance, optional (default = None).
- width : positive integer, optional (default = 1)
- Required width of peaks in samples above or equal to mph.
- Returns
- -------
- ind : 1D array_like
- indeces of the peaks in `x`.
- Notes
- -----
- The detection of valleys instead of peaks is performed internally by simply
- negating the data: `ind_valleys = detect_peaks(-x)`
-
- The function can handle NaN's
- See this IPython Notebook [1]_.
- References
- ----------
- .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
- Examples
- --------
- >>> from detect_peaks import detect_peaks
- >>> x = np.random.randn(100)
- >>> x[60:81] = np.nan
- >>> # detect all peaks and plot data
- >>> ind = detect_peaks(x, show=True)
- >>> print(ind)
- >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
- >>> # set minimum peak height = 0 and minimum peak distance = 20
- >>> detect_peaks(x, mph=0, mpd=20, show=True)
- >>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
- >>> # set minimum peak distance = 2
- >>> detect_peaks(x, mpd=2, show=True)
- >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
- >>> # detection of valleys instead of peaks
- >>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
- >>> x = [0, 1, 1, 0, 1, 1, 0]
- >>> # detect both edges
- >>> detect_peaks(x, edge='both', show=True)
- >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
- >>> # set threshold = 2
- >>> detect_peaks(x, threshold = 2, show=True)
- Version history
- ---------------
- '1.0.5':
- The sign of `mph` is inverted if parameter `valley` is True
- '1.0.6':
- @Espinosa: Peak width added
- """
- x = np.atleast_1d(x).astype('float64')
- if x.size < 3:
- return np.array([], dtype=int)
- if valley:
- x = -x
- if mph is not None:
- mph = -mph
- # find indices of all peaks
- dx = x[1:] - x[:-1]
- # handle NaN's
- indnan = np.where(np.isnan(x))[0]
- if indnan.size:
- x[indnan] = np.inf
- dx[np.where(np.isnan(dx))[0]] = np.inf
- ine, ire, ife = np.array([[], [], []], dtype=int)
- if not edge:
- ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
- else:
- if edge.lower() in ['rising', 'both']:
- ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
- if edge.lower() in ['falling', 'both']:
- ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
- ind = np.unique(np.hstack((ine, ire, ife)))
- # handle NaN's
- if ind.size and indnan.size:
- # NaN's and values close to NaN's cannot be peaks
- ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
- # first and last values of x cannot be peaks
- if ind.size and ind[0] == 0:
- ind = ind[1:]
- if ind.size and ind[-1] == x.size-1:
- ind = ind[:-1]
- # remove peaks < minimum peak height
- if ind.size and mph is not None:
- ind = ind[x[ind] >= mph]
- # remove peaks - neighbors < threshold
- if ind.size and threshold > 0:
- dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
- ind = np.delete(ind, np.where(dx < threshold)[0])
-
- # enforce peaks above mph have minimum width (added by Espinosa)
- if ind.size and width > 1:
- # ind_in = np.ones((ind.size, 1), dtype=bool)
- ind_in = ind < x.size - width
- for index in range(np.sum(ind_in)):
- ind_in[index] = np.all(x[ ind[index]:ind[index] + width ] >= mph)
- ind = ind[ind_in]
- ind += width
- # detect small peaks closer than minimum peak distance
- if ind.size and mpd > 1:
- # ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
- idel = np.zeros(ind.size, dtype=bool)
- for i in range(ind.size):
- if not idel[i]:
- # keep peaks with the same height if kpsh is True
- idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
- & (x[ind[i]] > x[ind] if kpsh else True)
- idel[i] = 0 # Keep current peak
- # remove the small peaks and sort back the indices by their occurrence
- ind = np.sort(ind[~idel])
-
- if show:
- if indnan.size:
- x[indnan] = np.nan
- if valley:
- x = -x
- if mph is not None:
- mph = -mph
- _plot(x, mph, mpd, threshold, edge, valley, ax, ind)
- return ind
- def _plot(x, mph, mpd, threshold, edge, valley, ax, ind):
- """Plot results of the detect_peaks function, see its help."""
- try:
- import matplotlib.pyplot as plt
- except ImportError:
- print('matplotlib is not available.')
- else:
- if ax is None:
- _, ax = plt.subplots(1, 1, figsize=(8, 4))
- ax.plot(x, 'b', lw=1)
- if ind.size:
- label = 'valley' if valley else 'peak'
- label = label + 's' if ind.size > 1 else label
- ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
- label='%d %s' % (ind.size, label))
- ax.legend(loc='best', framealpha=.5, numpoints=1)
- ax.set_xlim(-.02*x.size, x.size*1.02-1)
- ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
- yrange = ymax - ymin if ymax > ymin else 1
- ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
- ax.set_xlabel('Data #', fontsize=14)
- ax.set_ylabel('Amplitude', fontsize=14)
- mode = 'Valley detection' if valley else 'Peak detection'
- ax.set_title("%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"
- % (mode, str(mph), mpd, str(threshold), edge))
- # plt.grid()
- plt.show()
- # SPECTRAL METHODS
- def bw_filter_coeff(fc, fs, order=5, btype='low'):
- nyq = 0.5 * fs
- fc_norm = np.array(fc) / nyq
- b, a = signal.butter(order, fc_norm, btype=btype, analog=False)
- return b, a
- def bw_filter(data, fc, fs, order=5, plot=False):
- if fc[1] == 0:
- b, a = bw_filter_coeff(fc[0], fs, order=order, btype='lowpass')
- elif fc[0] == 0:
- b, a = bw_filter_coeff(fc[1], fs, order=order, btype='highpass')
- else:
- b, a = bw_filter_coeff(fc, fs, order=order, btype='bandpass')
- y = signal.filtfilt(b, a, data)
- if plot:
- plot_freq_resp(fc, fs, b, a)
- # print(f'filter coeffs: {b}\n{a}')
- return y.T, b, a
- def calc_fft(x, fs, i_stop=1e5, axis=-1):
- '''x: ndarray, (samples x channels)
- fs: sampling frequency
- i_stop: calculate fft up to this sample index
- '''
- i_stop = int(min(i_stop, len(x)))
- x = x[:i_stop, :]
- # hann = np.hanning(x.shape[0])[:,np.newaxis]
- # hann = np.repeat(hann,x.shape[1],axis=1)
- # x = x * hann
- x_fft = np.abs(fftpack.fft((x - np.mean(x, axis=axis)), axis=axis) / x.shape[axis])**2
- n = x_fft.shape[axis]
- x_fft_freq = fftpack.fftfreq(n, d=1. / fs)
- x_fft = fftpack.fftshift(x_fft, axes=axis)
- x_fft_freq = fftpack.fftshift(x_fft_freq)
- return x_fft, x_fft_freq
- def plot_freq_resp(fc, fs, b, a):
-
- w, h = signal.freqz(b, a, worN=8000)
- plt.figure()
- plt.clf()
- # plt.subplot(311)
- plt.plot(0.5*fs*w/np.pi, np.abs(h), 'b')
- if fc[0]>0:
- plt.plot(fc[0], 0.5*np.sqrt(2), 'ko')
- plt.axvline(fc[0], color='k',ls='--', alpha=0.5)
- if fc[1]>0:
- plt.plot(fc[1], 0.5*np.sqrt(2), 'ko')
- plt.axvline(fc[1], color='k', ls='--', alpha=0.5)
- plt.xlim(0, 0.5*fs)
- plt.title(f'Filter Frequency Response, [{fc[0]}-{fc[1]}] Hz')
- plt.xlabel('Frequency [Hz]')
- plt.ylabel('H(e^jw)')
- plt.grid()
- return None
|