kwikio.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading data from a .kwik dataset
  4. Depends on: scipy
  5. phy
  6. Supported: Read
  7. Author: Mikkel E. Lepperød @CINPLA
  8. """
  9. # TODO: writing to file
  10. # needed for python 3 compatibility
  11. from __future__ import absolute_import
  12. from __future__ import division
  13. import numpy as np
  14. import quantities as pq
  15. import os
  16. try:
  17. from scipy import stats
  18. except ImportError as err:
  19. HAVE_SCIPY = False
  20. SCIPY_ERR = err
  21. else:
  22. HAVE_SCIPY = True
  23. SCIPY_ERR = None
  24. try:
  25. from klusta import kwik
  26. except ImportError as err:
  27. HAVE_KWIK = False
  28. KWIK_ERR = err
  29. else:
  30. HAVE_KWIK = True
  31. KWIK_ERR = None
  32. # I need to subclass BaseIO
  33. from neo.io.baseio import BaseIO
  34. # to import from core
  35. from neo.core import (Segment, SpikeTrain, Unit, Epoch, AnalogSignal,
  36. ChannelIndex, Block)
  37. import neo.io.tools
  38. class KwikIO(BaseIO):
  39. """
  40. Class for "reading" experimental data from a .kwik file.
  41. Generates a :class:`Segment` with a :class:`AnalogSignal`
  42. """
  43. is_readable = True # This class can only read data
  44. is_writable = False # write is not supported
  45. supported_objects = [Block, Segment, SpikeTrain, AnalogSignal,
  46. ChannelIndex]
  47. # This class can return either a Block or a Segment
  48. # The first one is the default ( self.read )
  49. # These lists should go from highest object to lowest object because
  50. # common_io_test assumes it.
  51. readable_objects = [Block]
  52. # This class is not able to write objects
  53. writeable_objects = []
  54. has_header = False
  55. is_streameable = False
  56. name = 'Kwik'
  57. description = 'This IO reads experimental data from a .kwik dataset'
  58. extensions = ['kwik']
  59. mode = 'file'
  60. def __init__(self, filename):
  61. """
  62. Arguments:
  63. filename : the filename
  64. """
  65. if not HAVE_KWIK:
  66. raise KWIK_ERR
  67. BaseIO.__init__(self)
  68. self.filename = os.path.abspath(filename)
  69. model = kwik.KwikModel(self.filename) # TODO this group is loaded twice
  70. self.models = [kwik.KwikModel(self.filename, channel_group=grp)
  71. for grp in model.channel_groups]
  72. def read_block(self,
  73. lazy=False,
  74. cascade=True,
  75. get_waveforms=True,
  76. cluster_group=None,
  77. raw_data_units='uV',
  78. get_raw_data=False,
  79. ):
  80. """
  81. Reads a block with segments and channel_indexes
  82. Parameters:
  83. get_waveforms: bool, default = False
  84. Wether or not to get the waveforms
  85. get_raw_data: bool, default = False
  86. Wether or not to get the raw traces
  87. raw_data_units: str, default = "uV"
  88. SI units of the raw trace according to voltage_gain given to klusta
  89. cluster_group: str, default = None
  90. Which clusters to load, possibilities are "noise", "unsorted",
  91. "good", if None all is loaded.
  92. """
  93. blk = Block()
  94. if cascade:
  95. seg = Segment(file_origin=self.filename)
  96. blk.segments += [seg]
  97. for model in self.models:
  98. group_id = model.channel_group
  99. group_meta = {'group_id': group_id}
  100. group_meta.update(model.metadata)
  101. chx = ChannelIndex(name='channel group #{}'.format(group_id),
  102. index=model.channels,
  103. **group_meta)
  104. blk.channel_indexes.append(chx)
  105. clusters = model.spike_clusters
  106. for cluster_id in model.cluster_ids:
  107. meta = model.cluster_metadata[cluster_id]
  108. if cluster_group is None:
  109. pass
  110. elif cluster_group != meta:
  111. continue
  112. sptr = self.read_spiketrain(cluster_id=cluster_id,
  113. model=model, lazy=lazy,
  114. cascade=cascade,
  115. get_waveforms=get_waveforms,
  116. raw_data_units=raw_data_units)
  117. sptr.annotations.update({'cluster_group': meta,
  118. 'group_id': model.channel_group})
  119. sptr.channel_index = chx
  120. unit = Unit(cluster_group=meta,
  121. group_id=model.channel_group,
  122. name='unit #{}'.format(cluster_id))
  123. unit.spiketrains.append(sptr)
  124. chx.units.append(unit)
  125. unit.channel_index = chx
  126. seg.spiketrains.append(sptr)
  127. if get_raw_data:
  128. ana = self.read_analogsignal(model, raw_data_units,
  129. lazy, cascade)
  130. ana.channel_index = chx
  131. seg.analogsignals.append(ana)
  132. seg.duration = model.duration * pq.s
  133. blk.create_many_to_one_relationship()
  134. return blk
  135. def read_analogsignal(self, model, units='uV',
  136. lazy=False,
  137. cascade=True,
  138. ):
  139. """
  140. Reads analogsignals
  141. Parameters:
  142. units: str, default = "uV"
  143. SI units of the raw trace according to voltage_gain given to klusta
  144. """
  145. arr = model.traces[:]*model.metadata['voltage_gain']
  146. ana = AnalogSignal(arr, sampling_rate=model.sample_rate*pq.Hz,
  147. units=units,
  148. file_origin=model.metadata['raw_data_files'])
  149. return ana
  150. def read_spiketrain(self, cluster_id, model,
  151. lazy=False,
  152. cascade=True,
  153. get_waveforms=True,
  154. raw_data_units=None
  155. ):
  156. """
  157. Reads sorted spiketrains
  158. Parameters:
  159. get_waveforms: bool, default = False
  160. Wether or not to get the waveforms
  161. cluster_id: int,
  162. Which cluster to load, according to cluster id from klusta
  163. model: klusta.kwik.KwikModel
  164. A KwikModel object obtained by klusta.kwik.KwikModel(fname)
  165. """
  166. try:
  167. if ((not(cluster_id in model.cluster_ids))):
  168. raise ValueError
  169. except ValueError:
  170. print("Exception: cluster_id (%d) not found !! " % cluster_id)
  171. return
  172. clusters = model.spike_clusters
  173. idx = np.argwhere(clusters == cluster_id)
  174. if get_waveforms:
  175. w = model.all_waveforms[idx]
  176. # klusta: num_spikes, samples_per_spike, num_chans = w.shape
  177. w = w.swapaxes(1, 2)
  178. w = pq.Quantity(w, raw_data_units)
  179. else:
  180. w = None
  181. sptr = SpikeTrain(times=model.spike_times[idx],
  182. t_stop=model.duration, waveforms=w, units='s',
  183. sampling_rate=model.sample_rate*pq.Hz,
  184. file_origin=self.filename,
  185. **{'cluster_id': cluster_id})
  186. return sptr