123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- # -*- coding: utf-8 -*-
- """
- Class for reading data from a .kwik dataset
- Depends on: scipy
- phy
- Supported: Read
- Author: Mikkel E. Lepperød @CINPLA
- """
- # TODO: writing to file
- # needed for python 3 compatibility
- from __future__ import absolute_import
- from __future__ import division
- import numpy as np
- import quantities as pq
- import os
- try:
- from scipy import stats
- except ImportError as err:
- HAVE_SCIPY = False
- SCIPY_ERR = err
- else:
- HAVE_SCIPY = True
- SCIPY_ERR = None
- try:
- from klusta import kwik
- except ImportError as err:
- HAVE_KWIK = False
- KWIK_ERR = err
- else:
- HAVE_KWIK = True
- KWIK_ERR = None
- # I need to subclass BaseIO
- from neo.io.baseio import BaseIO
- # to import from core
- from neo.core import (Segment, SpikeTrain, Unit, Epoch, AnalogSignal,
- ChannelIndex, Block)
- import neo.io.tools
- class KwikIO(BaseIO):
- """
- Class for "reading" experimental data from a .kwik file.
- Generates a :class:`Segment` with a :class:`AnalogSignal`
- """
- is_readable = True # This class can only read data
- is_writable = False # write is not supported
- supported_objects = [Block, Segment, SpikeTrain, AnalogSignal,
- ChannelIndex]
- # This class can return either a Block or a Segment
- # The first one is the default ( self.read )
- # These lists should go from highest object to lowest object because
- # common_io_test assumes it.
- readable_objects = [Block]
- # This class is not able to write objects
- writeable_objects = []
- has_header = False
- is_streameable = False
- name = 'Kwik'
- description = 'This IO reads experimental data from a .kwik dataset'
- extensions = ['kwik']
- mode = 'file'
- def __init__(self, filename):
- """
- Arguments:
- filename : the filename
- """
- if not HAVE_KWIK:
- raise KWIK_ERR
- BaseIO.__init__(self)
- self.filename = os.path.abspath(filename)
- model = kwik.KwikModel(self.filename) # TODO this group is loaded twice
- self.models = [kwik.KwikModel(self.filename, channel_group=grp)
- for grp in model.channel_groups]
- def read_block(self,
- lazy=False,
- cascade=True,
- get_waveforms=True,
- cluster_group=None,
- raw_data_units='uV',
- get_raw_data=False,
- ):
- """
- Reads a block with segments and channel_indexes
- Parameters:
- get_waveforms: bool, default = False
- Wether or not to get the waveforms
- get_raw_data: bool, default = False
- Wether or not to get the raw traces
- raw_data_units: str, default = "uV"
- SI units of the raw trace according to voltage_gain given to klusta
- cluster_group: str, default = None
- Which clusters to load, possibilities are "noise", "unsorted",
- "good", if None all is loaded.
- """
- blk = Block()
- if cascade:
- seg = Segment(file_origin=self.filename)
- blk.segments += [seg]
- for model in self.models:
- group_id = model.channel_group
- group_meta = {'group_id': group_id}
- group_meta.update(model.metadata)
- chx = ChannelIndex(name='channel group #{}'.format(group_id),
- index=model.channels,
- **group_meta)
- blk.channel_indexes.append(chx)
- clusters = model.spike_clusters
- for cluster_id in model.cluster_ids:
- meta = model.cluster_metadata[cluster_id]
- if cluster_group is None:
- pass
- elif cluster_group != meta:
- continue
- sptr = self.read_spiketrain(cluster_id=cluster_id,
- model=model, lazy=lazy,
- cascade=cascade,
- get_waveforms=get_waveforms,
- raw_data_units=raw_data_units)
- sptr.annotations.update({'cluster_group': meta,
- 'group_id': model.channel_group})
- sptr.channel_index = chx
- unit = Unit(cluster_group=meta,
- group_id=model.channel_group,
- name='unit #{}'.format(cluster_id))
- unit.spiketrains.append(sptr)
- chx.units.append(unit)
- unit.channel_index = chx
- seg.spiketrains.append(sptr)
- if get_raw_data:
- ana = self.read_analogsignal(model, raw_data_units,
- lazy, cascade)
- ana.channel_index = chx
- seg.analogsignals.append(ana)
- seg.duration = model.duration * pq.s
- blk.create_many_to_one_relationship()
- return blk
- def read_analogsignal(self, model, units='uV',
- lazy=False,
- cascade=True,
- ):
- """
- Reads analogsignals
- Parameters:
- units: str, default = "uV"
- SI units of the raw trace according to voltage_gain given to klusta
- """
- arr = model.traces[:]*model.metadata['voltage_gain']
- ana = AnalogSignal(arr, sampling_rate=model.sample_rate*pq.Hz,
- units=units,
- file_origin=model.metadata['raw_data_files'])
- return ana
- def read_spiketrain(self, cluster_id, model,
- lazy=False,
- cascade=True,
- get_waveforms=True,
- raw_data_units=None
- ):
- """
- Reads sorted spiketrains
- Parameters:
- get_waveforms: bool, default = False
- Wether or not to get the waveforms
- cluster_id: int,
- Which cluster to load, according to cluster id from klusta
- model: klusta.kwik.KwikModel
- A KwikModel object obtained by klusta.kwik.KwikModel(fname)
- """
- try:
- if ((not(cluster_id in model.cluster_ids))):
- raise ValueError
- except ValueError:
- print("Exception: cluster_id (%d) not found !! " % cluster_id)
- return
- clusters = model.spike_clusters
- idx = np.argwhere(clusters == cluster_id)
- if get_waveforms:
- w = model.all_waveforms[idx]
- # klusta: num_spikes, samples_per_spike, num_chans = w.shape
- w = w.swapaxes(1, 2)
- w = pq.Quantity(w, raw_data_units)
- else:
- w = None
- sptr = SpikeTrain(times=model.spike_times[idx],
- t_stop=model.duration, waveforms=w, units='s',
- sampling_rate=model.sample_rate*pq.Hz,
- file_origin=self.filename,
- **{'cluster_id': cluster_id})
- return sptr
|