123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import brian2 as br
- from brian2.units import *
- import numpy as np
- import matplotlib.pyplot as plt
- import warnings
- import brianutils, json, numpy
- import sys
- print(sys.path)
- print("load model")
- model_dict= json.load(open("../../models/Hodkin_Homebrew.json"))
- eqs = brianutils.load_model(model_dict,substitution_depth=4)
- print("model loaded")
- delta_synapse_model = 'perturbation: volt'
- delta_synapse = 'v+=perturbation'
- threshold = "v>v_threshold"
- neuron_properties = {
- "tau": 10*ms,
- "v_threshold": -40*mV,
- "v_reset": -75*mV,
- "v_refractory": 'v > -40*mV',
- "u_ext": - 39 * mV
- }
- # pulse_strength = -10. * mV
- neuron = br.NeuronGroup(N=1, \
- name='single_neuron',\
- model=eqs, \
- threshold='v > -40*mV', \
- refractory = 'v > -40*mV',\
- method='exponential_euler')
- record_variables = ['v','ih']
- spike_recorder = br.SpikeMonitor(source=neuron)
- state_recorder = br.StateMonitor(neuron, record_variables, record=True)
- auto_synapse = br.Synapses(source=neuron, target=neuron, model=delta_synapse_model, on_pre = delta_synapse, delay = 0.0 * ms)
- auto_synapse.connect()
- net = br.Network(neuron)
- net.add(auto_synapse)
- net.add(spike_recorder)
- net.add(state_recorder)
- net.store()
- def run_sim(delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms):
- net.restore()
- state_recorder.record = record_states
- auto_synapse.perturbation = pulse_strength
- auto_synapse.delay = delay
- net.run(duration=duration)
- def get_mean_period(spike_train):
- return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1)
- def get_mean_response(spike_trains_dict,t_isi):
- mean_response_dict = {}
- # t_isi = get_mean_period(spike_trains_dict[0.0 * ms])
- for delay, spike_train in spike_trains_dict.items():
- mean_response_dict[delay] = get_mean_period(spike_train) - t_isi
- return mean_response_dict
- def plot_spiking(spike_trains_dict):
- fig = plt.figure()
- ax = fig.add_subplot(111)
- for key, times in spike_trains_dict.items():
- ax.plot(times/ms, key*np.ones(times.shape), 'b.')
- # ax.plot(spike_train[0]/ms, np.ones(spike_train[0].shape), 'b|')
- ax.grid(axis='x')
- # ax.set_ylim(-0.1, 1.1)
- ax.set_xlabel("Time(ms)");
- def plot_mean_period(mean_period_dict):
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(list(mean_period_dict.keys()), list(mean_period_dict.values()), 'b')
- ax.grid(axis='x')
- ax.set_xlabel("Delay (ms)");
- def plot_mean_response(mean_response_dict):
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(np.array(list(mean_response_dict.keys())), 1e3*np.array(list(mean_response_dict.values())), 'b')
- ax.plot(list(mean_response_dict.keys()), list(mean_response_dict.keys()), 'k--')
- ax.grid(axis='x')
- ax.set_xlabel("t_inh (ms)")
- ax.set_ylabel("exc_spike_delay (ms)")
- def plot_delay_minus_response(mean_response_dict):
- fig = plt.figure()
- ax = fig.add_subplot(111)
- print(np.array(list(mean_response_dict.keys())))
- print(np.array(list(mean_response_dict.values())))
- print(np.array(list(mean_response_dict.keys()))-np.array(list(mean_response_dict.values())))
- ax.plot(list(mean_response_dict.keys()), np.array(list(mean_response_dict.keys()))-1e3*np.array(list(mean_response_dict.values())), 'b')
- ax.grid(axis='x')
- ax.set_xlabel("Delay (ms)");
- def plot_phase_response_curve(mean_response_dict, t_isi):
- phi_inh = np.array(list(mean_response_dict.keys())) * ms / t_isi
- delta_phi = -np.array(list(mean_response_dict.values())) * second / t_isi
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(phi_inh, delta_phi, 'b')
- ax.plot(phi_inh, -phi_inh, 'k--')
- ax.grid(axis='x')
- ax.set_xlabel("phase of inhibition (ms)")
- ax.set_ylabel("phase shift (ms)");
- run_sim(record_states=True, pulse_strength=0.0*mV, duration=200 * ms)
- t_state = state_recorder.t
- v_state = state_recorder[1].v
- ih_state = state_recorder[1].ih
- spike_train = spike_recorder.spike_trains()[0]
- fig, ax1 = plt.subplots()
- ax1.plot(t_state/ms,v_state/ms,'C2')
- ax2 = ax1.twinx()
- ax2.plot(t_state/ms,ih_state/ms,'C1')
- t_isi = get_mean_period(spike_train)
- print(t_isi)
- n_simulations = 40
- pulse_strength = -20 * mV
- delay_list = np.linspace(0.1 * ms,t_isi - 0.1*ms,n_simulations)
- spike_trains_dict = {}
- for delay in delay_list:
- run_sim(delay, pulse_strength, record_states=False, duration=100 * ms)
- spike_train = spike_recorder.spike_trains()[0]
- # print(spike_train)
- spike_trains_dict[delay / ms] = spike_train #Careful, delay is not a Quantity anymore because in must be hashable
- mean_response_dict = get_mean_response(spike_trains_dict, t_isi)
- plot_mean_response(mean_response_dict)
- # plot_phase_response_curve(mean_response_dict, t_isi)
- # plot_spiking(spike_trains_dict)
- plt.show()
|