paper_figures_orientation_map.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767
  1. import itertools
  2. import os
  3. import matplotlib.legend as mlegend
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. from matplotlib.patches import Ellipse
  8. from matplotlib.patches import Polygon
  9. from matplotlib.patches import Rectangle
  10. from mpl_toolkits.axes_grid1 import ImageGrid
  11. from pypet import Trajectory
  12. from pypet.brian2 import Brian2MonitorResult
  13. from scripts.spatial_network import get_position_mesh, Pickle, get_correct_position_mesh
  14. from scripts.spatial_maps.correlation_length_fit.correlation_length_fit import \
  15. correlation_length_fit_dict
  16. from scripts.spatial_network.figures_spatial_head_direction_network_orientation_map import plot_hdi_in_space
  17. from scripts.model_figure.figure_utils import remove_frame, remove_ticks, add_length_scale, cm_per_inch, \
  18. panel_size, head_direction_input_colormap
  19. from scripts.spatial_network.supplement_pinwheel_map.run_simulation_pinwheel_map import DATA_FOLDER, TRAJ_NAME, \
  20. get_input_head_directions, POLARIZED, CIRCULAR, NO_SYNAPSES
  21. plt.style.use('figures.mplstyle')
  22. FIGURE_SAVE_PATH = '../../../figures/supplementary_orientation_map/'
  23. save_figs = False
  24. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  25. """
  26. Place a table legend on the axes.
  27. Creates a legend where the labels are not directly placed with the artists,
  28. but are used as row and column headers, looking like this:
  29. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  30. -------------------------------------------------------------
  31. row_labels[1] |
  32. row_labels[2] | <artists go there>
  33. row_labels[3] |
  34. Parameters
  35. ----------
  36. ax : `matplotlib.axes.Axes`
  37. The artist that contains the legend table, i.e. current axes instant.
  38. col_labels : list of str, optional
  39. A list of labels to be used as column headers in the legend table.
  40. `len(col_labels)` needs to match `ncol`.
  41. row_labels : list of str, optional
  42. A list of labels to be used as row headers in the legend table.
  43. `len(row_labels)` needs to match `len(handles) // ncol`.
  44. title_label : str, optional
  45. Label for the top left corner in the legend table.
  46. ncol : int
  47. Number of columns.
  48. Other Parameters
  49. ----------------
  50. Refer to `matplotlib.legend.Legend` for other parameters.
  51. """
  52. #################### same as `matplotlib.axes.Axes.legend` #####################
  53. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  54. if len(extra_args):
  55. raise TypeError('legend only accepts two non-keyword arguments')
  56. if col_labels is None and row_labels is None:
  57. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  58. ax.legend_._remove_method = ax._remove_legend
  59. return ax.legend_
  60. #################### modifications for table legend ############################
  61. else:
  62. ncol = kwargs.pop('ncol')
  63. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  64. title_label = [title_label]
  65. # blank rectangle handle
  66. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  67. # empty label
  68. empty = [""]
  69. # number of rows infered from number of handles and desired number of columns
  70. nrow = len(handles) // ncol
  71. # organise the list of handles and labels for table construction
  72. if col_labels is None:
  73. assert nrow == len(
  74. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  75. nrow, len(row_labels))
  76. leg_handles = extra * nrow
  77. leg_labels = row_labels
  78. elif row_labels is None:
  79. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  80. ncol, len(col_labels))
  81. leg_handles = []
  82. leg_labels = []
  83. else:
  84. assert nrow == len(
  85. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  86. nrow, len(row_labels))
  87. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  88. ncol, len(col_labels))
  89. leg_handles = extra + extra * nrow
  90. leg_labels = title_label + row_labels
  91. for col in range(ncol):
  92. if col_labels is not None:
  93. leg_handles += extra
  94. leg_labels += [col_labels[col]]
  95. leg_handles += handles[col * nrow:(col + 1) * nrow]
  96. leg_labels += empty * nrow
  97. # Create legend
  98. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol + int(row_labels is not None),
  99. handletextpad=handletextpad, **kwargs)
  100. ax.legend_._remove_method = ax._remove_legend
  101. return ax.legend_
  102. def get_closest_correlation_length(traj, correlation_length):
  103. available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
  104. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
  105. if closest_length != correlation_length:
  106. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  107. correlation_length, closest_length))
  108. corr_len = closest_length
  109. return corr_len
  110. def get_closest_fit_correlation_length(traj, fit_corr_len):
  111. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  112. available_lengths = list(corr_len_fit_dict.values())
  113. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - fit_corr_len))]
  114. closest_corresponding_scale = list(corr_len_fit_dict.keys())[list(corr_len_fit_dict.values()).index(closest_length)]
  115. if closest_length != fit_corr_len:
  116. print("Warning: desired fit correlation length {:.1f} not available. Taking {:.1f} instead".format(
  117. fit_corr_len, closest_length))
  118. return closest_corresponding_scale
  119. def colorbar(mappable):
  120. from mpl_toolkits.axes_grid1 import make_axes_locatable
  121. import matplotlib.pyplot as plt
  122. last_axes = plt.gca()
  123. ax = mappable.axes
  124. fig = ax.figure
  125. divider = make_axes_locatable(ax)
  126. cax = divider.append_axes("right", size="5%", pad=0.05)
  127. cbar = fig.colorbar(mappable, cax=cax)
  128. plt.sca(last_axes)
  129. return cbar
  130. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, exemplary_excitatory_neuron_id):
  131. max_val = 0
  132. for run_name in plot_run_names:
  133. fr_array = traj.results.runs[run_name].firing_rate_array
  134. f_rates = fr_array[:, direction_idx]
  135. run_max_val = np.max(f_rates)
  136. if run_max_val > max_val:
  137. # if traj.derived_parameters.runs[run_name].morphology.morph_label == POLARIZED:
  138. # n_id_max_rate = np.argmax(f_rates)
  139. max_val = run_max_val
  140. n_id_polar_plot = exemplary_excitatory_neuron_id
  141. # Mark the neuron that is shown in polar plot
  142. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  143. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  144. # Vertices for the plotted triangle
  145. tr_scale = 30.
  146. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  147. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  148. tr_vertices = np.array(
  149. [[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  150. height = 1 * panel_size
  151. width = 3 * panel_size
  152. fig = plt.figure(figsize=(width, height))
  153. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  154. cbar_mode="single",
  155. cbar_size="7%", cbar_pad=panel_size / 10.0,
  156. nrows_ncols=(1, 3))
  157. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  158. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  159. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  160. firing_rate_array = traj.results[run_name].firing_rate_array
  161. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  162. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  163. number_of_excitatory_neurons_per_row)),
  164. vmin=0, vmax=max_val, cmap='Reds')
  165. # ax.set_title(adjust_label(label))
  166. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  167. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  168. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  169. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  170. for spine in ax.spines.values():
  171. spine.set_edgecolor("grey")
  172. spine.set_linewidth(1)
  173. remove_ticks(ax)
  174. # fig.suptitle('spatial firing rate map', fontsize=16)
  175. ax.cax.colorbar(c)
  176. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  177. # fig.tight_layout()
  178. if save_figs:
  179. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_excitatory.png', dpi=300)
  180. plt.close(fig)
  181. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron):
  182. max_val = 0
  183. for run_name in plot_run_names:
  184. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  185. f_rates = fr_array[:, direction_idx]
  186. run_max_val = np.max(f_rates)
  187. if run_max_val > max_val:
  188. max_val = run_max_val
  189. n_id_polar_plot = selected_inhibitory_neuron
  190. # Mark the neuron that is shown in Polar plot
  191. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  192. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  193. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  194. width = 3 * panel_size
  195. height = panel_size
  196. fig = plt.figure(figsize=(width, height))
  197. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  198. cbar_mode="single",
  199. cbar_size="7%", cbar_pad=panel_size / 10.0,
  200. nrows_ncols=(1, 3))
  201. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  202. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  203. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  204. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  205. X, Y = get_correct_position_mesh(inh_positions)
  206. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  207. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  208. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  209. number_of_inhibitory_neurons_per_row)),
  210. vmin=0, vmax=max_val, cmap='Blues')
  211. # ax.set_title(adjust_label(label))
  212. circle_r = 40.
  213. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  214. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  215. for spine in ax.spines.values():
  216. spine.set_edgecolor("grey")
  217. spine.set_linewidth(0.5)
  218. remove_ticks(ax)
  219. # fig.colorbar(c, ax=ax, label="f (Hz)")
  220. # fig.suptitle('spatial firing rate map', fontsize=16)
  221. ax.cax.colorbar(c)
  222. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  223. # fig.tight_layout()
  224. if save_figs:
  225. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_inhibitory.png', dpi=300)
  226. plt.close(fig)
  227. return max_val
  228. def plot_hdi_over_tuning(traj, plot_run_names):
  229. fig, ax = plt.subplots(1, 1)
  230. for run_idx, run_name in enumerate(plot_run_names):
  231. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  232. ex_tunings = traj.results.runs[run_name].ex_tunings
  233. ex_tunings_plt = np.array(ex_tunings)
  234. sort_ids = ex_tunings_plt.argsort()
  235. ex_tunings_plt = ex_tunings_plt[sort_ids]
  236. head_direction_indices = traj.results[run_name].head_direction_indices
  237. hdi_plt = head_direction_indices
  238. hdi_plt = hdi_plt[sort_ids]
  239. ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
  240. ax.legend()
  241. ax.set_xlabel("Angles (rad)")
  242. ax.set_ylabel("head direction index")
  243. ax.set_title('hdi over input tuning', fontsize=16)
  244. if save_figs:
  245. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png')
  246. plt.close(fig)
  247. def normal_labels(label):
  248. if label == POLARIZED:
  249. label = 'polar'
  250. elif label == NO_SYNAPSES:
  251. label = 'no interneurons'
  252. return label
  253. def short_labels(label):
  254. if label == POLARIZED:
  255. label = 'polar'
  256. elif label == CIRCULAR:
  257. label = 'circ.'
  258. elif label == "no conn":
  259. label = "no inh."
  260. return label
  261. def plot_input_map(traj, run_name, figsize=(panel_size, panel_size), figname='input_map.png'):
  262. n_ex = int(np.sqrt(traj.N_E))
  263. width, height = figsize
  264. fig = plt.figure(figsize=(width, height))
  265. axes = ImageGrid(fig, (0.15, 0.1, 0.7, 0.8), axes_pad=panel_size / 3.0, cbar_location="right", cbar_mode=None,
  266. nrows_ncols=(1, 1))
  267. ax = axes[0]
  268. traj.f_set_crun(run_name)
  269. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  270. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  271. scale_length = traj.morphology.long_axis * 2
  272. traj.f_restore_default()
  273. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  274. for spine in ax.spines.values():
  275. spine.set_edgecolor("grey")
  276. spine.set_linewidth(0.5)
  277. remove_ticks(ax)
  278. start_scale_x = 100
  279. end_scale_x = start_scale_x + scale_length
  280. start_scale_y = -70
  281. end_scale_y = start_scale_y
  282. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  283. ax.cax.set_visible(False)
  284. if save_figs:
  285. plt.savefig(FIGURE_SAVE_PATH + figname)
  286. plt.close(fig)
  287. def plot_example_input_maps(traj, figsize=(panel_size, panel_size), figname='pinwheel_example_input_maps.png'):
  288. n_ex = int(np.sqrt(traj.N_E))
  289. width, height = figsize
  290. fig = plt.figure(figsize=(width, height))
  291. axes = ImageGrid(fig, (0.1, 0.1, 0.85, 0.85), axes_pad=panel_size / 12.0, nrows_ncols=(3, 3))
  292. fit_corr_len_list = [20, 50, 100]
  293. seed_list = [1, 2, 3]
  294. corr_and_seed = itertools.product(fit_corr_len_list, seed_list)
  295. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  296. # for corr_len_idx, corr_len in enumerate(corr_len_list):
  297. # for seed_idx, seed in enumerate(seed_list):
  298. for idx, (ax, (fit_corr_len, seed)) in enumerate(zip(axes, corr_and_seed)):
  299. # ax = axes[corr_len_idx + seed_idx]
  300. corresp_scale = get_closest_fit_correlation_length(traj, fit_corr_len)
  301. par_dict = {'seed': seed, 'correlation_length': corresp_scale}
  302. run_name = filter_run_names_by_par_dict(traj, par_dict)[0]
  303. print(run_name)
  304. traj.f_set_crun(run_name)
  305. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  306. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  307. scale_length = traj.morphology.long_axis * 2
  308. traj.f_restore_default()
  309. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  310. for spine in ax.spines.values():
  311. spine.set_edgecolor("grey")
  312. spine.set_linewidth(0.5)
  313. remove_ticks(ax)
  314. ax.set_ylabel('{:3.0f} um'.format(corr_len_fit_dict[corresp_scale]))
  315. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  316. plt_polar_id = 255
  317. plt_polar_coordinates = inhibitory_axonal_cloud_array[plt_polar_id]
  318. plt_polar_neuron = Pickle(plt_polar_coordinates[0],
  319. plt_polar_coordinates[1],
  320. traj.morphology.long_axis,
  321. traj.morphology.short_axis,
  322. plt_polar_coordinates[2])
  323. circ_radius = np.sqrt(traj.morphology.long_axis * traj.morphology.short_axis)
  324. plt_circ_id = 65
  325. plt_circ_coordinates = inhibitory_axonal_cloud_array[plt_circ_id]
  326. plt_circ_neuron = Pickle(plt_circ_coordinates[0],
  327. plt_circ_coordinates[1],
  328. circ_radius,
  329. circ_radius,
  330. plt_circ_coordinates[2])
  331. plt_interneurons = [plt_polar_neuron, plt_circ_neuron]
  332. for p in plt_interneurons:
  333. ell = p.get_ellipse()
  334. edgecolor = 'black'
  335. alpha = 1
  336. zorder = 10
  337. linewidth = 1.
  338. ell.set_edgecolor(edgecolor)
  339. ell.set_alpha(alpha)
  340. ell.set_zorder(zorder)
  341. ell.set_linewidth(linewidth)
  342. ax.add_artist(ell)
  343. if idx == 8:
  344. start_scale_x = 100
  345. end_scale_x = start_scale_x + scale_length
  346. start_scale_y = -70
  347. end_scale_y = start_scale_y
  348. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  349. ax.cax.set_visible(False)
  350. if save_figs:
  351. plt.savefig(FIGURE_SAVE_PATH + figname)
  352. plt.close(fig)
  353. def plot_axonal_clouds(traj, plot_run_names):
  354. n_ex = int(np.sqrt(traj.N_E))
  355. height = 1 * panel_size
  356. width = 3 * panel_size
  357. cluster_positions = [(250, 250), (750, 200), (450, 600)]
  358. cluster_sizes = [4, 4, 4]
  359. selected_neurons = []
  360. inhibitory_positions = traj.results.runs[plot_run_names[0]].inhibitory_axonal_cloud_array[:, :2]
  361. for cluster_position, number_of_neurons_in_cluster in zip(cluster_positions, cluster_sizes):
  362. selection = get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster,
  363. inhibitory_positions)
  364. selected_neurons.extend(selection)
  365. fig = plt.figure(figsize=(width, height))
  366. axes = ImageGrid(fig, 111, axes_pad=0.15, cbar_location="right", cbar_mode="single", cbar_size="7%",
  367. nrows_ncols=(1, 3))
  368. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  369. traj.f_set_crun(run_name)
  370. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  371. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  372. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  373. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  374. inhibitory_axonal_cloud_array]
  375. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  376. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  377. # ax.set_title(normal_labels(label))
  378. ax.set_aspect('equal')
  379. remove_frame(ax)
  380. remove_ticks(ax)
  381. # fig.colorbar(c, ax=ax, label="Tuning")
  382. if label != NO_SYNAPSES and axonal_clouds is not None:
  383. for i, p in enumerate(axonal_clouds):
  384. ell = p.get_ellipse()
  385. if i in selected_neurons:
  386. edgecolor = 'black'
  387. alpha = 1
  388. zorder = 10
  389. linewidth = 1
  390. else:
  391. edgecolor = 'gray'
  392. alpha = 0.5
  393. zorder = 1
  394. linewidth = 0.3
  395. ell.set_edgecolor(edgecolor)
  396. ell.set_alpha(alpha)
  397. ell.set_zorder(zorder)
  398. ell.set_linewidth(linewidth)
  399. ax.add_artist(ell)
  400. traj.f_set_crun(plot_run_names[0])
  401. scale_length = traj.morphology.long_axis * 2
  402. traj.f_restore_default()
  403. start_scale_x = 100
  404. end_scale_x = start_scale_x + scale_length
  405. start_scale_y = -70
  406. end_scale_y = start_scale_y
  407. add_length_scale(axes[0], scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  408. axes[1].set_yticklabels([])
  409. axes[2].set_yticklabels([])
  410. axes[0].cax.set_visible(False)
  411. traj.f_restore_default()
  412. if save_figs:
  413. plt.savefig(FIGURE_SAVE_PATH + 'B_i_axonal_clouds.png')
  414. plt.close(fig)
  415. def get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster, positions):
  416. position = np.array(cluster_position)
  417. distance_vectors = positions - np.expand_dims(position, 0).repeat(positions.shape[0],
  418. axis=0)
  419. distances = np.linalg.norm(distance_vectors, axis=1)
  420. selection = list(np.argpartition(distances, number_of_neurons_in_cluster)[:number_of_neurons_in_cluster])
  421. return selection
  422. def plot_orientation_maps_diff_scales(traj):
  423. n_ex = int(np.sqrt(traj.N_E))
  424. scale_run_names = []
  425. plot_scales = [0.0, 100.0, 200.0, 300.0]
  426. for scale in plot_scales:
  427. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj, scale), 'long_axis': 100.}
  428. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  429. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  430. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  431. traj.f_set_crun(run_name)
  432. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  433. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  434. # TODO: Why was this transposed for plotting? (now changed)
  435. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  436. ax.set_title('Correlation length: {}'.format(scale))
  437. fig.colorbar(c, ax=ax, label="Tuning")
  438. # fig.suptitle('axonal cloud', fontsize=16)
  439. traj.f_restore_default()
  440. if save_figs:
  441. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png')
  442. plt.close(fig)
  443. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  444. n_ex = int(np.sqrt(traj.N_E))
  445. scale_run_names = []
  446. plot_scales = [0.0, 400.0, 800.0]
  447. real_scales = [get_closest_correlation_length(traj, scale) for scale in plot_scales]
  448. for scale in plot_scales:
  449. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj, scale), 'long_axis': 100.}
  450. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  451. print(scale_run_names)
  452. inhibitory_positions = traj.results.runs[scale_run_names[1]].inhibitory_axonal_cloud_array[:, :2]
  453. selected_polar_neuron = get_neurons_close_to_given_position((300, 300), 1, inhibitory_positions)[0]
  454. print(selected_polar_neuron)
  455. selected_circular_neuron = get_neurons_close_to_given_position((600, 600), 1, inhibitory_positions)[0]
  456. width = panel_size
  457. height = 1.5 * panel_size
  458. fig = plt.figure(figsize=(width, height))
  459. axes = ImageGrid(fig, 111, axes_pad=0.15,
  460. nrows_ncols=(len(plot_scales), 1))
  461. for ax, run_name, scale in zip(axes, scale_run_names, real_scales):
  462. traj.f_set_crun(run_name)
  463. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  464. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  465. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  466. inhibitory_axonal_cloud_array]
  467. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  468. # TODO: Why was this transposed for plotting? (now changed)
  469. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  470. # ax.set_title('Correlation length: {}'.format(scale))
  471. # fig.colorbar(c, ax=ax, label="Tuning")
  472. ax.set_xticks([])
  473. ax.set_yticks([])
  474. p1 = axonal_clouds[selected_polar_neuron]
  475. ell = p1.get_ellipse()
  476. ell._linewidth = 2.
  477. ax.add_artist(ell)
  478. p2 = axonal_clouds[selected_circular_neuron]
  479. circ_r = 2 * np.sqrt(p2.a * p2.b)
  480. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  481. circ._linewidth = 2.
  482. ax.add_artist(circ)
  483. ax.annotate("{:.0f}".format(scale), xy=(1.0, 0.5), xytext=(2, 0), xycoords="axes fraction", textcoords="offset "
  484. "points",
  485. va="center", ha="left")
  486. remove_frame(ax)
  487. # fig.suptitle('axonal cloud', fontsize=16)
  488. traj.f_restore_default()
  489. add_length_scale(axes[1], 200, -300, -100, 50, 50)
  490. axes[0].annotate("input maps", xy=(-0.5, 1.0), xycoords="axes fraction", rotation=90, ha="left", va="top")
  491. fig.tight_layout()
  492. if save_figs:
  493. plt.savefig(FIGURE_SAVE_PATH + 'F_orientation_maps_diff_scales_with_ellipse.png')
  494. plt.close(fig)
  495. def plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_idx,
  496. figname=FIGURE_SAVE_PATH + 'D_polar_plot_excitatory.png'):
  497. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  498. directions_plt = list(directions)
  499. directions_plt.append(directions[0])
  500. height = panel_size
  501. width = panel_size
  502. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  503. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  504. # sorted_ids = np.argsort(head_direction_indices)
  505. # plot_n_idx = sorted_ids[-75]
  506. plot_n_idx = selected_neuron_idx
  507. line_styles = ['dotted', 'solid', 'dashed']
  508. colors = ['r', 'lightsalmon', 'grey']
  509. line_widths = [1.5, 1.5, 1]
  510. zorders = [10, 2, 1]
  511. max_rate = 0.0
  512. for run_idx, run_name in enumerate(plot_run_names):
  513. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  514. hdi = traj.results.runs[run_name].head_direction_indices[selected_neuron_idx]
  515. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  516. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  517. run_max_rate = np.max(rate_plot)
  518. if run_max_rate > max_rate:
  519. max_rate = run_max_rate
  520. rate_plot.append(rate_plot[0])
  521. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  522. label='{:s} {:.2f}'.format(short_labels(label), hdi),
  523. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  524. # ax.set_title('Firing Rate')
  525. # ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  526. # TODO: Set ticks for polar
  527. ticks = [40., 80.]
  528. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  529. enumerate(ticks)], angle=60)
  530. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  531. ax.xaxis.grid(linewidth=0.4)
  532. ax.yaxis.grid(linewidth=0.4)
  533. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  534. leg.get_frame().set_linewidth(0.0)
  535. ax.axes.spines["polar"].set_visible(False)
  536. if save_figs:
  537. plt.savefig(figname)
  538. plt.close(fig)
  539. def plot_polar_plot_inhibitory(traj, plot_run_names, selected_neuron_idx, figname=FIGURE_SAVE_PATH +
  540. 'D_polar_plot_inhibitory.png'):
  541. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  542. directions_plt = list(directions)
  543. directions_plt.append(directions[0])
  544. height = panel_size
  545. width = panel_size
  546. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  547. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  548. # sorted_ids = np.argsort(head_direction_indices)
  549. # plot_n_idx = sorted_ids[-75]
  550. plot_n_idx = selected_neuron_idx
  551. line_styles = ['dotted', 'solid']
  552. colors = ['b', 'lightblue']
  553. line_widths = [1.5, 1.5]
  554. zorders = [10, 2]
  555. for run_idx, run_name in enumerate(plot_run_names[:2]):
  556. # ax = axes[max_hdi_idx, run_idx]
  557. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  558. hdi = traj.results.runs[run_name].inh_head_direction_indices[selected_neuron_idx]
  559. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  560. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  561. rate_plot.append(rate_plot[0])
  562. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  563. label='{:s} {:.2f}'.format(short_labels(label), hdi),
  564. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  565. # ax.set_title('Inh. Firing Rate')
  566. # TODO: Set ticks for polar
  567. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  568. ticks = [40., 80., 120.]
  569. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  570. enumerate(ticks)], angle=60)
  571. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  572. ax.xaxis.grid(linewidth=0.4)
  573. ax.yaxis.grid(linewidth=0.4)
  574. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  575. leg.get_frame().set_linewidth(0.0)
  576. ax.axes.spines["polar"].set_visible(False)
  577. if save_figs:
  578. plt.savefig(figname)
  579. plt.close(fig)
  580. def plot_hdi_over_corr_len(traj, plot_run_names):
  581. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  582. seed_expl = traj.f_get('seed').f_get_range()
  583. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  584. label_range = set(label_expl)
  585. hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  586. hdi_frame.index.names = ["corr_len", "seed", "label"]
  587. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  588. ex_tunings = traj.results.runs[run_name].ex_tunings
  589. head_direction_indices = traj.results[run_name].head_direction_indices
  590. hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  591. # TODO: Standart deviation also for the population
  592. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
  593. hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
  594. # Ellipsoid markers
  595. rx, ry = 5., 12.
  596. # area = rx * ry * np.pi * 2.
  597. area = 1.
  598. theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
  599. verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
  600. style_dict = {
  601. NO_SYNAPSES: ['grey', 'dashed', '', 0],
  602. POLARIZED: ['blue', 'solid', verts, 10.],
  603. CIRCULAR: ['lightblue', 'solid', 'o', 8.]
  604. }
  605. # colors = ['blue', 'grey', 'lightblue']
  606. # linestyles = ['solid', 'dashed', 'solid']
  607. # markers = [verts, '', 'o']
  608. fig, ax = plt.subplots(1, 1)
  609. for label in label_range:
  610. hdi_mean = hdi_exc_n_and_seed_mean[:, label]
  611. hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
  612. corr_len_range = hdi_mean.keys().to_numpy()
  613. col, lin, mar, mar_size = style_dict[label]
  614. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  615. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  616. hdi_mean + hdi_std, alpha=0.4, color=col)
  617. ax.set_xlabel('Correlation length')
  618. ax.set_ylabel('Head Direction Index')
  619. ax.axvline(206.9, color='k', linewidth=0.5)
  620. ax.set_ylim(0.0, 1.0)
  621. ax.set_xlim(0.0, 400.)
  622. ax.legend()
  623. if save_figs:
  624. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
  625. plt.close(fig)
  626. def plot_hdi_histogram_excitatory(traj, plot_run_names):
  627. labels = []
  628. hdis = []
  629. colors = ['black', 'red', 'green']
  630. for run_idx, run_name in enumerate(plot_run_names):
  631. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  632. labels.append(label)
  633. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  634. hdis.append(head_direction_indices)
  635. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  636. ax.hist(hdis, color=colors, label=labels, bins=30)
  637. for hdi, color in zip(hdis, colors):
  638. mean_hdi = np.mean(hdi)
  639. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  640. ax.set_xlabel("HDI")
  641. ax.legend()
  642. fig.tight_layout()
  643. if save_figs:
  644. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
  645. plt.close(fig)
  646. def plot_hdi_violin_excitatory(traj, plot_run_names):
  647. labels = []
  648. hdis = []
  649. colors = ['black', 'red', 'green']
  650. no_conn_hdi = 0.
  651. for run_idx, run_name in enumerate(plot_run_names):
  652. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  653. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  654. if label == NO_SYNAPSES:
  655. no_conn_hdi = np.mean(head_direction_indices)
  656. else:
  657. labels.append(label)
  658. hdis.append(sorted(head_direction_indices))
  659. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  660. # hdis = np.array(hdis)
  661. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  662. viol_plt['cmeans'].set_color('black')
  663. for pc in viol_plt['bodies']:
  664. pc.set_facecolor('red')
  665. pc.set_edgecolor('black')
  666. pc.set_alpha(0.7)
  667. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  668. ax.annotate(NO_SYNAPSES, xy=(0.45, 0.48), xycoords='axes fraction')
  669. ax.set_xticks(np.arange(1, len(labels) + 1))
  670. ax.set_xticklabels(labels)
  671. ax.set_ylabel('HDI')
  672. fig.tight_layout()
  673. if save_figs:
  674. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
  675. plt.close(fig)
  676. def plot_hdi_violin_inhibitory(traj, plot_run_names):
  677. labels = []
  678. hdis = []
  679. colors = ['black', 'red']
  680. for run_idx, run_name in enumerate(plot_run_names):
  681. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  682. if label != NO_SYNAPSES:
  683. labels.append(label)
  684. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  685. hdis.append(sorted(head_direction_indices))
  686. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  687. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  688. viol_plt['cmeans'].set_color('black')
  689. for pc in viol_plt['bodies']:
  690. pc.set_facecolor('blue')
  691. pc.set_edgecolor('black')
  692. pc.set_alpha(0.7)
  693. ax.set_xticks(np.arange(1, len(labels) + 1))
  694. ax.set_xticklabels(labels)
  695. ax.set_ylabel('HDI')
  696. fig.tight_layout()
  697. if save_figs:
  698. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
  699. plt.close(fig)
  700. def plot_hdi_violin_combined(traj, plot_run_names):
  701. labels = []
  702. inh_hdis = []
  703. exc_hdis = []
  704. no_conn_hdi = 0.
  705. colors = ['black', 'red']
  706. for run_idx, run_name in enumerate(plot_run_names):
  707. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  708. if label != NO_SYNAPSES:
  709. labels.append(label)
  710. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  711. inh_hdis.append(sorted(inh_head_direction_indices))
  712. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  713. exc_hdis.append(sorted(exc_head_direction_indices))
  714. else:
  715. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  716. no_conn_hdi = np.mean(exc_head_direction_indices)
  717. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  718. inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
  719. # viol_plt['cmeans'].set_color('black')
  720. #
  721. # for pc in viol_plt['bodies']:
  722. # pc.set_facecolor('blue')
  723. # pc.set_edgecolor('black')
  724. # pc.set_alpha(0.7)
  725. for b in inh_viol_plt['bodies']:
  726. m = np.mean(b.get_paths()[0].vertices[:, 0])
  727. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  728. b.set_color('b')
  729. exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
  730. for b in exc_viol_plt['bodies']:
  731. m = np.mean(b.get_paths()[0].vertices[:, 0])
  732. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  733. b.set_color('r')
  734. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  735. ax.annotate(NO_SYNAPSES, xy=(0.45, 0.48), xycoords='axes fraction')
  736. ax.set_xticks(np.arange(1, len(labels) + 1))
  737. ax.set_xticklabels(labels)
  738. ax.set_ylabel('HDI')
  739. fig.tight_layout()
  740. if save_figs:
  741. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
  742. plt.close(fig)
  743. def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  744. labels = []
  745. inh_hdis = []
  746. exc_hdis = []
  747. no_conn_hdi = 0.
  748. in_polar_plot_hdi = []
  749. ex_polar_plot_hdi = []
  750. colors = ['black', 'red']
  751. for run_idx, run_name in enumerate(plot_run_names):
  752. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  753. if label != NO_SYNAPSES:
  754. labels.append(label)
  755. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  756. inh_hdis.append(sorted(inh_head_direction_indices))
  757. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  758. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  759. exc_hdis.append(sorted(exc_head_direction_indices))
  760. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  761. else:
  762. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  763. no_conn_hdi = np.mean(exc_head_direction_indices)
  764. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  765. fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
  766. inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
  767. for b in inh_ell_viol_plt['bodies']:
  768. m = np.mean(b.get_paths()[0].vertices[:, 0])
  769. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  770. b.set_color('b')
  771. mean_line = inh_ell_viol_plt['cmeans']
  772. mean_line.set_color('b')
  773. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  774. exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
  775. for b in exc_ell_viol_plt['bodies']:
  776. m = np.mean(b.get_paths()[0].vertices[:, 0])
  777. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  778. b.set_color('r')
  779. mean_line = exc_ell_viol_plt['cmeans']
  780. mean_line.set_color('r')
  781. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  782. inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
  783. for b in inh_cir_viol_plt['bodies']:
  784. m = np.mean(b.get_paths()[0].vertices[:, 0])
  785. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  786. b.set_color('b')
  787. mean_line = inh_cir_viol_plt['cmeans']
  788. mean_line.set_color('b')
  789. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  790. exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
  791. for b in exc_cir_viol_plt['bodies']:
  792. m = np.mean(b.get_paths()[0].vertices[:, 0])
  793. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  794. b.set_color('r')
  795. mean_line = exc_cir_viol_plt['cmeans']
  796. mean_line.set_color('r')
  797. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  798. ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  799. ax.axvline(1.0, color='k')
  800. ax.annotate(NO_SYNAPSES, xy=(0.75, 0.415), xycoords='axes fraction')
  801. ax.set_xlim(0.5, 1.5)
  802. ax.set_ylim(0.0, 1.0)
  803. ax.set_xticks([0.75, 1.25])
  804. ax.set_xticklabels([CIRCULAR, POLARIZED])
  805. ax.set_ylabel('HDI')
  806. fig.tight_layout()
  807. if save_figs:
  808. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  809. plt.close(fig)
  810. return ex_polar_plot_hdi, in_polar_plot_hdi
  811. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id, cut_off_dist):
  812. labels = []
  813. inh_hdis = []
  814. exc_hdis = []
  815. no_conn_hdi = 0.
  816. for run_idx, run_name in enumerate(plot_run_names):
  817. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  818. if label != NO_SYNAPSES:
  819. labels.append(normal_labels(label))
  820. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  821. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  822. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  823. (inh_axonal_cloud[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  824. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  825. (inh_axonal_cloud[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  826. # print(inh_positions)
  827. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  828. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  829. ex_positions = traj.results.runs[run_name].ex_positions
  830. # print(ex_positions)
  831. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  832. (ex_positions[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  833. (ex_positions[:, 1] >= cut_off_dist) & \
  834. (ex_positions[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  835. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  836. else:
  837. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  838. no_conn_hdi = np.mean(exc_head_direction_indices)
  839. # Look for a representative excitatory neuron
  840. hdi_mean_dict = {}
  841. excitatory_hdi_means = [np.mean(hdis) for hdis in exc_hdis]
  842. hdi_mean_dict["polar_exc"] = excitatory_hdi_means[0]
  843. hdi_mean_dict["circular_exc"] = excitatory_hdi_means[1]
  844. inhibitory_hdi_means = [np.mean(hdis) for hdis in inh_hdis]
  845. hdi_mean_dict["polar_inh"] = inhibitory_hdi_means[0]
  846. hdi_mean_dict["circular_inh"] = inhibitory_hdi_means[1]
  847. # # fig = plt.figure(figsize=(3.5, 3.5))
  848. # # gs1 = gridspec.GridSpec(1, 2)
  849. #
  850. # fig = plt.figure(constrained_layout=True)
  851. # gs = gridspec.GridSpec(ncols=1, nrows=2, hspace=0.0, wspace=0.0, figure=fig)
  852. # # gs.update(wspace=0.0, hspace=0.0) # set the spacing between axes.
  853. # # f2_ax1 = fig.add_subplot(gs[0, 0])
  854. # # f2_ax2 = fig.add_subplot(gs[0, 1])
  855. # print(gs.get_subplot_params())
  856. width = 2 * panel_size
  857. height = 1.2 * panel_size
  858. fig, axes = plt.subplots(2, 1, figsize=(width, height))
  859. plt.subplots_adjust(wspace=0, hspace=0.1)
  860. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  861. max_density = 0
  862. for i in range(2):
  863. # i = i + 1 # grid spec indexes from 0
  864. # ax = fig.add_subplot(gs[i])
  865. ax = axes[i]
  866. density_e, _, _ = ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  867. density_i, _, _ = ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  868. max_density = np.max([max_density, np.max(density_e), np.max(density_i)])
  869. ax.axvline(np.mean(exc_hdis[i]), color='r')
  870. ax.axvline(np.mean(inh_hdis[i]), color='b')
  871. ax.axvline(no_conn_hdi, color='dimgrey', linestyle='--', linewidth=1.5)
  872. ax.set_ylabel(labels[i], rotation='vertical')
  873. # plt.axis('on')
  874. if i == 0:
  875. ax.set_xticklabels([])
  876. else:
  877. ax.set_xlabel('head direction index')
  878. remove_frame(ax, ["top", "right", "bottom"])
  879. max_density = 1.2 * max_density
  880. fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2)
  881. axes[0].annotate('% cells', (0, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  882. axes[1].annotate("no ihn.\n{:.2f}".format(no_conn_hdi), xy=(no_conn_hdi, max_density),
  883. xytext=(-2, 0), xycoords="data",
  884. textcoords="offset points",
  885. va="top", ha="right", color="dimgrey")
  886. for i, ax in enumerate(axes):
  887. ax.annotate("{:.2f}".format(np.mean(exc_hdis[i])), xy=(np.mean(exc_hdis[i]), max_density),
  888. xytext=(2, 0), xycoords="data",
  889. textcoords="offset points",
  890. va="top", ha="left", color="r")
  891. # i_ha = "left" if i == 1 else "right"
  892. # i_offset = 2 if i == 1 else -2
  893. i_ha = "right"
  894. i_offset = -1
  895. ax.annotate("{:.2f}".format(np.mean(inh_hdis[i])), xy=(np.mean(inh_hdis[i]), max_density),
  896. xytext=(i_offset, 0), xycoords="data",
  897. textcoords="offset points",
  898. va="top", ha=i_ha, color="b")
  899. for ax in axes:
  900. ax.set_ylim(0, max_density)
  901. # plt.annotate('probability density', (-0.2,1.5), xycoords='axes fraction', rotation=90, fontsize=18)
  902. if save_figs:
  903. plt.savefig(FIGURE_SAVE_PATH + 'E_hdi_histogram_combined_and_overlayed_cutoff_{}um.png'.format(cut_off_dist))
  904. plt.close(fig)
  905. return hdi_mean_dict
  906. def get_neurons_with_given_hdi(polar_hdi, circular_hdi, max_number_of_suggestions, plot_run_names, traj, type):
  907. polar_run_name = plot_run_names[0]
  908. circular_run_name = plot_run_names[1]
  909. polar_ex_hdis = traj.results.runs[polar_run_name].head_direction_indices if type == "ex" else traj.results.runs[
  910. polar_run_name].inh_head_direction_indices
  911. circular_ex_hdis = traj.results.runs[circular_run_name].head_direction_indices if type == "ex" else \
  912. traj.results.runs[
  913. polar_run_name].inh_head_direction_indices
  914. neuron_indices = get_indices_of_closest_values(polar_ex_hdis, polar_hdi,
  915. circular_ex_hdis,
  916. circular_hdi, 0.1 * np.abs(
  917. polar_hdi - circular_hdi), max_number_of_suggestions)
  918. return neuron_indices
  919. def get_indices_of_closest_values(first_list, first_value, second_list, second_value, absolute_tolerance_list_one,
  920. number_of_indices):
  921. is_close_in_list_one = np.abs(first_list - first_value) < absolute_tolerance_list_one
  922. indices_close_in_list_one = np.where(is_close_in_list_one)[0]
  923. indices_closest_in_list_two = indices_close_in_list_one[np.argpartition(np.abs(second_list[
  924. indices_close_in_list_one] - second_value),
  925. number_of_indices)]
  926. return indices_closest_in_list_two[:number_of_indices]
  927. def plot_hdi_histogram_inhibitory(traj, plot_run_names, in_polar_plot_id):
  928. labels = []
  929. hdis = []
  930. colors = ['black', 'red']
  931. for run_idx, run_name in enumerate(plot_run_names):
  932. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  933. if label != NO_SYNAPSES:
  934. labels.append(label)
  935. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  936. print('inh {}: {}'.format(label, head_direction_indices[in_polar_plot_id]))
  937. hdis.append(head_direction_indices)
  938. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  939. ax.hist(hdis, color=colors, label=labels, bins=30)
  940. for hdi, color in zip(hdis, colors):
  941. mean_hdi = np.mean(hdi)
  942. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  943. ax.set_xlabel("HDI")
  944. ax.legend()
  945. fig.tight_layout()
  946. if save_figs:
  947. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
  948. plt.close(fig)
  949. def filter_run_names_by_par_dict(traj, par_dict):
  950. run_name_list = []
  951. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  952. traj.f_set_crun(run_name)
  953. paramters_equal = True
  954. for key, val in par_dict.items():
  955. if (traj.par[key] != val):
  956. paramters_equal = False
  957. if paramters_equal:
  958. run_name_list.append(run_name)
  959. traj.f_restore_default()
  960. return run_name_list
  961. def plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, plot_run_names, cut_off_dist):
  962. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  963. seed_expl = traj.f_get('seed').f_get_range()
  964. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  965. label_range = set(label_expl)
  966. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  967. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  968. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  969. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  970. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  971. ex_tunings = traj.results.runs[run_name].ex_tunings
  972. inh_hdis = []
  973. exc_hdis = []
  974. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  975. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  976. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  977. (inh_axonal_cloud[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  978. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  979. (inh_axonal_cloud[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  980. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  981. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  982. ex_positions = traj.results.runs[run_name].ex_positions
  983. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  984. (ex_positions[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  985. (ex_positions[:, 1] >= cut_off_dist) & \
  986. (ex_positions[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  987. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  988. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  989. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  990. # TODO: Standard deviation also for the population
  991. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  992. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  993. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  994. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  995. markersize = 4.
  996. exc_style_dict = {
  997. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  998. POLARIZED: ['red', 'solid', '^', markersize],
  999. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  1000. }
  1001. inh_style_dict = {
  1002. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  1003. POLARIZED: ['blue', 'solid', 'o', markersize],
  1004. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  1005. }
  1006. # colors = ['blue', 'grey', 'lightblue']
  1007. # linestyles = ['solid', 'dashed', 'solid']
  1008. # markers = [verts, '', 'o']
  1009. # corr_len_fit_dict = correlation_length_fit_dict(traj, load=True)
  1010. width = 2 * panel_size
  1011. height = 1.2 * panel_size
  1012. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  1013. for label in sorted(label_range, reverse=True):
  1014. if label == NO_SYNAPSES:
  1015. no_conn_hdi = exc_hdi_n_and_seed_mean[1, label]
  1016. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  1017. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  1018. textcoords="offset points",
  1019. va="top", \
  1020. ha="right",
  1021. color="dimgrey")
  1022. continue
  1023. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  1024. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  1025. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  1026. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  1027. corr_len_range = exc_hdi_mean.keys().to_numpy()
  1028. print(label)
  1029. for corr_len, ex_hdi, in_hdi in zip(corr_len_range, exc_hdi_mean, inh_hdi_mean):
  1030. print("length: {:.2f} um, ex hdi: {:.2f} and in hdi {:.2f}".format(corr_len, ex_hdi, in_hdi))
  1031. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  1032. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  1033. simplex_grid_scale = corr_len_range * np.sqrt(2)
  1034. ax.plot(simplex_grid_scale, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  1035. markersize=exc_mar_size, alpha=0.5)
  1036. plt.fill_between(simplex_grid_scale, exc_hdi_mean - exc_hdi_std,
  1037. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  1038. ax.plot(simplex_grid_scale, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  1039. markersize=inh_mar_size, alpha=0.5)
  1040. plt.fill_between(simplex_grid_scale, inh_hdi_mean - inh_hdi_std,
  1041. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  1042. ax.set_xlabel('simplex grid scale')
  1043. ax.set_ylabel('head direction index')
  1044. ax.axvline(get_closest_correlation_length(traj, 200.0) * np.sqrt(2), color='k', linewidth=0.5, zorder=0)
  1045. ax.set_ylim(0.0, 1.0)
  1046. # ax.set_xlim(0.0, np.max(corr_len_range))
  1047. remove_frame(ax, ["right", "top"])
  1048. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  1049. row_labels=None,
  1050. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  1051. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  1052. fig.subplots_adjust(bottom=0.2, left=0.2)
  1053. # plt.legend()
  1054. if save_figs:
  1055. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_grid_scale.png')
  1056. plt.close(fig)
  1057. def plot_exc_and_inh_hdi_over_fit_corr_len(traj, plot_run_names, cut_off_dist):
  1058. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  1059. seed_expl = traj.f_get('seed').f_get_range()
  1060. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  1061. label_range = set(label_expl)
  1062. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  1063. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  1064. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  1065. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  1066. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  1067. ex_tunings = traj.results.runs[run_name].ex_tunings
  1068. inh_hdis = []
  1069. exc_hdis = []
  1070. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  1071. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  1072. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  1073. (inh_axonal_cloud[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  1074. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  1075. (inh_axonal_cloud[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  1076. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  1077. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  1078. ex_positions = traj.results.runs[run_name].ex_positions
  1079. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  1080. (ex_positions[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  1081. (ex_positions[:, 1] >= cut_off_dist) & \
  1082. (ex_positions[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  1083. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  1084. exc_hdi_frame[corr_len, seed, label] = np.mean(exc_hdis)
  1085. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_hdis)
  1086. # TODO: Standard deviation also for the population
  1087. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  1088. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  1089. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  1090. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  1091. markersize = 4.
  1092. exc_style_dict = {
  1093. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  1094. POLARIZED: ['red', 'solid', '^', markersize],
  1095. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  1096. }
  1097. inh_style_dict = {
  1098. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  1099. POLARIZED: ['blue', 'solid', 'o', markersize],
  1100. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  1101. }
  1102. # colors = ['blue', 'grey', 'lightblue']
  1103. # linestyles = ['solid', 'dashed', 'solid']
  1104. # markers = [verts, '', 'o']
  1105. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  1106. width = 2 * panel_size
  1107. height = 1.2 * panel_size
  1108. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  1109. for label in sorted(label_range, reverse=True):
  1110. if label == NO_SYNAPSES:
  1111. no_conn_hdi = exc_hdi_n_and_seed_mean[1, label]
  1112. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  1113. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  1114. textcoords="offset points",
  1115. va="top", \
  1116. ha="right",
  1117. color="dimgrey")
  1118. continue
  1119. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  1120. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  1121. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  1122. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  1123. corr_len_range = exc_hdi_mean.keys().to_numpy()
  1124. print(label)
  1125. for corr_len, ex_hdi, in_hdi in zip(corr_len_range, exc_hdi_mean, inh_hdi_mean):
  1126. print("length: {:.2f} um, ex hdi: {:.2f} and in hdi {:.2f}".format(corr_len, ex_hdi, in_hdi))
  1127. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  1128. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  1129. fit_corr_len = [corr_len_fit_dict[corr_len] for corr_len in corr_len_range]
  1130. last_shown_point_id = -1
  1131. fit_corr_len = fit_corr_len[:last_shown_point_id]
  1132. exc_hdi_mean = exc_hdi_mean.to_numpy()[:last_shown_point_id]
  1133. exc_hdi_std = exc_hdi_std.to_numpy()[:last_shown_point_id]
  1134. inh_hdi_mean = inh_hdi_mean.to_numpy()[:last_shown_point_id]
  1135. inh_hdi_std = inh_hdi_std.to_numpy()[:last_shown_point_id]
  1136. ax.plot(fit_corr_len, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  1137. markersize=exc_mar_size, alpha=0.5)
  1138. plt.fill_between(fit_corr_len, exc_hdi_mean - exc_hdi_std,
  1139. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  1140. ax.plot(fit_corr_len, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  1141. markersize=inh_mar_size, alpha=0.5)
  1142. plt.fill_between(fit_corr_len, inh_hdi_mean - inh_hdi_std,
  1143. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  1144. ax.set_xlabel('correlation length')
  1145. ax.set_ylabel('head direction index')
  1146. ax.axvline(corr_len_fit_dict[get_closest_correlation_length(traj, 170.0)], color='k', linewidth=0.5, zorder=0)
  1147. ax.set_ylim(0.0, 1.0)
  1148. # ax.set_xlim(0.0, 130.)
  1149. remove_frame(ax, ["right", "top"])
  1150. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  1151. row_labels=None,
  1152. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  1153. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  1154. fig.subplots_adjust(bottom=0.2, left=0.2)
  1155. # plt.legend()
  1156. if save_figs:
  1157. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_corr_len_scaled.png')
  1158. plt.close(fig)
  1159. def plot_in_degree_map(traj, plot_run_names):
  1160. n_ex = int(np.sqrt(traj.N_E))
  1161. max_degree = 0
  1162. for run_name in plot_run_names:
  1163. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  1164. exc_degree = np.sum(ie_adjacency, axis=0)
  1165. run_max_degree = np.max(exc_degree)
  1166. if run_max_degree > max_degree:
  1167. max_degree = run_max_degree
  1168. fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
  1169. for ax, run_name in zip(axes, plot_run_names[:-1]):
  1170. traj.f_set_crun(run_name)
  1171. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1172. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  1173. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  1174. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  1175. exc_degree = np.sum(ie_adjacency, axis=0)
  1176. c = ax.pcolor(X, Y, np.reshape(exc_degree, (number_of_excitatory_neurons_per_row,
  1177. number_of_excitatory_neurons_per_row)), vmin=0, vmax=max_degree,
  1178. cmap='hot')
  1179. ax.set_title(label)
  1180. fig.colorbar(c, ax=ax, label="in/out-degree")
  1181. fig.suptitle('in/out-degree', fontsize=16)
  1182. traj.f_restore_default()
  1183. if save_figs:
  1184. plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png', dpi=200)
  1185. plt.close(fig)
  1186. def plot_spatial_hdi_map(traj, plot_run_names):
  1187. max_val = 0
  1188. for run_name in plot_run_names:
  1189. hdis = traj.results.runs[run_name].head_direction_indices
  1190. run_max_val = np.max(hdis)
  1191. if run_max_val > max_val:
  1192. max_val = run_max_val
  1193. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1194. for ax, run_name in zip(axes, plot_run_names):
  1195. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1196. positions = traj.results.runs[run_name].ex_positions
  1197. head_direction_indices = traj.results[run_name].head_direction_indices
  1198. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1199. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1200. ax.set_title(label)
  1201. fig.colorbar(c, ax=ax, label="head direction index")
  1202. fig.suptitle('spatial HDI map', fontsize=16)
  1203. if save_figs:
  1204. plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png', dpi=200)
  1205. plt.close(fig)
  1206. def plot_exc_spatial_hdi_map(traj, plot_run_names):
  1207. max_val = 0
  1208. for run_name in plot_run_names:
  1209. hdis = traj.results.runs[run_name].head_direction_indices
  1210. run_max_val = np.max(hdis)
  1211. if run_max_val > max_val:
  1212. max_val = run_max_val
  1213. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1214. for ax, run_name in zip(axes, plot_run_names):
  1215. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1216. positions = traj.results.runs[run_name].ex_positions
  1217. head_direction_indices = traj.results[run_name].head_direction_indices
  1218. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1219. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1220. ax.set_title(label)
  1221. fig.colorbar(c, ax=ax, label="head direction index")
  1222. fig.suptitle('spatial exc. HDI map', fontsize=16)
  1223. if save_figs:
  1224. plt.savefig(FIGURE_SAVE_PATH + 'spatial_exc_hdi_map.png', dpi=200)
  1225. plt.close(fig)
  1226. def plot_inh_spatial_hdi_map(traj, plot_run_names):
  1227. max_val = 0
  1228. for run_name in plot_run_names:
  1229. hdis = traj.results.runs[run_name].inh_head_direction_indices
  1230. run_max_val = np.max(hdis)
  1231. if run_max_val > max_val:
  1232. max_val = run_max_val
  1233. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1234. for ax, run_name in zip(axes, plot_run_names):
  1235. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1236. ax_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  1237. positions = [[x, y] for x, y, phi in ax_cloud]
  1238. head_direction_indices = traj.results[run_name].inh_head_direction_indices
  1239. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1240. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1241. ax.set_title(label)
  1242. fig.colorbar(c, ax=ax, label="head direction index")
  1243. fig.suptitle('spatial inh. HDI map', fontsize=16)
  1244. if save_figs:
  1245. plt.savefig(FIGURE_SAVE_PATH + 'spatial_inh_hdi_map.png', dpi=200)
  1246. plt.close(fig)
  1247. def get_phase_difference(total_difference):
  1248. """
  1249. Map accumulated phase difference to shortest possible difference
  1250. :param total_difference:
  1251. :return: relative_difference
  1252. """
  1253. return (total_difference + np.pi) % (2 * np.pi) - np.pi
  1254. def plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names):
  1255. # The plot that Imre wanted
  1256. n_bins = traj.parameters.input.number_of_directions
  1257. fig, ax = plt.subplots(1, 1, figsize=(9, 9))
  1258. dir_bins = np.linspace(-np.pi, np.pi, n_bins + 1)
  1259. print(dir_bins)
  1260. plot_fr_array = []
  1261. labels = []
  1262. similarity_threshold = np.pi / 6.
  1263. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  1264. for run_idx, run_name in enumerate(plot_run_names):
  1265. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1266. labels.append(label)
  1267. fr_similar_tunings = []
  1268. fr_different_tunings = []
  1269. ex_tunings = traj.results.runs[run_name].ex_tunings
  1270. firing_rate_array = traj.results[run_name].firing_rate_array
  1271. for tuning, firing_rates in zip(ex_tunings, firing_rate_array):
  1272. for idx, dir in enumerate(directions):
  1273. if np.abs(get_phase_difference(tuning - dir)) <= similarity_threshold:
  1274. fr_similar_tunings.append(firing_rates[idx])
  1275. elif np.abs(get_phase_difference(tuning + np.pi - dir)) <= similarity_threshold:
  1276. fr_different_tunings.append(firing_rates[idx])
  1277. plot_fr_array.append([np.mean(fr_similar_tunings), np.mean(fr_different_tunings)])
  1278. x = np.arange(3) # the label locations
  1279. width = 0.35 # the width of the bars
  1280. plot_fr_array = np.array(plot_fr_array)
  1281. rects1 = ax.bar(x - width / 2, plot_fr_array[:, 0], width,
  1282. label='theta pref +/- {}°'.format(np.round(similarity_threshold / np.pi * 180)))
  1283. rects2 = ax.bar(x + width / 2, plot_fr_array[:, 1], width,
  1284. label='theta pref + 180° +/- {}°'.format(np.round(similarity_threshold / np.pi * 180)))
  1285. ax.set_xticks(x)
  1286. ax.set_xticklabels(labels)
  1287. ax.set_title('Mean firing rate for tunings similar and different to input')
  1288. ax.set_ylabel('Mean firing rate')
  1289. ax.legend()
  1290. def autolabel(rects):
  1291. """Attach a text label above each bar in *rects*, displaying its height."""
  1292. for rect in rects:
  1293. height = rect.get_height()
  1294. ax.annotate('{}'.format(np.round(height)),
  1295. xy=(rect.get_x() + rect.get_width() / 2, height),
  1296. xytext=(0, 3), # 3 points vertical offset
  1297. textcoords="offset points",
  1298. ha='center', va='bottom')
  1299. autolabel(rects1)
  1300. autolabel(rects2)
  1301. fig.tight_layout()
  1302. if save_figs:
  1303. plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_similar_vs_diff_tuning.png', dpi=200)
  1304. plt.close(fig)
  1305. def get_firing_rates_along_preferred_axis(traj, run_name, neuron_idx):
  1306. firing_rates = traj.results[run_name].firing_rate_array[neuron_idx, :]
  1307. tuning = traj.results[run_name].ex_tunings[neuron_idx]
  1308. anti_tuning = tuning + np.pi if tuning + np.pi < np.pi else tuning - np.pi
  1309. tuning_idx = np.argmin(np.abs(directions - tuning))
  1310. anti_tuning_idx = np.argmin(np.abs(directions - anti_tuning))
  1311. firing_at_the_preferred_direction = firing_rates[tuning_idx]
  1312. firing_at_the_opposite_direction = firing_rates[anti_tuning_idx]
  1313. return firing_at_the_preferred_direction, firing_at_the_opposite_direction
  1314. def get_hdi(traj, run_name, neuron_idx, type):
  1315. return traj.results.runs[run_name].head_direction_indices[neuron_idx] if type=="ex" else traj.results.runs[
  1316. run_name].inh_head_direction_indices[neuron_idx]
  1317. if __name__ == "__main__":
  1318. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  1319. NO_LOADING = 0
  1320. FULL_LOAD = 2
  1321. traj.f_load(filename=os.path.join(DATA_FOLDER, TRAJ_NAME + ".hdf5"), load_parameters=FULL_LOAD,
  1322. load_results=NO_LOADING)
  1323. traj.v_auto_load = True
  1324. save_figs = True
  1325. print("# Plotting script polarized interneurons")
  1326. print()
  1327. map_length_scale = 170.0
  1328. map_seed = 1
  1329. exemplary_head_direction = 0
  1330. # corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  1331. # plt.plot(corr_len_fit_dict.keys(),corr_len_fit_dict.values())
  1332. # plt.show()
  1333. # abbrechen
  1334. print("## Map specifications")
  1335. print("\tcorrelation length: {:.1f} um".format(map_length_scale))
  1336. print("\tmap seed: {:d}".format(map_seed))
  1337. print()
  1338. print("## Input specification")
  1339. print("\tselected head direction: {:.0f}°".format(exemplary_head_direction))
  1340. print()
  1341. print("## Selected simulations")
  1342. plot_corr_len = get_closest_correlation_length(traj, map_length_scale)
  1343. par_dict = {'seed': map_seed, 'correlation_length': plot_corr_len}
  1344. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  1345. run_name_dict = {}
  1346. for run_name in plot_run_names:
  1347. traj.f_set_crun(run_name)
  1348. run_name_dict[traj.derived_parameters.runs[run_name].morphology.morph_label] = run_name
  1349. for network_type, run_name in run_name_dict.items():
  1350. print("{:s}: {:s}".format(network_type, run_name))
  1351. directions = get_input_head_directions(traj)
  1352. direction_idx = np.argmin(np.abs(np.array(directions) - np.deg2rad(exemplary_head_direction)))
  1353. selected_neuron_excitatory = 1052
  1354. selected_inhibitory_neuron = 28
  1355. print("## Figure specification")
  1356. print("\tpanel size: {:.2f} cm".format(panel_size * cm_per_inch))
  1357. print()
  1358. plot_input_map(traj, run_name_dict[POLARIZED], figname="A_i_exemplary_input_map.png",
  1359. figsize=(panel_size, panel_size))
  1360. plot_axonal_clouds(traj, plot_run_names)
  1361. plot_example_input_maps(traj, figsize=(2 * panel_size, 2 * panel_size))
  1362. plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, selected_neuron_excitatory)
  1363. in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron)
  1364. #
  1365. hdi_means = plot_hdi_histogram_combined_and_overlayed(
  1366. traj, plot_run_names,
  1367. selected_neuron_excitatory,
  1368. selected_inhibitory_neuron,
  1369. cut_off_dist=100.)
  1370. #
  1371. number_of_suggestions = 0
  1372. representative_excitatory_neuron_indices = get_neurons_with_given_hdi(hdi_means["polar_exc"], hdi_means[
  1373. "circular_exc"],
  1374. number_of_suggestions, plot_run_names,
  1375. traj, "ex")
  1376. representative_inhibitory_neuron_indices = get_neurons_with_given_hdi(hdi_means["polar_inh"],
  1377. hdi_means["circular_inh"],
  1378. number_of_suggestions, plot_run_names,
  1379. traj, "in")
  1380. plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_excitatory)
  1381. plot_polar_plot_inhibitory(traj, plot_run_names, selected_inhibitory_neuron)
  1382. plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names)
  1383. plot_orientation_maps_diff_scales_with_ellipse(traj)
  1384. # plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, traj.f_get_run_names(), cut_off_dist=100.)
  1385. plot_exc_and_inh_hdi_over_fit_corr_len(traj, traj.f_get_run_names(), cut_off_dist=100.)
  1386. if not save_figs:
  1387. plt.show()
  1388. traj.f_restore_default()