123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import nest
- import time
- import numpy as np
- import json
- import argparse
- import sys
- import os
- from pathlib import Path
- def build_network(config, weights, spikes_ex_output, spikes_in_output):
- # Initalizing Network
- nest.ResetKernel()
- data_path = Path(spikes_ex_output).parents[0]
- data_prefix = Path(spikes_ex_output).parts[-1].split('_')[0] + '_'
- file_extension = Path(spikes_ex_output).suffix.strip('.')
- np.random.seed(config['seed'])
- nest.rng_seed = config['seed']
- nest.SetKernelStatus({"resolution" : config['dt'],
- "print_time" : True,
- "overwrite_files": True,
- "rng_seed" : config['seed'],
- "data_path" : str(data_path),
- "data_prefix" : data_prefix})
- nest.SetDefaults(config['neuron_model'], config['neuron_params'])
- delay_dist = nest.random.uniform(min=config['delay'][0],
- max=config['delay'][1])
- avg_delay = np.mean(config['delay'])
- nest.CopyModel(config['synapse_model'], "excitatory",
- {"weight": config['J_ex'], "delay": avg_delay})
- nodes_ex = nest.Create(config['neuron_model'], config['NE'])
- nodes_in = nest.Create(config['neuron_model'], config['NI'])
- espikes = nest.Create("spike_recorder")
- ispikes = nest.Create("spike_recorder")
- espikes.set(record_to='ascii', label='ex', file_extension=file_extension,
- precision=2)
- ispikes.set(record_to='ascii', label='in', file_extension=file_extension,
- precision=2)
- noise = nest.Create(config['stimulus'], params={"rate": config['p_rate']})
- if config['parrot_input']:
- stimulus_input = nest.Create('parrot_neuron', 1)
- nest.Connect(noise, stimulus_input, syn_spec="excitatory")
- else:
- stimulus_input = noise
- # Connecting Devices
- # nest.Connect(noise, nodes_ex, syn_spec="excitatory")
- # nest.Connect(noise, nodes_in, syn_spec="excitatory")
- nest.Connect(stimulus_input, nodes_ex+nodes_in, syn_spec="excitatory")
- nest.Connect(nodes_ex, espikes, syn_spec="excitatory")
- nest.Connect(nodes_in, ispikes, syn_spec="excitatory")
- # Connecting Network
- for pre, weight in zip(nodes_ex + nodes_in, weights):
- nonzero_indices = np.where(weight != 0)[0]
- weight = weight[nonzero_indices]
- post_array = np.array(nodes_ex + nodes_in)[nonzero_indices]
- pre_array = np.ones(len(nonzero_indices), dtype=int)*pre.get('global_id')
- # delay = np.array([config['delay'] for _ in nonzero_indices])
- delay = np.array([delay_dist.GetValue() for _ in nonzero_indices])
- nest.Connect(pre_array, post_array, conn_spec='one_to_one',
- syn_spec={'weight': weight, 'delay': delay})
- # conn = nest.GetConnections(source = nodes_ex + nodes_in,
- # target = nodes_ex + nodes_in)
- #
- # weight_matrix = np.zeros((config['N'],config['N']))
- # for s, t, w in nest.GetStatus(conn, ['source', 'target', 'weight']):
- # weight_matrix[s-1][t-1] = w
- fname = lambda spikes, id: f'{data_prefix}{id}-{spikes.get("global_id")}' \
- + f'-{spikes.get("vp")}.{file_extension}'
- espikes_path = data_path / fname(espikes, 'ex')
- ispikes_path = data_path / fname(ispikes, 'in')
- return espikes_path, ispikes_path
- if __name__ == '__main__':
- CLI = argparse.ArgumentParser()
- CLI.add_argument("--N", nargs='?', type=int)
- CLI.add_argument("--f", nargs='?', type=float)
- CLI.add_argument("--mu", nargs='?', type=float)
- CLI.add_argument("--epsilon", nargs='?', type=float)
- CLI.add_argument("--sigma_ex", nargs='?', type=float)
- CLI.add_argument("--sigma_in", nargs='?', type=float)
- CLI.add_argument("--simtime", nargs='?', type=float)
- CLI.add_argument("--eta", nargs='?', type=float)
- CLI.add_argument("--seed", nargs='?', type=int)
- CLI.add_argument("--spikes_ex_path", nargs='?', type=str)
- CLI.add_argument("--spikes_in_path", nargs='?', type=str)
- CLI.add_argument("--weights_path", nargs='?', type=str)
- CLI.add_argument("--network_config", nargs='?', type=str)
- args, unknown = CLI.parse_known_args()
- dirname = os.path.dirname(args.weights_path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- # Loading default configs from config file
- sys.path.append(os.path.dirname(args.network_config))
- import config as network_config
- locals_dict = locals()
- config = dict([(k,v) for k,v in vars(network_config).items()
- if k not in locals_dict])
- # Updating config with comandline arguments
- # (These are only the parameters which are eventually expanded)
- args.p_rate = network_config.p_rate_func(args.eta)
- update_params = ['N', 'f', 'mu', 'epsilon', 'signma_ex', 'sigma_in',
- 'simtime', 'seed', 'eta', 'p_rate']
- config.update(dict([(k,v) for k,v in vars(args).items()
- if k in update_params]))
- # Building network
- starttime = time.time()
- espikes_path, ispikes_path = build_network(config=config,
- weights=np.load(args.weights_path),
- spikes_ex_output= args.spikes_ex_path,
- spikes_in_output= args.spikes_in_path)
- # Run simulation
- nest.Simulate(config['simtime'])
- endtime = time.time()
- print("Simulation time : {:.2} s".format(endtime-starttime))
- espikes_path.rename(args.spikes_ex_path)
- ispikes_path.rename(args.spikes_in_path)
|