Переглянути джерело

Reworked model figure plots

moritz 4 роки тому
батько
коміт
5ca8693236

+ 21 - 1
scripts/interneuron_placement.py

@@ -35,7 +35,7 @@ class Pickle:
 
     def get_ellipse(self):
         ell = Ellipse((self.x, self.y), 2.0 * self.a, 2.0 * self.b,
-                      angle=self.phi * 180. / np.pi + 90, linewidth=1, fill=False, zorder=2)
+                      angle=self.phi * 180. / np.pi, linewidth=1, fill=False, zorder=2)
         return ell
 
     def get_area(self):
@@ -66,6 +66,26 @@ def get_position_mesh(ex_positions):
     Y = np.array([y for _, y in ex_positions]).reshape((n_ex, n_ex))
     return X, Y
 
+def get_correct_position_mesh(ex_positions):
+    n_ex = int(np.sqrt(len(ex_positions)))
+
+    x_wrong = np.array([x for x, _ in ex_positions]).reshape((n_ex, n_ex))
+    y_wrong = np.array([y for _, y in ex_positions]).reshape((n_ex, n_ex))
+
+    x = x_wrong[0, :]
+    y = y_wrong[:, 0]
+
+    xmin = np.min(x)
+    xmax = np.max(x)
+    dx = (xmax - xmin) / (x.shape[0] - 1)
+
+    ymin = np.min(y)
+    ymax = np.max(y)
+    dy = (ymax - ymin) / (y.shape[0] - 1)
+
+    X, Y = np.meshgrid(np.arange(xmin, xmax + 2 * dx, dx) - dx / 2., np.arange(ymin, ymax + 2 * dy, dy) - dy / 2.)
+
+    return X, Y
 
 def get_entropy_of_axonal_coverage(ex_neuron_tuning, ie_connections):
     ex_phase_vals_per_pickle = []

+ 0 - 0
scripts/model_figure/__init__.py


+ 49 - 0
scripts/model_figure/max_entropy_bar_plots.py

@@ -0,0 +1,49 @@
+import matplotlib.pyplot as plt
+from matplotlib import cm
+import numpy as np
+
+FIGURE_SAVE_PATH = '../../figures/figure_4_paper_draft/'
+
+
+cmap = cm.get_cmap('hsv')
+
+high_entropy = [1, 2, 2, 1]
+low_entropy = [0, 2, 4, 0]
+
+
+bars = ('0°', '90°', '180°', '270°')
+xticks = [0, 90, 180, 270]
+y_pos = np.arange(len(bars))
+
+ax = plt.subplot(111)
+
+ax.spines['right'].set_visible(False)
+ax.spines['top'].set_visible(False)
+
+ax.bar(y_pos, low_entropy, color=[cmap(0.), cmap(0.25), cmap(0.5), cmap(0.75)])
+ax.set_xticks(y_pos)
+ax.set_xticklabels(bars, fontsize=20)
+ax.set_yticks([])
+ax.set_ylim(0,4)
+
+for axis in ['bottom','left']:
+  ax.spines[axis].set_linewidth(2.5)
+
+plt.savefig(FIGURE_SAVE_PATH + 'low_entropy_bar_plot.png', dpi=200)
+plt.close()
+
+ax2 = plt.subplot(111)
+
+ax2.spines['right'].set_visible(False)
+ax2.spines['top'].set_visible(False)
+
+ax2.bar(y_pos, high_entropy, color=[cmap(0.), cmap(0.25), cmap(0.5), cmap(0.75)])
+ax2.set_xticks(y_pos)
+ax2.set_xticklabels(bars, fontsize=20)
+ax2.set_yticks([])
+ax2.set_ylim(0,4)
+
+for axis in ['bottom','left']:
+  ax2.spines[axis].set_linewidth(2.5)
+
+plt.savefig(FIGURE_SAVE_PATH + 'high_entropy_bar_plot.png', dpi=200)

+ 27 - 0
scripts/model_figure/plot_circular_colorbar.py

@@ -0,0 +1,27 @@
+import matplotlib.pyplot as plt
+from matplotlib.patches import Ellipse
+import numpy as np
+
+FIGURE_SAVE_PATH = '../../figures/figure_4_paper_draft/'
+
+azimuths = np.arange(0, 361, 1)
+zeniths = np.arange(55, 70, 1)
+values = azimuths * np.ones((15, 361))
+fig, ax = plt.subplots(subplot_kw=dict(projection='polar'))
+ax.pcolormesh(azimuths*np.pi/180.0, zeniths, values, cmap='hsv')
+# ax.set_yticks([])
+ax.set_rticks([55.0])  # less radial ticks
+ax.grid(True)
+
+y_tick_labels = []
+ax.set_yticklabels(y_tick_labels)
+
+gridlines = ax.yaxis.get_gridlines()
+gridlines[0].set_color("k")
+gridlines[0].set_linewidth(1.0)
+
+gridlines = ax.xaxis.get_gridlines()
+[line.set_linewidth(0.0) for line in gridlines]
+
+# plt.show()
+plt.savefig(FIGURE_SAVE_PATH + 'circular_color_bar.png', dpi=200)

+ 131 - 41
scripts/spatial_network/figures_spatial_head_direction_network_orientation_map.py

@@ -57,7 +57,7 @@ def plot_tuning_curve(traj, direction_idx, plot_run_names):
     ax.set_title('tuning curves', fontsize=16)
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png', dpi=200)
 
 
 def plot_firing_rate_map(traj, direction_idx, plot_run_names):
@@ -84,7 +84,7 @@ def plot_firing_rate_map(traj, direction_idx, plot_run_names):
     fig.suptitle('spatial firing rate map', fontsize=16)
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png', dpi=200)
 
 
 def plot_spatial_hdi_map(traj, plot_run_names):
@@ -108,7 +108,7 @@ def plot_spatial_hdi_map(traj, plot_run_names):
     fig.suptitle('spatial HDI map', fontsize=16)
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png', dpi=200)
 
 
 def plot_hdi_in_space(ax, positions, head_direction_indices, max_val=1):
@@ -143,14 +143,14 @@ def plot_hdi_over_tuning(traj, plot_run_names):
     ax.set_title('hdi over input tuning', fontsize=16)
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png', dpi=200)
 
 
 def plot_axonal_clouds(traj, plot_run_names):
     n_ex = int(np.sqrt(traj.N_E))
 
-    fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
-    for ax, run_name in zip(axes, plot_run_names):
+    fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
+    for ax, run_name in zip(axes, plot_run_names[:-1]):
         traj.f_set_crun(run_name)
 
         label = traj.derived_parameters.runs[run_name].morphology.morph_label
@@ -175,8 +175,72 @@ def plot_axonal_clouds(traj, plot_run_names):
     traj.f_restore_default()
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png', dpi=200)
+
+def plot_orientation_maps_diff_scales(traj):
+
+    n_ex = int(np.sqrt(traj.N_E))
+
+    scale_run_names = []
+    plot_scales = [0.0, 100.0, 200.0, 300.0]
+    for scale in plot_scales:
+        par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
+        scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
+
+    fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
+    for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
+        traj.f_set_crun(run_name)
+
+        X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
+
+        head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
+        # TODO: Why was this transposed for plotting? (now changed)
+        c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
+        ax.set_title('Correlation length: {}'.format(scale))
+        fig.colorbar(c, ax=ax, label="Tuning")
+
+    # fig.suptitle('axonal cloud', fontsize=16)
+    traj.f_restore_default()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png', dpi=200)
+
+
+def plot_orientation_maps_diff_scales_with_ellipse(traj):
+    n_ex = int(np.sqrt(traj.N_E))
+
+    scale_run_names = []
+    plot_scales = [0.0, 100.0, 200.0, 300.0]
+    for scale in plot_scales:
+        par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
+        scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
+    print(scale_run_names)
+
+    fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
+    for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
+        traj.f_set_crun(run_name)
+
+        X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
+
+        inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
+        axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
+                         inhibitory_axonal_cloud_array]
+
+        head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
+        # TODO: Why was this transposed for plotting? (now changed)
+        c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
+        ax.set_title('Correlation length: {}'.format(scale))
+        fig.colorbar(c, ax=ax, label="Tuning")
 
+        p = axonal_clouds[45]
+        ell = p.get_ellipse()
+        ell._linewidth = 3.
+        ax.add_artist(ell)
+    # fig.suptitle('axonal cloud', fontsize=16)
+    traj.f_restore_default()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales_with_ellipse.png', dpi=200)
 
 def plot_in_degree_map(traj, plot_run_names):
     n_ex = int(np.sqrt(traj.N_E))
@@ -188,8 +252,8 @@ def plot_in_degree_map(traj, plot_run_names):
         if run_max_degree > max_degree:
             max_degree = run_max_degree
 
-    fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
-    for ax, run_name in zip(axes, plot_run_names):
+    fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
+    for ax, run_name in zip(axes, plot_run_names[:-1]):
         traj.f_set_crun(run_name)
 
         label = traj.derived_parameters.runs[run_name].morphology.morph_label
@@ -207,7 +271,7 @@ def plot_in_degree_map(traj, plot_run_names):
     traj.f_restore_default()
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png', dpi=200)
 
 
 def plot_example_polar_plots(traj, plot_run_names):
@@ -237,33 +301,35 @@ def plot_example_polar_plots(traj, plot_run_names):
             ax.set_title(label)
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'example_polar_plots.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'example_polar_plots.png', dpi=200)
 
 def plot_condensed_polar_plot(traj, plot_run_names):
     directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
     directions_plt = list(directions)
     directions_plt.append(directions[0])
     fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5), subplot_kw=dict(projection='polar'))
-    head_direction_indices = traj.results.runs[plot_run_names[1]].head_direction_indices
-    max_hdi_idx = np.argmax(head_direction_indices)
-
+    head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
+    sorted_ids = np.argsort(head_direction_indices)
+    plot_n_idx = sorted_ids[-75]
 
     for run_idx, run_name in enumerate(plot_run_names):
         # ax = axes[max_hdi_idx, run_idx]
         label = traj.derived_parameters.runs[run_name].morphology.morph_label
 
         tuning_vectors = traj.results.runs[run_name].tuning_vectors
-        rate_plot = [np.linalg.norm(v) for v in tuning_vectors[max_hdi_idx]]
+        rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
         rate_plot.append(rate_plot[0])
+        print(rate_plot)
         # ax.scatter(tuning_vectors[exc_n_idx, :, 0], tuning_vectors[exc_n_idx, :, 1])
         ax.plot(directions_plt, rate_plot, label=label)
     ax.set_title('Firing Rate')
     # TODO: Set ticks for polar
     # ax.set_rticks(4)
     plt.legend()
+    plt.tight_layout()
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png', dpi=200)
 
 def plot_explanation_hdi_vector(traj, plot_run_names):
     directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
@@ -301,7 +367,7 @@ def plot_explanation_hdi_vector(traj, plot_run_names):
     plt.legend()
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'explanation_hdi_vector.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'explanation_hdi_vector.png', dpi=200)
 
 def plot_hdi_over_corr_len(traj, plot_run_names):
     corr_len_expl = traj.f_get('correlation_length').f_get_range()
@@ -335,7 +401,7 @@ def plot_hdi_over_corr_len(traj, plot_run_names):
     ax.set_ylabel('Head Direction Index')
     ax.legend()
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len.png', dpi=200)
 
 
 def filter_run_indices_by_par_dict(traj, par_dict):
@@ -376,7 +442,7 @@ def plot_hdi_over_in_degree(traj, plot_run_names):
     ax.set_ylabel("Head Direction Index")
     ax.set_title('hdi over in-degree', fontsize=16)
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_in_degree.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_in_degree.png', dpi=200)
 
 def plot_hdi_histogram_excitatory(traj, plot_run_names):
     labels = []
@@ -398,20 +464,23 @@ def plot_hdi_histogram_excitatory(traj, plot_run_names):
     ax.set_xlabel("HDI")
     ax.legend()
 
+    fig.tight_layout()
+
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
 
 
 def plot_hdi_histogram_inhibitory(traj, plot_run_names):
     labels = []
     hdis = []
-    colors = ['black', 'red', 'green']
+    colors = ['black', 'red']
     for run_idx, run_name in enumerate(plot_run_names):
         label = traj.derived_parameters.runs[run_name].morphology.morph_label
-        labels.append(label)
+        if label != 'no conn':
+            labels.append(label)
 
-        head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
-        hdis.append(head_direction_indices)
+            head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
+            hdis.append(head_direction_indices)
 
     fig, ax = plt.subplots(1, 1, figsize=(6, 3))
     ax.hist(hdis, color=colors, label=labels)
@@ -422,8 +491,10 @@ def plot_hdi_histogram_inhibitory(traj, plot_run_names):
     ax.set_xlabel("HDI")
     ax.legend()
 
+    fig.tight_layout()
+
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
 
 
 
@@ -438,7 +509,7 @@ def plot_firing_rate_map_diff_dir(traj, dir_indices, plot_run_names):
                 max_val = run_max_val
 
     directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
-
+    labels = []
     fig, axes = plt.subplots(4, 3, figsize=(13.5, 18.0))
     for run_idx, run_name in enumerate(plot_run_names):
         for plt_dir_idx, dir_idx in enumerate(dir_indices):
@@ -447,19 +518,31 @@ def plot_firing_rate_map_diff_dir(traj, dir_indices, plot_run_names):
             print(dir)
             ax = axes[plt_dir_idx, run_idx]
             label = traj.derived_parameters.runs[run_name].morphology.morph_label
-
+            labels.append(label)
             X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
             firing_rate_array = traj.results[run_name].firing_rate_array
             number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
             c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, dir_idx], (number_of_excitatory_neurons_per_row,
-                                                                                 number_of_excitatory_neurons_per_row)),
+                                                                           number_of_excitatory_neurons_per_row)),
                           vmin=0, vmax=max_val, cmap='hot')
-            ax.set_title('{}, input at {} deg.'.format(label,dir_idx * 30))
+            # ax.set_title('{}, input at {} deg.'.format(label,dir_idx * 30))
             fig.colorbar(c, ax=ax, label="f (Hz)")
-    fig.suptitle('spatial firing rate map', fontsize=16)
+    # fig.suptitle('spatial firing rate map', fontsize=16)
+
+    rows = ['{}°'.format(dir_idx * 30) for dir_idx in dir_indices]
+    print(labels)
+    cols = ['ellipsoid', 'circular', 'no conn.']
+
+    for ax, row in zip(axes[:, 0], rows):
+        ax.set_ylabel(row, rotation=0, size='x-large')
+
+    for ax, col in zip(axes[0, :], cols):
+        ax.set_title(col, size='x-large')
+
+    fig.tight_layout()
 
     if save_figs:
-        plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map_diff_dir.png', dpi=200)
 
 def filter_run_names_by_par_dict(traj, par_dict):
     run_name_list = []
@@ -482,7 +565,7 @@ if __name__ == "__main__":
     FULL_LOAD = 2
     traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
     traj.v_auto_load = True
-    save_figs = False
+    save_figs = True
 
     plot_corr_len = get_closest_correlation_length(traj, 200.0)
     par_dict = {'seed': 1, 'correlation_length': plot_corr_len}
@@ -496,9 +579,7 @@ if __name__ == "__main__":
     #
     # plot_firing_rate_map(traj, direction_idx, plot_run_names)
     #
-    # plot_spatial_hdi_map(traj, plot_run_names)
-    #
-    # plot_axonal_clouds(traj, plot_run_names)
+    plot_spatial_hdi_map(traj, plot_run_names)
     #
     # plot_in_degree_map(traj, plot_run_names)
     #
@@ -506,18 +587,27 @@ if __name__ == "__main__":
     #
     # plot_example_polar_plots(traj, plot_run_names)
     #
-    # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
-
     # plot_hdi_over_in_degree(traj, plot_run_names)
 
+    # plot_axonal_clouds(traj, plot_run_names)
+    #
+    # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
+    #
     # plot_condensed_polar_plot(traj, plot_run_names)
-
+    #
     # plot_explanation_hdi_vector(traj, plot_run_names)
-
+    #
     # plot_hdi_histogram_excitatory(traj, plot_run_names)
+    #
+    # plot_hdi_histogram_inhibitory(traj, plot_run_names)
 
-    plot_firing_rate_map_diff_dir(traj, dir_indices, plot_run_names)
+    # plot_firing_rate_map_diff_dir(traj, dir_indices, plot_run_names)
+
+    # plot_orientation_maps_diff_scales_with_ellipse(traj)
+    #
+    # plot_orientation_maps_diff_scales(traj)
 
-    plt.show()
+    if not save_figs:
+        plt.show()
 
     traj.f_restore_default()

+ 41 - 6
scripts/spatial_network/figures_synaptic_strength_spatial_head_direction_network.py

@@ -9,7 +9,7 @@ import pandas as pd
 
 from scripts.spatial_network.run_synaptic_strength_scan_orientation_map import DATA_FOLDER, TRAJ_NAME
 
-FIGURE_SAVE_PATH = '../../figures/figures_syn_strength_spatial_head_direction_network_orientation_map/'
+FIGURE_SAVE_PATH = '../../figures/figures_spatial_head_direction_network_orientation_map/'
 
 def plot_hdi_synaptic_strength(traj, plot_run_names):
 
@@ -50,8 +50,9 @@ def plot_hdi_synaptic_strength(traj, plot_run_names):
     # TODO: Standart deviation also for the population
     hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 1, 3]).mean()
     # print(hdi_exc_n_and_seed_mean)
-    fig, axes = plt.subplots(1, 2, figsize=(9, 4.5))
-    for ax, label in zip(axes, ['circular', 'ellipsoid']):
+    fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
+    mean_hdi_for_diff = []
+    for ax, label in zip(axes[:-1], ['circular', 'ellipsoid']):
         hdi_mean = hdi_exc_n_and_seed_mean[:, :, label]
 
         # values x and y give values at z
@@ -79,8 +80,11 @@ def plot_hdi_synaptic_strength(traj, plot_run_names):
         print(label)
         # print(hdi_mean)
         hdi_mean_plot = hdi_mean.values.reshape(len(inh_strength_range),len(exc_strength_range))
+
+        mean_hdi_for_diff.append(hdi_mean_plot)
+
         # print(hdi_mean_plot)
-        c = ax.pcolor(X, Y, hdi_mean_plot.T, vmin=0.0, vmax=1.0, cmap='hot')
+        c = ax.pcolor(X, Y, hdi_mean_plot.T, vmin=0.0, vmax=0.5, cmap='hot')
         ax.set_title(label)
         fig.colorbar(c, ax=ax, label="mean HDI")
         ax.set_xticks(np.arange(xmin, xmax + dx, dx))
@@ -91,10 +95,41 @@ def plot_hdi_synaptic_strength(traj, plot_run_names):
         ax.set_yticks(np.linspace(ymin, ymax, 6))
         ax.set_xlim(x[0], x[-1])
         ax.set_ylim(y[0], y[-1])
+
+    xmin = inh_strength_range[0]
+    xmax = inh_strength_range[-1]
+    dx = (xmax - xmin) / (len(inh_strength_range))
+    ymin = exc_strength_range[0]
+    ymax = exc_strength_range[-1]
+    dy = (ymax - ymin) / (len(exc_strength_range))
+    x = np.linspace(xmin, xmax, len(inh_strength_range) + 1)
+    y = np.linspace(ymin, ymax, len(exc_strength_range) + 1)
+    X, Y = np.meshgrid(x, y)
+
+    hdi_diff_plot = mean_hdi_for_diff[1] - mean_hdi_for_diff[0]
+    print(mean_hdi_for_diff[1])
+    print(mean_hdi_for_diff[0])
+    print(hdi_diff_plot)
+    print(hdi_diff_plot.shape)
+
+    ax = axes[2]
+
+    c = ax.pcolor(X, Y, hdi_diff_plot.T, cmap='hot')
+    ax.set_title('difference')
+    fig.colorbar(c, ax=ax, label="mean HDI difference")
+    ax.set_xticks(np.arange(xmin, xmax + dx, dx))
+    ax.set_yticks(np.arange(ymin, ymax + dy, dy))
+    ax.set_xlabel('inhibitory strength (nS)')
+    ax.set_ylabel('excitatory strength (nS)')
+    ax.set_xticks(np.linspace(xmin, xmax, 6))
+    ax.set_yticks(np.linspace(ymin, ymax, 6))
+    ax.set_xlim(x[0], x[-1])
+    ax.set_ylim(y[0], y[-1])
+
     fig.suptitle('Mean HDI over syn. strength', fontsize=16)
 
     if save_dont_show:
-        plt.savefig(FIGURE_SAVE_PATH + 'hdi_synaptic_strength.png')
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_synaptic_strength.png', dpi=200)
 
 def filter_run_names_by_par_dict(traj, par_dict):
     run_name_list = []
@@ -132,7 +167,7 @@ if __name__ == "__main__":
     traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
     traj.v_auto_load = True
 
-    save_dont_show = False
+    save_dont_show = True
 
     if save_dont_show:
         mpl.use('Agg')

+ 751 - 0
scripts/spatial_network/paper_figures_spatial_head_direction_network_orientation_map.py

@@ -0,0 +1,751 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from brian2.units import *
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from pypet import Trajectory
+from pypet.brian2 import Brian2MonitorResult
+from scipy.optimize import curve_fit
+from matplotlib.patches import Ellipse
+from matplotlib.patches import Polygon
+
+from scripts.interneuron_placement import get_position_mesh, Pickle, get_correct_position_mesh
+from scripts.spatial_network.run_entropy_maximisation_orientation_map import DATA_FOLDER, TRAJ_NAME
+
+FIGURE_SAVE_PATH = '../../figures/figure_4_paper_draft/'
+
+def get_closest_correlation_length(traj, correlation_length):
+    available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
+    closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
+    if closest_length != correlation_length:
+        print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
+            correlation_length, closest_length))
+    corr_len = closest_length
+    return corr_len
+
+def gauss(x, *p):
+    A, mu, sigma, B = p
+    return A * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + B
+
+
+def plot_tuning_curve(traj, direction_idx, plot_run_names):
+    # p0 is the initial guess for the fitting coefficients (A, mu and sigma above)
+    p0 = [30., 0., 1., 10.]
+
+    fig, ax = plt.subplots(1, 1)
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        ex_tunings = traj.results.runs[run_name].ex_tunings
+
+        coeff_list = []
+
+        ex_tunings_plt = np.array(ex_tunings)
+        sort_ids = ex_tunings_plt.argsort()
+        ex_tunings_plt = ex_tunings_plt[sort_ids]
+
+        firing_rate_array = traj.results[run_name].firing_rate_array
+        # firing_rate_array = traj.f_get('firing_rate_array')
+
+        rates_plt = firing_rate_array[:, direction_idx]
+        rates_plt = rates_plt[sort_ids]
+        coeff, var_matrix = curve_fit(gauss, ex_tunings_plt, rates_plt / hertz, p0=p0)
+        coeff_list.append(coeff)
+        ax.scatter(ex_tunings_plt, rates_plt / hertz, label=label, alpha=0.3)
+        ax.plot(ex_tunings_plt, gauss(ex_tunings_plt, *coeff), label=label + '-fit')
+
+    ax.legend()
+    ax.set_xlabel("Angles (rad)")
+    ax.set_ylabel("f (Hz)")
+    ax.set_title('tuning curves', fontsize=16)
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png', dpi=200)
+
+def colorbar(mappable):
+    from mpl_toolkits.axes_grid1 import make_axes_locatable
+    import matplotlib.pyplot as plt
+    last_axes = plt.gca()
+    ax = mappable.axes
+    fig = ax.figure
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.05)
+    cbar = fig.colorbar(mappable, cax=cax)
+    plt.sca(last_axes)
+    return cbar
+
+
+def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names):
+    max_val = 0
+    for run_name in plot_run_names:
+        fr_array = traj.results.runs[run_name].firing_rate_array
+        f_rates = fr_array[:, direction_idx]
+        run_max_val = np.max(f_rates)
+        if run_max_val > max_val:
+            # if traj.derived_parameters.runs[run_name].morphology.morph_label == 'ellipsoid':
+            #     n_id_max_rate = np.argmax(f_rates)
+            max_val = run_max_val
+
+    n_id_polar_plot = 609
+
+    # Mark the neuron that is shown in Polar plot
+    ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
+    polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
+
+    # Vertices for the plotted triangle
+    tr_scale = 13.
+    tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
+    tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
+    tr_vertices = np.array([[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
+
+    height = 4.5
+    # color_bar_size = 0.05 * height + 0.05
+    # width = 3 * height + color_bar_size
+    width = 13.5
+
+    fig, axes = plt.subplots(1, 3, figsize=(width, height))
+    for ax, run_name in zip(axes, plot_run_names[::-1]):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
+        firing_rate_array = traj.results[run_name].firing_rate_array
+        number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
+        c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
+                                                                             number_of_excitatory_neurons_per_row)),
+                      vmin=0, vmax=max_val, cmap='Reds')
+        ax.set_title(label)
+        # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
+        # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
+        ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
+        ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
+    # fig.suptitle('spatial firing rate map', fontsize=16)
+    colorbar(c)
+
+
+    fig.tight_layout()
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png', dpi=200)
+
+    return n_id_polar_plot
+
+def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names):
+    max_val = 0
+    for run_name in plot_run_names:
+        fr_array = traj.results.runs[run_name].inh_firing_rate_array
+        f_rates = fr_array[:, direction_idx]
+        run_max_val = np.max(f_rates)
+        if run_max_val > max_val:
+            max_val = run_max_val
+
+    n_id_polar_plot = 52
+
+    # Mark the neuron that is shown in Polar plot
+    inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
+    polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
+    polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
+
+    plot_run_names_sorted = [plot_run_names[1], plot_run_names[0]]
+
+    fig, axes = plt.subplots(1, 2, figsize=(9.0, 4.5))
+    for ax, run_name in zip(axes, plot_run_names_sorted):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
+
+        inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
+
+        X, Y = get_correct_position_mesh(inh_positions)
+        inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
+        number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
+        c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
+                                                                             number_of_inhibitory_neurons_per_row)),
+                      vmin=0, vmax=max_val, cmap='Blues')
+        ax.set_title(label)
+
+        circle_r = 40.
+        ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
+        ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
+        # fig.colorbar(c, ax=ax, label="f (Hz)")
+    # fig.suptitle('spatial firing rate map', fontsize=16)
+    colorbar(c)
+    fig.tight_layout()
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'inh_firing_rate_map.png', dpi=200)
+    return n_id_polar_plot, max_val
+
+def plot_hdi_over_tuning(traj, plot_run_names):
+    fig, ax = plt.subplots(1, 1)
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        ex_tunings = traj.results.runs[run_name].ex_tunings
+
+        ex_tunings_plt = np.array(ex_tunings)
+        sort_ids = ex_tunings_plt.argsort()
+        ex_tunings_plt = ex_tunings_plt[sort_ids]
+
+        head_direction_indices = traj.results[run_name].head_direction_indices
+
+        hdi_plt = head_direction_indices
+        hdi_plt = hdi_plt[sort_ids]
+        ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
+
+    ax.legend()
+    ax.set_xlabel("Angles (rad)")
+    ax.set_ylabel("head direction index")
+    ax.set_title('hdi over input tuning', fontsize=16)
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png', dpi=200)
+
+
+def plot_axonal_clouds(traj, plot_run_names):
+    n_ex = int(np.sqrt(traj.N_E))
+
+    fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
+    for ax, run_name in zip(axes, plot_run_names[::-1]):
+        traj.f_set_crun(run_name)
+
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
+
+        inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
+        axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
+                         inhibitory_axonal_cloud_array]
+
+        head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
+        # TODO: Why was this transposed for plotting? (now changed)
+        c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
+        ax.set_title(label)
+        # fig.colorbar(c, ax=ax, label="Tuning")
+
+        if label != 'no conn' and axonal_clouds is not None:
+            for i, p in enumerate(axonal_clouds):
+                ell = p.get_ellipse()
+                ax.add_artist(ell)
+    # fig.suptitle('axonal cloud', fontsize=16)
+    traj.f_restore_default()
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png', dpi=200)
+
+def plot_orientation_maps_diff_scales(traj):
+
+    n_ex = int(np.sqrt(traj.N_E))
+
+    scale_run_names = []
+    plot_scales = [0.0, 100.0, 200.0, 300.0]
+    for scale in plot_scales:
+        par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
+        scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
+
+    fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
+    for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
+        traj.f_set_crun(run_name)
+
+        X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
+
+        head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
+        # TODO: Why was this transposed for plotting? (now changed)
+        c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
+        ax.set_title('Correlation length: {}'.format(scale))
+        fig.colorbar(c, ax=ax, label="Tuning")
+
+    # fig.suptitle('axonal cloud', fontsize=16)
+    traj.f_restore_default()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png', dpi=200)
+
+
+def plot_orientation_maps_diff_scales_with_ellipse(traj):
+    n_ex = int(np.sqrt(traj.N_E))
+
+    scale_run_names = []
+    plot_scales = [0.0, 100.0, 200.0, 300.0, 400.0]
+    for scale in plot_scales:
+        par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
+        scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
+    print(scale_run_names)
+
+    fig, axes = plt.subplots(1, 5, figsize=(18., 4.5))
+    for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
+        traj.f_set_crun(run_name)
+
+        X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
+
+        inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
+        axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
+                         inhibitory_axonal_cloud_array]
+
+        head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
+        # TODO: Why was this transposed for plotting? (now changed)
+        c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
+        # ax.set_title('Correlation length: {}'.format(scale))
+        # fig.colorbar(c, ax=ax, label="Tuning")
+        ax.set_xticks([])
+        ax.set_yticks([])
+
+        p1 = axonal_clouds[44]
+        ell = p1.get_ellipse()
+        ell._linewidth = 5.
+        ax.add_artist(ell)
+
+        p2 = axonal_clouds[77]
+        circ_r = 2 * np.sqrt(2500.)
+        circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
+        circ._linewidth = 5.
+
+        ax.add_artist(circ)
+
+    # fig.suptitle('axonal cloud', fontsize=16)
+    traj.f_restore_default()
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales_with_ellipse.png', dpi=200)
+
+def plot_excitatory_condensed_polar_plot(traj, plot_run_names, polar_plot_id):
+    directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
+    directions_plt = list(directions)
+    directions_plt.append(directions[0])
+    fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
+    # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
+    # sorted_ids = np.argsort(head_direction_indices)
+    # plot_n_idx = sorted_ids[-75]
+    plot_n_idx = polar_plot_id
+
+    line_styles = ['dotted', 'solid', 'dashed']
+    colors = ['r', 'lightsalmon', 'grey']
+
+    max_rate = 0.0
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        tuning_vectors = traj.results.runs[run_name].tuning_vectors
+        rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
+        run_max_rate = np.max(rate_plot)
+        if run_max_rate > max_rate:
+            max_rate = run_max_rate
+        rate_plot.append(rate_plot[0])
+        ax.plot(directions_plt, rate_plot, label=label, color=colors[run_idx], linestyle=line_styles[run_idx])
+    # ax.set_title('Firing Rate')
+    ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
+    # TODO: Set ticks for polar
+    ticks = [30., 60., 90.]
+    ax.set_rticks(ticks)
+    ax.set_rlabel_position(230)
+    ax.legend(loc='upper center', bbox_to_anchor=(0.2, 1.05),
+          fancybox=True, shadow=True)
+    plt.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png', dpi=200)
+
+def plot_inhibitory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, max_rate):
+    directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
+    directions_plt = list(directions)
+    directions_plt.append(directions[0])
+    fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
+    # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
+    # sorted_ids = np.argsort(head_direction_indices)
+    # plot_n_idx = sorted_ids[-75]
+    plot_n_idx = polar_plot_id
+
+    line_styles = ['dotted', 'solid']
+    colors = ['b', 'lightblue']
+
+    for run_idx, run_name in enumerate(plot_run_names[:2]):
+        # ax = axes[max_hdi_idx, run_idx]
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+
+        tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
+        rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
+        rate_plot.append(rate_plot[0])
+        ax.plot(directions_plt, rate_plot, label=label, color=colors[run_idx], linestyle=line_styles[run_idx])
+    # ax.set_title('Inh. Firing Rate')
+    # TODO: Set ticks for polar
+    # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
+    ticks = [40., 80., 120.]
+    ax.set_rticks(ticks)
+    ax.set_rlabel_position(230)
+    ax.legend(loc='upper center', bbox_to_anchor=(0.2, 1.05),
+              fancybox=True, shadow=True)
+    plt.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'condensed_inhibitory_polar_plot.png', dpi=200)
+
+def plot_hdi_over_corr_len(traj, plot_run_names):
+    corr_len_expl = traj.f_get('correlation_length').f_get_range()
+    seed_expl = traj.f_get('seed').f_get_range()
+    label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
+    label_range = set(label_expl)
+
+    hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
+    hdi_frame.index.names = ["corr_len", "seed", "label"]
+
+    for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
+        ex_tunings = traj.results.runs[run_name].ex_tunings
+        head_direction_indices = traj.results[run_name].head_direction_indices
+        hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
+
+    # TODO: Standart deviation also for the population
+    hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
+    hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
+
+    # Ellipsoid markers
+    rx, ry = 5., 12.
+    # area = rx * ry * np.pi * 2.
+    area = 1.
+    theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
+    verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
+
+    style_dict = {
+        'no conn': ['grey', 'dashed', '', 0],
+        'ellipsoid': ['blue', 'solid', verts, 10.],
+        'circular': ['lightblue', 'solid', 'o', 8.]
+
+    }
+    # colors = ['blue', 'grey', 'lightblue']
+    # linestyles = ['solid', 'dashed', 'solid']
+    # markers = [verts, '', 'o']
+
+
+    fig, ax = plt.subplots(1, 1)
+
+    for label in label_range:
+        hdi_mean = hdi_exc_n_and_seed_mean[:, label]
+        hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
+        corr_len_range = hdi_mean.keys().to_numpy()
+
+        col, lin, mar, mar_size = style_dict[label]
+
+        ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
+        plt.fill_between(corr_len_range, hdi_mean - hdi_std,
+                         hdi_mean + hdi_std, alpha=0.4, color=col)
+    ax.set_xlabel('Correlation length')
+    ax.set_ylabel('Head Direction Index')
+    ax.axvline(206.9, color='k', linewidth=0.5)
+    ax.set_ylim(0.0,1.0)
+    ax.set_xlim(0.0,400.)
+    ax.legend()
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
+
+
+def plot_hdi_histogram_excitatory(traj, plot_run_names):
+    labels = []
+    hdis = []
+    colors = ['black', 'red', 'green']
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        labels.append(label)
+
+        head_direction_indices = traj.results.runs[run_name].head_direction_indices
+        hdis.append(head_direction_indices)
+
+    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
+    ax.hist(hdis, color=colors, label=labels, bins=30)
+
+    for hdi, color in zip(hdis, colors):
+        mean_hdi = np.mean(hdi)
+        ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
+    ax.set_xlabel("HDI")
+    ax.legend()
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
+
+
+def plot_hdi_violin_excitatory(traj, plot_run_names):
+    labels = []
+    hdis = []
+    colors = ['black', 'red', 'green']
+    no_conn_hdi = 0.
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        head_direction_indices = traj.results.runs[run_name].head_direction_indices
+
+        if label == 'no conn':
+            no_conn_hdi = np.mean(head_direction_indices)
+        else:
+            labels.append(label)
+            hdis.append(sorted(head_direction_indices))
+
+    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
+    # hdis = np.array(hdis)
+    viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
+
+    viol_plt['cmeans'].set_color('black')
+
+    for pc in viol_plt['bodies']:
+        pc.set_facecolor('red')
+        pc.set_edgecolor('black')
+        pc.set_alpha(0.7)
+
+    ax.axhline(no_conn_hdi, color='black', linestyle='--')
+    ax.annotate('no conn', xy=(0.45,0.48), xycoords='axes fraction')
+
+    ax.set_xticks(np.arange(1, len(labels) + 1))
+    ax.set_xticklabels(labels)
+    ax.set_ylabel('HDI')
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
+
+
+def plot_hdi_violin_inhibitory(traj, plot_run_names):
+    labels = []
+    hdis = []
+    colors = ['black', 'red']
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        if label != 'no conn':
+            labels.append(label)
+
+            head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
+            hdis.append(sorted(head_direction_indices))
+
+    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
+    viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
+
+    viol_plt['cmeans'].set_color('black')
+
+    for pc in viol_plt['bodies']:
+        pc.set_facecolor('blue')
+        pc.set_edgecolor('black')
+        pc.set_alpha(0.7)
+
+    ax.set_xticks(np.arange(1, len(labels) + 1))
+    ax.set_xticklabels(labels)
+    ax.set_ylabel('HDI')
+
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
+
+def plot_hdi_violin_combined(traj, plot_run_names):
+    labels = []
+    inh_hdis = []
+    exc_hdis = []
+    no_conn_hdi = 0.
+
+    colors = ['black', 'red']
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        if label != 'no conn':
+            labels.append(label)
+
+            inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
+            inh_hdis.append(sorted(inh_head_direction_indices))
+
+            exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
+            exc_hdis.append(sorted(exc_head_direction_indices))
+        else:
+            exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
+            no_conn_hdi = np.mean(exc_head_direction_indices)
+
+    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
+    inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
+
+    # viol_plt['cmeans'].set_color('black')
+    #
+    # for pc in viol_plt['bodies']:
+    #     pc.set_facecolor('blue')
+    #     pc.set_edgecolor('black')
+    #     pc.set_alpha(0.7)
+
+    for b in inh_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
+        b.set_color('b')
+
+    exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
+
+    for b in exc_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
+        b.set_color('r')
+
+    ax.axhline(no_conn_hdi, color='black', linestyle='--')
+    ax.annotate('no conn', xy=(0.45, 0.48), xycoords='axes fraction')
+
+    ax.set_xticks(np.arange(1, len(labels) + 1))
+    ax.set_xticklabels(labels)
+    ax.set_ylabel('HDI')
+
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
+
+def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names):
+    labels = []
+    inh_hdis = []
+    exc_hdis = []
+    no_conn_hdi = 0.
+
+    colors = ['black', 'red']
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        if label != 'no conn':
+            labels.append(label)
+
+            inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
+            inh_hdis.append(sorted(inh_head_direction_indices))
+
+            exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
+            exc_hdis.append(sorted(exc_head_direction_indices))
+        else:
+            exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
+            no_conn_hdi = np.mean(exc_head_direction_indices)
+
+    fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
+
+    inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
+    for b in inh_ell_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
+        b.set_color('b')
+    mean_line = inh_ell_viol_plt['cmeans']
+    mean_line.set_color('b')
+    mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
+
+    exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
+    for b in exc_ell_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
+        b.set_color('r')
+    mean_line = exc_ell_viol_plt['cmeans']
+    mean_line.set_color('r')
+    mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
+
+    inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
+    for b in inh_cir_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
+        b.set_color('b')
+    mean_line = inh_cir_viol_plt['cmeans']
+    mean_line.set_color('b')
+    mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
+
+    exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
+    for b in exc_cir_viol_plt['bodies']:
+        m = np.mean(b.get_paths()[0].vertices[:, 0])
+        b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
+        b.set_color('r')
+    mean_line = exc_cir_viol_plt['cmeans']
+    mean_line.set_color('r')
+    mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
+
+    ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
+    ax.axvline(1.0, color='k')
+    ax.annotate('no conn', xy=(0.75, 0.415), xycoords='axes fraction')
+    ax.set_xlim(0.5, 1.5)
+    ax.set_ylim(0.0, 1.0)
+    ax.set_xticks([0.75, 1.25])
+    ax.set_xticklabels(['circular', 'ellipsoid'])
+    ax.set_ylabel('HDI')
+
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
+
+def plot_hdi_histogram_inhibitory(traj, plot_run_names):
+    labels = []
+    hdis = []
+    colors = ['black', 'red']
+    for run_idx, run_name in enumerate(plot_run_names):
+        label = traj.derived_parameters.runs[run_name].morphology.morph_label
+        if label != 'no conn':
+            labels.append(label)
+
+            head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
+            hdis.append(head_direction_indices)
+
+    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
+    ax.hist(hdis, color=colors, label=labels, bins=30)
+
+    for hdi, color in zip(hdis, colors):
+        mean_hdi = np.mean(hdi)
+        ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
+    ax.set_xlabel("HDI")
+    ax.legend()
+
+    fig.tight_layout()
+
+    if save_figs:
+        plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
+
+
+def filter_run_names_by_par_dict(traj, par_dict):
+    run_name_list = []
+    for run_idx, run_name in enumerate(traj.f_get_run_names()):
+        traj.f_set_crun(run_name)
+        paramters_equal = True
+        for key, val in par_dict.items():
+            if (traj.par[key] != val):
+                paramters_equal = False
+        if paramters_equal:
+            run_name_list.append(run_name)
+    traj.f_restore_default()
+
+    return run_name_list
+
+
+if __name__ == "__main__":
+    traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
+    NO_LOADING = 0
+    FULL_LOAD = 2
+    traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
+    traj.v_auto_load = True
+    save_figs = True
+
+    plot_corr_len = get_closest_correlation_length(traj, 200.0)
+    par_dict = {'seed': 1, 'correlation_length': plot_corr_len}
+    plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
+    print(plot_run_names)
+
+    direction_idx = 6
+    dir_indices = [0, 3, 6, 9]
+
+    # plot_axonal_clouds(traj, plot_run_names)
+    # #
+    # ex_polar_plot_id = plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names)
+    # #
+    # in_polar_plot_id, in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names)
+    #
+    # plot_orientation_maps_diff_scales_with_ellipse(traj)
+    #
+    # plot_hdi_histogram_inhibitory(traj, plot_run_names)
+    #
+    # plot_hdi_histogram_excitatory(traj, plot_run_names)
+    #
+    # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
+
+    # plot_excitatory_condensed_polar_plot(traj, plot_run_names, ex_polar_plot_id)
+    #
+    # plot_inhibitory_condensed_polar_plot(traj, plot_run_names, in_polar_plot_id, in_max_rate)
+
+    # plot_hdi_violin_combined(traj, plot_run_names)
+    #
+    plot_hdi_violin_combined_and_overlayed(traj, plot_run_names)
+
+    if not save_figs:
+        plt.show()
+
+    traj.f_restore_default()