import numpy as np from pypet import Trajectory from pypet.brian2 import Brian2MonitorResult from tqdm import tqdm from brian2.units import * import matplotlib.pyplot as plt import matplotlib as mpl import pandas as pd from scripts.spatial_network.run_synaptic_strength_scan_orientation_map import DATA_FOLDER, TRAJ_NAME FIGURE_SAVE_PATH = '../../figures/figures_spatial_head_direction_network_orientation_map/' def plot_hdi_synaptic_strength(traj, plot_run_names): inh_strength_expl = traj.f_get('inhibitory').f_get_range() exc_strength_expl = traj.f_get('excitatory').f_get_range() seed_expl = traj.f_get('seed').f_get_range() label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()] label_expl = [] for i in range(3000): label_expl.append('ellipsoid') for i in range(3000): label_expl.append('circular') print(label_expl) inh_strength_expl = inh_strength_expl[:-10] exc_strength_expl = exc_strength_expl[:-10] seed_expl = seed_expl[:-10] # label_range = list(set(label_expl)) inh_strength_range = sorted(set(inh_strength_expl)) print(inh_strength_range) exc_strength_range = sorted(set(exc_strength_expl)) idiot_run_names = plot_run_names[:-10] hdi_frame = pd.Series(index=[inh_strength_expl, exc_strength_expl, seed_expl, label_expl]) hdi_frame.index.names = ["inhibitory", "excitatory", "seed", "label"] idiot_id = 0 for run_name, inh_strength, exc_strength, seed, label in tqdm(zip(idiot_run_names, inh_strength_expl, exc_strength_expl, seed_expl, label_expl), total=len(idiot_run_names)): if idiot_id >= 6000: continue # The tunings, while not used, must be accessed or the following line will fail! ex_tunings = traj.results.runs[run_name].ex_tunings head_direction_indices = traj.results[run_name].head_direction_indices hdi_frame[inh_strength, exc_strength, seed, label] = np.mean(head_direction_indices) idiot_id += 1 # print(hdi_frame) # TODO: Standart deviation also for the population hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 1, 3]).mean() # print(hdi_exc_n_and_seed_mean) fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5)) mean_hdi_for_diff = [] for ax, label in zip(axes[:-1], ['circular', 'ellipsoid']): hdi_mean = hdi_exc_n_and_seed_mean[:, :, label] # values x and y give values at z xmin = inh_strength_range[0] xmax = inh_strength_range[-1] dx = (xmax - xmin) / (len(inh_strength_range)) ymin = exc_strength_range[0] ymax = exc_strength_range[-1] dy = (ymax - ymin) / (len(exc_strength_range)) print(xmin, xmax, dx) print(ymin, ymax, dy) # transform x and y to boundaries of x and y # print(np.arange(xmin, xmax + dx, dx) - dx / 2.) # print(np.arange(ymin, ymax + dy, dy) - dy / 2.) print(exc_strength_range) # x = np.arange(xmin, xmax + dx, dx) - dx / 2. # y = np.arange(ymin, ymax + dy, dy) - dy / 2. x = np.linspace(xmin, xmax, len(inh_strength_range) + 1) y = np.linspace(ymin, ymax, len(exc_strength_range) + 1) print(x) print(y) X, Y = np.meshgrid(x, y) # X, Y = np.meshgrid(inh_strength_range,exc_strength_range) print(label) # print(hdi_mean) hdi_mean_plot = hdi_mean.values.reshape(len(inh_strength_range),len(exc_strength_range)) mean_hdi_for_diff.append(hdi_mean_plot) # print(hdi_mean_plot) c = ax.pcolor(X, Y, hdi_mean_plot.T, vmin=0.0, vmax=0.5, cmap='hot') ax.set_title(label) fig.colorbar(c, ax=ax, label="mean HDI") ax.set_xticks(np.arange(xmin, xmax + dx, dx)) ax.set_yticks(np.arange(ymin, ymax + dy, dy)) ax.set_xlabel('inhibitory strength (nS)') ax.set_ylabel('excitatory strength (nS)') ax.set_xticks(np.linspace(xmin, xmax, 6)) ax.set_yticks(np.linspace(ymin, ymax, 6)) ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) xmin = inh_strength_range[0] xmax = inh_strength_range[-1] dx = (xmax - xmin) / (len(inh_strength_range)) ymin = exc_strength_range[0] ymax = exc_strength_range[-1] dy = (ymax - ymin) / (len(exc_strength_range)) x = np.linspace(xmin, xmax, len(inh_strength_range) + 1) y = np.linspace(ymin, ymax, len(exc_strength_range) + 1) X, Y = np.meshgrid(x, y) hdi_diff_plot = mean_hdi_for_diff[1] - mean_hdi_for_diff[0] print(mean_hdi_for_diff[1]) print(mean_hdi_for_diff[0]) print(hdi_diff_plot) print(hdi_diff_plot.shape) ax = axes[2] c = ax.pcolor(X, Y, hdi_diff_plot.T, cmap='hot') ax.set_title('difference') fig.colorbar(c, ax=ax, label="mean HDI difference") ax.set_xticks(np.arange(xmin, xmax + dx, dx)) ax.set_yticks(np.arange(ymin, ymax + dy, dy)) ax.set_xlabel('inhibitory strength (nS)') ax.set_ylabel('excitatory strength (nS)') ax.set_xticks(np.linspace(xmin, xmax, 6)) ax.set_yticks(np.linspace(ymin, ymax, 6)) ax.set_xlim(x[0], x[-1]) ax.set_ylim(y[0], y[-1]) fig.suptitle('Mean HDI over syn. strength', fontsize=16) if save_dont_show: plt.savefig(FIGURE_SAVE_PATH + 'hdi_synaptic_strength.png', dpi=200) def filter_run_names_by_par_dict(traj, par_dict): run_name_list = [] for run_idx, run_name in enumerate(traj.f_get_run_names()): traj.f_set_crun(run_name) paramters_equal = True for key, val in par_dict.items(): if(traj.par[key] != val): paramters_equal = False if paramters_equal: run_name_list.append(run_name) traj.f_restore_default() return run_name_list def filter_run_names_and_duplicates_because_im_an_idiot(traj, par_dict): run_name_list = [] for run_idx, run_name in enumerate(traj.f_get_run_names()): traj.f_set_crun(run_name) paramters_equal = True for key, val in par_dict.items(): if(traj.par[key] != val): paramters_equal = False if paramters_equal: run_name_list.append(run_name) traj.f_restore_default() return run_name_list if __name__ == "__main__": traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult) NO_LOADING = 0 FULL_LOAD = 2 traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING) traj.v_auto_load = True save_dont_show = True if save_dont_show: mpl.use('Agg') plot_hdi_synaptic_strength(traj, traj.f_get_run_names()) plt.show() traj.f_restore_default()