analytics1.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. '''
  2. description: offline analysis of kiap data
  3. author: Ioannis Vlachos
  4. date: 12.12.2018
  5. Copyright (c) 2018 Ioannis Vlachos.
  6. All rights reserved.'''
  7. from __future__ import division, print_function
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. from mpl_toolkits.mplot3d import Axes3D
  11. from sklearn.decomposition.pca import PCA
  12. import scipy.signal as signal
  13. from scipy import fftpack
  14. from aux import log
  15. import os
  16. from scipy import stats
  17. def exclude_channels(data, exclude_channels):
  18. '''Note: ids in exclude channels correspond to recorded data channels, which may be less then max channels'''
  19. channel_mask = [ii for ii in range(data[0,0].shape[1]) if ii not in exclude_channels]
  20. for ii in range(data.size):
  21. # log.info(f'channel_mask: {channel_mask}')
  22. if len(channel_mask) <= data[ii,0].shape[1]:
  23. data[ii, 0] = data[ii, 0][:,channel_mask]
  24. else:
  25. log.warning(f'channel mask bigger than number of channels in data. channel_mask: {len(channel_mask)}, data channels: {data[0,0].shape[1]}')
  26. return data
  27. def correct_init_baseline(data, idx1, offset):
  28. for ii in range(data.size):
  29. tmp = np.max(data[ii,0][:idx1,:],axis=0)-np.min(data[ii,0][:idx1,:],axis=0)
  30. # data[ii, 0][idx1:,:] = data[ii,0][idx1:,:]*(tmp+offset)+tmp
  31. data[ii, 0][idx1:,:] = (data[ii,0][idx1:,:]-tmp)/(tmp+offset)
  32. data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+offset)
  33. # data[ii, 0][:idx1,:] = (data[ii,0][:idx1,:]-tmp)/(tmp+20)
  34. return data
  35. def plot_channels(data, ch_ids, fs, array_msk = [], i_start=0, i_stop=-1, step=1, color='C0'):
  36. '''plot all data from array1 or array2 in the time domain'''
  37. data = stats.zscore(data)
  38. if ch_ids == []:
  39. ch_ids = range(data.shape[1])
  40. if i_stop == -1:
  41. i_stop = data.shape[0]
  42. var1 = 4 * np.var(data[:, ch_ids], axis=0)[:, np.newaxis]
  43. var1[np.isnan(var1)] = 0
  44. offset = np.cumsum(var1).T
  45. plt.figure()
  46. plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, ch_ids] + offset, color=color, lw=1)
  47. if array_msk != []: # highlight excluded channels
  48. plt.plot(np.arange(i_start, i_stop) / fs, data[i_start:i_stop, array_msk] + offset[array_msk], color='C3', lw=1)
  49. plt.xlabel('Time (sec)')
  50. plt.yticks(offset[::step], range(0, len(ch_ids), step))
  51. plt.ylim(0, offset[-1] + 4)
  52. # plt.title(f'raw data from array {arr_id}')
  53. plt.tight_layout()
  54. return None
  55. def plot_channel_pdf(data, xlim=50, ylim=5000, arr_id=1):
  56. d0,d1 = data.shape
  57. for ii in range(64):
  58. ax = plt.subplot(8,8,ii+1)
  59. if arr_id == 1:
  60. label_id = ii + 32
  61. else:
  62. if ii < 32:
  63. label_id = ii
  64. else:
  65. label_id = ii + 64
  66. plt.hist(data[:,ii], label=f'{label_id}')
  67. plt.xlim(0, xlim)
  68. plt.ylim(0, ylim)
  69. plt.legend()
  70. if ii<56:
  71. # if (ii // 8) < (d1 // 8):
  72. ax.set_xticks([])
  73. if np.mod(ii, 8) != 0:
  74. ax.set_yticks([])
  75. plt.draw()
  76. plt.show()
  77. return None
  78. def plot_pca1(clf):
  79. plt.figure()
  80. plt.clf()
  81. for ii in range(clf.psth[0].shape[1]): # steps in time
  82. aa = clf.psth[0][:,ii,:]
  83. bb = clf.psth[1][:,ii,:]
  84. # cc = clf.psth[2][:,ii,:2]
  85. data = np.vstack((aa,bb))
  86. if len(clf.psth) == 3:
  87. cc = clf.psth[2][:,ii,:2]
  88. data = np.vstack((data,cc))
  89. res = PCA(n_components=2).fit_transform(data)
  90. plt.cla()
  91. plt.plot(res[:6,0], res[:6,1],'C0o')
  92. plt.plot(res[6:17,0], res[6:17,1],'C1o')
  93. plt.plot(res[17:,0], res[17:,1],'C2o')
  94. plt.title(f'sample idx: {ii}')
  95. # plt.xlim(-100,100)
  96. # plt.ylim(-100,100)
  97. plt.draw()
  98. input()
  99. plt.show()
  100. return None
  101. def plot_pca_2D(clf, ch_list, stim_id=0):
  102. fig = plt.figure(2)
  103. plt.clf()
  104. # ax = fig.add_subplot(111, projection='3d')
  105. # for ii in range(clf.psth[0].shape[1]):
  106. aa = clf.psth[0][stim_id,:,ch_list]
  107. bb = clf.psth[1][stim_id,:,ch_list]
  108. data = np.vstack((aa,bb))
  109. if len(clf.psth) == 3:
  110. cc = clf.psth[2][:,ii,:2]
  111. data = np.vstack((data,cc))
  112. res = PCA(n_components=2).fit_transform(data.T)
  113. plt.plot(clf.psth_xx, res[:,0])
  114. plt.plot(clf.psth_xx, res[:,1])
  115. plt.title(f'PC1, PC2 vs time, stim_id: {stim_id}')
  116. # plt.xlim(0,ii)
  117. plt.ylim(-100,100)
  118. # plt.set_zlim(-100,100)
  119. plt.draw()
  120. # input()
  121. plt.show()
  122. return None
  123. def plot_pca_3D(clf, ch_list):
  124. fig = plt.figure(2)
  125. plt.clf()
  126. ax = fig.add_subplot(111, projection='3d')
  127. for ii in range(clf.psth[0].shape[1]):
  128. aa = clf.psth[0][:,ii,ch_list]
  129. bb = clf.psth[1][:,ii,ch_list]
  130. data = np.vstack((aa,bb))
  131. if len(clf.psth) == 3:
  132. cc = clf.psth[2][:,ii,:2]
  133. data = np.vstack((data,cc))
  134. res = PCA(n_components=2).fit_transform(data)
  135. ax.scatter(res[:6,0]*0+ii, res[:6,0], res[:6,1],color = 'C0')
  136. ax.scatter(res[6:17,0]*0+ii, res[6:17,0], res[6:17,1],color = 'C1')
  137. ax.scatter(res[17:,0]*0+ii, res[17:,0], res[17:,1],color = 'C2')
  138. plt.title(f'sample idx: {ii}')
  139. plt.xlim(0,ii)
  140. plt.ylim(-100,100)
  141. ax.set_zlim(-100,100)
  142. plt.draw()
  143. # input()
  144. plt.show()
  145. return None
  146. def plot_psth(clf, ylim=[0,10], ch_ids=[0,1,2,3]):
  147. plt.figure(2) # plot psth of first four channels
  148. plt.clf()
  149. try:
  150. ax1 = plt.subplot(221)
  151. tmp = clf.psth[0][:,:,0]
  152. # tmp = signal.savgol_filter(tmp,51,3)
  153. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].T,'C1', lw=1, alpha=0.5)
  154. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[0]].mean(0),'C1', lw=2)
  155. # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1)
  156. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].T,'C0', lw=1, alpha=0.5)
  157. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[0]].mean(0),'C0', lw=2)
  158. ax1.set_title('Channel 0')
  159. # plt.ylim(ylim)
  160. ax2 = plt.subplot(222)
  161. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].T,'C1', lw=1, alpha=0.5)
  162. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[1]].mean(0),'C1', lw=2)
  163. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].T,'C0', lw=1, alpha=0.5)
  164. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[1]].mean(0),'C0', lw=2)
  165. # plt.ylim(ylim)
  166. ax2 = plt.subplot(223)
  167. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].T,'C1', lw=1, alpha=0.5)
  168. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[2]].mean(0),'C1', lw=2)
  169. # plt.plot(clf.psth_xx, tmp.T,'C1', lw=1)
  170. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].T,'C0', lw=1, alpha=0.5)
  171. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[2]].mean(0),'C0', lw=2)
  172. ax1.set_title('Channel 0')
  173. # plt.ylim(ylim)
  174. ax2 = plt.subplot(224)
  175. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].T,'C1', lw=1, alpha=0.5)
  176. plt.plot(clf.psth_xx, clf.psth[0][:,:,ch_ids[3]].mean(0),'C1', lw=2)
  177. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].T,'C0', lw=1, alpha=0.5)
  178. plt.plot(clf.psth_xx, clf.psth[1][:,:,ch_ids[3]].mean(0),'C0', lw=2)
  179. # plt.ylim(ylim)
  180. except Exception as e:
  181. log.warning(e)
  182. log.warning('Not all channels plotted')
  183. plt.draw()
  184. plt.show()
  185. return None
  186. def plot_psth2(clf, ch_ids, fig_name=''):
  187. psth1 = clf.psth[0] # trials x tt xx n_ch
  188. psth2 = clf.psth[1]
  189. psth_tot = np.vstack((psth1,psth2))
  190. mu1 = np.mean(psth1[:,:,ch_ids],axis=2) # average accross channels yes
  191. mu2 = np.mean(psth2[:,:,ch_ids],axis=2) # average accross channels no
  192. mu3 = np.mean(psth1[:,:,:],axis=0) # average accross trials yes
  193. mu4 = np.mean(psth2[:,:,:],axis=0) # average accross trials no
  194. mu5 = np.mean(psth_tot, axis=0) # stack yes and no trials
  195. res1 = PCA(n_components=2).fit_transform(mu3) # pca space of channels yes
  196. res2 = PCA(n_components=2).fit_transform(mu4) # pca space of channels no
  197. res3 = PCA(n_components=2).fit_transform(mu5) # pca space of channels yes and no
  198. plt.clf()
  199. plt.subplot(321)
  200. plt.title(f'PSTH, yes, ch={ch_ids[0]}, n={psth1.shape[0]}')
  201. plt.plot(clf.psth_xx, mu1.T, alpha=0.5) # plot all trials averaged accross channels - yes
  202. plt.plot(clf.psth_xx, mu1.mean(axis=0),'k') # plot mean of above
  203. plt.plot(clf.psth_xx, np.median(mu1, axis=0),'k--') # plot mean of above
  204. plt.ylabel('all trials,\naverage accross channels')
  205. plt.subplot(322)
  206. plt.title(f'PSTH, no, {ch_ids[0]}, n={psth2.shape[0]}')
  207. plt.plot(clf.psth_xx, mu2.T, alpha=0.5) # plot all trials average accross channels - no
  208. plt.plot(clf.psth_xx, mu2.mean(axis=0), 'k') # plot mean of above
  209. plt.plot(clf.psth_xx, np.median(mu2, axis=0),'k--') # plot mean of above
  210. plt.subplot(323)
  211. plt.plot(clf.psth_xx, np.mean(psth1[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean')
  212. plt.plot(clf.psth_xx, np.median(psth1[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median')
  213. plt.plot(clf.psth_xx, np.min(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='min')
  214. plt.plot(clf.psth_xx, np.max(psth1[:,:,ch_ids], axis=0), 'k', alpha=0.5, label='max')
  215. # plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1),'k', label='mean channels')
  216. # plt.plot(clf.psth_xx, np.median(np.median(psth1[:,:,ch_ids],axis=0), axis=1),'k--', label='median')
  217. if len(ch_ids)>2:
  218. plt.plot(clf.psth_xx,res1[:,0],'C3', label='PC1')
  219. plt.plot(clf.psth_xx,res1[:,1],'C4', label='PC2')
  220. plt.ylabel('Selected channels,\naverage accross trials')
  221. plt.legend(loc=1, prop={'size': 6})
  222. plt.subplot(324)
  223. plt.plot(clf.psth_xx, np.mean(psth2[:,:,ch_ids], axis=0), 'k', alpha=1, label='mean')
  224. plt.plot(clf.psth_xx, np.median(psth2[:,:,ch_ids], axis=0), 'k--', alpha=1, label='median')
  225. plt.plot(clf.psth_xx, np.min(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='min')
  226. plt.plot(clf.psth_xx, np.max(psth2[:,:,ch_ids], axis=0), 'k--', alpha=0.5, label='max')
  227. # plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1),'k', label='average')
  228. # plt.plot(clf.psth_xx, np.median(np.median(psth2[:,:,ch_ids],axis=0), axis=1),'k--', label='median')
  229. if len(ch_ids)>2:
  230. plt.plot(clf.psth_xx,res2[:,0],'C3', label='PC1')
  231. plt.plot(clf.psth_xx,res2[:,1],'C4', label='PC2')
  232. plt.legend(loc=1, prop={'size': 6})
  233. plt.subplot(325)
  234. plt.plot(clf.psth_xx, mu1.mean(axis=0), 'C2', label='yes mean')
  235. plt.plot(clf.psth_xx, mu2.mean(axis=0), 'C1', label='no mean')
  236. plt.plot(clf.psth_xx, np.median(mu1, axis=0), 'C2--', label='yes median')
  237. plt.plot(clf.psth_xx, np.median(mu2, axis=0), 'C1--', label='no median')
  238. plt.legend(loc=1, prop={'size': 6})
  239. plt.subplot(326)
  240. plt.plot(clf.psth_xx, psth1[:,:,ch_ids].mean(0).mean(1), 'C2', label='yes mean')
  241. plt.plot(clf.psth_xx, psth2[:,:,ch_ids].mean(0).mean(1), 'C1', label='no mean')
  242. plt.legend(loc=1, prop={'size': 6})
  243. if fig_name !='':
  244. fig_name = f'{fig_name}_ch_{ch_ids[0]}'
  245. print(f'saving figure as {fig_name}')
  246. plt.savefig(fig_name)
  247. return None
  248. # CODE writte by Marcos
  249. __author__ = "Marcos Duarte, https://github.com/demotu/BMC"
  250. __version__ = "1.0.5"
  251. __license__ = "MIT"
  252. def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
  253. kpsh=False, valley=False, show=False, ax=None, width=1):
  254. """Detect peaks in data based on their amplitude and other features.
  255. Parameters
  256. ----------
  257. x : 1D array_like
  258. data.
  259. mph : {None, number}, optional (default = None)
  260. detect peaks that are greater than minimum peak height (if parameter
  261. `valley` is False) or peaks that are smaller than maximum peak height
  262. (if parameter `valley` is True).
  263. mpd : positive integer, optional (default = 1)
  264. detect peaks that are at least separated by minimum peak distance (in
  265. number of data).
  266. threshold : positive number, optional (default = 0)
  267. detect peaks (valleys) that are greater (smaller) than `threshold`
  268. in relation to their immediate neighbors.
  269. edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
  270. for a flat peak, keep only the rising edge ('rising'), only the
  271. falling edge ('falling'), both edges ('both'), or don't detect a
  272. flat peak (None).
  273. kpsh : bool, optional (default = False)
  274. keep peaks with same height even if they are closer than `mpd`.
  275. valley : bool, optional (default = False)
  276. if True (1), detect valleys (local minima) instead of peaks.
  277. show : bool, optional (default = False)
  278. if True (1), plot data in matplotlib figure.
  279. ax : a matplotlib.axes.Axes instance, optional (default = None).
  280. width : positive integer, optional (default = 1)
  281. Required width of peaks in samples above or equal to mph.
  282. Returns
  283. -------
  284. ind : 1D array_like
  285. indeces of the peaks in `x`.
  286. Notes
  287. -----
  288. The detection of valleys instead of peaks is performed internally by simply
  289. negating the data: `ind_valleys = detect_peaks(-x)`
  290. The function can handle NaN's
  291. See this IPython Notebook [1]_.
  292. References
  293. ----------
  294. .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
  295. Examples
  296. --------
  297. >>> from detect_peaks import detect_peaks
  298. >>> x = np.random.randn(100)
  299. >>> x[60:81] = np.nan
  300. >>> # detect all peaks and plot data
  301. >>> ind = detect_peaks(x, show=True)
  302. >>> print(ind)
  303. >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
  304. >>> # set minimum peak height = 0 and minimum peak distance = 20
  305. >>> detect_peaks(x, mph=0, mpd=20, show=True)
  306. >>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
  307. >>> # set minimum peak distance = 2
  308. >>> detect_peaks(x, mpd=2, show=True)
  309. >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
  310. >>> # detection of valleys instead of peaks
  311. >>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
  312. >>> x = [0, 1, 1, 0, 1, 1, 0]
  313. >>> # detect both edges
  314. >>> detect_peaks(x, edge='both', show=True)
  315. >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
  316. >>> # set threshold = 2
  317. >>> detect_peaks(x, threshold = 2, show=True)
  318. Version history
  319. ---------------
  320. '1.0.5':
  321. The sign of `mph` is inverted if parameter `valley` is True
  322. '1.0.6':
  323. @Espinosa: Peak width added
  324. """
  325. x = np.atleast_1d(x).astype('float64')
  326. if x.size < 3:
  327. return np.array([], dtype=int)
  328. if valley:
  329. x = -x
  330. if mph is not None:
  331. mph = -mph
  332. # find indices of all peaks
  333. dx = x[1:] - x[:-1]
  334. # handle NaN's
  335. indnan = np.where(np.isnan(x))[0]
  336. if indnan.size:
  337. x[indnan] = np.inf
  338. dx[np.where(np.isnan(dx))[0]] = np.inf
  339. ine, ire, ife = np.array([[], [], []], dtype=int)
  340. if not edge:
  341. ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
  342. else:
  343. if edge.lower() in ['rising', 'both']:
  344. ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
  345. if edge.lower() in ['falling', 'both']:
  346. ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
  347. ind = np.unique(np.hstack((ine, ire, ife)))
  348. # handle NaN's
  349. if ind.size and indnan.size:
  350. # NaN's and values close to NaN's cannot be peaks
  351. ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
  352. # first and last values of x cannot be peaks
  353. if ind.size and ind[0] == 0:
  354. ind = ind[1:]
  355. if ind.size and ind[-1] == x.size-1:
  356. ind = ind[:-1]
  357. # remove peaks < minimum peak height
  358. if ind.size and mph is not None:
  359. ind = ind[x[ind] >= mph]
  360. # remove peaks - neighbors < threshold
  361. if ind.size and threshold > 0:
  362. dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
  363. ind = np.delete(ind, np.where(dx < threshold)[0])
  364. # enforce peaks above mph have minimum width (added by Espinosa)
  365. if ind.size and width > 1:
  366. # ind_in = np.ones((ind.size, 1), dtype=bool)
  367. ind_in = ind < x.size - width
  368. for index in range(np.sum(ind_in)):
  369. ind_in[index] = np.all(x[ ind[index]:ind[index] + width ] >= mph)
  370. ind = ind[ind_in]
  371. ind += width
  372. # detect small peaks closer than minimum peak distance
  373. if ind.size and mpd > 1:
  374. # ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
  375. idel = np.zeros(ind.size, dtype=bool)
  376. for i in range(ind.size):
  377. if not idel[i]:
  378. # keep peaks with the same height if kpsh is True
  379. idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
  380. & (x[ind[i]] > x[ind] if kpsh else True)
  381. idel[i] = 0 # Keep current peak
  382. # remove the small peaks and sort back the indices by their occurrence
  383. ind = np.sort(ind[~idel])
  384. if show:
  385. if indnan.size:
  386. x[indnan] = np.nan
  387. if valley:
  388. x = -x
  389. if mph is not None:
  390. mph = -mph
  391. _plot(x, mph, mpd, threshold, edge, valley, ax, ind)
  392. return ind
  393. def _plot(x, mph, mpd, threshold, edge, valley, ax, ind):
  394. """Plot results of the detect_peaks function, see its help."""
  395. try:
  396. import matplotlib.pyplot as plt
  397. except ImportError:
  398. print('matplotlib is not available.')
  399. else:
  400. if ax is None:
  401. _, ax = plt.subplots(1, 1, figsize=(8, 4))
  402. ax.plot(x, 'b', lw=1)
  403. if ind.size:
  404. label = 'valley' if valley else 'peak'
  405. label = label + 's' if ind.size > 1 else label
  406. ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
  407. label='%d %s' % (ind.size, label))
  408. ax.legend(loc='best', framealpha=.5, numpoints=1)
  409. ax.set_xlim(-.02*x.size, x.size*1.02-1)
  410. ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
  411. yrange = ymax - ymin if ymax > ymin else 1
  412. ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
  413. ax.set_xlabel('Data #', fontsize=14)
  414. ax.set_ylabel('Amplitude', fontsize=14)
  415. mode = 'Valley detection' if valley else 'Peak detection'
  416. ax.set_title("%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"
  417. % (mode, str(mph), mpd, str(threshold), edge))
  418. # plt.grid()
  419. plt.show()
  420. # SPECTRAL METHODS
  421. def bw_filter_coeff(fc, fs, order=5, btype='low'):
  422. nyq = 0.5 * fs
  423. fc_norm = np.array(fc) / nyq
  424. b, a = signal.butter(order, fc_norm, btype=btype, analog=False)
  425. return b, a
  426. def bw_filter(data, fc, fs, order=5, plot=False):
  427. if fc[1] == 0:
  428. b, a = bw_filter_coeff(fc[0], fs, order=order, btype='lowpass')
  429. elif fc[0] == 0:
  430. b, a = bw_filter_coeff(fc[1], fs, order=order, btype='highpass')
  431. else:
  432. b, a = bw_filter_coeff(fc, fs, order=order, btype='bandpass')
  433. y = signal.filtfilt(b, a, data)
  434. if plot:
  435. plot_freq_resp(fc, fs, b, a)
  436. # print(f'filter coeffs: {b}\n{a}')
  437. return y.T, b, a
  438. def calc_fft(x, fs, i_stop=1e5, axis=-1):
  439. '''x: ndarray, (samples x channels)
  440. fs: sampling frequency
  441. i_stop: calculate fft up to this sample index
  442. '''
  443. i_stop = int(min(i_stop, len(x)))
  444. x = x[:i_stop, :]
  445. # hann = np.hanning(x.shape[0])[:,np.newaxis]
  446. # hann = np.repeat(hann,x.shape[1],axis=1)
  447. # x = x * hann
  448. x_fft = np.abs(fftpack.fft((x - np.mean(x, axis=axis)), axis=axis) / x.shape[axis])**2
  449. n = x_fft.shape[axis]
  450. x_fft_freq = fftpack.fftfreq(n, d=1. / fs)
  451. x_fft = fftpack.fftshift(x_fft, axes=axis)
  452. x_fft_freq = fftpack.fftshift(x_fft_freq)
  453. return x_fft, x_fft_freq
  454. def plot_freq_resp(fc, fs, b, a):
  455. w, h = signal.freqz(b, a, worN=8000)
  456. plt.figure()
  457. plt.clf()
  458. # plt.subplot(311)
  459. plt.plot(0.5*fs*w/np.pi, np.abs(h), 'b')
  460. if fc[0]>0:
  461. plt.plot(fc[0], 0.5*np.sqrt(2), 'ko')
  462. plt.axvline(fc[0], color='k',ls='--', alpha=0.5)
  463. if fc[1]>0:
  464. plt.plot(fc[1], 0.5*np.sqrt(2), 'ko')
  465. plt.axvline(fc[1], color='k', ls='--', alpha=0.5)
  466. plt.xlim(0, 0.5*fs)
  467. plt.title(f'Filter Frequency Response, [{fc[0]}-{fc[1]}] Hz')
  468. plt.xlabel('Frequency [Hz]')
  469. plt.ylabel('H(e^jw)')
  470. plt.grid()
  471. return None