microcircuit_bistability.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import copy
  2. import warnings
  3. import brian2 as br
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from brian2.units import *
  7. class MicroCircuitBistability:
  8. def __init__(self, exc_neuron_model, interneuron_model, synapse_model, synapse_on_pre, exc_neuron_parameters, interneuron_parameters, synapse_parameters,
  9. initial_states, spike_threshold, integration_method, current_drive, stimulus, record_variables):
  10. threshold_eqs = """
  11. v_threshold: volt
  12. """
  13. self.neuron = br.NeuronGroup(N=1, \
  14. model=neuron_model + threshold_eqs, \
  15. threshold='v > v_threshold', \
  16. refractory='v > v_threshold', \
  17. method=integration_method)
  18. self.neuron.v_threshold = spike_threshold
  19. set_parameters_from_dict(self.neuron, initial_states)
  20. self.neuron.I = current_drive
  21. self.spike_recorder = br.SpikeMonitor(source=self.neuron)
  22. self.neuron_state_recorder = br.StateMonitor(self.neuron, record_variables, record=True)
  23. self.auto_synapse = br.Synapses(source=self.neuron, target=self.neuron, model=synapse_model,
  24. on_pre=synapse_on_pre,
  25. delay=0.0 * ms)
  26. self.auto_synapse.connect()
  27. set_parameters_from_dict(self.auto_synapse, synapse_parameters)
  28. self.net = br.Network(self.neuron)
  29. self.net.add(self.auto_synapse)
  30. self.net.add(self.spike_recorder)
  31. self.net.add(self.neuron_state_recorder)
  32. self.net.store()
  33. self.network_params = copy.deepcopy(neuron_parameters)
  34. self.network_params.update(synapse_parameters)
  35. def measure_response_curve(self, number_of_points, inhibitory_strength, offset_after_spike_peak=1 * ms, \
  36. offset_before_spike_peak=
  37. 1 * ms):
  38. self.run_sim(duration=200 * ms)
  39. spike_train = self.spike_recorder.spike_trains()[0]
  40. period_unperturbed = get_mean_period(spike_train)
  41. n_simulations = number_of_points
  42. inhibitory_delays = np.linspace(offset_after_spike_peak, period_unperturbed - offset_before_spike_peak,
  43. n_simulations)
  44. perturbed_periods = []
  45. for delay in inhibitory_delays:
  46. self.run_sim(delay, inhibitory_strength, duration=100 * ms)
  47. spike_train = self.spike_recorder.spike_trains()[0]
  48. perturbed_periods.append(get_mean_period(spike_train))
  49. perturbed_periods = np.array(perturbed_periods) * second
  50. return period_unperturbed, inhibitory_delays, perturbed_periods
  51. def run_sim(self, delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms):
  52. self.net.restore()
  53. self.neuron_state_recorder.record = record_states
  54. self.auto_synapse.delay = delay
  55. self.auto_synapse.synaptic_strength = pulse_strength
  56. self.net.run(duration=duration, namespace=self.network_params)
  57. return self.spike_recorder, self.neuron_state_recorder
  58. def get_mean_period(spike_train):
  59. return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1)
  60. def set_parameters_from_dict(neurongroup, dictionary_of_parameters):
  61. for param_key, param_value in dictionary_of_parameters.items():
  62. try:
  63. neurongroup.__setattr__(param_key, param_value)
  64. except AttributeError as err:
  65. warnings.warn("{:s} has no parameter {:s}".format(neurongroup.name, param_key))
  66. def analyze_response_curve(period_unperturbed, inhibitory_delay, periods):
  67. plt.plot(inhibitory_delay / ms, periods / ms)
  68. plt.hlines(period_unperturbed / ms, inhibitory_delay[0] / ms, inhibitory_delay[-1] / ms, label="No inhibition")
  69. plt.xlabel("Delay spike to inhibition (ms)")
  70. plt.ylabel("Period (ms")
  71. plt.legend()
  72. plt.show()
  73. phases = inhibitory_delay / period_unperturbed
  74. dphases = (period_unperturbed - periods) / period_unperturbed
  75. optimal_inhibitory_delay_idx = np.argmax(dphases)
  76. optimal_inhibitory_delay = inhibitory_delay[optimal_inhibitory_delay_idx] # assumes a single maximum
  77. plt.figure()
  78. plt.plot(phases, dphases)
  79. plt.vlines(optimal_inhibitory_delay, -1, 1)
  80. plt.title("Phase response curve")
  81. plt.ylim(-1, 1)
  82. plt.figure()
  83. if optimal_inhibitory_delay_idx == 0:
  84. delta_i = phases
  85. else:
  86. delta_i = phases[:-optimal_inhibitory_delay_idx]
  87. delta_iplus1 = delta_i + 1 + dphases[optimal_inhibitory_delay_idx:] - dphases[optimal_inhibitory_delay_idx]
  88. plt.plot(delta_i, delta_iplus1)
  89. plt.plot(delta_i, delta_i, '--')
  90. plt.hlines(1, 0, 1)
  91. plt.xlabel("Delta i")
  92. plt.ylabel("Delta i+1")
  93. plt.title("Iterative map")
  94. plt.ylim(0, 1.2)
  95. plt.xlim(0, 1)
  96. plt.show()