hdi_optimization_via_syn_strength.py 15 KB


  1. import matplotlib.pyplot as plt
  2. import noise
  3. import numpy as np
  4. from brian2.units import *
  5. from scripts.spatial_maps.orientation_map import OrientationMap
  6. from scripts.interneuron_placement import create_interneuron_sheet_entropy_max_orientation
  7. import scripts.models as modellib
  8. from scripts.ring_network.head_direction import get_head_direction_input, \
  9. ex_in_network
  10. from tqdm import tqdm
  11. from scripts.interneuron_placement import create_grid_of_excitatory_neurons, \
  12. create_interneuron_sheet_by_repulsive_force, get_excitatory_neurons_in_inhibitory_axonal_clouds
  13. import multiprocessing
  14. import itertools
  15. trials_per_scale = 1
  16. N_E = 900
  17. N_I = 90
  18. sheet_x = 450 * um
  19. sheet_y = 450 * um
  20. inhibitory_axon_long_axis = 100 * um
  21. inhibitory_axon_short_axis = 25 * um
  22. number_of_excitatory_neurons_per_row = int(np.sqrt(N_E))
  23. '''
  24. Neuron and synapse models
  25. '''
  26. excitatory_eqs = modellib.hodgkin_huxley_eqs_with_synaptic_conductance + modellib.eqs_ih
  27. excitatory_params = modellib.hodgkin_huxley_params
  28. excitatory_params.update(modellib.ih_params)
  29. excitatory_params.update({"E_i": -80 * mV})
  30. excitatory_params['ghbar'] = 0. * nS
  31. lif_interneuron_eqs = """
  32. dv/dt =1.0/tau* (-v + u_ext) :volt (unless refractory)
  33. u_ext = u_ext_const : volt
  34. """
  35. lif_interneuron_params = {
  36. "tau": 7 * ms,
  37. "v_threshold": -40 * mV,
  38. "v_reset": -60 * mV,
  39. "tau_refractory": 0.0 * ms,
  40. "u_ext_const": -50 * mV
  41. }
  42. lif_interneuron_options = {
  43. "threshold": "v>v_threshold",
  44. "reset": "v=v_reset",
  45. "refractory": "tau_refractory",
  46. 'method': 'euler'
  47. }
  48. ei_synapse_model = modellib.delta_synapse_model
  49. ei_synapse_on_pre = modellib.delta_synapse_on_pre
  50. ei_synapse_param = modellib.delta_synapse_param
  51. ie_synapse_model = modellib.exponential_synapse
  52. ie_synapse_on_pre = modellib.exponential_synapse_on_pre
  53. ie_synapse_param = modellib.exponential_synapse_params
  54. ie_synapse_param["tau_syn"] = 2 * ms
  55. '''
  56. Tuning Maps
  57. '''
  58. # tuning_label = "Perlin"
  59. tuning_label = "Orientation map"
  60. # optimization_label = "Repulsive"
  61. optimization_label = "Entropy Optimization"
  62. ellipse_trial_sharpening_list = []
  63. circle_trial_sharpening_list = []
  64. no_conn_trial_sharpening_list = []
  65. def get_fwhm_for_corr_len_and_seed(corr_len, seed, tuning_center):
  66. print(corr_len, tuning_center, seed)
  67. if tuning_label == "Perlin": # TODO: How to handle scale in Perlin
  68. tuning_map = lambda x, y: noise.pnoise2(x / 100.0, y / 100.0, octaves=2) * np.pi
  69. elif tuning_label == "Orientation map":
  70. map = OrientationMap(number_of_excitatory_neurons_per_row + 1, number_of_excitatory_neurons_per_row + 1,
  71. corr_len, sheet_x / um, sheet_y / um, seed)
  72. # map.improve(10)
  73. try:
  74. map.load_orientation_map()
  75. except:
  76. print(
  77. 'No map yet with {}x{} pixels and {} pixel correllation length and {} seed'.format(map.x_dim, map.y_dim,
  78. map.corr_len,
  79. map.rnd_seed))
  80. return -1,-1,-1
  81. tuning_map = lambda x, y: map.tuning(x, y)
  82. ex_positions, ex_tunings = create_grid_of_excitatory_neurons(sheet_x / um, sheet_y / um,
  83. number_of_excitatory_neurons_per_row, tuning_map)
  84. inhibitory_radial_axis = np.sqrt(inhibitory_axon_long_axis * inhibitory_axon_short_axis)
  85. if optimization_label == "Repulsive":
  86. inhibitory_axonal_clouds = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_axon_long_axis / um,
  87. inhibitory_axon_short_axis / um, sheet_x / um,
  88. sheet_y / um, random_seed=2, n_iterations=1000)
  89. inhibitory_axonal_circles = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_radial_axis / um,
  90. inhibitory_radial_axis / um, sheet_x / um,
  91. sheet_y / um, random_seed=2, n_iterations=1000)
  92. elif optimization_label == "Entropy Optimization":
  93. inhibitory_axonal_clouds, ellipse_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(
  94. ex_positions, ex_tunings, N_I, inhibitory_axon_long_axis / um,
  95. inhibitory_axon_short_axis / um, sheet_x / um,
  96. sheet_y / um, trial_orientations=30)
  97. inhibitory_axonal_circles, circle_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(
  98. ex_positions, ex_tunings, N_I, inhibitory_radial_axis / um,
  99. inhibitory_radial_axis / um, sheet_x / um,
  100. sheet_y / um, trial_orientations=1)
  101. '''
  102. Connectvities
  103. '''
  104. # Spatial network with ellipsoid axons
  105. ie_connections = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions, inhibitory_axonal_clouds)
  106. inhibitory_synapse_strength = 30 * nS
  107. in_ex_weights = np.zeros((N_I, N_E)) * nS
  108. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections):
  109. in_ex_weights[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  110. excitatory_synapse_strength = 1 * mV
  111. ex_in_weights = np.where(in_ex_weights > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  112. # Spatial network with circular axons
  113. ie_connections_circle = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions,
  114. inhibitory_axonal_circles)
  115. in_ex_weights_circle = np.zeros((N_I, N_E)) * nS
  116. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections_circle):
  117. in_ex_weights_circle[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  118. excitatory_synapse_strength = 1 * mV
  119. ex_in_weights_circle = np.where(in_ex_weights_circle > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  120. # No synapses
  121. no_conn_ie = np.zeros((N_I, N_E)) * nS
  122. no_conn_ei = np.zeros((N_E, N_I)) * mV
  123. '''
  124. Prepare nets
  125. '''
  126. nets = []
  127. connectivity_label = ["No synapse", "Ellipsoid", "Circle"]
  128. connectivities = [(no_conn_ei, no_conn_ie), (ex_in_weights, in_ex_weights),
  129. (ex_in_weights_circle, in_ex_weights_circle)]
  130. for ei_weights, ie_weights in connectivities:
  131. net = ex_in_network(N_E, N_I, excitatory_eqs, excitatory_params, lif_interneuron_eqs,
  132. lif_interneuron_params,
  133. lif_interneuron_options, ei_synapse_model, ei_synapse_on_pre,
  134. ei_synapse_param,
  135. ei_weights, ie_synapse_model, ie_synapse_on_pre,
  136. ie_synapse_param, ie_weights, random_seed=2)
  137. nets.append(net)
  138. '''
  139. Head direction input
  140. '''
  141. ex_input_baseline = 0.0 * nA
  142. peak_phase = tuning_center
  143. input_sharpness = 1
  144. direction_input = get_head_direction_input(peak_phase, input_sharpness)
  145. max_head_direction_input_amplitude = 0.5 * nA
  146. input_to_excitatory_population = ex_input_baseline + max_head_direction_input_amplitude * direction_input(
  147. np.array(ex_tunings))
  148. '''
  149. Run simulation
  150. '''
  151. skip = 100 * ms
  152. length = 1100 * ms
  153. duration = skip + length
  154. for net in nets:
  155. excitatory_neurons = net["excitatory_neurons"]
  156. excitatory_neurons.I = input_to_excitatory_population
  157. net.run(duration)
  158. '''
  159. Get spatial map of rates
  160. '''
  161. def get_rates(spike_monitor, min_time=0 * ms):
  162. spike_trains = spike_monitor.spike_trains()
  163. isis = [np.ediff1d(np.extract(spike_times / ms > skip / ms, spike_times / ms)) * ms for spike_times in
  164. spike_trains.values()]
  165. rates = np.array([1.0 / np.mean(isi / ms) if isi.shape[0] != 0 else 0 for isi in isis]) * khertz
  166. return rates
  167. excitatory_rates = [get_rates(net["excitatory_spike_monitor"]) for net in nets]
  168. '''
  169. Get rate distribution over angles
  170. '''
  171. return ex_tunings, excitatory_rates
  172. corr_len_range = range(0, 451, 15)
  173. corr_len_range_len = len(corr_len_range)
  174. tuning_range_len = 12
  175. tuning_range = np.linspace(-np.pi,np.pi,tuning_range_len,endpoint=False)
  176. seed_range = range(10)
  177. seed_range_len = len(seed_range)
  178. pool_arguments = itertools.product(corr_len_range, seed_range, tuning_range)
  179. use_saved_array = True
  180. if not use_saved_array:
  181. pool = multiprocessing.Pool()
  182. data = pool.starmap(get_fwhm_for_corr_len_and_seed,[*pool_arguments])
  183. print(type(data))
  184. # print(data)
  185. # data_array = np.reshape(np.array(data),(corr_len_range_len,seed_range_len,tuning_range_len,3))
  186. np.save('../../simulations/2020_02_27_head_direction_index_over_noise_scale/data.npy', np.array(data))
  187. else:
  188. data = np.load('../../simulations/2020_02_27_head_direction_index_over_noise_scale/data.npy', allow_pickle=True)
  189. no_conn_trial_hdi_array = np.array((corr_len_range_len,seed_range_len,tuning_range_len))
  190. ellipse_trial_hdi_array = np.array((corr_len_range_len,seed_range_len,tuning_range_len))
  191. circle_trial_hdi_array = np.array((corr_len_range_len,seed_range_len,tuning_range_len))
  192. # pool_arguments = itertools.product(corr_len_range, seed_range, tuning_range)
  193. # ex_tunings, excitatory_rates = data[0]
  194. # plt.plot(ex_tunings, excitatory_rates[0] / hertz)
  195. # plt.show()
  196. # for id, corr_len, tuning_center, seed in enumerate(pool_arguments):
  197. # ex_tunings, excitatory_rates = data[id]
  198. # print(id, corr_len, tuning_center, seed)
  199. circle_mean_hdi_overall = []
  200. no_conn_mean_hdi_overall = []
  201. ellipse_mean_hdi_overall = []
  202. tuning_vectors_test = []
  203. for cl_id, corr_len in enumerate(tqdm(corr_len_range)):
  204. circle_mean_hdi_per_corr_len = []
  205. no_conn_mean_hdi_per_corr_len = []
  206. ellipse_mean_hdi_per_corr_len = []
  207. for s_id, seed in enumerate(seed_range):
  208. circle_tuning_vector_list = np.zeros((N_E, 2))
  209. ellipse_tuning_vector_list = np.zeros((N_E, 2))
  210. no_conn_tuning_vector_list = np.zeros((N_E, 2))
  211. circle_rates_sum = np.zeros(N_E)
  212. ellipse_rates_sum = np.zeros(N_E)
  213. no_conn_rates_sum = np.zeros(N_E)
  214. for t_id, tuning_center in enumerate(tuning_range):
  215. total_id = t_id + tuning_range_len * s_id + tuning_range_len*seed_range_len * cl_id
  216. # print(cl_id, s_id, t_id, total_id)
  217. ex_tunings, excitatory_rates = data[total_id]
  218. circle_tuning_vector_list_at_tuning = np.array([np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  219. * rate for rate in excitatory_rates[2]])
  220. circle_tuning_vector_list = circle_tuning_vector_list + circle_tuning_vector_list_at_tuning
  221. circle_rates_sum += excitatory_rates[2]
  222. ellipse_tuning_vector_list_at_tuning = np.array([np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  223. * rate for rate in excitatory_rates[1]])
  224. ellipse_tuning_vector_list = ellipse_tuning_vector_list + ellipse_tuning_vector_list_at_tuning
  225. ellipse_rates_sum += excitatory_rates[1]
  226. no_conn_tuning_vector_list_at_tuning = np.array([np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  227. * rate for rate in excitatory_rates[0]])
  228. no_conn_tuning_vector_list = no_conn_tuning_vector_list + no_conn_tuning_vector_list_at_tuning
  229. no_conn_rates_sum += excitatory_rates[0]
  230. if cl_id == 0 and s_id == 0:
  231. print('rates: \n', excitatory_rates[0][0])
  232. print('vectors: \n', no_conn_tuning_vector_list_at_tuning[0])
  233. print('vec norm: \n', np.linalg.norm(no_conn_tuning_vector_list_at_tuning[0]))
  234. tuning_vectors_test.append(no_conn_tuning_vector_list_at_tuning[0])
  235. circle_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in zip(circle_tuning_vector_list,circle_rates_sum)]
  236. circle_hdi_mean_over_pop_list = np.sum(circle_hdi_list) / len(circle_hdi_list)
  237. circle_mean_hdi_per_corr_len.append(circle_hdi_mean_over_pop_list)
  238. ellipse_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in
  239. zip(ellipse_tuning_vector_list, ellipse_rates_sum)]
  240. ellipse_hdi_mean_over_pop_list = np.sum(ellipse_hdi_list) / len(ellipse_hdi_list)
  241. ellipse_mean_hdi_per_corr_len.append(ellipse_hdi_mean_over_pop_list)
  242. no_conn_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in
  243. zip(no_conn_tuning_vector_list, no_conn_rates_sum)]
  244. no_conn_hdi_mean_over_pop_list = np.sum(no_conn_hdi_list) / len(no_conn_hdi_list)
  245. no_conn_mean_hdi_per_corr_len.append(no_conn_hdi_mean_over_pop_list)
  246. # print(circle_tuning_vector_mean)
  247. circle_mean_hdi_overall.append(circle_mean_hdi_per_corr_len)
  248. no_conn_mean_hdi_overall.append(no_conn_mean_hdi_per_corr_len)
  249. ellipse_mean_hdi_overall.append(ellipse_mean_hdi_per_corr_len)
  250. plt.figure()
  251. plt.scatter(np.array(tuning_vectors_test)[:,0],np.array(tuning_vectors_test)[:,1])
  252. plt.show()
  253. # print(data.shape)
  254. #
  255. # for i in range(corr_len_range_len):
  256. # no_conn_trial_sharpening_list.append(data[i,:,0])
  257. # ellipse_trial_sharpening_list.append(data[i,:,1])
  258. # circle_trial_sharpening_list.append(data[i,:,2])
  259. # print(circle_trial_sharpening_list)
  260. ellipse_sharpening_mean = np.array([np.mean(i) for i in ellipse_mean_hdi_overall])
  261. circle_sharpening_mean = np.array([np.mean(i) for i in circle_mean_hdi_overall])
  262. no_conn_sharpening_mean = np.array([np.mean(i) for i in no_conn_mean_hdi_overall])
  263. ellipse_sharpening_std_dev = np.array([np.std(i) for i in ellipse_mean_hdi_overall])
  264. circle_sharpening_std_dev = np.array([np.std(i) for i in circle_mean_hdi_overall])
  265. no_conn_sharpening_std_dev = np.array([np.std(i) for i in no_conn_mean_hdi_overall])
  266. # print(ellipse_trial_sharpening_list)
  267. # print(ellipse_entropy_std_dev)
  268. plt.figure()
  269. plt.plot(corr_len_range,circle_sharpening_mean, label='Circle', marker='o',color='C1')
  270. plt.fill_between(corr_len_range,circle_sharpening_mean-circle_sharpening_std_dev,circle_sharpening_mean+circle_sharpening_std_dev,color='C1',alpha=0.4)
  271. plt.plot(corr_len_range,ellipse_sharpening_mean, label='Ellipse', marker='o',color='C2')
  272. plt.fill_between(corr_len_range,ellipse_sharpening_mean-ellipse_sharpening_std_dev,ellipse_sharpening_mean+ellipse_sharpening_std_dev,color='C2',alpha=0.4)
  273. plt.plot(corr_len_range,no_conn_sharpening_mean, label='No Conn.', marker='o',color='C3')
  274. plt.fill_between(corr_len_range,no_conn_sharpening_mean-no_conn_sharpening_std_dev,no_conn_sharpening_mean+no_conn_sharpening_std_dev,color='C3',alpha=0.4)
  275. plt.xlabel('Correlation length')
  276. plt.ylabel('Head Direction Index')
  277. plt.legend()
  278. plt.show()