123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- # -*- coding: utf-8 -*-
- '''
- Created on 03.10.2012
- @author: frank
- '''
- from .chunkfile import ChunkFile, ChunkFileWriter, FileFormat
- import struct
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib.cm as cm
- import logging
- class Error(Exception):
- """base class for Connection errors."""
- pass
- class NumberOfNodesMismatchError(Error):
- def __init__(self, con_n_target, target_nodes):
- self.con_n_target = con_n_target
- self.target_nodes = target_nodes
- def __str__(self):
- ret = "too few target nodes: connection has "
- ret += str(self.con_n_target)
- ret += " but only " + str(len(self.target_nodes)) + " are given"
- return ret
- def main():
- c = Connection()
- # c.load_weight_file("../test/data/conExInhweights.dat")
- c.load_weight_file("../test/data/simdata_som02_gauss/conForward0weights.dat19")
- c.calc_pre_post_arrays()
- a = c.get_incomming_matrix(7690)
- print(a.sum())
- plt.imshow(a, cmap=cm.gray, interpolation="nearest")
- plt.show()
-
- class ConnectionHeader(object):
- """represents a header of an objsim weight file"""
- HEADER_LEN_V0 = 6*4
- HEADER_LEN_V1 = 7*4
- def __init__(self, header_bytes="", n_source=0, ns_x=0, ns_y=0, n_target=0, nt_x=0, nt_y=0, xfast=True):
- if header_bytes:
- self.from_string(header_bytes)
- else:
- self.n_source = n_source
- self.ns_x = ns_x
- self.ns_y = ns_y
- self.n_target = n_target
- self.nt_x = nt_x
- self.nt_y = nt_y
- self.xfast = xfast
-
- def from_string(self, header_bytes):
- header_len = len(header_bytes)
- if header_len < self.HEADER_LEN_V0:
- raise Error("wrong header length: " + str(header_len))
- self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y = struct.unpack('iiiiii', header_bytes[0:self.HEADER_LEN_V0])
- if header_len == self.HEADER_LEN_V1:
- xfast_int = struct.unpack('i', header_bytes[self.HEADER_LEN_V0:self.HEADER_LEN_V1])
- self.xfast = (xfast_int == 1)
- else:
- self.xfast = True
-
- def bytes(self, with_xfast=True):
- if with_xfast:
- return struct.pack('iiiiiii', self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y, self.xfast)
- else:
- return struct.pack('iiiiii', self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y)
-
- def get_members(self, names):
- ret = ""
- for n in names:
- ret += n + str(getattr(self, n))
-
- def __repr__(self):
- ret = "<ConnectionHeader "
- members = ["n_source", "ns_x", "ns_y", "n_target", "nt_x", "nt_y"]
- ret += " ".join([n + "=" + str(getattr(self, n)) for n in members])
- ret += ">"
- return ret
- class Connection(object):
- """class to access objsim weight files"""
- def __init__(self, file_path=None):
- self.con_header = None
- self.source_nr = None
- self.target_nr = None
- self.weights = None
- self.delays = None
- if file_path is not None:
- self.load_weight_file(file_path)
-
- def load_weight_file(self, file_path):
- """load an objsim weight file"""
- weight_file = ChunkFile(file_path)
- header_bytes = weight_file.read_chunk("VecHeader")
- self.con_header = ConnectionHeader(header_bytes)
- logging.debug(self.con_header)
- self.source_nr = weight_file.read_array("SourceNr", 'i')
- self.target_nr = weight_file.read_array("TargetNr", 'i')
- self.delays = weight_file.read_array("Delays", 'i')
- self.weights = weight_file.read_array("Weights", 'f')
- #print len(self.source_nr)
- #print self.source_nr[0:400]
- #print self.source_nr.max()
-
- #print weight_file.get_chunk_names()
- #print self.weights[0:200]
- #print self.weights.max()
-
- self.calc_pre_post_arrays()
-
- @classmethod
- def from_nest_con_list(cls, con_list, source_nodes, target_nodes, source_dim=None, target_dim=None, dt=0.25):
- import nest
- connection = cls()
- connection.weights = np.array(nest.GetStatus(con_list, 'weight'), dtype='f')
- connection.source_nr = np.array(nest.GetStatus(con_list, 'source'), dtype='i')
- connection.source_nr -= source_nodes[0]
- connection.target_nr = np.array(nest.GetStatus(con_list, 'target'), dtype='i')
- connection.target_nr -= target_nodes[0]
- delays_ms = np.array(nest.GetStatus(con_list, 'delay'))
- delays_int = np.array(delays_ms / dt, dtype='i')
- connection.delays = delays_int
- n_source = len(source_nodes)
- n_target = len(target_nodes)
- if source_dim is None:
- source_dim = [n_source, 1]
- if target_dim is None:
- target_dim = [n_target, 1]
- connection.con_header = ConnectionHeader(n_source=n_source,
- ns_x=source_dim[0],
- ns_y=source_dim[1],
- n_target=n_target,
- nt_x=target_dim[0],
- nt_y=target_dim[1],
- xfast=False)
- return connection
-
- def save(self, file_path, major=2, minor=1):
- file_format = FileFormat('VecConnection', major, minor)
- weight_file = ChunkFileWriter(file_path, file_format)
- weight_file.write_byte_chunk('VecHeader', self.con_header.bytes(with_xfast=(minor>0)))
- weight_file.write_int_array_chunk('SourceNr', self.source_nr)
- weight_file.write_int_array_chunk('TargetNr', self.target_nr)
- weight_file.write_int_array_chunk('Delays', self.delays)
- weight_file.write_float_array_chunk('Weights', self.weights)
- weight_file.close()
-
- def get_n_source(self):
- return self.con_header.n_source
-
- def get_n_target(self):
- return self.con_header.n_target
-
- def get_target_nx(self):
- return self.con_header.nt_x
-
- def get_target_ny(self):
- return self.con_header.nt_y
-
- def calc_pre_post_arrays(self):
- self.pre_syn_nr = [[] for x in range(self.con_header.n_target)]
- self.post_syn_nr = [[] for x in range(self.con_header.n_source)]
- n_synapses = len(self.source_nr)
- for i in range(n_synapses):
- self.pre_syn_nr[self.target_nr[i]].append(i)
- self.post_syn_nr[self.source_nr[i]].append(i)
-
- def get_incomming_matrix(self, target_nr):
- m = np.zeros(self.con_header.n_source)
- for syn_nr in self.pre_syn_nr[target_nr]:
- m[self.source_nr[syn_nr]] += self.weights[syn_nr]
- # numpy indexing order: [z,y,x]. higher to lower dimension from left to right
- return m.reshape(self.con_header.ns_y, self.con_header.ns_x)
-
- def get_incomming_weight_sum(self):
- """Calculate the sum of incomming for every postsynaptic neuron."""
-
- m = np.zeros(self.con_header.n_target)
- for target_nr in range(self.con_header.n_target):
- m[target_nr] = sum([self.weights[syn_nr] for syn_nr in self.pre_syn_nr[target_nr]])
- return m.reshape(self.con_header.nt_y, self.con_header.nt_x)
- def calc_nest_con_params(self, target_nodes):
- """
- Calculate NEST connection parameters to be used with nest.DataConnect
-
- :param target_nodes: list of NEST gid numbers for target nodes.
- Must be as long as number of target nodes of the connection object
- :rtype: [{'target': ..., 'weight': ..., 'delay': ...}, ...]
- to be used with nest.DataConnect
-
- :raises: NumberOfNodesMismatchError when target_nodes has too few items for this connection
-
- """
-
- n_source = self.get_n_source()
- n_target = self.get_n_target()
- if len(target_nodes) < n_target:
- raise NumberOfNodesMismatchError(n_target, target_nodes)
- params = []
- for i in range(n_source):
- synapses = self.post_syn_nr[i]
- if synapses:
- target_ids = (self.target_nr[s] for s in synapses)
- target = [float(target_nodes[t]) for t in target_ids]
- delay = [float(self.delays[s])+1.0 for s in synapses]
- weight = [float(self.weights[s]) for s in synapses]
- params.append({'target': target, 'weight': weight, 'delay': delay})
- logging.debug("n params = " + str(len(params)))
- return params
- if __name__ == '__main__':
- main()
|