head_direction_index_over_noise_scale.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import itertools
  2. import multiprocessing
  3. import matplotlib.pyplot as plt
  4. import noise
  5. import numpy as np
  6. from brian2.units import *
  7. from tqdm import tqdm
  8. import scripts.models as modellib
  9. from scripts.interneuron_placement import create_grid_of_excitatory_neurons, \
  10. create_interneuron_sheet_by_repulsive_force, get_excitatory_neurons_in_inhibitory_axonal_clouds
  11. from scripts.interneuron_placement import create_interneuron_sheet_entropy_max_orientation
  12. from scripts.ring_network.head_direction import get_head_direction_input, \
  13. ex_in_network
  14. from scripts.spatial_maps.orientation_map import OrientationMap
  15. trials_per_scale = 1
  16. N_E = 900
  17. N_I = 90
  18. sheet_x = 450 * um
  19. sheet_y = 450 * um
  20. inhibitory_axon_long_axis = 100 * um
  21. inhibitory_axon_short_axis = 25 * um
  22. number_of_excitatory_neurons_per_row = int(np.sqrt(N_E))
  23. '''
  24. Neuron and synapse models
  25. '''
  26. excitatory_eqs = modellib.hodgkin_huxley_eqs_with_synaptic_conductance + modellib.eqs_ih
  27. excitatory_params = modellib.hodgkin_huxley_params
  28. excitatory_params.update(modellib.ih_params)
  29. excitatory_params.update({"E_i": -80 * mV})
  30. excitatory_params['ghbar'] = 0. * nS
  31. lif_interneuron_eqs = """
  32. dv/dt =1.0/tau* (-v + u_ext) :volt (unless refractory)
  33. u_ext = u_ext_const : volt
  34. """
  35. lif_interneuron_params = {
  36. "tau": 7 * ms,
  37. "v_threshold": -40 * mV,
  38. "v_reset": -60 * mV,
  39. "tau_refractory": 0.0 * ms,
  40. "u_ext_const": -50 * mV
  41. }
  42. lif_interneuron_options = {
  43. "threshold": "v>v_threshold",
  44. "reset": "v=v_reset",
  45. "refractory": "tau_refractory",
  46. 'method': 'euler'
  47. }
  48. ei_synapse_model = modellib.delta_synapse_model
  49. ei_synapse_on_pre = modellib.delta_synapse_on_pre
  50. ei_synapse_param = modellib.delta_synapse_param
  51. ie_synapse_model = modellib.exponential_synapse
  52. ie_synapse_on_pre = modellib.exponential_synapse_on_pre
  53. ie_synapse_param = modellib.exponential_synapse_params
  54. ie_synapse_param["tau_syn"] = 2 * ms
  55. '''
  56. Tuning Maps
  57. '''
  58. # tuning_label = "Perlin"
  59. tuning_label = "Orientation map"
  60. # optimization_label = "Repulsive"
  61. optimization_label = "Entropy Optimization"
  62. ellipse_trial_sharpening_list = []
  63. circle_trial_sharpening_list = []
  64. no_conn_trial_sharpening_list = []
  65. def get_fwhm_for_corr_len_and_seed(corr_len, seed, tuning_center):
  66. print(corr_len, tuning_center, seed)
  67. if tuning_label == "Perlin": # TODO: How to handle scale in Perlin
  68. tuning_map = lambda x, y: noise.pnoise2(x / 100.0, y / 100.0, octaves=2) * np.pi
  69. elif tuning_label == "Orientation map":
  70. map = OrientationMap(number_of_excitatory_neurons_per_row + 1, number_of_excitatory_neurons_per_row + 1,
  71. corr_len, sheet_x / um, sheet_y / um, seed)
  72. # map.improve(10)
  73. try:
  74. map.load_orientation_map()
  75. except:
  76. print(
  77. 'No map yet with {}x{} pixels and {} pixel correllation length and {} seed'.format(map.x_dim, map.y_dim,
  78. map.corr_len,
  79. map.rnd_seed))
  80. return -1, -1, -1
  81. tuning_map = lambda x, y: map.orientation_map(x, y)
  82. ex_positions, ex_tunings = create_grid_of_excitatory_neurons(sheet_x / um, sheet_y / um,
  83. number_of_excitatory_neurons_per_row, tuning_map)
  84. inhibitory_radial_axis = np.sqrt(inhibitory_axon_long_axis * inhibitory_axon_short_axis)
  85. if optimization_label == "Repulsive":
  86. inhibitory_axonal_clouds = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_axon_long_axis / um,
  87. inhibitory_axon_short_axis / um,
  88. sheet_x / um,
  89. sheet_y / um, random_seed=2,
  90. n_iterations=1000)
  91. inhibitory_axonal_circles = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_radial_axis / um,
  92. inhibitory_radial_axis / um,
  93. sheet_x / um,
  94. sheet_y / um, random_seed=2,
  95. n_iterations=1000)
  96. elif optimization_label == "Entropy Optimization":
  97. inhibitory_axonal_clouds, ellipse_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(
  98. ex_positions, ex_tunings, N_I, inhibitory_axon_long_axis / um,
  99. inhibitory_axon_short_axis / um, sheet_x / um,
  100. sheet_y / um, trial_orientations=30)
  101. inhibitory_axonal_circles, circle_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(
  102. ex_positions, ex_tunings, N_I, inhibitory_radial_axis / um,
  103. inhibitory_radial_axis / um, sheet_x / um,
  104. sheet_y / um, trial_orientations=1)
  105. '''
  106. Connectvities
  107. '''
  108. # Spatial network with ellipsoid axons
  109. ie_connections = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions, inhibitory_axonal_clouds)
  110. inhibitory_synapse_strength = 30 * nS
  111. excitatory_synapse_strength = 1 * mV
  112. ex_in_weights, in_ex_weights = get_synaptic_weights(N_E, N_I, ie_connections, excitatory_synapse_strength,
  113. inhibitory_synapse_strength)
  114. # Spatial network with circular axons
  115. ie_connections_circle = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions,
  116. inhibitory_axonal_circles)
  117. in_ex_weights_circle = np.zeros((N_I, N_E)) * nS
  118. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections_circle):
  119. in_ex_weights_circle[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  120. excitatory_synapse_strength = 1 * mV
  121. ex_in_weights_circle = np.where(in_ex_weights_circle > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  122. # No synapses
  123. no_conn_ie = np.zeros((N_I, N_E)) * nS
  124. no_conn_ei = np.zeros((N_E, N_I)) * mV
  125. '''
  126. Prepare nets
  127. '''
  128. nets = []
  129. connectivity_label = ["No synapse", "Ellipsoid", "Circle"]
  130. connectivities = [(no_conn_ei, no_conn_ie), (ex_in_weights, in_ex_weights),
  131. (ex_in_weights_circle, in_ex_weights_circle)]
  132. for ei_weights, ie_weights in connectivities:
  133. net = ex_in_network(N_E, N_I, excitatory_eqs, excitatory_params, lif_interneuron_eqs,
  134. lif_interneuron_params,
  135. lif_interneuron_options, ei_synapse_model, ei_synapse_on_pre,
  136. ei_synapse_param,
  137. ei_weights, ie_synapse_model, ie_synapse_on_pre,
  138. ie_synapse_param, ie_weights, random_seed=2)
  139. nets.append(net)
  140. '''
  141. Head direction input
  142. '''
  143. ex_input_baseline = 0.0 * nA
  144. input_sharpness = 1
  145. max_head_direction_input_amplitude = 0.5 * nA
  146. input_to_excitatory_population = create_head_direction_input(ex_input_baseline, ex_tunings, input_sharpness,
  147. max_head_direction_input_amplitude, tuning_center)
  148. '''
  149. Run simulation
  150. '''
  151. skip = 100 * ms
  152. length = 1100 * ms
  153. duration = skip + length
  154. for net in nets:
  155. excitatory_neurons = net["excitatory_neurons"]
  156. excitatory_neurons.I = input_to_excitatory_population
  157. net.run(duration)
  158. '''
  159. Get spatial map of rates
  160. '''
  161. excitatory_rates = [get_rates(net["excitatory_spike_monitor"], skip) for net in nets]
  162. '''
  163. Get rate distribution over angles
  164. '''
  165. return ex_tunings, excitatory_rates
  166. def get_rates(spike_monitor, min_time=0 * ms):
  167. list_of_spike_times = spike_monitor.spike_trains().values()
  168. rates = calculate_rates(list_of_spike_times, min_time)
  169. return rates
  170. def calculate_rates(list_of_spike_times, min_time=0*ms):
  171. isis = [np.ediff1d(np.extract(spike_times / ms > min_time / ms, spike_times / ms)) * ms for spike_times in
  172. list_of_spike_times]
  173. rates = np.array([1.0 / np.mean(isi / ms) if isi.shape[0] != 0 else 0 for isi in isis]) * khertz
  174. return rates
  175. def create_head_direction_input(ex_input_baseline, ex_tunings, input_sharpness, max_head_direction_input_amplitude,
  176. tuning_center):
  177. peak_phase = tuning_center
  178. direction_input = get_head_direction_input(peak_phase, input_sharpness)
  179. input_to_excitatory_population = ex_input_baseline + max_head_direction_input_amplitude * direction_input(
  180. np.array(ex_tunings))
  181. return input_to_excitatory_population
  182. def get_synaptic_weights(N_E, N_I, ie_connections, excitatory_synapse_strength, inhibitory_synapse_strength):
  183. in_ex_weights = np.zeros((N_I, N_E)) * nS
  184. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections):
  185. in_ex_weights[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  186. ex_in_weights = np.where(in_ex_weights > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  187. return ex_in_weights, in_ex_weights
  188. if __name__ == "__main__":
  189. corr_len_range = range(0, 451, 15)
  190. corr_len_range_len = len(corr_len_range)
  191. tuning_range_len = 12
  192. tuning_range = np.linspace(-np.pi, np.pi, tuning_range_len, endpoint=False)
  193. seed_range = range(10)
  194. seed_range_len = len(seed_range)
  195. pool_arguments = itertools.product(corr_len_range, seed_range, tuning_range)
  196. use_saved_array = True
  197. if not use_saved_array:
  198. pool = multiprocessing.Pool()
  199. data = pool.starmap(get_fwhm_for_corr_len_and_seed, [*pool_arguments])
  200. print(type(data))
  201. # print(data)
  202. # data_array = np.reshape(np.array(data),(corr_len_range_len,seed_range_len,tuning_range_len,3))
  203. np.save('../../simulations/2020_02_27_head_direction_index_over_noise_scale/data.npy', np.array(data))
  204. else:
  205. data = np.load('../../simulations/2020_02_27_head_direction_index_over_noise_scale/data.npy', allow_pickle=True)
  206. print('Calculating HDI')
  207. no_conn_trial_hdi_array = np.array((corr_len_range_len, seed_range_len, tuning_range_len))
  208. ellipse_trial_hdi_array = np.array((corr_len_range_len, seed_range_len, tuning_range_len))
  209. circle_trial_hdi_array = np.array((corr_len_range_len, seed_range_len, tuning_range_len))
  210. # pool_arguments = itertools.product(corr_len_range, seed_range, tuning_range)
  211. # ex_tunings, excitatory_rates = data[0]
  212. # plt.plot(ex_tunings, excitatory_rates[0] / hertz)
  213. # plt.show()
  214. # for id, corr_len, tuning_center, seed in enumerate(pool_arguments):
  215. # ex_tunings, excitatory_rates = data[id]
  216. # print(id, corr_len, tuning_center, seed)
  217. circle_mean_hdi_overall = []
  218. no_conn_mean_hdi_overall = []
  219. ellipse_mean_hdi_overall = []
  220. tuning_vectors_test = []
  221. for cl_id, corr_len in enumerate(tqdm(corr_len_range)):
  222. circle_mean_hdi_per_corr_len = []
  223. no_conn_mean_hdi_per_corr_len = []
  224. ellipse_mean_hdi_per_corr_len = []
  225. for s_id, seed in enumerate(seed_range):
  226. circle_tuning_vector_list = np.zeros((N_E, 2))
  227. ellipse_tuning_vector_list = np.zeros((N_E, 2))
  228. no_conn_tuning_vector_list = np.zeros((N_E, 2))
  229. circle_rates_sum = np.zeros(N_E)
  230. ellipse_rates_sum = np.zeros(N_E)
  231. no_conn_rates_sum = np.zeros(N_E)
  232. for t_id, tuning_center in enumerate(tuning_range):
  233. total_id = t_id + tuning_range_len * s_id + tuning_range_len * seed_range_len * cl_id
  234. # print(cl_id, s_id, t_id, total_id)
  235. ex_tunings, excitatory_rates = data[total_id]
  236. circle_tuning_vector_list_at_tuning = np.array([np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  237. * rate for rate in excitatory_rates[2]])
  238. circle_tuning_vector_list = circle_tuning_vector_list + circle_tuning_vector_list_at_tuning
  239. circle_rates_sum += excitatory_rates[2]
  240. ellipse_tuning_vector_list_at_tuning = np.array(
  241. [np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  242. * rate for rate in excitatory_rates[1]])
  243. ellipse_tuning_vector_list = ellipse_tuning_vector_list + ellipse_tuning_vector_list_at_tuning
  244. ellipse_rates_sum += excitatory_rates[1]
  245. no_conn_tuning_vector_list_at_tuning = np.array(
  246. [np.array([np.cos(tuning_center), np.sin(tuning_center)]) \
  247. * rate for rate in excitatory_rates[0]])
  248. no_conn_tuning_vector_list = no_conn_tuning_vector_list + no_conn_tuning_vector_list_at_tuning
  249. no_conn_rates_sum += excitatory_rates[0]
  250. if cl_id == 0 and s_id == 0:
  251. print('rates: \n', excitatory_rates[0][0])
  252. print('vectors: \n', no_conn_tuning_vector_list_at_tuning[0])
  253. print('vec norm: \n', np.linalg.norm(no_conn_tuning_vector_list_at_tuning[0]))
  254. tuning_vectors_test.append(no_conn_tuning_vector_list_at_tuning[0])
  255. circle_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in
  256. zip(circle_tuning_vector_list, circle_rates_sum)]
  257. circle_hdi_mean_over_pop_list = np.sum(circle_hdi_list) / len(circle_hdi_list)
  258. circle_mean_hdi_per_corr_len.append(circle_hdi_mean_over_pop_list)
  259. ellipse_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in
  260. zip(ellipse_tuning_vector_list, ellipse_rates_sum)]
  261. ellipse_hdi_mean_over_pop_list = np.sum(ellipse_hdi_list) / len(ellipse_hdi_list)
  262. ellipse_mean_hdi_per_corr_len.append(ellipse_hdi_mean_over_pop_list)
  263. no_conn_hdi_list = [np.linalg.norm(vec) / rate_sum for vec, rate_sum in
  264. zip(no_conn_tuning_vector_list, no_conn_rates_sum)]
  265. no_conn_hdi_mean_over_pop_list = np.sum(no_conn_hdi_list) / len(no_conn_hdi_list)
  266. no_conn_mean_hdi_per_corr_len.append(no_conn_hdi_mean_over_pop_list)
  267. # print(circle_tuning_vector_mean)
  268. circle_mean_hdi_overall.append(circle_mean_hdi_per_corr_len)
  269. no_conn_mean_hdi_overall.append(no_conn_mean_hdi_per_corr_len)
  270. ellipse_mean_hdi_overall.append(ellipse_mean_hdi_per_corr_len)
  271. plt.figure()
  272. plt.scatter(np.array(tuning_vectors_test)[:, 0], np.array(tuning_vectors_test)[:, 1])
  273. plt.show()
  274. # print(data.shape)
  275. #
  276. # for i in range(corr_len_range_len):
  277. # no_conn_trial_sharpening_list.append(data[i,:,0])
  278. # ellipse_trial_sharpening_list.append(data[i,:,1])
  279. # circle_trial_sharpening_list.append(data[i,:,2])
  280. # print(circle_trial_sharpening_list)
  281. ellipse_sharpening_mean = np.array([np.mean(i) for i in ellipse_mean_hdi_overall])
  282. circle_sharpening_mean = np.array([np.mean(i) for i in circle_mean_hdi_overall])
  283. no_conn_sharpening_mean = np.array([np.mean(i) for i in no_conn_mean_hdi_overall])
  284. ellipse_sharpening_std_dev = np.array([np.std(i) for i in ellipse_mean_hdi_overall])
  285. circle_sharpening_std_dev = np.array([np.std(i) for i in circle_mean_hdi_overall])
  286. no_conn_sharpening_std_dev = np.array([np.std(i) for i in no_conn_mean_hdi_overall])
  287. # print(ellipse_trial_sharpening_list)
  288. # print(ellipse_entropy_std_dev)
  289. plt.figure()
  290. plt.plot(corr_len_range, circle_sharpening_mean, label='Circle', marker='o', color='C1')
  291. plt.fill_between(corr_len_range, circle_sharpening_mean - circle_sharpening_std_dev,
  292. circle_sharpening_mean + circle_sharpening_std_dev, color='C1', alpha=0.4)
  293. plt.plot(corr_len_range, ellipse_sharpening_mean, label='Ellipse', marker='o', color='C2')
  294. plt.fill_between(corr_len_range, ellipse_sharpening_mean - ellipse_sharpening_std_dev,
  295. ellipse_sharpening_mean + ellipse_sharpening_std_dev, color='C2', alpha=0.4)
  296. plt.plot(corr_len_range, no_conn_sharpening_mean, label='No Conn.', marker='o', color='C3')
  297. plt.fill_between(corr_len_range, no_conn_sharpening_mean - no_conn_sharpening_std_dev,
  298. no_conn_sharpening_mean + no_conn_sharpening_std_dev, color='C3', alpha=0.4)
  299. plt.xlabel('Correlation length')
  300. plt.ylabel('Head Direction Index')
  301. plt.legend()
  302. plt.show()