123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- # -*- coding: utf-8 -*-
- """
- Module for reading/writing data from/to legacy PyNN formats.
- PyNN is available at http://neuralensemble.org/PyNN
- Classes:
- PyNNNumpyIO
- PyNNTextIO
- Supported: Read/Write
- Authors: Andrew Davison, Pierre Yger
- """
- from itertools import chain
- import numpy
- import quantities as pq
- import warnings
- from neo.io.baseio import BaseIO
- from neo.core import Segment, AnalogSignal, SpikeTrain
- try:
- unicode
- PY2 = True
- except NameError:
- PY2 = False
- UNITS_MAP = {
- 'spikes': pq.ms,
- 'v': pq.mV,
- 'gsyn': pq.UnitQuantity('microsiemens', 1e-6*pq.S, 'uS', 'µS'), # checked
- }
- class BasePyNNIO(BaseIO):
- """
- Base class for PyNN IO classes
- """
- is_readable = True
- is_writable = True
- has_header = True
- is_streameable = False # TODO - correct spelling to "is_streamable"
- supported_objects = [Segment, AnalogSignal, SpikeTrain]
- readable_objects = supported_objects
- writeable_objects = supported_objects
- mode = 'file'
- def _read_file_contents(self):
- raise NotImplementedError
- def _extract_array(self, data, channel_index):
- idx = numpy.where(data[:, 1] == channel_index)[0]
- return data[idx, 0]
- def _determine_units(self, metadata):
- if 'units' in metadata:
- return metadata['units']
- elif 'variable' in metadata and metadata['variable'] in UNITS_MAP:
- return UNITS_MAP[metadata['variable']]
- else:
- raise IOError("Cannot determine units")
- def _extract_signals(self, data, metadata, lazy):
- signal = None
- if lazy and data.size > 0:
- signal = AnalogSignal([],
- units=self._determine_units(metadata),
- sampling_period=metadata['dt']*pq.ms)
- signal.lazy_shape = None
- else:
- arr = numpy.vstack(self._extract_array(data, channel_index)
- for channel_index in range(metadata['first_index'], metadata['last_index'] + 1))
- if len(arr) > 0:
- signal = AnalogSignal(arr.T,
- units=self._determine_units(metadata),
- sampling_period=metadata['dt']*pq.ms)
- if signal is not None:
- signal.annotate(label=metadata["label"],
- variable=metadata["variable"])
- return signal
- def _extract_spikes(self, data, metadata, channel_index, lazy):
- spiketrain = None
- if lazy:
- if channel_index in data[:, 1]:
- spiketrain = SpikeTrain([], units=pq.ms, t_stop=0.0)
- spiketrain.lazy_shape = None
- else:
- spike_times = self._extract_array(data, channel_index)
- if len(spike_times) > 0:
- spiketrain = SpikeTrain(spike_times, units=pq.ms, t_stop=spike_times.max())
- if spiketrain is not None:
- spiketrain.annotate(label=metadata["label"],
- channel_index=channel_index,
- dt=metadata["dt"])
- return spiketrain
- def _write_file_contents(self, data, metadata):
- raise NotImplementedError
- def read_segment(self, lazy=False, cascade=True):
- data, metadata = self._read_file_contents()
- annotations = dict((k, metadata.get(k, 'unknown')) for k in ("label", "variable", "first_id", "last_id"))
- seg = Segment(**annotations)
- if cascade:
- if metadata['variable'] == 'spikes':
- for i in range(metadata['first_index'], metadata['last_index'] + 1):
- spiketrain = self._extract_spikes(data, metadata, i, lazy)
- if spiketrain is not None:
- seg.spiketrains.append(spiketrain)
- seg.annotate(dt=metadata['dt']) # store dt for SpikeTrains only, as can be retrieved from sampling_period for AnalogSignal
- else:
- signal = self._extract_signals(data, metadata, lazy)
- if signal is not None:
- seg.analogsignals.append(signal)
- seg.create_many_to_one_relationship()
- return seg
- def write_segment(self, segment):
- source = segment.analogsignals or segment.spiketrains
- assert len(source) > 0, "Segment contains neither analog signals nor spike trains."
- metadata = segment.annotations.copy()
- s0 = source[0]
- if isinstance(s0, AnalogSignal):
- if len(source) > 1:
- warnings.warn("Cannot handle multiple analog signals. Writing only the first.")
- source = s0.T
- metadata['size'] = s0.shape[1]
- n = source.size
- else:
- metadata['size'] = len(source)
- n = sum(s.size for s in source)
- metadata['first_index'] = 0
- metadata['last_index'] = metadata['size'] - 1
- if 'label' not in metadata:
- metadata['label'] = 'unknown'
- if 'dt' not in metadata: # dt not included in annotations if Segment contains only AnalogSignals
- metadata['dt'] = s0.sampling_period.rescale(pq.ms).magnitude
- metadata['n'] = n
- data = numpy.empty((n, 2))
- # if the 'variable' annotation is a standard one from PyNN, we rescale
- # to use standard PyNN units
- # we take the units from the first element of source and scale all
- # the signals to have the same units
- if 'variable' in segment.annotations:
- units = UNITS_MAP.get(segment.annotations['variable'], source[0].dimensionality)
- else:
- units = source[0].dimensionality
- metadata['variable'] = 'unknown'
- try:
- metadata['units'] = units.unicode
- except AttributeError:
- metadata['units'] = units.u_symbol
- start = 0
- for i, signal in enumerate(source): # here signal may be AnalogSignal or SpikeTrain
- end = start + signal.size
- data[start:end, 0] = numpy.array(signal.rescale(units))
- data[start:end, 1] = i*numpy.ones((signal.size,), dtype=float)
- start = end
- self._write_file_contents(data, metadata)
- def read_analogsignal(self, lazy=False):
- data, metadata = self._read_file_contents()
- if metadata['variable'] == 'spikes':
- raise TypeError("File contains spike data, not analog signals")
- else:
- signal = self._extract_signals(data, metadata, lazy)
- if signal is None:
- raise IndexError("File does not contain a signal")
- else:
- return signal
- def read_spiketrain(self, lazy=False, channel_index=0):
- data, metadata = self._read_file_contents()
- if metadata['variable'] != 'spikes':
- raise TypeError("File contains analog signals, not spike data")
- else:
- spiketrain = self._extract_spikes(data, metadata, channel_index, lazy)
- if spiketrain is None:
- raise IndexError("File does not contain any spikes with channel index %d" % channel_index)
- else:
- return spiketrain
- class PyNNNumpyIO(BasePyNNIO):
- """
- Reads/writes data from/to PyNN NumpyBinaryFile format
- """
- name = "PyNN NumpyBinaryFile"
- extensions = ['npz']
- def _read_file_contents(self):
- contents = numpy.load(self.filename)
- data = contents["data"]
- metadata = {}
- for name,value in contents['metadata']:
- try:
- metadata[name] = eval(value)
- except Exception:
- metadata[name] = value
- return data, metadata
- def _write_file_contents(self, data, metadata):
- # we explicitly set the dtype to ensure roundtrips preserve file contents exactly
- max_metadata_length = max(chain([len(k) for k in metadata.keys()],
- [len(str(v)) for v in metadata.values()]))
- if PY2:
- dtype = "S%d" % max_metadata_length
- else:
- dtype = "U%d" % max_metadata_length
- metadata_array = numpy.array(sorted(metadata.items()), dtype)
- numpy.savez(self.filename, data=data, metadata=metadata_array)
- class PyNNTextIO(BasePyNNIO):
- """
- Reads/writes data from/to PyNN StandardTextFile format
- """
- name = "PyNN StandardTextFile"
- extensions = ['v', 'ras', 'gsyn']
- def _read_metadata(self):
- metadata = {}
- with open(self.filename) as f:
- for line in f:
- if line[0] == "#":
- name, value = line[1:].strip().split("=")
- name = name.strip()
- try:
- metadata[name] = eval(value)
- except Exception:
- metadata[name] = value.strip()
- else:
- break
- return metadata
- def _read_file_contents(self):
- data = numpy.loadtxt(self.filename)
- metadata = self._read_metadata()
- return data, metadata
- def _write_file_contents(self, data, metadata):
- with open(self.filename, 'wb') as f:
- for item in sorted(metadata.items()):
- f.write(("# %s = %s\n" % item).encode('utf8'))
- numpy.savetxt(f, data)
|