import brian2 as br from brian2.units import * import numpy as np import matplotlib.pyplot as plt import warnings import brianutils, json, numpy import sys print(sys.path) print("load model") model_dict= json.load(open("../../models/Hodkin_Homebrew.json")) eqs = brianutils.load_model(model_dict,substitution_depth=4) print("model loaded") delta_synapse_model = 'perturbation: volt' delta_synapse = 'v+=perturbation' threshold = "v>v_threshold" neuron_properties = { "tau": 10*ms, "v_threshold": -40*mV, "v_reset": -75*mV, "v_refractory": 'v > -40*mV', "u_ext": - 39 * mV } # pulse_strength = -10. * mV neuron = br.NeuronGroup(N=1, \ name='single_neuron',\ model=eqs, \ threshold='v > -40*mV', \ refractory = 'v > -40*mV',\ method='exponential_euler') record_variables = ['v','ih'] spike_recorder = br.SpikeMonitor(source=neuron) state_recorder = br.StateMonitor(neuron, record_variables, record=True) auto_synapse = br.Synapses(source=neuron, target=neuron, model=delta_synapse_model, on_pre = delta_synapse, delay = 0.0 * ms) auto_synapse.connect() net = br.Network(neuron) net.add(auto_synapse) net.add(spike_recorder) net.add(state_recorder) net.store() def run_sim(delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms): net.restore() state_recorder.record = record_states auto_synapse.perturbation = pulse_strength auto_synapse.delay = delay net.run(duration=duration) def get_mean_period(spike_train): return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1) def get_mean_response(spike_trains_dict,t_isi): mean_response_dict = {} # t_isi = get_mean_period(spike_trains_dict[0.0 * ms]) for delay, spike_train in spike_trains_dict.items(): mean_response_dict[delay] = get_mean_period(spike_train) - t_isi return mean_response_dict def plot_spiking(spike_trains_dict): fig = plt.figure() ax = fig.add_subplot(111) for key, times in spike_trains_dict.items(): ax.plot(times/ms, key*np.ones(times.shape), 'b.') # ax.plot(spike_train[0]/ms, np.ones(spike_train[0].shape), 'b|') ax.grid(axis='x') # ax.set_ylim(-0.1, 1.1) ax.set_xlabel("Time(ms)"); def plot_mean_period(mean_period_dict): fig = plt.figure() ax = fig.add_subplot(111) ax.plot(list(mean_period_dict.keys()), list(mean_period_dict.values()), 'b') ax.grid(axis='x') ax.set_xlabel("Delay (ms)"); def plot_mean_response(mean_response_dict): fig = plt.figure() ax = fig.add_subplot(111) ax.plot(np.array(list(mean_response_dict.keys())), 1e3*np.array(list(mean_response_dict.values())), 'b') ax.plot(list(mean_response_dict.keys()), list(mean_response_dict.keys()), 'k--') ax.grid(axis='x') ax.set_xlabel("t_inh (ms)") ax.set_ylabel("exc_spike_delay (ms)") def plot_delay_minus_response(mean_response_dict): fig = plt.figure() ax = fig.add_subplot(111) print(np.array(list(mean_response_dict.keys()))) print(np.array(list(mean_response_dict.values()))) print(np.array(list(mean_response_dict.keys()))-np.array(list(mean_response_dict.values()))) ax.plot(list(mean_response_dict.keys()), np.array(list(mean_response_dict.keys()))-1e3*np.array(list(mean_response_dict.values())), 'b') ax.grid(axis='x') ax.set_xlabel("Delay (ms)"); def plot_phase_response_curve(mean_response_dict, t_isi): phi_inh = np.array(list(mean_response_dict.keys())) * ms / t_isi delta_phi = -np.array(list(mean_response_dict.values())) * second / t_isi fig = plt.figure() ax = fig.add_subplot(111) ax.plot(phi_inh, delta_phi, 'b') ax.plot(phi_inh, -phi_inh, 'k--') ax.grid(axis='x') ax.set_xlabel("phase of inhibition (ms)") ax.set_ylabel("phase shift (ms)"); run_sim(record_states=True, pulse_strength=0.0*mV, duration=200 * ms) t_state = state_recorder.t v_state = state_recorder[1].v ih_state = state_recorder[1].ih spike_train = spike_recorder.spike_trains()[0] fig, ax1 = plt.subplots() ax1.plot(t_state/ms,v_state/ms,'C2') ax2 = ax1.twinx() ax2.plot(t_state/ms,ih_state/ms,'C1') t_isi = get_mean_period(spike_train) print(t_isi) n_simulations = 40 pulse_strength = -20 * mV delay_list = np.linspace(0.1 * ms,t_isi - 0.1*ms,n_simulations) spike_trains_dict = {} for delay in delay_list: run_sim(delay, pulse_strength, record_states=False, duration=100 * ms) spike_train = spike_recorder.spike_trains()[0] # print(spike_train) spike_trains_dict[delay / ms] = spike_train #Careful, delay is not a Quantity anymore because in must be hashable mean_response_dict = get_mean_response(spike_trains_dict, t_isi) plot_mean_response(mean_response_dict) # plot_phase_response_curve(mean_response_dict, t_isi) # plot_spiking(spike_trains_dict) plt.show()