|
@@ -208,51 +208,40 @@ def plot_single_cell_analysis(traj, plot_run_names):
|
|
|
|
|
|
def plot_hdi_over_corr_len(traj, plot_run_names):
|
|
|
|
|
|
- fig, ax = plt.subplots(1, 1)
|
|
|
|
|
|
corr_len_expl = traj.f_get('correlation_length').f_get_range()
|
|
|
seed_expl = traj.f_get('seed').f_get_range()
|
|
|
label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
|
|
|
- # seed_range = sorted(set(seed_expl))
|
|
|
- # corr_len_range = sorted(set(corr_len_expl))
|
|
|
- # label_range = set(label_expl)
|
|
|
- # print(seed_range)
|
|
|
- # print(corr_len_range)
|
|
|
- # print(label_range)
|
|
|
-
|
|
|
- # index = pd.MultiIndex.from_tuples(zip(corr_len_expl, seed_expl, label_expl), names=['corr_len', 'seed', 'label'])
|
|
|
- hdi_frame = pd.DataFrame(index=[corr_len_expl, seed_expl, label_expl])
|
|
|
+ label_range = set(label_expl)
|
|
|
|
|
|
- # label_range = [p.f_get() for p in traj.f_get_derived_parameters().values()]
|
|
|
+ hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
|
|
|
+ hdi_frame.index.names = ["corr_len", "seed", "label"]
|
|
|
|
|
|
for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
|
|
|
- print(run_name, corr_len, seed, label)
|
|
|
ex_tunings = traj.results.runs[run_name].ex_tunings
|
|
|
head_direction_indices = traj.results[run_name].head_direction_indices
|
|
|
- print(head_direction_indices.shape)
|
|
|
hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
|
|
|
|
|
|
+ # TODO: Standart deviation also for the population
|
|
|
+ hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
|
|
|
+ hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
|
|
|
+
|
|
|
+
|
|
|
+ fig, ax = plt.subplots(1, 1)
|
|
|
+
|
|
|
+ for label in label_range:
|
|
|
+ hdi_mean = hdi_exc_n_and_seed_mean[:, label]
|
|
|
+ hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
|
|
|
+ corr_len_range = hdi_mean.keys().to_numpy()
|
|
|
+
|
|
|
+
|
|
|
+ ax.plot(corr_len_range, hdi_mean, label=label, marker='o')
|
|
|
+ plt.fill_between(corr_len_range, hdi_mean - hdi_std,
|
|
|
+ hdi_mean + hdi_std, alpha=0.4)
|
|
|
+ ax.set_xlabel('Correlation length')
|
|
|
+ ax.set_ylabel('Head Direction Index')
|
|
|
+ ax.legend()
|
|
|
|
|
|
- # label = traj.derived_parameters.runs[run_name].morphology.morph_label
|
|
|
- #
|
|
|
- # ex_tunings = traj.results.runs[run_name].ex_tunings
|
|
|
- #
|
|
|
- #
|
|
|
- # ex_tunings_plt = np.array(ex_tunings)
|
|
|
- # sort_ids = ex_tunings_plt.argsort()
|
|
|
- # ex_tunings_plt = ex_tunings_plt[sort_ids]
|
|
|
- #
|
|
|
- # head_direction_indices = traj.results[run_name].head_direction_indices
|
|
|
- #
|
|
|
- # hdi_plt = head_direction_indices
|
|
|
- # hdi_plt = hdi_plt[sort_ids]
|
|
|
- # ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
|
|
|
-
|
|
|
- # ax.legend()
|
|
|
- # ax.set_xlabel("Angles (rad)")
|
|
|
- # ax.set_ylabel("head direction index")
|
|
|
- # ax.set_title('hdi over input tuning', fontsize=16)
|
|
|
- print(hdi_frame)
|
|
|
|
|
|
|
|
|
def filter_run_indices_by_par_dict(traj, par_dict):
|