123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- import random
- import matplotlib.pyplot as plt
- import numpy as np
- from brian2.units import *
- import scripts.models as modellib
- from scripts.ring_network.head_direction import get_phase_difference, get_head_direction_input, \
- get_half_width, ex_in_network
- '''
- Neuron and synapse models
- '''
- excitatory_eqs = modellib.hodgkin_huxley_eqs_with_synaptic_conductance + modellib.eqs_ih
- excitatory_params = modellib.hodgkin_huxley_params
- excitatory_params.update(modellib.ih_params)
- excitatory_params.update({"E_i": -80 * mV})
- excitatory_params['ghbar'] = 0. * nS
- lif_interneuron_eqs = """
- dv/dt =1.0/tau* (-v + u_ext) :volt (unless refractory)
- u_ext = u_ext_const : volt
- """
- lif_interneuron_params = {
- "tau": 7 * ms,
- "v_threshold": -40 * mV,
- "v_reset": -60 * mV,
- "tau_refractory": 0.0 * ms,
- "u_ext_const": -50 * mV
- }
- lif_interneuron_options = {
- "threshold": "v>v_threshold",
- "reset": "v=v_reset",
- "refractory": "tau_refractory",
- 'method': 'euler'
- }
- ei_synapse_model = modellib.delta_synapse_model
- ei_synapse_on_pre = modellib.delta_synapse_on_pre
- ei_synapse_param = modellib.delta_synapse_param
- ie_synapse_model = modellib.exponential_synapse
- ie_synapse_on_pre = modellib.exponential_synapse_on_pre
- ie_synapse_param = modellib.exponential_synapse_params
- ie_synapse_param["tau_syn"] = 2 * ms
- '''
- Binary connectvity
- I will compare two connectivity styles: uniform (an interneuron receives input from all direction and projects back
- to those same directions) and a two-pop connectivity (where an interneuron mediates competition between populations
- separated by a given angular distance
- '''
- N_E = 1000
- N_I = 100
- excitatory_neurons_per_interneuron = 100
- ### Uniform
- def distribute_inhibition_uniformly(number_of_excitatory_neurons, number_of_inhibitory_neurons,
- number_of_target_excitatory_neurons):
- set_of_available_neurons = set(range(number_of_excitatory_neurons))
- target_ids = []
- for in_idx in range(number_of_inhibitory_neurons):
- if len(set_of_available_neurons) > number_of_target_excitatory_neurons:
- ids = random.sample(set_of_available_neurons, k=number_of_target_excitatory_neurons)
- set_of_available_neurons -= set(ids)
- else:
- old_ids = list(set_of_available_neurons)
- set_of_available_neurons = set(range(N_E))
- new_ids = random.sample(set_of_available_neurons - set(old_ids),
- k=number_of_target_excitatory_neurons - len(old_ids))
- set_of_available_neurons -= set(new_ids)
- ids = old_ids + new_ids
- target_ids.append((ids))
- return target_ids
- list_of_excitatory_target_per_interneuron = distribute_inhibition_uniformly(N_E, N_I,
- excitatory_neurons_per_interneuron)
- inhibitory_synapse_strength = 15 * nS
- ie_uniform = np.zeros((N_I, N_E)) * nS
- for interneuron_idx, connected_excitatory_idxs in enumerate(list_of_excitatory_target_per_interneuron):
- ie_uniform[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
- excitatory_synapse_strength = 0.5 * mV
- ei_uniform = np.where(ie_uniform > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
- ### mediate direct competition between different tunings
- ex_angles = np.linspace(-np.pi, np.pi, N_E)
- in_angles = np.linspace(-np.pi, np.pi, N_I)
- ex_in_differences = get_phase_difference(np.tile(ex_angles, (N_I, 1)).T - np.tile(in_angles, (N_E, 1)))
- in_ex_differences = ex_in_differences.T
- angle_between_competing_populations = 100.0 / 180.0 * np.pi
- angular_width = 2 * np.pi / N_E * excitatory_neurons_per_interneuron / 2.0
- ie_angular = np.zeros((N_I, N_E)) * nS
- in_left_arm = np.logical_and(in_ex_differences < -
- angle_between_competing_populations / 2.0 + angular_width / 2.0,
- in_ex_differences >
- - angle_between_competing_populations / 2.0 - angular_width)
- in_right_arm = np.logical_and(in_ex_differences >
- angle_between_competing_populations / 2.0 - angular_width / 2.0,
- in_ex_differences <
- angle_between_competing_populations / 2.0 + angular_width)
- ie_angular[np.logical_or(in_left_arm, in_right_arm)] = inhibitory_synapse_strength
- ei_angular = np.where(ie_angular > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
- ## Clustered connectivity
- ie_clustered = np.zeros((N_I, N_E)) * nS
- cluster_width = 2 * np.pi / 360 * 20
- excitatory_neurons_per_cluster = int(N_E / N_I)
- in_cluster = np.logical_and(in_ex_differences >= -cluster_width / 2.0, in_ex_differences < cluster_width / 2.0)
- ie_clustered[in_cluster] = inhibitory_synapse_strength
- ex_neurons_indices = set(range(N_E))
- for excitatory_connections in ie_clustered:
- non_available_indices = set(np.nonzero(excitatory_connections / siemens)[0].tolist())
- connected_indices = random.sample(ex_neurons_indices - non_available_indices,
- k=excitatory_neurons_per_interneuron - excitatory_neurons_per_cluster)
- excitatory_connections[connected_indices] = inhibitory_synapse_strength
- ei_clustered = np.where(ie_clustered > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
- ### No synapses
- no_conn_ie = np.zeros((N_I, N_E)) * nS
- no_conn_ei = np.zeros((N_E, N_I)) * mV
- connectivity_labels = ["No synapse", "Uniform", "Angular {:.1f}°".format(
- angle_between_competing_populations / np.pi * 180), "Cluster {:1f}°".format(cluster_width / np.pi * 180)]
- connectivities = [(no_conn_ei, no_conn_ie), (ei_uniform, ie_uniform), (ei_angular, ie_angular), (ei_clustered,
- ie_clustered)]
- ### Plot connectivities
- fig, axes = plt.subplots(len(connectivity_labels), 1, sharex=True)
- for ax, label, connectivity in zip(axes, connectivity_labels, connectivities):
- ei, ie = connectivity
- ax.imshow(ie / np.max(ie), vmin=0, vmax=1, cmap="gray")
- ax.set_title(label)
- ax.set_ylabel("I")
- axes[-1].set_xlabel("E")
- '''
- Prepare nets
- '''
- nets = []
- for ei_weights, ie_weights in connectivities:
- net = ex_in_network(N_E, N_I, excitatory_eqs, excitatory_params, lif_interneuron_eqs,
- lif_interneuron_params,
- lif_interneuron_options, ei_synapse_model, ei_synapse_on_pre,
- ei_synapse_param,
- ei_weights, ie_synapse_model, ie_synapse_on_pre,
- ie_synapse_param, ie_weights, random_seed=2)
- nets.append(net)
- '''
- Head direction input
- '''
- ex_input_baseline = 0.0 * nA
- first_peak_phase = 0
- input_sharpness = 1
- direction_input = get_head_direction_input(first_peak_phase, input_sharpness)
- max_head_direction_input_amplitude = 0.5 * nA
- input_to_excitatory_population = ex_input_baseline + max_head_direction_input_amplitude * direction_input(ex_angles)
- half_width_input = get_half_width(ex_angles, input_to_excitatory_population)
- '''
- Run simulation
- '''
- skip = 100 * ms
- length = 1100 * ms
- duration = skip + length
- for net in nets:
- excitatory_neurons = net["excitatory_neurons"]
- excitatory_neurons.I = input_to_excitatory_population
- net.run(duration)
- '''
- Get tuning of the network
- '''
- fig, ax = plt.subplots(1, 1)
- output_handles = []
- output_widths = []
- for net_label, net in zip(connectivity_labels, nets):
- excitatory_spike_monitor = net["excitatory_spike_monitor"]
- spike_trains = excitatory_spike_monitor.spike_trains()
- isis = [np.ediff1d(np.extract(spike_times / ms > skip / ms, spike_times / ms)) * ms for spike_times in
- spike_trains.values()]
- rates = np.array([1.0 / np.mean(isi / ms) if isi.shape[0] != 0 else 0 for isi in isis]) * khertz
- half_width_output = get_half_width(ex_angles, rates)
- output_widths.append(half_width_output / np.pi * 180)
- output_handles.append(ax.plot(ex_angles / np.pi * 180, rates / hertz))
- ax.set_xlabel("Angle")
- ax.set_ylabel("f(Hz)")
- ax_input = ax.twinx()
- input_label = "Input width {" \
- ":.1f}°".format(half_width_input / np.pi * 180)
- input_handle = ax_input.plot(ex_angles / np.pi * 180, input_to_excitatory_population / nA, 'k--')
- handles = [input_handle[0]] + [output_handle[0] for output_handle in output_handles]
- labels = [input_label] + ["{:s} output width {:.1f}°".format(type, width) for type, width in zip(connectivity_labels,
- output_widths)]
- ax.legend(handles, labels)
- ax_input.set_ylabel("Input (nA)")
- '''
- Spike trains
- '''
- height_ratios = list(map(lambda idx: 0.1 if (idx + 1) % 3 == 0 else 1, [idx for idx in range( len(
- connectivity_labels) * 3 - 1)]))
- fig, axes = plt.subplots(len(connectivity_labels) * 3 - 1, 1, sharex=True, gridspec_kw={"height_ratios": height_ratios})
- axs = [(axes[idx], axes[idx + 1]) for idx in range(0, len(connectivity_labels) * 3 - 1, 3)]
- # (axes[0], axes[1]), (axes[3], axes[4]), (axes[6], axes[7])]
- for net_label, net, raster_axes in zip(connectivity_labels, nets, axs):
- ax_ex = raster_axes[0]
- ax_in = raster_axes[1]
- ax_ex.set_ylim(-180, 180)
- ax_ex.set_ylabel("E angles")
- ax_ex_raster = ax_ex.twinx()
- ex_spike_trains = net["excitatory_spike_monitor"].spike_trains()
- for neuron_idx, spike_times in ex_spike_trains.items():
- ax_ex_raster.plot(spike_times / ms, neuron_idx + 1 * np.ones(spike_times.shape), 'r|')
- ax_ex_raster.set_xlim(0, duration / ms)
- ax_ex_raster.set_ylim(0, N_E + 1)
- ax_ex.set_title(net_label)
- ax_in.set_ylim(-180, 180)
- ax_in.set_ylabel("I angles")
- ax_in_raster = ax_in.twinx()
- in_spike_trains = net["inhibitory_spike_monitor"].spike_trains()
- for neuron_idx, spike_times in in_spike_trains.items():
- ax_in_raster.plot(spike_times / ms, neuron_idx + 1 * np.ones(spike_times.shape), 'b|')
- axes[-1].set_xlabel("Time (ms)")
- for ax in [axes[idx] for idx in range(2, len(connectivity_labels) - 1, 3)]:
- ax.set_frame_on(False)
- ax.axes.get_xaxis().set_visible(False)
- ax.axes.get_yaxis().set_visible(False)
- plt.show()
|