prc_hodkin_huxley_h_current_brianutils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import brian2 as br
  2. from brian2.units import *
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import warnings
  6. import brianutils, json, numpy
  7. import sys
  8. print(sys.path)
  9. print("load model")
  10. model_dict= json.load(open("../../models/Hodkin_Homebrew.json"))
  11. eqs = brianutils.load_model(model_dict,substitution_depth=4)
  12. print("model loaded")
  13. delta_synapse_model = 'perturbation: volt'
  14. delta_synapse = 'v+=perturbation'
  15. threshold = "v>v_threshold"
  16. neuron_properties = {
  17. "tau": 10*ms,
  18. "v_threshold": -40*mV,
  19. "v_reset": -75*mV,
  20. "v_refractory": 'v > -40*mV',
  21. "u_ext": - 39 * mV
  22. }
  23. # pulse_strength = -10. * mV
  24. neuron = br.NeuronGroup(N=1, \
  25. name='single_neuron',\
  26. model=eqs, \
  27. threshold='v > -40*mV', \
  28. refractory = 'v > -40*mV',\
  29. method='exponential_euler')
  30. record_variables = ['v','ih']
  31. spike_recorder = br.SpikeMonitor(source=neuron)
  32. state_recorder = br.StateMonitor(neuron, record_variables, record=True)
  33. auto_synapse = br.Synapses(source=neuron, target=neuron, model=delta_synapse_model, on_pre = delta_synapse, delay = 0.0 * ms)
  34. auto_synapse.connect()
  35. net = br.Network(neuron)
  36. net.add(auto_synapse)
  37. net.add(spike_recorder)
  38. net.add(state_recorder)
  39. net.store()
  40. def run_sim(delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms):
  41. net.restore()
  42. state_recorder.record = record_states
  43. auto_synapse.perturbation = pulse_strength
  44. auto_synapse.delay = delay
  45. net.run(duration=duration)
  46. def get_mean_period(spike_train):
  47. return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1)
  48. def get_mean_response(spike_trains_dict,t_isi):
  49. mean_response_dict = {}
  50. # t_isi = get_mean_period(spike_trains_dict[0.0 * ms])
  51. for delay, spike_train in spike_trains_dict.items():
  52. mean_response_dict[delay] = get_mean_period(spike_train) - t_isi
  53. return mean_response_dict
  54. def plot_spiking(spike_trains_dict):
  55. fig = plt.figure()
  56. ax = fig.add_subplot(111)
  57. for key, times in spike_trains_dict.items():
  58. ax.plot(times/ms, key*np.ones(times.shape), 'b.')
  59. # ax.plot(spike_train[0]/ms, np.ones(spike_train[0].shape), 'b|')
  60. ax.grid(axis='x')
  61. # ax.set_ylim(-0.1, 1.1)
  62. ax.set_xlabel("Time(ms)");
  63. def plot_mean_period(mean_period_dict):
  64. fig = plt.figure()
  65. ax = fig.add_subplot(111)
  66. ax.plot(list(mean_period_dict.keys()), list(mean_period_dict.values()), 'b')
  67. ax.grid(axis='x')
  68. ax.set_xlabel("Delay (ms)");
  69. def plot_mean_response(mean_response_dict):
  70. fig = plt.figure()
  71. ax = fig.add_subplot(111)
  72. ax.plot(np.array(list(mean_response_dict.keys())), 1e3*np.array(list(mean_response_dict.values())), 'b')
  73. ax.plot(list(mean_response_dict.keys()), list(mean_response_dict.keys()), 'k--')
  74. ax.grid(axis='x')
  75. ax.set_xlabel("t_inh (ms)")
  76. ax.set_ylabel("exc_spike_delay (ms)")
  77. def plot_delay_minus_response(mean_response_dict):
  78. fig = plt.figure()
  79. ax = fig.add_subplot(111)
  80. print(np.array(list(mean_response_dict.keys())))
  81. print(np.array(list(mean_response_dict.values())))
  82. print(np.array(list(mean_response_dict.keys()))-np.array(list(mean_response_dict.values())))
  83. ax.plot(list(mean_response_dict.keys()), np.array(list(mean_response_dict.keys()))-1e3*np.array(list(mean_response_dict.values())), 'b')
  84. ax.grid(axis='x')
  85. ax.set_xlabel("Delay (ms)");
  86. def plot_phase_response_curve(mean_response_dict, t_isi):
  87. phi_inh = np.array(list(mean_response_dict.keys())) * ms / t_isi
  88. delta_phi = -np.array(list(mean_response_dict.values())) * second / t_isi
  89. fig = plt.figure()
  90. ax = fig.add_subplot(111)
  91. ax.plot(phi_inh, delta_phi, 'b')
  92. ax.plot(phi_inh, -phi_inh, 'k--')
  93. ax.grid(axis='x')
  94. ax.set_xlabel("phase of inhibition (ms)")
  95. ax.set_ylabel("phase shift (ms)");
  96. run_sim(record_states=True, pulse_strength=0.0*mV, duration=200 * ms)
  97. t_state = state_recorder.t
  98. v_state = state_recorder[1].v
  99. ih_state = state_recorder[1].ih
  100. spike_train = spike_recorder.spike_trains()[0]
  101. fig, ax1 = plt.subplots()
  102. ax1.plot(t_state/ms,v_state/ms,'C2')
  103. ax2 = ax1.twinx()
  104. ax2.plot(t_state/ms,ih_state/ms,'C1')
  105. t_isi = get_mean_period(spike_train)
  106. print(t_isi)
  107. n_simulations = 40
  108. pulse_strength = -20 * mV
  109. delay_list = np.linspace(0.1 * ms,t_isi - 0.1*ms,n_simulations)
  110. spike_trains_dict = {}
  111. for delay in delay_list:
  112. run_sim(delay, pulse_strength, record_states=False, duration=100 * ms)
  113. spike_train = spike_recorder.spike_trains()[0]
  114. # print(spike_train)
  115. spike_trains_dict[delay / ms] = spike_train #Careful, delay is not a Quantity anymore because in must be hashable
  116. mean_response_dict = get_mean_response(spike_trains_dict, t_isi)
  117. plot_mean_response(mean_response_dict)
  118. # plot_phase_response_curve(mean_response_dict, t_isi)
  119. # plot_spiking(spike_trains_dict)
  120. plt.show()