lif_simulation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import os
  2. import numpy as np
  3. import scipy.signal as signal
  4. from joblib import Parallel, delayed
  5. from ..util import LIF, mutual_info, default_job_count
  6. def whitenoise(cflow, cfup, dt, duration, rng=np.random):
  7. """Band-limited white noise.
  8. Generates white noise with a flat power spectrum between `cflow` and
  9. `cfup` Hertz, zero mean and unit standard deviation. Note, that in
  10. particular for short segments of the generated noise the mean and
  11. standard deviation can deviate from zero and one.
  12. Parameters
  13. ----------
  14. cflow: float
  15. Lower cutoff frequency in Hertz.
  16. cfup: float
  17. Upper cutoff frequency in Hertz.
  18. dt: float
  19. Time step of the resulting array in seconds.
  20. duration: float
  21. Total duration of the resulting array in seconds.
  22. Returns
  23. -------
  24. noise: 1-D array
  25. White noise.
  26. """
  27. # next power of two:
  28. n = int(duration//dt)
  29. nn = int(2**(np.ceil(np.log2(n))))
  30. # draw random numbers in Fourier domain:
  31. inx0 = int(np.round(dt*nn*cflow))
  32. inx1 = int(np.round(dt*nn*cfup))
  33. if inx0 < 0:
  34. inx0 = 0
  35. if inx1 >= nn/2:
  36. inx1 = nn/2
  37. sigma = 0.5 / np.sqrt(float(inx1 - inx0))
  38. whitef = np.zeros((nn//2+1), dtype=complex)
  39. if inx0 == 0:
  40. whitef[0] = rng.randn()
  41. inx0 = 1
  42. if inx1 >= nn//2:
  43. whitef[nn//2] = rng.randn()
  44. inx1 = nn//2-1
  45. m = inx1 - inx0 + 1
  46. whitef[inx0:inx1+1] = rng.randn(m) + 1j*rng.randn(m)
  47. # inverse FFT:
  48. noise = np.real(np.fft.irfft(whitef))[:n]*sigma*nn
  49. return noise
  50. def create_noise_stimulus(duration, cutoff, dt, amplitude):
  51. """Create band-limited Gaussian noise stimulus.
  52. Parameters
  53. ----------
  54. duration : float
  55. duration of the stimulus in seconds.
  56. cutoff : float
  57. the cutoff frequency in Hertz.
  58. dt : float
  59. the temporal resolution in seconds.
  60. amplitude : float
  61. the amplitude of the stimulus.
  62. Returns
  63. -------
  64. np.ndarray
  65. the stimulus
  66. """
  67. print("generating stimulus...", end="")
  68. nyquist = 1./dt/2 # Hz
  69. stimulus = np.random.randn(int(np.round(duration/dt))) * amplitude
  70. b, a = signal.butter(8, cutoff/nyquist)
  71. stimulus = signal.filtfilt(b, a, stimulus)
  72. stimulus[0] = np.mean(stimulus)
  73. print(" done")
  74. return stimulus
  75. def create_lif_responses(num_trials, stimulus, dt, noise=0., tau_m=0.00125):
  76. """Creates a set of model responses to the stimulus.
  77. Parameters
  78. ----------
  79. num_trials : int
  80. number of responses
  81. stimulus : np.ndarray
  82. the stimulus
  83. dt : float
  84. the stepsize of the stimulation. Given in s.
  85. noise : float
  86. noise in the model. Defaults tp 0.0
  87. tau_m : float, optional
  88. the membrane time-constant. Defaults to 0.005 s.
  89. Returns
  90. -------
  91. list of np.ndarray
  92. list of spike respones.
  93. """
  94. spikes = []
  95. lif_model = LIF(stepsize=dt, offset=2.5, tau_m=tau_m, D=noise)
  96. for i in range(num_trials):
  97. _, _ , spike_times = lif_model.run_stimulus(stimulus)
  98. spikes.append(spike_times)
  99. return spikes
  100. def create_population(spike_respones, pop_size, cell_density=0, conduction_velocity=50.):
  101. """
  102. Creates a random population of the spike_responses from the set of passed responses. If cell_density is larger than 0 the responses will be delayed according to the conduction delay.
  103. The function assumes a constant distance between "cells".
  104. Parameters
  105. ----------
  106. spike_respones : list of lists
  107. Each entry is a list/array of spike times
  108. pop_size : int
  109. the population size.
  110. cell_density : float, optional
  111. cell density in m^-1. Defaults to 0, which indicates infinitely, i.e. no conduction delays.
  112. conduction_velocity :float, optional
  113. The conduction velocity. Defaults to 50 m/s
  114. Returns
  115. -------
  116. list of np.ndarray
  117. randomly selected and delayed responses.
  118. """
  119. assert(conduction_velocity > 0.0)
  120. assert(pop_size <= len(spike_respones))
  121. indices = np.arange(len(spike_respones), dtype=int)
  122. np.random.shuffle(indices)
  123. population = []
  124. for i in range(pop_size):
  125. spike_times = spike_respones[indices[i]]
  126. spike_times = np.asarray(spike_times)
  127. delay = 0.0
  128. if cell_density > 0:
  129. delay = i / cell_density / conduction_velocity
  130. population.append(spike_times + delay)
  131. return population
  132. def get_population_information(spike_responses, stimulus, population_size, density=0., cutoff=(0., 500.),
  133. kernel_sigma=0.0005, stepsize=0.0001, trial_duration=10., repeats=1,
  134. conduction_velocity=50.):
  135. """Estimate the amount of information carried by the population response.
  136. Parameters
  137. ----------
  138. spike_responses : list of np.ndarray
  139. list of spike times
  140. stimulus : np.ndarray
  141. the stimulus
  142. population_size : int
  143. population size
  144. density : float, optional
  145. cell density in m^-1. Defaults to 0/m, which indicates infinitely high density, i.e. no conduction delays.
  146. cutoff : tuple of float, optional
  147. the lower and upper cutoff frequency of the stimulus. Defaults to (0, 500)
  148. kernel_sigma : float, optional
  149. std of the Gaussian kernel used for firing rate estimation. Defaults to 0.0005.
  150. stepsize : float, optional
  151. Temporal resolution of the firing rate. Defaults to 0.0001.
  152. trial_duration : float, optional
  153. trial duration in seconds. Defaults to 10.
  154. repeats : int, optional
  155. number of random populations per population size. Defaults to 1.
  156. conduction_velocity : float, optional
  157. conduction velocity in m/s. Defaults to 50 m/s.
  158. Returns
  159. -------
  160. np.ndarray
  161. mutual information between stimulus and population response. Has the shape [num_population_sizes, repeats]
  162. """
  163. information = np.zeros(repeats)
  164. for i in range(repeats):
  165. population_spikes = create_population(spike_responses, population_size, density, conduction_velocity)
  166. _, _, mi, _, _ = mutual_info(population_spikes, 0.0, [False] * len(population_spikes),
  167. stimulus, cutoff, kernel_sigma, stepsize=stepsize, trial_duration=trial_duration)
  168. information[i] = mi[-1]
  169. return information
  170. def process_populations(spike_responses, stimulus, population_sizes, density, cutoff=(0., 500.), kernel_sigma=0.0005,
  171. stepsize=0.0001, trial_duration=10., repeats=1, conduction_velocity=50., num_cores=1):
  172. """_summary_
  173. Parameters
  174. ----------
  175. spike_responses : list of np.arrays
  176. containing the spike times in response to the stimulus presentations
  177. stimulus : np.array
  178. the stimulus waveform
  179. population_sizes : list
  180. list of population sizes that should be analysed
  181. density : float
  182. spatial density of model neurons in m^-1
  183. cutoff : tuple, optional
  184. lower and upper cutoff frequencies of the analyzed frequency bands by default (0., 500.)
  185. kernel_sigma : float, optional
  186. The sigma of the Gaussian kernel used for firing rate estimation, by default 0.0005
  187. stepsize : float, optional
  188. temporal stepsize of the simulation, by default 0.0001
  189. trial_duration : float, optional
  190. duration of the trials in seconds, by default 10.
  191. repeats : int, optional
  192. number of repetitions, by default 1
  193. conduction_velocity : float, optional
  194. simulated axonal conduction velocity, by default 50.
  195. num_cores : int, optional
  196. number of spawned parallel processes, by default 1
  197. Returns
  198. -------
  199. list
  200. the results
  201. """
  202. processed_list = Parallel(n_jobs=num_cores)(delayed(get_population_information)(spike_responses, stimulus, ps, density, cutoff, kernel_sigma, stepsize, trial_duration, repeats, conduction_velocity) for ps in population_sizes)
  203. information = np.zeros((len(population_sizes), repeats))
  204. for i, pr in enumerate(processed_list):
  205. information[i, :] = pr[0]
  206. return information
  207. def generate_responses(num_responses, stimulus, repeats, dt, noise, tau_m, num_cores):
  208. """Generate spike responses to whitenoise stimulli.
  209. Parameters
  210. ----------
  211. num_responses : int
  212. number of responses
  213. stimulus : np.adarray
  214. The stimulus waveform.
  215. repeats : int
  216. number of stimulus repetitions
  217. dt : float
  218. time-step of the simulation
  219. noise : float
  220. the noise in the LIF model
  221. tau_m : float
  222. membrane time constant
  223. num_cores : int
  224. number of parallel processes that should be spawned
  225. Returns
  226. -------
  227. list
  228. spike responses, i.e. np.arrays of spike times
  229. """
  230. num_calls = int(num_responses / 5)
  231. processed_list = Parallel(n_jobs=num_cores)(delayed(create_lif_responses)(repeats, stimulus, dt, noise, tau_m) for i in range(num_calls))
  232. spikes = []
  233. for pr in processed_list:
  234. spikes.extend(pr)
  235. return spikes
  236. def main(args=None):
  237. """Run the simulation.
  238. """
  239. if args is None:
  240. num_cores = default_job_count()
  241. else:
  242. num_cores = args.jobs
  243. dt = 0.00001 # s sr = 100 kHz
  244. duration = 2 # s
  245. lower_cutoffs = [0, 100, 200] # Hz
  246. upper_cutoffs = [100, 200, 300] # Hz
  247. amplitude = 0.0625
  248. num_responses = 300
  249. noise = 5
  250. tau_m = 0.0025
  251. repeats = 5
  252. kernel_sigma = 0.00125 # s
  253. density = 2000
  254. vel_efish = 50.
  255. vel_squid = 25.
  256. vel_corp_callosum = 7.0
  257. for lc, uc in zip(lower_cutoffs, upper_cutoffs):
  258. print("generating stimulus %i -- %i Hz..." % (lc, uc), end="\r")
  259. stimulus = whitenoise(lc, uc, dt, duration) * amplitude
  260. print("generating stimulus %i -- %i Hz... done" % (lc, uc))
  261. print("generating responses... ", end="\r")
  262. spikes = generate_responses(num_responses, stimulus, repeats, dt, noise, tau_m, num_cores)
  263. print("generating responses... done")
  264. population_sizes = list(range(1, int(np.round(2 * num_responses / 3) + 1), 2))
  265. print("analysing populations, no delay... ", end="\r", flush=True)
  266. information = process_populations(spikes, stimulus, population_sizes, 0.0, (lc, uc), kernel_sigma, dt, duration, repeats, vel_efish, num_cores)
  267. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_efish), end="\r", flush=True)
  268. information_efish = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_efish, num_cores)
  269. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_squid), end="\r", flush=True)
  270. information_squid = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_squid, num_cores)
  271. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_corp_callosum), end="\r", flush=True)
  272. information_cc = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_corp_callosum, num_cores)
  273. print(r"analysing populations ... done" + " " * 80)
  274. velocities = {"efish": vel_efish, "squid": vel_squid, "corpus callosum": vel_corp_callosum}
  275. np.savez_compressed(os.path.join("derived_data", f"lif_simulation_{int(lc)}_{int(uc)}_test.npz"), population_sizes=population_sizes,
  276. info_no_delays=information, info_delay_efish=information_efish,
  277. info_delay_squid=information_squid, info_delay_cc=information_cc,
  278. cutoffs=(lc, uc), density=density, velocities=velocities)
  279. def command_line_parser(parent_parser):
  280. parser = parent_parser.add_parser("population", help="Calculate the mutual information as a function of delay and population size.")
  281. parser.add_argument("-j", "--jobs", type=int, default=default_job_count(), help=f"The number of parallel processes spawned for the analyses (defaults to {default_job_count()})")
  282. parser.set_defaults(func=main)