analyse_entropy_maximisation_orientation_map_placement_jitter.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import multiprocessing
  2. import numpy as np
  3. from brian2.units import Hz
  4. from pypet import Trajectory
  5. from pypet.brian2 import Brian2MonitorResult, Brian2Result
  6. from scripts.spatial_network.head_direction_index_over_noise_scale import calculate_rates
  7. from scripts.spatial_network.placement_jitter.run_entropy_maximisation_orientation_map_placement_jitter import DATA_FOLDER, TRAJ_NAME
  8. traj = None
  9. directions = None
  10. def get_spike_train_dictionary(number_of_neurons, spike_times, neuron_indices):
  11. spike_train_dict = {}
  12. for neuron_idx in range(number_of_neurons):
  13. spike_train_dict[neuron_idx] = []
  14. for neuron_idx, t in zip(neuron_indices, spike_times):
  15. spike_train_dict[neuron_idx].append(t)
  16. return spike_train_dict
  17. def get_firing_rate_dict_per_cell_and_direction(traj, run_name):
  18. traj.f_set_crun(run_name)
  19. firing_rate_dict = {}
  20. direction_names = ["dir{:d}".format(idx) for idx in range(traj.input.number_of_directions)]
  21. for idx in range(traj.N_E):
  22. firing_rate_dict[idx] = []
  23. for direction in direction_names:
  24. number_of_neurons = traj.N_E
  25. all_spike_times = traj.results.runs[run_name][direction].spikes.e.t
  26. neuron_indices = traj.results.runs[run_name][direction].spikes.e.i
  27. ex_spike_trains = get_spike_train_dictionary(number_of_neurons, all_spike_times,
  28. neuron_indices)
  29. ex_spike_rates = calculate_rates(ex_spike_trains.values())
  30. for idx, spike_rate in enumerate(ex_spike_rates):
  31. firing_rate_dict[idx].append(spike_rate)
  32. traj.f_restore_default()
  33. return firing_rate_dict
  34. def get_firing_rate_array_per_cell_and_direction(traj, run_name):
  35. traj.f_set_crun(run_name)
  36. firing_rate_array = np.ndarray((traj.N_E, traj.input.number_of_directions))
  37. direction_names = ["dir{:d}".format(idx) for idx in range(traj.input.number_of_directions)]
  38. for dir_idx, direction in enumerate(direction_names):
  39. number_of_neurons = traj.N_E
  40. all_spike_times = traj.results.runs[run_name][direction].spikes.e.t
  41. neuron_indices = traj.results.runs[run_name][direction].spikes.e.i
  42. ex_spike_trains = get_spike_train_dictionary(number_of_neurons, all_spike_times,
  43. neuron_indices)
  44. ex_spike_rates = calculate_rates(ex_spike_trains.values())
  45. for n_idx, spike_rate in enumerate(ex_spike_rates):
  46. #TODO: Why on earth does the unit vanish?
  47. firing_rate_array[n_idx, dir_idx] = spike_rate
  48. traj.f_restore_default()
  49. return firing_rate_array
  50. def get_head_direction_indices(directions, firing_rate_array):
  51. n_exc_neurons = firing_rate_array.shape[0]
  52. n_directions = len(directions)
  53. tuning_vectors = np.zeros((n_exc_neurons,n_directions,2))
  54. rate_sums = np.zeros((n_exc_neurons,))
  55. for ex_id, ex_rates in enumerate(firing_rate_array):
  56. rate_sum = 0.
  57. for dir_id, dir in enumerate(directions):
  58. tuning_vectors[ex_id, dir_id] = np.array([np.cos(dir), np.sin(dir)]) * ex_rates[dir_id]
  59. rate_sum += ex_rates[dir_id]
  60. rate_sums[ex_id] = rate_sum
  61. tuning_vectors_return = tuning_vectors.copy()
  62. for ex_id in range(n_exc_neurons):
  63. if rate_sums[ex_id] != 0.:
  64. tuning_vectors[ex_id, :, :] /= rate_sums[ex_id]
  65. tuning_vectors_summed = np.sum(tuning_vectors, axis=1)
  66. head_direction_indices = np.array([np.linalg.norm(v) for v in tuning_vectors_summed])
  67. return head_direction_indices, tuning_vectors_return
  68. def get_inhibitory_firing_rate_array_per_cell_and_direction(traj, run_name):
  69. number_of_neurons = traj.N_I
  70. traj.f_set_crun(run_name)
  71. firing_rate_array = np.ndarray((number_of_neurons, traj.input.number_of_directions))
  72. direction_names = ["dir{:d}".format(idx) for idx in range(traj.input.number_of_directions)]
  73. try:
  74. traj.results.runs[run_name]['dir0'].spikes.i.t
  75. except:
  76. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  77. print('Cant find t for run {} with label {}'.format(run_name,label))
  78. return np.zeros((number_of_neurons, traj.input.number_of_directions))
  79. for dir_idx, direction in enumerate(direction_names):
  80. all_spike_times = traj.results.runs[run_name][direction].spikes.i.t
  81. neuron_indices = traj.results.runs[run_name][direction].spikes.i.i
  82. inh_spike_trains = get_spike_train_dictionary(number_of_neurons, all_spike_times,
  83. neuron_indices)
  84. inh_spike_rates = calculate_rates(inh_spike_trains.values())
  85. for n_idx, spike_rate in enumerate(inh_spike_rates):
  86. #TODO: Why on earth does the unit vanish?
  87. firing_rate_array[n_idx, dir_idx] = spike_rate
  88. traj.f_restore_default()
  89. return firing_rate_array
  90. def get_inhibitory_head_direction_indices(directions, firing_rate_array):
  91. n_inh_neurons = firing_rate_array.shape[0]
  92. n_directions = len(directions)
  93. tuning_vectors = np.zeros((n_inh_neurons,n_directions,2))
  94. rate_sums = np.zeros((n_inh_neurons,))
  95. for inh_id, inh_rates in enumerate(firing_rate_array):
  96. rate_sum = 0.
  97. for dir_id, dir in enumerate(directions):
  98. tuning_vectors[inh_id, dir_id] = np.array([np.cos(dir), np.sin(dir)]) * inh_rates[dir_id]
  99. rate_sum += inh_rates[dir_id]
  100. rate_sums[inh_id] = rate_sum
  101. tuning_vectors_return = tuning_vectors.copy()
  102. for inh_id in range(n_inh_neurons):
  103. if rate_sums[inh_id] != 0.:
  104. tuning_vectors[inh_id, :, :] /= rate_sums[inh_id]
  105. tuning_vectors_summed = np.sum(tuning_vectors, axis=1)
  106. head_direction_indices = np.array([np.linalg.norm(v) for v in tuning_vectors_summed])
  107. return head_direction_indices, tuning_vectors_return
  108. def get_runs_with_circular_morphology(traj):
  109. filtered_indices = traj.f_find_idx(('parameters.long_axis', 'parameters.short_axis'), lambda r1, r2: r1 == r2)
  110. return filtered_indices
  111. def analyse_single_run(run_name):
  112. traj.f_set_crun(run_name)
  113. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  114. print(run_name, label)
  115. if label != 'no conn':
  116. print(run_name,'inh')
  117. inh_firing_rate_array = get_inhibitory_firing_rate_array_per_cell_and_direction(traj, run_name)
  118. inh_head_direction_indices, inh_tuning_vectors = get_inhibitory_head_direction_indices(directions, inh_firing_rate_array)
  119. else:
  120. n_inh = traj.N_I
  121. n_dir = traj.input.number_of_directions
  122. inh_firing_rate_array = np.zeros((n_inh, n_dir))
  123. inh_head_direction_indices = np.zeros(n_inh)
  124. inh_tuning_vectors = np.zeros((n_inh,n_dir,2))
  125. exc_firing_rate_array = get_firing_rate_array_per_cell_and_direction(traj, run_name)
  126. exc_head_direction_indices, exc_tuning_vectors = get_head_direction_indices(directions, exc_firing_rate_array)
  127. return exc_firing_rate_array, exc_head_direction_indices, exc_tuning_vectors, inh_firing_rate_array, inh_head_direction_indices, inh_tuning_vectors
  128. def main():
  129. global traj, directions
  130. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  131. NO_LOADING = 0
  132. FULL_LOAD = 2
  133. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=FULL_LOAD)
  134. # traj.v_auto_load = True
  135. correlation_lengths = traj.f_get('correlation_length').f_get_range()
  136. long_axis = traj.f_get('long_axis').f_get_range()
  137. short_axis = traj.f_get('short_axis').f_get_range()
  138. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  139. run_names = traj.f_get_run_names()[::-1]
  140. # run_names = ['run_00000900']
  141. pool = multiprocessing.Pool()
  142. # for idx, run_name in enumerate(tqdm(traj.f_get_run_names())):
  143. multi_proc_result = pool.map(analyse_single_run, run_names)
  144. print(type(multi_proc_result), len(multi_proc_result))
  145. for idx, run_name in enumerate(run_names):
  146. traj.f_set_crun(run_name)
  147. traj.f_add_result('runs.$.firing_rate_array', multi_proc_result[idx][0],
  148. comment='The firing rates of the excitatory population')
  149. traj.f_add_result('runs.$.head_direction_indices', multi_proc_result[idx][1],
  150. comment='The HDIs of the excitatory population')
  151. traj.f_add_result('runs.$.tuning_vectors', multi_proc_result[idx][2],
  152. comment='The tuning vectors of the excitatory population')
  153. traj.f_add_result('runs.$.inh_firing_rate_array', multi_proc_result[idx][3],
  154. comment='The firing rates of the inhibitory population')
  155. traj.f_add_result('runs.$.inh_head_direction_indices', multi_proc_result[idx][4],
  156. comment='The HDIs of the inhibitory population')
  157. traj.f_add_result('runs.$.inh_tuning_vectors', multi_proc_result[idx][5],
  158. comment='The tuning vectors of the inhibitory population')
  159. traj.f_restore_default()
  160. traj.f_store()
  161. if __name__ == "__main__":
  162. main()