analyse_synaptic_strength_orientation_map.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import multiprocessing
  2. import numpy as np
  3. from brian2.units import Hz
  4. from pypet import Trajectory
  5. from pypet.brian2 import Brian2MonitorResult, Brian2Result
  6. from tqdm import tqdm
  7. from scripts.spatial_network.head_direction_index_over_noise_scale import calculate_rates
  8. from scripts.spatial_network.run_synaptic_strength_scan_orientation_map_small_scale import DATA_FOLDER, TRAJ_NAME
  9. n_exc_test_print = 465
  10. def get_spike_train_dictionary(number_of_neurons, spike_times, neuron_indices):
  11. spike_train_dict = {}
  12. for neuron_idx in range(number_of_neurons):
  13. spike_train_dict[neuron_idx] = []
  14. for neuron_idx, t in zip(neuron_indices, spike_times):
  15. spike_train_dict[neuron_idx].append(t)
  16. return spike_train_dict
  17. def get_firing_rate_dict_per_cell_and_direction(traj, run_name):
  18. traj.f_set_crun(run_name)
  19. firing_rate_dict = {}
  20. direction_names = ["dir{:d}".format(idx) for idx in range(traj.input.number_of_directions)]
  21. for idx in range(traj.N_E):
  22. firing_rate_dict[idx] = []
  23. for direction in direction_names:
  24. number_of_neurons = traj.N_E
  25. all_spike_times = traj.results.runs[run_name][direction].spikes.e.t
  26. neuron_indices = traj.results.runs[run_name][direction].spikes.e.i
  27. ex_spike_trains = get_spike_train_dictionary(number_of_neurons, all_spike_times,
  28. neuron_indices)
  29. ex_spike_rates = calculate_rates(ex_spike_trains.values())
  30. for idx, spike_rate in enumerate(ex_spike_rates):
  31. firing_rate_dict[idx].append(spike_rate)
  32. traj.f_restore_default()
  33. return firing_rate_dict
  34. def get_firing_rate_array_per_cell_and_direction(traj, run_name):
  35. traj.f_set_crun(run_name)
  36. firing_rate_array = np.ndarray((traj.N_E, traj.input.number_of_directions))
  37. direction_names = ["dir{:d}".format(idx) for idx in range(traj.input.number_of_directions)]
  38. for dir_idx, direction in enumerate(direction_names):
  39. number_of_neurons = traj.N_E
  40. all_spike_times = traj.results.runs[run_name][direction].spikes.e.t
  41. neuron_indices = traj.results.runs[run_name][direction].spikes.e.i
  42. ex_spike_trains = get_spike_train_dictionary(number_of_neurons, all_spike_times,
  43. neuron_indices)
  44. ex_spike_rates = calculate_rates(ex_spike_trains.values())
  45. for n_idx, spike_rate in enumerate(ex_spike_rates):
  46. #TODO: Why on earth does the unit vanish?
  47. firing_rate_array[n_idx, dir_idx] = spike_rate
  48. traj.f_restore_default()
  49. return firing_rate_array
  50. def get_head_direction_indices(directions, firing_rate_array):
  51. n_exc_neurons = firing_rate_array.shape[0]
  52. n_directions = len(directions)
  53. tuning_vectors = np.zeros((n_exc_neurons,n_directions,2))
  54. rate_sums = np.zeros((n_exc_neurons,))
  55. for ex_id, ex_rates in enumerate(firing_rate_array):
  56. rate_sum = 0.
  57. for dir_id, dir in enumerate(directions):
  58. tuning_vectors[ex_id, dir_id] = np.array([np.cos(dir), np.sin(dir)]) * ex_rates[dir_id]
  59. rate_sum += ex_rates[dir_id]
  60. rate_sums[ex_id] = rate_sum
  61. tuning_vectors_return = tuning_vectors.copy()
  62. for ex_id in range(n_exc_neurons):
  63. if rate_sums[ex_id] != 0.:
  64. tuning_vectors[ex_id, :, :] /= rate_sums[ex_id]
  65. tuning_vectors_summed = np.sum(tuning_vectors, axis=1)
  66. head_direction_indices = np.array([np.linalg.norm(v) for v in tuning_vectors_summed])
  67. return head_direction_indices, tuning_vectors_return
  68. def get_runs_with_circular_morphology(traj):
  69. filtered_indices = traj.f_find_idx(('parameters.long_axis', 'parameters.short_axis'), lambda r1, r2: r1 == r2)
  70. return filtered_indices
  71. def analyse_single_run(run_name):
  72. print(run_name)
  73. firing_rate_array = get_firing_rate_array_per_cell_and_direction(traj, run_name)
  74. # firing_rate_dict = get_firing_rate_dict_per_cell_and_direction(traj, run_name)
  75. # traj.f_add_result('runs.$.firing_rate_array', firing_rate_array,
  76. # comment='The firing rates of the excitatory population')
  77. head_direction_indices, tuning_vectors = get_head_direction_indices(directions, firing_rate_array)
  78. # traj.f_add_result('runs.$.head_direction_indices', head_direction_indices,
  79. # comment='The HDIs of the excitatory population')
  80. # traj.f_add_result('runs.$.tuning_vectors', tuning_vectors,
  81. # comment='The tuning vectors of the excitatory population')
  82. return firing_rate_array, head_direction_indices, tuning_vectors
  83. if __name__ == "__main__":
  84. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  85. NO_LOADING = 0
  86. FULL_LOAD = 2
  87. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=FULL_LOAD)
  88. # traj.v_auto_load = True
  89. correlation_lengths = traj.f_get('correlation_length').f_get_range()
  90. long_axis = traj.f_get('long_axis').f_get_range()
  91. short_axis = traj.f_get('short_axis').f_get_range()
  92. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  93. circular_indices = list(get_runs_with_circular_morphology(traj))
  94. pool = multiprocessing.Pool()
  95. # for idx, run_name in enumerate(tqdm(traj.f_get_run_names())):
  96. multi_proc_result = pool.map(analyse_single_run, traj.f_get_run_names())
  97. print(type(multi_proc_result),len(multi_proc_result))
  98. for idx, run_name in enumerate(traj.f_get_run_names()):
  99. traj.f_set_crun(run_name)
  100. traj.f_add_result('runs.$.firing_rate_array', multi_proc_result[idx][0],
  101. comment='The firing rates of the excitatory population')
  102. traj.f_add_result('runs.$.head_direction_indices', multi_proc_result[idx][1],
  103. comment='The HDIs of the excitatory population')
  104. traj.f_add_result('runs.$.tuning_vectors', multi_proc_result[idx][2],
  105. comment='The tuning vectors of the excitatory population')
  106. traj.f_restore_default()
  107. traj.f_store()