lif_simulation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import os
  2. import numpy as np
  3. import multiprocessing
  4. import scipy.signal as signal
  5. from joblib import Parallel, delayed
  6. from ..util import LIF, mutual_info
  7. def whitenoise(cflow, cfup, dt, duration, rng=np.random):
  8. """Band-limited white noise.
  9. Generates white noise with a flat power spectrum between `cflow` and
  10. `cfup` Hertz, zero mean and unit standard deviation. Note, that in
  11. particular for short segments of the generated noise the mean and
  12. standard deviation can deviate from zero and one.
  13. Parameters
  14. ----------
  15. cflow: float
  16. Lower cutoff frequency in Hertz.
  17. cfup: float
  18. Upper cutoff frequency in Hertz.
  19. dt: float
  20. Time step of the resulting array in seconds.
  21. duration: float
  22. Total duration of the resulting array in seconds.
  23. Returns
  24. -------
  25. noise: 1-D array
  26. White noise.
  27. """
  28. # next power of two:
  29. n = int(duration//dt)
  30. nn = int(2**(np.ceil(np.log2(n))))
  31. # draw random numbers in Fourier domain:
  32. inx0 = int(np.round(dt*nn*cflow))
  33. inx1 = int(np.round(dt*nn*cfup))
  34. if inx0 < 0:
  35. inx0 = 0
  36. if inx1 >= nn/2:
  37. inx1 = nn/2
  38. sigma = 0.5 / np.sqrt(float(inx1 - inx0))
  39. whitef = np.zeros((nn//2+1), dtype=complex)
  40. if inx0 == 0:
  41. whitef[0] = rng.randn()
  42. inx0 = 1
  43. if inx1 >= nn//2:
  44. whitef[nn//2] = rng.randn()
  45. inx1 = nn//2-1
  46. m = inx1 - inx0 + 1
  47. whitef[inx0:inx1+1] = rng.randn(m) + 1j*rng.randn(m)
  48. # inverse FFT:
  49. noise = np.real(np.fft.irfft(whitef))[:n]*sigma*nn
  50. return noise
  51. def create_noise_stimulus(duration, cutoff, dt, amplitude):
  52. """Create band-limited Gaussian noise stimulus.
  53. Parameters
  54. ----------
  55. duration : float
  56. duration of the stimulus in seconds.
  57. cutoff : float
  58. the cutoff frequency in Hertz.
  59. dt : float
  60. the temporal resolution in seconds.
  61. amplitude : float
  62. the amplitude of the stimulus.
  63. Returns
  64. -------
  65. np.ndarray
  66. the stimulus
  67. """
  68. print("generating stimulus...", end="")
  69. nyquist = 1./dt/2 # Hz
  70. stimulus = np.random.randn(int(np.round(duration/dt))) * amplitude
  71. b, a = signal.butter(8, cutoff/nyquist)
  72. stimulus = signal.filtfilt(b, a, stimulus)
  73. stimulus[0] = np.mean(stimulus)
  74. print(" done")
  75. return stimulus
  76. def create_lif_responses(num_trials, stimulus, dt, noise=0., tau_m=0.00125):
  77. """Creates a set of model responses to the stimulus.
  78. Parameters
  79. ----------
  80. num_trials : int
  81. number of responses
  82. stimulus : np.ndarray
  83. the stimulus
  84. dt : float
  85. the stepsize of the stimulation. Given in s.
  86. noise : float
  87. noise in the model. Defaults tp 0.0
  88. tau_m : float, optional
  89. the membrane time-constant. Defaults to 0.005 s.
  90. Returns
  91. -------
  92. list of np.ndarray
  93. list of spike respones.
  94. """
  95. spikes = []
  96. lif_model = LIF(stepsize=dt, offset=2.5, tau_m=tau_m, D=noise)
  97. for i in range(num_trials):
  98. _, _ , spike_times = lif_model.run_stimulus(stimulus)
  99. spikes.append(spike_times)
  100. return spikes
  101. def create_population(spike_respones, pop_size, cell_density=0, conduction_velocity=50.):
  102. """
  103. Creates a random population of the spike_responses from the set of passed responses. If cell_density is larger than 0 the responses will be delayed according to the conduction delay.
  104. The function assumes a constant distance between "cells".
  105. Parameters
  106. ----------
  107. spike_respones : list of lists
  108. Each entry is a list/array of spike times
  109. pop_size : int
  110. the population size.
  111. cell_density : float, optional
  112. cell density in m^-1. Defaults to 0, which indicates infinitely, i.e. no conduction delays.
  113. conduction_velocity :float, optional
  114. The conduction velocity. Defaults to 50 m/s
  115. Returns
  116. -------
  117. list of np.ndarray
  118. randomly selected and delayed responses.
  119. """
  120. assert(conduction_velocity > 0.0)
  121. assert(pop_size <= len(spike_respones))
  122. indices = np.arange(len(spike_respones), dtype=int)
  123. np.random.shuffle(indices)
  124. population = []
  125. for i in range(pop_size):
  126. spike_times = spike_respones[indices[i]]
  127. spike_times = np.asarray(spike_times)
  128. delay = 0.0
  129. if cell_density > 0:
  130. delay = i / cell_density / conduction_velocity
  131. population.append(spike_times + delay)
  132. return population
  133. def get_population_information(spike_responses, stimulus, population_size, density=0., cutoff=(0., 500.),
  134. kernel_sigma=0.0005, stepsize=0.0001, trial_duration=10., repeats=1,
  135. conduction_velocity=50.):
  136. """Estimate the amount of information carried by the population response.
  137. Parameters
  138. ----------
  139. spike_responses : list of np.ndarray
  140. list of spike times
  141. stimulus : np.ndarray
  142. the stimulus
  143. population_size : int
  144. population size
  145. density : float, optional
  146. cell density in m^-1. Defaults to 0/m, which indicates infinitely high density, i.e. no conduction delays.
  147. cutoff : tuple of float, optional
  148. the lower and upper cutoff frequency of the stimulus. Defaults to (0, 500)
  149. kernel_sigma : float, optional
  150. std of the Gaussian kernel used for firing rate estimation. Defaults to 0.0005.
  151. stepsize : float, optional
  152. Temporal resolution of the firing rate. Defaults to 0.0001.
  153. trial_duration : float, optional
  154. trial duration in seconds. Defaults to 10.
  155. repeats : int, optional
  156. number of random populations per population size. Defaults to 1.
  157. conduction_velocity : float, optional
  158. conduction velocity in m/s. Defaults to 50 m/s.
  159. Returns
  160. -------
  161. np.ndarray
  162. mutual information between stimulus and population response. Has the shape [num_population_sizes, repeats]
  163. """
  164. information = np.zeros(repeats)
  165. for i in range(repeats):
  166. population_spikes = create_population(spike_responses, population_size, density, conduction_velocity)
  167. _, _, mi, _, _ = mutual_info(population_spikes, 0.0, [False] * len(population_spikes),
  168. stimulus, cutoff, kernel_sigma, stepsize=stepsize, trial_duration=trial_duration)
  169. information[i] = mi[-1]
  170. return information
  171. def process_populations(spike_responses, stimulus, population_sizes, density, cutoff=(0., 500.), kernel_sigma=0.0005,
  172. stepsize=0.0001, trial_duration=10., repeats=1, conduction_velocity=50., num_cores=1):
  173. """_summary_
  174. Parameters
  175. ----------
  176. spike_responses : list of np.arrays
  177. containing the spike times in response to the stimulus presentations
  178. stimulus : np.array
  179. the stimulus waveform
  180. population_sizes : list
  181. list of population sizes that should be analysed
  182. density : float
  183. spatial density of model neurons in m^-1
  184. cutoff : tuple, optional
  185. lower and upper cutoff frequencies of the analyzed frequency bands by default (0., 500.)
  186. kernel_sigma : float, optional
  187. The sigma of the Gaussian kernel used for firing rate estimation, by default 0.0005
  188. stepsize : float, optional
  189. temporal stepsize of the simulation, by default 0.0001
  190. trial_duration : float, optional
  191. duration of the trials in seconds, by default 10.
  192. repeats : int, optional
  193. number of repetitions, by default 1
  194. conduction_velocity : float, optional
  195. simulated axonal conduction velocity, by default 50.
  196. num_cores : int, optional
  197. number of spawned parallel processes, by default 1
  198. Returns
  199. -------
  200. list
  201. the results
  202. """
  203. processed_list = Parallel(n_jobs=num_cores)(delayed(get_population_information)(spike_responses, stimulus, ps, density, cutoff, kernel_sigma, stepsize, trial_duration, repeats, conduction_velocity) for ps in population_sizes)
  204. information = np.zeros((len(population_sizes), repeats))
  205. for i, pr in enumerate(processed_list):
  206. information[i, :] = pr[0]
  207. return information
  208. def generate_responses(num_responses, stimulus, repeats, dt, noise, tau_m, num_cores):
  209. """Generate spike responses to whitenoise stimulli.
  210. Parameters
  211. ----------
  212. num_responses : int
  213. number of responses
  214. stimulus : np.adarray
  215. The stimulus waveform.
  216. repeats : int
  217. number of stimulus repetitions
  218. dt : float
  219. time-step of the simulation
  220. noise : float
  221. the noise in the LIF model
  222. tau_m : float
  223. membrane time constant
  224. num_cores : int
  225. number of parallel processes that should be spawned
  226. Returns
  227. -------
  228. list
  229. spike responses, i.e. np.arrays of spike times
  230. """
  231. num_calls = int(num_responses / 5)
  232. processed_list = Parallel(n_jobs=num_cores)(delayed(create_lif_responses)(repeats, stimulus, dt, noise, tau_m) for i in range(num_calls))
  233. spikes = []
  234. for pr in processed_list:
  235. spikes.extend(pr)
  236. return spikes
  237. def main(args=None):
  238. """Run the simulation.
  239. """
  240. if args is None:
  241. num_cores = int(multiprocessing.cpu_count() / 2)
  242. else:
  243. num_cores = args.jobs
  244. dt = 0.00001 # s sr = 100 kHz
  245. duration = 2 # s
  246. lower_cutoffs = [0, 100, 200] # Hz
  247. upper_cutoffs = [100, 200, 300] # Hz
  248. amplitude = 0.0625
  249. num_responses = 300
  250. noise = 5
  251. tau_m = 0.0025
  252. repeats = 5
  253. kernel_sigma = 0.00125 # s
  254. density = 2000
  255. vel_efish = 50.
  256. vel_squid = 25.
  257. vel_corp_callosum = 7.0
  258. for lc, uc in zip(lower_cutoffs, upper_cutoffs):
  259. print("generating stimulus %i -- %i Hz..." % (lc, uc), end="\r")
  260. stimulus = whitenoise(lc, uc, dt, duration) * amplitude
  261. print("generating stimulus %i -- %i Hz... done" % (lc, uc))
  262. print("generating responses... ", end="\r")
  263. spikes = generate_responses(num_responses, stimulus, repeats, dt, noise, tau_m, num_cores)
  264. print("generating responses... done")
  265. population_sizes = list(range(1, int(np.round(2 * num_responses / 3) + 1), 2))
  266. print("analysing populations, no delay... ", end="\r", flush=True)
  267. information = process_populations(spikes, stimulus, population_sizes, 0.0, (lc, uc), kernel_sigma, dt, duration, repeats, vel_efish, num_cores)
  268. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_efish), end="\r", flush=True)
  269. information_efish = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_efish, num_cores)
  270. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_squid), end="\r", flush=True)
  271. information_squid = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_squid, num_cores)
  272. print(r"analysing populations, density: %i m^-1, vel.: %.1f m s^-1 ..." % (density, vel_corp_callosum), end="\r", flush=True)
  273. information_cc = process_populations(spikes, stimulus, population_sizes, density, (lc, uc), kernel_sigma, dt, duration, repeats, vel_corp_callosum, num_cores)
  274. print(r"analysing populations ... done" + " " * 80)
  275. velocities = {"efish": vel_efish, "squid": vel_squid, "corpus callosum": vel_corp_callosum}
  276. np.savez_compressed(os.path.join("derived_data", f"lif_simulation_{int(lc)}_{int(uc)}_test.npz"), population_sizes=population_sizes,
  277. info_no_delays=information, info_delay_efish=information_efish,
  278. info_delay_squid=information_squid, info_delay_cc=information_cc,
  279. cutoffs=(lc, uc), density=density, velocities=velocities)