klustakwikio.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # -*- coding: utf-8 -*-
  2. """
  3. Reading and writing from KlustaKwik-format files.
  4. Ref: http://klusters.sourceforge.net/UserManual/data-files.html
  5. Supported : Read, Write
  6. Author : Chris Rodgers
  7. TODO:
  8. * When reading, put the Unit into the RCG, RC hierarchy
  9. * When writing, figure out how to get group and cluster if those annotations
  10. weren't set. Consider removing those annotations if they are redundant.
  11. * Load features in addition to spiketimes.
  12. """
  13. import glob
  14. import logging
  15. import os.path
  16. import shutil
  17. # note neo.core need only numpy and quantitie
  18. import numpy as np
  19. try:
  20. import matplotlib.mlab as mlab
  21. except ImportError as err:
  22. HAVE_MLAB = False
  23. MLAB_ERR = err
  24. else:
  25. HAVE_MLAB = True
  26. MLAB_ERR = None
  27. # I need to subclass BaseIO
  28. from neo.io.baseio import BaseIO
  29. from neo.core import Block, Segment, Unit, SpikeTrain
  30. # Pasted version of feature file format spec
  31. """
  32. The Feature File
  33. Generic file name: base.fet.n
  34. Format: ASCII, integer values
  35. The feature file lists for each spike the PCA coefficients for each
  36. electrode, followed by the timestamp of the spike (more features can
  37. be inserted between the PCA coefficients and the timestamp).
  38. The first line contains the number of dimensions.
  39. Assuming N1 spikes (spike1...spikeN1), N2 electrodes (e1...eN2) and
  40. N3 coefficients (c1...cN3), this file looks like:
  41. nbDimensions
  42. c1_e1_spk1 c2_e1_spk1 ... cN3_e1_spk1 c1_e2_spk1 ... cN3_eN2_spk1 timestamp_spk1
  43. c1_e1_spk2 c2_e1_spk2 ... cN3_e1_spk2 c1_e2_spk2 ... cN3_eN2_spk2 timestamp_spk2
  44. ...
  45. c1_e1_spkN1 c2_e1_spkN1 ... cN3_e1_spkN1 c1_e2_spkN1 ... cN3_eN2_spkN1 timestamp_spkN1
  46. The timestamp is expressed in multiples of the sampling interval. For
  47. instance, for a 20kHz recording (50 microsecond sampling interval), a
  48. timestamp of 200 corresponds to 200x0.000050s=0.01s from the beginning
  49. of the recording session.
  50. Notice that the last line must end with a newline or carriage return.
  51. """
  52. class KlustaKwikIO(BaseIO):
  53. """Reading and writing from KlustaKwik-format files."""
  54. # Class variables demonstrating capabilities of this IO
  55. is_readable = True
  56. is_writable = True
  57. # This IO can only manipulate objects relating to spike times
  58. supported_objects = [Block, SpikeTrain, Unit]
  59. # Keep things simple by always returning a block
  60. readable_objects = [Block]
  61. # And write a block
  62. writeable_objects = [Block]
  63. # Not sure what these do, if anything
  64. has_header = False
  65. is_streameable = False
  66. # GUI params
  67. read_params = {}
  68. # GUI params
  69. write_params = {}
  70. # The IO name and the file extensions it uses
  71. name = 'KlustaKwik'
  72. extensions = ['fet', 'clu', 'res', 'spk']
  73. # Operates on directories
  74. mode = 'file'
  75. def __init__(self, filename, sampling_rate=30000.):
  76. """Create a new IO to operate on a directory
  77. filename : the directory to contain the files
  78. basename : string, basename of KlustaKwik format, or None
  79. sampling_rate : in Hz, necessary because the KlustaKwik files
  80. stores data in samples.
  81. """
  82. if not HAVE_MLAB:
  83. raise MLAB_ERR
  84. BaseIO.__init__(self)
  85. # self.filename = os.path.normpath(filename)
  86. self.filename, self.basename = os.path.split(os.path.abspath(filename))
  87. self.sampling_rate = float(sampling_rate)
  88. # error check
  89. if not os.path.isdir(self.filename):
  90. raise ValueError("filename must be a directory")
  91. # initialize a helper object to parse filenames
  92. self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
  93. def read_block(self, lazy=False):
  94. """Returns a Block containing spike information.
  95. There is no obvious way to infer the segment boundaries from
  96. raw spike times, so for now all spike times are returned in one
  97. big segment. The way around this would be to specify the segment
  98. boundaries, and then change this code to put the spikes in the right
  99. segments.
  100. """
  101. assert not lazy, 'Do not support lazy'
  102. # Create block and segment to hold all the data
  103. block = Block()
  104. # Search data directory for KlustaKwik files.
  105. # If nothing found, return empty block
  106. self._fetfiles = self._fp.read_filenames('fet')
  107. self._clufiles = self._fp.read_filenames('clu')
  108. if len(self._fetfiles) == 0:
  109. return block
  110. # Create a single segment to hold all of the data
  111. seg = Segment(name='seg0', index=0, file_origin=self.filename)
  112. block.segments.append(seg)
  113. # Load spike times from each group and store in a dict, keyed
  114. # by group number
  115. self.spiketrains = dict()
  116. for group in sorted(self._fetfiles.keys()):
  117. # Load spike times
  118. fetfile = self._fetfiles[group]
  119. spks, features = self._load_spike_times(fetfile)
  120. # Load cluster ids or generate
  121. if group in self._clufiles:
  122. clufile = self._clufiles[group]
  123. uids = self._load_unit_id(clufile)
  124. else:
  125. # unclustered data, assume all zeros
  126. uids = np.zeros(spks.shape, dtype=np.int32)
  127. # error check
  128. if len(spks) != len(uids):
  129. raise ValueError("lengths of fet and clu files are different")
  130. # Create Unit for each cluster
  131. unique_unit_ids = np.unique(uids)
  132. for unit_id in sorted(unique_unit_ids):
  133. # Initialize the unit
  134. u = Unit(name=('unit %d from group %d' % (unit_id, group)),
  135. index=unit_id, group=group)
  136. # Initialize a new SpikeTrain for the spikes from this unit
  137. st = SpikeTrain(
  138. times=spks[uids == unit_id] / self.sampling_rate,
  139. units='sec', t_start=0.0,
  140. t_stop=spks.max() / self.sampling_rate,
  141. name=('unit %d from group %d' % (unit_id, group)))
  142. st.annotations['cluster'] = unit_id
  143. st.annotations['group'] = group
  144. # put features in
  145. if len(features) != 0:
  146. st.annotations['waveform_features'] = features
  147. # Link
  148. u.spiketrains.append(st)
  149. seg.spiketrains.append(st)
  150. block.create_many_to_one_relationship()
  151. return block
  152. # Helper hidden functions for reading
  153. def _load_spike_times(self, fetfilename):
  154. """Reads and returns the spike times and features"""
  155. with open(fetfilename, mode='r') as f:
  156. # Number of clustering features is integer on first line
  157. nbFeatures = int(f.readline().strip())
  158. # Each subsequent line consists of nbFeatures values, followed by
  159. # the spike time in samples.
  160. names = ['fet%d' % n for n in range(nbFeatures)]
  161. names.append('spike_time')
  162. # Load into recarray
  163. data = mlab.csv2rec(f, names=names, skiprows=1, delimiter=' ')
  164. # get features
  165. features = np.array([data['fet%d' % n] for n in range(nbFeatures)])
  166. # Return the spike_time column
  167. return data['spike_time'], features.transpose()
  168. def _load_unit_id(self, clufilename):
  169. """Reads and return the cluster ids as int32"""
  170. with open(clufilename, mode='r') as f:
  171. # Number of clusters on this tetrode is integer on first line
  172. nbClusters = int(f.readline().strip())
  173. # Read each cluster name as a string
  174. cluster_names = f.readlines()
  175. # Convert names to integers
  176. # I think the spec requires cluster names to be integers, but
  177. # this code could be modified to support string names which are
  178. # auto-numbered.
  179. try:
  180. cluster_ids = [int(name) for name in cluster_names]
  181. except ValueError:
  182. raise ValueError(
  183. "Could not convert cluster name to integer in %s" % clufilename)
  184. # convert to numpy array and error check
  185. cluster_ids = np.array(cluster_ids, dtype=np.int32)
  186. if len(np.unique(cluster_ids)) != nbClusters:
  187. logging.warning("warning: I got %d clusters instead of %d in %s" % (
  188. len(np.unique(cluster_ids)), nbClusters, clufilename))
  189. return cluster_ids
  190. # writing functions
  191. def write_block(self, block):
  192. """Write spike times and unit ids to disk.
  193. Currently descends hierarchy from block to segment to spiketrain.
  194. Then gets group and cluster information from spiketrain.
  195. Then writes the time and cluster info to the file associated with
  196. that group.
  197. The group and cluster information are extracted from annotations,
  198. eg `sptr.annotations['group']`. If no cluster information exists,
  199. it is assigned to cluster 0.
  200. Note that all segments are essentially combined in
  201. this process, since the KlustaKwik format does not allow for
  202. segment boundaries.
  203. As implemented currently, does not use the `Unit` object at all.
  204. We first try to use the sampling rate of each SpikeTrain, or if this
  205. is not set, we use `self.sampling_rate`.
  206. If the files already exist, backup copies are created by appending
  207. the filenames with a "~".
  208. """
  209. # set basename
  210. if self.basename is None:
  211. logging.warning("warning: no basename provided, using `basename`")
  212. self.basename = 'basename'
  213. # First create file handles for each group which will be stored
  214. self._make_all_file_handles(block)
  215. # We'll detect how many features belong in each group
  216. self._group2features = {}
  217. # Iterate through segments in this block
  218. for seg in block.segments:
  219. # Write each spiketrain of the segment
  220. for st in seg.spiketrains:
  221. # Get file handles for this spiketrain using its group
  222. group = self.st2group(st)
  223. fetfilehandle = self._fetfilehandles[group]
  224. clufilehandle = self._clufilehandles[group]
  225. # Get the id to write to clu file for this spike train
  226. cluster = self.st2cluster(st)
  227. # Choose sampling rate to convert to samples
  228. try:
  229. sr = st.annotations['sampling_rate']
  230. except KeyError:
  231. sr = self.sampling_rate
  232. # Convert to samples
  233. spike_times_in_samples = np.rint(
  234. np.array(st) * sr).astype(np.int)
  235. # Try to get features from spiketrain
  236. try:
  237. all_features = st.annotations['waveform_features']
  238. except KeyError:
  239. # Use empty
  240. all_features = [
  241. [] for _ in range(len(spike_times_in_samples))]
  242. all_features = np.asarray(all_features)
  243. if all_features.ndim != 2:
  244. raise ValueError("waveform features should be 2d array")
  245. # Check number of features we're supposed to have
  246. try:
  247. n_features = self._group2features[group]
  248. except KeyError:
  249. # First time through .. set number of features
  250. n_features = all_features.shape[1]
  251. self._group2features[group] = n_features
  252. # and write to first line of file
  253. fetfilehandle.write("%d\n" % n_features)
  254. if n_features != all_features.shape[1]:
  255. raise ValueError("inconsistent number of features: " +
  256. "supposed to be %d but I got %d" %
  257. (n_features, all_features.shape[1]))
  258. # Write features and time for each spike
  259. for stt, features in zip(spike_times_in_samples, all_features):
  260. # first features
  261. for val in features:
  262. fetfilehandle.write(str(val))
  263. fetfilehandle.write(" ")
  264. # now time
  265. fetfilehandle.write("%d\n" % stt)
  266. # and cluster id
  267. clufilehandle.write("%d\n" % cluster)
  268. # We're done, so close the files
  269. self._close_all_files()
  270. # Helper functions for writing
  271. def st2group(self, st):
  272. # Not sure this is right so make it a method in case we change it
  273. try:
  274. return st.annotations['group']
  275. except KeyError:
  276. return 0
  277. def st2cluster(self, st):
  278. # Not sure this is right so make it a method in case we change it
  279. try:
  280. return st.annotations['cluster']
  281. except KeyError:
  282. return 0
  283. def _make_all_file_handles(self, block):
  284. """Get the tetrode (group) of each neuron (cluster) by descending
  285. the hierarchy through segment and block.
  286. Store in a dict {group_id: list_of_clusters_in_that_group}
  287. """
  288. group2clusters = {}
  289. for seg in block.segments:
  290. for st in seg.spiketrains:
  291. group = self.st2group(st)
  292. cluster = self.st2cluster(st)
  293. if group in group2clusters:
  294. if cluster not in group2clusters[group]:
  295. group2clusters[group].append(cluster)
  296. else:
  297. group2clusters[group] = [cluster]
  298. # Make new file handles for each group
  299. self._fetfilehandles, self._clufilehandles = {}, {}
  300. for group, clusters in group2clusters.items():
  301. self._new_group(group, nbClusters=len(clusters))
  302. def _new_group(self, id_group, nbClusters):
  303. # generate filenames
  304. fetfilename = os.path.join(self.filename,
  305. self.basename + ('.fet.%d' % id_group))
  306. clufilename = os.path.join(self.filename,
  307. self.basename + ('.clu.%d' % id_group))
  308. # back up before overwriting
  309. if os.path.exists(fetfilename):
  310. shutil.copyfile(fetfilename, fetfilename + '~')
  311. if os.path.exists(clufilename):
  312. shutil.copyfile(clufilename, clufilename + '~')
  313. # create file handles
  314. self._fetfilehandles[id_group] = open(fetfilename, mode='w')
  315. self._clufilehandles[id_group] = open(clufilename, mode='w')
  316. # write out first line
  317. # self._fetfilehandles[id_group].write("0\n") # Number of features
  318. self._clufilehandles[id_group].write("%d\n" % nbClusters)
  319. def _close_all_files(self):
  320. for val in self._fetfilehandles.values():
  321. val.close()
  322. for val in self._clufilehandles.values():
  323. val.close()
  324. class FilenameParser:
  325. """Simple class to interpret user's requests into KlustaKwik filenames"""
  326. def __init__(self, dirname, basename=None):
  327. """Initialize a new parser for a directory containing files
  328. dirname: directory containing files
  329. basename: basename in KlustaKwik format spec
  330. If basename is left None, then files with any basename in the directory
  331. will be used. An error is raised if files with multiple basenames
  332. exist in the directory.
  333. """
  334. self.dirname = os.path.normpath(dirname)
  335. self.basename = basename
  336. # error check
  337. if not os.path.isdir(self.dirname):
  338. raise ValueError("filename must be a directory")
  339. def read_filenames(self, typestring='fet'):
  340. """Returns filenames in the data directory matching the type.
  341. Generally, `typestring` is one of the following:
  342. 'fet', 'clu', 'spk', 'res'
  343. Returns a dict {group_number: filename}, e.g.:
  344. { 0: 'basename.fet.0',
  345. 1: 'basename.fet.1',
  346. 2: 'basename.fet.2'}
  347. 'basename' can be any string not containing whitespace.
  348. Only filenames that begin with "basename.typestring." and end with
  349. a sequence of digits are valid. The digits are converted to an integer
  350. and used as the group number.
  351. """
  352. all_filenames = glob.glob(os.path.join(self.dirname, '*'))
  353. # Fill the dict with valid filenames
  354. d = {}
  355. for v in all_filenames:
  356. # Test whether matches format, ie ends with digits
  357. split_fn = os.path.split(v)[1]
  358. m = glob.re.search(('^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
  359. if m is not None:
  360. # get basename from first hit if not specified
  361. if self.basename is None:
  362. self.basename = m.group(1)
  363. # return files with correct basename
  364. if self.basename == m.group(1):
  365. # Key the group number to the filename
  366. # This conversion to int should always work since only
  367. # strings of digits will match the regex
  368. tetn = int(m.group(2))
  369. d[tetn] = v
  370. return d