1
0

utils.py 4.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import numpy as np
  2. import quantities as pq
  3. import neo
  4. import elephant.spike_train_generation as stg
  5. import matplotlib.pyplot as plt
  6. def generate_spike_trains_with_coinc(lambda_b, lambda_c, trial_duration, coinc_duration, num_trials,
  7. num_units=2, unit_ids_sync=(1,2), RateJitter=0*pq.Hz):
  8. """
  9. generate stationary poisson spiketrains with injected coincidences
  10. """
  11. if not isinstance(unit_ids_sync[0], (list, tuple, np.ndarray)):
  12. unit_ids_sync = [unit_ids_sync]
  13. sync_count = np.zeros(num_units, dtype=int)
  14. for unit_ids in unit_ids_sync:
  15. for unit_id in unit_ids:
  16. sync_count[unit_id - 1] += 1
  17. t_coinc_start = (trial_duration - coinc_duration) / 2.
  18. t_coinc_stop = t_coinc_start + coinc_duration
  19. rate_jitters = (np.random.rand(num_trials) - 0.5) * RateJitter
  20. spike_trains = []
  21. for i_trial in range(num_trials):
  22. rate = lambda_b + rate_jitters[i_trial]
  23. # spiketrains of non-coincident spikes
  24. sp_pre_coinc = [stg.homogeneous_poisson_process(rate, 0.*pq.ms, t_coinc_start) for _ in range(num_units)]
  25. sp_coinc = [stg.homogeneous_poisson_process(rate - lambda_c*sync_count[i], t_coinc_start, t_coinc_stop) for i in range(num_units)]
  26. sp_post_coinc = [stg.homogeneous_poisson_process(rate, t_coinc_stop, trial_duration) for _ in range(num_units)]
  27. # spiketrains of coincident spikes (one spiketrain per synchronous subset of units)
  28. coinc = [stg.homogeneous_poisson_process(lambda_c, t_coinc_start, t_coinc_stop) for _ in unit_ids_sync]
  29. sts = []
  30. for i_unit in range(num_units):
  31. # collect all spike times of a unit
  32. spike_times = [sp_pre_coinc[i_unit].times.rescale('ms').magnitude,
  33. sp_coinc[i_unit].times.rescale('ms').magnitude,
  34. sp_post_coinc[i_unit].times.rescale('ms').magnitude]
  35. for i, unit_ids in enumerate(unit_ids_sync):
  36. if i_unit + 1 in unit_ids:
  37. spike_times.append(coinc[i].times.rescale('ms').magnitude)
  38. # concatenate the collected spike times and sort them
  39. spike_times = np.sort(np.concatenate(spike_times))
  40. sts.append(neo.SpikeTrain(spike_times*pq.ms, t_start=0.*pq.ms, t_stop=trial_duration))
  41. spike_trains.append(sts)
  42. return spike_trains
  43. def generate_spike_trains_with_osc_coinc(num_trials, num_units, trial_duration, freq_coinc, amp_coinc, offset_coinc,
  44. freq_bg, amp_bg, offset_bg, RateJitter=10*pq.Hz):
  45. """
  46. generate non-stationary poisson spiketrains with oscillatory rate modulation and injected coincidences
  47. """
  48. dt = 1 * pq.ms
  49. times = np.arange(0, trial_duration.rescale('s').magnitude, dt.rescale('s').magnitude)
  50. pi2 = np.pi * 2
  51. # modulatory coincidence rate
  52. phases_coinc = pi2 * freq_coinc.rescale('Hz').magnitude * times
  53. rate_coinc = (offset_coinc + amp_coinc * np.sin(phases_coinc)).rescale('Hz').magnitude
  54. rate_coinc[rate_coinc < 0] = 0
  55. # background rate
  56. phases_bg = pi2 * freq_bg.rescale('Hz').magnitude * times
  57. rate_bg = (offset_bg + amp_bg * np.sin(phases_bg)).rescale('Hz').magnitude
  58. rate_bg[rate_bg < 0] = 0
  59. # inhomogenious rate across trials
  60. rate_jitters = (np.random.rand(num_trials) - 0.5) * RateJitter
  61. spiketrain = []
  62. for i in range(num_trials):
  63. rate_signal_bg = neo.AnalogSignal(rate_bg + rate_jitters[i].magnitude, sampling_period=dt, units=pq.Hz, t_start=0*pq.ms)
  64. rate_signal_coinc = neo.AnalogSignal(rate_coinc, sampling_period=dt, units=pq.Hz, t_start=0*pq.ms)
  65. sts_bg = [stg.inhomogeneous_poisson_process(rate_signal_bg) for _ in range(num_units)]
  66. # inserting coincidences
  67. sts_coinc = stg.inhomogeneous_poisson_process(rate_signal_coinc)
  68. sts_bg_coinc = []
  69. for st_bg in sts_bg:
  70. spike_times = np.sort(np.append(st_bg.times.magnitude, sts_coinc.times.magnitude))
  71. st_bg_coinc = neo.SpikeTrain(spike_times, units=st_bg.units, t_start=st_bg.t_start, t_stop=st_bg.t_stop)
  72. sts_bg_coinc.append(st_bg_coinc)
  73. spiketrain.append(sts_bg_coinc)
  74. return {'st':spiketrain, 'background_rate':rate_bg, 'coinc_rate':rate_coinc}