kwikio.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. Class for reading data from a .kwik dataset
  3. Depends on: scipy
  4. phy
  5. Supported: Read
  6. Author: Mikkel E. Lepperød @CINPLA
  7. """
  8. # TODO: writing to file
  9. import numpy as np
  10. import quantities as pq
  11. import os
  12. try:
  13. from scipy import stats
  14. except ImportError as err:
  15. HAVE_SCIPY = False
  16. SCIPY_ERR = err
  17. else:
  18. HAVE_SCIPY = True
  19. SCIPY_ERR = None
  20. try:
  21. from klusta import kwik
  22. except ImportError as err:
  23. HAVE_KWIK = False
  24. KWIK_ERR = err
  25. else:
  26. HAVE_KWIK = True
  27. KWIK_ERR = None
  28. # I need to subclass BaseIO
  29. from neo.io.baseio import BaseIO
  30. # to import from core
  31. from neo.core import (Segment, SpikeTrain, Unit, Epoch, AnalogSignal,
  32. ChannelIndex, Block)
  33. import neo.io.tools
  34. class KwikIO(BaseIO):
  35. """
  36. Class for "reading" experimental data from a .kwik file.
  37. Generates a :class:`Segment` with a :class:`AnalogSignal`
  38. """
  39. is_readable = True # This class can only read data
  40. is_writable = False # write is not supported
  41. supported_objects = [Block, Segment, SpikeTrain, AnalogSignal,
  42. ChannelIndex]
  43. # This class can return either a Block or a Segment
  44. # The first one is the default ( self.read )
  45. # These lists should go from highest object to lowest object because
  46. # common_io_test assumes it.
  47. readable_objects = [Block]
  48. # This class is not able to write objects
  49. writeable_objects = []
  50. has_header = False
  51. is_streameable = False
  52. name = 'Kwik'
  53. description = 'This IO reads experimental data from a .kwik dataset'
  54. extensions = ['kwik']
  55. mode = 'file'
  56. def __init__(self, filename):
  57. """
  58. Arguments:
  59. filename : the filename
  60. """
  61. if not HAVE_KWIK:
  62. raise KWIK_ERR
  63. BaseIO.__init__(self)
  64. self.filename = os.path.abspath(filename)
  65. model = kwik.KwikModel(self.filename) # TODO this group is loaded twice
  66. self.models = [kwik.KwikModel(self.filename, channel_group=grp)
  67. for grp in model.channel_groups]
  68. def read_block(self,
  69. lazy=False,
  70. get_waveforms=True,
  71. cluster_group=None,
  72. raw_data_units='uV',
  73. get_raw_data=False,
  74. ):
  75. """
  76. Reads a block with segments and channel_indexes
  77. Parameters:
  78. get_waveforms: bool, default = False
  79. Wether or not to get the waveforms
  80. get_raw_data: bool, default = False
  81. Wether or not to get the raw traces
  82. raw_data_units: str, default = "uV"
  83. SI units of the raw trace according to voltage_gain given to klusta
  84. cluster_group: str, default = None
  85. Which clusters to load, possibilities are "noise", "unsorted",
  86. "good", if None all is loaded.
  87. """
  88. assert not lazy, 'Do not support lazy'
  89. blk = Block()
  90. seg = Segment(file_origin=self.filename)
  91. blk.segments += [seg]
  92. for model in self.models:
  93. group_id = model.channel_group
  94. group_meta = {'group_id': group_id}
  95. group_meta.update(model.metadata)
  96. chx = ChannelIndex(name='channel group #{}'.format(group_id),
  97. index=model.channels,
  98. **group_meta)
  99. blk.channel_indexes.append(chx)
  100. clusters = model.spike_clusters
  101. for cluster_id in model.cluster_ids:
  102. meta = model.cluster_metadata[cluster_id]
  103. if cluster_group is None:
  104. pass
  105. elif cluster_group != meta:
  106. continue
  107. sptr = self.read_spiketrain(cluster_id=cluster_id,
  108. model=model,
  109. get_waveforms=get_waveforms,
  110. raw_data_units=raw_data_units)
  111. sptr.annotations.update({'cluster_group': meta,
  112. 'group_id': model.channel_group})
  113. sptr.channel_index = chx
  114. unit = Unit(cluster_group=meta,
  115. group_id=model.channel_group,
  116. name='unit #{}'.format(cluster_id))
  117. unit.spiketrains.append(sptr)
  118. chx.units.append(unit)
  119. unit.channel_index = chx
  120. seg.spiketrains.append(sptr)
  121. if get_raw_data:
  122. ana = self.read_analogsignal(model, units=raw_data_units)
  123. ana.channel_index = chx
  124. seg.analogsignals.append(ana)
  125. seg.duration = model.duration * pq.s
  126. blk.create_many_to_one_relationship()
  127. return blk
  128. def read_analogsignal(self, model, units='uV', lazy=False):
  129. """
  130. Reads analogsignals
  131. Parameters:
  132. units: str, default = "uV"
  133. SI units of the raw trace according to voltage_gain given to klusta
  134. """
  135. assert not lazy, 'Do not support lazy'
  136. arr = model.traces[:] * model.metadata['voltage_gain']
  137. ana = AnalogSignal(arr, sampling_rate=model.sample_rate * pq.Hz,
  138. units=units,
  139. file_origin=model.metadata['raw_data_files'])
  140. return ana
  141. def read_spiketrain(self, cluster_id, model,
  142. lazy=False,
  143. get_waveforms=True,
  144. raw_data_units=None
  145. ):
  146. """
  147. Reads sorted spiketrains
  148. Parameters:
  149. get_waveforms: bool, default = False
  150. Wether or not to get the waveforms
  151. cluster_id: int,
  152. Which cluster to load, according to cluster id from klusta
  153. model: klusta.kwik.KwikModel
  154. A KwikModel object obtained by klusta.kwik.KwikModel(fname)
  155. """
  156. try:
  157. if (not (cluster_id in model.cluster_ids)):
  158. raise ValueError
  159. except ValueError:
  160. print("Exception: cluster_id (%d) not found !! " % cluster_id)
  161. return
  162. clusters = model.spike_clusters
  163. idx = np.nonzero(clusters == cluster_id)
  164. if get_waveforms:
  165. w = model.all_waveforms[idx]
  166. # klusta: num_spikes, samples_per_spike, num_chans = w.shape
  167. w = w.swapaxes(1, 2)
  168. w = pq.Quantity(w, raw_data_units)
  169. else:
  170. w = None
  171. sptr = SpikeTrain(times=model.spike_times[idx],
  172. t_stop=model.duration, waveforms=w, units='s',
  173. sampling_rate=model.sample_rate * pq.Hz,
  174. file_origin=self.filename,
  175. **{'cluster_id': cluster_id})
  176. return sptr