change_point_detection.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # -*- coding: utf-8 -*-
  2. """
  3. This algorithm determines if a spike train `spk` can be considered as stationary
  4. process (constant firing rate) or not as stationary process (i.e. presence of
  5. one or more points at which the rate increases or decreases). In case of
  6. non-stationarity, the output is a list of detected Change Points (CPs).
  7. Essentially, a det of two-sided window of width `h` (`_filter(t, h, spk)`)
  8. slides over the spike train within the time `[h, t_final-h]`. This generates a
  9. `_filter_process(dt, h, spk)` that assigns at each time `t` the difference
  10. between a spike lying in the right and left window. If at any time `t` this
  11. difference is large 'enough' is assumed the presence of a rate Change Point in
  12. a neighborhood of `t`. A threshold `test_quantile` for the maximum of
  13. the filter_process (max difference of spike count between the left and right
  14. window) is derived based on asymptotic considerations. The procedure is repeated
  15. for an arbitrary set of windows, with different size `h`.
  16. Examples
  17. --------
  18. The following applies multiple_filter_test to a spike trains.
  19. >>> import quantities as pq
  20. >>> import neo
  21. >>> from elephant.change_point_detection import multiple_filter_test
  22. >>> test_array = [1.1,1.2,1.4, 1.6,1.7,1.75,1.8,1.85,1.9,1.95]
  23. >>> st = neo.SpikeTrain(test_array, units='s', t_stop = 2.1)
  24. >>> window_size = [0.5]*pq.s
  25. >>> t_fin = 2.1*pq.s
  26. >>> alpha = 5.0
  27. >>> num_surrogates = 10000
  28. >>> change_points = multiple_filter_test(window_size, st, t_fin, alpha,
  29. num_surrogates, dt = 0.5*pq.s)
  30. References
  31. ----------
  32. Messer, M., Kirchner, M., Schiemann, J., Roeper, J., Neininger, R., & Schneider,
  33. G. (2014). A multiple filter test for the detection of rate changes in renewal
  34. processes with varying variance. The Annals of Applied Statistics, 8(4),2027-2067.
  35. Original code
  36. -------------
  37. Adapted from the published R implementation:
  38. DOI: 10.1214/14-AOAS782SUPP;.r
  39. """
  40. import numpy as np
  41. import quantities as pq
  42. def multiple_filter_test(window_sizes, spiketrain, t_final, alpha, n_surrogates,
  43. test_quantile=None, test_param=None, dt=None):
  44. """
  45. Detects change points.
  46. This function returns the detected change points, that correspond to the
  47. maxima of the `_filter_processes`. These are the processes generated by
  48. sliding the windows of step `dt`; at each step the difference between spike
  49. on the right and left window is calculated.
  50. Parameters
  51. ----------
  52. window_sizes : list of quantity objects
  53. list that contains windows sizes
  54. spiketrain : neo.SpikeTrain, numpy array or list
  55. spiketrain objects to analyze
  56. t_final : quantity
  57. final time of the spike train which is to be analysed
  58. alpha : float
  59. alpha-quantile in range [0, 100] for the set of maxima of the limit
  60. processes
  61. n_surrogates : integer
  62. numbers of simulated limit processes
  63. test_quantile : float
  64. threshold for the maxima of the filter derivative processes, if any
  65. of these maxima is larger than this value, it is assumed the
  66. presence of a cp at the time corresponding to that maximum
  67. dt : quantity
  68. resolution, time step at which the windows are slided
  69. test_param : np.array of shape (3, num of window),
  70. first row: list of `h`, second and third rows: empirical means and
  71. variances of the limit process correspodning to `h`. This will be
  72. used to normalize the `filter_process` in order to give to the every
  73. maximum the same impact on the global statistic.
  74. Returns:
  75. --------
  76. cps : list of lists
  77. one list for each window size `h`, containing the points detected with
  78. the corresponding `filter_process`. N.B.: only cps whose h-neighborhood
  79. does not include previously detected cps (with smaller window h) are
  80. added to the list.
  81. """
  82. if (test_quantile is None) and (test_param is None):
  83. test_quantile, test_param = empirical_parameters(window_sizes, t_final,
  84. alpha, n_surrogates,
  85. dt)
  86. elif test_quantile is None:
  87. test_quantile = empirical_parameters(window_sizes, t_final, alpha,
  88. n_surrogates, dt)[0]
  89. elif test_param is None:
  90. test_param = empirical_parameters(window_sizes, t_final, alpha,
  91. n_surrogates, dt)[1]
  92. spk = spiketrain
  93. # List of lists of detected change points (CPs), to be returned
  94. cps = []
  95. for i, h in enumerate(window_sizes):
  96. # automatic setting of dt
  97. dt_temp = h / 20 if dt is None else dt
  98. # filter_process for window of size h
  99. t, differences = _filter_process(dt_temp, h, spk, t_final, test_param)
  100. time_index = np.arange(len(differences))
  101. # Point detected with window h
  102. cps_window = []
  103. while np.max(differences) > test_quantile:
  104. cp_index = np.argmax(differences)
  105. # from index to time
  106. cp = cp_index * dt_temp + h
  107. #print("detected point {0}".format(cp), "with filter {0}".format(h))
  108. # before repeating the procedure, the h-neighbourgs of detected CP
  109. # are discarded, because rate changes into it are alrady explained
  110. mask_fore = time_index > cp_index - int((h / dt_temp).simplified)
  111. mask_back = time_index < cp_index + int((h / dt_temp).simplified)
  112. differences[mask_fore & mask_back] = 0
  113. # check if the neighbourhood of detected cp does not contain cps
  114. # detected with other windows
  115. neighbourhood_free = True
  116. # iterate on lists of cps detected with smaller window
  117. for j in range(i):
  118. # iterate on CPs detected with the j-th smallest window
  119. for c_pre in cps[j]:
  120. if c_pre - h < cp < c_pre + h:
  121. neighbourhood_free = False
  122. break
  123. # if none of the previously detected CPs falls in the h-
  124. # neighbourhood
  125. if neighbourhood_free:
  126. # add the current CP to the list
  127. cps_window.append(cp)
  128. # add the present list to the grand list
  129. cps.append(cps_window)
  130. return cps
  131. def _brownian_motion(t_in, t_fin, x_in, dt):
  132. """
  133. Generate a Brownian Motion.
  134. Parameters
  135. ----------
  136. t_in : quantities,
  137. initial time
  138. t_fin : quantities,
  139. final time
  140. x_in : float,
  141. initial point of the process: _brownian_motio(0) = x_in
  142. dt : quantities,
  143. resolution, time step at which brownian increments are summed
  144. Returns
  145. -------
  146. Brownian motion on [t_in, t_fin], with resolution dt and initial state x_in
  147. """
  148. u = 1 * pq.s
  149. try:
  150. t_in_sec = t_in.rescale(u).magnitude
  151. except ValueError:
  152. raise ValueError("t_in must be a time quantity")
  153. try:
  154. t_fin_sec = t_fin.rescale(u).magnitude
  155. except ValueError:
  156. raise ValueError("t_fin must be a time quantity")
  157. try:
  158. dt_sec = dt.rescale(u).magnitude
  159. except ValueError:
  160. raise ValueError("dt must be a time quantity")
  161. x = np.random.normal(0, np.sqrt(dt_sec), size=int((t_fin_sec - t_in_sec)
  162. / dt_sec))
  163. s = np.cumsum(x)
  164. return s + x_in
  165. def _limit_processes(window_sizes, t_final, dt):
  166. """
  167. Generate the limit processes (depending only on t_final and h), one for
  168. each window size `h` in H. The distribution of maxima of these processes
  169. is used to derive threshold `test_quantile` and parameters `test_param`.
  170. Parameters
  171. ----------
  172. window_sizes : list of quantities
  173. set of windows' size
  174. t_final : quantity object
  175. end of limit process
  176. dt : quantity object
  177. resolution, time step at which the windows are slided
  178. Returns
  179. -------
  180. limit_processes : list of numpy array
  181. each entries contains the limit processes for each h,
  182. evaluated in [h,T-h] with steps dt
  183. """
  184. limit_processes = []
  185. u = 1 * pq.s
  186. try:
  187. window_sizes_sec = window_sizes.rescale(u).magnitude
  188. except ValueError:
  189. raise ValueError("window_sizes must be a list of times")
  190. try:
  191. dt_sec = dt.rescale(u).magnitude
  192. except ValueError:
  193. raise ValueError("dt must be a time quantity")
  194. w = _brownian_motion(0 * u, t_final, 0, dt)
  195. for h in window_sizes_sec:
  196. # BM on [h,T-h], shifted in time t-->t+h
  197. brownian_right = w[int(2 * h/dt_sec):]
  198. # BM on [h,T-h], shifted in time t-->t-h
  199. brownian_left = w[:int(-2 * h/dt_sec)]
  200. # BM on [h,T-h]
  201. brownian_center = w[int(h/dt_sec):int(-h/dt_sec)]
  202. modul = np.abs(brownian_right + brownian_left - 2 * brownian_center)
  203. limit_process_h = modul / (np.sqrt(2 * h))
  204. limit_processes.append(limit_process_h)
  205. return limit_processes
  206. def empirical_parameters(window_sizes, t_final, alpha, n_surrogates, dt = None):
  207. """
  208. This function generates the threshold and the null parameters.
  209. The`_filter_process_h` has been proved to converge (for t_fin, h-->infinity)
  210. to a continuous functional of a Brownaian motion ('limit_process').
  211. Using a MonteCarlo technique, maxima of these limit_processes are
  212. collected.
  213. The threshold is defined as the alpha quantile of this set of maxima.
  214. Namely:
  215. test_quantile := alpha quantile of {max_(h in window_size)[
  216. max_(t in [h, t_final-h])_limit_process_h(t)]}
  217. Parameters
  218. ----------
  219. window_sizes : list of quantity objects
  220. set of windows' size
  221. t_final : quantity object
  222. final time of the spike
  223. alpha : float
  224. alpha-quantile in range [0, 100]
  225. n_surrogates : integer
  226. numbers of simulated limit processes
  227. dt : quantity object
  228. resolution, time step at which the windows are slided
  229. Returns
  230. -------
  231. test_quantile : float
  232. threshold for the maxima of the filter derivative processes, if any
  233. of these maxima is larger than this value, it is assumed the
  234. presence of a cp at the time corresponding to that maximum
  235. test_param : np.array 3 * num of window,
  236. first row: list of `h`, second and third rows: empirical means and
  237. variances of the limit process correspodning to `h`. This will be
  238. used to normalize the `filter_process` in order to give to the every
  239. maximum the same impact on the global statistic.
  240. """
  241. # try:
  242. # window_sizes_sec = window_sizes.rescale(u)
  243. # except ValueError:
  244. # raise ValueError("H must be a list of times")
  245. # window_sizes_mag = window_sizes_sec.magnitude
  246. # try:
  247. # t_final_sec = t_final.rescale(u)
  248. # except ValueError:
  249. # raise ValueError("T must be a time quantity")
  250. # t_final_mag = t_final_sec.magnitude
  251. if not isinstance(window_sizes, pq.Quantity):
  252. raise ValueError("window_sizes must be a list of time quantities")
  253. if not isinstance(t_final, pq.Quantity):
  254. raise ValueError("t_final must be a time quantity")
  255. if not isinstance(n_surrogates, int):
  256. raise TypeError("n_surrogates must be an integer")
  257. if not (isinstance(dt, pq.Quantity) or (dt is None)):
  258. raise ValueError("dt must be a time quantity")
  259. if t_final <= 0:
  260. raise ValueError("t_final needs to be strictly positive")
  261. if alpha * (100.0 - alpha) < 0:
  262. raise ValueError("alpha needs to be in (0,100)")
  263. if np.min(window_sizes) <= 0:
  264. raise ValueError("window size needs to be strictly positive")
  265. if np.max(window_sizes) >= t_final / 2:
  266. raise ValueError("window size too large")
  267. if dt is not None:
  268. for h in window_sizes:
  269. if int(h.rescale('us')) % int(dt.rescale('us')) != 0:
  270. raise ValueError(
  271. "Every window size h must be a multiple of dt")
  272. # Generate a matrix M*: n X m where n = n_surrogates is the number of
  273. # simulated limit processes and m is the number of chosen window sizes.
  274. # Elements are: M*(i,h) = max(t in T)[`limit_process_h`(t)],
  275. # for each h in H and surrogate i
  276. maxima_matrix = []
  277. for i in range(n_surrogates):
  278. # mh_star = []
  279. simu = _limit_processes(window_sizes, t_final, dt)
  280. # for i, h in enumerate(window_sizes_mag):
  281. # # max over time of the limit process generated with window h
  282. # m_h = np.max(simu[i])
  283. # mh_star.append(m_h)
  284. mh_star = [np.max(x) for x in simu] # max over time of the limit process generated with window h
  285. maxima_matrix.append(mh_star)
  286. maxima_matrix = np.asanyarray(maxima_matrix)
  287. # these parameters will be used to normalize both the limit_processes (H0)
  288. # and the filter_processes
  289. null_mean = maxima_matrix.mean(axis=0)
  290. null_var = maxima_matrix.var(axis=0)
  291. # matrix normalization by mean and variance of the limit process, in order
  292. # to give, for every h, the same impact on the global maximum
  293. matrix_normalized = (maxima_matrix - null_mean) / np.sqrt(null_var)
  294. great_maxs = np.max(matrix_normalized, axis=1)
  295. test_quantile = np.percentile(great_maxs, 100.0 - alpha)
  296. null_parameters = [window_sizes, null_mean, null_var]
  297. test_param = np.asanyarray(null_parameters)
  298. return test_quantile, test_param
  299. def _filter(t, h, spk):
  300. """
  301. This function calculates the difference of spike counts in the left and right
  302. side of a window of size h centered in t and normalized by its variance.
  303. The variance of this count can be expressed as a combination of mean and var
  304. of the I.S.I. lying inside the window.
  305. Parameters
  306. ----------
  307. h : quantity
  308. window's size
  309. t : quantity
  310. time on which the window is centered
  311. spk : list, numpy array or SpikeTrain
  312. spike train to analyze
  313. Returns
  314. -------
  315. difference : float,
  316. difference of spike count normalized by its variance
  317. """
  318. u = 1 * pq.s
  319. try:
  320. t_sec = t.rescale(u).magnitude
  321. except AttributeError:
  322. raise ValueError("t must be a quantities object")
  323. # tm = t_sec.magnitude
  324. try:
  325. h_sec = h.rescale(u).magnitude
  326. except AttributeError:
  327. raise ValueError("h must be a time quantity")
  328. # hm = h_sec.magnitude
  329. try:
  330. spk_sec = spk.rescale(u).magnitude
  331. except AttributeError:
  332. raise ValueError(
  333. "spiketrain must be a list (array) of times or a neo spiketrain")
  334. # cut spike-train on the right
  335. train_right = spk_sec[(t_sec < spk_sec) & (spk_sec < t_sec + h_sec)]
  336. # cut spike-train on the left
  337. train_left = spk_sec[(t_sec - h_sec < spk_sec) & (spk_sec < t_sec)]
  338. # spike count in the right side
  339. count_right = train_right.size
  340. # spike count in the left side
  341. count_left = train_left.size
  342. # form spikes to I.S.I
  343. isi_right = np.diff(train_right)
  344. isi_left = np.diff(train_left)
  345. if isi_right.size == 0:
  346. mu_ri = 0
  347. sigma_ri = 0
  348. else:
  349. # mean of I.S.I inside the window
  350. mu_ri = np.mean(isi_right)
  351. # var of I.S.I inside the window
  352. sigma_ri = np.var(isi_right)
  353. if isi_left.size == 0:
  354. mu_le = 0
  355. sigma_le = 0
  356. else:
  357. mu_le = np.mean(isi_left)
  358. sigma_le = np.var(isi_left)
  359. if (sigma_le > 0) & (sigma_ri > 0):
  360. s_quad = (sigma_ri / mu_ri**3) * h_sec + (sigma_le / mu_le**3) * h_sec
  361. else:
  362. s_quad = 0
  363. if s_quad == 0:
  364. difference = 0
  365. else:
  366. difference = (count_right - count_left) / np.sqrt(s_quad)
  367. return difference
  368. def _filter_process(dt, h, spk, t_final, test_param):
  369. """
  370. Given a spike train `spk` and a window size `h`, this function generates
  371. the `filter derivative process` by evaluating the function `_filter`
  372. in steps of `dt`.
  373. Parameters
  374. ----------
  375. h : quantity object
  376. window's size
  377. t_final : quantity,
  378. time on which the window is centered
  379. spk : list, array or SpikeTrain
  380. spike train to analyze
  381. dt : quantity object, time step at which the windows are slided
  382. resolution
  383. test_param : matrix, the means of the first row list of `h`,
  384. the second row Empirical and the third row variances of
  385. the limit processes `Lh` are used to normalize the number
  386. of elements inside the windows
  387. Returns
  388. -------
  389. time_domain : numpy array
  390. time domain of the `filter derivative process`
  391. filter_process : array,
  392. values of the `filter derivative process`
  393. """
  394. u = 1 * pq.s
  395. try:
  396. h_sec = h.rescale(u).magnitude
  397. except AttributeError:
  398. raise ValueError("h must be a time quantity")
  399. try:
  400. t_final_sec = t_final.rescale(u).magnitude
  401. except AttributeError:
  402. raise ValueError("t_final must be a time quanity")
  403. try:
  404. dt_sec = dt.rescale(u).magnitude
  405. except AttributeError:
  406. raise ValueError("dt must be a time quantity")
  407. # domain of the process
  408. time_domain = np.arange(h_sec, t_final_sec - h_sec, dt_sec)
  409. filter_trajectrory = []
  410. # taken from the function used to generate the threshold
  411. emp_mean_h = test_param[1][test_param[0] == h]
  412. emp_var_h = test_param[2][test_param[0] == h]
  413. for t in time_domain:
  414. filter_trajectrory.append(_filter(t*u, h, spk))
  415. filter_trajectrory = np.asanyarray(filter_trajectrory)
  416. # ordered normalization to give each process the same impact on the max
  417. filter_process = (
  418. np.abs(filter_trajectrory) - emp_mean_h) / np.sqrt(emp_var_h)
  419. return time_domain, filter_process