paper_figures_spatial_head_direction_network_orientation_map.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import pandas as pd
  4. from brian2.units import *
  5. from matplotlib import gridspec
  6. from mpl_toolkits.axes_grid1 import make_axes_locatable
  7. from pypet import Trajectory
  8. from pypet.brian2 import Brian2MonitorResult
  9. from scipy.optimize import curve_fit
  10. from matplotlib.patches import Ellipse
  11. from matplotlib.patches import Polygon
  12. import matplotlib.legend as mlegend
  13. from matplotlib.patches import Rectangle
  14. from scripts.interneuron_placement import get_position_mesh, Pickle, get_correct_position_mesh
  15. from scripts.spatial_network.figures_spatial_head_direction_network_orientation_map import plot_hdi_in_space
  16. from scripts.spatial_network.run_entropy_maximisation_orientation_map import DATA_FOLDER, TRAJ_NAME
  17. # FIGURE_SAVE_PATH = '../../figures/figure_4_paper_draft/'
  18. FIGURE_SAVE_PATH = '../../figures/figure_4_poster_fens/'
  19. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  20. """
  21. Place a table legend on the axes.
  22. Creates a legend where the labels are not directly placed with the artists,
  23. but are used as row and column headers, looking like this:
  24. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  25. -------------------------------------------------------------
  26. row_labels[1] |
  27. row_labels[2] | <artists go there>
  28. row_labels[3] |
  29. Parameters
  30. ----------
  31. ax : `matplotlib.axes.Axes`
  32. The artist that contains the legend table, i.e. current axes instant.
  33. col_labels : list of str, optional
  34. A list of labels to be used as column headers in the legend table.
  35. `len(col_labels)` needs to match `ncol`.
  36. row_labels : list of str, optional
  37. A list of labels to be used as row headers in the legend table.
  38. `len(row_labels)` needs to match `len(handles) // ncol`.
  39. title_label : str, optional
  40. Label for the top left corner in the legend table.
  41. ncol : int
  42. Number of columns.
  43. Other Parameters
  44. ----------------
  45. Refer to `matplotlib.legend.Legend` for other parameters.
  46. """
  47. #################### same as `matplotlib.axes.Axes.legend` #####################
  48. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  49. if len(extra_args):
  50. raise TypeError('legend only accepts two non-keyword arguments')
  51. if col_labels is None and row_labels is None:
  52. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  53. ax.legend_._remove_method = ax._remove_legend
  54. return ax.legend_
  55. #################### modifications for table legend ############################
  56. else:
  57. ncol = kwargs.pop('ncol')
  58. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  59. title_label = [title_label]
  60. # blank rectangle handle
  61. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  62. # empty label
  63. empty = [""]
  64. # number of rows infered from number of handles and desired number of columns
  65. nrow = len(handles) // ncol
  66. # organise the list of handles and labels for table construction
  67. if col_labels is None:
  68. assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
  69. leg_handles = extra * nrow
  70. leg_labels = row_labels
  71. elif row_labels is None:
  72. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
  73. leg_handles = []
  74. leg_labels = []
  75. else:
  76. assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
  77. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
  78. leg_handles = extra + extra * nrow
  79. leg_labels = title_label + row_labels
  80. for col in range(ncol):
  81. if col_labels is not None:
  82. leg_handles += extra
  83. leg_labels += [col_labels[col]]
  84. leg_handles += handles[col*nrow:(col+1)*nrow]
  85. leg_labels += empty * nrow
  86. # Create legend
  87. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol+int(row_labels is not None), handletextpad=handletextpad, **kwargs)
  88. ax.legend_._remove_method = ax._remove_legend
  89. return ax.legend_
  90. def get_closest_correlation_length(traj, correlation_length):
  91. available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
  92. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
  93. if closest_length != correlation_length:
  94. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  95. correlation_length, closest_length))
  96. corr_len = closest_length
  97. return corr_len
  98. def gauss(x, *p):
  99. A, mu, sigma, B = p
  100. return A * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + B
  101. def plot_tuning_curve(traj, direction_idx, plot_run_names):
  102. seed_expl = traj.f_get('seed').f_get_range()
  103. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  104. label_range = set(label_expl)
  105. rate_frame = pd.Series(index=[seed_expl, label_expl])
  106. rate_frame.index.names = ["seed", "label"]
  107. dir_bins = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  108. rate_bins = [[] for i in range(len(dir_bins)-1)]
  109. for run_name, seed, label in zip(plot_run_names, seed_expl, label_expl):
  110. ex_tunings = traj.results.runs[run_name].ex_tunings
  111. binned_idx = np.digitize(ex_tunings, dir_bins)
  112. #TODO: Avareage over directions by recentering
  113. firing_rate_array = traj.results[run_name].firing_rate_array
  114. for bin_idx, rate in zip(binned_idx, firing_rate_array[:, direction_idx]):
  115. rate_bins[bin_idx].append(rate)
  116. rate_bins_mean = [np.mean(rate_bin) for rate_bin in rate_bins]
  117. rate_frame[seed, label] = firing_rate_array
  118. # TODO: Standart deviation also for the population
  119. rate_seed_mean = rate_frame.groupby(level=[1]).mean()
  120. rate_seed_std_dev = rate_frame.groupby(level=[1]).std()
  121. style_dict = {
  122. 'no conn': ['grey', 'dashed', '', 0],
  123. 'ellipsoid': ['blue', 'solid', 'x', 10.],
  124. 'circular': ['lightblue', 'solid', 'o', 8.]
  125. }
  126. fig, ax = plt.subplots(1, 1)
  127. for label in label_range:
  128. hdi_mean = rate_seed_mean[label]
  129. hdi_std = rate_seed_std_dev[label]
  130. ex_tunings = traj.results.runs[run_name].ex_tunings
  131. col, lin, mar, mar_size = style_dict[label]
  132. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  133. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  134. hdi_mean + hdi_std, alpha=0.4, color=col)
  135. ax.set_xlabel('Correlation length')
  136. ax.set_ylabel('Head Direction Index')
  137. ax.axvline(206.9, color='k', linewidth=0.5)
  138. ax.set_ylim(0.0, 1.0)
  139. ax.set_xlim(0.0, 400.)
  140. ax.legend()
  141. fig, ax = plt.subplots(1, 1)
  142. for run_idx, run_name in enumerate(plot_run_names):
  143. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  144. ex_tunings = traj.results.runs[run_name].ex_tunings
  145. coeff_list = []
  146. ex_tunings_plt = np.array(ex_tunings)
  147. sort_ids = ex_tunings_plt.argsort()
  148. ex_tunings_plt = ex_tunings_plt[sort_ids]
  149. firing_rate_array = traj.results[run_name].firing_rate_array
  150. # firing_rate_array = traj.f_get('firing_rate_array')
  151. rates_plt = firing_rate_array[:, direction_idx]
  152. rates_plt = rates_plt[sort_ids]
  153. ax.scatter(ex_tunings_plt, rates_plt / hertz, label=label, alpha=0.3)
  154. ax.legend()
  155. ax.set_xlabel("Angles (rad)")
  156. ax.set_ylabel("f (Hz)")
  157. ax.set_title('tuning curves', fontsize=16)
  158. if save_figs:
  159. plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png', dpi=200)
  160. def colorbar(mappable):
  161. from mpl_toolkits.axes_grid1 import make_axes_locatable
  162. import matplotlib.pyplot as plt
  163. last_axes = plt.gca()
  164. ax = mappable.axes
  165. fig = ax.figure
  166. divider = make_axes_locatable(ax)
  167. cax = divider.append_axes("right", size="5%", pad=0.05)
  168. cbar = fig.colorbar(mappable, cax=cax)
  169. plt.sca(last_axes)
  170. return cbar
  171. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names):
  172. max_val = 0
  173. for run_name in plot_run_names:
  174. fr_array = traj.results.runs[run_name].firing_rate_array
  175. f_rates = fr_array[:, direction_idx]
  176. run_max_val = np.max(f_rates)
  177. if run_max_val > max_val:
  178. # if traj.derived_parameters.runs[run_name].morphology.morph_label == 'ellipsoid':
  179. # n_id_max_rate = np.argmax(f_rates)
  180. max_val = run_max_val
  181. n_id_polar_plot = 609
  182. # Mark the neuron that is shown in Polar plot
  183. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  184. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  185. # Vertices for the plotted triangle
  186. tr_scale = 13.
  187. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  188. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  189. tr_vertices = np.array([[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  190. height = 3.5
  191. # color_bar_size = 0.05 * height + 0.05
  192. # width = 3 * height + color_bar_size
  193. width = 10.5
  194. fig, axes = plt.subplots(1, 3, figsize=(width, height))
  195. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2,0,1]]):
  196. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  197. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  198. firing_rate_array = traj.results[run_name].firing_rate_array
  199. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  200. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  201. number_of_excitatory_neurons_per_row)),
  202. vmin=0, vmax=max_val, cmap='Reds')
  203. ax.set_title(label)
  204. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  205. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  206. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  207. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  208. # fig.suptitle('spatial firing rate map', fontsize=16)
  209. colorbar(c)
  210. fig.tight_layout()
  211. if save_figs:
  212. plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png', dpi=200)
  213. return n_id_polar_plot
  214. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names):
  215. max_val = 0
  216. for run_name in plot_run_names:
  217. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  218. f_rates = fr_array[:, direction_idx]
  219. run_max_val = np.max(f_rates)
  220. if run_max_val > max_val:
  221. max_val = run_max_val
  222. n_id_polar_plot = 52
  223. # Mark the neuron that is shown in Polar plot
  224. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  225. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  226. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  227. fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.5))
  228. for ax, run_name in zip(axes, [plot_run_names[i] for i in [0,1]]):
  229. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  230. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  231. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  232. X, Y = get_correct_position_mesh(inh_positions)
  233. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  234. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  235. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  236. number_of_inhibitory_neurons_per_row)),
  237. vmin=0, vmax=max_val, cmap='Blues')
  238. ax.set_title(label)
  239. circle_r = 40.
  240. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  241. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  242. # fig.colorbar(c, ax=ax, label="f (Hz)")
  243. # fig.suptitle('spatial firing rate map', fontsize=16)
  244. colorbar(c)
  245. fig.tight_layout()
  246. if save_figs:
  247. plt.savefig(FIGURE_SAVE_PATH + 'inh_firing_rate_map.png', dpi=200)
  248. return n_id_polar_plot, max_val
  249. def plot_hdi_over_tuning(traj, plot_run_names):
  250. fig, ax = plt.subplots(1, 1)
  251. for run_idx, run_name in enumerate(plot_run_names):
  252. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  253. ex_tunings = traj.results.runs[run_name].ex_tunings
  254. ex_tunings_plt = np.array(ex_tunings)
  255. sort_ids = ex_tunings_plt.argsort()
  256. ex_tunings_plt = ex_tunings_plt[sort_ids]
  257. head_direction_indices = traj.results[run_name].head_direction_indices
  258. hdi_plt = head_direction_indices
  259. hdi_plt = hdi_plt[sort_ids]
  260. ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
  261. ax.legend()
  262. ax.set_xlabel("Angles (rad)")
  263. ax.set_ylabel("head direction index")
  264. ax.set_title('hdi over input tuning', fontsize=16)
  265. if save_figs:
  266. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png', dpi=200)
  267. def plot_axonal_clouds(traj, plot_run_names):
  268. n_ex = int(np.sqrt(traj.N_E))
  269. fig, axes = plt.subplots(1, 3, figsize=(10.5, 3.5))
  270. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2,0,1]]):
  271. traj.f_set_crun(run_name)
  272. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  273. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  274. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  275. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  276. inhibitory_axonal_cloud_array]
  277. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  278. # TODO: Why was this transposed for plotting? (now changed)
  279. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  280. ax.set_title(label)
  281. # fig.colorbar(c, ax=ax, label="Tuning")
  282. if label != 'no conn' and axonal_clouds is not None:
  283. for i, p in enumerate(axonal_clouds):
  284. ell = p.get_ellipse()
  285. ax.add_artist(ell)
  286. # fig.suptitle('axonal cloud', fontsize=16)
  287. traj.f_restore_default()
  288. fig.tight_layout()
  289. if save_figs:
  290. plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png', dpi=200)
  291. def plot_orientation_maps_diff_scales(traj):
  292. n_ex = int(np.sqrt(traj.N_E))
  293. scale_run_names = []
  294. plot_scales = [0.0, 100.0, 200.0, 300.0]
  295. for scale in plot_scales:
  296. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  297. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  298. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  299. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  300. traj.f_set_crun(run_name)
  301. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  302. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  303. # TODO: Why was this transposed for plotting? (now changed)
  304. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  305. ax.set_title('Correlation length: {}'.format(scale))
  306. fig.colorbar(c, ax=ax, label="Tuning")
  307. # fig.suptitle('axonal cloud', fontsize=16)
  308. traj.f_restore_default()
  309. if save_figs:
  310. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png', dpi=200)
  311. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  312. n_ex = int(np.sqrt(traj.N_E))
  313. scale_run_names = []
  314. plot_scales = [0.0, 100.0, 200.0, 300.0, 400.0]
  315. for scale in plot_scales:
  316. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  317. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  318. print(scale_run_names)
  319. fig, axes = plt.subplots(1, 5, figsize=(18., 4.5))
  320. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  321. traj.f_set_crun(run_name)
  322. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  323. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  324. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  325. inhibitory_axonal_cloud_array]
  326. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  327. # TODO: Why was this transposed for plotting? (now changed)
  328. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  329. # ax.set_title('Correlation length: {}'.format(scale))
  330. # fig.colorbar(c, ax=ax, label="Tuning")
  331. ax.set_xticks([])
  332. ax.set_yticks([])
  333. p1 = axonal_clouds[44]
  334. ell = p1.get_ellipse()
  335. ell._linewidth = 5.
  336. ax.add_artist(ell)
  337. p2 = axonal_clouds[77]
  338. circ_r = 2 * np.sqrt(2500.)
  339. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  340. circ._linewidth = 5.
  341. ax.add_artist(circ)
  342. # fig.suptitle('axonal cloud', fontsize=16)
  343. traj.f_restore_default()
  344. fig.tight_layout()
  345. if save_figs:
  346. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales_with_ellipse.png', dpi=200)
  347. def plot_excitatory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, ex_polar_plot_hdi):
  348. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  349. directions_plt = list(directions)
  350. directions_plt.append(directions[0])
  351. fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
  352. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  353. # sorted_ids = np.argsort(head_direction_indices)
  354. # plot_n_idx = sorted_ids[-75]
  355. plot_n_idx = polar_plot_id
  356. line_styles = ['dotted', 'solid', 'dashed']
  357. colors = ['r', 'lightsalmon', 'grey']
  358. max_rate = 0.0
  359. for run_idx, run_name in enumerate(plot_run_names):
  360. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  361. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  362. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  363. run_max_rate = np.max(rate_plot)
  364. if run_max_rate > max_rate:
  365. max_rate = run_max_rate
  366. rate_plot.append(rate_plot[0])
  367. ax.plot(directions_plt, rate_plot, label='{0}, HDI: {1:.2f}'.format(label, ex_polar_plot_hdi[run_idx]), color=colors[run_idx], linestyle=line_styles[run_idx])
  368. # ax.set_title('Firing Rate')
  369. ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  370. # TODO: Set ticks for polar
  371. ticks = [30., 60., 90.]
  372. ax.set_rticks(ticks)
  373. ax.set_rlabel_position(230)
  374. ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15),
  375. fancybox=True, shadow=True)
  376. plt.tight_layout()
  377. if save_figs:
  378. plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png', dpi=200)
  379. def plot_inhibitory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, in_polar_plot_hdi):
  380. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  381. directions_plt = list(directions)
  382. directions_plt.append(directions[0])
  383. fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
  384. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  385. # sorted_ids = np.argsort(head_direction_indices)
  386. # plot_n_idx = sorted_ids[-75]
  387. plot_n_idx = polar_plot_id
  388. line_styles = ['dotted', 'solid']
  389. colors = ['b', 'lightblue']
  390. for run_idx, run_name in enumerate(plot_run_names[:2]):
  391. # ax = axes[max_hdi_idx, run_idx]
  392. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  393. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  394. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  395. rate_plot.append(rate_plot[0])
  396. ax.plot(directions_plt, rate_plot, label='{0}, HDI: {1:.2f}'.format(label, in_polar_plot_hdi[run_idx]), color=colors[run_idx], linestyle=line_styles[run_idx])
  397. # ax.set_title('Inh. Firing Rate')
  398. # TODO: Set ticks for polar
  399. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  400. ticks = [40., 80., 120.]
  401. ax.set_rticks(ticks)
  402. ax.set_rlabel_position(230)
  403. ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15),
  404. fancybox=True, shadow=True)
  405. plt.tight_layout()
  406. if save_figs:
  407. plt.savefig(FIGURE_SAVE_PATH + 'condensed_inhibitory_polar_plot.png', dpi=200)
  408. def plot_hdi_over_corr_len(traj, plot_run_names):
  409. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  410. seed_expl = traj.f_get('seed').f_get_range()
  411. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  412. label_range = set(label_expl)
  413. hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  414. hdi_frame.index.names = ["corr_len", "seed", "label"]
  415. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  416. ex_tunings = traj.results.runs[run_name].ex_tunings
  417. head_direction_indices = traj.results[run_name].head_direction_indices
  418. hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  419. # TODO: Standart deviation also for the population
  420. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
  421. hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
  422. # Ellipsoid markers
  423. rx, ry = 5., 12.
  424. # area = rx * ry * np.pi * 2.
  425. area = 1.
  426. theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
  427. verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
  428. style_dict = {
  429. 'no conn': ['grey', 'dashed', '', 0],
  430. 'ellipsoid': ['blue', 'solid', verts, 10.],
  431. 'circular': ['lightblue', 'solid', 'o', 8.]
  432. }
  433. # colors = ['blue', 'grey', 'lightblue']
  434. # linestyles = ['solid', 'dashed', 'solid']
  435. # markers = [verts, '', 'o']
  436. fig, ax = plt.subplots(1, 1)
  437. for label in label_range:
  438. hdi_mean = hdi_exc_n_and_seed_mean[:, label]
  439. hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
  440. corr_len_range = hdi_mean.keys().to_numpy()
  441. col, lin, mar, mar_size = style_dict[label]
  442. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  443. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  444. hdi_mean + hdi_std, alpha=0.4, color=col)
  445. ax.set_xlabel('Correlation length')
  446. ax.set_ylabel('Head Direction Index')
  447. ax.axvline(206.9, color='k', linewidth=0.5)
  448. ax.set_ylim(0.0,1.0)
  449. ax.set_xlim(0.0,400.)
  450. ax.legend()
  451. if save_figs:
  452. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
  453. def plot_hdi_histogram_excitatory(traj, plot_run_names, ex_polar_plot_id):
  454. labels = []
  455. hdis = []
  456. colors = ['black', 'red', 'green']
  457. for run_idx, run_name in enumerate(plot_run_names):
  458. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  459. labels.append(label)
  460. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  461. print('exc {}: {}'.format(label, head_direction_indices[ex_polar_plot_id]))
  462. hdis.append(head_direction_indices)
  463. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  464. ax.hist(hdis, color=colors, label=labels, bins=30)
  465. for hdi, color in zip(hdis, colors):
  466. mean_hdi = np.mean(hdi)
  467. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  468. ax.set_xlabel("HDI")
  469. ax.legend()
  470. fig.tight_layout()
  471. if save_figs:
  472. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
  473. def plot_hdi_violin_excitatory(traj, plot_run_names):
  474. labels = []
  475. hdis = []
  476. colors = ['black', 'red', 'green']
  477. no_conn_hdi = 0.
  478. for run_idx, run_name in enumerate(plot_run_names):
  479. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  480. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  481. if label == 'no conn':
  482. no_conn_hdi = np.mean(head_direction_indices)
  483. else:
  484. labels.append(label)
  485. hdis.append(sorted(head_direction_indices))
  486. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  487. # hdis = np.array(hdis)
  488. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  489. viol_plt['cmeans'].set_color('black')
  490. for pc in viol_plt['bodies']:
  491. pc.set_facecolor('red')
  492. pc.set_edgecolor('black')
  493. pc.set_alpha(0.7)
  494. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  495. ax.annotate('no conn', xy=(0.45,0.48), xycoords='axes fraction')
  496. ax.set_xticks(np.arange(1, len(labels) + 1))
  497. ax.set_xticklabels(labels)
  498. ax.set_ylabel('HDI')
  499. fig.tight_layout()
  500. if save_figs:
  501. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
  502. def plot_hdi_violin_inhibitory(traj, plot_run_names):
  503. labels = []
  504. hdis = []
  505. colors = ['black', 'red']
  506. for run_idx, run_name in enumerate(plot_run_names):
  507. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  508. if label != 'no conn':
  509. labels.append(label)
  510. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  511. hdis.append(sorted(head_direction_indices))
  512. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  513. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  514. viol_plt['cmeans'].set_color('black')
  515. for pc in viol_plt['bodies']:
  516. pc.set_facecolor('blue')
  517. pc.set_edgecolor('black')
  518. pc.set_alpha(0.7)
  519. ax.set_xticks(np.arange(1, len(labels) + 1))
  520. ax.set_xticklabels(labels)
  521. ax.set_ylabel('HDI')
  522. fig.tight_layout()
  523. if save_figs:
  524. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
  525. def plot_hdi_violin_combined(traj, plot_run_names):
  526. labels = []
  527. inh_hdis = []
  528. exc_hdis = []
  529. no_conn_hdi = 0.
  530. colors = ['black', 'red']
  531. for run_idx, run_name in enumerate(plot_run_names):
  532. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  533. if label != 'no conn':
  534. labels.append(label)
  535. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  536. inh_hdis.append(sorted(inh_head_direction_indices))
  537. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  538. exc_hdis.append(sorted(exc_head_direction_indices))
  539. else:
  540. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  541. no_conn_hdi = np.mean(exc_head_direction_indices)
  542. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  543. inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
  544. # viol_plt['cmeans'].set_color('black')
  545. #
  546. # for pc in viol_plt['bodies']:
  547. # pc.set_facecolor('blue')
  548. # pc.set_edgecolor('black')
  549. # pc.set_alpha(0.7)
  550. for b in inh_viol_plt['bodies']:
  551. m = np.mean(b.get_paths()[0].vertices[:, 0])
  552. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  553. b.set_color('b')
  554. exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
  555. for b in exc_viol_plt['bodies']:
  556. m = np.mean(b.get_paths()[0].vertices[:, 0])
  557. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  558. b.set_color('r')
  559. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  560. ax.annotate('no conn', xy=(0.45, 0.48), xycoords='axes fraction')
  561. ax.set_xticks(np.arange(1, len(labels) + 1))
  562. ax.set_xticklabels(labels)
  563. ax.set_ylabel('HDI')
  564. fig.tight_layout()
  565. if save_figs:
  566. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
  567. def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  568. labels = []
  569. inh_hdis = []
  570. exc_hdis = []
  571. no_conn_hdi = 0.
  572. in_polar_plot_hdi = []
  573. ex_polar_plot_hdi = []
  574. colors = ['black', 'red']
  575. for run_idx, run_name in enumerate(plot_run_names):
  576. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  577. if label != 'no conn':
  578. labels.append(label)
  579. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  580. inh_hdis.append(sorted(inh_head_direction_indices))
  581. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  582. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  583. exc_hdis.append(sorted(exc_head_direction_indices))
  584. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  585. else:
  586. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  587. no_conn_hdi = np.mean(exc_head_direction_indices)
  588. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  589. fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
  590. inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
  591. for b in inh_ell_viol_plt['bodies']:
  592. m = np.mean(b.get_paths()[0].vertices[:, 0])
  593. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  594. b.set_color('b')
  595. mean_line = inh_ell_viol_plt['cmeans']
  596. mean_line.set_color('b')
  597. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  598. exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
  599. for b in exc_ell_viol_plt['bodies']:
  600. m = np.mean(b.get_paths()[0].vertices[:, 0])
  601. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  602. b.set_color('r')
  603. mean_line = exc_ell_viol_plt['cmeans']
  604. mean_line.set_color('r')
  605. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  606. inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
  607. for b in inh_cir_viol_plt['bodies']:
  608. m = np.mean(b.get_paths()[0].vertices[:, 0])
  609. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  610. b.set_color('b')
  611. mean_line = inh_cir_viol_plt['cmeans']
  612. mean_line.set_color('b')
  613. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  614. exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
  615. for b in exc_cir_viol_plt['bodies']:
  616. m = np.mean(b.get_paths()[0].vertices[:, 0])
  617. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  618. b.set_color('r')
  619. mean_line = exc_cir_viol_plt['cmeans']
  620. mean_line.set_color('r')
  621. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  622. ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  623. ax.axvline(1.0, color='k')
  624. ax.annotate('no conn', xy=(0.75, 0.415), xycoords='axes fraction')
  625. ax.set_xlim(0.5, 1.5)
  626. ax.set_ylim(0.0, 1.0)
  627. ax.set_xticks([0.75, 1.25])
  628. ax.set_xticklabels(['circular', 'ellipsoid'])
  629. ax.set_ylabel('HDI')
  630. fig.tight_layout()
  631. if save_figs:
  632. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  633. return ex_polar_plot_hdi, in_polar_plot_hdi
  634. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  635. labels = []
  636. inh_hdis = []
  637. exc_hdis = []
  638. no_conn_hdi = 0.
  639. in_polar_plot_hdi = []
  640. ex_polar_plot_hdi = []
  641. colors = ['black', 'red']
  642. for run_idx, run_name in enumerate(plot_run_names):
  643. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  644. if label != 'no conn':
  645. labels.append(label)
  646. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  647. inh_hdis.append(sorted(inh_head_direction_indices))
  648. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  649. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  650. exc_hdis.append(sorted(exc_head_direction_indices))
  651. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  652. else:
  653. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  654. no_conn_hdi = np.mean(exc_head_direction_indices)
  655. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  656. # # fig = plt.figure(figsize=(3.5, 3.5))
  657. # # gs1 = gridspec.GridSpec(1, 2)
  658. #
  659. # fig = plt.figure(constrained_layout=True)
  660. # gs = gridspec.GridSpec(ncols=1, nrows=2, hspace=0.0, wspace=0.0, figure=fig)
  661. # # gs.update(wspace=0.0, hspace=0.0) # set the spacing between axes.
  662. # # f2_ax1 = fig.add_subplot(gs[0, 0])
  663. # # f2_ax2 = fig.add_subplot(gs[0, 1])
  664. # print(gs.get_subplot_params())
  665. fig, axes = plt.subplots(2, 1, figsize=(3.5, 2.5))
  666. plt.subplots_adjust(wspace=0, hspace=0)
  667. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  668. for i in range(2):
  669. # i = i + 1 # grid spec indexes from 0
  670. # ax = fig.add_subplot(gs[i])
  671. ax = axes[i]
  672. ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  673. ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  674. ax.axvline(np.mean(exc_hdis[i]), color='r')
  675. ax.axvline(np.mean(inh_hdis[i]), color='b')
  676. ax.axvline(no_conn_hdi, color='grey', linestyle='--')
  677. ax.set_ylabel(labels[i], rotation='vertical')
  678. # plt.axis('on')
  679. if i == 0:
  680. ax.set_xticklabels([])
  681. ax.set_yticks([])
  682. # ax.set_aspect('equal')
  683. # ax.set_yticks([2.5], label, rotation='vertical')
  684. # ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  685. # ax.axvline(1.0, color='k')
  686. # ax.annotate('no conn', xy=(0.75, 0.415), xycoords='axes fraction')
  687. # ax.set_xlim(0.5, 1.5)
  688. # ax.set_ylim(0.0, 1.0)
  689. # ax.set_xticks([0.75, 1.25])
  690. # ax.set_xticklabels(['circular', 'ellipsoid'])
  691. # ax.set_ylabel('HDI')
  692. # fig.tight_layout()
  693. if save_figs:
  694. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  695. return ex_polar_plot_hdi, in_polar_plot_hdi
  696. def plot_hdi_histogram_inhibitory(traj, plot_run_names, in_polar_plot_id):
  697. labels = []
  698. hdis = []
  699. colors = ['black', 'red']
  700. for run_idx, run_name in enumerate(plot_run_names):
  701. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  702. if label != 'no conn':
  703. labels.append(label)
  704. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  705. print('inh {}: {}'.format(label, head_direction_indices[in_polar_plot_id]))
  706. hdis.append(head_direction_indices)
  707. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  708. ax.hist(hdis, color=colors, label=labels, bins=30)
  709. for hdi, color in zip(hdis, colors):
  710. mean_hdi = np.mean(hdi)
  711. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  712. ax.set_xlabel("HDI")
  713. ax.legend()
  714. fig.tight_layout()
  715. if save_figs:
  716. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
  717. def filter_run_names_by_par_dict(traj, par_dict):
  718. run_name_list = []
  719. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  720. traj.f_set_crun(run_name)
  721. paramters_equal = True
  722. for key, val in par_dict.items():
  723. if (traj.par[key] != val):
  724. paramters_equal = False
  725. if paramters_equal:
  726. run_name_list.append(run_name)
  727. traj.f_restore_default()
  728. return run_name_list
  729. def plot_exc_and_inh_hdi_over_corr_len(traj, plot_run_names):
  730. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  731. seed_expl = traj.f_get('seed').f_get_range()
  732. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  733. label_range = set(label_expl)
  734. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  735. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  736. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  737. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  738. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  739. ex_tunings = traj.results.runs[run_name].ex_tunings
  740. head_direction_indices = traj.results[run_name].head_direction_indices
  741. #TODO: Actual correlation lengths
  742. # actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  743. exc_hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  744. inh_head_direction_indices = traj.results[run_name].inh_head_direction_indices
  745. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_head_direction_indices)
  746. # TODO: Standart deviation also for the population
  747. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  748. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  749. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  750. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  751. exc_style_dict = {
  752. 'no conn': ['grey', 'dashed', '', 0],
  753. 'ellipsoid': ['red', 'solid', '^', 8.],
  754. 'circular': ['lightsalmon', 'solid', '^', 8.]
  755. }
  756. inh_style_dict = {
  757. 'no conn': ['grey', 'dashed', '', 0],
  758. 'ellipsoid': ['blue', 'solid', 'o', 8.],
  759. 'circular': ['lightblue', 'solid', 'o', 8.]
  760. }
  761. fig, ax = plt.subplots(1, 1)
  762. for label in label_range:
  763. if label == 'no conn':
  764. ax.axhline(exc_hdi_n_and_seed_mean[0, label], color='grey', linestyle='--')
  765. ax.annotate('input', xy=(1.01, 0.44), xycoords='axes fraction')
  766. continue
  767. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  768. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  769. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  770. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  771. corr_len_range = exc_hdi_mean.keys().to_numpy()
  772. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  773. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  774. ax.plot(corr_len_range, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin, markersize=exc_mar_size, alpha=0.5)
  775. plt.fill_between(corr_len_range, exc_hdi_mean - exc_hdi_std,
  776. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  777. ax.plot(corr_len_range, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin, markersize=inh_mar_size, alpha=0.5)
  778. plt.fill_between(corr_len_range, inh_hdi_mean - inh_hdi_std,
  779. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  780. ax.set_xlabel('Correlation length')
  781. ax.set_ylabel('Head Direction Index')
  782. ax.axvline(206.9, color='k', linewidth=0.5)
  783. ax.set_ylim(0.0,1.0)
  784. ax.set_xlim(0.0,400.)
  785. tablelegend(ax, ncol=2, bbox_to_anchor=(1, 1),
  786. row_labels=['exc.', 'inh.'],
  787. col_labels=['ellipsoid', 'circular'],
  788. title_label='')
  789. # plt.legend()
  790. if save_figs:
  791. plt.savefig(FIGURE_SAVE_PATH + 'exc_and_inh_hdi_over_corr_len_scaled.png', dpi=200)
  792. def plot_in_degree_map(traj, plot_run_names):
  793. n_ex = int(np.sqrt(traj.N_E))
  794. max_degree = 0
  795. for run_name in plot_run_names:
  796. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  797. exc_degree = np.sum(ie_adjacency, axis=0)
  798. run_max_degree = np.max(exc_degree)
  799. if run_max_degree > max_degree:
  800. max_degree = run_max_degree
  801. fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
  802. for ax, run_name in zip(axes, plot_run_names[:-1]):
  803. traj.f_set_crun(run_name)
  804. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  805. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  806. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  807. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  808. exc_degree = np.sum(ie_adjacency, axis=0)
  809. c = ax.pcolor(X, Y, np.reshape(exc_degree, (number_of_excitatory_neurons_per_row,
  810. number_of_excitatory_neurons_per_row)), vmin=0, vmax=max_degree,
  811. cmap='hot')
  812. ax.set_title(label)
  813. fig.colorbar(c, ax=ax, label="in/out-degree")
  814. fig.suptitle('in/out-degree', fontsize=16)
  815. traj.f_restore_default()
  816. if save_figs:
  817. plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png', dpi=200)
  818. def plot_spatial_hdi_map(traj, plot_run_names):
  819. max_val = 0
  820. for run_name in plot_run_names:
  821. hdis = traj.results.runs[run_name].head_direction_indices
  822. run_max_val = np.max(hdis)
  823. if run_max_val > max_val:
  824. max_val = run_max_val
  825. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  826. for ax, run_name in zip(axes, plot_run_names):
  827. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  828. positions = traj.results.runs[run_name].ex_positions
  829. head_direction_indices = traj.results[run_name].head_direction_indices
  830. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  831. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  832. ax.set_title(label)
  833. fig.colorbar(c, ax=ax, label="head direction index")
  834. fig.suptitle('spatial HDI map', fontsize=16)
  835. if save_figs:
  836. plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png', dpi=200)
  837. def plot_ellipse_hdi_over_circular_hdi(traj, plot_run_names):
  838. labels = []
  839. inh_hdis = []
  840. exc_hdis = []
  841. no_conn_hdi = 0.
  842. colors = ['black', 'red']
  843. for run_idx, run_name in enumerate(plot_run_names):
  844. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  845. if label != 'no conn':
  846. labels.append(label)
  847. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  848. inh_hdis.append(inh_head_direction_indices)
  849. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  850. exc_hdis.append(exc_head_direction_indices)
  851. else:
  852. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  853. no_conn_hdi = np.mean(exc_head_direction_indices)
  854. fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5))
  855. plt.subplots_adjust(wspace=0, hspace=0)
  856. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  857. ax.scatter(exc_hdis[1], exc_hdis[0], color='r', alpha=0.3)
  858. ax.scatter(inh_hdis[1], inh_hdis[0], color='b', alpha=0.3)
  859. ax.set_xlabel('{} HDI'.format(labels[1]))
  860. ax.set_ylabel('{} HDI'.format(labels[0]))
  861. ax.plot([0.0, 1.0], [0.0, 1.0], 'k')
  862. # fig.tight_layout()
  863. if save_figs:
  864. plt.savefig(FIGURE_SAVE_PATH + 'ellipse_hdi_over_circular_hdi.png', dpi=200)
  865. def plot_hdi_over_in_degree_binned(traj, plot_run_names):
  866. fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5))
  867. for run_idx, run_name in enumerate(plot_run_names):
  868. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  869. if label == 'no conn':
  870. continue
  871. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  872. exc_degree = np.sum(ie_adjacency, axis=0)
  873. head_direction_indices = traj.results[run_name].head_direction_indices
  874. n_bins = 13
  875. bins = np.arange(n_bins + 1) - 0.5
  876. exc_degree_bin_ids = np.digitize(exc_degree, bins)
  877. # print(bins)
  878. # print(exc_degree)
  879. # print(exc_degree_bin_ids)
  880. hdis_per_bin = [[] for i in range(n_bins)]
  881. for idx, hdi in zip(exc_degree_bin_ids, head_direction_indices):
  882. hdis_per_bin[idx - 1].append(hdi)
  883. mean_hdi_per_bin = [np.mean(hdis) for hdis in hdis_per_bin]
  884. ax.bar(np.arange(n_bins), mean_hdi_per_bin, label=label, alpha=0.3)
  885. ax.legend()
  886. ax.set_xlabel("In-degree of exc. neurons")
  887. ax.set_ylabel("Mean Head Direction Index")
  888. # ax.set_title('hdi over in-degree', fontsize=16)
  889. if save_figs:
  890. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_in_degree_binned.png', dpi=200)
  891. def plot_in_degree_hist(traj, plot_run_names):
  892. fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5))
  893. for run_idx, run_name in enumerate(plot_run_names):
  894. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  895. if label == 'no conn':
  896. continue
  897. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  898. exc_degree = np.sum(ie_adjacency, axis=0)
  899. head_direction_indices = traj.results[run_name].head_direction_indices
  900. n_bins = 13
  901. bins = np.arange(n_bins + 1) - 0.5
  902. in_degree_hist, bin_edges = np.histogram(exc_degree, bins)
  903. print(in_degree_hist)
  904. # ax.hist(exc_degree, bins=bins, label=label, alpha=0.3)
  905. ax.bar(np.arange(n_bins), in_degree_hist, label=label, alpha=0.3)
  906. ax.legend()
  907. ax.set_xlabel("In-degree of exc. neurons")
  908. ax.set_ylabel("# exc. neurons")
  909. # ax.set_title('hdi over in-degree', fontsize=16)
  910. if save_figs:
  911. plt.savefig(FIGURE_SAVE_PATH + 'in_degree_hist.png', dpi=200)
  912. def plot_hdi_over_in_degree(traj, plot_run_names):
  913. fig, ax = plt.subplots(1, 1)
  914. for run_idx, run_name in enumerate(plot_run_names):
  915. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  916. if label == 'no conn':
  917. continue
  918. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  919. exc_degree = np.sum(ie_adjacency, axis=0)
  920. sort_ids = exc_degree.argsort()
  921. exc_degree_plt = exc_degree[sort_ids]
  922. head_direction_indices = traj.results[run_name].head_direction_indices
  923. hdi_plt = head_direction_indices
  924. hdi_plt = hdi_plt[sort_ids]
  925. ax.scatter(exc_degree_plt, hdi_plt, label=label, alpha=0.3)
  926. ax.legend()
  927. ax.set_xlabel("In-degree of exc. neurons")
  928. ax.set_ylabel("Head Direction Index")
  929. ax.set_title('hdi over in-degree', fontsize=16)
  930. if save_figs:
  931. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_in_degree.png', dpi=200)
  932. if __name__ == "__main__":
  933. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  934. NO_LOADING = 0
  935. FULL_LOAD = 2
  936. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
  937. traj.v_auto_load = True
  938. save_figs = False
  939. plot_corr_len = get_closest_correlation_length(traj, 200.0)
  940. par_dict = {'seed': 1, 'correlation_length': plot_corr_len}
  941. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  942. print(plot_run_names)
  943. direction_idx = 6
  944. dir_indices = [0, 3, 6, 9]
  945. # plot_axonal_clouds(traj, plot_run_names)
  946. # #
  947. ex_polar_plot_id = plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names)
  948. #
  949. in_polar_plot_id, in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names)
  950. #
  951. ex_polar_plot_hdi, in_polar_plot_hdi = plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id)
  952. plot_excitatory_condensed_polar_plot(traj, plot_run_names, ex_polar_plot_id, ex_polar_plot_hdi)
  953. plot_inhibitory_condensed_polar_plot(traj, plot_run_names, in_polar_plot_id, in_polar_plot_hdi)
  954. # plot_ellipse_hdi_over_circular_hdi(traj, plot_run_names)
  955. #
  956. # plot_hdi_over_in_degree_binned(traj, plot_run_names)
  957. #
  958. # plot_hdi_over_in_degree(traj, plot_run_names)
  959. #
  960. # plot_in_degree_hist(traj, plot_run_names)
  961. #
  962. # plot_orientation_maps_diff_scales_with_ellipse(traj)
  963. #
  964. # plot_hdi_histogram_inhibitory(traj, plot_run_names, in_polar_plot_id)
  965. #
  966. # plot_hdi_histogram_excitatory(traj, plot_run_names, ex_polar_plot_id)
  967. #
  968. # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
  969. # plot_hdi_violin_combined(traj, plot_run_names)
  970. #
  971. # plot_exc_and_inh_hdi_over_corr_len(traj, traj.f_get_run_names())
  972. # plot_in_degree_map(traj, plot_run_names)
  973. #
  974. # plot_spatial_hdi_map(traj, plot_run_names)
  975. # par_dict = {'correlation_length': plot_corr_len}
  976. # single_corr_len_run_names = filter_run_names_by_par_dict(traj, par_dict)
  977. # plot_tuning_curve(traj, direction_idx, single_corr_len_run_names)
  978. if not save_figs:
  979. plt.show()
  980. traj.f_restore_default()