paper_figures_spatial_head_direction_network_orientation_map_placement_jitter.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import pandas as pd
  4. # from brian2.units import *
  5. from mpl_toolkits.axes_grid1 import make_axes_locatable
  6. from pypet import Trajectory
  7. from pypet.brian2 import Brian2MonitorResult
  8. from scipy.optimize import curve_fit
  9. from matplotlib.patches import Ellipse
  10. from matplotlib.patches import Polygon
  11. import matplotlib.legend as mlegend
  12. from matplotlib.patches import Rectangle
  13. from scripts.interneuron_placement import get_position_mesh, Pickle, get_correct_position_mesh
  14. from scripts.spatial_network.placement_jitter.run_entropy_maximisation_orientation_map_placement_jitter import DATA_FOLDER, TRAJ_NAME
  15. FIGURE_SAVE_PATH = '../../../figures/placement_jitter/'
  16. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  17. """
  18. Place a table legend on the axes.
  19. Creates a legend where the labels are not directly placed with the artists,
  20. but are used as row and column headers, looking like this:
  21. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  22. -------------------------------------------------------------
  23. row_labels[1] |
  24. row_labels[2] | <artists go there>
  25. row_labels[3] |
  26. Parameters
  27. ----------
  28. ax : `matplotlib.axes.Axes`
  29. The artist that contains the legend table, i.e. current axes instant.
  30. col_labels : list of str, optional
  31. A list of labels to be used as column headers in the legend table.
  32. `len(col_labels)` needs to match `ncol`.
  33. row_labels : list of str, optional
  34. A list of labels to be used as row headers in the legend table.
  35. `len(row_labels)` needs to match `len(handles) // ncol`.
  36. title_label : str, optional
  37. Label for the top left corner in the legend table.
  38. ncol : int
  39. Number of columns.
  40. Other Parameters
  41. ----------------
  42. Refer to `matplotlib.legend.Legend` for other parameters.
  43. """
  44. #################### same as `matplotlib.axes.Axes.legend` #####################
  45. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  46. if len(extra_args):
  47. raise TypeError('legend only accepts two non-keyword arguments')
  48. if col_labels is None and row_labels is None:
  49. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  50. ax.legend_._remove_method = ax._remove_legend
  51. return ax.legend_
  52. #################### modifications for table legend ############################
  53. else:
  54. ncol = kwargs.pop('ncol')
  55. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  56. title_label = [title_label]
  57. # blank rectangle handle
  58. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  59. # empty label
  60. empty = [""]
  61. # number of rows infered from number of handles and desired number of columns
  62. nrow = len(handles) // ncol
  63. # organise the list of handles and labels for table construction
  64. if col_labels is None:
  65. assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
  66. leg_handles = extra * nrow
  67. leg_labels = row_labels
  68. elif row_labels is None:
  69. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
  70. leg_handles = []
  71. leg_labels = []
  72. else:
  73. assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
  74. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
  75. leg_handles = extra + extra * nrow
  76. leg_labels = title_label + row_labels
  77. for col in range(ncol):
  78. if col_labels is not None:
  79. leg_handles += extra
  80. leg_labels += [col_labels[col]]
  81. leg_handles += handles[col*nrow:(col+1)*nrow]
  82. leg_labels += empty * nrow
  83. # Create legend
  84. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol+int(row_labels is not None), handletextpad=handletextpad, **kwargs)
  85. ax.legend_._remove_method = ax._remove_legend
  86. return ax.legend_
  87. def get_closest_correlation_length(traj, correlation_length):
  88. available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
  89. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
  90. if closest_length != correlation_length:
  91. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  92. correlation_length, closest_length))
  93. corr_len = closest_length
  94. return corr_len
  95. def gauss(x, *p):
  96. A, mu, sigma, B = p
  97. return A * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + B
  98. def plot_tuning_curve(traj, direction_idx, plot_run_names):
  99. seed_expl = traj.f_get('seed').f_get_range()
  100. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  101. label_range = set(label_expl)
  102. rate_frame = pd.Series(index=[seed_expl, label_expl])
  103. rate_frame.index.names = ["seed", "label"]
  104. dir_bins = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  105. rate_bins = [[] for i in range(len(dir_bins)-1)]
  106. for run_name, seed, label in zip(plot_run_names, seed_expl, label_expl):
  107. ex_tunings = traj.results.runs[run_name].ex_tunings
  108. binned_idx = np.digitize(ex_tunings, dir_bins)
  109. #TODO: Avareage over directions by recentering
  110. firing_rate_array = traj.results[run_name].firing_rate_array
  111. for bin_idx, rate in zip(binned_idx, firing_rate_array[:, direction_idx]):
  112. rate_bins[bin_idx].append(rate)
  113. rate_bins_mean = [np.mean(rate_bin) for rate_bin in rate_bins]
  114. rate_frame[seed, label] = firing_rate_array
  115. # TODO: Standart deviation also for the population
  116. rate_seed_mean = rate_frame.groupby(level=[1]).mean()
  117. rate_seed_std_dev = rate_frame.groupby(level=[1]).std()
  118. style_dict = {
  119. 'no conn': ['grey', 'dashed', '', 0],
  120. 'ellipsoid': ['blue', 'solid', 'x', 10.],
  121. 'circular': ['lightblue', 'solid', 'o', 8.]
  122. }
  123. fig, ax = plt.subplots(1, 1)
  124. for label in label_range:
  125. hdi_mean = rate_seed_mean[label]
  126. hdi_std = rate_seed_std_dev[label]
  127. ex_tunings = traj.results.runs[run_name].ex_tunings
  128. col, lin, mar, mar_size = style_dict[label]
  129. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  130. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  131. hdi_mean + hdi_std, alpha=0.4, color=col)
  132. ax.set_xlabel('Correlation length')
  133. ax.set_ylabel('Head Direction Index')
  134. ax.axvline(206.9, color='k', linewidth=0.5)
  135. ax.set_ylim(0.0, 1.0)
  136. ax.set_xlim(0.0, 400.)
  137. ax.legend()
  138. fig, ax = plt.subplots(1, 1)
  139. for run_idx, run_name in enumerate(plot_run_names):
  140. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  141. ex_tunings = traj.results.runs[run_name].ex_tunings
  142. coeff_list = []
  143. ex_tunings_plt = np.array(ex_tunings)
  144. sort_ids = ex_tunings_plt.argsort()
  145. ex_tunings_plt = ex_tunings_plt[sort_ids]
  146. firing_rate_array = traj.results[run_name].firing_rate_array
  147. # firing_rate_array = traj.f_get('firing_rate_array')
  148. rates_plt = firing_rate_array[:, direction_idx]
  149. rates_plt = rates_plt[sort_ids]
  150. ax.scatter(ex_tunings_plt, rates_plt / hertz, label=label, alpha=0.3)
  151. ax.legend()
  152. ax.set_xlabel("Angles (rad)")
  153. ax.set_ylabel("f (Hz)")
  154. ax.set_title('tuning curves', fontsize=16)
  155. if save_figs:
  156. plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png', dpi=200)
  157. def colorbar(mappable):
  158. from mpl_toolkits.axes_grid1 import make_axes_locatable
  159. import matplotlib.pyplot as plt
  160. last_axes = plt.gca()
  161. ax = mappable.axes
  162. fig = ax.figure
  163. divider = make_axes_locatable(ax)
  164. cax = divider.append_axes("right", size="5%", pad=0.05)
  165. cbar = fig.colorbar(mappable, cax=cax)
  166. plt.sca(last_axes)
  167. return cbar
  168. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names):
  169. max_val = 0
  170. for run_name in plot_run_names:
  171. fr_array = traj.results.runs[run_name].firing_rate_array
  172. f_rates = fr_array[:, direction_idx]
  173. run_max_val = np.max(f_rates)
  174. if run_max_val > max_val:
  175. # if traj.derived_parameters.runs[run_name].morphology.morph_label == 'ellipsoid':
  176. # n_id_max_rate = np.argmax(f_rates)
  177. max_val = run_max_val
  178. n_id_polar_plot = 609
  179. # Mark the neuron that is shown in Polar plot
  180. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  181. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  182. # Vertices for the plotted triangle
  183. tr_scale = 13.
  184. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  185. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  186. tr_vertices = np.array([[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  187. height = 4.5
  188. # color_bar_size = 0.05 * height + 0.05
  189. # width = 3 * height + color_bar_size
  190. width = 13.5
  191. fig, axes = plt.subplots(1, 3, figsize=(width, height))
  192. for ax, run_name in zip(axes, plot_run_names[::-1]):
  193. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  194. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  195. firing_rate_array = traj.results[run_name].firing_rate_array
  196. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  197. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  198. number_of_excitatory_neurons_per_row)),
  199. vmin=0, vmax=max_val, cmap='Reds')
  200. ax.set_title(label)
  201. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  202. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  203. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  204. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  205. # fig.suptitle('spatial firing rate map', fontsize=16)
  206. colorbar(c)
  207. fig.tight_layout()
  208. if save_figs:
  209. plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png', dpi=200)
  210. return n_id_polar_plot
  211. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names):
  212. max_val = 0
  213. for run_name in plot_run_names:
  214. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  215. f_rates = fr_array[:, direction_idx]
  216. run_max_val = np.max(f_rates)
  217. if run_max_val > max_val:
  218. max_val = run_max_val
  219. n_id_polar_plot = 52
  220. # Mark the neuron that is shown in Polar plot
  221. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  222. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  223. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  224. plot_run_names_sorted = [plot_run_names[1], plot_run_names[0]]
  225. fig, axes = plt.subplots(1, 2, figsize=(9.0, 4.5))
  226. for ax, run_name in zip(axes, plot_run_names_sorted):
  227. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  228. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  229. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  230. X, Y = get_correct_position_mesh(inh_positions)
  231. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  232. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  233. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  234. number_of_inhibitory_neurons_per_row)),
  235. vmin=0, vmax=max_val, cmap='Blues')
  236. ax.set_title(label)
  237. circle_r = 40.
  238. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  239. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  240. # fig.colorbar(c, ax=ax, label="f (Hz)")
  241. # fig.suptitle('spatial firing rate map', fontsize=16)
  242. colorbar(c)
  243. fig.tight_layout()
  244. if save_figs:
  245. plt.savefig(FIGURE_SAVE_PATH + 'inh_firing_rate_map.png', dpi=200)
  246. return n_id_polar_plot, max_val
  247. def plot_hdi_over_tuning(traj, plot_run_names):
  248. fig, ax = plt.subplots(1, 1)
  249. for run_idx, run_name in enumerate(plot_run_names):
  250. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  251. ex_tunings = traj.results.runs[run_name].ex_tunings
  252. ex_tunings_plt = np.array(ex_tunings)
  253. sort_ids = ex_tunings_plt.argsort()
  254. ex_tunings_plt = ex_tunings_plt[sort_ids]
  255. head_direction_indices = traj.results[run_name].head_direction_indices
  256. hdi_plt = head_direction_indices
  257. hdi_plt = hdi_plt[sort_ids]
  258. ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
  259. ax.legend()
  260. ax.set_xlabel("Angles (rad)")
  261. ax.set_ylabel("head direction index")
  262. ax.set_title('hdi over input tuning', fontsize=16)
  263. if save_figs:
  264. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png', dpi=200)
  265. def plot_axonal_clouds(traj, plot_run_names):
  266. n_ex = int(np.sqrt(traj.N_E))
  267. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  268. for ax, run_name in zip(axes, plot_run_names[::-1]):
  269. traj.f_set_crun(run_name)
  270. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  271. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  272. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  273. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  274. inhibitory_axonal_cloud_array]
  275. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  276. # TODO: Why was this transposed for plotting? (now changed)
  277. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  278. ax.set_title(label)
  279. # fig.colorbar(c, ax=ax, label="Tuning")
  280. if label != 'no conn' and axonal_clouds is not None:
  281. for i, p in enumerate(axonal_clouds):
  282. ell = p.get_ellipse()
  283. ax.add_artist(ell)
  284. # fig.suptitle('axonal cloud', fontsize=16)
  285. traj.f_restore_default()
  286. fig.tight_layout()
  287. if save_figs:
  288. plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png', dpi=200)
  289. def plot_orientation_maps_diff_scales(traj):
  290. n_ex = int(np.sqrt(traj.N_E))
  291. scale_run_names = []
  292. plot_scales = [0.0, 100.0, 200.0, 300.0]
  293. for scale in plot_scales:
  294. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  295. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  296. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  297. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  298. traj.f_set_crun(run_name)
  299. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  300. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  301. # TODO: Why was this transposed for plotting? (now changed)
  302. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  303. ax.set_title('Correlation length: {}'.format(scale))
  304. fig.colorbar(c, ax=ax, label="Tuning")
  305. # fig.suptitle('axonal cloud', fontsize=16)
  306. traj.f_restore_default()
  307. if save_figs:
  308. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png', dpi=200)
  309. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  310. n_ex = int(np.sqrt(traj.N_E))
  311. scale_run_names = []
  312. plot_scales = [0.0, 100.0, 200.0, 300.0, 400.0]
  313. for scale in plot_scales:
  314. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  315. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  316. print(scale_run_names)
  317. fig, axes = plt.subplots(1, 5, figsize=(18., 4.5))
  318. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  319. traj.f_set_crun(run_name)
  320. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  321. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  322. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  323. inhibitory_axonal_cloud_array]
  324. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  325. # TODO: Why was this transposed for plotting? (now changed)
  326. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  327. # ax.set_title('Correlation length: {}'.format(scale))
  328. # fig.colorbar(c, ax=ax, label="Tuning")
  329. ax.set_xticks([])
  330. ax.set_yticks([])
  331. p1 = axonal_clouds[44]
  332. ell = p1.get_ellipse()
  333. ell._linewidth = 5.
  334. ax.add_artist(ell)
  335. p2 = axonal_clouds[77]
  336. circ_r = 2 * np.sqrt(2500.)
  337. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  338. circ._linewidth = 5.
  339. ax.add_artist(circ)
  340. # fig.suptitle('axonal cloud', fontsize=16)
  341. traj.f_restore_default()
  342. fig.tight_layout()
  343. if save_figs:
  344. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales_with_ellipse.png', dpi=200)
  345. def plot_excitatory_condensed_polar_plot(traj, plot_run_names, polar_plot_id):
  346. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  347. directions_plt = list(directions)
  348. directions_plt.append(directions[0])
  349. fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
  350. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  351. # sorted_ids = np.argsort(head_direction_indices)
  352. # plot_n_idx = sorted_ids[-75]
  353. plot_n_idx = polar_plot_id
  354. line_styles = ['dotted', 'solid', 'dashed']
  355. colors = ['r', 'lightsalmon', 'grey']
  356. max_rate = 0.0
  357. for run_idx, run_name in enumerate(plot_run_names):
  358. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  359. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  360. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  361. run_max_rate = np.max(rate_plot)
  362. if run_max_rate > max_rate:
  363. max_rate = run_max_rate
  364. rate_plot.append(rate_plot[0])
  365. ax.plot(directions_plt, rate_plot, label=label, color=colors[run_idx], linestyle=line_styles[run_idx])
  366. # ax.set_title('Firing Rate')
  367. ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  368. # TODO: Set ticks for polar
  369. ticks = [30., 60., 90.]
  370. ax.set_rticks(ticks)
  371. ax.set_rlabel_position(230)
  372. ax.legend(loc='upper center', bbox_to_anchor=(0.2, 1.05),
  373. fancybox=True, shadow=True)
  374. plt.tight_layout()
  375. if save_figs:
  376. plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png', dpi=200)
  377. def plot_inhibitory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, max_rate):
  378. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  379. directions_plt = list(directions)
  380. directions_plt.append(directions[0])
  381. fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
  382. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  383. # sorted_ids = np.argsort(head_direction_indices)
  384. # plot_n_idx = sorted_ids[-75]
  385. plot_n_idx = polar_plot_id
  386. line_styles = ['dotted', 'solid']
  387. colors = ['b', 'lightblue']
  388. for run_idx, run_name in enumerate(plot_run_names[:2]):
  389. # ax = axes[max_hdi_idx, run_idx]
  390. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  391. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  392. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  393. rate_plot.append(rate_plot[0])
  394. ax.plot(directions_plt, rate_plot, label=label, color=colors[run_idx], linestyle=line_styles[run_idx])
  395. # ax.set_title('Inh. Firing Rate')
  396. # TODO: Set ticks for polar
  397. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  398. ticks = [40., 80., 120.]
  399. ax.set_rticks(ticks)
  400. ax.set_rlabel_position(230)
  401. ax.legend(loc='upper center', bbox_to_anchor=(0.2, 1.05),
  402. fancybox=True, shadow=True)
  403. plt.tight_layout()
  404. if save_figs:
  405. plt.savefig(FIGURE_SAVE_PATH + 'condensed_inhibitory_polar_plot.png', dpi=200)
  406. def plot_hdi_over_corr_len(traj, plot_run_names):
  407. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  408. seed_expl = traj.f_get('seed').f_get_range()
  409. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  410. label_range = set(label_expl)
  411. hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  412. hdi_frame.index.names = ["corr_len", "seed", "label"]
  413. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  414. ex_tunings = traj.results.runs[run_name].ex_tunings
  415. head_direction_indices = traj.results[run_name].head_direction_indices
  416. hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  417. # TODO: Standart deviation also for the population
  418. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
  419. hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
  420. # Ellipsoid markers
  421. rx, ry = 5., 12.
  422. # area = rx * ry * np.pi * 2.
  423. area = 1.
  424. theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
  425. verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
  426. style_dict = {
  427. 'no conn': ['grey', 'dashed', '', 0],
  428. 'ellipsoid': ['blue', 'solid', verts, 10.],
  429. 'circular': ['lightblue', 'solid', 'o', 8.]
  430. }
  431. # colors = ['blue', 'grey', 'lightblue']
  432. # linestyles = ['solid', 'dashed', 'solid']
  433. # markers = [verts, '', 'o']
  434. fig, ax = plt.subplots(1, 1)
  435. for label in label_range:
  436. hdi_mean = hdi_exc_n_and_seed_mean[:, label]
  437. hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
  438. corr_len_range = hdi_mean.keys().to_numpy()
  439. col, lin, mar, mar_size = style_dict[label]
  440. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  441. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  442. hdi_mean + hdi_std, alpha=0.4, color=col)
  443. ax.set_xlabel('Correlation length')
  444. ax.set_ylabel('Head Direction Index')
  445. ax.axvline(206.9, color='k', linewidth=0.5)
  446. ax.set_ylim(0.0,1.0)
  447. ax.set_xlim(0.0,400.)
  448. ax.legend()
  449. if save_figs:
  450. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
  451. def plot_hdi_histogram_excitatory(traj, plot_run_names):
  452. labels = []
  453. hdis = []
  454. colors = ['black', 'red', 'green']
  455. for run_idx, run_name in enumerate(plot_run_names):
  456. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  457. labels.append(label)
  458. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  459. hdis.append(head_direction_indices)
  460. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  461. ax.hist(hdis, color=colors, label=labels, bins=30)
  462. for hdi, color in zip(hdis, colors):
  463. mean_hdi = np.mean(hdi)
  464. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  465. ax.set_xlabel("HDI")
  466. ax.legend()
  467. fig.tight_layout()
  468. if save_figs:
  469. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
  470. def plot_hdi_violin_excitatory(traj, plot_run_names):
  471. labels = []
  472. hdis = []
  473. colors = ['black', 'red', 'green']
  474. no_conn_hdi = 0.
  475. for run_idx, run_name in enumerate(plot_run_names):
  476. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  477. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  478. if label == 'no conn':
  479. no_conn_hdi = np.mean(head_direction_indices)
  480. else:
  481. labels.append(label)
  482. hdis.append(sorted(head_direction_indices))
  483. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  484. # hdis = np.array(hdis)
  485. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  486. viol_plt['cmeans'].set_color('black')
  487. for pc in viol_plt['bodies']:
  488. pc.set_facecolor('red')
  489. pc.set_edgecolor('black')
  490. pc.set_alpha(0.7)
  491. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  492. ax.annotate('no conn', xy=(0.45,0.48), xycoords='axes fraction')
  493. ax.set_xticks(np.arange(1, len(labels) + 1))
  494. ax.set_xticklabels(labels)
  495. ax.set_ylabel('HDI')
  496. fig.tight_layout()
  497. if save_figs:
  498. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
  499. def plot_hdi_violin_inhibitory(traj, plot_run_names):
  500. labels = []
  501. hdis = []
  502. colors = ['black', 'red']
  503. for run_idx, run_name in enumerate(plot_run_names):
  504. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  505. if label != 'no conn':
  506. labels.append(label)
  507. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  508. hdis.append(sorted(head_direction_indices))
  509. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  510. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  511. viol_plt['cmeans'].set_color('black')
  512. for pc in viol_plt['bodies']:
  513. pc.set_facecolor('blue')
  514. pc.set_edgecolor('black')
  515. pc.set_alpha(0.7)
  516. ax.set_xticks(np.arange(1, len(labels) + 1))
  517. ax.set_xticklabels(labels)
  518. ax.set_ylabel('HDI')
  519. fig.tight_layout()
  520. if save_figs:
  521. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
  522. def plot_hdi_violin_combined(traj, plot_run_names):
  523. labels = []
  524. inh_hdis = []
  525. exc_hdis = []
  526. no_conn_hdi = 0.
  527. colors = ['black', 'red']
  528. for run_idx, run_name in enumerate(plot_run_names):
  529. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  530. if label != 'no conn':
  531. labels.append(label)
  532. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  533. inh_hdis.append(sorted(inh_head_direction_indices))
  534. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  535. exc_hdis.append(sorted(exc_head_direction_indices))
  536. else:
  537. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  538. no_conn_hdi = np.mean(exc_head_direction_indices)
  539. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  540. inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
  541. # viol_plt['cmeans'].set_color('black')
  542. #
  543. # for pc in viol_plt['bodies']:
  544. # pc.set_facecolor('blue')
  545. # pc.set_edgecolor('black')
  546. # pc.set_alpha(0.7)
  547. for b in inh_viol_plt['bodies']:
  548. m = np.mean(b.get_paths()[0].vertices[:, 0])
  549. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  550. b.set_color('b')
  551. exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
  552. for b in exc_viol_plt['bodies']:
  553. m = np.mean(b.get_paths()[0].vertices[:, 0])
  554. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  555. b.set_color('r')
  556. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  557. ax.annotate('no conn', xy=(0.45, 0.48), xycoords='axes fraction')
  558. ax.set_xticks(np.arange(1, len(labels) + 1))
  559. ax.set_xticklabels(labels)
  560. ax.set_ylabel('HDI')
  561. fig.tight_layout()
  562. if save_figs:
  563. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
  564. def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names):
  565. labels = []
  566. inh_hdis = []
  567. exc_hdis = []
  568. no_conn_hdi = 0.
  569. colors = ['black', 'red']
  570. for run_idx, run_name in enumerate(plot_run_names):
  571. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  572. if label != 'no conn':
  573. labels.append(label)
  574. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  575. inh_hdis.append(sorted(inh_head_direction_indices))
  576. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  577. exc_hdis.append(sorted(exc_head_direction_indices))
  578. else:
  579. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  580. no_conn_hdi = np.mean(exc_head_direction_indices)
  581. fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
  582. inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
  583. for b in inh_ell_viol_plt['bodies']:
  584. m = np.mean(b.get_paths()[0].vertices[:, 0])
  585. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  586. b.set_color('b')
  587. mean_line = inh_ell_viol_plt['cmeans']
  588. mean_line.set_color('b')
  589. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  590. exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
  591. for b in exc_ell_viol_plt['bodies']:
  592. m = np.mean(b.get_paths()[0].vertices[:, 0])
  593. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  594. b.set_color('r')
  595. mean_line = exc_ell_viol_plt['cmeans']
  596. mean_line.set_color('r')
  597. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  598. inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
  599. for b in inh_cir_viol_plt['bodies']:
  600. m = np.mean(b.get_paths()[0].vertices[:, 0])
  601. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  602. b.set_color('b')
  603. mean_line = inh_cir_viol_plt['cmeans']
  604. mean_line.set_color('b')
  605. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  606. exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
  607. for b in exc_cir_viol_plt['bodies']:
  608. m = np.mean(b.get_paths()[0].vertices[:, 0])
  609. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  610. b.set_color('r')
  611. mean_line = exc_cir_viol_plt['cmeans']
  612. mean_line.set_color('r')
  613. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  614. ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  615. ax.axvline(1.0, color='k')
  616. ax.annotate('no conn', xy=(0.75, 0.415), xycoords='axes fraction')
  617. ax.set_xlim(0.5, 1.5)
  618. ax.set_ylim(0.0, 1.0)
  619. ax.set_xticks([0.75, 1.25])
  620. ax.set_xticklabels(['circular', 'ellipsoid'])
  621. ax.set_ylabel('HDI')
  622. fig.tight_layout()
  623. if save_figs:
  624. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  625. def plot_hdi_histogram_inhibitory(traj, plot_run_names):
  626. labels = []
  627. hdis = []
  628. colors = ['black', 'red']
  629. for run_idx, run_name in enumerate(plot_run_names):
  630. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  631. if label != 'no conn':
  632. labels.append(label)
  633. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  634. hdis.append(head_direction_indices)
  635. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  636. ax.hist(hdis, color=colors, label=labels, bins=30)
  637. for hdi, color in zip(hdis, colors):
  638. mean_hdi = np.mean(hdi)
  639. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  640. ax.set_xlabel("HDI")
  641. ax.legend()
  642. fig.tight_layout()
  643. if save_figs:
  644. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
  645. def filter_run_names_by_par_dict(traj, par_dict):
  646. run_name_list = []
  647. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  648. traj.f_set_crun(run_name)
  649. paramters_equal = True
  650. for key, val in par_dict.items():
  651. if (traj.par[key] != val):
  652. paramters_equal = False
  653. if paramters_equal:
  654. run_name_list.append(run_name)
  655. traj.f_restore_default()
  656. return run_name_list
  657. def plot_exc_and_inh_hdi_over_corr_len(traj, plot_run_names):
  658. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  659. seed_expl = traj.f_get('seed').f_get_range()
  660. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  661. label_range = set(label_expl)
  662. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  663. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  664. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  665. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  666. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  667. ex_tunings = traj.results.runs[run_name].ex_tunings
  668. head_direction_indices = traj.results[run_name].head_direction_indices
  669. #TODO: Actual correlation lengths
  670. # actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  671. exc_hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  672. inh_head_direction_indices = traj.results[run_name].inh_head_direction_indices
  673. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_head_direction_indices)
  674. # TODO: Standart deviation also for the population
  675. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  676. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  677. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  678. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  679. exc_style_dict = {
  680. 'no conn': ['grey', 'dashed', '', 0],
  681. 'ellipsoid': ['red', 'solid', '^', 8.],
  682. 'circular': ['lightsalmon', 'solid', '^', 8.]
  683. }
  684. inh_style_dict = {
  685. 'no conn': ['grey', 'dashed', '', 0],
  686. 'ellipsoid': ['blue', 'solid', 'o', 8.],
  687. 'circular': ['lightblue', 'solid', 'o', 8.]
  688. }
  689. fig, ax = plt.subplots(1, 1)
  690. for label in label_range:
  691. if label == 'no conn':
  692. ax.axhline(exc_hdi_n_and_seed_mean[0, label], color='grey', linestyle='--')
  693. ax.annotate('input', xy=(1.01, 0.44), xycoords='axes fraction')
  694. continue
  695. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  696. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  697. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  698. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  699. corr_len_range = exc_hdi_mean.keys().to_numpy()
  700. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  701. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  702. ax.plot(corr_len_range, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin, markersize=exc_mar_size, alpha=0.5)
  703. plt.fill_between(corr_len_range, exc_hdi_mean - exc_hdi_std,
  704. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  705. ax.plot(corr_len_range, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin, markersize=inh_mar_size, alpha=0.5)
  706. plt.fill_between(corr_len_range, inh_hdi_mean - inh_hdi_std,
  707. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  708. ax.set_xlabel('Correlation length')
  709. ax.set_ylabel('Head Direction Index')
  710. ax.axvline(206.9, color='k', linewidth=0.5)
  711. ax.set_ylim(0.0,1.0)
  712. ax.set_xlim(0.0,400.)
  713. tablelegend(ax, ncol=2, bbox_to_anchor=(1, 1),
  714. row_labels=['exc.', 'inh.'],
  715. col_labels=['ellipsoid', 'circular'],
  716. title_label='')
  717. # plt.legend()
  718. if save_figs:
  719. plt.savefig(FIGURE_SAVE_PATH + 'exc_and_inh_hdi_over_corr_len_scaled.png', dpi=200)
  720. def plot_inhibitory_condensed_polar_plot_with_input(traj, plot_run_names, polar_plot_id, max_rate):
  721. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  722. directions_plt = list(directions)
  723. directions_plt.append(directions[0])
  724. fig, ax = plt.subplots(1, 1, figsize=(3.5, 3.5), subplot_kw=dict(projection='polar'))
  725. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  726. # sorted_ids = np.argsort(head_direction_indices)
  727. # plot_n_idx = sorted_ids[-75]
  728. plot_n_idx = polar_plot_id
  729. line_styles = ['dotted', 'solid', 'dashed']
  730. colors = ['b', 'lightblue', 'k']
  731. for run_idx, run_name in enumerate(plot_run_names):
  732. # ax = axes[max_hdi_idx, run_idx]
  733. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  734. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  735. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  736. rate_plot.append(rate_plot[0])
  737. ax.plot(directions_plt, rate_plot, label=label, color=colors[run_idx], linestyle=line_styles[run_idx])
  738. # ax.set_title('Inh. Firing Rate')
  739. # TODO: Set ticks for polar
  740. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  741. ticks = [40., 80., 120.]
  742. ax.set_rticks(ticks)
  743. ax.set_rlabel_position(230)
  744. ax.legend(loc='upper center', bbox_to_anchor=(0.2, 1.05),
  745. fancybox=True, shadow=True)
  746. plt.tight_layout()
  747. if save_figs:
  748. plt.savefig(FIGURE_SAVE_PATH + 'condensed_inhibitory_polar_plot.png', dpi=200)
  749. if __name__ == "__main__":
  750. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  751. NO_LOADING = 0
  752. FULL_LOAD = 2
  753. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
  754. traj.v_auto_load = True
  755. save_figs = True
  756. plot_corr_len = get_closest_correlation_length(traj, 200.0)
  757. par_dict = {'seed': 1, 'correlation_length': plot_corr_len, 'placement_jitter': 0.0}
  758. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  759. print(plot_run_names)
  760. direction_idx = 6
  761. dir_indices = [0, 3, 6, 9]
  762. plot_axonal_clouds(traj, plot_run_names)
  763. #
  764. ex_polar_plot_id = plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names)
  765. #
  766. in_polar_plot_id, in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names)
  767. #
  768. # plot_orientation_maps_diff_scales_with_ellipse(traj)
  769. #
  770. # plot_hdi_histogram_inhibitory(traj, plot_run_names)
  771. #
  772. # plot_hdi_histogram_excitatory(traj, plot_run_names)
  773. #
  774. # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
  775. plot_excitatory_condensed_polar_plot(traj, plot_run_names, ex_polar_plot_id)
  776. plot_inhibitory_condensed_polar_plot_with_input(traj, plot_run_names, in_polar_plot_id, in_max_rate)
  777. # plot_hdi_violin_combined(traj, plot_run_names)
  778. #
  779. plot_hdi_violin_combined_and_overlayed(traj, plot_run_names)
  780. # plot_exc_and_inh_hdi_over_corr_len(traj, traj.f_get_run_names())
  781. # par_dict = {'correlation_length': plot_corr_len}
  782. # single_corr_len_run_names = filter_run_names_by_par_dict(traj, par_dict)
  783. # plot_tuning_curve(traj, direction_idx, single_corr_len_run_names)
  784. if not save_figs:
  785. plt.show()
  786. traj.f_restore_default()