spiketrain.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. '''
  2. This module implements :class:`SpikeTrain`, an array of spike times.
  3. :class:`SpikeTrain` derives from :class:`BaseNeo`, from
  4. :module:`neo.core.baseneo`, and from :class:`quantites.Quantity`, which
  5. inherits from :class:`numpy.array`.
  6. Inheritance from :class:`numpy.array` is explained here:
  7. http://docs.scipy.org/doc/numpy/user/basics.subclassing.html
  8. In brief:
  9. * Initialization of a new object from constructor happens in :meth:`__new__`.
  10. This is where user-specified attributes are set.
  11. * :meth:`__array_finalize__` is called for all new objects, including those
  12. created by slicing. This is where attributes are copied over from
  13. the old object.
  14. '''
  15. import neo
  16. import sys
  17. from copy import deepcopy, copy
  18. import warnings
  19. import numpy as np
  20. import quantities as pq
  21. from neo.core.baseneo import BaseNeo, MergeError, merge_annotations
  22. from neo.core.dataobject import DataObject, ArrayDict
  23. def check_has_dimensions_time(*values):
  24. '''
  25. Verify that all arguments have a dimensionality that is compatible
  26. with time.
  27. '''
  28. errmsgs = []
  29. for value in values:
  30. dim = value.dimensionality.simplified
  31. if (len(dim) != 1 or
  32. list(dim.values())[0] != 1 or not
  33. isinstance(list(dim.keys())[0], pq.UnitTime)):
  34. errmsgs.append(
  35. "value {} has dimensions {}, not [time]".format(
  36. value, dim))
  37. if errmsgs:
  38. raise ValueError("\n".join(errmsgs))
  39. def _check_time_in_range(value, t_start, t_stop, view=False):
  40. '''
  41. Verify that all times in :attr:`value` are between :attr:`t_start`
  42. and :attr:`t_stop` (inclusive.
  43. If :attr:`view` is True, vies are used for the test.
  44. Using drastically increases the speed, but is only safe if you are
  45. certain that the dtype and units are the same
  46. '''
  47. if t_start > t_stop:
  48. raise ValueError("t_stop ({}) is before t_start ({})".format(t_stop, t_start))
  49. if not value.size:
  50. return
  51. if view:
  52. value = value.view(np.ndarray)
  53. t_start = t_start.view(np.ndarray)
  54. t_stop = t_stop.view(np.ndarray)
  55. if value.min() < t_start:
  56. raise ValueError("The first spike ({}) is before t_start ({})".format(value, t_start))
  57. if value.max() > t_stop:
  58. raise ValueError("The last spike ({}) is after t_stop ({})".format(value, t_stop))
  59. def _check_waveform_dimensions(spiketrain):
  60. '''
  61. Verify that waveform is compliant with the waveform definition as
  62. quantity array 3D (spike, channel_index, time)
  63. '''
  64. if not spiketrain.size:
  65. return
  66. waveforms = spiketrain.waveforms
  67. if (waveforms is None) or (not waveforms.size):
  68. return
  69. if waveforms.shape[0] != len(spiketrain):
  70. raise ValueError("Spiketrain length (%s) does not match to number of "
  71. "waveforms present (%s)" % (len(spiketrain), waveforms.shape[0]))
  72. def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, copy=True,
  73. sampling_rate=1.0 * pq.Hz, t_start=0.0 * pq.s, waveforms=None, left_sweep=None,
  74. name=None, file_origin=None, description=None, array_annotations=None,
  75. annotations=None, segment=None, unit=None):
  76. '''
  77. A function to map :meth:`BaseAnalogSignal.__new__` to function that
  78. does not do the unit checking. This is needed for :module:`pickle` to work.
  79. '''
  80. if annotations is None:
  81. annotations = {}
  82. obj = SpikeTrain(signal, t_stop, units, dtype, copy, sampling_rate, t_start, waveforms,
  83. left_sweep, name, file_origin, description, array_annotations, **annotations)
  84. obj.segment = segment
  85. obj.unit = unit
  86. return obj
  87. class SpikeTrain(DataObject):
  88. '''
  89. :class:`SpikeTrain` is a :class:`Quantity` array of spike times.
  90. It is an ensemble of action potentials (spikes) emitted by the same unit
  91. in a period of time.
  92. *Usage*::
  93. >>> from neo.core import SpikeTrain
  94. >>> from quantities import s
  95. >>>
  96. >>> train = SpikeTrain([3, 4, 5]*s, t_stop=10.0)
  97. >>> train2 = train[1:3]
  98. >>>
  99. >>> train.t_start
  100. array(0.0) * s
  101. >>> train.t_stop
  102. array(10.0) * s
  103. >>> train
  104. <SpikeTrain(array([ 3., 4., 5.]) * s, [0.0 s, 10.0 s])>
  105. >>> train2
  106. <SpikeTrain(array([ 4., 5.]) * s, [0.0 s, 10.0 s])>
  107. *Required attributes/properties*:
  108. :times: (quantity array 1D, numpy array 1D, or list) The times of
  109. each spike.
  110. :units: (quantity units) Required if :attr:`times` is a list or
  111. :class:`~numpy.ndarray`, not if it is a
  112. :class:`~quantites.Quantity`.
  113. :t_stop: (quantity scalar, numpy scalar, or float) Time at which
  114. :class:`SpikeTrain` ended. This will be converted to the
  115. same units as :attr:`times`. This argument is required because it
  116. specifies the period of time over which spikes could have occurred.
  117. Note that :attr:`t_start` is highly recommended for the same
  118. reason.
  119. Note: If :attr:`times` contains values outside of the
  120. range [t_start, t_stop], an Exception is raised.
  121. *Recommended attributes/properties*:
  122. :name: (str) A label for the dataset.
  123. :description: (str) Text description.
  124. :file_origin: (str) Filesystem path or URL of the original data file.
  125. :t_start: (quantity scalar, numpy scalar, or float) Time at which
  126. :class:`SpikeTrain` began. This will be converted to the
  127. same units as :attr:`times`.
  128. Default: 0.0 seconds.
  129. :waveforms: (quantity array 3D (spike, channel_index, time))
  130. The waveforms of each spike.
  131. :sampling_rate: (quantity scalar) Number of samples per unit time
  132. for the waveforms.
  133. :left_sweep: (quantity array 1D) Time from the beginning
  134. of the waveform to the trigger time of the spike.
  135. :sort: (bool) If True, the spike train will be sorted by time.
  136. *Optional attributes/properties*:
  137. :dtype: (numpy dtype or str) Override the dtype of the signal array.
  138. :copy: (bool) Whether to copy the times array. True by default.
  139. Must be True when you request a change of units or dtype.
  140. :array_annotations: (dict) Dict mapping strings to numpy arrays containing annotations \
  141. for all data points
  142. Note: Any other additional arguments are assumed to be user-specific
  143. metadata and stored in :attr:`annotations`.
  144. *Properties available on this object*:
  145. :sampling_period: (quantity scalar) Interval between two samples.
  146. (1/:attr:`sampling_rate`)
  147. :duration: (quantity scalar) Duration over which spikes can occur,
  148. read-only.
  149. (:attr:`t_stop` - :attr:`t_start`)
  150. :spike_duration: (quantity scalar) Duration of a waveform, read-only.
  151. (:attr:`waveform`.shape[2] * :attr:`sampling_period`)
  152. :right_sweep: (quantity scalar) Time from the trigger times of the
  153. spikes to the end of the waveforms, read-only.
  154. (:attr:`left_sweep` + :attr:`spike_duration`)
  155. :times: (quantity array 1D) Returns the :class:`SpikeTrain` as a quantity array.
  156. *Slicing*:
  157. :class:`SpikeTrain` objects can be sliced. When this occurs, a new
  158. :class:`SpikeTrain` (actually a view) is returned, with the same
  159. metadata, except that :attr:`waveforms` is also sliced in the same way
  160. (along dimension 0). Note that t_start and t_stop are not changed
  161. automatically, although you can still manually change them.
  162. '''
  163. _single_parent_objects = ('Segment', 'Unit')
  164. _single_parent_attrs = ('segment', 'unit')
  165. _quantity_attr = 'times'
  166. _necessary_attrs = (('times', pq.Quantity, 1), ('t_start', pq.Quantity, 0),
  167. ('t_stop', pq.Quantity, 0))
  168. _recommended_attrs = ((('waveforms', pq.Quantity, 3), ('left_sweep', pq.Quantity, 0),
  169. ('sampling_rate', pq.Quantity, 0)) + BaseNeo._recommended_attrs)
  170. def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, sampling_rate=1.0 * pq.Hz,
  171. t_start=0.0 * pq.s, waveforms=None, left_sweep=None, name=None, file_origin=None,
  172. description=None, array_annotations=None, **annotations):
  173. '''
  174. Constructs a new :clas:`Spiketrain` instance from data.
  175. This is called whenever a new :class:`SpikeTrain` is created from the
  176. constructor, but not when slicing.
  177. '''
  178. if len(times) != 0 and waveforms is not None and len(times) != waveforms.shape[0]:
  179. # len(times)!=0 has been used to workaround a bug occuring during neo import
  180. raise ValueError("the number of waveforms should be equal to the number of spikes")
  181. # Make sure units are consistent
  182. # also get the dimensionality now since it is much faster to feed
  183. # that to Quantity rather than a unit
  184. if units is None:
  185. # No keyword units, so get from `times`
  186. try:
  187. dim = times.units.dimensionality
  188. except AttributeError:
  189. raise ValueError('you must specify units')
  190. else:
  191. if hasattr(units, 'dimensionality'):
  192. dim = units.dimensionality
  193. else:
  194. dim = pq.quantity.validate_dimensionality(units)
  195. if hasattr(times, 'dimensionality'):
  196. if times.dimensionality.items() == dim.items():
  197. units = None # units will be taken from times, avoids copying
  198. else:
  199. if not copy:
  200. raise ValueError("cannot rescale and return view")
  201. else:
  202. # this is needed because of a bug in python-quantities
  203. # see issue # 65 in python-quantities github
  204. # remove this if it is fixed
  205. times = times.rescale(dim)
  206. if dtype is None:
  207. if not hasattr(times, 'dtype'):
  208. dtype = np.float
  209. elif hasattr(times, 'dtype') and times.dtype != dtype:
  210. if not copy:
  211. raise ValueError("cannot change dtype and return view")
  212. # if t_start.dtype or t_stop.dtype != times.dtype != dtype,
  213. # _check_time_in_range can have problems, so we set the t_start
  214. # and t_stop dtypes to be the same as times before converting them
  215. # to dtype below
  216. # see ticket #38
  217. if hasattr(t_start, 'dtype') and t_start.dtype != times.dtype:
  218. t_start = t_start.astype(times.dtype)
  219. if hasattr(t_stop, 'dtype') and t_stop.dtype != times.dtype:
  220. t_stop = t_stop.astype(times.dtype)
  221. # check to make sure the units are time
  222. # this approach is orders of magnitude faster than comparing the
  223. # reference dimensionality
  224. if (len(dim) != 1 or list(dim.values())[0] != 1 or not isinstance(list(dim.keys())[0],
  225. pq.UnitTime)):
  226. ValueError("Unit has dimensions %s, not [time]" % dim.simplified)
  227. # Construct Quantity from data
  228. obj = pq.Quantity(times, units=units, dtype=dtype, copy=copy).view(cls)
  229. # spiketrain times always need to be 1-dimensional
  230. if len(obj.shape) > 1:
  231. raise ValueError("Spiketrain times array has more than 1 dimension")
  232. # if the dtype and units match, just copy the values here instead
  233. # of doing the much more expensive creation of a new Quantity
  234. # using items() is orders of magnitude faster
  235. if (hasattr(t_start, 'dtype')
  236. and t_start.dtype == obj.dtype
  237. and hasattr(t_start, 'dimensionality')
  238. and t_start.dimensionality.items() == dim.items()):
  239. obj.t_start = t_start.copy()
  240. else:
  241. obj.t_start = pq.Quantity(t_start, units=dim, dtype=obj.dtype)
  242. if (hasattr(t_stop, 'dtype') and t_stop.dtype == obj.dtype
  243. and hasattr(t_stop, 'dimensionality')
  244. and t_stop.dimensionality.items() == dim.items()):
  245. obj.t_stop = t_stop.copy()
  246. else:
  247. obj.t_stop = pq.Quantity(t_stop, units=dim, dtype=obj.dtype)
  248. # Store attributes
  249. obj.waveforms = waveforms
  250. obj.left_sweep = left_sweep
  251. obj.sampling_rate = sampling_rate
  252. # parents
  253. obj.segment = None
  254. obj.unit = None
  255. # Error checking (do earlier?)
  256. _check_time_in_range(obj, obj.t_start, obj.t_stop, view=True)
  257. return obj
  258. def __init__(self, times, t_stop, units=None, dtype=np.float, copy=True,
  259. sampling_rate=1.0 * pq.Hz, t_start=0.0 * pq.s, waveforms=None, left_sweep=None,
  260. name=None, file_origin=None, description=None, array_annotations=None,
  261. **annotations):
  262. '''
  263. Initializes a newly constructed :class:`SpikeTrain` instance.
  264. '''
  265. # This method is only called when constructing a new SpikeTrain,
  266. # not when slicing or viewing. We use the same call signature
  267. # as __new__ for documentation purposes. Anything not in the call
  268. # signature is stored in annotations.
  269. # Calls parent __init__, which grabs universally recommended
  270. # attributes and sets up self.annotations
  271. DataObject.__init__(self, name=name, file_origin=file_origin, description=description,
  272. array_annotations=array_annotations, **annotations)
  273. def _repr_pretty_(self, pp, cycle):
  274. super()._repr_pretty_(pp, cycle)
  275. def rescale(self, units):
  276. '''
  277. Return a copy of the :class:`SpikeTrain` converted to the specified
  278. units
  279. '''
  280. obj = super().rescale(units)
  281. obj.t_start = self.t_start.rescale(units)
  282. obj.t_stop = self.t_stop.rescale(units)
  283. obj.unit = self.unit
  284. return obj
  285. def __reduce__(self):
  286. '''
  287. Map the __new__ function onto _new_BaseAnalogSignal, so that pickle
  288. works
  289. '''
  290. import numpy
  291. return _new_spiketrain, (self.__class__, numpy.array(self), self.t_stop, self.units,
  292. self.dtype, True, self.sampling_rate, self.t_start,
  293. self.waveforms, self.left_sweep, self.name, self.file_origin,
  294. self.description, self.array_annotations, self.annotations,
  295. self.segment, self.unit)
  296. def __array_finalize__(self, obj):
  297. '''
  298. This is called every time a new :class:`SpikeTrain` is created.
  299. It is the appropriate place to set default values for attributes
  300. for :class:`SpikeTrain` constructed by slicing or viewing.
  301. User-specified values are only relevant for construction from
  302. constructor, and these are set in __new__. Then they are just
  303. copied over here.
  304. Note that the :attr:`waveforms` attibute is not sliced here. Nor is
  305. :attr:`t_start` or :attr:`t_stop` modified.
  306. '''
  307. # This calls Quantity.__array_finalize__ which deals with
  308. # dimensionality
  309. super().__array_finalize__(obj)
  310. # Supposedly, during initialization from constructor, obj is supposed
  311. # to be None, but this never happens. It must be something to do
  312. # with inheritance from Quantity.
  313. if obj is None:
  314. return
  315. # Set all attributes of the new object `self` from the attributes
  316. # of `obj`. For instance, when slicing, we want to copy over the
  317. # attributes of the original object.
  318. self.t_start = getattr(obj, 't_start', None)
  319. self.t_stop = getattr(obj, 't_stop', None)
  320. self.waveforms = getattr(obj, 'waveforms', None)
  321. self.left_sweep = getattr(obj, 'left_sweep', None)
  322. self.sampling_rate = getattr(obj, 'sampling_rate', None)
  323. self.segment = getattr(obj, 'segment', None)
  324. self.unit = getattr(obj, 'unit', None)
  325. # The additional arguments
  326. self.annotations = getattr(obj, 'annotations', {})
  327. # Add empty array annotations, because they cannot always be copied,
  328. # but do not overwrite existing ones from slicing etc.
  329. # This ensures the attribute exists
  330. if not hasattr(self, 'array_annotations'):
  331. self.array_annotations = ArrayDict(self._get_arr_ann_length())
  332. # Note: Array annotations have to be changed when slicing or initializing an object,
  333. # copying them over in spite of changed data would result in unexpected behaviour
  334. # Globally recommended attributes
  335. self.name = getattr(obj, 'name', None)
  336. self.file_origin = getattr(obj, 'file_origin', None)
  337. self.description = getattr(obj, 'description', None)
  338. if hasattr(obj, 'lazy_shape'):
  339. self.lazy_shape = obj.lazy_shape
  340. def __repr__(self):
  341. '''
  342. Returns a string representing the :class:`SpikeTrain`.
  343. '''
  344. return '<SpikeTrain(%s, [%s, %s])>' % (
  345. super().__repr__(), self.t_start, self.t_stop)
  346. def sort(self):
  347. '''
  348. Sorts the :class:`SpikeTrain` and its :attr:`waveforms`, if any,
  349. by time.
  350. '''
  351. # sort the waveforms by the times
  352. sort_indices = np.argsort(self)
  353. if self.waveforms is not None and self.waveforms.any():
  354. self.waveforms = self.waveforms[sort_indices]
  355. self.array_annotate(**deepcopy(self.array_annotations_at_index(sort_indices)))
  356. # now sort the times
  357. # We have sorted twice, but `self = self[sort_indices]` introduces
  358. # a dependency on the slicing functionality of SpikeTrain.
  359. super().sort()
  360. def __getslice__(self, i, j):
  361. '''
  362. Get a slice from :attr:`i` to :attr:`j`.
  363. Doesn't get called in Python 3, :meth:`__getitem__` is called instead
  364. '''
  365. return self.__getitem__(slice(i, j))
  366. def __add__(self, time):
  367. '''
  368. Shifts the time point of all spikes by adding the amount in
  369. :attr:`time` (:class:`Quantity`)
  370. If `time` is a scalar, this also shifts :attr:`t_start` and :attr:`t_stop`.
  371. If `time` is an array, :attr:`t_start` and :attr:`t_stop` are not changed unless
  372. some of the new spikes would be outside this range.
  373. In this case :attr:`t_start` and :attr:`t_stop` are modified if necessary to
  374. ensure they encompass all spikes.
  375. It is not possible to add two SpikeTrains (raises ValueError).
  376. '''
  377. spikes = self.view(pq.Quantity)
  378. check_has_dimensions_time(time)
  379. if isinstance(time, SpikeTrain):
  380. raise TypeError("Can't add two spike trains")
  381. new_times = spikes + time
  382. if time.size > 1:
  383. t_start = min(self.t_start, np.min(new_times))
  384. t_stop = max(self.t_stop, np.max(new_times))
  385. else:
  386. t_start = self.t_start + time
  387. t_stop = self.t_stop + time
  388. return SpikeTrain(times=new_times, t_stop=t_stop, units=self.units,
  389. sampling_rate=self.sampling_rate, t_start=t_start,
  390. waveforms=self.waveforms, left_sweep=self.left_sweep, name=self.name,
  391. file_origin=self.file_origin, description=self.description,
  392. array_annotations=deepcopy(self.array_annotations),
  393. **self.annotations)
  394. def __sub__(self, time):
  395. '''
  396. Shifts the time point of all spikes by subtracting the amount in
  397. :attr:`time` (:class:`Quantity`)
  398. If `time` is a scalar, this also shifts :attr:`t_start` and :attr:`t_stop`.
  399. If `time` is an array, :attr:`t_start` and :attr:`t_stop` are not changed unless
  400. some of the new spikes would be outside this range.
  401. In this case :attr:`t_start` and :attr:`t_stop` are modified if necessary to
  402. ensure they encompass all spikes.
  403. In general, it is not possible to subtract two SpikeTrain objects (raises ValueError).
  404. However, if `time` is itself a SpikeTrain of the same size as the SpikeTrain,
  405. returns a Quantities array (since this is often used in checking
  406. whether two spike trains are the same or in calculating the inter-spike interval.
  407. '''
  408. spikes = self.view(pq.Quantity)
  409. check_has_dimensions_time(time)
  410. if isinstance(time, SpikeTrain):
  411. if self.size == time.size:
  412. return spikes - time
  413. else:
  414. raise TypeError("Can't subtract spike trains with different sizes")
  415. else:
  416. new_times = spikes - time
  417. if time.size > 1:
  418. t_start = min(self.t_start, np.min(new_times))
  419. t_stop = max(self.t_stop, np.max(new_times))
  420. else:
  421. t_start = self.t_start - time
  422. t_stop = self.t_stop - time
  423. return SpikeTrain(times=spikes - time, t_stop=t_stop, units=self.units,
  424. sampling_rate=self.sampling_rate, t_start=t_start,
  425. waveforms=self.waveforms, left_sweep=self.left_sweep, name=self.name,
  426. file_origin=self.file_origin, description=self.description,
  427. array_annotations=deepcopy(self.array_annotations),
  428. **self.annotations)
  429. def __getitem__(self, i):
  430. '''
  431. Get the item or slice :attr:`i`.
  432. '''
  433. obj = super().__getitem__(i)
  434. if hasattr(obj, 'waveforms') and obj.waveforms is not None:
  435. obj.waveforms = obj.waveforms.__getitem__(i)
  436. try:
  437. obj.array_annotate(**deepcopy(self.array_annotations_at_index(i)))
  438. except AttributeError: # If Quantity was returned, not SpikeTrain
  439. pass
  440. return obj
  441. def __setitem__(self, i, value):
  442. '''
  443. Set the value the item or slice :attr:`i`.
  444. '''
  445. if not hasattr(value, "units"):
  446. value = pq.Quantity(value,
  447. units=self.units) # or should we be strict: raise ValueError(
  448. # "Setting a value # requires a quantity")?
  449. # check for values outside t_start, t_stop
  450. _check_time_in_range(value, self.t_start, self.t_stop)
  451. super().__setitem__(i, value)
  452. def __setslice__(self, i, j, value):
  453. if not hasattr(value, "units"):
  454. value = pq.Quantity(value, units=self.units)
  455. _check_time_in_range(value, self.t_start, self.t_stop)
  456. super().__setslice__(i, j, value)
  457. def _copy_data_complement(self, other, deep_copy=False):
  458. '''
  459. Copy the metadata from another :class:`SpikeTrain`.
  460. Note: Array annotations can not be copied here because length of data can change
  461. '''
  462. # Note: Array annotations cannot be copied because length of data can be changed
  463. # here which would cause inconsistencies
  464. for attr in ("left_sweep", "sampling_rate", "name", "file_origin", "description",
  465. "annotations"):
  466. attr_value = getattr(other, attr, None)
  467. if deep_copy:
  468. attr_value = deepcopy(attr_value)
  469. setattr(self, attr, attr_value)
  470. def duplicate_with_new_data(self, signal, t_start=None, t_stop=None, waveforms=None,
  471. deep_copy=True, units=None):
  472. '''
  473. Create a new :class:`SpikeTrain` with the same metadata
  474. but different data (times, t_start, t_stop)
  475. Note: Array annotations can not be copied here because length of data can change
  476. '''
  477. # using previous t_start and t_stop if no values are provided
  478. if t_start is None:
  479. t_start = self.t_start
  480. if t_stop is None:
  481. t_stop = self.t_stop
  482. if waveforms is None:
  483. waveforms = self.waveforms
  484. if units is None:
  485. units = self.units
  486. else:
  487. units = pq.quantity.validate_dimensionality(units)
  488. new_st = self.__class__(signal, t_start=t_start, t_stop=t_stop, waveforms=waveforms,
  489. units=units)
  490. new_st._copy_data_complement(self, deep_copy=deep_copy)
  491. # Note: Array annotations are not copied here, because length of data could change
  492. # overwriting t_start and t_stop with new values
  493. new_st.t_start = t_start
  494. new_st.t_stop = t_stop
  495. # consistency check
  496. _check_time_in_range(new_st, new_st.t_start, new_st.t_stop, view=False)
  497. _check_waveform_dimensions(new_st)
  498. return new_st
  499. def time_slice(self, t_start, t_stop):
  500. '''
  501. Creates a new :class:`SpikeTrain` corresponding to the time slice of
  502. the original :class:`SpikeTrain` between (and including) times
  503. :attr:`t_start` and :attr:`t_stop`. Either parameter can also be None
  504. to use infinite endpoints for the time interval.
  505. '''
  506. _t_start = t_start
  507. _t_stop = t_stop
  508. if t_start is None:
  509. _t_start = -np.inf
  510. if t_stop is None:
  511. _t_stop = np.inf
  512. if _t_start > self.t_stop or _t_stop < self.t_start:
  513. # the alternative to raising an exception would be to return
  514. # a zero-duration spike train set at self.t_stop or self.t_start
  515. raise ValueError("A time slice completely outside the "
  516. "boundaries of the spike train is not defined.")
  517. indices = (self >= _t_start) & (self <= _t_stop)
  518. # Time slicing should create a deep copy of the object
  519. new_st = deepcopy(self[indices])
  520. new_st.t_start = max(_t_start, self.t_start)
  521. new_st.t_stop = min(_t_stop, self.t_stop)
  522. if self.waveforms is not None:
  523. new_st.waveforms = self.waveforms[indices]
  524. return new_st
  525. def time_shift(self, t_shift):
  526. """
  527. Shifts a :class:`SpikeTrain` to start at a new time.
  528. Parameters:
  529. -----------
  530. t_shift: Quantity (time)
  531. Amount of time by which to shift the :class:`SpikeTrain`.
  532. Returns:
  533. --------
  534. spiketrain: :class:`SpikeTrain`
  535. New instance of a :class:`SpikeTrain` object starting at t_shift later than the
  536. original :class:`SpikeTrain` (the original :class:`SpikeTrain` is not modified).
  537. """
  538. new_st = self.duplicate_with_new_data(
  539. signal=self.times.view(pq.Quantity) + t_shift,
  540. t_start=self.t_start + t_shift,
  541. t_stop=self.t_stop + t_shift)
  542. # Here we can safely copy the array annotations since we know that
  543. # the length of the SpikeTrain does not change.
  544. new_st.array_annotate(**self.array_annotations)
  545. return new_st
  546. def merge(self, *others):
  547. '''
  548. Merge other :class:`SpikeTrain` objects into this one.
  549. The times of the :class:`SpikeTrain` objects combined in one array
  550. and sorted.
  551. If the attributes of the :class:`SpikeTrain` objects are not
  552. compatible, an Exception is raised.
  553. '''
  554. for other in others:
  555. if isinstance(other, neo.io.proxyobjects.SpikeTrainProxy):
  556. raise MergeError("Cannot merge, SpikeTrainProxy objects cannot be merged"
  557. "into regular SpikeTrain objects, please load them first.")
  558. elif not isinstance(other, SpikeTrain):
  559. raise MergeError("Cannot merge, only SpikeTrain"
  560. "can be merged into a SpikeTrain.")
  561. if self.sampling_rate != other.sampling_rate:
  562. raise MergeError("Cannot merge, different sampling rates")
  563. if self.t_start != other.t_start:
  564. raise MergeError("Cannot merge, different t_start")
  565. if self.t_stop != other.t_stop:
  566. raise MergeError("Cannot merge, different t_stop")
  567. if self.left_sweep != other.left_sweep:
  568. raise MergeError("Cannot merge, different left_sweep")
  569. if self.segment != other.segment:
  570. raise MergeError("Cannot merge these signals as they belong to"
  571. " different segments.")
  572. all_spiketrains = [self]
  573. all_spiketrains.extend([st.rescale(self.units) for st in others])
  574. wfs = [st.waveforms is not None for st in all_spiketrains]
  575. if any(wfs) and not all(wfs):
  576. raise MergeError("Cannot merge signal with waveform and signal "
  577. "without waveform.")
  578. stack = np.concatenate([np.asarray(st) for st in all_spiketrains])
  579. sorting = np.argsort(stack)
  580. stack = stack[sorting]
  581. kwargs = {}
  582. kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting)
  583. for name in ("name", "description", "file_origin"):
  584. attr = getattr(self, name)
  585. # check if self is already a merged spiketrain
  586. # if it is, get rid of the bracket at the end to append more attributes
  587. if attr is not None:
  588. if attr.startswith('merge(') and attr.endswith(')'):
  589. attr = attr[:-1]
  590. for other in others:
  591. attr_other = getattr(other, name)
  592. # both attributes are None --> nothing to do
  593. if attr is None and attr_other is None:
  594. continue
  595. # one of the attributes is None --> convert to string in order to merge them
  596. elif attr is None or attr_other is None:
  597. attr = str(attr)
  598. attr_other = str(attr_other)
  599. # check if the other spiketrain is already a merged spiketrain
  600. # if it is, append all of its merged attributes that aren't already in attr
  601. if attr_other.startswith('merge(') and attr_other.endswith(')'):
  602. for subattr in attr_other[6:-1].split('; '):
  603. if subattr not in attr:
  604. attr += '; ' + subattr
  605. if not attr.startswith('merge('):
  606. attr = 'merge(' + attr
  607. # if the other attribute is not in the list --> append
  608. # if attr doesn't already start with merge add merge( in the beginning
  609. elif attr_other not in attr:
  610. attr += '; ' + attr_other
  611. if not attr.startswith('merge('):
  612. attr = 'merge(' + attr
  613. # close the bracket of merge(...) if necessary
  614. if attr is not None:
  615. if attr.startswith('merge('):
  616. attr += ')'
  617. # write attr into kwargs dict
  618. kwargs[name] = attr
  619. merged_annotations = merge_annotations(*(st.annotations for st in
  620. all_spiketrains))
  621. kwargs.update(merged_annotations)
  622. train = SpikeTrain(stack, units=self.units, dtype=self.dtype, copy=False,
  623. t_start=self.t_start, t_stop=self.t_stop,
  624. sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs)
  625. if all(wfs):
  626. wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units)
  627. for st in all_spiketrains])
  628. wfs_stack = wfs_stack[sorting] * self.waveforms.units
  629. train.waveforms = wfs_stack
  630. train.segment = self.segment
  631. if train.segment is not None:
  632. self.segment.spiketrains.append(train)
  633. return train
  634. def _merge_array_annotations(self, others, sorting=None):
  635. '''
  636. Merges array annotations of multiple different objects.
  637. The merge happens in such a way that the result fits the merged data
  638. In general this means concatenating the arrays from the objects.
  639. If an annotation is not present in one of the objects, it will be omitted.
  640. Apart from that the array_annotations need to be sorted according to the sorting of
  641. the spikes.
  642. :return Merged array_annotations
  643. '''
  644. assert sorting is not None, "The order of the merged spikes must be known"
  645. merged_array_annotations = {}
  646. omitted_keys_self = []
  647. keys = self.array_annotations.keys()
  648. for key in keys:
  649. try:
  650. self_ann = deepcopy(self.array_annotations[key])
  651. other_ann = np.concatenate([deepcopy(other.array_annotations[key])
  652. for other in others])
  653. if isinstance(self_ann, pq.Quantity):
  654. other_ann.rescale(self_ann.units)
  655. arr_ann = np.concatenate([self_ann, other_ann]) * self_ann.units
  656. else:
  657. arr_ann = np.concatenate([self_ann, other_ann])
  658. merged_array_annotations[key] = arr_ann[sorting]
  659. # Annotation only available in 'self', must be skipped
  660. # Ignore annotations present only in one of the SpikeTrains
  661. except KeyError:
  662. omitted_keys_self.append(key)
  663. continue
  664. omitted_keys_other = [key for key in np.unique([key for other in others
  665. for key in other.array_annotations])
  666. if key not in self.array_annotations]
  667. if omitted_keys_self or omitted_keys_other:
  668. warnings.warn("The following array annotations were omitted, because they were only "
  669. "present in one of the merged objects: {} from the one that was merged "
  670. "into and {} from the ones that were merged into it."
  671. "".format(omitted_keys_self, omitted_keys_other), UserWarning)
  672. return merged_array_annotations
  673. @property
  674. def times(self):
  675. '''
  676. Returns the :class:`SpikeTrain` as a quantity array.
  677. '''
  678. return pq.Quantity(self)
  679. @property
  680. def duration(self):
  681. '''
  682. Duration over which spikes can occur,
  683. (:attr:`t_stop` - :attr:`t_start`)
  684. '''
  685. if self.t_stop is None or self.t_start is None:
  686. return None
  687. return self.t_stop - self.t_start
  688. @property
  689. def spike_duration(self):
  690. '''
  691. Duration of a waveform.
  692. (:attr:`waveform`.shape[2] * :attr:`sampling_period`)
  693. '''
  694. if self.waveforms is None or self.sampling_rate is None:
  695. return None
  696. return self.waveforms.shape[2] / self.sampling_rate
  697. @property
  698. def sampling_period(self):
  699. '''
  700. Interval between two samples.
  701. (1/:attr:`sampling_rate`)
  702. '''
  703. if self.sampling_rate is None:
  704. return None
  705. return 1.0 / self.sampling_rate
  706. @sampling_period.setter
  707. def sampling_period(self, period):
  708. '''
  709. Setter for :attr:`sampling_period`
  710. '''
  711. if period is None:
  712. self.sampling_rate = None
  713. else:
  714. self.sampling_rate = 1.0 / period
  715. @property
  716. def right_sweep(self):
  717. '''
  718. Time from the trigger times of the spikes to the end of the waveforms.
  719. (:attr:`left_sweep` + :attr:`spike_duration`)
  720. '''
  721. dur = self.spike_duration
  722. if self.left_sweep is None or dur is None:
  723. return None
  724. return self.left_sweep + dur