simulate_network.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import nest
  2. import time
  3. import numpy as np
  4. import json
  5. import argparse
  6. import sys
  7. import os
  8. from pathlib import Path
  9. def build_network(config, weights, spikes_ex_output, spikes_in_output):
  10. # Initalizing Network
  11. nest.ResetKernel()
  12. data_path = Path(spikes_ex_output).parents[0]
  13. data_prefix = Path(spikes_ex_output).parts[-1].split('_')[0] + '_'
  14. file_extension = Path(spikes_ex_output).suffix.strip('.')
  15. np.random.seed(config['seed'])
  16. nest.rng_seed = config['seed']
  17. nest.SetKernelStatus({"resolution" : config['dt'],
  18. "print_time" : True,
  19. "overwrite_files": True,
  20. "rng_seed" : config['seed'],
  21. "data_path" : str(data_path),
  22. "data_prefix" : data_prefix})
  23. nest.SetDefaults(config['neuron_model'], config['neuron_params'])
  24. delay_dist = nest.random.uniform(min=config['delay'][0],
  25. max=config['delay'][1])
  26. avg_delay = np.mean(config['delay'])
  27. nest.CopyModel(config['synapse_model'], "excitatory",
  28. {"weight": config['J_ex'], "delay": avg_delay})
  29. nodes_ex = nest.Create(config['neuron_model'], config['NE'])
  30. nodes_in = nest.Create(config['neuron_model'], config['NI'])
  31. espikes = nest.Create("spike_recorder")
  32. ispikes = nest.Create("spike_recorder")
  33. espikes.set(record_to='ascii', label='ex', file_extension=file_extension,
  34. precision=2)
  35. ispikes.set(record_to='ascii', label='in', file_extension=file_extension,
  36. precision=2)
  37. noise = nest.Create(config['stimulus'], params={"rate": config['p_rate']})
  38. if config['parrot_input']:
  39. stimulus_input = nest.Create('parrot_neuron', 1)
  40. nest.Connect(noise, stimulus_input, syn_spec="excitatory")
  41. else:
  42. stimulus_input = noise
  43. # Connecting Devices
  44. # nest.Connect(noise, nodes_ex, syn_spec="excitatory")
  45. # nest.Connect(noise, nodes_in, syn_spec="excitatory")
  46. nest.Connect(stimulus_input, nodes_ex+nodes_in, syn_spec="excitatory")
  47. nest.Connect(nodes_ex, espikes, syn_spec="excitatory")
  48. nest.Connect(nodes_in, ispikes, syn_spec="excitatory")
  49. # Connecting Network
  50. for pre, weight in zip(nodes_ex + nodes_in, weights):
  51. nonzero_indices = np.where(weight != 0)[0]
  52. weight = weight[nonzero_indices]
  53. post_array = np.array(nodes_ex + nodes_in)[nonzero_indices]
  54. pre_array = np.ones(len(nonzero_indices), dtype=int)*pre.get('global_id')
  55. # delay = np.array([config['delay'] for _ in nonzero_indices])
  56. delay = np.array([delay_dist.GetValue() for _ in nonzero_indices])
  57. nest.Connect(pre_array, post_array, conn_spec='one_to_one',
  58. syn_spec={'weight': weight, 'delay': delay})
  59. # conn = nest.GetConnections(source = nodes_ex + nodes_in,
  60. # target = nodes_ex + nodes_in)
  61. #
  62. # weight_matrix = np.zeros((config['N'],config['N']))
  63. # for s, t, w in nest.GetStatus(conn, ['source', 'target', 'weight']):
  64. # weight_matrix[s-1][t-1] = w
  65. fname = lambda spikes, id: f'{data_prefix}{id}-{spikes.get("global_id")}' \
  66. + f'-{spikes.get("vp")}.{file_extension}'
  67. espikes_path = data_path / fname(espikes, 'ex')
  68. ispikes_path = data_path / fname(ispikes, 'in')
  69. return espikes_path, ispikes_path
  70. if __name__ == '__main__':
  71. CLI = argparse.ArgumentParser()
  72. CLI.add_argument("--N", nargs='?', type=int)
  73. CLI.add_argument("--f", nargs='?', type=float)
  74. CLI.add_argument("--mu", nargs='?', type=float)
  75. CLI.add_argument("--epsilon", nargs='?', type=float)
  76. CLI.add_argument("--sigma_ex", nargs='?', type=float)
  77. CLI.add_argument("--sigma_in", nargs='?', type=float)
  78. CLI.add_argument("--simtime", nargs='?', type=float)
  79. CLI.add_argument("--eta", nargs='?', type=float)
  80. CLI.add_argument("--seed", nargs='?', type=int)
  81. CLI.add_argument("--spikes_ex_path", nargs='?', type=str)
  82. CLI.add_argument("--spikes_in_path", nargs='?', type=str)
  83. CLI.add_argument("--weights_path", nargs='?', type=str)
  84. CLI.add_argument("--network_config", nargs='?', type=str)
  85. args, unknown = CLI.parse_known_args()
  86. dirname = os.path.dirname(args.weights_path)
  87. if not os.path.exists(dirname):
  88. os.makedirs(dirname)
  89. # Loading default configs from config file
  90. sys.path.append(os.path.dirname(args.network_config))
  91. import config as network_config
  92. locals_dict = locals()
  93. config = dict([(k,v) for k,v in vars(network_config).items()
  94. if k not in locals_dict])
  95. # Updating config with comandline arguments
  96. # (These are only the parameters which are eventually expanded)
  97. args.p_rate = network_config.p_rate_func(args.eta)
  98. update_params = ['N', 'f', 'mu', 'epsilon', 'signma_ex', 'sigma_in',
  99. 'simtime', 'seed', 'eta', 'p_rate']
  100. config.update(dict([(k,v) for k,v in vars(args).items()
  101. if k in update_params]))
  102. # Building network
  103. starttime = time.time()
  104. espikes_path, ispikes_path = build_network(config=config,
  105. weights=np.load(args.weights_path),
  106. spikes_ex_output= args.spikes_ex_path,
  107. spikes_in_output= args.spikes_in_path)
  108. # Run simulation
  109. nest.Simulate(config['simtime'])
  110. endtime = time.time()
  111. print("Simulation time : {:.2} s".format(endtime-starttime))
  112. espikes_path.rename(args.spikes_ex_path)
  113. ispikes_path.rename(args.spikes_in_path)