paper_figures_orientation_map.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655
  1. import os
  2. import matplotlib.legend as mlegend
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import pandas as pd
  6. from matplotlib.patches import Ellipse
  7. from matplotlib.patches import Polygon
  8. from matplotlib.patches import Rectangle
  9. from mpl_toolkits.axes_grid1 import ImageGrid
  10. from pypet import Trajectory
  11. from pypet.brian2 import Brian2MonitorResult
  12. from scripts.interneuron_placement import get_position_mesh, Pickle, get_correct_position_mesh
  13. from scripts.model_figure.plot_circular_colorbar import plot_colorbar
  14. from scripts.spatial_maps.simplex_input_tuning_correlation.correlation_length_distance_perlin import \
  15. correlation_length_fit_dict
  16. from scripts.spatial_network.figures_spatial_head_direction_network_orientation_map import plot_hdi_in_space
  17. from scripts.spatial_network.perlin.figure_utils import remove_frame, remove_ticks, add_length_scale, cm_per_inch, \
  18. panel_size, head_direction_input_colormap
  19. from scripts.spatial_network.orientation_map.run_orientation_map import DATA_FOLDER, TRAJ_NAME, \
  20. get_input_head_directions, POLARIZED, CIRCULAR, NO_SYNAPSES
  21. plt.style.use('figures.mplstyle')
  22. FIGURE_SAVE_PATH = '../../../figures/supplementary_orientation_map/'
  23. save_figs = False
  24. def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
  25. """
  26. Place a table legend on the axes.
  27. Creates a legend where the labels are not directly placed with the artists,
  28. but are used as row and column headers, looking like this:
  29. title_label | col_labels[1] | col_labels[2] | col_labels[3]
  30. -------------------------------------------------------------
  31. row_labels[1] |
  32. row_labels[2] | <artists go there>
  33. row_labels[3] |
  34. Parameters
  35. ----------
  36. ax : `matplotlib.axes.Axes`
  37. The artist that contains the legend table, i.e. current axes instant.
  38. col_labels : list of str, optional
  39. A list of labels to be used as column headers in the legend table.
  40. `len(col_labels)` needs to match `ncol`.
  41. row_labels : list of str, optional
  42. A list of labels to be used as row headers in the legend table.
  43. `len(row_labels)` needs to match `len(handles) // ncol`.
  44. title_label : str, optional
  45. Label for the top left corner in the legend table.
  46. ncol : int
  47. Number of columns.
  48. Other Parameters
  49. ----------------
  50. Refer to `matplotlib.legend.Legend` for other parameters.
  51. """
  52. #################### same as `matplotlib.axes.Axes.legend` #####################
  53. handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
  54. if len(extra_args):
  55. raise TypeError('legend only accepts two non-keyword arguments')
  56. if col_labels is None and row_labels is None:
  57. ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
  58. ax.legend_._remove_method = ax._remove_legend
  59. return ax.legend_
  60. #################### modifications for table legend ############################
  61. else:
  62. ncol = kwargs.pop('ncol')
  63. handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
  64. title_label = [title_label]
  65. # blank rectangle handle
  66. extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
  67. # empty label
  68. empty = [""]
  69. # number of rows infered from number of handles and desired number of columns
  70. nrow = len(handles) // ncol
  71. # organise the list of handles and labels for table construction
  72. if col_labels is None:
  73. assert nrow == len(
  74. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  75. nrow, len(row_labels))
  76. leg_handles = extra * nrow
  77. leg_labels = row_labels
  78. elif row_labels is None:
  79. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  80. ncol, len(col_labels))
  81. leg_handles = []
  82. leg_labels = []
  83. else:
  84. assert nrow == len(
  85. row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (
  86. nrow, len(row_labels))
  87. assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (
  88. ncol, len(col_labels))
  89. leg_handles = extra + extra * nrow
  90. leg_labels = title_label + row_labels
  91. for col in range(ncol):
  92. if col_labels is not None:
  93. leg_handles += extra
  94. leg_labels += [col_labels[col]]
  95. leg_handles += handles[col * nrow:(col + 1) * nrow]
  96. leg_labels += empty * nrow
  97. # Create legend
  98. ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol + int(row_labels is not None),
  99. handletextpad=handletextpad, **kwargs)
  100. ax.legend_._remove_method = ax._remove_legend
  101. return ax.legend_
  102. def get_closest_correlation_length(traj, correlation_length):
  103. available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
  104. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
  105. if closest_length != correlation_length:
  106. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  107. correlation_length, closest_length))
  108. corr_len = closest_length
  109. return corr_len
  110. def colorbar(mappable):
  111. from mpl_toolkits.axes_grid1 import make_axes_locatable
  112. import matplotlib.pyplot as plt
  113. last_axes = plt.gca()
  114. ax = mappable.axes
  115. fig = ax.figure
  116. divider = make_axes_locatable(ax)
  117. cax = divider.append_axes("right", size="5%", pad=0.05)
  118. cbar = fig.colorbar(mappable, cax=cax)
  119. plt.sca(last_axes)
  120. return cbar
  121. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, exemplary_excitatory_neuron_id):
  122. max_val = 0
  123. for run_name in plot_run_names:
  124. fr_array = traj.results.runs[run_name].firing_rate_array
  125. f_rates = fr_array[:, direction_idx]
  126. run_max_val = np.max(f_rates)
  127. if run_max_val > max_val:
  128. # if traj.derived_parameters.runs[run_name].morphology.morph_label == POLARIZED:
  129. # n_id_max_rate = np.argmax(f_rates)
  130. max_val = run_max_val
  131. n_id_polar_plot = exemplary_excitatory_neuron_id
  132. # Mark the neuron that is shown in polar plot
  133. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  134. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  135. # Vertices for the plotted triangle
  136. tr_scale = 30.
  137. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  138. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  139. tr_vertices = np.array(
  140. [[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  141. height = 1 * panel_size
  142. width = 3 * panel_size
  143. fig = plt.figure(figsize=(width, height))
  144. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  145. cbar_mode="single",
  146. cbar_size="7%", cbar_pad=panel_size / 10.0,
  147. nrows_ncols=(1, 3))
  148. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  149. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  150. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  151. firing_rate_array = traj.results[run_name].firing_rate_array
  152. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  153. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  154. number_of_excitatory_neurons_per_row)),
  155. vmin=0, vmax=max_val, cmap='Reds')
  156. # ax.set_title(adjust_label(label))
  157. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  158. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  159. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  160. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  161. for spine in ax.spines.values():
  162. spine.set_edgecolor("grey")
  163. spine.set_linewidth(1)
  164. remove_ticks(ax)
  165. # fig.suptitle('spatial firing rate map', fontsize=16)
  166. ax.cax.colorbar(c)
  167. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  168. # fig.tight_layout()
  169. if save_figs:
  170. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_excitatory.png', dpi=300)
  171. plt.close(fig)
  172. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron):
  173. max_val = 0
  174. for run_name in plot_run_names:
  175. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  176. f_rates = fr_array[:, direction_idx]
  177. run_max_val = np.max(f_rates)
  178. if run_max_val > max_val:
  179. max_val = run_max_val
  180. n_id_polar_plot = selected_inhibitory_neuron
  181. # Mark the neuron that is shown in Polar plot
  182. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  183. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  184. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  185. width = 3 * panel_size
  186. height = panel_size
  187. fig = plt.figure(figsize=(width, height))
  188. axes = ImageGrid(fig, (0.05, 0.15, 0.85, 0.7), axes_pad=panel_size / 6.0, cbar_location="right",
  189. cbar_mode="single",
  190. cbar_size="7%", cbar_pad=panel_size / 10.0,
  191. nrows_ncols=(1, 3))
  192. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  193. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  194. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  195. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  196. X, Y = get_correct_position_mesh(inh_positions)
  197. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  198. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  199. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  200. number_of_inhibitory_neurons_per_row)),
  201. vmin=0, vmax=max_val, cmap='Blues')
  202. # ax.set_title(adjust_label(label))
  203. circle_r = 40.
  204. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  205. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  206. for spine in ax.spines.values():
  207. spine.set_edgecolor("grey")
  208. spine.set_linewidth(0.5)
  209. remove_ticks(ax)
  210. # fig.colorbar(c, ax=ax, label="f (Hz)")
  211. # fig.suptitle('spatial firing rate map', fontsize=16)
  212. ax.cax.colorbar(c)
  213. ax.cax.annotate("fr (Hz)", xy=(1, 1), xytext=(3, 3), xycoords="axes fraction", textcoords="offset points")
  214. # fig.tight_layout()
  215. if save_figs:
  216. plt.savefig(FIGURE_SAVE_PATH + 'C_firing_rate_map_inhibitory.png', dpi=300)
  217. plt.close(fig)
  218. return max_val
  219. def plot_hdi_over_tuning(traj, plot_run_names):
  220. fig, ax = plt.subplots(1, 1)
  221. for run_idx, run_name in enumerate(plot_run_names):
  222. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  223. ex_tunings = traj.results.runs[run_name].ex_tunings
  224. ex_tunings_plt = np.array(ex_tunings)
  225. sort_ids = ex_tunings_plt.argsort()
  226. ex_tunings_plt = ex_tunings_plt[sort_ids]
  227. head_direction_indices = traj.results[run_name].head_direction_indices
  228. hdi_plt = head_direction_indices
  229. hdi_plt = hdi_plt[sort_ids]
  230. ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
  231. ax.legend()
  232. ax.set_xlabel("Angles (rad)")
  233. ax.set_ylabel("head direction index")
  234. ax.set_title('hdi over input tuning', fontsize=16)
  235. if save_figs:
  236. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png')
  237. plt.close(fig)
  238. def normal_labels(label):
  239. if label == POLARIZED:
  240. label = 'polar'
  241. elif label == NO_SYNAPSES:
  242. label = 'no interneurons'
  243. return label
  244. def short_labels(label):
  245. if label == POLARIZED:
  246. label = 'polar'
  247. elif label == CIRCULAR:
  248. label = 'circ.'
  249. elif label == "no conn":
  250. label = "no inh."
  251. return label
  252. def plot_input_map(traj, run_name, figsize=(panel_size, panel_size), figname='input_map.png'):
  253. n_ex = int(np.sqrt(traj.N_E))
  254. width, height = figsize
  255. fig = plt.figure(figsize=(width, height))
  256. axes = ImageGrid(fig, (0.15, 0.1, 0.7, 0.8), axes_pad=panel_size / 3.0, cbar_location="right", cbar_mode=None,
  257. nrows_ncols=(1, 1))
  258. ax = axes[0]
  259. traj.f_set_crun(run_name)
  260. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  261. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  262. scale_length = traj.morphology.long_axis * 2
  263. traj.f_restore_default()
  264. ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  265. for spine in ax.spines.values():
  266. spine.set_edgecolor("grey")
  267. spine.set_linewidth(0.5)
  268. remove_ticks(ax)
  269. start_scale_x = 100
  270. end_scale_x = start_scale_x + scale_length
  271. start_scale_y = -70
  272. end_scale_y = start_scale_y
  273. add_length_scale(ax, scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  274. ax.cax.set_visible(False)
  275. if save_figs:
  276. plt.savefig(FIGURE_SAVE_PATH + figname)
  277. plt.close(fig)
  278. def plot_axonal_clouds(traj, plot_run_names):
  279. n_ex = int(np.sqrt(traj.N_E))
  280. height = 1 * panel_size
  281. width = 3 * panel_size
  282. cluster_positions = [(250, 250), (750, 200), (450, 600)]
  283. cluster_sizes = [4, 4, 4]
  284. selected_neurons = []
  285. inhibitory_positions = traj.results.runs[plot_run_names[0]].inhibitory_axonal_cloud_array[:, :2]
  286. for cluster_position, number_of_neurons_in_cluster in zip(cluster_positions, cluster_sizes):
  287. selection = get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster,
  288. inhibitory_positions)
  289. selected_neurons.extend(selection)
  290. fig = plt.figure(figsize=(width, height))
  291. axes = ImageGrid(fig, 111, axes_pad=0.15, cbar_location="right", cbar_mode="single", cbar_size="7%",
  292. nrows_ncols=(1, 3))
  293. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2, 0, 1]]):
  294. traj.f_set_crun(run_name)
  295. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  296. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  297. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  298. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  299. inhibitory_axonal_cloud_array]
  300. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  301. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  302. # ax.set_title(normal_labels(label))
  303. ax.set_aspect('equal')
  304. remove_frame(ax)
  305. remove_ticks(ax)
  306. # fig.colorbar(c, ax=ax, label="Tuning")
  307. if label != NO_SYNAPSES and axonal_clouds is not None:
  308. for i, p in enumerate(axonal_clouds):
  309. ell = p.get_ellipse()
  310. if i in selected_neurons:
  311. edgecolor = 'black'
  312. alpha = 1
  313. zorder = 10
  314. linewidth = 1
  315. else:
  316. edgecolor = 'gray'
  317. alpha = 0.5
  318. zorder = 1
  319. linewidth = 0.3
  320. ell.set_edgecolor(edgecolor)
  321. ell.set_alpha(alpha)
  322. ell.set_zorder(zorder)
  323. ell.set_linewidth(linewidth)
  324. ax.add_artist(ell)
  325. traj.f_set_crun(plot_run_names[0])
  326. scale_length = traj.morphology.long_axis * 2
  327. traj.f_restore_default()
  328. start_scale_x = 100
  329. end_scale_x = start_scale_x + scale_length
  330. start_scale_y = -70
  331. end_scale_y = start_scale_y
  332. add_length_scale(axes[0], scale_length, start_scale_x, end_scale_x, start_scale_y, end_scale_y)
  333. axes[1].set_yticklabels([])
  334. axes[2].set_yticklabels([])
  335. axes[0].cax.set_visible(False)
  336. traj.f_restore_default()
  337. if save_figs:
  338. plt.savefig(FIGURE_SAVE_PATH + 'B_i_axonal_clouds.png')
  339. plt.close(fig)
  340. def get_neurons_close_to_given_position(cluster_position, number_of_neurons_in_cluster, positions):
  341. position = np.array(cluster_position)
  342. distance_vectors = positions - np.expand_dims(position, 0).repeat(positions.shape[0],
  343. axis=0)
  344. distances = np.linalg.norm(distance_vectors, axis=1)
  345. selection = list(np.argpartition(distances, number_of_neurons_in_cluster)[:number_of_neurons_in_cluster])
  346. return selection
  347. def plot_orientation_maps_diff_scales(traj):
  348. n_ex = int(np.sqrt(traj.N_E))
  349. scale_run_names = []
  350. plot_scales = [0.0, 100.0, 200.0, 300.0]
  351. for scale in plot_scales:
  352. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj, scale), 'long_axis': 100.}
  353. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  354. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  355. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  356. traj.f_set_crun(run_name)
  357. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  358. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  359. # TODO: Why was this transposed for plotting? (now changed)
  360. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  361. ax.set_title('Correlation length: {}'.format(scale))
  362. fig.colorbar(c, ax=ax, label="Tuning")
  363. # fig.suptitle('axonal cloud', fontsize=16)
  364. traj.f_restore_default()
  365. if save_figs:
  366. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png')
  367. plt.close(fig)
  368. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  369. n_ex = int(np.sqrt(traj.N_E))
  370. scale_run_names = []
  371. plot_scales = [0.0, 400.0, 800.0]
  372. real_scales = [get_closest_correlation_length(traj, scale) for scale in plot_scales]
  373. for scale in plot_scales:
  374. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj, scale), 'long_axis': 100.}
  375. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  376. print(scale_run_names)
  377. inhibitory_positions = traj.results.runs[scale_run_names[1]].inhibitory_axonal_cloud_array[:, :2]
  378. selected_polar_neuron = get_neurons_close_to_given_position((300, 300), 1, inhibitory_positions)[0]
  379. print(selected_polar_neuron)
  380. selected_circular_neuron = get_neurons_close_to_given_position((600, 600), 1, inhibitory_positions)[0]
  381. width = panel_size
  382. height = 1.5 * panel_size
  383. fig = plt.figure(figsize=(width, height))
  384. axes = ImageGrid(fig, 111, axes_pad=0.15,
  385. nrows_ncols=(len(plot_scales), 1))
  386. for ax, run_name, scale in zip(axes, scale_run_names, real_scales):
  387. traj.f_set_crun(run_name)
  388. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  389. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  390. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  391. inhibitory_axonal_cloud_array]
  392. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  393. # TODO: Why was this transposed for plotting? (now changed)
  394. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap=head_direction_input_colormap)
  395. # ax.set_title('Correlation length: {}'.format(scale))
  396. # fig.colorbar(c, ax=ax, label="Tuning")
  397. ax.set_xticks([])
  398. ax.set_yticks([])
  399. p1 = axonal_clouds[selected_polar_neuron]
  400. ell = p1.get_ellipse()
  401. ell._linewidth = 2.
  402. ax.add_artist(ell)
  403. p2 = axonal_clouds[selected_circular_neuron]
  404. circ_r = 2 * np.sqrt(p2.a * p2.b)
  405. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  406. circ._linewidth = 2.
  407. ax.add_artist(circ)
  408. ax.annotate("{:.0f}".format(scale), xy=(1.0, 0.5), xytext=(2, 0), xycoords="axes fraction", textcoords="offset "
  409. "points",
  410. va="center", ha="left")
  411. remove_frame(ax)
  412. # fig.suptitle('axonal cloud', fontsize=16)
  413. traj.f_restore_default()
  414. add_length_scale(axes[1], 200, -300, -100, 50, 50)
  415. axes[0].annotate("input maps", xy=(-0.5, 1.0), xycoords="axes fraction", rotation=90, ha="left", va="top")
  416. fig.tight_layout()
  417. if save_figs:
  418. plt.savefig(FIGURE_SAVE_PATH + 'F_orientation_maps_diff_scales_with_ellipse.png')
  419. plt.close(fig)
  420. def plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_idx,
  421. figname=FIGURE_SAVE_PATH + 'D_polar_plot_excitatory.png'):
  422. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  423. directions_plt = list(directions)
  424. directions_plt.append(directions[0])
  425. height = panel_size
  426. width = panel_size
  427. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  428. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  429. # sorted_ids = np.argsort(head_direction_indices)
  430. # plot_n_idx = sorted_ids[-75]
  431. plot_n_idx = selected_neuron_idx
  432. line_styles = ['dotted', 'solid', 'dashed']
  433. colors = ['r', 'lightsalmon', 'grey']
  434. line_widths = [1.5, 1.5, 1]
  435. zorders = [10, 2, 1]
  436. max_rate = 0.0
  437. for run_idx, run_name in enumerate(plot_run_names):
  438. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  439. hdi = traj.results.runs[run_name].head_direction_indices[selected_neuron_idx]
  440. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  441. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  442. run_max_rate = np.max(rate_plot)
  443. if run_max_rate > max_rate:
  444. max_rate = run_max_rate
  445. rate_plot.append(rate_plot[0])
  446. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  447. label='{:s} {:.2f}'.format(short_labels(label), hdi),
  448. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  449. # ax.set_title('Firing Rate')
  450. # ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  451. # TODO: Set ticks for polar
  452. ticks = [40., 80.]
  453. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  454. enumerate(ticks)], angle=60)
  455. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  456. ax.xaxis.grid(linewidth=0.4)
  457. ax.yaxis.grid(linewidth=0.4)
  458. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  459. leg.get_frame().set_linewidth(0.0)
  460. ax.axes.spines["polar"].set_visible(False)
  461. if save_figs:
  462. plt.savefig(figname)
  463. plt.close(fig)
  464. def plot_polar_plot_inhibitory(traj, plot_run_names, selected_neuron_idx, figname=FIGURE_SAVE_PATH +
  465. 'D_polar_plot_inhibitory.png'):
  466. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  467. directions_plt = list(directions)
  468. directions_plt.append(directions[0])
  469. height = panel_size
  470. width = panel_size
  471. fig, ax = plt.subplots(1, 1, figsize=(height, width), subplot_kw=dict(projection='polar'))
  472. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  473. # sorted_ids = np.argsort(head_direction_indices)
  474. # plot_n_idx = sorted_ids[-75]
  475. plot_n_idx = selected_neuron_idx
  476. line_styles = ['dotted', 'solid']
  477. colors = ['b', 'lightblue']
  478. line_widths = [1.5, 1.5]
  479. zorders = [10, 2]
  480. for run_idx, run_name in enumerate(plot_run_names[:2]):
  481. # ax = axes[max_hdi_idx, run_idx]
  482. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  483. hdi = traj.results.runs[run_name].inh_head_direction_indices[selected_neuron_idx]
  484. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  485. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  486. rate_plot.append(rate_plot[0])
  487. ax.plot(directions_plt, rate_plot, linewidth=line_widths[run_idx],
  488. label='{:s} {:.2f}'.format(short_labels(label), hdi),
  489. color=colors[run_idx], linestyle=line_styles[run_idx], zorder=zorders[run_idx])
  490. # ax.set_title('Inh. Firing Rate')
  491. # TODO: Set ticks for polar
  492. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  493. ticks = [40., 80., 120.]
  494. ax.set_rgrids(ticks, labels=["{:.0f} Hz".format(ticklabel) if idx == len(ticks) - 1 else "" for idx, ticklabel in
  495. enumerate(ticks)], angle=60)
  496. ax.set_thetagrids([0, 90, 180, 270], labels=[])
  497. ax.xaxis.grid(linewidth=0.4)
  498. ax.yaxis.grid(linewidth=0.4)
  499. leg = ax.legend(loc="lower right", bbox_to_anchor=(1.15, -0.15), handlelength=1, fontsize="medium")
  500. leg.get_frame().set_linewidth(0.0)
  501. ax.axes.spines["polar"].set_visible(False)
  502. if save_figs:
  503. plt.savefig(figname)
  504. plt.close(fig)
  505. def plot_hdi_over_corr_len(traj, plot_run_names):
  506. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  507. seed_expl = traj.f_get('seed').f_get_range()
  508. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  509. label_range = set(label_expl)
  510. hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  511. hdi_frame.index.names = ["corr_len", "seed", "label"]
  512. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  513. ex_tunings = traj.results.runs[run_name].ex_tunings
  514. head_direction_indices = traj.results[run_name].head_direction_indices
  515. hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  516. # TODO: Standart deviation also for the population
  517. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
  518. hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
  519. # Ellipsoid markers
  520. rx, ry = 5., 12.
  521. # area = rx * ry * np.pi * 2.
  522. area = 1.
  523. theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
  524. verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
  525. style_dict = {
  526. NO_SYNAPSES: ['grey', 'dashed', '', 0],
  527. POLARIZED: ['blue', 'solid', verts, 10.],
  528. CIRCULAR: ['lightblue', 'solid', 'o', 8.]
  529. }
  530. # colors = ['blue', 'grey', 'lightblue']
  531. # linestyles = ['solid', 'dashed', 'solid']
  532. # markers = [verts, '', 'o']
  533. fig, ax = plt.subplots(1, 1)
  534. for label in label_range:
  535. hdi_mean = hdi_exc_n_and_seed_mean[:, label]
  536. hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
  537. corr_len_range = hdi_mean.keys().to_numpy()
  538. col, lin, mar, mar_size = style_dict[label]
  539. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  540. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  541. hdi_mean + hdi_std, alpha=0.4, color=col)
  542. ax.set_xlabel('Correlation length')
  543. ax.set_ylabel('Head Direction Index')
  544. ax.axvline(206.9, color='k', linewidth=0.5)
  545. ax.set_ylim(0.0, 1.0)
  546. ax.set_xlim(0.0, 400.)
  547. ax.legend()
  548. if save_figs:
  549. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
  550. plt.close(fig)
  551. def plot_hdi_histogram_excitatory(traj, plot_run_names):
  552. labels = []
  553. hdis = []
  554. colors = ['black', 'red', 'green']
  555. for run_idx, run_name in enumerate(plot_run_names):
  556. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  557. labels.append(label)
  558. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  559. hdis.append(head_direction_indices)
  560. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  561. ax.hist(hdis, color=colors, label=labels, bins=30)
  562. for hdi, color in zip(hdis, colors):
  563. mean_hdi = np.mean(hdi)
  564. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  565. ax.set_xlabel("HDI")
  566. ax.legend()
  567. fig.tight_layout()
  568. if save_figs:
  569. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
  570. plt.close(fig)
  571. def plot_hdi_violin_excitatory(traj, plot_run_names):
  572. labels = []
  573. hdis = []
  574. colors = ['black', 'red', 'green']
  575. no_conn_hdi = 0.
  576. for run_idx, run_name in enumerate(plot_run_names):
  577. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  578. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  579. if label == NO_SYNAPSES:
  580. no_conn_hdi = np.mean(head_direction_indices)
  581. else:
  582. labels.append(label)
  583. hdis.append(sorted(head_direction_indices))
  584. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  585. # hdis = np.array(hdis)
  586. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  587. viol_plt['cmeans'].set_color('black')
  588. for pc in viol_plt['bodies']:
  589. pc.set_facecolor('red')
  590. pc.set_edgecolor('black')
  591. pc.set_alpha(0.7)
  592. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  593. ax.annotate(NO_SYNAPSES, xy=(0.45, 0.48), xycoords='axes fraction')
  594. ax.set_xticks(np.arange(1, len(labels) + 1))
  595. ax.set_xticklabels(labels)
  596. ax.set_ylabel('HDI')
  597. fig.tight_layout()
  598. if save_figs:
  599. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
  600. plt.close(fig)
  601. def plot_hdi_violin_inhibitory(traj, plot_run_names):
  602. labels = []
  603. hdis = []
  604. colors = ['black', 'red']
  605. for run_idx, run_name in enumerate(plot_run_names):
  606. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  607. if label != NO_SYNAPSES:
  608. labels.append(label)
  609. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  610. hdis.append(sorted(head_direction_indices))
  611. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  612. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  613. viol_plt['cmeans'].set_color('black')
  614. for pc in viol_plt['bodies']:
  615. pc.set_facecolor('blue')
  616. pc.set_edgecolor('black')
  617. pc.set_alpha(0.7)
  618. ax.set_xticks(np.arange(1, len(labels) + 1))
  619. ax.set_xticklabels(labels)
  620. ax.set_ylabel('HDI')
  621. fig.tight_layout()
  622. if save_figs:
  623. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
  624. plt.close(fig)
  625. def plot_hdi_violin_combined(traj, plot_run_names):
  626. labels = []
  627. inh_hdis = []
  628. exc_hdis = []
  629. no_conn_hdi = 0.
  630. colors = ['black', 'red']
  631. for run_idx, run_name in enumerate(plot_run_names):
  632. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  633. if label != NO_SYNAPSES:
  634. labels.append(label)
  635. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  636. inh_hdis.append(sorted(inh_head_direction_indices))
  637. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  638. exc_hdis.append(sorted(exc_head_direction_indices))
  639. else:
  640. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  641. no_conn_hdi = np.mean(exc_head_direction_indices)
  642. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  643. inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
  644. # viol_plt['cmeans'].set_color('black')
  645. #
  646. # for pc in viol_plt['bodies']:
  647. # pc.set_facecolor('blue')
  648. # pc.set_edgecolor('black')
  649. # pc.set_alpha(0.7)
  650. for b in inh_viol_plt['bodies']:
  651. m = np.mean(b.get_paths()[0].vertices[:, 0])
  652. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  653. b.set_color('b')
  654. exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
  655. for b in exc_viol_plt['bodies']:
  656. m = np.mean(b.get_paths()[0].vertices[:, 0])
  657. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  658. b.set_color('r')
  659. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  660. ax.annotate(NO_SYNAPSES, xy=(0.45, 0.48), xycoords='axes fraction')
  661. ax.set_xticks(np.arange(1, len(labels) + 1))
  662. ax.set_xticklabels(labels)
  663. ax.set_ylabel('HDI')
  664. fig.tight_layout()
  665. if save_figs:
  666. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
  667. plt.close(fig)
  668. def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  669. labels = []
  670. inh_hdis = []
  671. exc_hdis = []
  672. no_conn_hdi = 0.
  673. in_polar_plot_hdi = []
  674. ex_polar_plot_hdi = []
  675. colors = ['black', 'red']
  676. for run_idx, run_name in enumerate(plot_run_names):
  677. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  678. if label != NO_SYNAPSES:
  679. labels.append(label)
  680. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  681. inh_hdis.append(sorted(inh_head_direction_indices))
  682. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  683. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  684. exc_hdis.append(sorted(exc_head_direction_indices))
  685. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  686. else:
  687. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  688. no_conn_hdi = np.mean(exc_head_direction_indices)
  689. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  690. fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
  691. inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
  692. for b in inh_ell_viol_plt['bodies']:
  693. m = np.mean(b.get_paths()[0].vertices[:, 0])
  694. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  695. b.set_color('b')
  696. mean_line = inh_ell_viol_plt['cmeans']
  697. mean_line.set_color('b')
  698. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  699. exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
  700. for b in exc_ell_viol_plt['bodies']:
  701. m = np.mean(b.get_paths()[0].vertices[:, 0])
  702. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  703. b.set_color('r')
  704. mean_line = exc_ell_viol_plt['cmeans']
  705. mean_line.set_color('r')
  706. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  707. inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
  708. for b in inh_cir_viol_plt['bodies']:
  709. m = np.mean(b.get_paths()[0].vertices[:, 0])
  710. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  711. b.set_color('b')
  712. mean_line = inh_cir_viol_plt['cmeans']
  713. mean_line.set_color('b')
  714. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  715. exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
  716. for b in exc_cir_viol_plt['bodies']:
  717. m = np.mean(b.get_paths()[0].vertices[:, 0])
  718. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  719. b.set_color('r')
  720. mean_line = exc_cir_viol_plt['cmeans']
  721. mean_line.set_color('r')
  722. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  723. ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  724. ax.axvline(1.0, color='k')
  725. ax.annotate(NO_SYNAPSES, xy=(0.75, 0.415), xycoords='axes fraction')
  726. ax.set_xlim(0.5, 1.5)
  727. ax.set_ylim(0.0, 1.0)
  728. ax.set_xticks([0.75, 1.25])
  729. ax.set_xticklabels([CIRCULAR, POLARIZED])
  730. ax.set_ylabel('HDI')
  731. fig.tight_layout()
  732. if save_figs:
  733. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  734. plt.close(fig)
  735. return ex_polar_plot_hdi, in_polar_plot_hdi
  736. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id, cut_off_dist):
  737. labels = []
  738. inh_hdis = []
  739. exc_hdis = []
  740. no_conn_hdi = 0.
  741. for run_idx, run_name in enumerate(plot_run_names):
  742. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  743. if label != NO_SYNAPSES:
  744. labels.append(normal_labels(label))
  745. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  746. inh_axonal_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  747. inh_cut_off_ids = (inh_axonal_cloud[:, 0] >= cut_off_dist) & \
  748. (inh_axonal_cloud[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  749. (inh_axonal_cloud[:, 1] >= cut_off_dist) & \
  750. (inh_axonal_cloud[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  751. # print(inh_positions)
  752. inh_hdis.append(sorted(inh_head_direction_indices[inh_cut_off_ids]))
  753. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  754. ex_positions = traj.results.runs[run_name].ex_positions
  755. # print(ex_positions)
  756. exc_cut_off_ids = (ex_positions[:, 0] >= cut_off_dist) & \
  757. (ex_positions[:, 0] <= traj.parameters.map.sheet_size - cut_off_dist) & \
  758. (ex_positions[:, 1] >= cut_off_dist) & \
  759. (ex_positions[:, 1] <= traj.parameters.map.sheet_size - cut_off_dist)
  760. exc_hdis.append(sorted(exc_head_direction_indices[exc_cut_off_ids]))
  761. else:
  762. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  763. no_conn_hdi = np.mean(exc_head_direction_indices)
  764. # Look for a representative excitatory neuron
  765. hdi_mean_dict = {}
  766. excitatory_hdi_means = [np.mean(hdis) for hdis in exc_hdis]
  767. hdi_mean_dict["polar_exc"] = excitatory_hdi_means[0]
  768. hdi_mean_dict["circular_exc"] = excitatory_hdi_means[1]
  769. inhibitory_hdi_means = [np.mean(hdis) for hdis in inh_hdis]
  770. hdi_mean_dict["polar_inh"] = inhibitory_hdi_means[0]
  771. hdi_mean_dict["circular_inh"] = inhibitory_hdi_means[1]
  772. # # fig = plt.figure(figsize=(3.5, 3.5))
  773. # # gs1 = gridspec.GridSpec(1, 2)
  774. #
  775. # fig = plt.figure(constrained_layout=True)
  776. # gs = gridspec.GridSpec(ncols=1, nrows=2, hspace=0.0, wspace=0.0, figure=fig)
  777. # # gs.update(wspace=0.0, hspace=0.0) # set the spacing between axes.
  778. # # f2_ax1 = fig.add_subplot(gs[0, 0])
  779. # # f2_ax2 = fig.add_subplot(gs[0, 1])
  780. # print(gs.get_subplot_params())
  781. width = 2 * panel_size
  782. height = 1.2 * panel_size
  783. fig, axes = plt.subplots(2, 1, figsize=(width, height))
  784. plt.subplots_adjust(wspace=0, hspace=0.1)
  785. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  786. max_density = 0
  787. for i in range(2):
  788. # i = i + 1 # grid spec indexes from 0
  789. # ax = fig.add_subplot(gs[i])
  790. ax = axes[i]
  791. density_e, _, _ = ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  792. density_i, _, _ = ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  793. max_density = np.max([max_density, np.max(density_e), np.max(density_i)])
  794. ax.axvline(np.mean(exc_hdis[i]), color='r')
  795. ax.axvline(np.mean(inh_hdis[i]), color='b')
  796. ax.axvline(no_conn_hdi, color='dimgrey', linestyle='--', linewidth=1.5)
  797. ax.set_ylabel(labels[i], rotation='vertical')
  798. # plt.axis('on')
  799. if i == 0:
  800. ax.set_xticklabels([])
  801. else:
  802. ax.set_xlabel('head direction index')
  803. remove_frame(ax, ["top", "right", "bottom"])
  804. max_density = 1.2 * max_density
  805. fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2)
  806. axes[0].annotate('% cells', (0, 1.0), xycoords='axes fraction', va="bottom", ha="right")
  807. axes[1].annotate("no ihn.\n{:.2f}".format(no_conn_hdi), xy=(no_conn_hdi, max_density),
  808. xytext=(-2, 0), xycoords="data",
  809. textcoords="offset points",
  810. va="top", ha="right", color="dimgrey")
  811. for i, ax in enumerate(axes):
  812. ax.annotate("{:.2f}".format(np.mean(exc_hdis[i])), xy=(np.mean(exc_hdis[i]), max_density),
  813. xytext=(2, 0), xycoords="data",
  814. textcoords="offset points",
  815. va="top", ha="left", color="r")
  816. # i_ha = "left" if i == 1 else "right"
  817. # i_offset = 2 if i == 1 else -2
  818. i_ha = "right"
  819. i_offset = -1
  820. ax.annotate("{:.2f}".format(np.mean(inh_hdis[i])), xy=(np.mean(inh_hdis[i]), max_density),
  821. xytext=(i_offset, 0), xycoords="data",
  822. textcoords="offset points",
  823. va="top", ha=i_ha, color="b")
  824. for ax in axes:
  825. ax.set_ylim(0, max_density)
  826. # plt.annotate('probability density', (-0.2,1.5), xycoords='axes fraction', rotation=90, fontsize=18)
  827. if save_figs:
  828. plt.savefig(FIGURE_SAVE_PATH + 'E_hdi_histogram_combined_and_overlayed_cutoff_{}um.png'.format(cut_off_dist))
  829. plt.close(fig)
  830. return hdi_mean_dict
  831. def get_neurons_with_given_hdi(polar_hdi, circular_hdi, max_number_of_suggestions, plot_run_names, traj, type):
  832. polar_run_name = plot_run_names[0]
  833. circular_run_name = plot_run_names[1]
  834. polar_ex_hdis = traj.results.runs[polar_run_name].head_direction_indices if type == "ex" else traj.results.runs[
  835. polar_run_name].inh_head_direction_indices
  836. circular_ex_hdis = traj.results.runs[circular_run_name].head_direction_indices if type == "ex" else \
  837. traj.results.runs[
  838. polar_run_name].inh_head_direction_indices
  839. neuron_indices = get_indices_of_closest_values(polar_ex_hdis, polar_hdi,
  840. circular_ex_hdis,
  841. circular_hdi, 0.1 * np.abs(
  842. polar_hdi - circular_hdi), max_number_of_suggestions)
  843. return neuron_indices
  844. def get_indices_of_closest_values(first_list, first_value, second_list, second_value, absolute_tolerance_list_one,
  845. number_of_indices):
  846. is_close_in_list_one = np.abs(first_list - first_value) < absolute_tolerance_list_one
  847. indices_close_in_list_one = np.where(is_close_in_list_one)[0]
  848. indices_closest_in_list_two = indices_close_in_list_one[np.argpartition(np.abs(second_list[
  849. indices_close_in_list_one] - second_value),
  850. number_of_indices)]
  851. return indices_closest_in_list_two[:number_of_indices]
  852. def plot_hdi_histogram_inhibitory(traj, plot_run_names, in_polar_plot_id):
  853. labels = []
  854. hdis = []
  855. colors = ['black', 'red']
  856. for run_idx, run_name in enumerate(plot_run_names):
  857. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  858. if label != NO_SYNAPSES:
  859. labels.append(label)
  860. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  861. print('inh {}: {}'.format(label, head_direction_indices[in_polar_plot_id]))
  862. hdis.append(head_direction_indices)
  863. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  864. ax.hist(hdis, color=colors, label=labels, bins=30)
  865. for hdi, color in zip(hdis, colors):
  866. mean_hdi = np.mean(hdi)
  867. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  868. ax.set_xlabel("HDI")
  869. ax.legend()
  870. fig.tight_layout()
  871. if save_figs:
  872. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
  873. plt.close(fig)
  874. def filter_run_names_by_par_dict(traj, par_dict):
  875. run_name_list = []
  876. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  877. traj.f_set_crun(run_name)
  878. paramters_equal = True
  879. for key, val in par_dict.items():
  880. if (traj.par[key] != val):
  881. paramters_equal = False
  882. if paramters_equal:
  883. run_name_list.append(run_name)
  884. traj.f_restore_default()
  885. return run_name_list
  886. def plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, plot_run_names):
  887. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  888. seed_expl = traj.f_get('seed').f_get_range()
  889. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  890. label_range = set(label_expl)
  891. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  892. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  893. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  894. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  895. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  896. ex_tunings = traj.results.runs[run_name].ex_tunings
  897. head_direction_indices = traj.results[run_name].head_direction_indices
  898. # TODO: Actual correlation lengths
  899. # actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  900. exc_hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  901. inh_head_direction_indices = traj.results[run_name].inh_head_direction_indices
  902. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_head_direction_indices)
  903. # TODO: Standard deviation also for the population
  904. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  905. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  906. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  907. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  908. markersize = 4.
  909. exc_style_dict = {
  910. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  911. POLARIZED: ['red', 'solid', '^', markersize],
  912. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  913. }
  914. inh_style_dict = {
  915. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  916. POLARIZED: ['blue', 'solid', 'o', markersize],
  917. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  918. }
  919. # colors = ['blue', 'grey', 'lightblue']
  920. # linestyles = ['solid', 'dashed', 'solid']
  921. # markers = [verts, '', 'o']
  922. # corr_len_fit_dict = correlation_length_fit_dict(traj, load=True)
  923. width = 2 * panel_size
  924. height = 1.2 * panel_size
  925. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  926. for label in sorted(label_range, reverse=True):
  927. if label == NO_SYNAPSES:
  928. no_conn_hdi = exc_hdi_n_and_seed_mean[1, label]
  929. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  930. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  931. textcoords="offset points",
  932. va="top", \
  933. ha="right",
  934. color="dimgrey")
  935. continue
  936. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  937. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  938. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  939. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  940. corr_len_range = exc_hdi_mean.keys().to_numpy()
  941. print(label)
  942. for corr_len, ex_hdi, in_hdi in zip(corr_len_range, exc_hdi_mean, inh_hdi_mean):
  943. print("length: {:.2f} um, ex hdi: {:.2f} and in hdi {:.2f}".format(corr_len, ex_hdi, in_hdi))
  944. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  945. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  946. simplex_grid_scale = corr_len_range * np.sqrt(2)
  947. ax.plot(simplex_grid_scale, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  948. markersize=exc_mar_size, alpha=0.5)
  949. plt.fill_between(simplex_grid_scale, exc_hdi_mean - exc_hdi_std,
  950. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  951. ax.plot(simplex_grid_scale, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  952. markersize=inh_mar_size, alpha=0.5)
  953. plt.fill_between(simplex_grid_scale, inh_hdi_mean - inh_hdi_std,
  954. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  955. ax.set_xlabel('simplex grid scale')
  956. ax.set_ylabel('head direction index')
  957. ax.axvline(get_closest_correlation_length(traj, 200.0) * np.sqrt(2), color='k', linewidth=0.5, zorder=0)
  958. ax.set_ylim(0.0, 1.0)
  959. # ax.set_xlim(0.0, np.max(corr_len_range))
  960. remove_frame(ax, ["right", "top"])
  961. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  962. row_labels=None,
  963. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  964. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  965. fig.subplots_adjust(bottom=0.2, left=0.2)
  966. # plt.legend()
  967. if save_figs:
  968. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_grid_scale.png')
  969. plt.close(fig)
  970. def plot_exc_and_inh_hdi_over_fit_corr_len(traj, plot_run_names):
  971. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  972. seed_expl = traj.f_get('seed').f_get_range()
  973. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  974. label_range = set(label_expl)
  975. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  976. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  977. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  978. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  979. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  980. ex_tunings = traj.results.runs[run_name].ex_tunings
  981. head_direction_indices = traj.results[run_name].head_direction_indices
  982. # TODO: Actual correlation lengths
  983. # actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  984. exc_hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  985. inh_head_direction_indices = traj.results[run_name].inh_head_direction_indices
  986. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_head_direction_indices)
  987. # TODO: Standard deviation also for the population
  988. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  989. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  990. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  991. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  992. markersize = 4.
  993. exc_style_dict = {
  994. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  995. POLARIZED: ['red', 'solid', '^', markersize],
  996. CIRCULAR: ['lightsalmon', 'solid', '^', markersize]
  997. }
  998. inh_style_dict = {
  999. NO_SYNAPSES: ['dimgrey', 'dashed', '', 0],
  1000. POLARIZED: ['blue', 'solid', 'o', markersize],
  1001. CIRCULAR: ['lightblue', 'solid', 'o', markersize]
  1002. }
  1003. # colors = ['blue', 'grey', 'lightblue']
  1004. # linestyles = ['solid', 'dashed', 'solid']
  1005. # markers = [verts, '', 'o']
  1006. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  1007. width = 2 * panel_size
  1008. height = 1.2 * panel_size
  1009. fig, ax = plt.subplots(1, 1, figsize=(width, height))
  1010. for label in sorted(label_range, reverse=True):
  1011. if label == NO_SYNAPSES:
  1012. no_conn_hdi = exc_hdi_n_and_seed_mean[1, label]
  1013. ax.axhline(no_conn_hdi, color='grey', linestyle='--')
  1014. ax.annotate(short_labels(label), xy=(1.0, no_conn_hdi), xytext=(0, -2), xycoords='axes fraction',
  1015. textcoords="offset points",
  1016. va="top", \
  1017. ha="right",
  1018. color="dimgrey")
  1019. continue
  1020. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  1021. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  1022. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  1023. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  1024. corr_len_range = exc_hdi_mean.keys().to_numpy()
  1025. print(label)
  1026. for corr_len, ex_hdi, in_hdi in zip(corr_len_range, exc_hdi_mean, inh_hdi_mean):
  1027. print("length: {:.2f} um, ex hdi: {:.2f} and in hdi {:.2f}".format(corr_len, ex_hdi, in_hdi))
  1028. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  1029. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  1030. fit_corr_len = [corr_len_fit_dict[corr_len] for corr_len in corr_len_range]
  1031. print(corr_len_range)
  1032. print(fit_corr_len)
  1033. ax.plot(fit_corr_len, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin,
  1034. markersize=exc_mar_size, alpha=0.5)
  1035. plt.fill_between(fit_corr_len, exc_hdi_mean - exc_hdi_std,
  1036. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  1037. ax.plot(fit_corr_len, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin,
  1038. markersize=inh_mar_size, alpha=0.5)
  1039. plt.fill_between(fit_corr_len, inh_hdi_mean - inh_hdi_std,
  1040. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  1041. ax.set_xlabel('correlation length')
  1042. ax.set_ylabel('head direction index')
  1043. ax.axvline(corr_len_fit_dict[get_closest_correlation_length(traj, 200.0)], color='k', linewidth=0.5, zorder=0)
  1044. ax.set_ylim(0.0, 1.0)
  1045. # ax.set_xlim(0.0, np.max(corr_len_range))
  1046. remove_frame(ax, ["right", "top"])
  1047. tablelegend(ax, ncol=2, bbox_to_anchor=(1.1, 1.1), loc="upper right",
  1048. row_labels=None,
  1049. col_labels=[short_labels(label) for label in sorted(label_range - {"no conn"}, reverse=True)],
  1050. title_label='', borderaxespad=0, handlelength=2, edgecolor='white')
  1051. fig.subplots_adjust(bottom=0.2, left=0.2)
  1052. # plt.legend()
  1053. if save_figs:
  1054. plt.savefig(FIGURE_SAVE_PATH + 'F_hdi_over_corr_len_scaled.png')
  1055. plt.close(fig)
  1056. def plot_in_degree_map(traj, plot_run_names):
  1057. n_ex = int(np.sqrt(traj.N_E))
  1058. max_degree = 0
  1059. for run_name in plot_run_names:
  1060. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  1061. exc_degree = np.sum(ie_adjacency, axis=0)
  1062. run_max_degree = np.max(exc_degree)
  1063. if run_max_degree > max_degree:
  1064. max_degree = run_max_degree
  1065. fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
  1066. for ax, run_name in zip(axes, plot_run_names[:-1]):
  1067. traj.f_set_crun(run_name)
  1068. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1069. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  1070. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  1071. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  1072. exc_degree = np.sum(ie_adjacency, axis=0)
  1073. c = ax.pcolor(X, Y, np.reshape(exc_degree, (number_of_excitatory_neurons_per_row,
  1074. number_of_excitatory_neurons_per_row)), vmin=0, vmax=max_degree,
  1075. cmap='hot')
  1076. ax.set_title(label)
  1077. fig.colorbar(c, ax=ax, label="in/out-degree")
  1078. fig.suptitle('in/out-degree', fontsize=16)
  1079. traj.f_restore_default()
  1080. if save_figs:
  1081. plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png', dpi=200)
  1082. plt.close(fig)
  1083. def plot_spatial_hdi_map(traj, plot_run_names):
  1084. max_val = 0
  1085. for run_name in plot_run_names:
  1086. hdis = traj.results.runs[run_name].head_direction_indices
  1087. run_max_val = np.max(hdis)
  1088. if run_max_val > max_val:
  1089. max_val = run_max_val
  1090. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1091. for ax, run_name in zip(axes, plot_run_names):
  1092. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1093. positions = traj.results.runs[run_name].ex_positions
  1094. head_direction_indices = traj.results[run_name].head_direction_indices
  1095. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1096. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1097. ax.set_title(label)
  1098. fig.colorbar(c, ax=ax, label="head direction index")
  1099. fig.suptitle('spatial HDI map', fontsize=16)
  1100. if save_figs:
  1101. plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png', dpi=200)
  1102. plt.close(fig)
  1103. def plot_exc_spatial_hdi_map(traj, plot_run_names):
  1104. max_val = 0
  1105. for run_name in plot_run_names:
  1106. hdis = traj.results.runs[run_name].head_direction_indices
  1107. run_max_val = np.max(hdis)
  1108. if run_max_val > max_val:
  1109. max_val = run_max_val
  1110. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1111. for ax, run_name in zip(axes, plot_run_names):
  1112. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1113. positions = traj.results.runs[run_name].ex_positions
  1114. head_direction_indices = traj.results[run_name].head_direction_indices
  1115. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1116. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1117. ax.set_title(label)
  1118. fig.colorbar(c, ax=ax, label="head direction index")
  1119. fig.suptitle('spatial exc. HDI map', fontsize=16)
  1120. if save_figs:
  1121. plt.savefig(FIGURE_SAVE_PATH + 'spatial_exc_hdi_map.png', dpi=200)
  1122. plt.close(fig)
  1123. def plot_inh_spatial_hdi_map(traj, plot_run_names):
  1124. max_val = 0
  1125. for run_name in plot_run_names:
  1126. hdis = traj.results.runs[run_name].inh_head_direction_indices
  1127. run_max_val = np.max(hdis)
  1128. if run_max_val > max_val:
  1129. max_val = run_max_val
  1130. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  1131. for ax, run_name in zip(axes, plot_run_names):
  1132. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1133. ax_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  1134. positions = [[x, y] for x, y, phi in ax_cloud]
  1135. head_direction_indices = traj.results[run_name].inh_head_direction_indices
  1136. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  1137. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  1138. ax.set_title(label)
  1139. fig.colorbar(c, ax=ax, label="head direction index")
  1140. fig.suptitle('spatial inh. HDI map', fontsize=16)
  1141. if save_figs:
  1142. plt.savefig(FIGURE_SAVE_PATH + 'spatial_inh_hdi_map.png', dpi=200)
  1143. plt.close(fig)
  1144. def get_phase_difference(total_difference):
  1145. """
  1146. Map accumulated phase difference to shortest possible difference
  1147. :param total_difference:
  1148. :return: relative_difference
  1149. """
  1150. return (total_difference + np.pi) % (2 * np.pi) - np.pi
  1151. def plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names):
  1152. # The plot that Imre wanted
  1153. n_bins = traj.parameters.input.number_of_directions
  1154. fig, ax = plt.subplots(1, 1, figsize=(9, 9))
  1155. dir_bins = np.linspace(-np.pi, np.pi, n_bins + 1)
  1156. print(dir_bins)
  1157. plot_fr_array = []
  1158. labels = []
  1159. similarity_threshold = np.pi / 6.
  1160. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  1161. for run_idx, run_name in enumerate(plot_run_names):
  1162. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  1163. labels.append(label)
  1164. fr_similar_tunings = []
  1165. fr_different_tunings = []
  1166. ex_tunings = traj.results.runs[run_name].ex_tunings
  1167. firing_rate_array = traj.results[run_name].firing_rate_array
  1168. for tuning, firing_rates in zip(ex_tunings, firing_rate_array):
  1169. for idx, dir in enumerate(directions):
  1170. if np.abs(get_phase_difference(tuning - dir)) <= similarity_threshold:
  1171. fr_similar_tunings.append(firing_rates[idx])
  1172. elif np.abs(get_phase_difference(tuning + np.pi - dir)) <= similarity_threshold:
  1173. fr_different_tunings.append(firing_rates[idx])
  1174. plot_fr_array.append([np.mean(fr_similar_tunings), np.mean(fr_different_tunings)])
  1175. x = np.arange(3) # the label locations
  1176. width = 0.35 # the width of the bars
  1177. plot_fr_array = np.array(plot_fr_array)
  1178. rects1 = ax.bar(x - width / 2, plot_fr_array[:, 0], width,
  1179. label='theta pref +/- {}°'.format(np.round(similarity_threshold / np.pi * 180)))
  1180. rects2 = ax.bar(x + width / 2, plot_fr_array[:, 1], width,
  1181. label='theta pref + 180° +/- {}°'.format(np.round(similarity_threshold / np.pi * 180)))
  1182. ax.set_xticks(x)
  1183. ax.set_xticklabels(labels)
  1184. ax.set_title('Mean firing rate for tunings similar and different to input')
  1185. ax.set_ylabel('Mean firing rate')
  1186. ax.legend()
  1187. def autolabel(rects):
  1188. """Attach a text label above each bar in *rects*, displaying its height."""
  1189. for rect in rects:
  1190. height = rect.get_height()
  1191. ax.annotate('{}'.format(np.round(height)),
  1192. xy=(rect.get_x() + rect.get_width() / 2, height),
  1193. xytext=(0, 3), # 3 points vertical offset
  1194. textcoords="offset points",
  1195. ha='center', va='bottom')
  1196. autolabel(rects1)
  1197. autolabel(rects2)
  1198. fig.tight_layout()
  1199. if save_figs:
  1200. plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_similar_vs_diff_tuning.png', dpi=200)
  1201. plt.close(fig)
  1202. def get_firing_rates_along_preferred_axis(traj, run_name, neuron_idx):
  1203. firing_rates = traj.results[run_name].firing_rate_array[neuron_idx, :]
  1204. tuning = traj.results[run_name].ex_tunings[neuron_idx]
  1205. anti_tuning = tuning + np.pi if tuning + np.pi < np.pi else tuning - np.pi
  1206. tuning_idx = np.argmin(np.abs(directions - tuning))
  1207. anti_tuning_idx = np.argmin(np.abs(directions - anti_tuning))
  1208. firing_at_the_preferred_direction = firing_rates[tuning_idx]
  1209. firing_at_the_opposite_direction = firing_rates[anti_tuning_idx]
  1210. return firing_at_the_preferred_direction, firing_at_the_opposite_direction
  1211. def get_hdi(traj, run_name, neuron_idx, type):
  1212. return traj.results.runs[run_name].head_direction_indices[neuron_idx] if type=="ex" else traj.results.runs[
  1213. run_name].inh_head_direction_indices[neuron_idx]
  1214. if __name__ == "__main__":
  1215. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  1216. NO_LOADING = 0
  1217. FULL_LOAD = 2
  1218. traj.f_load(filename=os.path.join(DATA_FOLDER, TRAJ_NAME + ".hdf5"), load_parameters=FULL_LOAD,
  1219. load_results=NO_LOADING)
  1220. traj.v_auto_load = True
  1221. save_figs = True
  1222. print("# Plotting script polarized interneurons")
  1223. print()
  1224. map_length_scale = 200.0
  1225. map_seed = 1
  1226. exemplary_head_direction = 0
  1227. corr_len_fit_dict = correlation_length_fit_dict(traj, map_type='pinwheel', load=True)
  1228. plt.plot(corr_len_fit_dict.keys(),corr_len_fit_dict.values())
  1229. plt.show()
  1230. abbrechen
  1231. print("## Map specifications")
  1232. print("\tcorrelation length: {:.1f} um".format(map_length_scale))
  1233. print("\tmap seed: {:d}".format(map_seed))
  1234. print()
  1235. print("## Input specification")
  1236. print("\tselected head direction: {:.0f}°".format(exemplary_head_direction))
  1237. print()
  1238. print("## Selected simulations")
  1239. plot_corr_len = get_closest_correlation_length(traj, map_length_scale)
  1240. par_dict = {'seed': map_seed, 'correlation_length': plot_corr_len}
  1241. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  1242. run_name_dict = {}
  1243. for run_name in plot_run_names:
  1244. traj.f_set_crun(run_name)
  1245. run_name_dict[traj.derived_parameters.runs[run_name].morphology.morph_label] = run_name
  1246. for network_type, run_name in run_name_dict.items():
  1247. print("{:s}: {:s}".format(network_type, run_name))
  1248. directions = get_input_head_directions(traj)
  1249. direction_idx = np.argmin(np.abs(np.array(directions) - np.deg2rad(exemplary_head_direction)))
  1250. selected_neuron_excitatory = 1052
  1251. selected_inhibitory_neuron = 28
  1252. print("## Figure specification")
  1253. print("\tpanel size: {:.2f} cm".format(panel_size * cm_per_inch))
  1254. print()
  1255. #
  1256. plot_colorbar(figsize=(0.8 * panel_size, 0.8 * panel_size), figname=FIGURE_SAVE_PATH + "A_i_colormap.svg")
  1257. plot_input_map(traj, run_name_dict[POLARIZED], figname="A_i_exemplary_input_map.png",
  1258. figsize=(panel_size, panel_size))
  1259. plot_axonal_clouds(traj, plot_run_names)
  1260. #
  1261. plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names, selected_neuron_excitatory)
  1262. in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names, selected_inhibitory_neuron)
  1263. #
  1264. hdi_means = plot_hdi_histogram_combined_and_overlayed(
  1265. traj, plot_run_names,
  1266. selected_neuron_excitatory,
  1267. selected_inhibitory_neuron,
  1268. cut_off_dist=100.)
  1269. #
  1270. number_of_suggestions = 0
  1271. representative_excitatory_neuron_indices = get_neurons_with_given_hdi(hdi_means["polar_exc"], hdi_means[
  1272. "circular_exc"],
  1273. number_of_suggestions, plot_run_names,
  1274. traj, "ex")
  1275. representative_inhibitory_neuron_indices = get_neurons_with_given_hdi(hdi_means["polar_inh"],
  1276. hdi_means["circular_inh"],
  1277. number_of_suggestions, plot_run_names,
  1278. traj, "in")
  1279. plot_polar_plot_excitatory(traj, plot_run_names, selected_neuron_excitatory)
  1280. for suggested_excitatory_neuron in representative_excitatory_neuron_indices:
  1281. preferred_polar, opposite_polar = get_firing_rates_along_preferred_axis(
  1282. traj, run_name_dict["ellipsoid"], suggested_excitatory_neuron)
  1283. preferred_circular, opposite_circular = get_firing_rates_along_preferred_axis(
  1284. traj, run_name_dict["circular"], suggested_excitatory_neuron)
  1285. plot_polar_plot_excitatory(traj, plot_run_names, suggested_excitatory_neuron,
  1286. figname=FIGURE_SAVE_PATH + "X_polar_plot_excitatory_neuron_signal_increase_{"
  1287. ":0>2.0f}_noise_decrease_{"
  1288. ":0>2.0f}_neuron_id_{"
  1289. ":d}.png".format(
  1290. preferred_polar-preferred_circular,
  1291. opposite_circular-opposite_polar, suggested_excitatory_neuron))
  1292. plot_polar_plot_inhibitory(traj, plot_run_names, selected_inhibitory_neuron)
  1293. for suggested_inhibitory_neuron in representative_inhibitory_neuron_indices:
  1294. polar_hdi = get_hdi(traj, run_name_dict["ellipsoid"], suggested_inhibitory_neuron, "in")
  1295. circular_hdi = get_hdi(traj, run_name_dict["circular"], suggested_inhibitory_neuron, "in")
  1296. plot_polar_plot_inhibitory(traj, plot_run_names, suggested_inhibitory_neuron,
  1297. figname=FIGURE_SAVE_PATH + "X_polar_plot_inhibitory_neuron_polarized_{"
  1298. ":.2f}_circular_{:.2f}_neuron_id_{"
  1299. ":d}.png".format(polar_hdi, circular_hdi,
  1300. suggested_inhibitory_neuron))
  1301. plot_firing_rate_similar_vs_diff_tuning(traj, plot_run_names)
  1302. plot_orientation_maps_diff_scales_with_ellipse(traj)
  1303. plot_exc_and_inh_hdi_over_simplex_grid_scale(traj, traj.f_get_run_names())
  1304. plot_exc_and_inh_hdi_over_fit_corr_len(traj, traj.f_get_run_names())
  1305. if not save_figs:
  1306. plt.show()
  1307. traj.f_restore_default()