sharpening_over_noise_scale.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import matplotlib.pyplot as plt
  2. import noise
  3. import numpy as np
  4. from brian2.units import *
  5. from scripts.spatial_maps.orientation_map import OrientationMap
  6. from scripts.interneuron_placement import create_interneuron_sheet_entropy_max_orientation
  7. from tqdm import tqdm
  8. import scripts.models as modellib
  9. from scripts.ring_network.head_direction import get_head_direction_input, \
  10. ex_in_network
  11. from scipy.optimize import curve_fit
  12. from scripts.interneuron_placement import create_grid_of_excitatory_neurons, \
  13. create_interneuron_sheet_by_repulsive_force, get_excitatory_neurons_in_inhibitory_axonal_clouds
  14. use_saved_array = False
  15. trials_per_scale = 1
  16. N_E = 900
  17. N_I = 90
  18. sheet_x = 450 * um
  19. sheet_y = 450 * um
  20. inhibitory_axon_long_axis = 100 * um
  21. inhibitory_axon_short_axis = 25 * um
  22. number_of_excitatory_neurons_per_row = int(np.sqrt(N_E))
  23. '''
  24. Neuron and synapse models
  25. '''
  26. excitatory_eqs = modellib.hodgkin_huxley_eqs_with_synaptic_conductance + modellib.eqs_ih
  27. excitatory_params = modellib.hodgkin_huxley_params
  28. excitatory_params.update(modellib.ih_params)
  29. excitatory_params.update({"E_i": -80 * mV})
  30. excitatory_params['ghbar'] = 0. * nS
  31. lif_interneuron_eqs = """
  32. dv/dt =1.0/tau* (-v + u_ext) :volt (unless refractory)
  33. u_ext = u_ext_const : volt
  34. """
  35. lif_interneuron_params = {
  36. "tau": 7 * ms,
  37. "v_threshold": -40 * mV,
  38. "v_reset": -60 * mV,
  39. "tau_refractory": 0.0 * ms,
  40. "u_ext_const": -50 * mV
  41. }
  42. lif_interneuron_options = {
  43. "threshold": "v>v_threshold",
  44. "reset": "v=v_reset",
  45. "refractory": "tau_refractory",
  46. 'method': 'euler'
  47. }
  48. ei_synapse_model = modellib.delta_synapse_model
  49. ei_synapse_on_pre = modellib.delta_synapse_on_pre
  50. ei_synapse_param = modellib.delta_synapse_param
  51. ie_synapse_model = modellib.exponential_synapse
  52. ie_synapse_on_pre = modellib.exponential_synapse_on_pre
  53. ie_synapse_param = modellib.exponential_synapse_params
  54. ie_synapse_param["tau_syn"] = 2 * ms
  55. '''
  56. Tuning Maps
  57. '''
  58. # tuning_label = "Perlin"
  59. tuning_label = "Orientation map"
  60. # optimization_label = "Repulsive"
  61. optimization_label = "Entropy Optimization"
  62. corr_len_list = range(0,301,15) #For these values maps exist
  63. ellipse_trial_sharpening_list = []
  64. circle_trial_sharpening_list = []
  65. no_conn_trial_sharpening_list = []
  66. if not use_saved_array:
  67. for corr_len in tqdm(corr_len_list, desc="Calculating sharpening over scale"):
  68. ellipse_single_trial_sharpening_list = []
  69. circle_single_trial_sharpening_list = []
  70. no_conn_single_trial_sharpening_list = []
  71. for seed in range(6):
  72. print(corr_len,seed)
  73. if tuning_label == "Perlin": #TODO: How to handle scale in Perlin
  74. tuning_map = lambda x, y: noise.pnoise2(x / 100.0, y / 100.0, octaves=2)*np.pi
  75. elif tuning_label == "Orientation map":
  76. map = OrientationMap(number_of_excitatory_neurons_per_row + 1,number_of_excitatory_neurons_per_row + 1,
  77. corr_len,sheet_x/um,sheet_y/um,seed)
  78. # map.improve(10)
  79. try:
  80. map.load_orientation_map()
  81. except:
  82. print('No map yet with {}x{} pixels and {} pixel correllation length and {} seed'.format(map.x_dim, map.y_dim, map.corr_len, map.rnd_seed))
  83. continue
  84. tuning_map = lambda x, y: map.tuning(x, y)
  85. ex_positions, ex_tunings = create_grid_of_excitatory_neurons(sheet_x / um, sheet_y / um,
  86. number_of_excitatory_neurons_per_row, tuning_map)
  87. inhibitory_radial_axis = np.sqrt(inhibitory_axon_long_axis * inhibitory_axon_short_axis)
  88. if optimization_label == "Repulsive":
  89. inhibitory_axonal_clouds = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_axon_long_axis / um,
  90. inhibitory_axon_short_axis / um, sheet_x / um,
  91. sheet_y / um, random_seed=2, n_iterations=1000)
  92. inhibitory_axonal_circles = create_interneuron_sheet_by_repulsive_force(N_I, inhibitory_radial_axis / um,
  93. inhibitory_radial_axis / um, sheet_x / um,
  94. sheet_y / um, random_seed=2, n_iterations=1000)
  95. elif optimization_label == "Entropy Optimization":
  96. inhibitory_axonal_clouds, ellipse_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(ex_positions, ex_tunings, N_I, inhibitory_axon_long_axis / um,
  97. inhibitory_axon_short_axis / um, sheet_x / um,
  98. sheet_y / um, trial_orientations=30)
  99. inhibitory_axonal_circles, circle_single_trial_entropy = create_interneuron_sheet_entropy_max_orientation(ex_positions, ex_tunings, N_I, inhibitory_radial_axis / um,
  100. inhibitory_radial_axis / um, sheet_x / um,
  101. sheet_y / um, trial_orientations=1)
  102. interneuron_tunings = [inhibitory_axonal_clouds, inhibitory_axonal_circles]
  103. '''
  104. Connectvities
  105. '''
  106. # Spatial network with ellipsoid axons
  107. ie_connections = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions, inhibitory_axonal_clouds)
  108. inhibitory_synapse_strength = 30 * nS
  109. in_ex_weights = np.zeros((N_I, N_E)) * nS
  110. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections):
  111. in_ex_weights[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  112. excitatory_synapse_strength = 1 * mV
  113. ex_in_weights = np.where(in_ex_weights > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  114. # Spatial network with circular axons
  115. ie_connections_circle = get_excitatory_neurons_in_inhibitory_axonal_clouds(ex_positions,
  116. inhibitory_axonal_circles)
  117. in_ex_weights_circle = np.zeros((N_I, N_E)) * nS
  118. for interneuron_idx, connected_excitatory_idxs in enumerate(ie_connections_circle):
  119. in_ex_weights_circle[interneuron_idx, connected_excitatory_idxs] = inhibitory_synapse_strength
  120. excitatory_synapse_strength = 1 * mV
  121. ex_in_weights_circle = np.where(in_ex_weights_circle > 0 * nS, excitatory_synapse_strength, 0 * mV).T * volt
  122. # No synapses
  123. no_conn_ie = np.zeros((N_I, N_E)) * nS
  124. no_conn_ei = np.zeros((N_E, N_I)) * mV
  125. '''
  126. Prepare nets
  127. '''
  128. nets = []
  129. connectivity_label = ["No synapse", "Ellipsoid", "Circle"]
  130. connectivities = [(no_conn_ei, no_conn_ie), (ex_in_weights, in_ex_weights),
  131. (ex_in_weights_circle, in_ex_weights_circle)]
  132. for ei_weights, ie_weights in connectivities:
  133. net = ex_in_network(N_E, N_I, excitatory_eqs, excitatory_params, lif_interneuron_eqs,
  134. lif_interneuron_params,
  135. lif_interneuron_options, ei_synapse_model, ei_synapse_on_pre,
  136. ei_synapse_param,
  137. ei_weights, ie_synapse_model, ie_synapse_on_pre,
  138. ie_synapse_param, ie_weights, random_seed=2)
  139. nets.append(net)
  140. '''
  141. Head direction input
  142. '''
  143. ex_input_baseline = 0.0 * nA
  144. peak_phase = 0
  145. input_sharpness = 1
  146. direction_input = get_head_direction_input(peak_phase, input_sharpness)
  147. max_head_direction_input_amplitude = 0.5 * nA
  148. input_to_excitatory_population = ex_input_baseline + max_head_direction_input_amplitude * direction_input(
  149. np.array(ex_tunings))
  150. '''
  151. Run simulation
  152. '''
  153. skip = 100 * ms
  154. length = 1100 * ms
  155. duration = skip + length
  156. for net in nets:
  157. excitatory_neurons = net["excitatory_neurons"]
  158. excitatory_neurons.I = input_to_excitatory_population
  159. net.run(duration)
  160. '''
  161. Get spatial map of rates
  162. '''
  163. def get_rates(spike_monitor, min_time=0 * ms):
  164. spike_trains = spike_monitor.spike_trains()
  165. isis = [np.ediff1d(np.extract(spike_times / ms > skip / ms, spike_times / ms)) * ms for spike_times in
  166. spike_trains.values()]
  167. rates = np.array([1.0 / np.mean(isi / ms) if isi.shape[0] != 0 else 0 for isi in isis]) * khertz
  168. return rates
  169. excitatory_rates = [get_rates(net["excitatory_spike_monitor"]) for net in nets]
  170. '''
  171. Get rate distribution over angles
  172. '''
  173. def gauss(x, *p):
  174. A, mu, sigma, B = p
  175. return A * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + B
  176. p0 = [1., 0., 1., 0.]
  177. # rate_dist_bins = np.linspace(-np.pi, np.pi, 40)
  178. fwhm_list = []
  179. # fig, ax = plt.subplots(1, 1)
  180. for rates, label in zip(excitatory_rates, connectivity_label):
  181. # tuning_mean, bin_edges, bin_number = binned_statistic(ex_tunings, rates / hertz, statistic='mean',
  182. # bins=rate_dist_bins)
  183. coeff, var_matrix = curve_fit(gauss, ex_tunings, rates / hertz, p0=p0)
  184. contrast = np.max(gauss(ex_tunings,*coeff)) - np.min(gauss(ex_tunings,*coeff))
  185. print('Fitted contrast = ', contrast)
  186. fwhm_list.append(contrast)
  187. # print('Fitted standard deviation = ', np.abs(coeff[2]))
  188. # fwhm_list.append(2.355*np.abs(coeff[2]))
  189. no_conn_single_trial_sharpening_list.append(fwhm_list[0])
  190. ellipse_single_trial_sharpening_list.append(fwhm_list[1])
  191. circle_single_trial_sharpening_list.append(fwhm_list[2])
  192. ellipse_trial_sharpening_list.append(ellipse_single_trial_sharpening_list)
  193. circle_trial_sharpening_list.append(circle_single_trial_sharpening_list)
  194. no_conn_trial_sharpening_list.append(no_conn_single_trial_sharpening_list)
  195. np.save('../simulations/2020_02_27_sharpening_over_noise_scale/circle_trial_sharpening_list.npy',circle_trial_sharpening_list)
  196. np.save('../simulations/2020_02_27_sharpening_over_noise_scale/ellipse_trial_sharpening_list.npy',ellipse_trial_sharpening_list)
  197. np.save('../simulations/2020_02_27_sharpening_over_noise_scale/no_conn_trial_sharpening_list.npy',no_conn_trial_sharpening_list)
  198. else:
  199. circle_trial_entropy_list = np.load(
  200. '../../simulations/2020_02_27_entropy_over_noise_scale/circle_trial_entropy_list.npy')
  201. ellipse_trial_entropy_list = np.load(
  202. '../../simulations/2020_02_27_entropy_over_noise_scale/ellipse_trial_entropy_list.npy')
  203. print(circle_trial_sharpening_list)
  204. ellipse_sharpening_mean = np.array([np.mean(i) for i in ellipse_trial_sharpening_list])
  205. circle_sharpening_mean = np.array([np.mean(i) for i in circle_trial_sharpening_list])
  206. no_conn_sharpening_mean = np.array([np.mean(i) for i in no_conn_trial_sharpening_list])
  207. ellipse_sharpening_std_dev = np.array([np.std(i) for i in ellipse_trial_sharpening_list])
  208. circle_sharpening_std_dev = np.array([np.std(i) for i in circle_trial_sharpening_list])
  209. no_conn_sharpening_std_dev = np.array([np.std(i) for i in no_conn_trial_sharpening_list])
  210. # print(ellipse_trial_sharpening_list)
  211. # print(ellipse_entropy_std_dev)
  212. plt.figure()
  213. plt.plot(corr_len_list,circle_sharpening_mean, label='Circle', marker='o',color='C1')
  214. plt.fill_between(corr_len_list,circle_sharpening_mean-circle_sharpening_std_dev,circle_sharpening_mean+circle_sharpening_std_dev,color='C1',alpha=0.4)
  215. plt.plot(corr_len_list,ellipse_sharpening_mean, label='Ellipse', marker='o',color='C2')
  216. plt.fill_between(corr_len_list,ellipse_sharpening_mean-ellipse_sharpening_std_dev,ellipse_sharpening_mean+ellipse_sharpening_std_dev,color='C2',alpha=0.4)
  217. plt.plot(corr_len_list,no_conn_sharpening_mean, label='No Conn.', marker='o',color='C3')
  218. plt.fill_between(corr_len_list,no_conn_sharpening_mean-no_conn_sharpening_std_dev,no_conn_sharpening_mean+no_conn_sharpening_std_dev,color='C3',alpha=0.4)
  219. plt.xlabel('Correlation length')
  220. plt.ylabel('Contrast')
  221. plt.legend()
  222. plt.show()