figures_synaptic_strength_spatial_head_direction_network.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import numpy as np
  2. from pypet import Trajectory
  3. from pypet.brian2 import Brian2MonitorResult
  4. from tqdm import tqdm
  5. from brian2.units import *
  6. import matplotlib.pyplot as plt
  7. import matplotlib as mpl
  8. import pandas as pd
  9. from scripts.spatial_network.run_synaptic_strength_scan_orientation_map import DATA_FOLDER, TRAJ_NAME
  10. FIGURE_SAVE_PATH = '../../figures/figures_spatial_head_direction_network_orientation_map/'
  11. def plot_hdi_synaptic_strength(traj, plot_run_names):
  12. inh_strength_expl = traj.f_get('inhibitory').f_get_range()
  13. exc_strength_expl = traj.f_get('excitatory').f_get_range()
  14. seed_expl = traj.f_get('seed').f_get_range()
  15. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  16. label_expl = []
  17. for i in range(3000):
  18. label_expl.append('ellipsoid')
  19. for i in range(3000):
  20. label_expl.append('circular')
  21. print(label_expl)
  22. inh_strength_expl = inh_strength_expl[:-10]
  23. exc_strength_expl = exc_strength_expl[:-10]
  24. seed_expl = seed_expl[:-10]
  25. # label_range = list(set(label_expl))
  26. inh_strength_range = sorted(set(inh_strength_expl))
  27. print(inh_strength_range)
  28. exc_strength_range = sorted(set(exc_strength_expl))
  29. idiot_run_names = plot_run_names[:-10]
  30. hdi_frame = pd.Series(index=[inh_strength_expl, exc_strength_expl, seed_expl, label_expl])
  31. hdi_frame.index.names = ["inhibitory", "excitatory", "seed", "label"]
  32. idiot_id = 0
  33. for run_name, inh_strength, exc_strength, seed, label in tqdm(zip(idiot_run_names, inh_strength_expl, exc_strength_expl, seed_expl, label_expl), total=len(idiot_run_names)):
  34. if idiot_id >= 6000:
  35. continue
  36. # The tunings, while not used, must be accessed or the following line will fail!
  37. ex_tunings = traj.results.runs[run_name].ex_tunings
  38. head_direction_indices = traj.results[run_name].head_direction_indices
  39. hdi_frame[inh_strength, exc_strength, seed, label] = np.mean(head_direction_indices)
  40. idiot_id += 1
  41. # print(hdi_frame)
  42. # TODO: Standart deviation also for the population
  43. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 1, 3]).mean()
  44. # print(hdi_exc_n_and_seed_mean)
  45. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  46. mean_hdi_for_diff = []
  47. for ax, label in zip(axes[:-1], ['circular', 'ellipsoid']):
  48. hdi_mean = hdi_exc_n_and_seed_mean[:, :, label]
  49. # values x and y give values at z
  50. xmin = inh_strength_range[0]
  51. xmax = inh_strength_range[-1]
  52. dx = (xmax - xmin) / (len(inh_strength_range))
  53. ymin = exc_strength_range[0]
  54. ymax = exc_strength_range[-1]
  55. dy = (ymax - ymin) / (len(exc_strength_range))
  56. print(xmin, xmax, dx)
  57. print(ymin, ymax, dy)
  58. # transform x and y to boundaries of x and y
  59. # print(np.arange(xmin, xmax + dx, dx) - dx / 2.)
  60. # print(np.arange(ymin, ymax + dy, dy) - dy / 2.)
  61. print(exc_strength_range)
  62. # x = np.arange(xmin, xmax + dx, dx) - dx / 2.
  63. # y = np.arange(ymin, ymax + dy, dy) - dy / 2.
  64. x = np.linspace(xmin, xmax, len(inh_strength_range) + 1)
  65. y = np.linspace(ymin, ymax, len(exc_strength_range) + 1)
  66. print(x)
  67. print(y)
  68. X, Y = np.meshgrid(x, y)
  69. # X, Y = np.meshgrid(inh_strength_range,exc_strength_range)
  70. print(label)
  71. # print(hdi_mean)
  72. hdi_mean_plot = hdi_mean.values.reshape(len(inh_strength_range),len(exc_strength_range))
  73. mean_hdi_for_diff.append(hdi_mean_plot)
  74. # print(hdi_mean_plot)
  75. c = ax.pcolor(X, Y, hdi_mean_plot.T, vmin=0.0, vmax=0.5, cmap='hot')
  76. ax.set_title(label)
  77. fig.colorbar(c, ax=ax, label="mean HDI")
  78. ax.set_xticks(np.arange(xmin, xmax + dx, dx))
  79. ax.set_yticks(np.arange(ymin, ymax + dy, dy))
  80. ax.set_xlabel('inhibitory strength (nS)')
  81. ax.set_ylabel('excitatory strength (nS)')
  82. ax.set_xticks(np.linspace(xmin, xmax, 6))
  83. ax.set_yticks(np.linspace(ymin, ymax, 6))
  84. ax.set_xlim(x[0], x[-1])
  85. ax.set_ylim(y[0], y[-1])
  86. xmin = inh_strength_range[0]
  87. xmax = inh_strength_range[-1]
  88. dx = (xmax - xmin) / (len(inh_strength_range))
  89. ymin = exc_strength_range[0]
  90. ymax = exc_strength_range[-1]
  91. dy = (ymax - ymin) / (len(exc_strength_range))
  92. x = np.linspace(xmin, xmax, len(inh_strength_range) + 1)
  93. y = np.linspace(ymin, ymax, len(exc_strength_range) + 1)
  94. X, Y = np.meshgrid(x, y)
  95. hdi_diff_plot = mean_hdi_for_diff[1] - mean_hdi_for_diff[0]
  96. print(mean_hdi_for_diff[1])
  97. print(mean_hdi_for_diff[0])
  98. print(hdi_diff_plot)
  99. print(hdi_diff_plot.shape)
  100. ax = axes[2]
  101. c = ax.pcolor(X, Y, hdi_diff_plot.T, cmap='hot')
  102. ax.set_title('difference')
  103. fig.colorbar(c, ax=ax, label="mean HDI difference")
  104. ax.set_xticks(np.arange(xmin, xmax + dx, dx))
  105. ax.set_yticks(np.arange(ymin, ymax + dy, dy))
  106. ax.set_xlabel('inhibitory strength (nS)')
  107. ax.set_ylabel('excitatory strength (nS)')
  108. ax.set_xticks(np.linspace(xmin, xmax, 6))
  109. ax.set_yticks(np.linspace(ymin, ymax, 6))
  110. ax.set_xlim(x[0], x[-1])
  111. ax.set_ylim(y[0], y[-1])
  112. fig.suptitle('Mean HDI over syn. strength', fontsize=16)
  113. if save_dont_show:
  114. plt.savefig(FIGURE_SAVE_PATH + 'hdi_synaptic_strength.png', dpi=200)
  115. def filter_run_names_by_par_dict(traj, par_dict):
  116. run_name_list = []
  117. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  118. traj.f_set_crun(run_name)
  119. paramters_equal = True
  120. for key, val in par_dict.items():
  121. if(traj.par[key] != val):
  122. paramters_equal = False
  123. if paramters_equal:
  124. run_name_list.append(run_name)
  125. traj.f_restore_default()
  126. return run_name_list
  127. def filter_run_names_and_duplicates_because_im_an_idiot(traj, par_dict):
  128. run_name_list = []
  129. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  130. traj.f_set_crun(run_name)
  131. paramters_equal = True
  132. for key, val in par_dict.items():
  133. if(traj.par[key] != val):
  134. paramters_equal = False
  135. if paramters_equal:
  136. run_name_list.append(run_name)
  137. traj.f_restore_default()
  138. return run_name_list
  139. if __name__ == "__main__":
  140. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  141. NO_LOADING = 0
  142. FULL_LOAD = 2
  143. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
  144. traj.v_auto_load = True
  145. save_dont_show = True
  146. if save_dont_show:
  147. mpl.use('Agg')
  148. plot_hdi_synaptic_strength(traj, traj.f_get_run_names())
  149. plt.show()
  150. traj.f_restore_default()