123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import copy
- import warnings
- import brian2 as br
- import matplotlib.pyplot as plt
- import numpy as np
- from brian2.units import *
- class MicroCircuitBistability:
- def __init__(self, exc_neuron_model, interneuron_model, synapse_model, synapse_on_pre, exc_neuron_parameters, interneuron_parameters, synapse_parameters,
- initial_states, spike_threshold, integration_method, current_drive, stimulus, record_variables):
- threshold_eqs = """
- v_threshold: volt
- """
- self.neuron = br.NeuronGroup(N=1, \
- model=neuron_model + threshold_eqs, \
- threshold='v > v_threshold', \
- refractory='v > v_threshold', \
- method=integration_method)
- self.neuron.v_threshold = spike_threshold
- set_parameters_from_dict(self.neuron, initial_states)
- self.neuron.I = current_drive
- self.spike_recorder = br.SpikeMonitor(source=self.neuron)
- self.neuron_state_recorder = br.StateMonitor(self.neuron, record_variables, record=True)
- self.auto_synapse = br.Synapses(source=self.neuron, target=self.neuron, model=synapse_model,
- on_pre=synapse_on_pre,
- delay=0.0 * ms)
- self.auto_synapse.connect()
- set_parameters_from_dict(self.auto_synapse, synapse_parameters)
- self.net = br.Network(self.neuron)
- self.net.add(self.auto_synapse)
- self.net.add(self.spike_recorder)
- self.net.add(self.neuron_state_recorder)
- self.net.store()
- self.network_params = copy.deepcopy(neuron_parameters)
- self.network_params.update(synapse_parameters)
- def measure_response_curve(self, number_of_points, inhibitory_strength, offset_after_spike_peak=1 * ms, \
- offset_before_spike_peak=
- 1 * ms):
- self.run_sim(duration=200 * ms)
- spike_train = self.spike_recorder.spike_trains()[0]
- period_unperturbed = get_mean_period(spike_train)
- n_simulations = number_of_points
- inhibitory_delays = np.linspace(offset_after_spike_peak, period_unperturbed - offset_before_spike_peak,
- n_simulations)
- perturbed_periods = []
- for delay in inhibitory_delays:
- self.run_sim(delay, inhibitory_strength, duration=100 * ms)
- spike_train = self.spike_recorder.spike_trains()[0]
- perturbed_periods.append(get_mean_period(spike_train))
- perturbed_periods = np.array(perturbed_periods) * second
- return period_unperturbed, inhibitory_delays, perturbed_periods
- def run_sim(self, delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms):
- self.net.restore()
- self.neuron_state_recorder.record = record_states
- self.auto_synapse.delay = delay
- self.auto_synapse.synaptic_strength = pulse_strength
- self.net.run(duration=duration, namespace=self.network_params)
- return self.spike_recorder, self.neuron_state_recorder
- def get_mean_period(spike_train):
- return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1)
- def set_parameters_from_dict(neurongroup, dictionary_of_parameters):
- for param_key, param_value in dictionary_of_parameters.items():
- try:
- neurongroup.__setattr__(param_key, param_value)
- except AttributeError as err:
- warnings.warn("{:s} has no parameter {:s}".format(neurongroup.name, param_key))
- def analyze_response_curve(period_unperturbed, inhibitory_delay, periods):
- plt.plot(inhibitory_delay / ms, periods / ms)
- plt.hlines(period_unperturbed / ms, inhibitory_delay[0] / ms, inhibitory_delay[-1] / ms, label="No inhibition")
- plt.xlabel("Delay spike to inhibition (ms)")
- plt.ylabel("Period (ms")
- plt.legend()
- plt.show()
- phases = inhibitory_delay / period_unperturbed
- dphases = (period_unperturbed - periods) / period_unperturbed
- optimal_inhibitory_delay_idx = np.argmax(dphases)
- optimal_inhibitory_delay = inhibitory_delay[optimal_inhibitory_delay_idx] # assumes a single maximum
- plt.figure()
- plt.plot(phases, dphases)
- plt.vlines(optimal_inhibitory_delay, -1, 1)
- plt.title("Phase response curve")
- plt.ylim(-1, 1)
- plt.figure()
- if optimal_inhibitory_delay_idx == 0:
- delta_i = phases
- else:
- delta_i = phases[:-optimal_inhibitory_delay_idx]
- delta_iplus1 = delta_i + 1 + dphases[optimal_inhibitory_delay_idx:] - dphases[optimal_inhibitory_delay_idx]
- plt.plot(delta_i, delta_iplus1)
- plt.plot(delta_i, delta_i, '--')
- plt.hlines(1, 0, 1)
- plt.xlabel("Delta i")
- plt.ylabel("Delta i+1")
- plt.title("Iterative map")
- plt.ylim(0, 1.2)
- plt.xlim(0, 1)
- plt.show()
|