final_analysis_and_plotting_perlin_map.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246
  1. import itertools
  2. import os
  3. import matplotlib.legend as mlegend
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. from matplotlib.patches import Ellipse
  8. from matplotlib.patches import Polygon
  9. from matplotlib.patches import Rectangle
  10. from mpl_toolkits.axes_grid1 import ImageGrid
  11. from pypet import Trajectory
  12. from pypet.brian2 import Brian2MonitorResult
  13. from scripts.spatial_maps.spatial_network_layout import Interneuron, get_position_mesh
  14. from scripts.spatial_maps.correlation_length_fit.correlation_length_fit import \
  15. correlation_length_fit_dict
  16. from scripts.model_figure.figure_utils import remove_frame, remove_ticks, add_length_scale, cm_per_inch, \
  17. panel_size, head_direction_input_colormap
  18. from scripts.spatial_network.perlin_map.run_simulation_perlin_map import DATA_FOLDER, TRAJ_NAME, \
  19. get_input_head_directions, POLARIZED, CIRCULAR, NO_SYNAPSES
  20. plt.style.use('../../model_figure/figures.mplstyle')
  21. FIGURE_SAVE_PATH = '../../../figures/' + TRAJ_NAME + '/'
  22. FIGURE_SAVE_FORMAT = '.pdf'
  23. save_figs = False
  24. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  25. """
  26. Place a table legend on the axes.
  27. Creates a legend where the labels are not directly placed with the artists,
  28. but are used as row and column headers, looking like this:
  29. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  30. -------------------------------------------------------------
  31. row_labels[1] |
  32. row_labels[2] | <artists go there>
  33. row_labels[3] |
  34. Parameters
  35. ----------
  36. ax : `matplotlib.axes.Axes`
  37. The artist that contains the legend table, i.e. current axes instant.
  38. col_labels : list of str, optional
  39. A list of labels to be used as column headers in the legend table.
  40. `len(col_labels)` needs to match `ncol`.
  41. row_labels : list of str, optional
  42. A list of labels to be used as row headers in the legend table.
  43. `len(row_labels)` needs to match `len(handles) // ncol`.
  44. title_label : str, optional
  45. Label for the top left corner in the legend table.
  46. ncol : int
  47. Number of columns.
  48. Other Parameters
  49. ----------------
  50. Refer to `matplotlib.legend.Legend` for other parameters.
  51. """
  52. #################### same as `matplotlib.axes.Axes.legend` #####################
  53. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  54. if len(extra_args):
  55. raise TypeError('legend only accepts two non-keyword arguments')
  56. if col_labels is None and row_labels is None:
  57. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  58. ax.legend_._remove_method = ax._remove_legend
  59. return ax.legend_
  60. #################### modifications for table legend ############################
  61. else:
  62. ncol = kwargs.pop('ncol')
  63. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  64. title_label = [title_label]
  65. # blank rectangle handle
  66. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  67. # empty label
  68. empty = [""]
  69. # number of rows infered from number of handles and desired number of columns
  70. nrow = len(handles) // ncol
  71. # organise the list of handles and labels for table construction
  72. if col_labels is None:
  73. assert nrow == len(
  74. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  75. nrow, len(row_labels))
  76. leg_handles = extra * nrow
  77. leg_labels = row_labels
  78. elif row_labels is None:
  79. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  80. ncol, len(col_labels))
  81. leg_handles = []
  82. leg_labels = []
  83. else:
  84. assert nrow == len(
  85. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  86. nrow, len(row_labels))
  87. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  88. ncol, len(col_labels))
  89. leg_handles = extra + extra * nrow
  90. leg_labels = title_label + row_labels
  91. for col in range(ncol):
  92. if col_labels is not None:
  93. leg_handles += extra
  94. leg_labels += [col_labels[col]]
  95. leg_handles += handles[col * nrow:(col + 1) * nrow]
  96. leg_labels += empty * nrow
  97. # Create legend
  98. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol + int(row_labels is not None),
  99. handletextpad=handletextpad, **kwargs)
  100. ax.legend_._remove_method = ax._remove_legend
  101. return ax.legend_
  102. def get_closest_scale(traj, scale):
  103. available_lengths = sorted(list(set(traj.f_get("scale").f_get_range())))
  104. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - scale))]
  105. if closest_length != scale:
  106. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  107. scale, closest_length))
  108. corr_len = closest_length
  109. return corr_len
  110. def get_closest_fit_correlation_length(traj, fit_corr_len):
  111. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  112. available_lengths = list(corr_len_fit_dict.values())
  113. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - fit_corr_len))]
  114. closest_corresponding_scale = list(corr_len_fit_dict.keys())[list(corr_len_fit_dict.values()).index(closest_length)]
  115. if closest_length != fit_corr_len:
  116. print("Warning: desired fit correlation length {:.1f} not available. Taking {:.1f} instead".format(
  117. fit_corr_len, closest_length))
  118. return closest_corresponding_scale
  119. def colorbar(mappable):
  120. from mpl_toolkits.axes_grid1 import make_axes_locatable
  121. import matplotlib.pyplot as plt
  122. last_axes = plt.gca()
  123. ax = mappable.axes
  124. fig = ax.figure
  125. divider = make_axes_locatable(ax)
  126. cax = divider.append_axes("right", size="5%", pad=0.05)
  127. cbar = fig.colorbar(mappable, cax=cax)
  128. plt.sca(last_axes)
  129. return cbar
  130. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, exemplary_excitatory_neuron_id):
  131. max_val = 0
  132. for run_name in plot_run_names:
  133. fr_array = traj.results.runs[run_name].firing_rate_array
  134. f_rates = fr_array[:, direction_idx]
  135. run_max_val = np.max(f_rates)
  136. if run_max_val > max_val:
  137. # if traj.derived_parameters.runs[run_name].morphology.morph_label == POLARIZED:
  138. # n_id_max_rate = np.argmax(f_rates)
  139. max_val = run_max_val
  140. n_id_polar_plot = exemplary_excitatory_neuron_id
  141. # Mark the neuron that is shown in polar plot
  142. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  143. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  144. # Vertices for the plotted triangle
  145. tr_scale = 30.
  146. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  147. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  148. tr_vertices = np.array(
  149. [[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  150. height = 1 * panel_size
  151. width = 3 * panel_size
  152. fig = plt.figure(figsize=(width, height))
  153. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  154. cbar_mode="single",
  155. cbar_size="7%", cbar_pad=panel_size / 10.0,
  156. nrows_ncols=(1, 3))
  157. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  158. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  159. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  160. firing_rate_array = traj.results[run_name].firing_rate_array
  161. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  162. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  163. number_of_excitatory_neurons_per_row)),
  164. vmin=0, vmax=max_val, cmap='Reds')
  165. # ax.set_title(adjust_label(label))
  166. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  167. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  168. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  169. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  170. for spine in ax.spines.values():
  171. spine.set_edgecolor("grey")
  172. spine.set_linewidth(1)
  173. remove_ticks(ax)
  174. # fig.suptitle('spatial firing rate map', fontsize=16)
  175. ax.cax.colorbar(c)
  176. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  177. # fig.tight_layout()
  178. if save_figs:
  179. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_excitatory' + FIGURE_SAVE_FORMAT, dpi=300)
  180. plt.close(fig)
  181. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron):
  182. max_val = 0
  183. for run_name in plot_run_names:
  184. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  185. f_rates = fr_array[:, direction_idx]
  186. run_max_val = np.max(f_rates)
  187. if run_max_val > max_val:
  188. max_val = run_max_val
  189. n_id_polar_plot = selected_inhibitory_neuron
  190. # Mark the neuron that is shown in Polar plot
  191. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  192. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  193. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  194. width = 3 * panel_size
  195. height = panel_size
  196. fig = plt.figure(figsize=(width, height))
  197. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  198. cbar_mode="single",
  199. cbar_size="7%", cbar_pad=panel_size / 10.0,
  200. nrows_ncols=(1, 3))
  201. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  202. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  203. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  204. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  205. X, Y = get_position_mesh(inh_positions)
  206. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  207. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  208. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  209. number_of_inhibitory_neurons_per_row)),
  210. vmin=0, vmax=max_val, cmap='Blues')
  211. # ax.set_title(adjust_label(label))
  212. circle_r = 40.
  213. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  214. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  215. for spine in ax.spines.values():
  216. spine.set_edgecolor("grey")
  217. spine.set_linewidth(0.5)
  218. remove_ticks(ax)
  219. # fig.colorbar(c, ax=ax, label="f (Hz)")
  220. # fig.suptitle('spatial firing rate map', fontsize=16)
  221. cbar = ax.cax.colorbar(c)
  222. cbar_ticks = range(0,151,30)
  223. cbar.ax.set_yticks(cbar_ticks)
  224. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  225. # fig.tight_layout()
  226. if save_figs:
  227. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_inhibitory' + FIGURE_SAVE_FORMAT, dpi=300)
  228. plt.close(fig)
  229. return max_val
  230. def normal_labels(label):
  231. if label == POLARIZED:
  232. label = 'polarized'
  233. elif label == NO_SYNAPSES:
  234. label = 'no interneurons'
  235. return label
  236. def short_labels(label):
  237. if label == POLARIZED:
  238. label = 'pol.'
  239. elif label == CIRCULAR:
  240. label = 'cir.'
  241. elif label == "no conn":
  242. label = "no inh."
  243. return label
  244. def plot_input_map(traj, run_name, figsize=(panel_size, panel_size), figname='input_map' + FIGURE_SAVE_FORMAT):
  245. n_ex = int(np.sqrt(traj.N_E))
  246. width, height = figsize
  247. fig = plt.figure(figsize=(width, height))
  248. axes = ImageGrid(fig, (0.15, 0.1, 0.7, 0.8), axes_pad=panel_size / 3.0, cbar_location="right", cbar_mode=None,
  249. nrows_ncols=(1, 1))
  250. ax = axes[0]
  251. traj.f_set_crun(run_name)
  252. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  253. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  254. scale_length = traj.morphology.long_axis * 2
  255. traj.f_restore_default()
  256. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  257. for spine in ax.spines.values():
  258. spine.set_edgecolor("grey")
  259. spine.set_linewidth(0.5)
  260. remove_ticks(ax)
  261. start_scale_x = 100
  262. end_scale_x = start_scale_x + scale_length
  263. start_scale_y = -70
  264. end_scale_y = start_scale_y
  265. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  266. ax.cax.set_visible(False)
  267. if save_figs:
  268. plt.savefig(FIGURE_SAVE_PATH + figname)
  269. plt.close(fig)
  270. def plot_example_input_maps(traj, figsize=(panel_size, panel_size), figname='perlin_example_input_maps' + FIGURE_SAVE_FORMAT):
  271. n_ex = int(np.sqrt(traj.N_E))
  272. width, height = figsize
  273. fig = plt.figure(figsize=(width, height))
  274. axes = ImageGrid(fig, (0.1, 0.1, 0.85, 0.85), axes_pad=panel_size / 12.0, nrows_ncols=(3, 3))
  275. fit_corr_len_list = [20, 50, 100]
  276. seed_list = [1, 2, 3]
  277. corr_and_seed = itertools.product(fit_corr_len_list, seed_list)
  278. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  279. for idx, (ax, (fit_corr_len, seed)) in enumerate(zip(axes, corr_and_seed)):
  280. corresp_scale = get_closest_fit_correlation_length(traj, fit_corr_len)
  281. par_dict = {'seed': seed, 'scale': corresp_scale}
  282. run_name = filter_run_names_by_par_dict(traj, par_dict)[0]
  283. traj.f_set_crun(run_name)
  284. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  285. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  286. scale_length = traj.morphology.long_axis * 2
  287. traj.f_restore_default()
  288. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  289. for spine in ax.spines.values():
  290. spine.set_edgecolor("grey")
  291. spine.set_linewidth(0.5)
  292. remove_ticks(ax)
  293. ax.set_ylabel('{:3.0f} um'.format(corr_len_fit_dict[corresp_scale]))
  294. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  295. plt_polar_id = 255
  296. plt_polar_coordinates = inhibitory_axonal_cloud_array[plt_polar_id]
  297. plt_polar_neuron = Interneuron(plt_polar_coordinates[0],
  298. plt_polar_coordinates[1],
  299. traj.morphology.long_axis,
  300. traj.morphology.short_axis,
  301. plt_polar_coordinates[2])
  302. circ_radius = np.sqrt(traj.morphology.long_axis * traj.morphology.short_axis)
  303. plt_circ_id = 65
  304. plt_circ_coordinates = inhibitory_axonal_cloud_array[plt_circ_id]
  305. plt_circ_neuron = Interneuron(plt_circ_coordinates[0],
  306. plt_circ_coordinates[1],
  307. circ_radius,
  308. circ_radius,
  309. plt_circ_coordinates[2])
  310. plt_interneurons = [plt_polar_neuron, plt_circ_neuron]
  311. for p in plt_interneurons:
  312. ell = p.get_ellipse()
  313. edgecolor = 'black'
  314. alpha = 1
  315. zorder = 10
  316. linewidth = 1.
  317. ell.set_edgecolor(edgecolor)
  318. ell.set_alpha(alpha)
  319. ell.set_zorder(zorder)
  320. ell.set_linewidth(linewidth)
  321. ax.add_artist(ell)
  322. if idx == 8:
  323. start_scale_x = 100
  324. end_scale_x = start_scale_x + scale_length
  325. start_scale_y = -70
  326. end_scale_y = start_scale_y
  327. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  328. ax.cax.set_visible(False)
  329. if save_figs:
  330. plt.savefig(FIGURE_SAVE_PATH + figname)
  331. plt.close(fig)
  332. def plot_axonal_clouds(traj, plot_run_names):
  333. n_ex = int(np.sqrt(traj.N_E))
  334. height = 1 * panel_size
  335. width = 3 * panel_size
  336. cluster_positions = [(250, 250), (750, 200), (450, 600)]
  337. cluster_sizes = [4, 4, 4]
  338. selected_neurons = []
  339. inhibitory_positions = traj.results.runs[plot_run_names[0]].inhibitory_axonal_cloud_array[:, :2]
  340. for cluster_position, number_of_neurons_in_cluster in zip(cluster_positions, cluster_sizes):
  341. selection = get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster,
  342. inhibitory_positions)
  343. selected_neurons.extend(selection)
  344. fig = plt.figure(figsize=(width, height))
  345. axes = ImageGrid(fig, 111, axes_pad=0.15, cbar_location="right", cbar_mode="single", cbar_size="7%",
  346. nrows_ncols=(1, 3))
  347. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  348. traj.f_set_crun(run_name)
  349. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  350. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  351. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  352. axonal_clouds = [Interneuron(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  353. inhibitory_axonal_cloud_array]
  354. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  355. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  356. # ax.set_title(normal_labels(label))
  357. ax.set_aspect('equal')
  358. remove_frame(ax)
  359. remove_ticks(ax)
  360. # fig.colorbar(c, ax=ax, label="Tuning")
  361. if label != NO_SYNAPSES and axonal_clouds is not None:
  362. for i, p in enumerate(axonal_clouds):
  363. ell = p.get_ellipse()
  364. if i in selected_neurons:
  365. edgecolor = 'black'
  366. alpha = 1
  367. zorder = 10
  368. linewidth = 1
  369. else:
  370. edgecolor = 'gray'
  371. alpha = 0.5
  372. zorder = 1
  373. linewidth = 0.3
  374. ell.set_edgecolor(edgecolor)
  375. ell.set_alpha(alpha)
  376. ell.set_zorder(zorder)
  377. ell.set_linewidth(linewidth)
  378. ax.add_artist(ell)
  379. traj.f_set_crun(plot_run_names[0])
  380. scale_length = traj.morphology.long_axis * 2
  381. traj.f_restore_default()
  382. start_scale_x = 100
  383. end_scale_x = start_scale_x + scale_length
  384. start_scale_y = -70
  385. end_scale_y = start_scale_y
  386. add_length_scale(axes[0], scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  387. axes[1].set_yticklabels([])
  388. axes[2].set_yticklabels([])
  389. axes[0].cax.set_visible(False)
  390. traj.f_restore_default()
  391. if save_figs:
  392. plt.savefig(FIGURE_SAVE_PATH + 'B_i_axonal_clouds' + FIGURE_SAVE_FORMAT)
  393. plt.close(fig)
  394. def get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster, positions):
  395. position = np.array(cluster_position)
  396. distance_vectors = positions - np.expand_dims(position, 0).repeat(positions.shape[0],
  397. axis=0)
  398. distances = np.linalg.norm(distance_vectors, axis=1)
  399. selection = list(np.argpartition(distances, number_of_neurons_in_cluster)[:number_of_neurons_in_cluster])
  400. return selection
  401. def plot_orientation_maps_diff_scales(traj):
  402. n_ex = int(np.sqrt(traj.N_E))
  403. scale_run_names = []
  404. plot_scales = [0.0, 100.0, 200.0, 300.0]
  405. for scale in plot_scales:
  406. par_dict = {'seed': 1, 'correlation_length': get_closest_scale(traj, scale), 'long_axis': 100.}
  407. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  408. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  409. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  410. traj.f_set_crun(run_name)
  411. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  412. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  413. # TODO: Why was this transposed for plotting? (now changed)
  414. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  415. ax.set_title('Correlation length: {}'.format(scale))
  416. fig.colorbar(c, ax=ax, label="Tuning")
  417. # fig.suptitle('axonal cloud', fontsize=16)
  418. traj.f_restore_default()
  419. if save_figs:
  420. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales' + FIGURE_SAVE_FORMAT)
  421. plt.close(fig)
  422. def plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_idx,
  423. figname=FIGURE_SAVE_PATH + 'D_polar_plot_excitatory' + FIGURE_SAVE_FORMAT):
  424. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  425. directions_plt = list(directions)
  426. directions_plt.append(directions[0])
  427. height = panel_size
  428. width = panel_size
  429. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  430. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  431. # sorted_ids = np.argsort(head_direction_indices)
  432. # plot_n_idx = sorted_ids[-75]
  433. plot_n_idx = selected_neuron_idx
  434. line_styles = ['dotted', 'solid', 'dashed']
  435. colors = ['r', 'lightsalmon', 'grey']
  436. labels = ['pol. ', 'cir. ', 'no inh.']
  437. line_widths = [1.5, 1.5, 1]
  438. zorders = [10, 2, 1]
  439. max_rate = 0.0
  440. ax.plot([], [], color='white',label=' ')
  441. hdis = []
  442. for run_idx, run_name in enumerate(plot_run_names):
  443. label = labels[run_idx]
  444. hdi = traj.results.runs[run_name].head_direction_indices[selected_neuron_idx]
  445. hdis.append(hdi)
  446. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  447. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  448. run_max_rate = np.max(rate_plot)
  449. if run_max_rate > max_rate:
  450. max_rate = run_max_rate
  451. rate_plot.append(rate_plot[0])
  452. plt_label = '{:s}'.format(short_labels(label))
  453. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  454. label=plt_label, color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  455. ticks = [40., 80.]
  456. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  457. enumerate(ticks)], angle=60)
  458. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  459. ax.xaxis.grid(linewidth=0.4)
  460. ax.yaxis.grid(linewidth=0.4)
  461. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.2), handlelength=1, fontsize="medium")
  462. leg.get_frame().set_linewidth(0.0)
  463. hdi_box_x, hdi_box_y = (0.86, -0.09)
  464. hdi_box_dy = 0.14
  465. hdi_box = ax.text(hdi_box_x, hdi_box_y + 3 * hdi_box_dy, 'HDI', fontsize='medium', transform=ax.transAxes,zorder=9.)
  466. hdi_box = ax.text(hdi_box_x, hdi_box_y + 2 * hdi_box_dy, '{:.2f}'.format(hdis[0]), fontsize='medium',transform=ax.transAxes, zorder=9.)
  467. hdi_box = ax.text(hdi_box_x, hdi_box_y + hdi_box_dy, '{:.2f}'.format(hdis[1]), fontsize='medium', transform=ax.transAxes,zorder=9.)
  468. hdi_box = ax.text(hdi_box_x, hdi_box_y, '{:.2f}'.format(hdis[2]), fontsize='medium', transform=ax.transAxes,zorder=9.)
  469. ax.axes.spines["polar"].set_visible(False)
  470. if save_figs:
  471. plt.savefig(figname)
  472. plt.close(fig)
  473. def plot_polar_plot_inhibitory(traj, plot_run_names, selected_neuron_idx, figname=FIGURE_SAVE_PATH +
  474. 'D_polar_plot_inhibitory' + FIGURE_SAVE_FORMAT):
  475. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  476. directions_plt = list(directions)
  477. directions_plt.append(directions[0])
  478. height = panel_size
  479. width = panel_size
  480. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  481. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  482. # sorted_ids = np.argsort(head_direction_indices)
  483. # plot_n_idx = sorted_ids[-75]
  484. plot_n_idx = selected_neuron_idx
  485. line_styles = ['dotted', 'solid']
  486. colors = ['b', 'lightblue']
  487. labels = ['pol. ', 'cir. ']
  488. line_widths = [1.5, 1.5]
  489. zorders = [10, 2]
  490. ax.plot([], [], color='white',label=' ')
  491. hdis = []
  492. for run_idx, run_name in enumerate(plot_run_names[:2]):
  493. label = labels[run_idx]
  494. hdi = traj.results.runs[run_name].inh_head_direction_indices[selected_neuron_idx]
  495. hdis.append(hdi)
  496. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  497. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  498. rate_plot.append(rate_plot[0])
  499. plt_label = '{:s}'.format(short_labels(label))
  500. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  501. label=plt_label,
  502. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  503. ticks = [40., 80., 120.]
  504. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  505. enumerate(ticks)], angle=60)
  506. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  507. ax.xaxis.grid(linewidth=0.4)
  508. ax.yaxis.grid(linewidth=0.4)
  509. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  510. leg.get_frame().set_linewidth(0.0)
  511. hdi_box_x, hdi_box_y = (0.86, -0.04)
  512. hdi_box_dy = 0.14
  513. hdi_box = ax.text(hdi_box_x, hdi_box_y+2*hdi_box_dy, 'HDI', fontsize='medium', transform=ax.transAxes, zorder=9.)
  514. hdi_box = ax.text(hdi_box_x, hdi_box_y+hdi_box_dy, '{:.2f}'.format(hdis[0]), fontsize='medium', transform=ax.transAxes, zorder=9.)
  515. hdi_box = ax.text(hdi_box_x, hdi_box_y, '{:.2f}'.format(hdis[1]), fontsize='medium', transform=ax.transAxes, zorder=9.)
  516. ax.axes.spines["polar"].set_visible(False)
  517. if save_figs:
  518. plt.savefig(figname)
  519. plt.close(fig)
  520. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id, cut_off_dist):
  521. labels = []
  522. inh_hdis = []
  523. exc_hdis = []
  524. no_conn_hdi = 0.
  525. for run_idx, run_name in enumerate(plot_run_names):
  526. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  527. if label != NO_SYNAPSES:
  528. labels.append(normal_labels(label))
  529. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  530. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  531. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  532. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  533. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  534. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  535. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  536. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  537. ex_positions = traj.results.runs[run_name].ex_positions
  538. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  539. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  540. (ex_positions[:, 1] >= cut_off_dist) & \
  541. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  542. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  543. else:
  544. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  545. no_conn_hdi = np.mean(exc_head_direction_indices)
  546. # Look for a representative excitatory neuron
  547. hdi_mean_dict = {}
  548. excitatory_hdi_means = [np.mean(hdis) for hdis in exc_hdis]
  549. hdi_mean_dict["polar_exc"] = excitatory_hdi_means[0]
  550. hdi_mean_dict["circular_exc"] = excitatory_hdi_means[1]
  551. inhibitory_hdi_means = [np.mean(hdis) for hdis in inh_hdis]
  552. hdi_mean_dict["polar_inh"] = inhibitory_hdi_means[0]
  553. hdi_mean_dict["circular_inh"] = inhibitory_hdi_means[1]
  554. width = 2 * panel_size
  555. height = 1.2 * panel_size
  556. fig, axes = plt.subplots(2, 1, figsize=(width, height))
  557. plt.subplots_adjust(wspace=0, hspace=0.1)
  558. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  559. max_density = 0
  560. for i in range(2):
  561. # i = i + 1 # grid spec indexes from 0
  562. # ax = fig.add_subplot(gs[i])
  563. ax = axes[i]
  564. density_e, _, _ = ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  565. density_i, _, _ = ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  566. max_density = np.max([max_density, np.max(density_e), np.max(density_i)])
  567. ax.axvline(np.mean(exc_hdis[i]), color='r')
  568. ax.axvline(np.mean(inh_hdis[i]), color='b')
  569. ax.axvline(no_conn_hdi, color='dimgrey', linestyle='--', linewidth=1.5)
  570. ax.set_ylabel(labels[i], rotation='vertical')
  571. # plt.axis('on')
  572. if i == 0:
  573. ax.set_xticklabels([])
  574. else:
  575. ax.set_xlabel('HDI')
  576. remove_frame(ax, ["top", "right", "bottom"])
  577. max_density = 1.2 * max_density
  578. fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2)
  579. axes[0].annotate('% cells', (0, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  580. axes[1].annotate("no ihn.\n{:.2f}".format(no_conn_hdi), xy=(no_conn_hdi, max_density),
  581. xytext=(-2, 0), xycoords="data",
  582. textcoords="offset points",
  583. va="top", ha="right", color="dimgrey")
  584. for i, ax in enumerate(axes):
  585. ax.annotate("{:.2f}".format(np.mean(exc_hdis[i])), xy=(np.mean(exc_hdis[i]), max_density),
  586. xytext=(2, 0), xycoords="data",
  587. textcoords="offset points",
  588. va="top", ha="left", color="r")
  589. # i_ha = "left" if i == 1 else "right"
  590. # i_offset = 2 if i == 1 else -2
  591. i_ha = "right"
  592. i_offset = -1
  593. ax.annotate("{:.2f}".format(np.mean(inh_hdis[i])), xy=(np.mean(inh_hdis[i]), max_density),
  594. xytext=(i_offset, 0), xycoords="data",
  595. textcoords="offset points",
  596. va="top", ha=i_ha, color="b")
  597. for ax in axes:
  598. ax.set_ylim(0, max_density)
  599. # plt.annotate('probability density', (-0.2,1.5), xycoords='axes fraction', rotation=90, fontsize=18)
  600. if save_figs:
  601. plt.savefig(FIGURE_SAVE_PATH + 'E_hdi_histogram_combined_and_overlayed' + FIGURE_SAVE_FORMAT.format(cut_off_dist))
  602. plt.close(fig)
  603. return hdi_mean_dict
  604. def get_neurons_with_given_hdi(polar_hdi, circular_hdi, max_number_of_suggestions, plot_run_names, traj, type):
  605. polar_run_name = plot_run_names[0]
  606. circular_run_name = plot_run_names[1]
  607. polar_ex_hdis = traj.results.runs[polar_run_name].head_direction_indices if type == "ex" else traj.results.runs[
  608. polar_run_name].inh_head_direction_indices
  609. circular_ex_hdis = traj.results.runs[circular_run_name].head_direction_indices if type == "ex" else \
  610. traj.results.runs[
  611. polar_run_name].inh_head_direction_indices
  612. neuron_indices = get_indices_of_closest_values(polar_ex_hdis, polar_hdi,
  613. circular_ex_hdis,
  614. circular_hdi, 0.1 * np.abs(
  615. polar_hdi - circular_hdi), max_number_of_suggestions)
  616. return neuron_indices
  617. def get_indices_of_closest_values(first_list, first_value, second_list, second_value, absolute_tolerance_list_one,
  618. number_of_indices):
  619. is_close_in_list_one = np.abs(first_list - first_value) < absolute_tolerance_list_one
  620. indices_close_in_list_one = np.where(is_close_in_list_one)[0]
  621. indices_closest_in_list_two = indices_close_in_list_one[np.argpartition(np.abs(second_list[
  622. indices_close_in_list_one] - second_value),
  623. number_of_indices)]
  624. return indices_closest_in_list_two[:number_of_indices]
  625. def filter_run_names_by_par_dict(traj, par_dict):
  626. run_name_list = []
  627. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  628. traj.f_set_crun(run_name)
  629. paramters_equal = True
  630. for key, val in par_dict.items():
  631. if (traj.par[key] != val):
  632. paramters_equal = False
  633. if paramters_equal:
  634. run_name_list.append(run_name)
  635. traj.f_restore_default()
  636. return run_name_list
  637. def plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, plot_run_names, cut_off_dist):
  638. corr_len_expl = traj.f_get('scale').f_get_range()
  639. seed_expl = traj.f_get('seed').f_get_range()
  640. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  641. label_range = set(label_expl)
  642. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  643. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  644. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  645. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  646. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  647. ex_tunings = traj.results.runs[run_name].ex_tunings
  648. inh_hdis = []
  649. exc_hdis = []
  650. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  651. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  652. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  653. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  654. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  655. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  656. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  657. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  658. ex_positions = traj.results.runs[run_name].ex_positions
  659. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  660. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  661. (ex_positions[:, 1] >= cut_off_dist) & \
  662. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  663. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  664. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  665. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  666. # TODO: Standard deviation also for the population
  667. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  668. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  669. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  670. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  671. markersize = 4.
  672. exc_style_dict = {
  673. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  674. POLARIZED: ['red', 'solid', '^', markersize],
  675. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  676. }
  677. inh_style_dict = {
  678. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  679. POLARIZED: ['blue', 'solid', 'o', markersize],
  680. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  681. }
  682. width = 2 * panel_size
  683. height = 1.2 * panel_size
  684. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  685. for label in sorted(label_range, reverse=True):
  686. if label == NO_SYNAPSES:
  687. no_conn_hdi = exc_hdi_n_and_seed_mean[get_closest_scale(traj, 200), label]
  688. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  689. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  690. textcoords="offset points",
  691. va="top", \
  692. ha="right",
  693. color="dimgrey")
  694. continue
  695. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  696. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  697. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  698. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  699. corr_len_range = exc_hdi_mean.keys().to_numpy()
  700. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  701. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  702. simplex_grid_scale = corr_len_range * np.sqrt(2)
  703. ax.plot(simplex_grid_scale, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  704. markersize=exc_mar_size, alpha=0.5)
  705. plt.fill_between(simplex_grid_scale, exc_hdi_mean - exc_hdi_std,
  706. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  707. ax.plot(simplex_grid_scale, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  708. markersize=inh_mar_size, alpha=0.5)
  709. plt.fill_between(simplex_grid_scale, inh_hdi_mean - inh_hdi_std,
  710. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  711. ax.set_xlabel('simplex grid scale')
  712. ax.set_ylabel('HDI')
  713. ax.axvline(get_closest_scale(traj, 200.0) * np.sqrt(2), color='k', linewidth=0.5, zorder=0)
  714. ax.set_ylim(0.0, 1.0)
  715. # ax.set_xlim(0.0, np.max(corr_len_range))
  716. remove_frame(ax, ["right", "top"])
  717. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  718. row_labels=None,
  719. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  720. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  721. fig.subplots_adjust(bottom=0.2, left=0.2)
  722. # plt.legend()
  723. if save_figs:
  724. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_grid_scale' + FIGURE_SAVE_FORMAT)
  725. plt.close(fig)
  726. def plot_exc_and_inh_hdi_over_fit_corr_len(traj, plot_run_names, cut_off_dist):
  727. corr_len_expl = traj.f_get('scale').f_get_range()
  728. seed_expl = traj.f_get('seed').f_get_range()
  729. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  730. label_range = set(label_expl)
  731. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  732. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  733. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  734. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  735. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  736. ex_tunings = traj.results.runs[run_name].ex_tunings
  737. inh_hdis = []
  738. exc_hdis = []
  739. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  740. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  741. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  742. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  743. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  744. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  745. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  746. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  747. ex_positions = traj.results.runs[run_name].ex_positions
  748. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  749. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  750. (ex_positions[:, 1] >= cut_off_dist) & \
  751. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  752. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  753. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  754. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  755. # TODO: Standard deviation also for the population
  756. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  757. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  758. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  759. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  760. markersize = 4.
  761. exc_style_dict = {
  762. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  763. POLARIZED: ['red', 'solid', '^', markersize],
  764. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  765. }
  766. inh_style_dict = {
  767. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  768. POLARIZED: ['blue', 'solid', 'o', markersize],
  769. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  770. }
  771. # colors = ['blue', 'grey', 'lightblue']
  772. # linestyles = ['solid', 'dashed', 'solid']
  773. # markers = [verts, '', 'o']
  774. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  775. width = 2 * panel_size
  776. height = 1.2 * panel_size
  777. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  778. for label in sorted(label_range, reverse=True):
  779. if label == NO_SYNAPSES:
  780. no_conn_hdi = exc_hdi_n_and_seed_mean[get_closest_scale(traj, 200), label]
  781. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  782. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  783. textcoords="offset points",
  784. va="top", \
  785. ha="right",
  786. color="dimgrey")
  787. continue
  788. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  789. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  790. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  791. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  792. corr_len_range = exc_hdi_mean.keys().to_numpy()
  793. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  794. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  795. fit_corr_len = [corr_len_fit_dict[corr_len] for corr_len in corr_len_range]
  796. ax.plot(fit_corr_len, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  797. markersize=exc_mar_size, alpha=0.5)
  798. plt.fill_between(fit_corr_len, exc_hdi_mean - exc_hdi_std,
  799. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  800. ax.plot(fit_corr_len, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  801. markersize=inh_mar_size, alpha=0.5)
  802. plt.fill_between(fit_corr_len, inh_hdi_mean - inh_hdi_std,
  803. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  804. ax.set_xlabel('correlation length (um)')
  805. ax.set_ylabel('HDI')
  806. ax.axvline(corr_len_fit_dict[get_closest_scale(traj, 200.0)], color='k', linewidth=0.5, zorder=0)
  807. ax.set_ylim(0.0, 1.0)
  808. remove_frame(ax, ["right", "top"])
  809. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  810. row_labels=None,
  811. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  812. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  813. fig.subplots_adjust(bottom=0.2, left=0.2)
  814. # plt.legend()
  815. if save_figs:
  816. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_corr_len_scaled' + FIGURE_SAVE_FORMAT)
  817. plt.close(fig)
  818. def get_phase_difference(total_difference):
  819. """
  820. Map accumulated phase difference to shortest possible difference
  821. :param total_difference:
  822. :return: relative_difference
  823. """
  824. return (total_difference + np.pi) % (2 * np.pi) - np.pi
  825. def plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names, figsize=(9, 9)):
  826. # The plot that Imre wanted
  827. n_bins = traj.parameters.input.number_of_directions
  828. fig, ax = plt.subplots(1, 1, figsize=figsize)
  829. dir_bins = np.linspace(-np.pi, np.pi, n_bins + 1)
  830. plot_fr_array = []
  831. labels = []
  832. similarity_threshold = np.pi / 6.
  833. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  834. for run_idx, run_name in enumerate(plot_run_names):
  835. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  836. labels.append(short_labels(label))
  837. fr_similar_tunings = []
  838. fr_different_tunings = []
  839. ex_tunings = traj.results.runs[run_name].ex_tunings
  840. firing_rate_array = traj.results[run_name].firing_rate_array
  841. for tuning, firing_rates in zip(ex_tunings, firing_rate_array):
  842. for idx, dir in enumerate(directions):
  843. if np.abs(get_phase_difference(tuning - dir)) <= similarity_threshold:
  844. fr_similar_tunings.append(firing_rates[idx])
  845. elif np.abs(get_phase_difference(tuning + np.pi - dir)) <= similarity_threshold:
  846. fr_different_tunings.append(firing_rates[idx])
  847. plot_fr_array.append([np.mean(fr_similar_tunings), np.mean(fr_different_tunings)])
  848. x = np.arange(3) # the label locations
  849. width = 0.35 # the width of the bars
  850. plot_fr_array = np.array(plot_fr_array)
  851. rects1 = ax.bar(x - width / 2, plot_fr_array[:, 0], width,
  852. label=r'preferred')
  853. rects2 = ax.bar(x + width / 2, plot_fr_array[:, 1], width,
  854. label=r'opposite')
  855. ax.set_xticks(x)
  856. ax.set_xticklabels(labels)
  857. ax.spines['right'].set_visible(False)
  858. ax.spines['top'].set_visible(False)
  859. # ax.set_title('Mean firing rate for tunings similar and different to input')
  860. # ax.set_ylabel('Mean firing rate')
  861. ax.annotate(r'$\overline{\mathrm{fr}}$ (Hz)', (0.05, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  862. leg = ax.legend(loc="upper left", bbox_to_anchor=(0.0, 1.2), handlelength=1, fontsize="medium")
  863. leg.get_frame().set_linewidth(0.0)
  864. def autolabel(rects):
  865. """Attach a text label above each bar in *rects*, displaying its height."""
  866. for rect in rects:
  867. height = rect.get_height()
  868. ax.annotate('{}'.format(np.round(height)),
  869. xy=(rect.get_x() + rect.get_width() / 2, height),
  870. xytext=(0, 3), # 3 points vertical offset
  871. textcoords="offset points",
  872. ha='center', va='bottom')
  873. autolabel(rects1)
  874. autolabel(rects2)
  875. fig.tight_layout()
  876. if save_figs:
  877. plt.savefig(FIGURE_SAVE_PATH + 'SUPPLEMENT_B_firing_rate_similar_vs_diff_tuning' + FIGURE_SAVE_FORMAT, dpi=200)
  878. plt.close(fig)
  879. def get_firing_rates_along_preferred_axis(traj, run_name, neuron_idx):
  880. firing_rates = traj.results[run_name].firing_rate_array[neuron_idx, :]
  881. tuning = traj.results[run_name].ex_tunings[neuron_idx]
  882. anti_tuning = tuning + np.pi if tuning + np.pi < np.pi else tuning - np.pi
  883. tuning_idx = np.argmin(np.abs(directions - tuning))
  884. anti_tuning_idx = np.argmin(np.abs(directions - anti_tuning))
  885. firing_at_the_preferred_direction = firing_rates[tuning_idx]
  886. firing_at_the_opposite_direction = firing_rates[anti_tuning_idx]
  887. return firing_at_the_preferred_direction, firing_at_the_opposite_direction
  888. def get_hdi(traj, run_name, neuron_idx, type):
  889. return traj.results.runs[run_name].head_direction_indices[neuron_idx] if type=="ex" else traj.results.runs[
  890. run_name].inh_head_direction_indices[neuron_idx]
  891. def plot_colorbar(figsize=(2, 2), figname=None):
  892. azimuth_no = 360
  893. zenith_no = 15
  894. azimuths = np.linspace(-180, 180, azimuth_no)
  895. zeniths = np.linspace(0.85, 1, zenith_no)
  896. values = azimuths * np.ones((zenith_no, azimuth_no))
  897. fig, ax = plt.subplots(subplot_kw=dict(projection='polar'), figsize=figsize)
  898. ax.pcolormesh(azimuths * np.pi / 180.0, zeniths, values, cmap=head_direction_input_colormap)
  899. # ax.set_yticks([])
  900. ax.set_thetagrids([0, 90, 180, 270])
  901. ax.tick_params(pad=-2)
  902. ax.set_ylim(0, 1)
  903. ax.grid(True)
  904. y_tick_labels = []
  905. ax.set_yticklabels(y_tick_labels)
  906. gridlines = ax.yaxis.get_gridlines()
  907. [line.set_linewidth(0.0) for line in gridlines]
  908. gridlines = ax.xaxis.get_gridlines()
  909. [line.set_linewidth(0.0) for line in gridlines]
  910. ax.axes.spines["polar"].set_visible(False)
  911. plt.subplots_adjust(left=0.25, right=0.75, bottom=0.25, top=0.75)
  912. # plt.show()
  913. if figname is not None:
  914. plt.savefig(figname, transparent=True)
  915. plt.close(fig)
  916. if __name__ == "__main__":
  917. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  918. NO_LOADING = 0
  919. FULL_LOAD = 2
  920. traj.f_load(filename=os.path.join(DATA_FOLDER, TRAJ_NAME + ".hdf5"), load_parameters=FULL_LOAD,
  921. load_results=NO_LOADING)
  922. traj.v_auto_load = True
  923. save_figs = True
  924. print("# Plotting script polarized interneurons")
  925. print()
  926. map_length_scale = 200.0
  927. map_seed = 1
  928. exemplary_head_direction = 0
  929. print("## Map specifications")
  930. print("\tinput map scale: {:.1f} um".format(map_length_scale))
  931. print("\tmap seed: {:d}".format(map_seed))
  932. print()
  933. print("## Input specification")
  934. print("\tselected head direction: {:.0f}°".format(exemplary_head_direction))
  935. print()
  936. print("## Selected simulations")
  937. plot_scale = get_closest_scale(traj, map_length_scale)
  938. par_dict = {'seed': map_seed, 'scale': plot_scale}
  939. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  940. run_name_dict = {}
  941. for run_name in plot_run_names:
  942. traj.f_set_crun(run_name)
  943. run_name_dict[traj.derived_parameters.runs[run_name].morphology.morph_label] = run_name
  944. for network_type, run_name in run_name_dict.items():
  945. print("{:s}: {:s}".format(network_type, run_name))
  946. directions = get_input_head_directions(traj)
  947. direction_idx = np.argmin(np.abs(np.array(directions) - np.deg2rad(exemplary_head_direction)))
  948. selected_neuron_excitatory = 1052
  949. selected_inhibitory_neuron = 28
  950. print("## Figure specification")
  951. print("\tpanel size: {:.2f} cm".format(panel_size * cm_per_inch))
  952. print()
  953. plot_colorbar(figsize=(0.8 * panel_size, 0.8 * panel_size), figname=FIGURE_SAVE_PATH + "A_i_colormap.svg")
  954. plot_input_map(traj, run_name_dict[POLARIZED], figname="A_i_exemplary_input_map" + FIGURE_SAVE_FORMAT,
  955. figsize=(panel_size, panel_size))
  956. plot_example_input_maps(traj, figsize=(2 * panel_size, 2 * panel_size))
  957. plot_axonal_clouds(traj, plot_run_names)
  958. plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, selected_neuron_excitatory)
  959. in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron)
  960. #
  961. hdi_means = plot_hdi_histogram_combined_and_overlayed(
  962. traj, plot_run_names,
  963. selected_neuron_excitatory,
  964. selected_inhibitory_neuron,
  965. cut_off_dist=100.)
  966. #
  967. plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_excitatory)
  968. plot_polar_plot_inhibitory(traj, plot_run_names, selected_inhibitory_neuron)
  969. plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names, figsize=(1.2*panel_size, 1.2*panel_size))
  970. plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, traj.f_get_run_names(), cut_off_dist=100.)
  971. plot_exc_and_inh_hdi_over_fit_corr_len(traj, traj.f_get_run_names(), cut_off_dist=100.)
  972. if not save_figs:
  973. plt.show()
  974. traj.f_restore_default()