prc_hodkin_huxley_h_current_backup.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import copy
  2. import warnings
  3. import brian2 as br
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from brian2.units import *
  7. def set_parameters_from_dict(neurongroup, dictionary_of_parameters):
  8. for param_key, param_value in dictionary_of_parameters.items():
  9. try:
  10. neurongroup.__setattr__(param_key, param_value)
  11. except AttributeError as err:
  12. warnings.warn("{:s} has no parameter {:s}".format(neurongroup.name, param_key))
  13. def get_mean_period(spike_train):
  14. return (np.max(spike_train) - np.min(spike_train)) / (spike_train.shape[0] - 1)
  15. def get_difference_in_periods(spike_trains_dict, t_isi):
  16. mean_response_dict = {}
  17. # t_isi = get_mean_period(spike_trains_dict[0.0 * ms])
  18. for delay, spike_train in spike_trains_dict.items():
  19. mean_response_dict[delay] = get_mean_period(spike_train) - t_isi
  20. return mean_response_dict
  21. def plot_spiking(spike_trains_dict):
  22. fig = plt.figure()
  23. ax = fig.add_subplot(111)
  24. for key, times in spike_trains_dict.items():
  25. ax.plot(times / ms, key * np.ones(times.shape), 'b.')
  26. # ax.plot(spike_train[0]/ms, np.ones(spike_train[0].shape), 'b|')
  27. ax.grid(axis='x')
  28. # ax.set_ylim(-0.1, 1.1)
  29. ax.set_xlabel("Time(ms)")
  30. def plot_mean_period(mean_period_dict):
  31. fig = plt.figure()
  32. ax = fig.add_subplot(111)
  33. ax.plot(list(mean_period_dict.keys()), list(mean_period_dict.values()), 'b')
  34. ax.grid(axis='x')
  35. ax.set_xlabel("Delay (ms)")
  36. def plot_mean_response(mean_response_dict):
  37. fig = plt.figure()
  38. ax = fig.add_subplot(111)
  39. ax.plot(np.array(list(mean_response_dict.keys())), 1e3 * np.array(list(mean_response_dict.values())), 'b')
  40. ax.plot(list(mean_response_dict.keys()), list(mean_response_dict.keys()), 'k--')
  41. ax.grid(axis='x')
  42. ax.set_xlabel("t_inh (ms)")
  43. ax.set_ylabel("exc_spike_delay (ms)")
  44. def plot_delay_minus_response(mean_response_dict):
  45. fig = plt.figure()
  46. ax = fig.add_subplot(111)
  47. print(np.array(list(mean_response_dict.keys())))
  48. print(np.array(list(mean_response_dict.values())))
  49. print(np.array(list(mean_response_dict.keys())) - np.array(list(mean_response_dict.values())))
  50. ax.plot(list(mean_response_dict.keys()),
  51. np.array(list(mean_response_dict.keys())) - 1e3 * np.array(list(mean_response_dict.values())), 'b')
  52. ax.grid(axis='x')
  53. ax.set_xlabel("Delay (ms)");
  54. def plot_phase_response_curve(mean_response_dict, t_isi):
  55. phi_inh = np.array(list(mean_response_dict.keys())) * ms / t_isi
  56. delta_phi = -np.array(list(mean_response_dict.values())) * second / t_isi
  57. fig = plt.figure()
  58. ax = fig.add_subplot(111)
  59. ax.plot(phi_inh, delta_phi, 'b')
  60. ax.plot(phi_inh, -phi_inh, 'k--')
  61. ax.grid(axis='x')
  62. ax.set_xlabel("phase of inhibition (ms)")
  63. ax.set_ylabel("phase shift (ms)");
  64. # Hodgkin Huxley model from Brian2 documentation
  65. area = 20000 * umetre ** 2
  66. hodgkin_huxley_params = {
  67. "Cm": 1 * ufarad * cm ** -2 * area,
  68. "gl": 5e-5 * siemens * cm ** -2 * area,
  69. "El": -65 * mV,
  70. "EK": -90 * mV,
  71. "ENa": 50 * mV,
  72. "g_na": 100 * msiemens * cm ** -2 * area,
  73. "g_kd": 30 * msiemens * cm ** -2 * area,
  74. "VT": -63 * mV
  75. }
  76. # The model
  77. hodgkin_huxley_eqs = br.Equations('''
  78. dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I + ih)/Cm : volt
  79. dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
  80. (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
  81. (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
  82. dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
  83. (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
  84. dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
  85. I : amp
  86. ''')
  87. hodgkin_huxley_eqs_with_synaptic_conductance = br.Equations('''
  88. dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I + ih - g_syn*(v-E_i))/Cm : volt
  89. dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
  90. (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
  91. (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
  92. dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
  93. (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
  94. dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
  95. I : amp
  96. g_syn: siemens
  97. ''')
  98. # # First H-current from Izhikevich p. 48
  99. # ghbar = 40. * nS #Other values would be 0.5, 2, 3.5, 20 depending on neuron type (Rothman, Manis)
  100. # Eh = -43*mV
  101. # V_half_h = -75.0
  102. # k_h = -5.5
  103. # V_max_h = -75.
  104. # sigma_h = 15.
  105. # C_amp_h = 1000.
  106. # C_base_h = 100.
  107. # eqs_ih = """
  108. # ih = ghbar*r*(Eh-v) : amp
  109. # dr/dt= (rinf-r)/rtau : 1
  110. # rinf = 1. / (1+exp((V_half_h - v/mV) / k_h)) : 1
  111. # rtau = (C_base_h + C_amp_h * exp(-(V_max_h-v/mV)**2./sigma_h**2)) * ms : second
  112. # """
  113. # Second H-current from Izhikevich p. 48
  114. ih_params = {
  115. "ghbar": 60. * nS, # Other values would be 0.5, 2, 3.5, 20 depending on neuron type (Rothman, Manis)
  116. "Eh": -1. * mV,
  117. "V_half_h": -98.0,
  118. "k_h": -5.5,
  119. "V_max_h": -75.,
  120. "sigma_h": 25.,
  121. "C_amp_h": 60.,
  122. "C_base_h": 40.,
  123. }
  124. eqs_ih = """
  125. ih = ghbar*r*(Eh-v) : amp
  126. dr/dt= (rinf-r)/rtau : 1
  127. rinf = 1. / (1+exp((V_half_h - v/mV) / k_h)) : 1
  128. rtau = 0.1*(C_base_h + C_amp_h * exp(-(V_max_h-v/mV)**2./sigma_h**2)) * ms : second
  129. """
  130. # eqs_ih = """
  131. # ih = ghbar*r*(Eh-v) : amp
  132. # dr/dt=(rinf-r)/rtau : 1
  133. # rinf = 1. / (1+exp((v/mV + 90.) / 7.)) : 1
  134. # # rtau = ((100000. / (237.*exp((v/mV+60.) / 12.) + 17.*exp(-(v/mV+60.) / 14.))) + 25.)*ms : second
  135. # rtau = 0.01*(100. + 1000. * exp(-(-76.-v/mV)**2./15.**2)) * ms : second #From Izhikevich
  136. # """
  137. delta_synapse_model = 'synaptic_strength: volt'
  138. delta_synapse_on_pre = 'v+=synaptic_strength'
  139. delta_synapse_param = {}
  140. exponential_synapse = """
  141. dg/dt = -g/tau_syn : siemens
  142. g_syn_post = g :siemens (summed)
  143. tau_syn : second
  144. synaptic_strength : siemens
  145. """
  146. exponential_synapse_on_pre = "g+=synaptic_strength"
  147. exponential_synapse_params = {
  148. "tau_syn": 5 * ms
  149. }
  150. spike_threshold = +40 * mV
  151. ## With voltage delta synapse
  152. neuron_eqs = hodgkin_huxley_eqs+eqs_ih
  153. synapse_model = delta_synapse_model
  154. synapse_on_pre = delta_synapse_on_pre
  155. synapse_params = delta_synapse_param
  156. neuron_params = hodgkin_huxley_params
  157. neuron_params.update(ih_params)
  158. inhibition_off = 0.0 * mV
  159. inhibition_on = -30*mV
  160. # ### With conductance based delta synapse
  161. # neuron_eqs = hodgkin_huxley_eqs_with_synaptic_conductance + eqs_ih
  162. #
  163. # synapse_model = exponential_synapse
  164. # synapse_on_pre = exponential_synapse_on_pre
  165. # synapse_params = exponential_synapse_params
  166. #
  167. # neuron_params = hodgkin_huxley_params
  168. # neuron_params.update(ih_params)
  169. # neuron_params.update({"E_i": -80 * mV})
  170. #
  171. # inhibition_off = 0.0 * nS
  172. # inhibition_on = 100 * nS
  173. network_params = copy.deepcopy(neuron_params)
  174. network_params.update(synapse_params)
  175. record_variables = ['v', 'ih']
  176. integration_method = 'exponential_euler'
  177. initial_states = {
  178. "v": hodgkin_huxley_params["El"]
  179. }
  180. threshold_eqs = """
  181. v_threshold: volt
  182. """
  183. neuron = br.NeuronGroup(N=1, \
  184. model=neuron_eqs + threshold_eqs, \
  185. threshold='v > v_threshold', \
  186. refractory='v > v_threshold', \
  187. method=integration_method)
  188. neuron.v_threshold = spike_threshold
  189. set_parameters_from_dict(neuron, initial_states)
  190. neuron.I = 0.5 * nA
  191. spike_recorder = br.SpikeMonitor(source=neuron)
  192. neuron_state_recorder = br.StateMonitor(neuron, record_variables, record=True)
  193. auto_synapse = br.Synapses(source=neuron, target=neuron, model=synapse_model, on_pre=synapse_on_pre,
  194. delay=0.0 * ms)
  195. auto_synapse.connect()
  196. set_parameters_from_dict(auto_synapse, synapse_params)
  197. net = br.Network(neuron)
  198. net.add(auto_synapse)
  199. net.add(spike_recorder)
  200. net.add(neuron_state_recorder)
  201. net.store()
  202. def run_sim(delay=0.0, pulse_strength=0.0, record_states=False, duration=50 * ms):
  203. net.restore()
  204. neuron_state_recorder.record = record_states
  205. auto_synapse.delay = delay
  206. auto_synapse.synaptic_strength = pulse_strength
  207. net.run(duration=duration, namespace=network_params)
  208. run_sim(delay=3 * ms, record_states=True, pulse_strength=inhibition_off, duration=200 * ms)
  209. t_state = neuron_state_recorder.t
  210. v_state = neuron_state_recorder[1].v
  211. ih_state = neuron_state_recorder[1].ih
  212. spike_train = spike_recorder.spike_trains()[0]
  213. fig, ax1 = plt.subplots()
  214. ax1.plot(t_state / ms, v_state / ms, 'C2')
  215. plt.show()
  216. period_unperturbed = get_mean_period(spike_train)
  217. print(period_unperturbed)
  218. n_simulations = 40
  219. inhibitory_delay = np.linspace(1 * ms, period_unperturbed - 1 * ms, n_simulations)
  220. spike_trains_dict = {}
  221. periods = []
  222. for delay in inhibitory_delay:
  223. run_sim(delay, inhibition_on, record_states=False, duration=100 * ms)
  224. spike_train = spike_recorder.spike_trains()[0]
  225. # print(spike_train)
  226. spike_trains_dict[delay / ms] = spike_train # Careful, delay is not a Quantity anymore because in must be hashable
  227. periods.append(get_mean_period(spike_train))
  228. periods = np.array(periods)
  229. plt.plot(inhibitory_delay/ms, periods/ms)
  230. plt.hlines(period_unperturbed/ms, inhibitory_delay[0]/ms, inhibitory_delay[-1]/ms, label="No inhibition")
  231. plt.xlabel("Delay spike to inhibition (ms)")
  232. plt.ylabel("Period (ms")
  233. plt.legend()
  234. t_diff_dict = get_difference_in_periods(spike_trains_dict, period_unperturbed)
  235. phases = np.array([delay / period_unperturbed for delay in inhibitory_delay])
  236. phase_diff = np.array([-t_diff_dict[delay / ms] / period_unperturbed for delay in inhibitory_delay])
  237. optimal_inhibitory_delay_idx = np.argmax(phase_diff)
  238. optimal_inhibitory_delay = list(t_diff_dict.keys())[optimal_inhibitory_delay_idx] # assumes a single maximum
  239. optimal_inhibitory_delay_phase = phases[optimal_inhibitory_delay_idx] # assumes a single maximum
  240. plt.figure()
  241. plt.plot(phases, phase_diff)
  242. plt.vlines(optimal_inhibitory_delay, -1, 1)
  243. plt.title("Phase response curve")
  244. plt.ylim(-1, 1)
  245. plt.figure()
  246. if optimal_inhibitory_delay_idx == 0:
  247. delta_i = phases
  248. else:
  249. delta_i = phases[:-optimal_inhibitory_delay_idx]
  250. print(delta_i.shape)
  251. delta_iplus1 = delta_i + 1 + phase_diff[optimal_inhibitory_delay_idx:] - phase_diff[optimal_inhibitory_delay_idx]
  252. plt.plot(delta_i, delta_iplus1)
  253. plt.plot(delta_i, delta_i, '--')
  254. plt.hlines(1, 0, 1)
  255. plt.xlabel("Delta i")
  256. plt.ylabel("Delta i+1")
  257. plt.title("Iterative map")
  258. plt.ylim(0, 1.2)
  259. plt.xlim(0, 1)
  260. # plot_mean_response(t_diff_dict)
  261. # plot_phase_response_curve(t_diff_dict, t_isi)
  262. # plot_spiking(spike_trains_dict)
  263. plt.show()