ring_network_reciprocal_and_binary.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import random
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. from brian2.units import *
  5. import scripts.models as modellib
  6. from scripts.ring_network.head_direction import get_phase_difference, get_head_direction_input, \
  7. get_half_width, ex_in_network
  8. '''
  9. Neuron and synapse models
  10. '''
  11. excitatory_eqs = modellib.hodgkin_huxley_eqs_with_synaptic_conductance + modellib.eqs_ih
  12. excitatory_params = modellib.hodgkin_huxley_params
  13. excitatory_params.update(modellib.ih_params)
  14. excitatory_params.update({"E_i": -80 * mV})
  15. excitatory_params['ghbar'] = 0. * nS
  16. lif_interneuron_eqs = """
  17. dv/dt =1.0/tau* (-v + u_ext) :volt (unless refractory)
  18. u_ext = u_ext_const : volt
  19. """
  20. lif_interneuron_params = {
  21. "tau": 7 * ms,
  22. "v_threshold": -40 * mV,
  23. "v_reset": -60 * mV,
  24. "tau_refractory": 0.0 * ms,
  25. "u_ext_const": -50 * mV
  26. }
  27. lif_interneuron_options = {
  28. "threshold": "v>v_threshold",
  29. "reset": "v=v_reset",
  30. "refractory": "tau_refractory",
  31. 'method': 'euler'
  32. }
  33. ei_synapse_model = modellib.delta_synapse_model
  34. ei_synapse_on_pre = modellib.delta_synapse_on_pre
  35. ei_synapse_param = modellib.delta_synapse_param
  36. ie_synapse_model = modellib.exponential_synapse
  37. ie_synapse_on_pre = modellib.exponential_synapse_on_pre
  38. ie_synapse_param = modellib.exponential_synapse_params
  39. ie_synapse_param["tau_syn"] = 2 * ms
  40. '''
  41. Binary connectvity
  42. I will compare two connectivity styles: uniform (an interneuron receives input from all direction and projects back
  43. to those same directions) and a two-pop connectivity (where an interneuron mediates competition between populations
  44. separated by a given angular distance
  45. '''
  46. N_E = 1000
  47. N_I = 100
  48. excitatory_neurons_per_interneuron = 100
  49. ### Uniform
  50. def distribute_inhibition_uniformly(number_of_excitatory_neurons, number_of_inhibitory_neurons,
  51. number_of_target_excitatory_neurons):
  52. set_of_available_neurons = set(range(number_of_excitatory_neurons))
  53. target_ids = []
  54. for in_idx in range(number_of_inhibitory_neurons):
  55. if len(set_of_available_neurons) > number_of_target_excitatory_neurons:
  56. ids = random.sample(set_of_available_neurons, k=number_of_target_excitatory_neurons)
  57. set_of_available_neurons -= set(ids)
  58. else:
  59. old_ids = list(set_of_available_neurons)
  60. set_of_available_neurons = set(range(N_E))
  61. new_ids = random.sample(set_of_available_neurons - set(old_ids),
  62. k=number_of_target_excitatory_neurons - len(old_ids))
  63. set_of_available_neurons -= set(new_ids)
  64. ids = old_ids + new_ids
  65. target_ids.append((ids))
  66. return target_ids
  67. list_of_excitatory_target_per_interneuron = distribute_inhibition_uniformly(N_E, N_I,
  68. excitatory_neurons_per_interneuron)
  69. inhibitory_synapse_strength = 15 * nS
  70. ie_uniform = np.zeros((N_I, N_E)) * nS
  71. for interneuron_idx, connected_excitatory_idxs in enumerate(list_of_excitatory_target_per_interneuron):
  72. ie_uniform[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  73. excitatory_synapse_strength = 0.5 * mV
  74. ei_uniform = np.where(ie_uniform > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  75. ### mediate direct competition between different tunings
  76. ex_angles = np.linspace(-np.pi, np.pi, N_E)
  77. in_angles = np.linspace(-np.pi, np.pi, N_I)
  78. ex_in_differences = get_phase_difference(np.tile(ex_angles, (N_I, 1)).T - np.tile(in_angles, (N_E, 1)))
  79. in_ex_differences = ex_in_differences.T
  80. angle_between_competing_populations = 100.0 / 180.0 * np.pi
  81. angular_width = 2 * np.pi / N_E * excitatory_neurons_per_interneuron / 2.0
  82. ie_angular = np.zeros((N_I, N_E)) * nS
  83. in_left_arm = np.logical_and(in_ex_differences < -
  84. angle_between_competing_populations / 2.0 + angular_width / 2.0,
  85. in_ex_differences >
  86. - angle_between_competing_populations / 2.0 - angular_width)
  87. in_right_arm = np.logical_and(in_ex_differences >
  88. angle_between_competing_populations / 2.0 - angular_width / 2.0,
  89. in_ex_differences <
  90. angle_between_competing_populations / 2.0 + angular_width)
  91. ie_angular[np.logical_or(in_left_arm, in_right_arm)] = inhibitory_synapse_strength
  92. ei_angular = np.where(ie_angular > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  93. ## Clustered connectivity
  94. ie_clustered = np.zeros((N_I, N_E)) * nS
  95. cluster_width = 2 * np.pi / 360 * 20
  96. excitatory_neurons_per_cluster = int(N_E / N_I)
  97. in_cluster = np.logical_and(in_ex_differences >= -cluster_width / 2.0, in_ex_differences < cluster_width / 2.0)
  98. ie_clustered[in_cluster] = inhibitory_synapse_strength
  99. ex_neurons_indices = set(range(N_E))
  100. for excitatory_connections in ie_clustered:
  101. non_available_indices = set(np.nonzero(excitatory_connections / siemens)[0].tolist())
  102. connected_indices = random.sample(ex_neurons_indices - non_available_indices,
  103. k=excitatory_neurons_per_interneuron - excitatory_neurons_per_cluster)
  104. excitatory_connections[connected_indices] = inhibitory_synapse_strength
  105. ei_clustered = np.where(ie_clustered > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  106. ### No synapses
  107. no_conn_ie = np.zeros((N_I, N_E)) * nS
  108. no_conn_ei = np.zeros((N_E, N_I)) * mV
  109. connectivity_labels = ["No synapse", "Uniform", "Angular {:.1f}°".format(
  110. angle_between_competing_populations / np.pi * 180), "Cluster {:1f}°".format(cluster_width / np.pi * 180)]
  111. connectivities = [(no_conn_ei, no_conn_ie), (ei_uniform, ie_uniform), (ei_angular, ie_angular), (ei_clustered,
  112. ie_clustered)]
  113. ### Plot connectivities
  114. fig, axes = plt.subplots(len(connectivity_labels), 1, sharex=True)
  115. for ax, label, connectivity in zip(axes, connectivity_labels, connectivities):
  116. ei, ie = connectivity
  117. ax.imshow(ie / np.max(ie), vmin=0, vmax=1, cmap="gray")
  118. ax.set_title(label)
  119. ax.set_ylabel("I")
  120. axes[-1].set_xlabel("E")
  121. '''
  122. Prepare nets
  123. '''
  124. nets = []
  125. for ei_weights, ie_weights in connectivities:
  126. net = ex_in_network(N_E, N_I, excitatory_eqs, excitatory_params, lif_interneuron_eqs,
  127. lif_interneuron_params,
  128. lif_interneuron_options, ei_synapse_model, ei_synapse_on_pre,
  129. ei_synapse_param,
  130. ei_weights, ie_synapse_model, ie_synapse_on_pre,
  131. ie_synapse_param, ie_weights, random_seed=2)
  132. nets.append(net)
  133. '''
  134. Head direction input
  135. '''
  136. ex_input_baseline = 0.0 * nA
  137. first_peak_phase = 0
  138. input_sharpness = 1
  139. direction_input = get_head_direction_input(first_peak_phase, input_sharpness)
  140. max_head_direction_input_amplitude = 0.5 * nA
  141. input_to_excitatory_population = ex_input_baseline + max_head_direction_input_amplitude * direction_input(ex_angles)
  142. half_width_input = get_half_width(ex_angles, input_to_excitatory_population)
  143. '''
  144. Run simulation
  145. '''
  146. skip = 100 * ms
  147. length = 1100 * ms
  148. duration = skip + length
  149. for net in nets:
  150. excitatory_neurons = net["excitatory_neurons"]
  151. excitatory_neurons.I = input_to_excitatory_population
  152. net.run(duration)
  153. '''
  154. Get tuning of the network
  155. '''
  156. fig, ax = plt.subplots(1, 1)
  157. output_handles = []
  158. output_widths = []
  159. for net_label, net in zip(connectivity_labels, nets):
  160. excitatory_spike_monitor = net["excitatory_spike_monitor"]
  161. spike_trains = excitatory_spike_monitor.spike_trains()
  162. isis = [np.ediff1d(np.extract(spike_times / ms > skip / ms, spike_times / ms)) * ms for spike_times in
  163. spike_trains.values()]
  164. rates = np.array([1.0 / np.mean(isi / ms) if isi.shape[0] != 0 else 0 for isi in isis]) * khertz
  165. half_width_output = get_half_width(ex_angles, rates)
  166. output_widths.append(half_width_output / np.pi * 180)
  167. output_handles.append(ax.plot(ex_angles / np.pi * 180, rates / hertz))
  168. ax.set_xlabel("Angle")
  169. ax.set_ylabel("f(Hz)")
  170. ax_input = ax.twinx()
  171. input_label = "Input width {" \
  172. ":.1f}°".format(half_width_input / np.pi * 180)
  173. input_handle = ax_input.plot(ex_angles / np.pi * 180, input_to_excitatory_population / nA, 'k--')
  174. handles = [input_handle[0]] + [output_handle[0] for output_handle in output_handles]
  175. labels = [input_label] + ["{:s} output width {:.1f}°".format(type, width) for type, width in zip(connectivity_labels,
  176. output_widths)]
  177. ax.legend(handles, labels)
  178. ax_input.set_ylabel("Input (nA)")
  179. '''
  180. Spike trains
  181. '''
  182. height_ratios = list(map(lambda idx: 0.1 if (idx + 1) % 3 == 0 else 1, [idx for idx in range( len(
  183. connectivity_labels) * 3 - 1)]))
  184. fig, axes = plt.subplots(len(connectivity_labels) * 3 - 1, 1, sharex=True, gridspec_kw={"height_ratios": height_ratios})
  185. axs = [(axes[idx], axes[idx + 1]) for idx in range(0, len(connectivity_labels) * 3 - 1, 3)]
  186. # (axes[0], axes[1]), (axes[3], axes[4]), (axes[6], axes[7])]
  187. for net_label, net, raster_axes in zip(connectivity_labels, nets, axs):
  188. ax_ex = raster_axes[0]
  189. ax_in = raster_axes[1]
  190. ax_ex.set_ylim(-180, 180)
  191. ax_ex.set_ylabel("E angles")
  192. ax_ex_raster = ax_ex.twinx()
  193. ex_spike_trains = net["excitatory_spike_monitor"].spike_trains()
  194. for neuron_idx, spike_times in ex_spike_trains.items():
  195. ax_ex_raster.plot(spike_times / ms, neuron_idx + 1 * np.ones(spike_times.shape), 'r|')
  196. ax_ex_raster.set_xlim(0, duration / ms)
  197. ax_ex_raster.set_ylim(0, N_E + 1)
  198. ax_ex.set_title(net_label)
  199. ax_in.set_ylim(-180, 180)
  200. ax_in.set_ylabel("I angles")
  201. ax_in_raster = ax_in.twinx()
  202. in_spike_trains = net["inhibitory_spike_monitor"].spike_trains()
  203. for neuron_idx, spike_times in in_spike_trains.items():
  204. ax_in_raster.plot(spike_times / ms, neuron_idx + 1 * np.ones(spike_times.shape), 'b|')
  205. axes[-1].set_xlabel("Time (ms)")
  206. for ax in [axes[idx] for idx in range(2, len(connectivity_labels) - 1, 3)]:
  207. ax.set_frame_on(False)
  208. ax.axes.get_xaxis().set_visible(False)
  209. ax.axes.get_yaxis().set_visible(False)
  210. plt.show()