asset.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936
  1. # -*- coding: utf-8 -*-
  2. """
  3. ASSET is a statistical method :cite:`asset-Torre16_e1004939` for the detection
  4. of repeating sequences of synchronous spiking events in parallel spike trains.
  5. ASSET analysis class object of finding patterns
  6. -----------------------------------------------
  7. .. autosummary::
  8. :toctree: toctree/asset/
  9. ASSET
  10. Patterns post-exploration
  11. -------------------------
  12. .. autosummary::
  13. :toctree: toctree/asset/
  14. synchronous_events_intersection
  15. synchronous_events_difference
  16. synchronous_events_identical
  17. synchronous_events_no_overlap
  18. synchronous_events_contained_in
  19. synchronous_events_contains_all
  20. synchronous_events_overlap
  21. Tutorial
  22. --------
  23. :doc:`View tutorial <../tutorials/asset>`
  24. Run tutorial interactively:
  25. .. image:: https://mybinder.org/badge.svg
  26. :target: https://mybinder.org/v2/gh/NeuralEnsemble/elephant/master
  27. ?filepath=doc/tutorials/asset.ipynb
  28. Examples
  29. --------
  30. 0) Create `ASSET` class object that holds spike trains.
  31. `ASSET` requires at least one argument - a list of spike trains. If
  32. `spiketrains_y` is not provided, the same spike trains are used to build an
  33. intersection matrix with.
  34. >>> import neo
  35. >>> import numpy as np
  36. >>> import quantities as pq
  37. >>> from elephant import asset
  38. >>> spiketrains = [
  39. ... neo.SpikeTrain([start, start + 6] * (3 * pq.ms) + 10 * pq.ms,
  40. ... t_stop=60 * pq.ms)
  41. ... for _ in range(3)
  42. ... for start in range(3)
  43. ... ]
  44. >>> asset_obj = asset.ASSET(spiketrains, bin_size=3*pq.ms, verbose=False)
  45. 1) Build the intersection matrix `imat`:
  46. >>> imat = asset_obj.intersection_matrix()
  47. 2) Estimate the probability matrix `pmat`, using the analytical method:
  48. >>> pmat = asset_obj.probability_matrix_analytical(imat,
  49. ... kernel_width=9*pq.ms)
  50. 3) Compute the joint probability matrix `jmat`, using a suitable filter:
  51. >>> jmat = asset_obj.joint_probability_matrix(pmat, filter_shape=(5, 1),
  52. ... n_largest=3)
  53. 4) Create the masked version of the intersection matrix, `mmat`, from `pmat`
  54. and `jmat`:
  55. >>> mmat = asset_obj.mask_matrices([pmat, jmat], thresholds=.9)
  56. 5) Cluster significant elements of imat into diagonal structures:
  57. >>> cmat = asset_obj.cluster_matrix_entries(mmat, max_distance=3,
  58. ... min_neighbors=3, stretch=5)
  59. 6) Extract sequences of synchronous events:
  60. >>> sses = asset_obj.extract_synchronous_events(cmat)
  61. The ASSET found 2 sequences of synchronous events:
  62. >>> from pprint import pprint
  63. >>> pprint(sses)
  64. {1: {(9, 3): {0, 3, 6}, (10, 4): {1, 4, 7}, (11, 5): {8, 2, 5}}}
  65. """
  66. from __future__ import division, print_function, unicode_literals
  67. import warnings
  68. import neo
  69. import numpy as np
  70. import quantities as pq
  71. import scipy.spatial
  72. import scipy.stats
  73. from sklearn.cluster import dbscan
  74. from tqdm import trange, tqdm
  75. import elephant.conversion as conv
  76. from elephant import spike_train_surrogates
  77. try:
  78. from mpi4py import MPI
  79. mpi_accelerated = True
  80. comm = MPI.COMM_WORLD
  81. size = comm.Get_size()
  82. rank = comm.Get_rank()
  83. except ImportError:
  84. mpi_accelerated = False
  85. size = 1
  86. rank = 0
  87. __all__ = [
  88. "ASSET",
  89. "synchronous_events_intersection",
  90. "synchronous_events_difference",
  91. "synchronous_events_identical",
  92. "synchronous_events_no_overlap",
  93. "synchronous_events_contained_in",
  94. "synchronous_events_contains_all",
  95. "synchronous_events_overlap"
  96. ]
  97. # =============================================================================
  98. # Some Utility Functions to be dealt with in some way or another
  99. # =============================================================================
  100. def _signals_same_attribute(signals, attr_name):
  101. """
  102. Check whether a list of signals (`neo.AnalogSignal` or `neo.SpikeTrain`)
  103. have same attribute `attr_name`. If so, return that value. Otherwise,
  104. raise ValueError.
  105. Parameters
  106. ----------
  107. signals : list
  108. A list of signals (e.g. `neo.AnalogSignal` or `neo.SpikeTrain`) having
  109. attribute `attr_name`.
  110. Returns
  111. -------
  112. pq.Quantity
  113. The value of the common attribute `attr_name` of the list of signals.
  114. Raises
  115. ------
  116. ValueError
  117. If `signals` is an empty list.
  118. If `signals` have different `attr_name` attribute values.
  119. """
  120. if len(signals) == 0:
  121. raise ValueError('Empty signals list')
  122. attribute = getattr(signals[0], attr_name)
  123. for sig in signals[1:]:
  124. if getattr(sig, attr_name) != attribute:
  125. raise ValueError(
  126. "Signals have different '{}' values".format(attr_name))
  127. return attribute
  128. def _quantities_almost_equal(x, y):
  129. """
  130. Returns True if two quantities are almost equal, i.e., if `x - y` is
  131. "very close to 0" (not larger than machine precision for floats).
  132. Parameters
  133. ----------
  134. x : pq.Quantity
  135. First Quantity to compare.
  136. y : pq.Quantity
  137. Second Quantity to compare. Must have same unit type as `x`, but not
  138. necessarily the same shape. Any shapes of `x` and `y` for which `x - y`
  139. can be calculated are permitted.
  140. Returns
  141. -------
  142. np.ndarray
  143. Array of `bool`, which is True at any position where `x - y` is almost
  144. zero.
  145. Notes
  146. -----
  147. Not the same as `numpy.testing.assert_allclose` (which does not work
  148. with Quantities) and `numpy.testing.assert_almost_equal` (which works only
  149. with decimals)
  150. """
  151. eps = np.finfo(float).eps
  152. relative_diff = (x - y).magnitude
  153. return np.all([-eps <= relative_diff, relative_diff <= eps], axis=0)
  154. def _transactions(spiketrains, bin_size, t_start, t_stop, ids=None):
  155. """
  156. Transform parallel spike trains into a list of sublists, called
  157. transactions, each corresponding to a time bin and containing the list
  158. of spikes in `spiketrains` falling into that bin.
  159. To compute each transaction, the spike trains are binned (with adjacent
  160. exclusive binning) and clipped (i.e., spikes from the same train falling
  161. in the same bin are counted as one event). The list of spike IDs within
  162. each bin form the corresponding transaction.
  163. Parameters
  164. ----------
  165. spiketrains : list of neo.SpikeTrain or list of tuple
  166. A list of `neo.SpikeTrain` objects, or list of pairs
  167. (Train_ID, `neo.SpikeTrain`), where `Train_ID` can be any hashable
  168. object.
  169. bin_size : pq.Quantity
  170. Width of each time bin. Time is binned to determine synchrony.
  171. t_start : pq.Quantity
  172. The starting time. Only spikes occurring at times `t >= t_start` are
  173. considered. The first transaction contains spikes falling into the
  174. time segment `[t_start, t_start+bin_size]`.
  175. If None, takes the value of `spiketrain.t_start`, common for all
  176. input `spiketrains` (raises ValueError if it's not the case).
  177. Default: None.
  178. t_stop : pq.Quantity
  179. The ending time. Only spikes occurring at times `t < t_stop` are
  180. considered.
  181. If None, takes the value of `spiketrain.t_stop`, common for all
  182. input `spiketrains` (raises ValueError if it's not the case).
  183. Default: None.
  184. ids : list of int, optional
  185. List of spike train IDs.
  186. If None, the IDs `0` to `N-1` are used, where `N` is the number of
  187. input spike trains.
  188. Default: None.
  189. Returns
  190. -------
  191. list of list
  192. A list of transactions, where each transaction corresponds to a time
  193. bin and represents the list of spike train IDs having a spike in that
  194. time bin.
  195. Raises
  196. ------
  197. TypeError
  198. If `spiketrains` is not a list of `neo.SpikeTrain` or a list of tuples
  199. (id, `neo.SpikeTrain`).
  200. """
  201. if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
  202. trains = spiketrains
  203. if ids is None:
  204. ids = range(len(spiketrains))
  205. else:
  206. # (id, SpikeTrain) pairs
  207. try:
  208. ids, trains = zip(*spiketrains)
  209. except TypeError:
  210. raise TypeError('spiketrains must be either a list of ' +
  211. 'SpikeTrains or a list of (id, SpikeTrain) pairs')
  212. # Bin the spike trains and take for each of them the ids of filled bins
  213. binned = conv.BinnedSpikeTrain(
  214. trains, bin_size=bin_size, t_start=t_start, t_stop=t_stop)
  215. filled_bins = binned.spike_indices
  216. # Compute and return the transaction list
  217. return [[train_id for train_id, b in zip(ids, filled_bins)
  218. if bin_id in b] for bin_id in range(binned.n_bins)]
  219. def _analog_signal_step_interp(signal, times):
  220. """
  221. Compute the step-wise interpolation of a signal at desired times.
  222. Given a signal (e.g. a `neo.AnalogSignal`) `s` taking values `s[t0]` and
  223. `s[t1]` at two consecutive time points `t0` and `t1` (`t0 < t1`), the value
  224. of the step-wise interpolation at time `t: t0 <= t < t1` is given by
  225. `s[t] = s[t0]`.
  226. Parameters
  227. ----------
  228. signal : neo.AnalogSignal
  229. The analog signal, containing the discretization of the function to
  230. interpolate.
  231. times : pq.Quantity
  232. A vector of time points at which the step interpolation is computed.
  233. Returns
  234. -------
  235. pq.Quantity
  236. Object with same shape of `times` and containing
  237. the values of the interpolated signal at the time points in `times`.
  238. """
  239. dt = signal.sampling_period
  240. # Compute the ids of the signal times to the left of each time in times
  241. time_ids = np.floor(
  242. ((times - signal.t_start) / dt).rescale(
  243. pq.dimensionless).magnitude).astype('i')
  244. return (signal.magnitude[time_ids] * signal.units).rescale(signal.units)
  245. # =============================================================================
  246. # HERE ASSET STARTS
  247. # =============================================================================
  248. def _stretched_metric_2d(x, y, stretch, ref_angle):
  249. r"""
  250. Given a list of points on the real plane, identified by their abscissa `x`
  251. and ordinate `y`, compute a stretched transformation of the Euclidean
  252. distance among each of them.
  253. The classical euclidean distance `d` between points `(x1, y1)` and
  254. `(x2, y2)`, i.e., :math:`\sqrt((x1-x2)^2 + (y1-y2)^2)`, is multiplied by a
  255. factor
  256. .. math::
  257. 1 + (stretch - 1.) * \abs(\sin(ref_angle - \theta)),
  258. where :math:`\theta` is the angle between the points and the 45 degree
  259. direction (i.e., the line `y = x`).
  260. The stretching factor thus steadily varies between 1 (if the line
  261. connecting `(x1, y1)` and `(x2, y2)` has inclination `ref_angle`) and
  262. `stretch` (if that line has inclination `90 + ref_angle`).
  263. Parameters
  264. ----------
  265. x : (n,) np.ndarray
  266. Array of abscissas of all points among which to compute the distance.
  267. y : (n,) np.ndarray
  268. Array of ordinates of all points among which to compute the distance
  269. (same shape as `x`).
  270. stretch : float
  271. Maximum stretching factor, applied if the line connecting the points
  272. has inclination `90 + ref_angle`.
  273. ref_angle : float
  274. Reference angle in degrees (i.e., the inclination along which the
  275. stretching factor is 1).
  276. Returns
  277. -------
  278. D : (n,n) np.ndarray
  279. Square matrix of distances between all pairs of points.
  280. """
  281. alpha = np.deg2rad(ref_angle) # reference angle in radians
  282. # Create the array of points (one per row) for which to compute the
  283. # stretched distance
  284. points = np.vstack([x, y]).T
  285. # Compute the matrix D[i, j] of euclidean distances among points i and j
  286. D = scipy.spatial.distance_matrix(points, points)
  287. # Compute the angular coefficients of the line between each pair of points
  288. x_array = np.tile(x, reps=(len(x), 1))
  289. y_array = np.tile(y, reps=(len(y), 1))
  290. dX = x_array.T - x_array # dX[i,j]: x difference between points i and j
  291. dY = y_array.T - y_array # dY[i,j]: y difference between points i and j
  292. # Compute the matrix Theta of angles between each pair of points
  293. theta = np.arctan2(dY, dX)
  294. # Transform [-pi, pi] back to [-pi/2, pi/2]
  295. theta[theta < -np.pi / 2] += np.pi
  296. theta[theta > np.pi / 2] -= np.pi
  297. # Compute the matrix of stretching factors for each pair of points
  298. stretch_mat = 1 + (stretch - 1.) * np.abs(np.sin(alpha - theta))
  299. # Return the stretched distance matrix
  300. return D * stretch_mat
  301. def _interpolate_signals(signals, sampling_times, verbose=False):
  302. """
  303. Interpolate signals at given sampling times.
  304. """
  305. # Reshape all signals to one-dimensional array object (e.g. AnalogSignal)
  306. for i, signal in enumerate(signals):
  307. if signal.ndim == 2:
  308. signals[i] = signal.flatten()
  309. elif signal.ndim > 2:
  310. raise ValueError('elements in fir_rates must have 2 dimensions')
  311. if verbose:
  312. print('create time slices of the rates...')
  313. # Interpolate in the time bins
  314. interpolated_signal = np.vstack([_analog_signal_step_interp(
  315. signal, sampling_times).rescale('Hz').magnitude
  316. for signal in signals]) * pq.Hz
  317. return interpolated_signal
  318. def _num_iterations(n, d):
  319. if d > n:
  320. return 0
  321. if d == 1:
  322. return n
  323. if d == 2:
  324. # equivalent to np.sum(count_matrix)
  325. return n * (n + 1) // 2 - 1
  326. # Create square matrix with diagonal values equal to 2 to `n`.
  327. # Start from row/column with index == 2 to facilitate indexing.
  328. count_matrix = np.zeros((n + 1, n + 1), dtype=int)
  329. np.fill_diagonal(count_matrix, np.arange(n + 1))
  330. count_matrix[1, 1] = 0
  331. # Accumulate counts of all the iterations where the first index
  332. # is in the interval `d` to `n`.
  333. #
  334. # The counts for every level is obtained by accumulating the
  335. # `count_matrix`, which is the count of iterations with the first
  336. # index between `d` and `n`, when `d` == 2.
  337. #
  338. # For every value from 3 to `d`...
  339. # 1. Define each row `n` in the count matrix as the sum of all rows
  340. # equal or above.
  341. # 2. Set all rows above the current value of `d` with zeros.
  342. #
  343. # Example for `n` = 6 and `d` = 4:
  344. #
  345. # d = 2 (start) d = 3
  346. # count count
  347. # n n
  348. # 2 2 0 0 0 0
  349. # 3 0 3 0 0 0 ==> 3 2 3 0 0 0 ==>
  350. # 4 0 0 4 0 0 4 2 3 4 0 0
  351. # 5 0 0 0 5 0 5 2 3 4 5 0
  352. # 6 0 0 0 0 6 6 2 3 4 5 6
  353. #
  354. # d = 4
  355. # count
  356. # n
  357. #
  358. # 4 4 6 4 0 0
  359. # 5 6 9 8 5 0
  360. # 6 8 12 12 10 6
  361. #
  362. # The total number is the sum of the `count_matrix` when `d` has
  363. # the value passed to the function.
  364. #
  365. for cur_d in range(3, d + 1):
  366. for cur_n in range(n, 2, -1):
  367. count_matrix[cur_n, :] = np.sum(count_matrix[:cur_n + 1, :],
  368. axis=0)
  369. # Set previous `d` level to zeros
  370. count_matrix[cur_d - 1, :] = 0
  371. return np.sum(count_matrix)
  372. def _combinations_with_replacement(n, d):
  373. # Generate sequences of {a_i} such that
  374. # a_0 >= a_1 >= ... >= a_(d-1) and
  375. # d-i <= a_i <= n, for each i in [0, d-1].
  376. #
  377. # Almost equivalent to
  378. # list(itertools.combinations_with_replacement(range(n, 0, -1), r=d))[::-1]
  379. #
  380. # Example:
  381. # _combinations_with_replacement(n=13, d=3) -->
  382. # (3, 2, 1), (3, 2, 2), (3, 3, 1), ... , (13, 13, 12), (13, 13, 13).
  383. #
  384. # The implementation follows the insertion sort algorithm:
  385. # insert a new element a_i from right to left to keep the reverse sorted
  386. # order. Now substitute increment operation for insert.
  387. if d > n:
  388. return
  389. if d == 1:
  390. for matrix_entry in range(1, n + 1):
  391. yield (matrix_entry,)
  392. return
  393. sequence_sorted = list(range(d, 0, -1))
  394. input_order = tuple(sequence_sorted) # fixed
  395. while sequence_sorted[0] != n + 1:
  396. for last_element in range(1, sequence_sorted[-2] + 1):
  397. sequence_sorted[-1] = last_element
  398. yield tuple(sequence_sorted)
  399. increment_id = d - 2
  400. while increment_id > 0 and sequence_sorted[increment_id - 1] == \
  401. sequence_sorted[increment_id]:
  402. increment_id -= 1
  403. sequence_sorted[increment_id + 1:] = input_order[increment_id + 1:]
  404. sequence_sorted[increment_id] += 1
  405. def _jsf_uniform_orderstat_3d(u, n, verbose=False):
  406. r"""
  407. Considered n independent random variables X1, X2, ..., Xn all having
  408. uniform distribution in the interval (0, 1):
  409. .. centered:: Xi ~ Uniform(0, 1),
  410. given a 2D matrix U = (u_ij) where each U_i is an array of length d:
  411. U_i = [u0, u1, ..., u_{d-1}] of quantiles, with u1 <= u2 <= ... <= un,
  412. computes the joint survival function (jsf) of the d highest order
  413. statistics (U_{n-d+1}, U_{n-d+2}, ..., U_n),
  414. where U_k := "k-th highest X's" at each u_i, i.e.:
  415. .. centered:: jsf(u_i) = Prob(U_{n-k} >= u_ijk, k=0,1,..., d-1).
  416. Parameters
  417. ----------
  418. u : (A,d) np.ndarray
  419. 2D matrix of floats between 0 and 1.
  420. Each row `u_i` is an array of length `d`, considered a set of
  421. `d` largest order statistics extracted from a sample of `n` random
  422. variables whose cdf is `F(x) = x` for each `x`.
  423. The routine computes the joint cumulative probability of the `d`
  424. values in `u_ij`, for each `i` and `j`.
  425. n : int
  426. Size of the sample where the `d` largest order statistics `u_ij` are
  427. assumed to have been sampled from.
  428. verbose : bool
  429. If True, print messages during the computation.
  430. Default: False.
  431. Returns
  432. -------
  433. P_total : (A,) np.ndarray
  434. Matrix of joint survival probabilities. `s_ij` is the joint survival
  435. probability of the values `{u_ijk, k=0, ..., d-1}`.
  436. Note: the joint probability matrix computed for the ASSET analysis
  437. is `1 - S`.
  438. """
  439. num_p_vals, d = u.shape
  440. # Define ranges [1,...,n], [2,...,n], ..., [d,...,n] for the mute variables
  441. # used to compute the integral as a sum over all possibilities
  442. it_todo = _num_iterations(n, d)
  443. log_1 = np.log(1.)
  444. # Compute the log of the integral's coefficient
  445. logK = np.sum(np.log(np.arange(1, n + 1)))
  446. # Add to the 3D matrix u a bottom layer equal to 0 and a
  447. # top layer equal to 1. Then compute the difference du along
  448. # the first dimension.
  449. du = np.diff(u, prepend=0, append=1, axis=1)
  450. # precompute logarithms
  451. # ignore warnings about infinities, see inside the loop:
  452. # we replace 0 * ln(0) by 1 to get exp(0 * ln(0)) = 0 ** 0 = 1
  453. # the remaining infinities correctly evaluate to
  454. # exp(ln(0)) = exp(-inf) = 0
  455. with warnings.catch_warnings():
  456. warnings.simplefilter('ignore', RuntimeWarning)
  457. log_du = np.log(du)
  458. # prepare arrays for usage inside the loop
  459. di_scratch = np.empty_like(du, dtype=np.int32)
  460. log_du_scratch = np.empty_like(log_du)
  461. # precompute log(factorial)s
  462. # pad with a zero to get 0! = 1
  463. log_factorial = np.hstack((0, np.cumsum(np.log(range(1, n + 1)))))
  464. # compute the probabilities for each unique row of du
  465. # only loop over the indices and do all du entries at once
  466. # using matrix algebra
  467. # initialise probabilities to 0
  468. P_total = np.zeros(du.shape[0], dtype=np.float32)
  469. for iter_id, matrix_entries in enumerate(
  470. tqdm(_combinations_with_replacement(n, d=d),
  471. total=it_todo,
  472. desc="Joint survival function",
  473. disable=not verbose)):
  474. # if we are running with MPI
  475. if mpi_accelerated and iter_id % size != rank:
  476. continue
  477. # we only need the differences of the indices:
  478. di = -np.diff((n,) + matrix_entries + (0,))
  479. # reshape the matrix to be compatible with du
  480. di_scratch[:, range(len(di))] = di
  481. # use precomputed factorials
  482. sum_log_di_factorial = log_factorial[di].sum()
  483. # Compute for each i,j the contribution to the probability
  484. # given by this step, and add it to the total probability
  485. # Use precomputed log
  486. np.copyto(log_du_scratch, log_du)
  487. # for each a=0,1,...,A-1 and b=0,1,...,B-1, replace du with 1
  488. # whenever di_scratch = 0, so that du ** di_scratch = 1 (this avoids
  489. # nans when both du and di_scratch are 0, and is mathematically
  490. # correct)
  491. log_du_scratch[di_scratch == 0] = log_1
  492. di_log_du = di_scratch * log_du_scratch
  493. sum_di_log_du = di_log_du.sum(axis=1)
  494. logP = sum_di_log_du - sum_log_di_factorial
  495. P_total += np.exp(logP + logK)
  496. if mpi_accelerated:
  497. totals = np.zeros(du.shape[0], dtype=np.float32)
  498. # exchange all the results
  499. comm.Allreduce(
  500. [P_total, MPI.FLOAT],
  501. [totals, MPI.FLOAT],
  502. op=MPI.SUM)
  503. # We need to return the collected totals instead of the local P_total
  504. return totals
  505. return P_total
  506. def _pmat_neighbors(mat, filter_shape, n_largest):
  507. """
  508. Build the 3D matrix `L` of largest neighbors of elements in a 2D matrix
  509. `mat`.
  510. For each entry `mat[i, j]`, collects the `n_largest` elements with largest
  511. values around `mat[i, j]`, say `z_i, i=1,2,...,n_largest`, and assigns them
  512. to `L[i, j, :]`.
  513. The zone around `mat[i, j]` where largest neighbors are collected from is
  514. a rectangular area (kernel) of shape `(l, w) = filter_shape` centered
  515. around `mat[i, j]` and aligned along the diagonal.
  516. If `mat` is symmetric, only the triangle below the diagonal is considered.
  517. Parameters
  518. ----------
  519. mat : np.ndarray
  520. A square matrix of real-valued elements.
  521. filter_shape : tuple of int
  522. A pair of integers representing the kernel shape `(l, w)`.
  523. n_largest : int
  524. The number of largest neighbors to collect for each entry in `mat`.
  525. Returns
  526. -------
  527. lmat : np.ndarray
  528. A matrix of shape `(n_largest, l, w)` containing along the first
  529. dimension `lmat[:, i, j]` the largest neighbors of `mat[i, j]`.
  530. Raises
  531. ------
  532. ValueError
  533. If `filter_shape[1]` is not lower than `filter_shape[0]`.
  534. Warns
  535. -----
  536. UserWarning
  537. If both entries in `filter_shape` are not odd values (i.e., the kernel
  538. is not centered on the data point used in the calculation).
  539. """
  540. l, w = filter_shape
  541. # if the matrix is symmetric the diagonal was set to 0.5
  542. # when computing the probability matrix
  543. symmetric = np.all(np.diagonal(mat) == 0.5)
  544. # Check consistent arguments
  545. if w >= l:
  546. raise ValueError('filter_shape width must be lower than length')
  547. if not ((w % 2) and (l % 2)):
  548. warnings.warn('The kernel is not centered on the datapoint in whose'
  549. 'calculation it is used. Consider using odd values'
  550. 'for both entries of filter_shape.')
  551. # Construct the kernel
  552. filt = np.ones((l, l), dtype=np.float32)
  553. filt = np.triu(filt, -w)
  554. filt = np.tril(filt, w)
  555. # Convert mat values to floats, and replaces np.infs with specified input
  556. # values
  557. mat = np.array(mat, dtype=np.float32)
  558. # Initialize the matrix of d-largest values as a matrix of zeroes
  559. lmat = np.zeros((n_largest, mat.shape[0], mat.shape[1]), dtype=np.float32)
  560. N_bin_y = mat.shape[0]
  561. N_bin_x = mat.shape[1]
  562. # if the matrix is symmetric do not use kernel positions intersected
  563. # by the diagonal
  564. if symmetric:
  565. bin_range_y = range(l, N_bin_y - l + 1)
  566. else:
  567. bin_range_y = range(N_bin_y - l + 1)
  568. bin_range_x = range(N_bin_x - l + 1)
  569. # compute matrix of largest values
  570. for y in bin_range_y:
  571. if symmetric:
  572. # x range depends on y position
  573. bin_range_x = range(y - l + 1)
  574. for x in bin_range_x:
  575. patch = mat[y: y + l, x: x + l]
  576. mskd = np.multiply(filt, patch)
  577. largest_vals = np.sort(mskd, axis=None)[-n_largest:]
  578. lmat[:, y + (l // 2), x + (l // 2)] = largest_vals
  579. return lmat
  580. def synchronous_events_intersection(sse1, sse2, intersection='linkwise'):
  581. """
  582. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  583. consisting of a pool of positions `(iK, jK)` of matrix entries and
  584. associated synchronous events `SK`, finds the intersection among them.
  585. The intersection can be performed 'pixelwise' or 'linkwise'.
  586. * if 'pixelwise', it yields a new SSE which retains only events in
  587. `sse1` whose pixel position matches a pixel position in `sse2`. This
  588. operation is not symmetric:
  589. `intersection(sse1, sse2) != intersection(sse2, sse1)`.
  590. * if 'linkwise', an additional step is performed where each retained
  591. synchronous event `SK` in `sse1` is intersected with the
  592. corresponding event in `sse2`. This yields a symmetric operation:
  593. `intersection(sse1, sse2) = intersection(sse2, sse1)`.
  594. Both `sse1` and `sse2` must be provided as dictionaries of the type
  595. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  596. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  597. Parameters
  598. ----------
  599. sse1, sse2 : dict
  600. Each is a dictionary of pixel positions `(i, j)` as keys and sets `S`
  601. of synchronous events as values (see above).
  602. intersection : {'pixelwise', 'linkwise'}, optional
  603. The type of intersection to perform among the two SSEs (see above).
  604. Default: 'linkwise'.
  605. Returns
  606. -------
  607. sse_new : dict
  608. A new SSE (same structure as `sse1` and `sse2`) which retains only the
  609. events of `sse1` associated to keys present both in `sse1` and `sse2`.
  610. If `intersection = 'linkwise'`, such events are additionally
  611. intersected with the associated events in `sse2`.
  612. See Also
  613. --------
  614. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  615. """
  616. sse_new = sse1.copy()
  617. for pixel1 in sse1.keys():
  618. if pixel1 not in sse2.keys():
  619. del sse_new[pixel1]
  620. if intersection == 'linkwise':
  621. for pixel1, link1 in sse_new.items():
  622. sse_new[pixel1] = link1.intersection(sse2[pixel1])
  623. if len(sse_new[pixel1]) == 0:
  624. del sse_new[pixel1]
  625. elif intersection == 'pixelwise':
  626. pass
  627. else:
  628. raise ValueError(
  629. "intersection (=%s) can only be" % intersection +
  630. " 'pixelwise' or 'linkwise'")
  631. return sse_new
  632. def synchronous_events_difference(sse1, sse2, difference='linkwise'):
  633. """
  634. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  635. consisting of a pool of pixel positions and associated synchronous events
  636. (see below), computes the difference between `sse1` and `sse2`.
  637. The difference can be performed 'pixelwise' or 'linkwise':
  638. * if 'pixelwise', it yields a new SSE which contains all (and only) the
  639. events in `sse1` whose pixel position doesn't match any pixel in
  640. `sse2`.
  641. * if 'linkwise', for each pixel `(i, j)` in `sse1` and corresponding
  642. synchronous event `S1`, if `(i, j)` is a pixel in `sse2`
  643. corresponding to the event `S2`, it retains the set difference
  644. `S1 - S2`. If `(i, j)` is not a pixel in `sse2`, it retains the full
  645. set `S1`.
  646. Note that in either case the difference is a non-symmetric operation:
  647. `intersection(sse1, sse2) != intersection(sse2, sse1)`.
  648. Both `sse1` and `sse2` must be provided as dictionaries of the type
  649. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  650. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  651. Parameters
  652. ----------
  653. sse1, sse2 : dict
  654. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  655. synchronous events as values (see above).
  656. difference : {'pixelwise', 'linkwise'}, optional
  657. The type of difference to perform between `sse1` and `sse2` (see
  658. above).
  659. Default: 'linkwise'.
  660. Returns
  661. -------
  662. sse_new : dict
  663. A new SSE (same structure as `sse1` and `sse2`) which retains the
  664. difference between `sse1` and `sse2`.
  665. See Also
  666. --------
  667. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  668. """
  669. sse_new = sse1.copy()
  670. for pixel1 in sse1.keys():
  671. if pixel1 in sse2.keys():
  672. if difference == 'pixelwise':
  673. del sse_new[pixel1]
  674. elif difference == 'linkwise':
  675. sse_new[pixel1] = sse_new[pixel1].difference(sse2[pixel1])
  676. if len(sse_new[pixel1]) == 0:
  677. del sse_new[pixel1]
  678. else:
  679. raise ValueError(
  680. "difference (=%s) can only be" % difference +
  681. " 'pixelwise' or 'linkwise'")
  682. return sse_new
  683. def _remove_empty_events(sse):
  684. """
  685. Given a sequence of synchronous events (SSE) `sse` consisting of a pool of
  686. pixel positions and associated synchronous events (see below), returns a
  687. copy of `sse` where all empty events have been removed.
  688. `sse` must be provided as a dictionary of type
  689. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  690. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  691. Parameters
  692. ----------
  693. sse : dict
  694. A dictionary of pixel positions `(i, j)` as keys, and sets `S` of
  695. synchronous events as values (see above).
  696. Returns
  697. -------
  698. sse_new : dict
  699. A copy of `sse` where all empty events have been removed.
  700. """
  701. sse_new = sse.copy()
  702. for pixel, link in sse.items():
  703. if link == set([]):
  704. del sse_new[pixel]
  705. return sse_new
  706. def synchronous_events_identical(sse1, sse2):
  707. """
  708. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  709. consisting of a pool of pixel positions and associated synchronous events
  710. (see below), determines whether `sse1` is strictly contained in `sse2`.
  711. `sse1` is strictly contained in `sse2` if all its pixels are pixels of
  712. `sse2`,
  713. if its associated events are subsets of the corresponding events
  714. in `sse2`, and if `sse2` contains events, or neuron IDs in some event,
  715. which do not belong to `sse1` (i.e., `sse1` and `sse2` are not identical).
  716. Both `sse1` and `sse2` must be provided as dictionaries of the type
  717. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  718. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  719. Parameters
  720. ----------
  721. sse1, sse2 : dict
  722. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  723. synchronous events as values.
  724. Returns
  725. -------
  726. bool
  727. True if `sse1` is identical to `sse2`.
  728. See Also
  729. --------
  730. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  731. """
  732. # Remove empty links from sse11 and sse22, if any
  733. sse11 = _remove_empty_events(sse1)
  734. sse22 = _remove_empty_events(sse2)
  735. # Return whether sse11 == sse22
  736. return sse11 == sse22
  737. def synchronous_events_no_overlap(sse1, sse2):
  738. """
  739. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  740. consisting of a pool of pixel positions and associated synchronous events
  741. (see below), determines whether `sse1` and `sse2` are disjoint.
  742. Two SSEs are disjoint if they don't share pixels, or if the events
  743. associated to common pixels are disjoint.
  744. Both `sse1` and `sse2` must be provided as dictionaries of the type
  745. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  746. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  747. Parameters
  748. ----------
  749. sse1, sse2 : dict
  750. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  751. synchronous events as values.
  752. Returns
  753. -------
  754. bool
  755. True if `sse1` is disjoint from `sse2`.
  756. See Also
  757. --------
  758. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  759. """
  760. # Remove empty links from sse11 and sse22, if any
  761. sse11 = _remove_empty_events(sse1)
  762. sse22 = _remove_empty_events(sse2)
  763. # If both SSEs are empty, return False (we consider them equal)
  764. if sse11 == {} and sse22 == {}:
  765. return False
  766. common_pixels = set(sse11.keys()).intersection(set(sse22.keys()))
  767. if common_pixels == set([]):
  768. return True
  769. elif all(sse11[p].isdisjoint(sse22[p]) for p in common_pixels):
  770. return True
  771. else:
  772. return False
  773. def synchronous_events_contained_in(sse1, sse2):
  774. """
  775. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  776. consisting of a pool of pixel positions and associated synchronous events
  777. (see below), determines whether `sse1` is strictly contained in `sse2`.
  778. `sse1` is strictly contained in `sse2` if all its pixels are pixels of
  779. `sse2`, if its associated events are subsets of the corresponding events
  780. in `sse2`, and if `sse2` contains non-empty events, or neuron IDs in some
  781. event, which do not belong to `sse1` (i.e., `sse1` and `sse2` are not
  782. identical).
  783. Both `sse1` and `sse2` must be provided as dictionaries of the type
  784. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  785. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  786. Parameters
  787. ----------
  788. sse1, sse2 : dict
  789. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  790. synchronous events as values.
  791. Returns
  792. -------
  793. bool
  794. True if `sse1` is a subset of `sse2`.
  795. See Also
  796. --------
  797. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  798. """
  799. # Remove empty links from sse11 and sse22, if any
  800. sse11 = _remove_empty_events(sse1)
  801. sse22 = _remove_empty_events(sse2)
  802. # Return False if sse11 and sse22 are disjoint
  803. if synchronous_events_identical(sse11, sse22):
  804. return False
  805. # Return False if any pixel in sse1 is not contained in sse2, or if any
  806. # link of sse1 is not a subset of the corresponding link in sse2.
  807. # Otherwise (if sse1 is a subset of sse2) continue
  808. for pixel1, link1 in sse11.items():
  809. if pixel1 not in sse22.keys():
  810. return False
  811. elif not link1.issubset(sse22[pixel1]):
  812. return False
  813. # Check that sse1 is a STRICT subset of sse2, i.e. that sse2 contains at
  814. # least one pixel or neuron id not present in sse1.
  815. return not synchronous_events_identical(sse11, sse22)
  816. def synchronous_events_contains_all(sse1, sse2):
  817. """
  818. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  819. consisting of a pool of pixel positions and associated synchronous events
  820. (see below), determines whether `sse1` strictly contains `sse2`.
  821. `sse1` strictly contains `sse2` if it contains all pixels of `sse2`, if all
  822. associated events in `sse1` contain those in `sse2`, and if `sse1`
  823. additionally contains other pixels / events not contained in `sse2`.
  824. Both `sse1` and `sse2` must be provided as dictionaries of the type
  825. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  826. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  827. Parameters
  828. ----------
  829. sse1, sse2 : dict
  830. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  831. synchronous events as values.
  832. Returns
  833. -------
  834. bool
  835. True if `sse1` strictly contains `sse2`.
  836. Notes
  837. -----
  838. `synchronous_events_contains_all(sse1, sse2)` is identical to
  839. `synchronous_events_is_subsequence(sse2, sse1)`.
  840. See Also
  841. --------
  842. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  843. """
  844. return synchronous_events_contained_in(sse2, sse1)
  845. def synchronous_events_overlap(sse1, sse2):
  846. """
  847. Given two sequences of synchronous events (SSEs) `sse1` and `sse2`, each
  848. consisting of a pool of pixel positions and associated synchronous events
  849. (see below), determines whether the two SSEs overlap.
  850. The SSEs overlap if they are not equal and none of them is a superset of
  851. the other one but they are also not disjoint.
  852. Both `sse1` and `sse2` must be provided as dictionaries of the type
  853. .. centered:: {(i1, j1): S1, (i2, j2): S2, ..., (iK, jK): SK},
  854. where each `i`, `j` is an integer and each `S` is a set of neuron IDs.
  855. Parameters
  856. ----------
  857. sse1, sse2 : dict
  858. Dictionaries of pixel positions `(i, j)` as keys and sets `S` of
  859. synchronous events as values.
  860. Returns
  861. -------
  862. bool
  863. True if `sse1` and `sse2` overlap.
  864. See Also
  865. --------
  866. ASSET.extract_synchronous_events : extract SSEs from given spike trains
  867. """
  868. contained_in = synchronous_events_contained_in(sse1, sse2)
  869. contains_all = synchronous_events_contains_all(sse1, sse2)
  870. identical = synchronous_events_identical(sse1, sse2)
  871. is_disjoint = synchronous_events_no_overlap(sse1, sse2)
  872. return not (contained_in or contains_all or identical or is_disjoint)
  873. def _signals_t_start_stop(signals, t_start=None, t_stop=None):
  874. if t_start is None:
  875. t_start = _signals_same_attribute(signals, 't_start')
  876. if t_stop is None:
  877. t_stop = _signals_same_attribute(signals, 't_stop')
  878. return t_start, t_stop
  879. def _intersection_matrix(spiketrains, spiketrains_y, bin_size, t_start_x,
  880. t_start_y, t_stop_x, t_stop_y, normalization=None):
  881. if spiketrains_y is None:
  882. spiketrains_y = spiketrains
  883. # Compute the binned spike train matrices, along both time axes
  884. spiketrains_binned = conv.BinnedSpikeTrain(
  885. spiketrains, bin_size=bin_size,
  886. t_start=t_start_x, t_stop=t_stop_x)
  887. spiketrains_binned_y = conv.BinnedSpikeTrain(
  888. spiketrains_y, bin_size=bin_size,
  889. t_start=t_start_y, t_stop=t_stop_y)
  890. # Compute imat by matrix multiplication
  891. bsts_x = spiketrains_binned.sparse_matrix
  892. bsts_y = spiketrains_binned_y.sparse_matrix
  893. # Compute the number of spikes in each bin, for both time axes
  894. # 'A1' property returns self as a flattened ndarray.
  895. spikes_per_bin_x = bsts_x.sum(axis=0).A1
  896. spikes_per_bin_y = bsts_y.sum(axis=0).A1
  897. # Compute the intersection matrix imat
  898. imat = bsts_x.T.dot(bsts_y).toarray().astype(np.float32)
  899. for ii in range(bsts_x.shape[1]):
  900. # Normalize the row
  901. col_sum = bsts_x[:, ii].sum()
  902. if normalization is None or col_sum == 0:
  903. norm_coef = 1.
  904. elif normalization == 'intersection':
  905. norm_coef = np.minimum(
  906. spikes_per_bin_x[ii], spikes_per_bin_y)
  907. elif normalization == 'mean':
  908. # geometric mean
  909. norm_coef = np.sqrt(
  910. spikes_per_bin_x[ii] * spikes_per_bin_y)
  911. elif normalization == 'union':
  912. norm_coef = np.array([(bsts_x[:, ii]
  913. + bsts_y[:, jj]).count_nonzero()
  914. for jj in range(bsts_y.shape[1])])
  915. else:
  916. raise ValueError(
  917. "Invalid parameter 'norm': {}".format(normalization))
  918. # If normalization required, for each j such that bsts_y[j] is
  919. # identically 0 the code above sets imat[:, j] to identically nan.
  920. # Substitute 0s instead.
  921. imat[ii, :] = np.divide(imat[ii, :], norm_coef,
  922. out=np.zeros(imat.shape[1],
  923. dtype=np.float32),
  924. where=norm_coef != 0)
  925. # Return the intersection matrix and the edges of the bins used for the
  926. # x and y axes, respectively.
  927. return imat
  928. class ASSET(object):
  929. """
  930. Analysis of Sequences of Synchronous EvenTs class.
  931. Parameters
  932. ----------
  933. spiketrains_i, spiketrains_j : list of neo.SpikeTrain
  934. Input spike trains for the first and second time dimensions,
  935. respectively, to compute the p-values from.
  936. If `spiketrains_y` is None, it's set to `spiketrains`.
  937. bin_size : pq.Quantity, optional
  938. The width of the time bins used to compute the probability matrix.
  939. t_start_i, t_start_j : pq.Quantity, optional
  940. The start time of the binning for the first and second axes,
  941. respectively.
  942. If None, the attribute `t_start` of the spike trains is used
  943. (if the same for all spike trains).
  944. Default: None.
  945. t_stop_i, t_stop_j : pq.Quantity, optional
  946. The stop time of the binning for the first and second axes,
  947. respectively.
  948. If None, the attribute `t_stop` of the spike trains is used
  949. (if the same for all spike trains).
  950. Default: None.
  951. verbose : bool, optional
  952. If True, print messages and show progress bar.
  953. Default: True.
  954. Raises
  955. ------
  956. ValueError
  957. If the `t_start` & `t_stop` times are not (one of):
  958. perfectly aligned;
  959. fully disjoint.
  960. """
  961. def __init__(self, spiketrains_i, spiketrains_j=None, bin_size=3 * pq.ms,
  962. t_start_i=None, t_start_j=None, t_stop_i=None, t_stop_j=None,
  963. verbose=True):
  964. self.spiketrains_i = spiketrains_i
  965. if spiketrains_j is None:
  966. spiketrains_j = spiketrains_i
  967. self.spiketrains_j = spiketrains_j
  968. self.bin_size = bin_size
  969. self.t_start_i, self.t_stop_i = _signals_t_start_stop(
  970. spiketrains_i,
  971. t_start=t_start_i,
  972. t_stop=t_stop_i)
  973. self.t_start_j, self.t_stop_j = _signals_t_start_stop(
  974. spiketrains_j,
  975. t_start=t_start_j,
  976. t_stop=t_stop_j)
  977. self.verbose = verbose
  978. msg = 'The time intervals for x and y need to be either identical ' \
  979. 'or fully disjoint, but they are:\n' \
  980. 'x: ({}, {}) and y: ({}, {}).'.format(self.t_start_i,
  981. self.t_stop_i,
  982. self.t_start_j,
  983. self.t_stop_j)
  984. # the starts have to be perfectly aligned for the binning to work
  985. # the stops can differ without impacting the binning
  986. if self.t_start_i == self.t_start_j:
  987. if not _quantities_almost_equal(self.t_stop_i, self.t_stop_j):
  988. raise ValueError(msg)
  989. elif (self.t_start_i < self.t_start_j < self.t_stop_i) \
  990. or (self.t_start_i < self.t_stop_j < self.t_stop_i):
  991. raise ValueError(msg)
  992. # Compute the binned spike train matrices, along both time axes
  993. self.spiketrains_binned_i = conv.BinnedSpikeTrain(
  994. self.spiketrains_i, bin_size=self.bin_size,
  995. t_start=self.t_start_i, t_stop=self.t_stop_i)
  996. self.spiketrains_binned_j = conv.BinnedSpikeTrain(
  997. self.spiketrains_j, bin_size=self.bin_size,
  998. t_start=self.t_start_j, t_stop=self.t_stop_j)
  999. @property
  1000. def x_edges(self):
  1001. """
  1002. A Quantity array of `n+1` edges of the bins used for the horizontal
  1003. axis of the intersection matrix, where `n` is the number of bins that
  1004. time was discretized in.
  1005. """
  1006. return self.spiketrains_binned_i.bin_edges.rescale(self.bin_size.units)
  1007. @property
  1008. def y_edges(self):
  1009. """
  1010. A Quantity array of `n+1` edges of the bins used for the vertical axis
  1011. of the intersection matrix, where `n` is the number of bins that
  1012. time was discretized in.
  1013. """
  1014. return self.spiketrains_binned_j.bin_edges.rescale(self.bin_size.units)
  1015. def is_symmetric(self):
  1016. """
  1017. Returns
  1018. -------
  1019. bool
  1020. Whether the intersection matrix is symmetric or not.
  1021. See Also
  1022. --------
  1023. ASSET.intersection_matrix
  1024. """
  1025. return _quantities_almost_equal(self.x_edges[0], self.y_edges[0])
  1026. def intersection_matrix(self, normalization=None):
  1027. """
  1028. Generates the intersection matrix from a list of spike trains.
  1029. Given a list of `neo.SpikeTrain`, consider two binned versions of them
  1030. differing for the starting and ending times of the binning:
  1031. `t_start_x`, `t_stop_x`, `t_start_y` and `t_stop_y` respectively (the
  1032. time intervals can be either identical or completely disjoint). Then
  1033. calculate the intersection matrix `M` of the two binned data, where
  1034. `M[i,j]` is the overlap of bin `i` in the first binned data and bin `j`
  1035. in the second binned data (i.e., the number of spike trains spiking at
  1036. both bin `i` and bin `j`).
  1037. The matrix entries can be normalized to values between `0` and `1` via
  1038. different normalizations (see "Parameters" section).
  1039. Parameters
  1040. ----------
  1041. normalization : {'intersection', 'mean', 'union'} or None, optional
  1042. The normalization type to be applied to each entry `M[i,j]` of the
  1043. intersection matrix `M`. Given the sets `s_i` and `s_j` of neuron
  1044. IDs in the bins `i` and `j` respectively, the normalization
  1045. coefficient can be:
  1046. * None: no normalisation (row counts)
  1047. * 'intersection': `len(intersection(s_i, s_j))`
  1048. * 'mean': `sqrt(len(s_1) * len(s_2))`
  1049. * 'union': `len(union(s_i, s_j))`
  1050. Default: None.
  1051. Returns
  1052. -------
  1053. imat : (n,n) np.ndarray
  1054. The floating point intersection matrix of a list of spike trains.
  1055. It has the shape `(n, n)`, where `n` is the number of bins that
  1056. time was discretized in.
  1057. """
  1058. imat = _intersection_matrix(self.spiketrains_i, self.spiketrains_j,
  1059. self.bin_size,
  1060. self.t_start_i, self.t_start_j,
  1061. self.t_stop_i, self.t_stop_j,
  1062. normalization=normalization)
  1063. return imat
  1064. def probability_matrix_montecarlo(self, n_surrogates, imat=None,
  1065. surrogate_method='dither_spikes',
  1066. surrogate_dt=None):
  1067. """
  1068. Given a list of parallel spike trains, estimate the cumulative
  1069. probability of each entry in their intersection matrix by a Monte Carlo
  1070. approach using surrogate data.
  1071. Contrarily to the analytical version (see
  1072. :func:`ASSET.probability_matrix_analytical`) the Monte Carlo one does
  1073. not incorporate the assumptions of Poissonianity in the null
  1074. hypothesis.
  1075. The method produces surrogate spike trains (using one of several
  1076. methods at disposal, see "Parameters" section) and calculates their
  1077. intersection matrix `M`. For each entry `(i, j)`, the intersection CDF
  1078. `P[i, j]` is then given by:
  1079. .. centered:: P[i, j] = #(spike_train_surrogates such that
  1080. M[i, j] < I[i, j]) / #(spike_train_surrogates)
  1081. If `P[i, j]` is large (close to 1), `I[i, j]` is statistically
  1082. significant: the probability to observe an overlap equal to or larger
  1083. than `I[i, j]` under the null hypothesis is `1 - P[i, j]`, very small.
  1084. Parameters
  1085. ----------
  1086. n_surrogates : int
  1087. The number of spike train surrogates to generate for the bootstrap
  1088. procedure.
  1089. imat : (n,n) np.ndarray or None, optional
  1090. The floating point intersection matrix of a list of spike trains.
  1091. It has the shape `(n, n)`, where `n` is the number of bins that
  1092. time was discretized in.
  1093. If None, the output of :func:`ASSET.intersection_matrix` is used.
  1094. Default: None
  1095. surrogate_method : {'dither_spike_train', 'dither_spikes',
  1096. 'jitter_spikes',
  1097. 'randomise_spikes', 'shuffle_isis',
  1098. 'joint_isi_dithering'}, optional
  1099. The method to generate surrogate spike trains. Refer to the
  1100. :func:`spike_train_surrogates.surrogates` documentation for more
  1101. information about each surrogate method. Note that some of these
  1102. methods need `surrogate_dt` parameter, others ignore it.
  1103. Default: 'dither_spike_train'.
  1104. surrogate_dt : pq.Quantity, optional
  1105. For surrogate methods shifting spike times randomly around their
  1106. original time ('dither_spike_train', 'dither_spikes') or replacing
  1107. them randomly within a certain window ('jitter_spikes'),
  1108. `surrogate_dt` represents the size of that shift (window). For
  1109. other methods, `surrogate_dt` is ignored.
  1110. If None, it's set to `self.bin_size * 5`.
  1111. Default: None.
  1112. Returns
  1113. -------
  1114. pmat : np.ndarray
  1115. The cumulative probability matrix. `pmat[i, j]` represents the
  1116. estimated probability of having an overlap between bins `i` and `j`
  1117. STRICTLY LOWER than the observed overlap, under the null hypothesis
  1118. of independence of the input spike trains.
  1119. Notes
  1120. -----
  1121. We recommend playing with `surrogate_dt` parameter to see how it
  1122. influences the result matrix. For this, refer to the ASSET tutorial.
  1123. See Also
  1124. --------
  1125. ASSET.probability_matrix_analytical : analytical derivation of the
  1126. matrix
  1127. """
  1128. if imat is None:
  1129. # Compute the intersection matrix of the original data
  1130. imat = self.intersection_matrix()
  1131. if surrogate_dt is None:
  1132. surrogate_dt = self.bin_size * 5
  1133. symmetric = self.is_symmetric()
  1134. # Generate surrogate spike trains as a list surrs
  1135. # Compute the p-value matrix pmat; pmat[i, j] counts the fraction of
  1136. # surrogate data whose intersection value at (i, j) is lower than or
  1137. # equal to that of the original data
  1138. pmat = np.zeros(imat.shape, dtype=np.int32)
  1139. for surr_id in trange(n_surrogates, desc="pmat_bootstrap",
  1140. disable=not self.verbose):
  1141. if mpi_accelerated and surr_id % size != rank:
  1142. continue
  1143. surrogates = [spike_train_surrogates.surrogates(
  1144. st, n_surrogates=1,
  1145. method=surrogate_method,
  1146. dt=surrogate_dt,
  1147. decimals=None,
  1148. edges=True)[0]
  1149. for st in self.spiketrains_i]
  1150. if symmetric:
  1151. surrogates_y = surrogates
  1152. else:
  1153. surrogates_y = [spike_train_surrogates.surrogates(
  1154. st, n_surrogates=1, method=surrogate_method,
  1155. dt=surrogate_dt, decimals=None, edges=True)[0]
  1156. for st in self.spiketrains_j]
  1157. imat_surr = _intersection_matrix(surrogates, surrogates_y,
  1158. self.bin_size,
  1159. self.t_start_i, self.t_start_j,
  1160. self.t_stop_i, self.t_stop_j)
  1161. pmat += (imat_surr <= (imat - 1))
  1162. del imat_surr
  1163. if mpi_accelerated:
  1164. pmat = comm.allreduce(pmat, op=MPI.SUM)
  1165. pmat = pmat * 1. / n_surrogates
  1166. if symmetric:
  1167. np.fill_diagonal(pmat, 0.5)
  1168. return pmat
  1169. def probability_matrix_analytical(self, imat=None,
  1170. firing_rates_x='estimate',
  1171. firing_rates_y='estimate',
  1172. kernel_width=100 * pq.ms):
  1173. r"""
  1174. Given a list of spike trains, approximates the cumulative probability
  1175. of each entry in their intersection matrix.
  1176. The approximation is analytical and works under the assumptions that
  1177. the input spike trains are independent and Poisson. It works as
  1178. follows:
  1179. * Bin each spike train at the specified `bin_size`: this yields a
  1180. binary array of 1s (spike in bin) and 0s (no spike in bin;
  1181. clipping used);
  1182. * If required, estimate the rate profile of each spike train by
  1183. convolving the binned array with a boxcar kernel of user-defined
  1184. length;
  1185. * For each neuron `k` and each pair of bins `i` and `j`, compute
  1186. the probability :math:`p_ijk` that neuron `k` fired in both bins
  1187. `i` and `j`.
  1188. * Approximate the probability distribution of the intersection
  1189. value at `(i, j)` by a Poisson distribution with mean parameter
  1190. :math:`l = \sum_k (p_ijk)`,
  1191. justified by Le Cam's approximation of a sum of independent
  1192. Bernouilli random variables with a Poisson distribution.
  1193. Parameters
  1194. ----------
  1195. imat : (n,n) np.ndarray or None, optional
  1196. The intersection matrix of a list of spike trains.
  1197. It has the shape `(n, n)`, where `n` is the number of bins that
  1198. time was discretized in.
  1199. If None, the output of :func:`ASSET.intersection_matrix` is used.
  1200. Default: None
  1201. firing_rates_x, firing_rates_y : list of neo.AnalogSignal or 'estimate'
  1202. If a list, `firing_rates[i]` is the firing rate of the spike train
  1203. `spiketrains[i]`.
  1204. If 'estimate', firing rates are estimated by simple boxcar kernel
  1205. convolution, with the specified `kernel_width`.
  1206. Default: 'estimate'.
  1207. kernel_width : pq.Quantity, optional
  1208. The total width of the kernel used to estimate the rate profiles
  1209. when `firing_rates` is 'estimate'.
  1210. Default: 100 * pq.ms.
  1211. Returns
  1212. -------
  1213. pmat : np.ndarray
  1214. The cumulative probability matrix. `pmat[i, j]` represents the
  1215. estimated probability of having an overlap between bins `i` and `j`
  1216. STRICTLY LOWER than the observed overlap, under the null hypothesis
  1217. of independence of the input spike trains.
  1218. """
  1219. if imat is None:
  1220. # Compute the intersection matrix of the original data
  1221. imat = self.intersection_matrix()
  1222. symmetric = self.is_symmetric()
  1223. bsts_x_matrix = self.spiketrains_binned_i.to_bool_array()
  1224. if symmetric:
  1225. bsts_y_matrix = bsts_x_matrix
  1226. else:
  1227. bsts_y_matrix = self.spiketrains_binned_j.to_bool_array()
  1228. # Check that the nr. neurons is identical between the two axes
  1229. if bsts_x_matrix.shape[0] != bsts_y_matrix.shape[0]:
  1230. raise ValueError(
  1231. 'Different number of neurons along the x and y axis!')
  1232. # Define the firing rate profiles
  1233. if firing_rates_x == 'estimate':
  1234. # If rates are to be estimated, create the rate profiles as
  1235. # Quantity objects obtained by boxcar-kernel convolution
  1236. fir_rate_x = self._rate_of_binned_spiketrain(bsts_x_matrix,
  1237. kernel_width)
  1238. elif isinstance(firing_rates_x, list):
  1239. # If rates provided as lists of AnalogSignals, create time slices
  1240. # for both axes, interpolate in the time bins of interest and
  1241. # convert to Quantity
  1242. fir_rate_x = _interpolate_signals(
  1243. firing_rates_x, self.spiketrains_binned_i.bin_edges[:-1],
  1244. self.verbose)
  1245. else:
  1246. raise ValueError(
  1247. 'fir_rates_x must be a list or the string "estimate"')
  1248. if symmetric:
  1249. fir_rate_y = fir_rate_x
  1250. elif firing_rates_y == 'estimate':
  1251. fir_rate_y = self._rate_of_binned_spiketrain(bsts_y_matrix,
  1252. kernel_width)
  1253. elif isinstance(firing_rates_y, list):
  1254. # If rates provided as lists of AnalogSignals, create time slices
  1255. # for both axes, interpolate in the time bins of interest and
  1256. # convert to Quantity
  1257. fir_rate_y = _interpolate_signals(
  1258. firing_rates_y, self.spiketrains_binned_j.bin_edges[:-1],
  1259. self.verbose)
  1260. else:
  1261. raise ValueError(
  1262. 'fir_rates_y must be a list or the string "estimate"')
  1263. # For each neuron, compute the prob. that that neuron spikes in any bin
  1264. if self.verbose:
  1265. print('compute the prob. that each neuron fires in each pair of '
  1266. 'bins...')
  1267. spike_probs_x = [1. - np.exp(-(rate * self.bin_size).rescale(
  1268. pq.dimensionless).magnitude) for rate in fir_rate_x]
  1269. if symmetric:
  1270. spike_probs_y = spike_probs_x
  1271. else:
  1272. spike_probs_y = [1. - np.exp(-(rate * self.bin_size).rescale(
  1273. pq.dimensionless).magnitude) for rate in fir_rate_y]
  1274. # For each neuron k compute the matrix of probabilities p_ijk that
  1275. # neuron k spikes in both bins i and j. (For i = j it's just spike
  1276. # probs[k][i])
  1277. spike_prob_mats = [np.outer(probx, proby) for (probx, proby) in
  1278. zip(spike_probs_x, spike_probs_y)]
  1279. # Compute the matrix Mu[i, j] of parameters for the Poisson
  1280. # distributions which describe, at each (i, j), the approximated
  1281. # overlap probability. This matrix is just the sum of the probability
  1282. # matrices computed above
  1283. if self.verbose:
  1284. print(
  1285. "compute the probability matrix by Le Cam's approximation...")
  1286. Mu = np.sum(spike_prob_mats, axis=0)
  1287. # Compute the probability matrix obtained from imat using the Poisson
  1288. # pdfs
  1289. pmat = scipy.stats.poisson.cdf(imat - 1, Mu)
  1290. if symmetric:
  1291. # Substitute 0.5 to the elements along the main diagonal
  1292. if self.verbose:
  1293. print("substitute 0.5 to elements along the main diagonal...")
  1294. np.fill_diagonal(pmat, 0.5)
  1295. return pmat
  1296. def joint_probability_matrix(self, pmat, filter_shape, n_largest,
  1297. min_p_value=1e-5):
  1298. """
  1299. Map a probability matrix `pmat` to a joint probability matrix `jmat`,
  1300. where `jmat[i, j]` is the joint p-value of the largest neighbors of
  1301. `pmat[i, j]`.
  1302. The values of `pmat` are assumed to be uniformly distributed in the
  1303. range [0, 1]. Centered a rectangular kernel of shape
  1304. `filter_shape=(l, w)` around each entry `pmat[i, j]`,
  1305. aligned along the diagonal where `pmat[i, j]` lies into, extracts the
  1306. `n_largest` values falling within the kernel and computes their joint
  1307. p-value `jmat[i, j]`.
  1308. Parameters
  1309. ----------
  1310. pmat : np.ndarray
  1311. A square matrix, the output of
  1312. :func:`ASSET.probability_matrix_montecarlo` or
  1313. :func:`ASSET.probability_matrix_analytical`, of cumulative
  1314. probability values between 0 and 1. The values are assumed
  1315. to be uniformly distributed in the said range.
  1316. filter_shape : tuple of int
  1317. A pair of integers representing the kernel shape `(l, w)`.
  1318. n_largest : int
  1319. The number of the largest neighbors to collect for each entry in
  1320. `jmat`.
  1321. min_p_value : float, optional
  1322. The minimum p-value in range `[0, 1)` for individual entries in
  1323. `pmat`. Each `pmat[i, j]` is set to
  1324. `min(pmat[i, j], 1-p_value_min)` to avoid that a single highly
  1325. significant value in `pmat` (extreme case: `pmat[i, j] = 1`) yields
  1326. joint significance of itself and its neighbors.
  1327. Default: 1e-5.
  1328. Returns
  1329. -------
  1330. jmat : np.ndarray
  1331. The joint probability matrix associated to `pmat`.
  1332. """
  1333. l, w = filter_shape
  1334. # Find for each P_ij in the probability matrix its neighbors and
  1335. # maximize them by the maximum value 1-p_value_min
  1336. pmat_neighb = _pmat_neighbors(
  1337. pmat, filter_shape=filter_shape, n_largest=n_largest)
  1338. pmat_neighb = np.minimum(pmat_neighb, 1. - min_p_value)
  1339. # in order to avoid doing the same calculation multiple times:
  1340. # find all unique sets of values in pmat_neighb
  1341. # and store the corresponding indices
  1342. # flatten the second and third dimension in order to use np.unique
  1343. pmat_neighb = pmat_neighb.reshape(n_largest, pmat.size).T
  1344. pmat_neighb, pmat_neighb_indices = np.unique(pmat_neighb, axis=0,
  1345. return_inverse=True)
  1346. # Compute the joint p-value matrix jpvmat
  1347. n = l * (1 + 2 * w) - w * (
  1348. w + 1) # number of entries covered by kernel
  1349. jpvmat = _jsf_uniform_orderstat_3d(pmat_neighb, n,
  1350. verbose=self.verbose)
  1351. # restore the original shape using the stored indices
  1352. jpvmat = jpvmat[pmat_neighb_indices].reshape(pmat.shape)
  1353. return 1. - jpvmat
  1354. @staticmethod
  1355. def mask_matrices(matrices, thresholds):
  1356. """
  1357. Given a list of `matrices` and a list of `thresholds`, return a boolean
  1358. matrix `B` ("mask") such that `B[i,j]` is True if each input matrix in
  1359. the list strictly exceeds the corresponding threshold at that position.
  1360. If multiple matrices are passed along with only one threshold the same
  1361. threshold is applied to all matrices.
  1362. Parameters
  1363. ----------
  1364. matrices : list of np.ndarray
  1365. The matrices which are compared to the respective thresholds to
  1366. build the mask. All matrices must have the same shape.
  1367. Typically, it is a list `[pmat, jmat]`, i.e., the (cumulative)
  1368. probability and joint probability matrices.
  1369. thresholds : float or list of float
  1370. The significance thresholds for each matrix in `matrices`.
  1371. Returns
  1372. -------
  1373. mask : np.ndarray
  1374. Boolean mask matrix with the shape of the input matrices.
  1375. Raises
  1376. ------
  1377. ValueError
  1378. If `matrices` or `thresholds` is an empty list.
  1379. If `matrices` and `thresholds` have different lengths.
  1380. See Also
  1381. --------
  1382. ASSET.probability_matrix_montecarlo : for `pmat` generation
  1383. ASSET.probability_matrix_analytical : for `pmat` generation
  1384. ASSET.joint_probability_matrix : for `jmat` generation
  1385. """
  1386. if len(matrices) == 0:
  1387. raise ValueError("Empty list of matrices")
  1388. if isinstance(thresholds, float):
  1389. thresholds = np.full(shape=len(matrices), fill_value=thresholds)
  1390. if len(matrices) != len(thresholds):
  1391. raise ValueError(
  1392. '`matrices` and `thresholds` must have same length')
  1393. mask = np.ones_like(matrices[0], dtype=bool)
  1394. for (mat, thresh) in zip(matrices, thresholds):
  1395. mask &= mat > thresh
  1396. # Replace nans, coming from False * np.inf, with zeros
  1397. mask[np.isnan(mask)] = False
  1398. return mask
  1399. @staticmethod
  1400. def cluster_matrix_entries(mask_matrix, max_distance, min_neighbors,
  1401. stretch):
  1402. r"""
  1403. Given a matrix `mask_matrix`, replaces its positive elements with
  1404. integers representing different cluster IDs. Each cluster comprises
  1405. close-by elements.
  1406. In ASSET analysis, `mask_matrix` is a thresholded ("masked") version
  1407. of the intersection matrix `imat`, whose values are those of `imat`
  1408. only if considered statistically significant, and zero otherwise.
  1409. A cluster is built by pooling elements according to their distance,
  1410. via the DBSCAN algorithm (see `sklearn.cluster.DBSCAN` class). Elements
  1411. form a neighbourhood if at least one of them has a distance not larger
  1412. than `max_distance` from the others, and if they are at least
  1413. `min_neighbors`. Overlapping neighborhoods form a cluster:
  1414. * Clusters are assigned integers from `1` to the total number `k`
  1415. of clusters;
  1416. * Unclustered ("isolated") positive elements of `mask_matrix` are
  1417. assigned value `-1`;
  1418. * Non-positive elements are assigned the value `0`.
  1419. The distance between the positions of two positive elements in
  1420. `mask_matrix` is given by a Euclidean metric which is stretched if the
  1421. two positions are not aligned along the 45 degree direction (the main
  1422. diagonal direction), as more, with maximal stretching along the
  1423. anti-diagonal. Specifically, the Euclidean distance between positions
  1424. `(i1, j1)` and `(i2, j2)` is stretched by a factor
  1425. .. math::
  1426. 1 + (\mathtt{stretch} - 1.) *
  1427. \left|\sin((\pi / 4) - \theta)\right|,
  1428. where :math:`\theta` is the angle between the pixels and the 45 degree
  1429. direction. The stretching factor thus varies between 1 and `stretch`.
  1430. Parameters
  1431. ----------
  1432. mask_matrix : np.ndarray
  1433. The boolean matrix, whose elements with positive values are to be
  1434. clustered. The output of :func:`ASSET.mask_matrices`.
  1435. max_distance : float
  1436. The maximum distance between two elements in `mask_matrix` to be
  1437. a part of the same neighbourhood in the DBSCAN algorithm.
  1438. min_neighbors : int
  1439. The minimum number of elements to form a neighbourhood.
  1440. stretch : float
  1441. The stretching factor of the euclidean metric for elements aligned
  1442. along the 135 degree direction (anti-diagonal). The actual
  1443. stretching increases from 1 to `stretch` as the direction of the
  1444. two elements moves from the 45 to the 135 degree direction.
  1445. `stretch` must be greater than 1.
  1446. Returns
  1447. -------
  1448. cluster_mat : np.ndarray
  1449. A matrix with the same shape of `mask_matrix`, each of whose
  1450. elements is either:
  1451. * a positive integer (cluster ID) if the element is part of a
  1452. cluster;
  1453. * `0` if the corresponding element in `mask_matrix` is
  1454. non-positive;
  1455. * `-1` if the element does not belong to any cluster.
  1456. See Also
  1457. --------
  1458. sklearn.cluster.DBSCAN
  1459. """
  1460. # Don't do anything if mat is identically zero
  1461. if np.all(mask_matrix == 0):
  1462. return mask_matrix
  1463. # List the significant pixels of mat in a 2-columns array
  1464. xpos_sgnf, ypos_sgnf = np.where(mask_matrix > 0)
  1465. # Compute the matrix D[i, j] of euclidean distances between pixels i
  1466. # and j
  1467. D = _stretched_metric_2d(
  1468. xpos_sgnf, ypos_sgnf, stretch=stretch, ref_angle=45)
  1469. # Cluster positions of significant pixels via dbscan
  1470. core_samples, config = dbscan(
  1471. D, eps=max_distance, min_samples=min_neighbors,
  1472. metric='precomputed')
  1473. # Construct the clustered matrix, where each element has value
  1474. # * i = 1 to k if it belongs to a cluster i,
  1475. # * 0 if it is not significant,
  1476. # * -1 if it is significant but does not belong to any cluster
  1477. cluster_mat = np.zeros_like(mask_matrix, dtype=np.int32)
  1478. cluster_mat[xpos_sgnf, ypos_sgnf] = \
  1479. config * (config == -1) + (config + 1) * (config >= 0)
  1480. return cluster_mat
  1481. def extract_synchronous_events(self, cmat, ids=None):
  1482. """
  1483. Given a list of spike trains, a bin size, and a clustered
  1484. intersection matrix obtained from those spike trains via ASSET
  1485. analysis, extracts the sequences of synchronous events (SSEs)
  1486. corresponding to clustered elements in the cluster matrix.
  1487. Parameters
  1488. ----------
  1489. cmat: (n,n) np.ndarray
  1490. The cluster matrix, the output of
  1491. :func:`ASSET.cluster_matrix_entries`.
  1492. ids : list, optional
  1493. A list of spike train IDs. If provided, `ids[i]` is the identity
  1494. of `spiketrains[i]`. If None, the IDs `0,1,...,n-1` are used.
  1495. Default: None.
  1496. Returns
  1497. -------
  1498. sse_dict : dict
  1499. A dictionary `D` of SSEs, where each SSE is a sub-dictionary `Dk`,
  1500. `k=1,...,K`, where `K` is the max positive integer in `cmat` (i.e.,
  1501. the total number of clusters in `cmat`):
  1502. .. centered:: D = {1: D1, 2: D2, ..., K: DK}
  1503. Each sub-dictionary `Dk` represents the k-th diagonal structure
  1504. (i.e., the k-th cluster) in `cmat`, and is of the form
  1505. .. centered:: Dk = {(i1, j1): S1, (i2, j2): S2, ..., (iL, jL): SL}.
  1506. The keys `(i, j)` represent the positions (time bin IDs) of all
  1507. elements in `cmat` that compose the SSE (i.e., that take value `l`
  1508. and therefore belong to the same cluster), and the values `Sk` are
  1509. sets of neuron IDs representing a repeated synchronous event (i.e.,
  1510. spiking at time bins `i` and `j`).
  1511. """
  1512. nr_worms = cmat.max() # number of different clusters ("worms") in cmat
  1513. if nr_worms <= 0:
  1514. return {}
  1515. # Compute the transactions associated to the two binnings
  1516. tracts_x = _transactions(
  1517. self.spiketrains_i, bin_size=self.bin_size, t_start=self.t_start_i,
  1518. t_stop=self.t_stop_i,
  1519. ids=ids)
  1520. if self.spiketrains_j is self.spiketrains_i:
  1521. diag_id = 0
  1522. tracts_y = tracts_x
  1523. else:
  1524. if self.is_symmetric():
  1525. diag_id = 0
  1526. tracts_y = tracts_x
  1527. else:
  1528. diag_id = None
  1529. tracts_y = _transactions(
  1530. self.spiketrains_j, bin_size=self.bin_size,
  1531. t_start=self.t_start_j, t_stop=self.t_stop_j, ids=ids)
  1532. # Reconstruct each worm, link by link
  1533. sse_dict = {}
  1534. for k in range(1, nr_worms + 1): # for each worm
  1535. # worm k is a list of links (each link will be 1 sublist)
  1536. worm_k = {}
  1537. pos_worm_k = np.array(
  1538. np.where(cmat == k)).T # position of all links
  1539. # if no link lies on the reference diagonal
  1540. if all([y - x != diag_id for (x, y) in pos_worm_k]):
  1541. for bin_x, bin_y in pos_worm_k: # for each link
  1542. # reconstruct the link
  1543. link_l = set(tracts_x[bin_x]).intersection(
  1544. tracts_y[bin_y])
  1545. # and assign it to its pixel
  1546. worm_k[(bin_x, bin_y)] = link_l
  1547. sse_dict[k] = worm_k
  1548. return sse_dict
  1549. def _rate_of_binned_spiketrain(self, binned_spiketrains, kernel_width):
  1550. """
  1551. Calculate the rate of binned spiketrains using convolution with
  1552. a boxcar kernel.
  1553. """
  1554. if self.verbose:
  1555. print('compute rates by boxcar-kernel convolution...')
  1556. # Create the boxcar kernel and convolve it with the binned spike trains
  1557. k = int((kernel_width / self.bin_size).simplified.item())
  1558. kernel = np.full(k, fill_value=1. / k)
  1559. rate = np.vstack([np.convolve(bst, kernel, mode='same')
  1560. for bst in binned_spiketrains])
  1561. # The convolution results in an array decreasing at the borders due
  1562. # to absence of spikes beyond the borders. Replace the first and last
  1563. # (k//2) elements with the (k//2)-th / (n-k//2)-th ones, respectively
  1564. k2 = k // 2
  1565. for i in range(rate.shape[0]):
  1566. rate[i, :k2] = rate[i, k2]
  1567. rate[i, -k2:] = rate[i, -k2 - 1]
  1568. # Multiply the firing rates by the proper unit
  1569. rate = rate * (1. / self.bin_size).rescale('Hz')
  1570. return rate