# -*- coding: utf-8 -*- ''' Created on 03.10.2012 @author: frank ''' from .chunkfile import ChunkFile, ChunkFileWriter, FileFormat import struct import numpy as np import matplotlib.pyplot as plt import matplotlib.cm as cm import logging class Error(Exception): """base class for Connection errors.""" pass class NumberOfNodesMismatchError(Error): def __init__(self, con_n_target, target_nodes): self.con_n_target = con_n_target self.target_nodes = target_nodes def __str__(self): ret = "too few target nodes: connection has " ret += str(self.con_n_target) ret += " but only " + str(len(self.target_nodes)) + " are given" return ret def main(): c = Connection() # c.load_weight_file("../test/data/conExInhweights.dat") c.load_weight_file("../test/data/simdata_som02_gauss/conForward0weights.dat19") c.calc_pre_post_arrays() a = c.get_incomming_matrix(7690) print(a.sum()) plt.imshow(a, cmap=cm.gray, interpolation="nearest") plt.show() class ConnectionHeader(object): """represents a header of an objsim weight file""" HEADER_LEN_V0 = 6*4 HEADER_LEN_V1 = 7*4 def __init__(self, header_bytes="", n_source=0, ns_x=0, ns_y=0, n_target=0, nt_x=0, nt_y=0, xfast=True): if header_bytes: self.from_string(header_bytes) else: self.n_source = n_source self.ns_x = ns_x self.ns_y = ns_y self.n_target = n_target self.nt_x = nt_x self.nt_y = nt_y self.xfast = xfast def from_string(self, header_bytes): header_len = len(header_bytes) if header_len < self.HEADER_LEN_V0: raise Error("wrong header length: " + str(header_len)) self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y = struct.unpack('iiiiii', header_bytes[0:self.HEADER_LEN_V0]) if header_len == self.HEADER_LEN_V1: xfast_int = struct.unpack('i', header_bytes[self.HEADER_LEN_V0:self.HEADER_LEN_V1]) self.xfast = (xfast_int == 1) else: self.xfast = True def bytes(self, with_xfast=True): if with_xfast: return struct.pack('iiiiiii', self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y, self.xfast) else: return struct.pack('iiiiii', self.n_source, self.ns_x, self.ns_y, self.n_target, self.nt_x, self.nt_y) def get_members(self, names): ret = "" for n in names: ret += n + str(getattr(self, n)) def __repr__(self): ret = "" return ret class Connection(object): """class to access objsim weight files""" def __init__(self, file_path=None): self.con_header = None self.source_nr = None self.target_nr = None self.weights = None self.delays = None if file_path is not None: self.load_weight_file(file_path) def load_weight_file(self, file_path): """load an objsim weight file""" weight_file = ChunkFile(file_path) header_bytes = weight_file.read_chunk("VecHeader") self.con_header = ConnectionHeader(header_bytes) logging.debug(self.con_header) self.source_nr = weight_file.read_array("SourceNr", 'i') self.target_nr = weight_file.read_array("TargetNr", 'i') self.delays = weight_file.read_array("Delays", 'i') self.weights = weight_file.read_array("Weights", 'f') #print len(self.source_nr) #print self.source_nr[0:400] #print self.source_nr.max() #print weight_file.get_chunk_names() #print self.weights[0:200] #print self.weights.max() self.calc_pre_post_arrays() @classmethod def from_nest_con_list(cls, con_list, source_nodes, target_nodes, source_dim=None, target_dim=None, dt=0.25): import nest connection = cls() connection.weights = np.array(nest.GetStatus(con_list, 'weight'), dtype='f') connection.source_nr = np.array(nest.GetStatus(con_list, 'source'), dtype='i') connection.source_nr -= source_nodes[0] connection.target_nr = np.array(nest.GetStatus(con_list, 'target'), dtype='i') connection.target_nr -= target_nodes[0] delays_ms = np.array(nest.GetStatus(con_list, 'delay')) delays_int = np.array(delays_ms / dt, dtype='i') connection.delays = delays_int n_source = len(source_nodes) n_target = len(target_nodes) if source_dim is None: source_dim = [n_source, 1] if target_dim is None: target_dim = [n_target, 1] connection.con_header = ConnectionHeader(n_source=n_source, ns_x=source_dim[0], ns_y=source_dim[1], n_target=n_target, nt_x=target_dim[0], nt_y=target_dim[1], xfast=False) return connection def save(self, file_path, major=2, minor=1): file_format = FileFormat('VecConnection', major, minor) weight_file = ChunkFileWriter(file_path, file_format) weight_file.write_byte_chunk('VecHeader', self.con_header.bytes(with_xfast=(minor>0))) weight_file.write_int_array_chunk('SourceNr', self.source_nr) weight_file.write_int_array_chunk('TargetNr', self.target_nr) weight_file.write_int_array_chunk('Delays', self.delays) weight_file.write_float_array_chunk('Weights', self.weights) weight_file.close() def get_n_source(self): return self.con_header.n_source def get_n_target(self): return self.con_header.n_target def get_target_nx(self): return self.con_header.nt_x def get_target_ny(self): return self.con_header.nt_y def calc_pre_post_arrays(self): self.pre_syn_nr = [[] for x in range(self.con_header.n_target)] self.post_syn_nr = [[] for x in range(self.con_header.n_source)] n_synapses = len(self.source_nr) for i in range(n_synapses): self.pre_syn_nr[self.target_nr[i]].append(i) self.post_syn_nr[self.source_nr[i]].append(i) def get_incomming_matrix(self, target_nr): m = np.zeros(self.con_header.n_source) for syn_nr in self.pre_syn_nr[target_nr]: m[self.source_nr[syn_nr]] += self.weights[syn_nr] # numpy indexing order: [z,y,x]. higher to lower dimension from left to right return m.reshape(self.con_header.ns_y, self.con_header.ns_x) def get_incomming_weight_sum(self): """Calculate the sum of incomming for every postsynaptic neuron.""" m = np.zeros(self.con_header.n_target) for target_nr in range(self.con_header.n_target): m[target_nr] = sum([self.weights[syn_nr] for syn_nr in self.pre_syn_nr[target_nr]]) return m.reshape(self.con_header.nt_y, self.con_header.nt_x) def calc_nest_con_params(self, target_nodes): """ Calculate NEST connection parameters to be used with nest.DataConnect :param target_nodes: list of NEST gid numbers for target nodes. Must be as long as number of target nodes of the connection object :rtype: [{'target': ..., 'weight': ..., 'delay': ...}, ...] to be used with nest.DataConnect :raises: NumberOfNodesMismatchError when target_nodes has too few items for this connection """ n_source = self.get_n_source() n_target = self.get_n_target() if len(target_nodes) < n_target: raise NumberOfNodesMismatchError(n_target, target_nodes) params = [] for i in range(n_source): synapses = self.post_syn_nr[i] if synapses: target_ids = (self.target_nr[s] for s in synapses) target = [float(target_nodes[t]) for t in target_ids] delay = [float(self.delays[s])+1.0 for s in synapses] weight = [float(self.weights[s]) for s in synapses] params.append({'target': target, 'weight': weight, 'delay': delay}) logging.debug("n params = " + str(len(params))) return params if __name__ == '__main__': main()