klustakwikio.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # -*- coding: utf-8 -*-
  2. """
  3. Reading and writing from KlustaKwik-format files.
  4. Ref: http://klusters.sourceforge.net/UserManual/data-files.html
  5. Supported : Read, Write
  6. Author : Chris Rodgers
  7. TODO:
  8. * When reading, put the Unit into the RCG, RC hierarchy
  9. * When writing, figure out how to get group and cluster if those annotations
  10. weren't set. Consider removing those annotations if they are redundant.
  11. * Load features in addition to spiketimes.
  12. """
  13. import glob
  14. import logging
  15. import os.path
  16. import shutil
  17. # note neo.core need only numpy and quantitie
  18. import numpy as np
  19. try:
  20. import matplotlib.mlab as mlab
  21. except ImportError as err:
  22. HAVE_MLAB = False
  23. MLAB_ERR = err
  24. else:
  25. HAVE_MLAB = True
  26. MLAB_ERR = None
  27. # I need to subclass BaseIO
  28. from neo.io.baseio import BaseIO
  29. from neo.core import Block, Segment, Unit, SpikeTrain
  30. # Pasted version of feature file format spec
  31. """
  32. The Feature File
  33. Generic file name: base.fet.n
  34. Format: ASCII, integer values
  35. The feature file lists for each spike the PCA coefficients for each
  36. electrode, followed by the timestamp of the spike (more features can
  37. be inserted between the PCA coefficients and the timestamp).
  38. The first line contains the number of dimensions.
  39. Assuming N1 spikes (spike1...spikeN1), N2 electrodes (e1...eN2) and
  40. N3 coefficients (c1...cN3), this file looks like:
  41. nbDimensions
  42. c1_e1_spike1 c2_e1_spike1 ... cN3_e1_spike1 c1_e2_spike1 ... cN3_eN2_spike1 timestamp_spike1
  43. c1_e1_spike2 c2_e1_spike2 ... cN3_e1_spike2 c1_e2_spike2 ... cN3_eN2_spike2 timestamp_spike2
  44. ...
  45. c1_e1_spikeN1 c2_e1_spikeN1 ... cN3_e1_spikeN1 c1_e2_spikeN1 ... cN3_eN2_spikeN1 timestamp_spikeN1
  46. The timestamp is expressed in multiples of the sampling interval. For
  47. instance, for a 20kHz recording (50 microsecond sampling interval), a
  48. timestamp of 200 corresponds to 200x0.000050s=0.01s from the beginning
  49. of the recording session.
  50. Notice that the last line must end with a newline or carriage return.
  51. """
  52. class KlustaKwikIO(BaseIO):
  53. """Reading and writing from KlustaKwik-format files."""
  54. # Class variables demonstrating capabilities of this IO
  55. is_readable = True
  56. is_writable = True
  57. # This IO can only manipulate objects relating to spike times
  58. supported_objects = [Block, SpikeTrain, Unit]
  59. # Keep things simple by always returning a block
  60. readable_objects = [Block]
  61. # And write a block
  62. writeable_objects = [Block]
  63. # Not sure what these do, if anything
  64. has_header = False
  65. is_streameable = False
  66. # GUI params
  67. read_params = {}
  68. # GUI params
  69. write_params = {}
  70. # The IO name and the file extensions it uses
  71. name = 'KlustaKwik'
  72. extensions = ['fet', 'clu', 'res', 'spk']
  73. # Operates on directories
  74. mode = 'file'
  75. def __init__(self, filename, sampling_rate=30000.):
  76. """Create a new IO to operate on a directory
  77. filename : the directory to contain the files
  78. basename : string, basename of KlustaKwik format, or None
  79. sampling_rate : in Hz, necessary because the KlustaKwik files
  80. stores data in samples.
  81. """
  82. if not HAVE_MLAB:
  83. raise MLAB_ERR
  84. BaseIO.__init__(self)
  85. #self.filename = os.path.normpath(filename)
  86. self.filename, self.basename = os.path.split(os.path.abspath(filename))
  87. self.sampling_rate = float(sampling_rate)
  88. # error check
  89. if not os.path.isdir(self.filename):
  90. raise ValueError("filename must be a directory")
  91. # initialize a helper object to parse filenames
  92. self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
  93. # The reading methods. The `lazy` and `cascade` parameters are imposed
  94. # by neo.io API
  95. def read_block(self, lazy=False, cascade=True):
  96. """Returns a Block containing spike information.
  97. There is no obvious way to infer the segment boundaries from
  98. raw spike times, so for now all spike times are returned in one
  99. big segment. The way around this would be to specify the segment
  100. boundaries, and then change this code to put the spikes in the right
  101. segments.
  102. """
  103. # Create block and segment to hold all the data
  104. block = Block()
  105. # Search data directory for KlustaKwik files.
  106. # If nothing found, return empty block
  107. self._fetfiles = self._fp.read_filenames('fet')
  108. self._clufiles = self._fp.read_filenames('clu')
  109. if len(self._fetfiles) == 0 or not cascade:
  110. return block
  111. # Create a single segment to hold all of the data
  112. seg = Segment(name='seg0', index=0, file_origin=self.filename)
  113. block.segments.append(seg)
  114. # Load spike times from each group and store in a dict, keyed
  115. # by group number
  116. self.spiketrains = dict()
  117. for group in sorted(self._fetfiles.keys()):
  118. # Load spike times
  119. fetfile = self._fetfiles[group]
  120. spks, features = self._load_spike_times(fetfile)
  121. # Load cluster ids or generate
  122. if group in self._clufiles:
  123. clufile = self._clufiles[group]
  124. uids = self._load_unit_id(clufile)
  125. else:
  126. # unclustered data, assume all zeros
  127. uids = np.zeros(spks.shape, dtype=np.int32)
  128. # error check
  129. if len(spks) != len(uids):
  130. raise ValueError("lengths of fet and clu files are different")
  131. # Create Unit for each cluster
  132. unique_unit_ids = np.unique(uids)
  133. for unit_id in sorted(unique_unit_ids):
  134. # Initialize the unit
  135. u = Unit(name=('unit %d from group %d' % (unit_id, group)),
  136. index=unit_id, group=group)
  137. # Initialize a new SpikeTrain for the spikes from this unit
  138. if lazy:
  139. st = SpikeTrain(
  140. times=[],
  141. units='sec', t_start=0.0,
  142. t_stop=spks.max() / self.sampling_rate,
  143. name=('unit %d from group %d' % (unit_id, group)))
  144. st.lazy_shape = len(spks[uids==unit_id])
  145. else:
  146. st = SpikeTrain(
  147. times=spks[uids==unit_id] / self.sampling_rate,
  148. units='sec', t_start=0.0,
  149. t_stop=spks.max() / self.sampling_rate,
  150. name=('unit %d from group %d' % (unit_id, group)))
  151. st.annotations['cluster'] = unit_id
  152. st.annotations['group'] = group
  153. # put features in
  154. if not lazy and len(features) != 0:
  155. st.annotations['waveform_features'] = features
  156. # Link
  157. u.spiketrains.append(st)
  158. seg.spiketrains.append(st)
  159. block.create_many_to_one_relationship()
  160. return block
  161. # Helper hidden functions for reading
  162. def _load_spike_times(self, fetfilename):
  163. """Reads and returns the spike times and features"""
  164. f = file(fetfilename, 'r')
  165. # Number of clustering features is integer on first line
  166. nbFeatures = int(f.readline().strip())
  167. # Each subsequent line consists of nbFeatures values, followed by
  168. # the spike time in samples.
  169. names = ['fet%d' % n for n in xrange(nbFeatures)]
  170. names.append('spike_time')
  171. # Load into recarray
  172. data = mlab.csv2rec(f, names=names, skiprows=1, delimiter=' ')
  173. f.close()
  174. # get features
  175. features = np.array([data['fet%d' % n] for n in xrange(nbFeatures)])
  176. # Return the spike_time column
  177. return data['spike_time'], features.transpose()
  178. def _load_unit_id(self, clufilename):
  179. """Reads and return the cluster ids as int32"""
  180. f = file(clufilename, 'r')
  181. # Number of clusters on this tetrode is integer on first line
  182. nbClusters = int(f.readline().strip())
  183. # Read each cluster name as a string
  184. cluster_names = f.readlines()
  185. f.close()
  186. # Convert names to integers
  187. # I think the spec requires cluster names to be integers, but
  188. # this code could be modified to support string names which are
  189. # auto-numbered.
  190. try:
  191. cluster_ids = [int(name) for name in cluster_names]
  192. except ValueError:
  193. raise ValueError(
  194. "Could not convert cluster name to integer in %s" % clufilename)
  195. # convert to numpy array and error check
  196. cluster_ids = np.array(cluster_ids, dtype=np.int32)
  197. if len(np.unique(cluster_ids)) != nbClusters:
  198. logging.warning("warning: I got %d clusters instead of %d in %s" % (
  199. len(np.unique(cluster_ids)), nbClusters, clufilename))
  200. return cluster_ids
  201. # writing functions
  202. def write_block(self, block):
  203. """Write spike times and unit ids to disk.
  204. Currently descends hierarchy from block to segment to spiketrain.
  205. Then gets group and cluster information from spiketrain.
  206. Then writes the time and cluster info to the file associated with
  207. that group.
  208. The group and cluster information are extracted from annotations,
  209. eg `sptr.annotations['group']`. If no cluster information exists,
  210. it is assigned to cluster 0.
  211. Note that all segments are essentially combined in
  212. this process, since the KlustaKwik format does not allow for
  213. segment boundaries.
  214. As implemented currently, does not use the `Unit` object at all.
  215. We first try to use the sampling rate of each SpikeTrain, or if this
  216. is not set, we use `self.sampling_rate`.
  217. If the files already exist, backup copies are created by appending
  218. the filenames with a "~".
  219. """
  220. # set basename
  221. if self.basename is None:
  222. logging.warning("warning: no basename provided, using `basename`")
  223. self.basename = 'basename'
  224. # First create file handles for each group which will be stored
  225. self._make_all_file_handles(block)
  226. # We'll detect how many features belong in each group
  227. self._group2features = {}
  228. # Iterate through segments in this block
  229. for seg in block.segments:
  230. # Write each spiketrain of the segment
  231. for st in seg.spiketrains:
  232. # Get file handles for this spiketrain using its group
  233. group = self.st2group(st)
  234. fetfilehandle = self._fetfilehandles[group]
  235. clufilehandle = self._clufilehandles[group]
  236. # Get the id to write to clu file for this spike train
  237. cluster = self.st2cluster(st)
  238. # Choose sampling rate to convert to samples
  239. try:
  240. sr = st.annotations['sampling_rate']
  241. except KeyError:
  242. sr = self.sampling_rate
  243. # Convert to samples
  244. spike_times_in_samples = np.rint(
  245. np.array(st) * sr).astype(np.int)
  246. # Try to get features from spiketrain
  247. try:
  248. all_features = st.annotations['waveform_features']
  249. except KeyError:
  250. # Use empty
  251. all_features = [
  252. [] for _ in range(len(spike_times_in_samples))]
  253. all_features = np.asarray(all_features)
  254. if all_features.ndim != 2:
  255. raise ValueError("waveform features should be 2d array")
  256. # Check number of features we're supposed to have
  257. try:
  258. n_features = self._group2features[group]
  259. except KeyError:
  260. # First time through .. set number of features
  261. n_features = all_features.shape[1]
  262. self._group2features[group] = n_features
  263. # and write to first line of file
  264. fetfilehandle.write("%d\n" % n_features)
  265. if n_features != all_features.shape[1]:
  266. raise ValueError("inconsistent number of features: " +
  267. "supposed to be %d but I got %d" %\
  268. (n_features, all_features.shape[1]))
  269. # Write features and time for each spike
  270. for stt, features in zip(spike_times_in_samples, all_features):
  271. # first features
  272. for val in features:
  273. fetfilehandle.write(str(val))
  274. fetfilehandle.write(" ")
  275. # now time
  276. fetfilehandle.write("%d\n" % stt)
  277. # and cluster id
  278. clufilehandle.write("%d\n" % cluster)
  279. # We're done, so close the files
  280. self._close_all_files()
  281. # Helper functions for writing
  282. def st2group(self, st):
  283. # Not sure this is right so make it a method in case we change it
  284. try:
  285. return st.annotations['group']
  286. except KeyError:
  287. return 0
  288. def st2cluster(self, st):
  289. # Not sure this is right so make it a method in case we change it
  290. try:
  291. return st.annotations['cluster']
  292. except KeyError:
  293. return 0
  294. def _make_all_file_handles(self, block):
  295. """Get the tetrode (group) of each neuron (cluster) by descending
  296. the hierarchy through segment and block.
  297. Store in a dict {group_id: list_of_clusters_in_that_group}
  298. """
  299. group2clusters = {}
  300. for seg in block.segments:
  301. for st in seg.spiketrains:
  302. group = self.st2group(st)
  303. cluster = self.st2cluster(st)
  304. if group in group2clusters:
  305. if cluster not in group2clusters[group]:
  306. group2clusters[group].append(cluster)
  307. else:
  308. group2clusters[group] = [cluster]
  309. # Make new file handles for each group
  310. self._fetfilehandles, self._clufilehandles = {}, {}
  311. for group, clusters in group2clusters.items():
  312. self._new_group(group, nbClusters=len(clusters))
  313. def _new_group(self, id_group, nbClusters):
  314. # generate filenames
  315. fetfilename = os.path.join(self.filename,
  316. self.basename + ('.fet.%d' % id_group))
  317. clufilename = os.path.join(self.filename,
  318. self.basename + ('.clu.%d' % id_group))
  319. # back up before overwriting
  320. if os.path.exists(fetfilename):
  321. shutil.copyfile(fetfilename, fetfilename + '~')
  322. if os.path.exists(clufilename):
  323. shutil.copyfile(clufilename, clufilename + '~')
  324. # create file handles
  325. self._fetfilehandles[id_group] = file(fetfilename, 'w')
  326. self._clufilehandles[id_group] = file(clufilename, 'w')
  327. # write out first line
  328. #self._fetfilehandles[id_group].write("0\n") # Number of features
  329. self._clufilehandles[id_group].write("%d\n" % nbClusters)
  330. def _close_all_files(self):
  331. for val in self._fetfilehandles.values(): val.close()
  332. for val in self._clufilehandles.values(): val.close()
  333. class FilenameParser:
  334. """Simple class to interpret user's requests into KlustaKwik filenames"""
  335. def __init__(self, dirname, basename=None):
  336. """Initialize a new parser for a directory containing files
  337. dirname: directory containing files
  338. basename: basename in KlustaKwik format spec
  339. If basename is left None, then files with any basename in the directory
  340. will be used. An error is raised if files with multiple basenames
  341. exist in the directory.
  342. """
  343. self.dirname = os.path.normpath(dirname)
  344. self.basename = basename
  345. # error check
  346. if not os.path.isdir(self.dirname):
  347. raise ValueError("filename must be a directory")
  348. def read_filenames(self, typestring='fet'):
  349. """Returns filenames in the data directory matching the type.
  350. Generally, `typestring` is one of the following:
  351. 'fet', 'clu', 'spk', 'res'
  352. Returns a dict {group_number: filename}, e.g.:
  353. { 0: 'basename.fet.0',
  354. 1: 'basename.fet.1',
  355. 2: 'basename.fet.2'}
  356. 'basename' can be any string not containing whitespace.
  357. Only filenames that begin with "basename.typestring." and end with
  358. a sequence of digits are valid. The digits are converted to an integer
  359. and used as the group number.
  360. """
  361. all_filenames = glob.glob(os.path.join(self.dirname, '*'))
  362. # Fill the dict with valid filenames
  363. d = {}
  364. for v in all_filenames:
  365. # Test whether matches format, ie ends with digits
  366. split_fn = os.path.split(v)[1]
  367. m = glob.re.search(('^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
  368. if m is not None:
  369. # get basename from first hit if not specified
  370. if self.basename is None:
  371. self.basename = m.group(1)
  372. # return files with correct basename
  373. if self.basename == m.group(1):
  374. # Key the group number to the filename
  375. # This conversion to int should always work since only
  376. # strings of digits will match the regex
  377. tetn = int(m.group(2))
  378. d[tetn] = v
  379. return d