kwikio.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading data from a .kwik dataset
  4. Depends on: scipy
  5. phy
  6. Supported: Read
  7. Author: Mikkel E. Lepperød @CINPLA
  8. """
  9. # TODO: writing to file
  10. # needed for python 3 compatibility
  11. from __future__ import absolute_import
  12. from __future__ import division
  13. import numpy as np
  14. import quantities as pq
  15. import os
  16. try:
  17. from scipy import stats
  18. except ImportError as err:
  19. HAVE_SCIPY = False
  20. SCIPY_ERR = err
  21. else:
  22. HAVE_SCIPY = True
  23. SCIPY_ERR = None
  24. try:
  25. from klusta import kwik
  26. except ImportError as err:
  27. HAVE_KWIK = False
  28. KWIK_ERR = err
  29. else:
  30. HAVE_KWIK = True
  31. KWIK_ERR = None
  32. # I need to subclass BaseIO
  33. from neo.io.baseio import BaseIO
  34. # to import from core
  35. from neo.core import (Segment, SpikeTrain, Unit, Epoch, AnalogSignal,
  36. ChannelIndex, Block)
  37. import neo.io.tools
  38. class KwikIO(BaseIO):
  39. """
  40. Class for "reading" experimental data from a .kwik file.
  41. Generates a :class:`Segment` with a :class:`AnalogSignal`
  42. """
  43. is_readable = True # This class can only read data
  44. is_writable = False # write is not supported
  45. supported_objects = [Block, Segment, SpikeTrain, AnalogSignal,
  46. ChannelIndex]
  47. # This class can return either a Block or a Segment
  48. # The first one is the default ( self.read )
  49. # These lists should go from highest object to lowest object because
  50. # common_io_test assumes it.
  51. readable_objects = [Block]
  52. # This class is not able to write objects
  53. writeable_objects = []
  54. has_header = False
  55. is_streameable = False
  56. name = 'Kwik'
  57. description = 'This IO reads experimental data from a .kwik dataset'
  58. extensions = ['kwik']
  59. mode = 'file'
  60. def __init__(self, filename):
  61. """
  62. Arguments:
  63. filename : the filename
  64. """
  65. if not HAVE_KWIK:
  66. raise KWIK_ERR
  67. BaseIO.__init__(self)
  68. self.filename = os.path.abspath(filename)
  69. model = kwik.KwikModel(self.filename) # TODO this group is loaded twice
  70. self.models = [kwik.KwikModel(self.filename, channel_group=grp)
  71. for grp in model.channel_groups]
  72. def read_block(self,
  73. lazy=False,
  74. get_waveforms=True,
  75. cluster_group=None,
  76. raw_data_units='uV',
  77. get_raw_data=False,
  78. ):
  79. """
  80. Reads a block with segments and channel_indexes
  81. Parameters:
  82. get_waveforms: bool, default = False
  83. Wether or not to get the waveforms
  84. get_raw_data: bool, default = False
  85. Wether or not to get the raw traces
  86. raw_data_units: str, default = "uV"
  87. SI units of the raw trace according to voltage_gain given to klusta
  88. cluster_group: str, default = None
  89. Which clusters to load, possibilities are "noise", "unsorted",
  90. "good", if None all is loaded.
  91. """
  92. assert not lazy, 'Do not support lazy'
  93. blk = Block()
  94. seg = Segment(file_origin=self.filename)
  95. blk.segments += [seg]
  96. for model in self.models:
  97. group_id = model.channel_group
  98. group_meta = {'group_id': group_id}
  99. group_meta.update(model.metadata)
  100. chx = ChannelIndex(name='channel group #{}'.format(group_id),
  101. index=model.channels,
  102. **group_meta)
  103. blk.channel_indexes.append(chx)
  104. clusters = model.spike_clusters
  105. for cluster_id in model.cluster_ids:
  106. meta = model.cluster_metadata[cluster_id]
  107. if cluster_group is None:
  108. pass
  109. elif cluster_group != meta:
  110. continue
  111. sptr = self.read_spiketrain(cluster_id=cluster_id,
  112. model=model,
  113. get_waveforms=get_waveforms,
  114. raw_data_units=raw_data_units)
  115. sptr.annotations.update({'cluster_group': meta,
  116. 'group_id': model.channel_group})
  117. sptr.channel_index = chx
  118. unit = Unit(cluster_group=meta,
  119. group_id=model.channel_group,
  120. name='unit #{}'.format(cluster_id))
  121. unit.spiketrains.append(sptr)
  122. chx.units.append(unit)
  123. unit.channel_index = chx
  124. seg.spiketrains.append(sptr)
  125. if get_raw_data:
  126. ana = self.read_analogsignal(model, units=raw_data_units)
  127. ana.channel_index = chx
  128. seg.analogsignals.append(ana)
  129. seg.duration = model.duration * pq.s
  130. blk.create_many_to_one_relationship()
  131. return blk
  132. def read_analogsignal(self, model, units='uV', lazy=False):
  133. """
  134. Reads analogsignals
  135. Parameters:
  136. units: str, default = "uV"
  137. SI units of the raw trace according to voltage_gain given to klusta
  138. """
  139. assert not lazy, 'Do not support lazy'
  140. arr = model.traces[:] * model.metadata['voltage_gain']
  141. ana = AnalogSignal(arr, sampling_rate=model.sample_rate * pq.Hz,
  142. units=units,
  143. file_origin=model.metadata['raw_data_files'])
  144. return ana
  145. def read_spiketrain(self, cluster_id, model,
  146. lazy=False,
  147. get_waveforms=True,
  148. raw_data_units=None
  149. ):
  150. """
  151. Reads sorted spiketrains
  152. Parameters:
  153. get_waveforms: bool, default = False
  154. Wether or not to get the waveforms
  155. cluster_id: int,
  156. Which cluster to load, according to cluster id from klusta
  157. model: klusta.kwik.KwikModel
  158. A KwikModel object obtained by klusta.kwik.KwikModel(fname)
  159. """
  160. try:
  161. if ((not (cluster_id in model.cluster_ids))):
  162. raise ValueError
  163. except ValueError:
  164. print("Exception: cluster_id (%d) not found !! " % cluster_id)
  165. return
  166. clusters = model.spike_clusters
  167. idx = np.nonzero(clusters == cluster_id)
  168. if get_waveforms:
  169. w = model.all_waveforms[idx]
  170. # klusta: num_spikes, samples_per_spike, num_chans = w.shape
  171. w = w.swapaxes(1, 2)
  172. w = pq.Quantity(w, raw_data_units)
  173. else:
  174. w = None
  175. sptr = SpikeTrain(times=model.spike_times[idx],
  176. t_stop=model.duration, waveforms=w, units='s',
  177. sampling_rate=model.sample_rate * pq.Hz,
  178. file_origin=self.filename,
  179. **{'cluster_id': cluster_id})
  180. return sptr