final_analysis_and_plotting_perlin_map.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327
  1. import itertools
  2. import os
  3. import matplotlib.legend as mlegend
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. from matplotlib.patches import Ellipse
  8. from matplotlib.patches import Polygon
  9. from matplotlib.patches import Rectangle
  10. from mpl_toolkits.axes_grid1 import ImageGrid
  11. from pypet import Trajectory
  12. from pypet.brian2 import Brian2MonitorResult
  13. from scripts.spatial_maps.spatial_network_layout import Interneuron, get_position_mesh
  14. from scripts.spatial_maps.correlation_length_fit.correlation_length_fit import \
  15. correlation_length_fit_dict
  16. from scripts.model_figure.figure_utils import remove_frame, remove_ticks, add_length_scale, cm_per_inch, \
  17. panel_size, head_direction_input_colormap
  18. from scripts.spatial_network.perlin_map.run_simulation_perlin_map import DATA_FOLDER, TRAJ_NAME, \
  19. get_input_head_directions, POLARIZED, CIRCULAR, NO_SYNAPSES
  20. plt.style.use('../../model_figure/figures.mplstyle')
  21. FIGURE_SAVE_PATH = '../../../figures/' + TRAJ_NAME + '/'
  22. FIGURE_SAVE_FORMAT = '.pdf'
  23. save_figs = False
  24. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  25. """
  26. Place a table legend on the axes.
  27. Creates a legend where the labels are not directly placed with the artists,
  28. but are used as row and column headers, looking like this:
  29. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  30. -------------------------------------------------------------
  31. row_labels[1] |
  32. row_labels[2] | <artists go there>
  33. row_labels[3] |
  34. Parameters
  35. ----------
  36. ax : `matplotlib.axes.Axes`
  37. The artist that contains the legend table, i.e. current axes instant.
  38. col_labels : list of str, optional
  39. A list of labels to be used as column headers in the legend table.
  40. `len(col_labels)` needs to match `ncol`.
  41. row_labels : list of str, optional
  42. A list of labels to be used as row headers in the legend table.
  43. `len(row_labels)` needs to match `len(handles) // ncol`.
  44. title_label : str, optional
  45. Label for the top left corner in the legend table.
  46. ncol : int
  47. Number of columns.
  48. Other Parameters
  49. ----------------
  50. Refer to `matplotlib.legend.Legend` for other parameters.
  51. """
  52. #################### same as `matplotlib.axes.Axes.legend` #####################
  53. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  54. if len(extra_args):
  55. raise TypeError('legend only accepts two non-keyword arguments')
  56. if col_labels is None and row_labels is None:
  57. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  58. ax.legend_._remove_method = ax._remove_legend
  59. return ax.legend_
  60. #################### modifications for table legend ############################
  61. else:
  62. ncol = kwargs.pop('ncol')
  63. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  64. title_label = [title_label]
  65. # blank rectangle handle
  66. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  67. # empty label
  68. empty = [""]
  69. # number of rows infered from number of handles and desired number of columns
  70. nrow = len(handles) // ncol
  71. # organise the list of handles and labels for table construction
  72. if col_labels is None:
  73. assert nrow == len(
  74. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  75. nrow, len(row_labels))
  76. leg_handles = extra * nrow
  77. leg_labels = row_labels
  78. elif row_labels is None:
  79. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  80. ncol, len(col_labels))
  81. leg_handles = []
  82. leg_labels = []
  83. else:
  84. assert nrow == len(
  85. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  86. nrow, len(row_labels))
  87. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  88. ncol, len(col_labels))
  89. leg_handles = extra + extra * nrow
  90. leg_labels = title_label + row_labels
  91. for col in range(ncol):
  92. if col_labels is not None:
  93. leg_handles += extra
  94. leg_labels += [col_labels[col]]
  95. leg_handles += handles[col * nrow:(col + 1) * nrow]
  96. leg_labels += empty * nrow
  97. # Create legend
  98. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol + int(row_labels is not None),
  99. handletextpad=handletextpad, **kwargs)
  100. ax.legend_._remove_method = ax._remove_legend
  101. return ax.legend_
  102. def get_closest_scale(traj, scale):
  103. available_lengths = sorted(list(set(traj.f_get("scale").f_get_range())))
  104. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - scale))]
  105. if closest_length != scale:
  106. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  107. scale, closest_length))
  108. corr_len = closest_length
  109. return corr_len
  110. def get_closest_fit_correlation_length(traj, fit_corr_len):
  111. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  112. available_lengths = list(corr_len_fit_dict.values())
  113. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - fit_corr_len))]
  114. closest_corresponding_scale = list(corr_len_fit_dict.keys())[list(corr_len_fit_dict.values()).index(closest_length)]
  115. if closest_length != fit_corr_len:
  116. print("Warning: desired fit correlation length {:.1f} not available. Taking {:.1f} instead".format(
  117. fit_corr_len, closest_length))
  118. return closest_corresponding_scale
  119. def colorbar(mappable):
  120. from mpl_toolkits.axes_grid1 import make_axes_locatable
  121. import matplotlib.pyplot as plt
  122. last_axes = plt.gca()
  123. ax = mappable.axes
  124. fig = ax.figure
  125. divider = make_axes_locatable(ax)
  126. cax = divider.append_axes("right", size="5%", pad=0.05)
  127. cbar = fig.colorbar(mappable, cax=cax)
  128. plt.sca(last_axes)
  129. return cbar
  130. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, exemplary_excitatory_neuron_id):
  131. max_val = 0
  132. for run_name in plot_run_names:
  133. fr_array = traj.results.runs[run_name].firing_rate_array
  134. f_rates = fr_array[:, direction_idx]
  135. run_max_val = np.max(f_rates)
  136. if run_max_val > max_val:
  137. # if traj.derived_parameters.runs[run_name].morphology.morph_label == POLARIZED:
  138. # n_id_max_rate = np.argmax(f_rates)
  139. max_val = run_max_val
  140. n_id_polar_plot = exemplary_excitatory_neuron_id
  141. # Mark the neuron that is shown in polar plot
  142. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  143. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  144. # Vertices for the plotted triangle
  145. tr_scale = 30.
  146. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  147. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  148. tr_vertices = np.array(
  149. [[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  150. height = 1 * panel_size
  151. width = 3 * panel_size
  152. fig = plt.figure(figsize=(width, height))
  153. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  154. cbar_mode="single",
  155. cbar_size="7%", cbar_pad=panel_size / 10.0,
  156. nrows_ncols=(1, 3))
  157. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  158. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  159. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  160. firing_rate_array = traj.results[run_name].firing_rate_array
  161. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  162. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  163. number_of_excitatory_neurons_per_row)),
  164. vmin=0, vmax=max_val, cmap='Reds')
  165. # ax.set_title(adjust_label(label))
  166. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  167. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  168. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  169. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  170. for spine in ax.spines.values():
  171. spine.set_edgecolor("grey")
  172. spine.set_linewidth(1)
  173. remove_ticks(ax)
  174. # fig.suptitle('spatial firing rate map', fontsize=16)
  175. ax.cax.colorbar(c)
  176. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  177. # fig.tight_layout()
  178. if save_figs:
  179. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_excitatory' + FIGURE_SAVE_FORMAT, dpi=300)
  180. plt.close(fig)
  181. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron):
  182. max_val = 0
  183. for run_name in plot_run_names:
  184. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  185. f_rates = fr_array[:, direction_idx]
  186. run_max_val = np.max(f_rates)
  187. if run_max_val > max_val:
  188. max_val = run_max_val
  189. n_id_polar_plot = selected_inhibitory_neuron
  190. # Mark the neuron that is shown in Polar plot
  191. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  192. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  193. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  194. width = 3 * panel_size
  195. height = panel_size
  196. fig = plt.figure(figsize=(width, height))
  197. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  198. cbar_mode="single",
  199. cbar_size="7%", cbar_pad=panel_size / 10.0,
  200. nrows_ncols=(1, 3))
  201. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  202. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  203. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  204. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  205. X, Y = get_position_mesh(inh_positions)
  206. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  207. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  208. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  209. number_of_inhibitory_neurons_per_row)),
  210. vmin=0, vmax=max_val, cmap='Blues')
  211. # ax.set_title(adjust_label(label))
  212. circle_r = 40.
  213. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  214. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  215. for spine in ax.spines.values():
  216. spine.set_edgecolor("grey")
  217. spine.set_linewidth(0.5)
  218. remove_ticks(ax)
  219. # fig.colorbar(c, ax=ax, label="f (Hz)")
  220. # fig.suptitle('spatial firing rate map', fontsize=16)
  221. cbar = ax.cax.colorbar(c)
  222. cbar_ticks = range(0,151,30)
  223. cbar.ax.set_yticks(cbar_ticks)
  224. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  225. # fig.tight_layout()
  226. if save_figs:
  227. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_inhibitory' + FIGURE_SAVE_FORMAT, dpi=300)
  228. plt.close(fig)
  229. return max_val
  230. def normal_labels(label):
  231. if label == POLARIZED:
  232. label = 'polarized'
  233. elif label == NO_SYNAPSES:
  234. label = 'no interneurons'
  235. return label
  236. def short_labels(label):
  237. if label == POLARIZED:
  238. label = 'pol.'
  239. elif label == CIRCULAR:
  240. label = 'cir.'
  241. elif label == "no conn":
  242. label = "no inh."
  243. return label
  244. def plot_input_map(traj, run_name, figsize=(panel_size, panel_size), figname='input_map' + FIGURE_SAVE_FORMAT):
  245. n_ex = int(np.sqrt(traj.N_E))
  246. width, height = figsize
  247. fig = plt.figure(figsize=(width, height))
  248. axes = ImageGrid(fig, (0.15, 0.1, 0.7, 0.8), axes_pad=panel_size / 3.0, cbar_location="right", cbar_mode=None,
  249. nrows_ncols=(1, 1))
  250. ax = axes[0]
  251. traj.f_set_crun(run_name)
  252. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  253. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  254. scale_length = traj.morphology.long_axis * 2
  255. traj.f_restore_default()
  256. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  257. for spine in ax.spines.values():
  258. spine.set_edgecolor("grey")
  259. spine.set_linewidth(0.5)
  260. remove_ticks(ax)
  261. start_scale_x = 100
  262. end_scale_x = start_scale_x + scale_length
  263. start_scale_y = -70
  264. end_scale_y = start_scale_y
  265. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  266. ax.cax.set_visible(False)
  267. if save_figs:
  268. plt.savefig(FIGURE_SAVE_PATH + figname)
  269. plt.close(fig)
  270. def plot_example_input_maps(traj, figsize=(panel_size, panel_size), figname='perlin_example_input_maps' + FIGURE_SAVE_FORMAT):
  271. n_ex = int(np.sqrt(traj.N_E))
  272. width, height = figsize
  273. fig = plt.figure(figsize=(width, height))
  274. axes = ImageGrid(fig, (0.1, 0.1, 0.85, 0.85), axes_pad=panel_size / 12.0, nrows_ncols=(3, 3))
  275. fit_corr_len_list = [20, 50, 100]
  276. seed_list = [1, 2, 3]
  277. corr_and_seed = itertools.product(fit_corr_len_list, seed_list)
  278. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  279. for idx, (ax, (fit_corr_len, seed)) in enumerate(zip(axes, corr_and_seed)):
  280. corresp_scale = get_closest_fit_correlation_length(traj, fit_corr_len)
  281. par_dict = {'seed': seed, 'correlation_length': corresp_scale}
  282. run_name = filter_run_names_by_par_dict(traj, par_dict)[0]
  283. traj.f_set_crun(run_name)
  284. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  285. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  286. scale_length = traj.morphology.long_axis * 2
  287. traj.f_restore_default()
  288. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  289. for spine in ax.spines.values():
  290. spine.set_edgecolor("grey")
  291. spine.set_linewidth(0.5)
  292. remove_ticks(ax)
  293. ax.set_ylabel('{:3.0f} um'.format(corr_len_fit_dict[corresp_scale]))
  294. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  295. plt_polar_id = 255
  296. plt_polar_coordinates = inhibitory_axonal_cloud_array[plt_polar_id]
  297. plt_polar_neuron = Interneuron(plt_polar_coordinates[0],
  298. plt_polar_coordinates[1],
  299. traj.morphology.long_axis,
  300. traj.morphology.short_axis,
  301. plt_polar_coordinates[2])
  302. circ_radius = np.sqrt(traj.morphology.long_axis * traj.morphology.short_axis)
  303. plt_circ_id = 65
  304. plt_circ_coordinates = inhibitory_axonal_cloud_array[plt_circ_id]
  305. plt_circ_neuron = Interneuron(plt_circ_coordinates[0],
  306. plt_circ_coordinates[1],
  307. circ_radius,
  308. circ_radius,
  309. plt_circ_coordinates[2])
  310. plt_interneurons = [plt_polar_neuron, plt_circ_neuron]
  311. for p in plt_interneurons:
  312. ell = p.get_ellipse()
  313. edgecolor = 'black'
  314. alpha = 1
  315. zorder = 10
  316. linewidth = 1.
  317. ell.set_edgecolor(edgecolor)
  318. ell.set_alpha(alpha)
  319. ell.set_zorder(zorder)
  320. ell.set_linewidth(linewidth)
  321. ax.add_artist(ell)
  322. if idx == 8:
  323. start_scale_x = 100
  324. end_scale_x = start_scale_x + scale_length
  325. start_scale_y = -70
  326. end_scale_y = start_scale_y
  327. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  328. ax.cax.set_visible(False)
  329. if save_figs:
  330. plt.savefig(FIGURE_SAVE_PATH + figname)
  331. plt.close(fig)
  332. def plot_axonal_clouds(traj, plot_run_names):
  333. n_ex = int(np.sqrt(traj.N_E))
  334. height = 1 * panel_size
  335. width = 3 * panel_size
  336. cluster_positions = [(250, 250), (750, 200), (450, 600)]
  337. cluster_sizes = [4, 4, 4]
  338. selected_neurons = []
  339. inhibitory_positions = traj.results.runs[plot_run_names[0]].inhibitory_axonal_cloud_array[:, :2]
  340. for cluster_position, number_of_neurons_in_cluster in zip(cluster_positions, cluster_sizes):
  341. selection = get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster,
  342. inhibitory_positions)
  343. selected_neurons.extend(selection)
  344. fig = plt.figure(figsize=(width, height))
  345. axes = ImageGrid(fig, 111, axes_pad=0.15, cbar_location="right", cbar_mode="single", cbar_size="7%",
  346. nrows_ncols=(1, 3))
  347. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  348. traj.f_set_crun(run_name)
  349. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  350. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  351. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  352. axonal_clouds = [Interneuron(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  353. inhibitory_axonal_cloud_array]
  354. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  355. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  356. # ax.set_title(normal_labels(label))
  357. ax.set_aspect('equal')
  358. remove_frame(ax)
  359. remove_ticks(ax)
  360. # fig.colorbar(c, ax=ax, label="Tuning")
  361. if label != NO_SYNAPSES and axonal_clouds is not None:
  362. for i, p in enumerate(axonal_clouds):
  363. ell = p.get_ellipse()
  364. if i in selected_neurons:
  365. edgecolor = 'black'
  366. alpha = 1
  367. zorder = 10
  368. linewidth = 1
  369. else:
  370. edgecolor = 'gray'
  371. alpha = 0.5
  372. zorder = 1
  373. linewidth = 0.3
  374. ell.set_edgecolor(edgecolor)
  375. ell.set_alpha(alpha)
  376. ell.set_zorder(zorder)
  377. ell.set_linewidth(linewidth)
  378. ax.add_artist(ell)
  379. traj.f_set_crun(plot_run_names[0])
  380. scale_length = traj.morphology.long_axis * 2
  381. traj.f_restore_default()
  382. start_scale_x = 100
  383. end_scale_x = start_scale_x + scale_length
  384. start_scale_y = -70
  385. end_scale_y = start_scale_y
  386. add_length_scale(axes[0], scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  387. axes[1].set_yticklabels([])
  388. axes[2].set_yticklabels([])
  389. axes[0].cax.set_visible(False)
  390. traj.f_restore_default()
  391. if save_figs:
  392. plt.savefig(FIGURE_SAVE_PATH + 'B_i_axonal_clouds' + FIGURE_SAVE_FORMAT)
  393. plt.close(fig)
  394. def get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster, positions):
  395. position = np.array(cluster_position)
  396. distance_vectors = positions - np.expand_dims(position, 0).repeat(positions.shape[0],
  397. axis=0)
  398. distances = np.linalg.norm(distance_vectors, axis=1)
  399. selection = list(np.argpartition(distances, number_of_neurons_in_cluster)[:number_of_neurons_in_cluster])
  400. return selection
  401. def plot_orientation_maps_diff_scales(traj):
  402. n_ex = int(np.sqrt(traj.N_E))
  403. scale_run_names = []
  404. plot_scales = [0.0, 100.0, 200.0, 300.0]
  405. for scale in plot_scales:
  406. par_dict = {'seed': 1, 'correlation_length': get_closest_scale(traj, scale), 'long_axis': 100.}
  407. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  408. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  409. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  410. traj.f_set_crun(run_name)
  411. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  412. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  413. # TODO: Why was this transposed for plotting? (now changed)
  414. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  415. ax.set_title('Correlation length: {}'.format(scale))
  416. fig.colorbar(c, ax=ax, label="Tuning")
  417. # fig.suptitle('axonal cloud', fontsize=16)
  418. traj.f_restore_default()
  419. if save_figs:
  420. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales' + FIGURE_SAVE_FORMAT)
  421. plt.close(fig)
  422. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  423. n_ex = int(np.sqrt(traj.N_E))
  424. scale_run_names = []
  425. plot_scales = [0.0, 400.0, 800.0]
  426. real_scales = [get_closest_scale(traj, scale) for scale in plot_scales]
  427. for scale in plot_scales:
  428. par_dict = {'seed': 1, 'correlation_length': get_closest_scale(traj, scale), 'long_axis': 100.}
  429. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  430. inhibitory_positions = traj.results.runs[scale_run_names[1]].inhibitory_axonal_cloud_array[:, :2]
  431. selected_polar_neuron = get_neurons_close_to_given_position((300, 300), 1, inhibitory_positions)[0]
  432. selected_circular_neuron = get_neurons_close_to_given_position((600, 600), 1, inhibitory_positions)[0]
  433. width = panel_size
  434. height = 1.5 * panel_size
  435. fig = plt.figure(figsize=(width, height))
  436. axes = ImageGrid(fig, 111, axes_pad=0.15,
  437. nrows_ncols=(len(plot_scales), 1))
  438. for ax, run_name, scale in zip(axes, scale_run_names, real_scales):
  439. traj.f_set_crun(run_name)
  440. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  441. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  442. axonal_clouds = [Interneuron(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  443. inhibitory_axonal_cloud_array]
  444. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  445. # TODO: Why was this transposed for plotting? (now changed)
  446. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  447. # ax.set_title('Correlation length: {}'.format(scale))
  448. # fig.colorbar(c, ax=ax, label="Tuning")
  449. ax.set_xticks([])
  450. ax.set_yticks([])
  451. p1 = axonal_clouds[selected_polar_neuron]
  452. ell = p1.get_ellipse()
  453. ell._linewidth = 2.
  454. ax.add_artist(ell)
  455. p2 = axonal_clouds[selected_circular_neuron]
  456. circ_r = 2 * np.sqrt(p2.a * p2.b)
  457. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  458. circ._linewidth = 2.
  459. ax.add_artist(circ)
  460. ax.annotate("{:.0f}".format(scale), xy=(1.0, 0.5), xytext=(2, 0), xycoords="axes fraction", textcoords="offset "
  461. "points",
  462. va="center", ha="left")
  463. remove_frame(ax)
  464. # fig.suptitle('axonal cloud', fontsize=16)
  465. traj.f_restore_default()
  466. add_length_scale(axes[1], 200, -300, -100, 50, 50)
  467. axes[0].annotate("input maps", xy=(-0.5, 1.0), xycoords="axes fraction", rotation=90, ha="left", va="top")
  468. fig.tight_layout()
  469. if save_figs:
  470. plt.savefig(FIGURE_SAVE_PATH + 'F_orientation_maps_diff_scales_with_ellipse' + FIGURE_SAVE_FORMAT)
  471. plt.close(fig)
  472. def plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_idx,
  473. figname=FIGURE_SAVE_PATH + 'D_polar_plot_excitatory' + FIGURE_SAVE_FORMAT):
  474. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  475. directions_plt = list(directions)
  476. directions_plt.append(directions[0])
  477. height = panel_size
  478. width = panel_size
  479. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  480. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  481. # sorted_ids = np.argsort(head_direction_indices)
  482. # plot_n_idx = sorted_ids[-75]
  483. plot_n_idx = selected_neuron_idx
  484. line_styles = ['dotted', 'solid', 'dashed']
  485. colors = ['r', 'lightsalmon', 'grey']
  486. labels = ['pol. ', 'cir. ', 'no inh.']
  487. line_widths = [1.5, 1.5, 1]
  488. zorders = [10, 2, 1]
  489. max_rate = 0.0
  490. ax.plot([], [], color='white',label=' ')
  491. hdis = []
  492. for run_idx, run_name in enumerate(plot_run_names):
  493. # label = traj.derived_parameters.runs[run_name].morphology.morph_label
  494. label = labels[run_idx]
  495. hdi = traj.results.runs[run_name].head_direction_indices[selected_neuron_idx]
  496. hdis.append(hdi)
  497. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  498. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  499. run_max_rate = np.max(rate_plot)
  500. if run_max_rate > max_rate:
  501. max_rate = run_max_rate
  502. rate_plot.append(rate_plot[0])
  503. # plt_label = '{:s} {:.2f}'.format(short_labels(label), hdi)
  504. plt_label = '{:s}'.format(short_labels(label))
  505. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  506. label=plt_label, color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  507. # ax.set_title('Firing Rate')
  508. # ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  509. # TODO: Set ticks for polar
  510. ticks = [40., 80.]
  511. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  512. enumerate(ticks)], angle=60)
  513. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  514. ax.xaxis.grid(linewidth=0.4)
  515. ax.yaxis.grid(linewidth=0.4)
  516. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.2), handlelength=1, fontsize="medium")
  517. leg.get_frame().set_linewidth(0.0)
  518. hdi_box_x, hdi_box_y = (0.86, -0.09)
  519. hdi_box_dy = 0.14
  520. hdi_box = ax.text(hdi_box_x, hdi_box_y + 3 * hdi_box_dy, 'HDI', fontsize='medium', transform=ax.transAxes,zorder=9.)
  521. hdi_box = ax.text(hdi_box_x, hdi_box_y + 2 * hdi_box_dy, '{:.2f}'.format(hdis[0]), fontsize='medium',transform=ax.transAxes, zorder=9.)
  522. hdi_box = ax.text(hdi_box_x, hdi_box_y + hdi_box_dy, '{:.2f}'.format(hdis[1]), fontsize='medium', transform=ax.transAxes,zorder=9.)
  523. hdi_box = ax.text(hdi_box_x, hdi_box_y, '{:.2f}'.format(hdis[2]), fontsize='medium', transform=ax.transAxes,zorder=9.)
  524. ax.axes.spines["polar"].set_visible(False)
  525. if save_figs:
  526. plt.savefig(figname)
  527. plt.close(fig)
  528. def plot_polar_plot_inhibitory(traj, plot_run_names, selected_neuron_idx, figname=FIGURE_SAVE_PATH +
  529. 'D_polar_plot_inhibitory' + FIGURE_SAVE_FORMAT):
  530. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  531. directions_plt = list(directions)
  532. directions_plt.append(directions[0])
  533. height = panel_size
  534. width = panel_size
  535. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  536. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  537. # sorted_ids = np.argsort(head_direction_indices)
  538. # plot_n_idx = sorted_ids[-75]
  539. plot_n_idx = selected_neuron_idx
  540. line_styles = ['dotted', 'solid']
  541. colors = ['b', 'lightblue']
  542. labels = ['pol. ', 'cir. ']
  543. line_widths = [1.5, 1.5]
  544. zorders = [10, 2]
  545. ax.plot([], [], color='white',label=' ')
  546. hdis = []
  547. for run_idx, run_name in enumerate(plot_run_names[:2]):
  548. # ax = axes[max_hdi_idx, run_idx]
  549. # label = traj.derived_parameters.runs[run_name].morphology.morph_label
  550. label = labels[run_idx]
  551. hdi = traj.results.runs[run_name].inh_head_direction_indices[selected_neuron_idx]
  552. hdis.append(hdi)
  553. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  554. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  555. rate_plot.append(rate_plot[0])
  556. # plt_label = '{:s} {:.2f}'.format(short_labels(label), hdi)
  557. plt_label = '{:s}'.format(short_labels(label))
  558. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  559. label=plt_label,
  560. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  561. # ax.set_title('Inh. Firing Rate')
  562. # TODO: Set ticks for polar
  563. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  564. ticks = [40., 80., 120.]
  565. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  566. enumerate(ticks)], angle=60)
  567. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  568. ax.xaxis.grid(linewidth=0.4)
  569. ax.yaxis.grid(linewidth=0.4)
  570. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  571. leg.get_frame().set_linewidth(0.0)
  572. hdi_box_x, hdi_box_y = (0.86, -0.04)
  573. hdi_box_dy = 0.14
  574. hdi_box = ax.text(hdi_box_x, hdi_box_y+2*hdi_box_dy, 'HDI', fontsize='medium', transform=ax.transAxes, zorder=9.)
  575. hdi_box = ax.text(hdi_box_x, hdi_box_y+hdi_box_dy, '{:.2f}'.format(hdis[0]), fontsize='medium', transform=ax.transAxes, zorder=9.)
  576. hdi_box = ax.text(hdi_box_x, hdi_box_y, '{:.2f}'.format(hdis[1]), fontsize='medium', transform=ax.transAxes, zorder=9.)
  577. ax.axes.spines["polar"].set_visible(False)
  578. if save_figs:
  579. plt.savefig(figname)
  580. plt.close(fig)
  581. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id, cut_off_dist):
  582. labels = []
  583. inh_hdis = []
  584. exc_hdis = []
  585. no_conn_hdi = 0.
  586. for run_idx, run_name in enumerate(plot_run_names):
  587. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  588. if label != NO_SYNAPSES:
  589. labels.append(normal_labels(label))
  590. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  591. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  592. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  593. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  594. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  595. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  596. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  597. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  598. ex_positions = traj.results.runs[run_name].ex_positions
  599. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  600. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  601. (ex_positions[:, 1] >= cut_off_dist) & \
  602. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  603. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  604. else:
  605. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  606. no_conn_hdi = np.mean(exc_head_direction_indices)
  607. # Look for a representative excitatory neuron
  608. hdi_mean_dict = {}
  609. excitatory_hdi_means = [np.mean(hdis) for hdis in exc_hdis]
  610. hdi_mean_dict["polar_exc"] = excitatory_hdi_means[0]
  611. hdi_mean_dict["circular_exc"] = excitatory_hdi_means[1]
  612. inhibitory_hdi_means = [np.mean(hdis) for hdis in inh_hdis]
  613. hdi_mean_dict["polar_inh"] = inhibitory_hdi_means[0]
  614. hdi_mean_dict["circular_inh"] = inhibitory_hdi_means[1]
  615. width = 2 * panel_size
  616. height = 1.2 * panel_size
  617. fig, axes = plt.subplots(2, 1, figsize=(width, height))
  618. plt.subplots_adjust(wspace=0, hspace=0.1)
  619. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  620. max_density = 0
  621. for i in range(2):
  622. # i = i + 1 # grid spec indexes from 0
  623. # ax = fig.add_subplot(gs[i])
  624. ax = axes[i]
  625. density_e, _, _ = ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  626. density_i, _, _ = ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  627. max_density = np.max([max_density, np.max(density_e), np.max(density_i)])
  628. ax.axvline(np.mean(exc_hdis[i]), color='r')
  629. ax.axvline(np.mean(inh_hdis[i]), color='b')
  630. ax.axvline(no_conn_hdi, color='dimgrey', linestyle='--', linewidth=1.5)
  631. ax.set_ylabel(labels[i], rotation='vertical')
  632. # plt.axis('on')
  633. if i == 0:
  634. ax.set_xticklabels([])
  635. else:
  636. ax.set_xlabel('HDI')
  637. remove_frame(ax, ["top", "right", "bottom"])
  638. max_density = 1.2 * max_density
  639. fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2)
  640. axes[0].annotate('% cells', (0, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  641. axes[1].annotate("no ihn.\n{:.2f}".format(no_conn_hdi), xy=(no_conn_hdi, max_density),
  642. xytext=(-2, 0), xycoords="data",
  643. textcoords="offset points",
  644. va="top", ha="right", color="dimgrey")
  645. for i, ax in enumerate(axes):
  646. ax.annotate("{:.2f}".format(np.mean(exc_hdis[i])), xy=(np.mean(exc_hdis[i]), max_density),
  647. xytext=(2, 0), xycoords="data",
  648. textcoords="offset points",
  649. va="top", ha="left", color="r")
  650. # i_ha = "left" if i == 1 else "right"
  651. # i_offset = 2 if i == 1 else -2
  652. i_ha = "right"
  653. i_offset = -1
  654. ax.annotate("{:.2f}".format(np.mean(inh_hdis[i])), xy=(np.mean(inh_hdis[i]), max_density),
  655. xytext=(i_offset, 0), xycoords="data",
  656. textcoords="offset points",
  657. va="top", ha=i_ha, color="b")
  658. for ax in axes:
  659. ax.set_ylim(0, max_density)
  660. # plt.annotate('probability density', (-0.2,1.5), xycoords='axes fraction', rotation=90, fontsize=18)
  661. if save_figs:
  662. plt.savefig(FIGURE_SAVE_PATH + 'E_hdi_histogram_combined_and_overlayed' + FIGURE_SAVE_FORMAT.format(cut_off_dist))
  663. plt.close(fig)
  664. return hdi_mean_dict
  665. def get_neurons_with_given_hdi(polar_hdi, circular_hdi, max_number_of_suggestions, plot_run_names, traj, type):
  666. polar_run_name = plot_run_names[0]
  667. circular_run_name = plot_run_names[1]
  668. polar_ex_hdis = traj.results.runs[polar_run_name].head_direction_indices if type == "ex" else traj.results.runs[
  669. polar_run_name].inh_head_direction_indices
  670. circular_ex_hdis = traj.results.runs[circular_run_name].head_direction_indices if type == "ex" else \
  671. traj.results.runs[
  672. polar_run_name].inh_head_direction_indices
  673. neuron_indices = get_indices_of_closest_values(polar_ex_hdis, polar_hdi,
  674. circular_ex_hdis,
  675. circular_hdi, 0.1 * np.abs(
  676. polar_hdi - circular_hdi), max_number_of_suggestions)
  677. return neuron_indices
  678. def get_indices_of_closest_values(first_list, first_value, second_list, second_value, absolute_tolerance_list_one,
  679. number_of_indices):
  680. is_close_in_list_one = np.abs(first_list - first_value) < absolute_tolerance_list_one
  681. indices_close_in_list_one = np.where(is_close_in_list_one)[0]
  682. indices_closest_in_list_two = indices_close_in_list_one[np.argpartition(np.abs(second_list[
  683. indices_close_in_list_one] - second_value),
  684. number_of_indices)]
  685. return indices_closest_in_list_two[:number_of_indices]
  686. def filter_run_names_by_par_dict(traj, par_dict):
  687. run_name_list = []
  688. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  689. traj.f_set_crun(run_name)
  690. paramters_equal = True
  691. for key, val in par_dict.items():
  692. if (traj.par[key] != val):
  693. paramters_equal = False
  694. if paramters_equal:
  695. run_name_list.append(run_name)
  696. traj.f_restore_default()
  697. return run_name_list
  698. def plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, plot_run_names, cut_off_dist):
  699. corr_len_expl = traj.f_get('scale').f_get_range()
  700. seed_expl = traj.f_get('seed').f_get_range()
  701. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  702. label_range = set(label_expl)
  703. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  704. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  705. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  706. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  707. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  708. ex_tunings = traj.results.runs[run_name].ex_tunings
  709. inh_hdis = []
  710. exc_hdis = []
  711. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  712. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  713. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  714. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  715. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  716. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  717. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  718. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  719. ex_positions = traj.results.runs[run_name].ex_positions
  720. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  721. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  722. (ex_positions[:, 1] >= cut_off_dist) & \
  723. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  724. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  725. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  726. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  727. # TODO: Standard deviation also for the population
  728. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  729. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  730. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  731. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  732. markersize = 4.
  733. exc_style_dict = {
  734. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  735. POLARIZED: ['red', 'solid', '^', markersize],
  736. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  737. }
  738. inh_style_dict = {
  739. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  740. POLARIZED: ['blue', 'solid', 'o', markersize],
  741. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  742. }
  743. width = 2 * panel_size
  744. height = 1.2 * panel_size
  745. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  746. for label in sorted(label_range, reverse=True):
  747. if label == NO_SYNAPSES:
  748. no_conn_hdi = exc_hdi_n_and_seed_mean[get_closest_scale(traj, 200), label]
  749. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  750. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  751. textcoords="offset points",
  752. va="top", \
  753. ha="right",
  754. color="dimgrey")
  755. continue
  756. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  757. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  758. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  759. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  760. corr_len_range = exc_hdi_mean.keys().to_numpy()
  761. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  762. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  763. simplex_grid_scale = corr_len_range * np.sqrt(2)
  764. ax.plot(simplex_grid_scale, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  765. markersize=exc_mar_size, alpha=0.5)
  766. plt.fill_between(simplex_grid_scale, exc_hdi_mean - exc_hdi_std,
  767. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  768. ax.plot(simplex_grid_scale, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  769. markersize=inh_mar_size, alpha=0.5)
  770. plt.fill_between(simplex_grid_scale, inh_hdi_mean - inh_hdi_std,
  771. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  772. ax.set_xlabel('simplex grid scale')
  773. ax.set_ylabel('HDI')
  774. ax.axvline(get_closest_scale(traj, 200.0) * np.sqrt(2), color='k', linewidth=0.5, zorder=0)
  775. ax.set_ylim(0.0, 1.0)
  776. # ax.set_xlim(0.0, np.max(corr_len_range))
  777. remove_frame(ax, ["right", "top"])
  778. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  779. row_labels=None,
  780. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  781. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  782. fig.subplots_adjust(bottom=0.2, left=0.2)
  783. # plt.legend()
  784. if save_figs:
  785. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_grid_scale' + FIGURE_SAVE_FORMAT)
  786. plt.close(fig)
  787. def plot_exc_and_inh_hdi_over_fit_corr_len(traj, plot_run_names, cut_off_dist):
  788. corr_len_expl = traj.f_get('scale').f_get_range()
  789. seed_expl = traj.f_get('seed').f_get_range()
  790. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  791. label_range = set(label_expl)
  792. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  793. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  794. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  795. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  796. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  797. ex_tunings = traj.results.runs[run_name].ex_tunings
  798. inh_hdis = []
  799. exc_hdis = []
  800. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  801. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  802. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  803. (inh_axonal_cloud[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  804. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  805. (inh_axonal_cloud[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  806. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  807. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  808. ex_positions = traj.results.runs[run_name].ex_positions
  809. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  810. (ex_positions[:, 0] <= traj.parameters.input_map.sheet_size - cut_off_dist) & \
  811. (ex_positions[:, 1] >= cut_off_dist) & \
  812. (ex_positions[:, 1] <= traj.parameters.input_map.sheet_size - cut_off_dist)
  813. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  814. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  815. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  816. # TODO: Standard deviation also for the population
  817. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  818. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  819. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  820. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  821. markersize = 4.
  822. exc_style_dict = {
  823. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  824. POLARIZED: ['red', 'solid', '^', markersize],
  825. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  826. }
  827. inh_style_dict = {
  828. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  829. POLARIZED: ['blue', 'solid', 'o', markersize],
  830. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  831. }
  832. # colors = ['blue', 'grey', 'lightblue']
  833. # linestyles = ['solid', 'dashed', 'solid']
  834. # markers = [verts, '', 'o']
  835. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='perlin_map', load=True)
  836. width = 2 * panel_size
  837. height = 1.2 * panel_size
  838. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  839. for label in sorted(label_range, reverse=True):
  840. if label == NO_SYNAPSES:
  841. no_conn_hdi = exc_hdi_n_and_seed_mean[get_closest_scale(traj, 200), label]
  842. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  843. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  844. textcoords="offset points",
  845. va="top", \
  846. ha="right",
  847. color="dimgrey")
  848. continue
  849. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  850. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  851. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  852. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  853. corr_len_range = exc_hdi_mean.keys().to_numpy()
  854. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  855. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  856. fit_corr_len = [corr_len_fit_dict[corr_len] for corr_len in corr_len_range]
  857. ax.plot(fit_corr_len, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  858. markersize=exc_mar_size, alpha=0.5)
  859. plt.fill_between(fit_corr_len, exc_hdi_mean - exc_hdi_std,
  860. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  861. ax.plot(fit_corr_len, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  862. markersize=inh_mar_size, alpha=0.5)
  863. plt.fill_between(fit_corr_len, inh_hdi_mean - inh_hdi_std,
  864. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  865. ax.set_xlabel('correlation length (um)')
  866. ax.set_ylabel('HDI')
  867. ax.axvline(corr_len_fit_dict[get_closest_scale(traj, 200.0)], color='k', linewidth=0.5, zorder=0)
  868. ax.set_ylim(0.0, 1.0)
  869. # ax.set_xlim(0.0, np.max(corr_len_range))
  870. remove_frame(ax, ["right", "top"])
  871. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  872. row_labels=None,
  873. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  874. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  875. fig.subplots_adjust(bottom=0.2, left=0.2)
  876. # plt.legend()
  877. if save_figs:
  878. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_corr_len_scaled' + FIGURE_SAVE_FORMAT)
  879. plt.close(fig)
  880. def get_phase_difference(total_difference):
  881. """
  882. Map accumulated phase difference to shortest possible difference
  883. :param total_difference:
  884. :return: relative_difference
  885. """
  886. return (total_difference + np.pi) % (2 * np.pi) - np.pi
  887. def plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names, figsize=(9, 9)):
  888. # The plot that Imre wanted
  889. n_bins = traj.parameters.input.number_of_directions
  890. fig, ax = plt.subplots(1, 1, figsize=figsize)
  891. dir_bins = np.linspace(-np.pi, np.pi, n_bins + 1)
  892. plot_fr_array = []
  893. labels = []
  894. similarity_threshold = np.pi / 6.
  895. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  896. for run_idx, run_name in enumerate(plot_run_names):
  897. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  898. labels.append(short_labels(label))
  899. fr_similar_tunings = []
  900. fr_different_tunings = []
  901. ex_tunings = traj.results.runs[run_name].ex_tunings
  902. firing_rate_array = traj.results[run_name].firing_rate_array
  903. for tuning, firing_rates in zip(ex_tunings, firing_rate_array):
  904. for idx, dir in enumerate(directions):
  905. if np.abs(get_phase_difference(tuning - dir)) <= similarity_threshold:
  906. fr_similar_tunings.append(firing_rates[idx])
  907. elif np.abs(get_phase_difference(tuning + np.pi - dir)) <= similarity_threshold:
  908. fr_different_tunings.append(firing_rates[idx])
  909. plot_fr_array.append([np.mean(fr_similar_tunings), np.mean(fr_different_tunings)])
  910. x = np.arange(3) # the label locations
  911. width = 0.35 # the width of the bars
  912. plot_fr_array = np.array(plot_fr_array)
  913. # rects1 = ax.bar(x - width / 2, plot_fr_array[:, 0], width,
  914. # label=r'$\theta_{pref} \pm 30°$')
  915. # rects2 = ax.bar(x + width / 2, plot_fr_array[:, 1], width,
  916. # label=r'$\theta_{opp} \pm 30°$')
  917. rects1 = ax.bar(x - width / 2, plot_fr_array[:, 0], width,
  918. label=r'sim.')
  919. rects2 = ax.bar(x + width / 2, plot_fr_array[:, 1], width,
  920. label=r'diff.')
  921. ax.set_xticks(x)
  922. ax.set_xticklabels(labels)
  923. ax.spines['right'].set_visible(False)
  924. ax.spines['top'].set_visible(False)
  925. # ax.set_title('Mean firing rate for tunings similar and different to input')
  926. # ax.set_ylabel('Mean firing rate')
  927. ax.annotate(r'$\overline{\mathrm{fr}}$ (Hz)', (0.05, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  928. leg = ax.legend(loc="upper left", bbox_to_anchor=(0.0, 1.2), handlelength=1, fontsize="medium")
  929. leg.get_frame().set_linewidth(0.0)
  930. def autolabel(rects):
  931. """Attach a text label above each bar in *rects*, displaying its height."""
  932. for rect in rects:
  933. height = rect.get_height()
  934. ax.annotate('{}'.format(np.round(height)),
  935. xy=(rect.get_x() + rect.get_width() / 2, height),
  936. xytext=(0, 3), # 3 points vertical offset
  937. textcoords="offset points",
  938. ha='center', va='bottom')
  939. autolabel(rects1)
  940. autolabel(rects2)
  941. fig.tight_layout()
  942. if save_figs:
  943. plt.savefig(FIGURE_SAVE_PATH + 'SUPPLEMENT_B_firing_rate_similar_vs_diff_tuning' + FIGURE_SAVE_FORMAT, dpi=200)
  944. plt.close(fig)
  945. def get_firing_rates_along_preferred_axis(traj, run_name, neuron_idx):
  946. firing_rates = traj.results[run_name].firing_rate_array[neuron_idx, :]
  947. tuning = traj.results[run_name].ex_tunings[neuron_idx]
  948. anti_tuning = tuning + np.pi if tuning + np.pi < np.pi else tuning - np.pi
  949. tuning_idx = np.argmin(np.abs(directions - tuning))
  950. anti_tuning_idx = np.argmin(np.abs(directions - anti_tuning))
  951. firing_at_the_preferred_direction = firing_rates[tuning_idx]
  952. firing_at_the_opposite_direction = firing_rates[anti_tuning_idx]
  953. return firing_at_the_preferred_direction, firing_at_the_opposite_direction
  954. def get_hdi(traj, run_name, neuron_idx, type):
  955. return traj.results.runs[run_name].head_direction_indices[neuron_idx] if type=="ex" else traj.results.runs[
  956. run_name].inh_head_direction_indices[neuron_idx]
  957. def plot_colorbar(figsize=(2, 2), figname=None):
  958. azimuth_no = 360
  959. zenith_no = 15
  960. azimuths = np.linspace(-180, 180, azimuth_no)
  961. zeniths = np.linspace(0.85, 1, zenith_no)
  962. values = azimuths * np.ones((zenith_no, azimuth_no))
  963. fig, ax = plt.subplots(subplot_kw=dict(projection='polar'), figsize=figsize)
  964. ax.pcolormesh(azimuths * np.pi / 180.0, zeniths, values, cmap=head_direction_input_colormap)
  965. # ax.set_yticks([])
  966. ax.set_thetagrids([0, 90, 180, 270])
  967. ax.tick_params(pad=-2)
  968. ax.set_ylim(0, 1)
  969. ax.grid(True)
  970. y_tick_labels = []
  971. ax.set_yticklabels(y_tick_labels)
  972. gridlines = ax.yaxis.get_gridlines()
  973. [line.set_linewidth(0.0) for line in gridlines]
  974. gridlines = ax.xaxis.get_gridlines()
  975. [line.set_linewidth(0.0) for line in gridlines]
  976. ax.axes.spines["polar"].set_visible(False)
  977. plt.subplots_adjust(left=0.25, right=0.75, bottom=0.25, top=0.75)
  978. # plt.show()
  979. if figname is not None:
  980. plt.savefig(figname, transparent=True)
  981. plt.close(fig)
  982. if __name__ == "__main__":
  983. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  984. NO_LOADING = 0
  985. FULL_LOAD = 2
  986. traj.f_load(filename=os.path.join(DATA_FOLDER, TRAJ_NAME + ".hdf5"), load_parameters=FULL_LOAD,
  987. load_results=NO_LOADING)
  988. traj.v_auto_load = True
  989. save_figs = True
  990. print("# Plotting script polarized interneurons")
  991. print()
  992. map_length_scale = 200.0
  993. map_seed = 1
  994. exemplary_head_direction = 0
  995. print("## Map specifications")
  996. print("\tinput map scale: {:.1f} um".format(map_length_scale))
  997. print("\tmap seed: {:d}".format(map_seed))
  998. print()
  999. print("## Input specification")
  1000. print("\tselected head direction: {:.0f}°".format(exemplary_head_direction))
  1001. print()
  1002. print("## Selected simulations")
  1003. plot_scale = get_closest_scale(traj, map_length_scale)
  1004. par_dict = {'seed': map_seed, 'scale': plot_scale}
  1005. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  1006. run_name_dict = {}
  1007. for run_name in plot_run_names:
  1008. traj.f_set_crun(run_name)
  1009. run_name_dict[traj.derived_parameters.runs[run_name].morphology.morph_label] = run_name
  1010. for network_type, run_name in run_name_dict.items():
  1011. print("{:s}: {:s}".format(network_type, run_name))
  1012. directions = get_input_head_directions(traj)
  1013. direction_idx = np.argmin(np.abs(np.array(directions) - np.deg2rad(exemplary_head_direction)))
  1014. selected_neuron_excitatory = 1052
  1015. selected_inhibitory_neuron = 28
  1016. print("## Figure specification")
  1017. print("\tpanel size: {:.2f} cm".format(panel_size * cm_per_inch))
  1018. print()
  1019. plot_colorbar(figsize=(0.8 * panel_size, 0.8 * panel_size), figname=FIGURE_SAVE_PATH + "A_i_colormap.svg")
  1020. plot_input_map(traj, run_name_dict[POLARIZED], figname="A_i_exemplary_input_map"+FIGURE_SAVE_FORMAT,
  1021. figsize=(panel_size, panel_size))
  1022. # plot_example_input_maps(traj, figsize=(2 * panel_size, 2 * panel_size))
  1023. plot_axonal_clouds(traj, plot_run_names)
  1024. #
  1025. plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, selected_neuron_excitatory)
  1026. in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron)
  1027. #
  1028. hdi_means = plot_hdi_histogram_combined_and_overlayed(
  1029. traj, plot_run_names,
  1030. selected_neuron_excitatory,
  1031. selected_inhibitory_neuron,
  1032. cut_off_dist=100.)
  1033. #
  1034. plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_excitatory)
  1035. plot_polar_plot_inhibitory(traj, plot_run_names, selected_inhibitory_neuron)
  1036. plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names, figsize=(1.2*panel_size, 1.2*panel_size))
  1037. plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, traj.f_get_run_names(), cut_off_dist=100.)
  1038. plot_exc_and_inh_hdi_over_fit_corr_len(traj, traj.f_get_run_names(), cut_off_dist=100.)
  1039. if not save_figs:
  1040. plt.show()
  1041. traj.f_restore_default()