pynnio.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # -*- coding: utf-8 -*-
  2. """
  3. Module for reading/writing data from/to legacy PyNN formats.
  4. PyNN is available at http://neuralensemble.org/PyNN
  5. Classes:
  6. PyNNNumpyIO
  7. PyNNTextIO
  8. Supported: Read/Write
  9. Authors: Andrew Davison, Pierre Yger
  10. """
  11. from itertools import chain
  12. import numpy
  13. import quantities as pq
  14. import warnings
  15. from neo.io.baseio import BaseIO
  16. from neo.core import Segment, AnalogSignal, SpikeTrain
  17. try:
  18. unicode
  19. PY2 = True
  20. except NameError:
  21. PY2 = False
  22. UNITS_MAP = {
  23. 'spikes': pq.ms,
  24. 'v': pq.mV,
  25. 'gsyn': pq.UnitQuantity('microsiemens', 1e-6*pq.S, 'uS', 'µS'), # checked
  26. }
  27. class BasePyNNIO(BaseIO):
  28. """
  29. Base class for PyNN IO classes
  30. """
  31. is_readable = True
  32. is_writable = True
  33. has_header = True
  34. is_streameable = False # TODO - correct spelling to "is_streamable"
  35. supported_objects = [Segment, AnalogSignal, SpikeTrain]
  36. readable_objects = supported_objects
  37. writeable_objects = supported_objects
  38. mode = 'file'
  39. def _read_file_contents(self):
  40. raise NotImplementedError
  41. def _extract_array(self, data, channel_index):
  42. idx = numpy.where(data[:, 1] == channel_index)[0]
  43. return data[idx, 0]
  44. def _determine_units(self, metadata):
  45. if 'units' in metadata:
  46. return metadata['units']
  47. elif 'variable' in metadata and metadata['variable'] in UNITS_MAP:
  48. return UNITS_MAP[metadata['variable']]
  49. else:
  50. raise IOError("Cannot determine units")
  51. def _extract_signals(self, data, metadata, lazy):
  52. signal = None
  53. if lazy and data.size > 0:
  54. signal = AnalogSignal([],
  55. units=self._determine_units(metadata),
  56. sampling_period=metadata['dt']*pq.ms)
  57. signal.lazy_shape = None
  58. else:
  59. arr = numpy.vstack(self._extract_array(data, channel_index)
  60. for channel_index in range(metadata['first_index'], metadata['last_index'] + 1))
  61. if len(arr) > 0:
  62. signal = AnalogSignal(arr.T,
  63. units=self._determine_units(metadata),
  64. sampling_period=metadata['dt']*pq.ms)
  65. if signal is not None:
  66. signal.annotate(label=metadata["label"],
  67. variable=metadata["variable"])
  68. return signal
  69. def _extract_spikes(self, data, metadata, channel_index, lazy):
  70. spiketrain = None
  71. if lazy:
  72. if channel_index in data[:, 1]:
  73. spiketrain = SpikeTrain([], units=pq.ms, t_stop=0.0)
  74. spiketrain.lazy_shape = None
  75. else:
  76. spike_times = self._extract_array(data, channel_index)
  77. if len(spike_times) > 0:
  78. spiketrain = SpikeTrain(spike_times, units=pq.ms, t_stop=spike_times.max())
  79. if spiketrain is not None:
  80. spiketrain.annotate(label=metadata["label"],
  81. channel_index=channel_index,
  82. dt=metadata["dt"])
  83. return spiketrain
  84. def _write_file_contents(self, data, metadata):
  85. raise NotImplementedError
  86. def read_segment(self, lazy=False, cascade=True):
  87. data, metadata = self._read_file_contents()
  88. annotations = dict((k, metadata.get(k, 'unknown')) for k in ("label", "variable", "first_id", "last_id"))
  89. seg = Segment(**annotations)
  90. if cascade:
  91. if metadata['variable'] == 'spikes':
  92. for i in range(metadata['first_index'], metadata['last_index'] + 1):
  93. spiketrain = self._extract_spikes(data, metadata, i, lazy)
  94. if spiketrain is not None:
  95. seg.spiketrains.append(spiketrain)
  96. seg.annotate(dt=metadata['dt']) # store dt for SpikeTrains only, as can be retrieved from sampling_period for AnalogSignal
  97. else:
  98. signal = self._extract_signals(data, metadata, lazy)
  99. if signal is not None:
  100. seg.analogsignals.append(signal)
  101. seg.create_many_to_one_relationship()
  102. return seg
  103. def write_segment(self, segment):
  104. source = segment.analogsignals or segment.spiketrains
  105. assert len(source) > 0, "Segment contains neither analog signals nor spike trains."
  106. metadata = segment.annotations.copy()
  107. s0 = source[0]
  108. if isinstance(s0, AnalogSignal):
  109. if len(source) > 1:
  110. warnings.warn("Cannot handle multiple analog signals. Writing only the first.")
  111. source = s0.T
  112. metadata['size'] = s0.shape[1]
  113. n = source.size
  114. else:
  115. metadata['size'] = len(source)
  116. n = sum(s.size for s in source)
  117. metadata['first_index'] = 0
  118. metadata['last_index'] = metadata['size'] - 1
  119. if 'label' not in metadata:
  120. metadata['label'] = 'unknown'
  121. if 'dt' not in metadata: # dt not included in annotations if Segment contains only AnalogSignals
  122. metadata['dt'] = s0.sampling_period.rescale(pq.ms).magnitude
  123. metadata['n'] = n
  124. data = numpy.empty((n, 2))
  125. # if the 'variable' annotation is a standard one from PyNN, we rescale
  126. # to use standard PyNN units
  127. # we take the units from the first element of source and scale all
  128. # the signals to have the same units
  129. if 'variable' in segment.annotations:
  130. units = UNITS_MAP.get(segment.annotations['variable'], source[0].dimensionality)
  131. else:
  132. units = source[0].dimensionality
  133. metadata['variable'] = 'unknown'
  134. try:
  135. metadata['units'] = units.unicode
  136. except AttributeError:
  137. metadata['units'] = units.u_symbol
  138. start = 0
  139. for i, signal in enumerate(source): # here signal may be AnalogSignal or SpikeTrain
  140. end = start + signal.size
  141. data[start:end, 0] = numpy.array(signal.rescale(units))
  142. data[start:end, 1] = i*numpy.ones((signal.size,), dtype=float)
  143. start = end
  144. self._write_file_contents(data, metadata)
  145. def read_analogsignal(self, lazy=False):
  146. data, metadata = self._read_file_contents()
  147. if metadata['variable'] == 'spikes':
  148. raise TypeError("File contains spike data, not analog signals")
  149. else:
  150. signal = self._extract_signals(data, metadata, lazy)
  151. if signal is None:
  152. raise IndexError("File does not contain a signal")
  153. else:
  154. return signal
  155. def read_spiketrain(self, lazy=False, channel_index=0):
  156. data, metadata = self._read_file_contents()
  157. if metadata['variable'] != 'spikes':
  158. raise TypeError("File contains analog signals, not spike data")
  159. else:
  160. spiketrain = self._extract_spikes(data, metadata, channel_index, lazy)
  161. if spiketrain is None:
  162. raise IndexError("File does not contain any spikes with channel index %d" % channel_index)
  163. else:
  164. return spiketrain
  165. class PyNNNumpyIO(BasePyNNIO):
  166. """
  167. Reads/writes data from/to PyNN NumpyBinaryFile format
  168. """
  169. name = "PyNN NumpyBinaryFile"
  170. extensions = ['npz']
  171. def _read_file_contents(self):
  172. contents = numpy.load(self.filename)
  173. data = contents["data"]
  174. metadata = {}
  175. for name,value in contents['metadata']:
  176. try:
  177. metadata[name] = eval(value)
  178. except Exception:
  179. metadata[name] = value
  180. return data, metadata
  181. def _write_file_contents(self, data, metadata):
  182. # we explicitly set the dtype to ensure roundtrips preserve file contents exactly
  183. max_metadata_length = max(chain([len(k) for k in metadata.keys()],
  184. [len(str(v)) for v in metadata.values()]))
  185. if PY2:
  186. dtype = "S%d" % max_metadata_length
  187. else:
  188. dtype = "U%d" % max_metadata_length
  189. metadata_array = numpy.array(sorted(metadata.items()), dtype)
  190. numpy.savez(self.filename, data=data, metadata=metadata_array)
  191. class PyNNTextIO(BasePyNNIO):
  192. """
  193. Reads/writes data from/to PyNN StandardTextFile format
  194. """
  195. name = "PyNN StandardTextFile"
  196. extensions = ['v', 'ras', 'gsyn']
  197. def _read_metadata(self):
  198. metadata = {}
  199. with open(self.filename) as f:
  200. for line in f:
  201. if line[0] == "#":
  202. name, value = line[1:].strip().split("=")
  203. name = name.strip()
  204. try:
  205. metadata[name] = eval(value)
  206. except Exception:
  207. metadata[name] = value.strip()
  208. else:
  209. break
  210. return metadata
  211. def _read_file_contents(self):
  212. data = numpy.loadtxt(self.filename)
  213. metadata = self._read_metadata()
  214. return data, metadata
  215. def _write_file_contents(self, data, metadata):
  216. with open(self.filename, 'wb') as f:
  217. for item in sorted(metadata.items()):
  218. f.write(("# %s = %s\n" % item).encode('utf8'))
  219. numpy.savetxt(f, data)