klustakwikio.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """
  2. Reading and writing from KlustaKwik-format files.
  3. Ref: http://klusters.sourceforge.net/UserManual/data-files.html
  4. Supported : Read, Write
  5. Author : Chris Rodgers
  6. TODO:
  7. * When reading, put the Unit into the RCG, RC hierarchy
  8. * When writing, figure out how to get group and cluster if those annotations
  9. weren't set. Consider removing those annotations if they are redundant.
  10. * Load features in addition to spiketimes.
  11. """
  12. import glob
  13. import logging
  14. import os.path
  15. import shutil
  16. # note neo.core need only numpy and quantitie
  17. import numpy as np
  18. try:
  19. import matplotlib.mlab as mlab
  20. except ImportError as err:
  21. HAVE_MLAB = False
  22. MLAB_ERR = err
  23. else:
  24. HAVE_MLAB = True
  25. MLAB_ERR = None
  26. # I need to subclass BaseIO
  27. from neo.io.baseio import BaseIO
  28. from neo.core import Block, Segment, Unit, SpikeTrain
  29. # Pasted version of feature file format spec
  30. """
  31. The Feature File
  32. Generic file name: base.fet.n
  33. Format: ASCII, integer values
  34. The feature file lists for each spike the PCA coefficients for each
  35. electrode, followed by the timestamp of the spike (more features can
  36. be inserted between the PCA coefficients and the timestamp).
  37. The first line contains the number of dimensions.
  38. Assuming N1 spikes (spike1...spikeN1), N2 electrodes (e1...eN2) and
  39. N3 coefficients (c1...cN3), this file looks like:
  40. nbDimensions
  41. c1_e1_spk1 c2_e1_spk1 ... cN3_e1_spk1 c1_e2_spk1 ... cN3_eN2_spk1 timestamp_spk1
  42. c1_e1_spk2 c2_e1_spk2 ... cN3_e1_spk2 c1_e2_spk2 ... cN3_eN2_spk2 timestamp_spk2
  43. ...
  44. c1_e1_spkN1 c2_e1_spkN1 ... cN3_e1_spkN1 c1_e2_spkN1 ... cN3_eN2_spkN1 timestamp_spkN1
  45. The timestamp is expressed in multiples of the sampling interval. For
  46. instance, for a 20kHz recording (50 microsecond sampling interval), a
  47. timestamp of 200 corresponds to 200x0.000050s=0.01s from the beginning
  48. of the recording session.
  49. Notice that the last line must end with a newline or carriage return.
  50. """
  51. class KlustaKwikIO(BaseIO):
  52. """Reading and writing from KlustaKwik-format files."""
  53. # Class variables demonstrating capabilities of this IO
  54. is_readable = True
  55. is_writable = True
  56. # This IO can only manipulate objects relating to spike times
  57. supported_objects = [Block, SpikeTrain, Unit]
  58. # Keep things simple by always returning a block
  59. readable_objects = [Block]
  60. # And write a block
  61. writeable_objects = [Block]
  62. # Not sure what these do, if anything
  63. has_header = False
  64. is_streameable = False
  65. # GUI params
  66. read_params = {}
  67. # GUI params
  68. write_params = {}
  69. # The IO name and the file extensions it uses
  70. name = 'KlustaKwik'
  71. extensions = ['fet', 'clu', 'res', 'spk']
  72. # Operates on directories
  73. mode = 'file'
  74. def __init__(self, filename, sampling_rate=30000.):
  75. """Create a new IO to operate on a directory
  76. filename : the directory to contain the files
  77. basename : string, basename of KlustaKwik format, or None
  78. sampling_rate : in Hz, necessary because the KlustaKwik files
  79. stores data in samples.
  80. """
  81. if not HAVE_MLAB:
  82. raise MLAB_ERR
  83. BaseIO.__init__(self)
  84. # self.filename = os.path.normpath(filename)
  85. self.filename, self.basename = os.path.split(os.path.abspath(filename))
  86. self.sampling_rate = float(sampling_rate)
  87. # error check
  88. if not os.path.isdir(self.filename):
  89. raise ValueError("filename must be a directory")
  90. # initialize a helper object to parse filenames
  91. self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
  92. def read_block(self, lazy=False):
  93. """Returns a Block containing spike information.
  94. There is no obvious way to infer the segment boundaries from
  95. raw spike times, so for now all spike times are returned in one
  96. big segment. The way around this would be to specify the segment
  97. boundaries, and then change this code to put the spikes in the right
  98. segments.
  99. """
  100. assert not lazy, 'Do not support lazy'
  101. # Create block and segment to hold all the data
  102. block = Block()
  103. # Search data directory for KlustaKwik files.
  104. # If nothing found, return empty block
  105. self._fetfiles = self._fp.read_filenames('fet')
  106. self._clufiles = self._fp.read_filenames('clu')
  107. if len(self._fetfiles) == 0:
  108. return block
  109. # Create a single segment to hold all of the data
  110. seg = Segment(name='seg0', index=0, file_origin=self.filename)
  111. block.segments.append(seg)
  112. # Load spike times from each group and store in a dict, keyed
  113. # by group number
  114. self.spiketrains = dict()
  115. for group in sorted(self._fetfiles.keys()):
  116. # Load spike times
  117. fetfile = self._fetfiles[group]
  118. spks, features = self._load_spike_times(fetfile)
  119. # Load cluster ids or generate
  120. if group in self._clufiles:
  121. clufile = self._clufiles[group]
  122. uids = self._load_unit_id(clufile)
  123. else:
  124. # unclustered data, assume all zeros
  125. uids = np.zeros(spks.shape, dtype=np.int32)
  126. # error check
  127. if len(spks) != len(uids):
  128. raise ValueError("lengths of fet and clu files are different")
  129. # Create Unit for each cluster
  130. unique_unit_ids = np.unique(uids)
  131. for unit_id in sorted(unique_unit_ids):
  132. # Initialize the unit
  133. u = Unit(name=('unit %d from group %d' % (unit_id, group)),
  134. index=unit_id, group=group)
  135. # Initialize a new SpikeTrain for the spikes from this unit
  136. st = SpikeTrain(
  137. times=spks[uids == unit_id] / self.sampling_rate,
  138. units='sec', t_start=0.0,
  139. t_stop=spks.max() / self.sampling_rate,
  140. name=('unit %d from group %d' % (unit_id, group)))
  141. st.annotations['cluster'] = unit_id
  142. st.annotations['group'] = group
  143. # put features in
  144. if len(features) != 0:
  145. st.annotations['waveform_features'] = features
  146. # Link
  147. u.spiketrains.append(st)
  148. seg.spiketrains.append(st)
  149. block.create_many_to_one_relationship()
  150. return block
  151. # Helper hidden functions for reading
  152. def _load_spike_times(self, fetfilename):
  153. """Reads and returns the spike times and features"""
  154. with open(fetfilename, mode='r') as f:
  155. # Number of clustering features is integer on first line
  156. nbFeatures = int(f.readline().strip())
  157. # Each subsequent line consists of nbFeatures values, followed by
  158. # the spike time in samples.
  159. names = ['fet%d' % n for n in range(nbFeatures)]
  160. names.append('spike_time')
  161. # Load into recarray
  162. data = np.recfromtxt(fetfilename, names=names, skip_header=1, delimiter=' ')
  163. # get features
  164. features = np.array([data['fet%d' % n] for n in range(nbFeatures)])
  165. # Return the spike_time column
  166. return data['spike_time'], features.transpose()
  167. def _load_unit_id(self, clufilename):
  168. """Reads and return the cluster ids as int32"""
  169. with open(clufilename, mode='r') as f:
  170. # Number of clusters on this tetrode is integer on first line
  171. nbClusters = int(f.readline().strip())
  172. # Read each cluster name as a string
  173. cluster_names = f.readlines()
  174. # Convert names to integers
  175. # I think the spec requires cluster names to be integers, but
  176. # this code could be modified to support string names which are
  177. # auto-numbered.
  178. try:
  179. cluster_ids = [int(name) for name in cluster_names]
  180. except ValueError:
  181. raise ValueError(
  182. "Could not convert cluster name to integer in %s" % clufilename)
  183. # convert to numpy array and error check
  184. cluster_ids = np.array(cluster_ids, dtype=np.int32)
  185. if len(np.unique(cluster_ids)) != nbClusters:
  186. logging.warning("warning: I got %d clusters instead of %d in %s" % (
  187. len(np.unique(cluster_ids)), nbClusters, clufilename))
  188. return cluster_ids
  189. # writing functions
  190. def write_block(self, block):
  191. """Write spike times and unit ids to disk.
  192. Currently descends hierarchy from block to segment to spiketrain.
  193. Then gets group and cluster information from spiketrain.
  194. Then writes the time and cluster info to the file associated with
  195. that group.
  196. The group and cluster information are extracted from annotations,
  197. eg `sptr.annotations['group']`. If no cluster information exists,
  198. it is assigned to cluster 0.
  199. Note that all segments are essentially combined in
  200. this process, since the KlustaKwik format does not allow for
  201. segment boundaries.
  202. As implemented currently, does not use the `Unit` object at all.
  203. We first try to use the sampling rate of each SpikeTrain, or if this
  204. is not set, we use `self.sampling_rate`.
  205. If the files already exist, backup copies are created by appending
  206. the filenames with a "~".
  207. """
  208. # set basename
  209. if self.basename is None:
  210. logging.warning("warning: no basename provided, using `basename`")
  211. self.basename = 'basename'
  212. # First create file handles for each group which will be stored
  213. self._make_all_file_handles(block)
  214. # We'll detect how many features belong in each group
  215. self._group2features = {}
  216. # Iterate through segments in this block
  217. for seg in block.segments:
  218. # Write each spiketrain of the segment
  219. for st in seg.spiketrains:
  220. # Get file handles for this spiketrain using its group
  221. group = self.st2group(st)
  222. fetfilehandle = self._fetfilehandles[group]
  223. clufilehandle = self._clufilehandles[group]
  224. # Get the id to write to clu file for this spike train
  225. cluster = self.st2cluster(st)
  226. # Choose sampling rate to convert to samples
  227. try:
  228. sr = st.annotations['sampling_rate']
  229. except KeyError:
  230. sr = self.sampling_rate
  231. # Convert to samples
  232. spike_times_in_samples = np.rint(
  233. np.array(st) * sr).astype(np.int)
  234. # Try to get features from spiketrain
  235. try:
  236. all_features = st.annotations['waveform_features']
  237. except KeyError:
  238. # Use empty
  239. all_features = [
  240. [] for _ in range(len(spike_times_in_samples))]
  241. all_features = np.asarray(all_features)
  242. if all_features.ndim != 2:
  243. raise ValueError("waveform features should be 2d array")
  244. # Check number of features we're supposed to have
  245. try:
  246. n_features = self._group2features[group]
  247. except KeyError:
  248. # First time through .. set number of features
  249. n_features = all_features.shape[1]
  250. self._group2features[group] = n_features
  251. # and write to first line of file
  252. fetfilehandle.write("%d\n" % n_features)
  253. if n_features != all_features.shape[1]:
  254. raise ValueError("inconsistent number of features: " +
  255. "supposed to be %d but I got %d" %
  256. (n_features, all_features.shape[1]))
  257. # Write features and time for each spike
  258. for stt, features in zip(spike_times_in_samples, all_features):
  259. # first features
  260. for val in features:
  261. fetfilehandle.write(str(val))
  262. fetfilehandle.write(" ")
  263. # now time
  264. fetfilehandle.write("%d\n" % stt)
  265. # and cluster id
  266. clufilehandle.write("%d\n" % cluster)
  267. # We're done, so close the files
  268. self._close_all_files()
  269. # Helper functions for writing
  270. def st2group(self, st):
  271. # Not sure this is right so make it a method in case we change it
  272. try:
  273. return st.annotations['group']
  274. except KeyError:
  275. return 0
  276. def st2cluster(self, st):
  277. # Not sure this is right so make it a method in case we change it
  278. try:
  279. return st.annotations['cluster']
  280. except KeyError:
  281. return 0
  282. def _make_all_file_handles(self, block):
  283. """Get the tetrode (group) of each neuron (cluster) by descending
  284. the hierarchy through segment and block.
  285. Store in a dict {group_id: list_of_clusters_in_that_group}
  286. """
  287. group2clusters = {}
  288. for seg in block.segments:
  289. for st in seg.spiketrains:
  290. group = self.st2group(st)
  291. cluster = self.st2cluster(st)
  292. if group in group2clusters:
  293. if cluster not in group2clusters[group]:
  294. group2clusters[group].append(cluster)
  295. else:
  296. group2clusters[group] = [cluster]
  297. # Make new file handles for each group
  298. self._fetfilehandles, self._clufilehandles = {}, {}
  299. for group, clusters in group2clusters.items():
  300. self._new_group(group, nbClusters=len(clusters))
  301. def _new_group(self, id_group, nbClusters):
  302. # generate filenames
  303. fetfilename = os.path.join(self.filename,
  304. self.basename + ('.fet.%d' % id_group))
  305. clufilename = os.path.join(self.filename,
  306. self.basename + ('.clu.%d' % id_group))
  307. # back up before overwriting
  308. if os.path.exists(fetfilename):
  309. shutil.copyfile(fetfilename, fetfilename + '~')
  310. if os.path.exists(clufilename):
  311. shutil.copyfile(clufilename, clufilename + '~')
  312. # create file handles
  313. self._fetfilehandles[id_group] = open(fetfilename, mode='w')
  314. self._clufilehandles[id_group] = open(clufilename, mode='w')
  315. # write out first line
  316. # self._fetfilehandles[id_group].write("0\n") # Number of features
  317. self._clufilehandles[id_group].write("%d\n" % nbClusters)
  318. def _close_all_files(self):
  319. for val in self._fetfilehandles.values():
  320. val.close()
  321. for val in self._clufilehandles.values():
  322. val.close()
  323. class FilenameParser:
  324. """Simple class to interpret user's requests into KlustaKwik filenames"""
  325. def __init__(self, dirname, basename=None):
  326. """Initialize a new parser for a directory containing files
  327. dirname: directory containing files
  328. basename: basename in KlustaKwik format spec
  329. If basename is left None, then files with any basename in the directory
  330. will be used. An error is raised if files with multiple basenames
  331. exist in the directory.
  332. """
  333. self.dirname = os.path.normpath(dirname)
  334. self.basename = basename
  335. # error check
  336. if not os.path.isdir(self.dirname):
  337. raise ValueError("filename must be a directory")
  338. def read_filenames(self, typestring='fet'):
  339. """Returns filenames in the data directory matching the type.
  340. Generally, `typestring` is one of the following:
  341. 'fet', 'clu', 'spk', 'res'
  342. Returns a dict {group_number: filename}, e.g.:
  343. { 0: 'basename.fet.0',
  344. 1: 'basename.fet.1',
  345. 2: 'basename.fet.2'}
  346. 'basename' can be any string not containing whitespace.
  347. Only filenames that begin with "basename.typestring." and end with
  348. a sequence of digits are valid. The digits are converted to an integer
  349. and used as the group number.
  350. """
  351. all_filenames = glob.glob(os.path.join(self.dirname, '*'))
  352. # Fill the dict with valid filenames
  353. d = {}
  354. for v in all_filenames:
  355. # Test whether matches format, ie ends with digits
  356. split_fn = os.path.split(v)[1]
  357. m = glob.re.search((r'^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
  358. if m is not None:
  359. # get basename from first hit if not specified
  360. if self.basename is None:
  361. self.basename = m.group(1)
  362. # return files with correct basename
  363. if self.basename == m.group(1):
  364. # Key the group number to the filename
  365. # This conversion to int should always work since only
  366. # strings of digits will match the regex
  367. tetn = int(m.group(2))
  368. d[tetn] = v
  369. return d