SWTTEO.m 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. function [spikepos, out_] = SWTTEO(in,params)
  2. %SWTTEO Detects Spikes Location using a modified WTEO approach
  3. % Usage: spikepos = swtteo(in);
  4. % spikepos = swtteo(in,params);
  5. %
  6. % Input parameters:
  7. % in_struc: Input structure which contains
  8. % M: Matrix with data, stored columnwise
  9. % SaRa: Sampling frequency
  10. % optional input parameters:
  11. % none
  12. % Output parameters:
  13. % spikepos: Timestamps of the detected spikes stored columnwise
  14. %
  15. % Description:
  16. % swtteo(in,params) computes the location of action potential in
  17. % noisy MEA measurements. This method is based on the work of N.
  18. % Nabar and K. Rajgopal "A Wavelet based Teager Engergy Operator for
  19. % Spike Detection in Microelectrode Array Recordings". The algorithm
  20. % therein was further improved by using a stationary wavelet
  21. % transform and a different thresholding concept.
  22. % For an unsupervised usage the sensitivity of the algorithm can be
  23. % adapted by changing the value of the variable global_fac in line
  24. % 108. A larger value results in fewer detected spikes but also the
  25. % number of false positives decrease. Decreasing this factor makes it
  26. % more sensitive to detect spikes.
  27. %
  28. % References:
  29. % tbd.
  30. %
  31. %
  32. % Author: F. Lieb, February 2016
  33. %
  34. if nargin<2
  35. params = struct;
  36. end
  37. %parse inputs
  38. [params,s,fs] = parseInput(in,params);
  39. TEO = @(x,k) (x.^2 - myTEOcircshift(x,[-k, 0]).*myTEOcircshift(x,[k, 0]));
  40. [L,c] = size(s);
  41. if L==1
  42. s = s';
  43. L = c;
  44. c = 1;
  45. end
  46. %do zero padding if the L is not divisible by a power of two
  47. pow = 2^params.wavLevel;
  48. if rem(L,pow) > 0
  49. Lok = ceil(L/pow)*pow;
  50. Ldiff = Lok - L;
  51. s = [s; zeros(Ldiff,c)];
  52. end
  53. %testing showed prefiltering didnt improve the results
  54. %prefilter signal
  55. if params.filter
  56. if ~isfield(params,'F1')
  57. params.Fstop = 100;
  58. params.Fpass = 200;
  59. Apass = 0.2;
  60. Astop = 80;
  61. params.F1 = designfilt( 'highpassiir',...
  62. 'StopbandFrequency',params.Fstop ,...
  63. 'PassbandFrequency',params.Fpass,...
  64. 'StopbandAttenuation',Astop, ...
  65. 'PassbandRipple',Apass,...
  66. 'SampleRate',fs,...
  67. 'DesignMethod','butter');
  68. end
  69. f = filtfilt(params.F1,s);
  70. else
  71. f = s;
  72. end
  73. %non vectorized version:
  74. % [SWTa,~] = swt(s,wavLevel,wavelet);
  75. % out22 = TEO(SWTa);
  76. %vectorized version:
  77. lo_D = wfilters(params.wavelet);
  78. out_ = zeros(size(s));
  79. ss = f;
  80. for k=1:params.wavLevel
  81. %Extension
  82. lf = length(lo_D);
  83. ss = extendswt(ss,lf);
  84. %convolution
  85. swa = conv2(ss,lo_D','valid');
  86. swa = swa(2:end,:); %even number of filter coeffcients
  87. %apply teo to swt output
  88. temp = abs(TEO(swa,1));
  89. if params.smoothing
  90. wind = hamming(params.winlength);
  91. %wind = sqrt(3*sum(wind.^2) + sum(wind)^2);
  92. %temp = filtfilt(wind,1,temp);
  93. if params.normalize_smoothingwindow
  94. wind = wind./(sqrt(3*sum(wind.^2) + sum(wind)^2));
  95. end
  96. temp2 = conv2(temp,wind','same');
  97. %temp = circshift(filter(wind,1,temp), [-3*1 1]);
  98. else
  99. temp2 = temp;
  100. end
  101. out_ = out_ + temp2;
  102. %dyadic upscaling of filter coefficients
  103. lo_D = dyadup(lo_D,0,1);
  104. %updates
  105. ss = swa;
  106. end
  107. %non-vectorized version to extract spikes...
  108. switch params.method
  109. case 'auto'
  110. global_fac = 1.11e+03;%1.6285e+03; %540;%1800;%430; %1198; %change this
  111. if c == 1
  112. [CC,LL] = wavedec(s,5,'sym5');
  113. lambda = global_fac*wnoisest(CC,LL,1);
  114. thout = wthresh(out_,'h',lambda);
  115. spikepos = getSpikePositions(thout,fs,s,params);
  116. else
  117. spikepos = cell(c,1);
  118. for jj=1:c
  119. [CC,LL] = wavedec(s(:,jj),5,'sym5');
  120. lambda = global_fac*wnoisest(CC,LL,1);
  121. thout = wthresh(out_(:,jj),'h',lambda);
  122. spikepos{jj}=getSpikePositions(thout,fs,s(:,jj),params);
  123. end
  124. end
  125. case 'auto2'
  126. global_fac = 9.064e+02;%1.3454e+03;%800;%1800;%430; %1198; %change this
  127. params.method = 'auto';
  128. if c == 1
  129. [CC,LL] = wavedec(out_,5,'sym5');
  130. lambda = global_fac*wnoisest(CC,LL,1);
  131. thout = wthresh(out_,'h',lambda);
  132. spikepos = getSpikePositions(thout,fs,s,params);
  133. else
  134. spikepos = cell(c,1);
  135. for jj=1:c
  136. [CC,LL] = wavedec(out_(:,jj),5,'sym5');
  137. lambda = global_fac*wnoisest(CC,LL,1);
  138. thout = wthresh(out_(:,jj),'h',lambda);
  139. spikepos{jj}=getSpikePositions(thout,fs,s(:,jj),params);
  140. end
  141. end
  142. case 'numspikes'
  143. if c == 1
  144. spikepos=getSpikePositions(out_,fs,s,params);
  145. else
  146. spikepos = cell(1,c);
  147. params_tmp = params;
  148. for jj=1:c
  149. % extract spike positions from wteo output
  150. params_tmp.numspikes = params.numspikes(jj);
  151. spikepos{jj}=getSpikePositions(out_(:,jj),fs,s(:,jj),params_tmp);
  152. end
  153. end
  154. case 'lambda'
  155. thout = wthresh(out_,'h',params.lambda);
  156. spikepos = getSpikePositions(thout,fs,s,params);
  157. case 'energy'
  158. params.p = 0.80;
  159. params.rel_norm = 5.718e-3;%5.718e-3;%4.842e-3;%22e-5;%1.445e-4;
  160. %wavelet denoising
  161. wdenoising = 0;
  162. n = 9;
  163. w = 'sym5';
  164. tptr = 'sqtwolog'; %'rigrsure','heursure','sqtwolog','minimaxi'
  165. if c == 1
  166. if wdenoising == 1
  167. out_ = wden(out_,tptr,'h','mln',n,w);
  168. %high frequencies, decision variable
  169. c = dgtreal(out_,{'hann',10},1,200);
  170. out_ = sum(abs(c).^2,1);
  171. end
  172. spikepos = getSpikePositions(out_,fs,s,params);
  173. else
  174. spikepos = cell(c,1);
  175. for jj=1:c
  176. if wdenoising == 1
  177. out_(:,jj) = wden(out_(:,jj),tptr,'h','mln',n,w);
  178. end
  179. spikepos{jj} = getSpikePositions(out_(:,jj),fs,s(:,jj),params);
  180. end
  181. end
  182. otherwise
  183. error('unknown detection method specified');
  184. end
  185. %internal functions:
  186. %--------------------------------------------------------------------------
  187. function [params,s,fs] = parseInput(in,params)
  188. %PARSEINPUT parses input variables
  189. s = in.M;
  190. fs = in.SaRa;
  191. %Default settings for detection method
  192. if ~isfield(params,'method')
  193. params.method = 'auto';
  194. end
  195. if strcmp(params.method,'numspikes')
  196. if ~isfield(params,'numspikes')
  197. error('please specify number of spikes in params.numspikes');
  198. end
  199. end
  200. %Default settings for stationary wavelet transform
  201. if ~isfield(params,'wavLevel')
  202. params.wavLevel = 2;
  203. end
  204. if ~isfield(params, 'wavelet')
  205. params.wavelet = 'sym5';
  206. end
  207. if ~isfield(params, 'winlength')
  208. params.winlength = ceil(1.3e-3*fs); %1.3
  209. end
  210. if ~isfield(params, 'normalize_smoothingwindow')
  211. params.normalize_smoothingwindow = 0;
  212. end
  213. if ~isfield(params, 'smoothing')
  214. params.smoothing = 1;
  215. end
  216. if ~isfield(params, 'filter')
  217. params.filter = 0;
  218. end
  219. function y = extendswt(x,lf)
  220. %EXTENDSWT extends the signal periodically at the boundaries
  221. [r,c] = size(x);
  222. y = zeros(r+lf,c);
  223. y(1:lf/2,:) = x(end-lf/2+1:end,:);
  224. y(lf/2+1:lf/2+r,:) = x;
  225. y(end-lf/2+1:end,:) = x(1:lf/2,:);
  226. % function idx2 = getSpikePositions(input_sig,fs,orig_sig,params)
  227. % %GETSPIKEPOSITIONS computes spike positions from thresholded data
  228. % %
  229. % % This function computes the exact spike locations based on a thresholded
  230. % % signal. The spike locations are indicated as non-zero elements in
  231. % % input_sig and are accordingly evaluated.
  232. % %
  233. % % The outputs are the spike positions in absolute index values (no time
  234. % % dependance).
  235. % %
  236. % % Author: F.Lieb, February 2016
  237. % %
  238. %
  239. %
  240. % %Define a fixed spike duration, prevents from zeros before this duration is
  241. % %over
  242. % spikeduration = 1e-3*fs;
  243. % offset = 1;
  244. % L = length(input_sig);
  245. %
  246. % switch params.method
  247. % case 'numspikes'
  248. % out = input_sig;
  249. % np = 0;
  250. % idx2 = zeros(1,params.numspikes);
  251. % while (np < params.numspikes)
  252. % [~, idxmax] = max(out);
  253. % idxl = idxmax;
  254. % idxr = idxmax;
  255. % out(idxmax) = 0;
  256. % offsetcounter = 0;
  257. % while( out(max(1,idxl-2)) < out(max(1,idxl-1)) ||...
  258. % offsetcounter < spikeduration )
  259. % out(max(1,idxl-1)) = 0;
  260. % idxl = idxl-1;
  261. % offsetcounter = offsetcounter + 1;
  262. % end
  263. % offsetcounter = 0;
  264. % while( out(min(L,idxr+2)) < out(min(L,idxr+1)) ||...
  265. % offsetcounter < spikeduration )
  266. % out(min(L,idxr+1)) = 0;
  267. % idxr = idxr+1;
  268. % offsetcounter = offsetcounter + 1;
  269. % end
  270. % indexx = min(L,idxl-offset:idxr+offset);
  271. % indexx = max(1,indexx);
  272. % idxx = find( abs(orig_sig(indexx)) == ...
  273. % max( abs(orig_sig(indexx) )),1,'first');
  274. % idx2(np+1) = idxl - offset + idxx-1;
  275. % np = np + 1;
  276. % end
  277. %
  278. % case {'auto','lambda'}
  279. % %helper variables
  280. % idx2=[];
  281. % iii=1;
  282. % test2 = input_sig;
  283. % %loop until the input_sig is only zeros
  284. % while (sum(test2) ~= 0)
  285. % %get the first nonzero position
  286. % tmp = find(test2,1,'first');
  287. % test2(tmp) = 0;
  288. % %tmp2 is the counter until the spike duration
  289. % tmp2 = min(length(test2),tmp + 1);%protect against end of vec
  290. % counter = 0;
  291. % %search for the end of the spike
  292. % while(test2(tmp2) ~= 0 || counter<spikeduration )
  293. % test2(tmp2) = 0;
  294. % tmp2 = min(length(test2),tmp2 + 1);
  295. % counter = counter + 1;
  296. % end
  297. % %spike location is in intervall [tmp tmp2], look for the max
  298. % %element in the original signal with some predefined offset:
  299. % indexx = min(length(orig_sig),tmp-offset:tmp2+offset);
  300. % indexx = max(1,indexx);
  301. % idxx = find( abs(orig_sig(indexx)) == ...
  302. % max( abs(orig_sig(indexx) )),1,'first');
  303. % idx2(iii) = tmp - offset + idxx-1;
  304. % iii = iii+1;
  305. % end
  306. % otherwise
  307. % error('unknown method');
  308. % end
  309. %
  310. %
  311. %