paper_figures_entropy_minimisation_perlin_map.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import pandas as pd
  4. from brian2.units import *
  5. from pypet import Trajectory
  6. from pypet.brian2 import Brian2MonitorResult
  7. from scipy.optimize import curve_fit
  8. from matplotlib.patches import Ellipse
  9. from matplotlib.patches import Polygon
  10. from scripts.interneuron_placement import get_position_mesh, Pickle, get_correct_position_mesh
  11. from scripts.spatial_maps.uniform_perlin_map import get_correlation_length
  12. from scripts.spatial_network.figures_spatial_head_direction_network_orientation_map import plot_hdi_in_space
  13. from scripts.spatial_network.minimum_entropy_rule.run_entropy_minimisation_perlin_map import DATA_FOLDER, TRAJ_NAME
  14. from scripts.spatial_network.paper_figures_spatial_head_direction_network_orientation_map import tablelegend
  15. FIGURE_SAVE_PATH = '../../../figures/perlin_minimum_entropy_rule/'
  16. def get_closest_correlation_length(traj, correlation_length):
  17. available_lengths = sorted(list(set(traj.f_get("correlation_length").f_get_range())))
  18. closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths) - correlation_length))]
  19. if closest_length != correlation_length:
  20. print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
  21. correlation_length, closest_length))
  22. corr_len = closest_length
  23. return corr_len
  24. def gauss(x, *p):
  25. A, mu, sigma, B = p
  26. return A * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2)) + B
  27. def plot_tuning_curve(traj, direction_idx, plot_run_names):
  28. # p0 is the initial guess for the fitting coefficients (A, mu and sigma above)
  29. p0 = [30., 0., 1., 10.]
  30. fig, ax = plt.subplots(1, 1)
  31. for run_idx, run_name in enumerate(plot_run_names):
  32. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  33. ex_tunings = traj.results.runs[run_name].ex_tunings
  34. coeff_list = []
  35. ex_tunings_plt = np.array(ex_tunings)
  36. sort_ids = ex_tunings_plt.argsort()
  37. ex_tunings_plt = ex_tunings_plt[sort_ids]
  38. firing_rate_array = traj.results[run_name].firing_rate_array
  39. # firing_rate_array = traj.f_get('firing_rate_array')
  40. rates_plt = firing_rate_array[:, direction_idx]
  41. rates_plt = rates_plt[sort_ids]
  42. coeff, var_matrix = curve_fit(gauss, ex_tunings_plt, rates_plt / hertz, p0=p0)
  43. coeff_list.append(coeff)
  44. ax.scatter(ex_tunings_plt, rates_plt / hertz, label=label, alpha=0.3)
  45. ax.plot(ex_tunings_plt, gauss(ex_tunings_plt, *coeff), label=label + '-fit')
  46. ax.legend()
  47. ax.set_xlabel("Angles (rad)")
  48. ax.set_ylabel("f (Hz)")
  49. ax.set_title('tuning curves', fontsize=16)
  50. if save_figs:
  51. plt.savefig(FIGURE_SAVE_PATH + 'tuning_curve.png', dpi=200)
  52. def colorbar(mappable):
  53. from mpl_toolkits.axes_grid1 import make_axes_locatable
  54. import matplotlib.pyplot as plt
  55. last_axes = plt.gca()
  56. ax = mappable.axes
  57. fig = ax.figure
  58. divider = make_axes_locatable(ax)
  59. cax = divider.append_axes("right", size="5%", pad=0.05)
  60. cbar = fig.colorbar(mappable, cax=cax)
  61. plt.sca(last_axes)
  62. return cbar
  63. def plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names):
  64. max_val = 0
  65. for run_name in plot_run_names:
  66. fr_array = traj.results.runs[run_name].firing_rate_array
  67. f_rates = fr_array[:, direction_idx]
  68. run_max_val = np.max(f_rates)
  69. if run_max_val > max_val:
  70. # if traj.derived_parameters.runs[run_name].morphology.morph_label == 'ellipsoid':
  71. # n_id_max_rate = np.argmax(f_rates)
  72. max_val = run_max_val
  73. n_id_polar_plot = 609
  74. # Mark the neuron that is shown in Polar plot
  75. ex_positions = traj.results.runs[plot_run_names[0]].ex_positions
  76. polar_plot_x, polar_plot_y = ex_positions[n_id_polar_plot]
  77. # Vertices for the plotted triangle
  78. tr_scale = 13.
  79. tr_x = tr_scale * np.cos(2. * np.pi / 3. + np.pi / 2.)
  80. tr_y = tr_scale * np.sin(2. * np.pi / 3. + np.pi / 2.) + polar_plot_y
  81. tr_vertices = np.array([[polar_plot_x, polar_plot_y + tr_scale], [tr_x + polar_plot_x, tr_y], [-tr_x + polar_plot_x, tr_y]])
  82. height = 3.5
  83. # color_bar_size = 0.05 * height + 0.05
  84. # width = 3 * height + color_bar_size
  85. width = 10.5
  86. fig, axes = plt.subplots(1, 3, figsize=(width, height))
  87. for ax, run_name in zip(axes, [plot_run_names[i] for i in [2,0,1]]):
  88. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  89. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  90. firing_rate_array = traj.results[run_name].firing_rate_array
  91. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  92. c = ax.pcolor(X, Y, np.reshape(firing_rate_array[:, direction_idx], (number_of_excitatory_neurons_per_row,
  93. number_of_excitatory_neurons_per_row)),
  94. vmin=0, vmax=max_val, cmap='Reds')
  95. ax.set_title(label)
  96. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='k', fill=False, lw=2.))
  97. # ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), 20., 20., color='w', fill=False, lw=1.))
  98. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=2.5, color='k'))
  99. ax.add_artist(Polygon(tr_vertices, closed=True, fill=False, lw=1.5, color='w'))
  100. # fig.suptitle('spatial firing rate map', fontsize=16)
  101. colorbar(c)
  102. fig.tight_layout()
  103. if save_figs:
  104. plt.savefig(FIGURE_SAVE_PATH + 'firing_rate_map.png', dpi=200)
  105. return n_id_polar_plot
  106. def plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names):
  107. max_val = 0
  108. for run_name in plot_run_names:
  109. fr_array = traj.results.runs[run_name].inh_firing_rate_array
  110. f_rates = fr_array[:, direction_idx]
  111. run_max_val = np.max(f_rates)
  112. if run_max_val > max_val:
  113. max_val = run_max_val
  114. n_id_polar_plot = 55
  115. # Mark the neuron that is shown in Polar plot
  116. inhibitory_axonal_cloud_array = traj.results.runs[plot_run_names[1]].inhibitory_axonal_cloud_array
  117. polar_plot_x = inhibitory_axonal_cloud_array[n_id_polar_plot, 0]
  118. polar_plot_y = inhibitory_axonal_cloud_array[n_id_polar_plot, 1]
  119. fig, axes = plt.subplots(1, 2, figsize=(7.0, 3.5))
  120. for ax, run_name in zip(axes, [plot_run_names[i] for i in [0,1]]):
  121. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  122. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  123. inh_positions = [[p[0], p[1]] for p in inhibitory_axonal_cloud_array]
  124. X, Y = get_correct_position_mesh(inh_positions)
  125. inh_firing_rate_array = traj.results[run_name].inh_firing_rate_array
  126. number_of_inhibitory_neurons_per_row = int(np.sqrt(traj.N_I))
  127. c = ax.pcolor(X, Y, np.reshape(inh_firing_rate_array[:, direction_idx], (number_of_inhibitory_neurons_per_row,
  128. number_of_inhibitory_neurons_per_row)),
  129. vmin=0, vmax=max_val, cmap='Blues')
  130. ax.set_title(label)
  131. circle_r = 40.
  132. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='k', fill=False, lw=4.5))
  133. ax.add_artist(Ellipse((polar_plot_x, polar_plot_y), circle_r, circle_r, color='w', fill=False, lw=3))
  134. # fig.colorbar(c, ax=ax, label="f (Hz)")
  135. # fig.suptitle('spatial firing rate map', fontsize=16)
  136. colorbar(c)
  137. fig.tight_layout()
  138. if save_figs:
  139. plt.savefig(FIGURE_SAVE_PATH + 'inh_firing_rate_map.png', dpi=200)
  140. return n_id_polar_plot, max_val
  141. def plot_hdi_over_tuning(traj, plot_run_names):
  142. fig, ax = plt.subplots(1, 1)
  143. for run_idx, run_name in enumerate(plot_run_names):
  144. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  145. ex_tunings = traj.results.runs[run_name].ex_tunings
  146. ex_tunings_plt = np.array(ex_tunings)
  147. sort_ids = ex_tunings_plt.argsort()
  148. ex_tunings_plt = ex_tunings_plt[sort_ids]
  149. head_direction_indices = traj.results[run_name].head_direction_indices
  150. hdi_plt = head_direction_indices
  151. hdi_plt = hdi_plt[sort_ids]
  152. ax.scatter(ex_tunings_plt, hdi_plt, label=label, alpha=0.3)
  153. ax.legend()
  154. ax.set_xlabel("Angles (rad)")
  155. ax.set_ylabel("head direction index")
  156. ax.set_title('hdi over input tuning', fontsize=16)
  157. if save_figs:
  158. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_tuning.png', dpi=200)
  159. def adjust_label(label):
  160. if label == 'ellipsoid':
  161. label = 'polar'
  162. elif label == 'no conn':
  163. label = 'no interneurons'
  164. return label
  165. def plot_axonal_clouds(traj, plot_run_names):
  166. n_ex = int(np.sqrt(traj.N_E))
  167. fig, axes = plt.subplots(1, 3, figsize=(10.5, 3.5))
  168. for ax, run_name in zip(axes, [plot_run_names[i] for i in [0, 1, 2]]):
  169. traj.f_set_crun(run_name)
  170. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  171. X, Y = get_correct_position_mesh(traj.results.runs[run_name].ex_positions)
  172. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  173. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  174. inhibitory_axonal_cloud_array]
  175. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  176. # TODO: Why was this transposed for plotting? (now changed)
  177. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  178. ax.set_title(adjust_label(label))
  179. ax.set_aspect('equal')
  180. # fig.colorbar(c, ax=ax, label="Tuning")
  181. if label != 'no conn' and axonal_clouds is not None:
  182. for i, p in enumerate(axonal_clouds):
  183. ell = p.get_ellipse()
  184. ax.add_artist(ell)
  185. axes[0].set_ylabel('\u03BCm')
  186. axes[1].set_yticklabels([])
  187. axes[2].set_yticklabels([])
  188. traj.f_restore_default()
  189. fig.tight_layout()
  190. if save_figs:
  191. plt.savefig(FIGURE_SAVE_PATH + 'axonal_clouds.png', dpi=200)
  192. def plot_orientation_maps_diff_scales(traj):
  193. n_ex = int(np.sqrt(traj.N_E))
  194. scale_run_names = []
  195. plot_scales = [0.0, 100.0, 200.0, 300.0]
  196. for scale in plot_scales:
  197. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  198. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  199. fig, axes = plt.subplots(1, 4, figsize=(18., 4.5))
  200. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  201. traj.f_set_crun(run_name)
  202. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  203. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  204. # TODO: Why was this transposed for plotting? (now changed)
  205. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='twilight')
  206. ax.set_title('Correlation length: {}'.format(scale))
  207. fig.colorbar(c, ax=ax, label="Tuning")
  208. # fig.suptitle('axonal cloud', fontsize=16)
  209. traj.f_restore_default()
  210. if save_figs:
  211. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales.png', dpi=200)
  212. def plot_orientation_maps_diff_scales_with_ellipse(traj):
  213. n_ex = int(np.sqrt(traj.N_E))
  214. scale_run_names = []
  215. plot_scales = [0.0, 100.0, 200.0, 300.0, 400.0]
  216. for scale in plot_scales:
  217. par_dict = {'seed': 1, 'correlation_length': get_closest_correlation_length(traj,scale), 'long_axis': 100.}
  218. scale_run_names.append(*filter_run_names_by_par_dict(traj, par_dict))
  219. print(scale_run_names)
  220. fig, axes = plt.subplots(1, 5, figsize=(18., 4.5))
  221. for ax, run_name, scale in zip(axes, scale_run_names, plot_scales):
  222. traj.f_set_crun(run_name)
  223. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  224. inhibitory_axonal_cloud_array = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  225. axonal_clouds = [Pickle(p[0], p[1], traj.morphology.long_axis, traj.morphology.short_axis, p[2]) for p in
  226. inhibitory_axonal_cloud_array]
  227. head_dir_preference = np.array(traj.results.runs[run_name].ex_tunings).reshape((n_ex, n_ex))
  228. # TODO: Why was this transposed for plotting? (now changed)
  229. c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
  230. # ax.set_title('Correlation length: {}'.format(scale))
  231. # fig.colorbar(c, ax=ax, label="Tuning")
  232. ax.set_xticks([])
  233. ax.set_yticks([])
  234. p1 = axonal_clouds[44]
  235. ell = p1.get_ellipse()
  236. ell._linewidth = 5.
  237. ax.add_artist(ell)
  238. p2 = axonal_clouds[77]
  239. circ_r = 2 * np.sqrt(2500.)
  240. circ = Ellipse((p2.x, p2.y), circ_r, circ_r, fill=False, zorder=2, edgecolor='k')
  241. circ._linewidth = 5.
  242. ax.add_artist(circ)
  243. # fig.suptitle('axonal cloud', fontsize=16)
  244. traj.f_restore_default()
  245. fig.tight_layout()
  246. if save_figs:
  247. plt.savefig(FIGURE_SAVE_PATH + 'orientation_maps_diff_scales_with_ellipse.png', dpi=200)
  248. def plot_excitatory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, ex_polar_plot_hdi):
  249. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  250. directions_plt = list(directions)
  251. directions_plt.append(directions[0])
  252. fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5), subplot_kw=dict(projection='polar'))
  253. # head_direction_indices = traj.results.runs[plot_run_names[0]].head_direction_indices
  254. # sorted_ids = np.argsort(head_direction_indices)
  255. # plot_n_idx = sorted_ids[-75]
  256. plot_n_idx = polar_plot_id
  257. line_styles = ['dotted', 'solid', 'dashed']
  258. colors = ['r', 'lightsalmon', 'grey']
  259. max_rate = 0.0
  260. for run_idx, run_name in enumerate(plot_run_names):
  261. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  262. tuning_vectors = traj.results.runs[run_name].tuning_vectors
  263. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  264. run_max_rate = np.max(rate_plot)
  265. if run_max_rate > max_rate:
  266. max_rate = run_max_rate
  267. rate_plot.append(rate_plot[0])
  268. ax.plot(directions_plt, rate_plot, linewidth=2.5, label='{0}, HDI: {1:.2f}'.format(adjust_label(label), ex_polar_plot_hdi[run_idx]), color=colors[run_idx], linestyle=line_styles[run_idx])
  269. # ax.set_title('Firing Rate')
  270. # ax.plot([0.0, 0.0], [0.0, 1.05 * max_rate], color='red', alpha=0.25, linewidth=4.)
  271. # TODO: Set ticks for polar
  272. ticks = [40., 80.]
  273. ax.set_rticks(ticks)
  274. ax.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2])
  275. ax.set_xticklabels(['0°', '90°', '', ''])
  276. ax.set_rlabel_position(230)
  277. ax.legend(loc='upper center', bbox_to_anchor=(0.5, 0.15),
  278. fancybox=True, shadow=True)
  279. plt.tight_layout()
  280. if save_figs:
  281. plt.savefig(FIGURE_SAVE_PATH + 'condensed_polar_plot.png', dpi=200)
  282. def plot_inhibitory_condensed_polar_plot(traj, plot_run_names, polar_plot_id, in_polar_plot_hdi):
  283. directions = np.linspace(-np.pi, np.pi, traj.input.number_of_directions, endpoint=False)
  284. directions_plt = list(directions)
  285. directions_plt.append(directions[0])
  286. fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5), subplot_kw=dict(projection='polar'))
  287. # head_direction_indices = traj.results.runs[plot_run_names[0]].inh_head_direction_indices
  288. # sorted_ids = np.argsort(head_direction_indices)
  289. # plot_n_idx = sorted_ids[-75]
  290. plot_n_idx = polar_plot_id
  291. line_styles = ['dotted', 'solid']
  292. colors = ['b', 'lightblue']
  293. for run_idx, run_name in enumerate(plot_run_names[:2]):
  294. # ax = axes[max_hdi_idx, run_idx]
  295. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  296. tuning_vectors = traj.results.runs[run_name].inh_tuning_vectors
  297. rate_plot = [np.linalg.norm(v) for v in tuning_vectors[plot_n_idx]]
  298. rate_plot.append(rate_plot[0])
  299. ax.plot(directions_plt, rate_plot, linewidth=2.5, label='{0}, HDI: {1:.2f}'.format(adjust_label(label), in_polar_plot_hdi[run_idx]), color=colors[run_idx], linestyle=line_styles[run_idx])
  300. # ax.set_title('Inh. Firing Rate')
  301. # TODO: Set ticks for polar
  302. # ticks = [np.round(max_rate / 3.), np.round(max_rate * 2. / 3.), np.round(max_rate)]
  303. ticks = [40., 80., 120.]
  304. ax.set_rticks(ticks)
  305. ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
  306. ax.set_xticklabels(['0°', '90°', '', ''])
  307. ax.set_rlabel_position(230)
  308. ax.legend(loc='upper center', bbox_to_anchor=(0.5, 0.15),
  309. fancybox=True, shadow=True)
  310. plt.tight_layout()
  311. if save_figs:
  312. plt.savefig(FIGURE_SAVE_PATH + 'condensed_inhibitory_polar_plot.png', dpi=200)
  313. def plot_hdi_over_corr_len(traj, plot_run_names):
  314. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  315. seed_expl = traj.f_get('seed').f_get_range()
  316. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  317. label_range = set(label_expl)
  318. hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  319. hdi_frame.index.names = ["corr_len", "seed", "label"]
  320. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  321. ex_tunings = traj.results.runs[run_name].ex_tunings
  322. head_direction_indices = traj.results[run_name].head_direction_indices
  323. actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  324. hdi_frame[actual_corr_len, seed, label] = np.mean(head_direction_indices)
  325. # TODO: Standart deviation also for the population
  326. hdi_exc_n_and_seed_mean = hdi_frame.groupby(level=[0, 2]).mean()
  327. hdi_exc_n_and_seed_std_dev = hdi_frame.groupby(level=[0, 2]).std()
  328. # Ellipsoid markers
  329. rx, ry = 5., 12.
  330. # area = rx * ry * np.pi * 2.
  331. area = 1.
  332. theta = np.arange(0, 2 * np.pi + 0.01, 0.1)
  333. verts = np.column_stack([rx / area * np.cos(theta), ry / area * np.sin(theta)])
  334. style_dict = {
  335. 'no conn': ['grey', 'dashed', '', 0],
  336. 'ellipsoid': ['blue', 'solid', verts, 10.],
  337. 'circular': ['lightblue', 'solid', 'o', 8.]
  338. }
  339. # colors = ['blue', 'grey', 'lightblue']
  340. # linestyles = ['solid', 'dashed', 'solid']
  341. # markers = [verts, '', 'o']
  342. fig, ax = plt.subplots(1, 1)
  343. for label in label_range:
  344. hdi_mean = hdi_exc_n_and_seed_mean[:, label]
  345. hdi_std = hdi_exc_n_and_seed_std_dev[:, label]
  346. corr_len_range = hdi_mean.keys().to_numpy()
  347. col, lin, mar, mar_size = style_dict[label]
  348. ax.plot(corr_len_range, hdi_mean, label=label, marker=mar, color=col, linestyle=lin, markersize=mar_size)
  349. plt.fill_between(corr_len_range, hdi_mean - hdi_std,
  350. hdi_mean + hdi_std, alpha=0.4, color=col)
  351. ax.set_xlabel('Correlation length')
  352. ax.set_ylabel('Head Direction Index')
  353. ax.axvline(206.9, color='k', linewidth=0.5)
  354. ax.set_ylim(0.0,1.0)
  355. ax.set_xlim(0.0,400.)
  356. ax.legend()
  357. if save_figs:
  358. plt.savefig(FIGURE_SAVE_PATH + 'hdi_over_corr_len_scaled.png', dpi=200)
  359. def plot_hdi_histogram_excitatory(traj, plot_run_names):
  360. labels = []
  361. hdis = []
  362. colors = ['black', 'red', 'green']
  363. for run_idx, run_name in enumerate(plot_run_names):
  364. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  365. labels.append(label)
  366. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  367. hdis.append(head_direction_indices)
  368. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  369. ax.hist(hdis, color=colors, label=labels, bins=30)
  370. for hdi, color in zip(hdis, colors):
  371. mean_hdi = np.mean(hdi)
  372. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  373. ax.set_xlabel("HDI")
  374. ax.legend()
  375. fig.tight_layout()
  376. if save_figs:
  377. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_excitatory.png', dpi=200)
  378. def plot_hdi_violin_excitatory(traj, plot_run_names):
  379. labels = []
  380. hdis = []
  381. colors = ['black', 'red', 'green']
  382. no_conn_hdi = 0.
  383. for run_idx, run_name in enumerate(plot_run_names):
  384. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  385. head_direction_indices = traj.results.runs[run_name].head_direction_indices
  386. if label == 'no conn':
  387. no_conn_hdi = np.mean(head_direction_indices)
  388. else:
  389. labels.append(label)
  390. hdis.append(sorted(head_direction_indices))
  391. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  392. # hdis = np.array(hdis)
  393. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  394. viol_plt['cmeans'].set_color('black')
  395. for pc in viol_plt['bodies']:
  396. pc.set_facecolor('red')
  397. pc.set_edgecolor('black')
  398. pc.set_alpha(0.7)
  399. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  400. ax.annotate('no conn', xy=(0.45,0.48), xycoords='axes fraction')
  401. ax.set_xticks(np.arange(1, len(labels) + 1))
  402. ax.set_xticklabels(labels)
  403. ax.set_ylabel('HDI')
  404. fig.tight_layout()
  405. if save_figs:
  406. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_excitatory.png', dpi=200)
  407. def plot_hdi_violin_inhibitory(traj, plot_run_names):
  408. labels = []
  409. hdis = []
  410. colors = ['black', 'red']
  411. for run_idx, run_name in enumerate(plot_run_names):
  412. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  413. if label != 'no conn':
  414. labels.append(label)
  415. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  416. hdis.append(sorted(head_direction_indices))
  417. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  418. viol_plt = ax.violinplot(hdis, showmeans=True, showextrema=False)
  419. viol_plt['cmeans'].set_color('black')
  420. for pc in viol_plt['bodies']:
  421. pc.set_facecolor('blue')
  422. pc.set_edgecolor('black')
  423. pc.set_alpha(0.7)
  424. ax.set_xticks(np.arange(1, len(labels) + 1))
  425. ax.set_xticklabels(labels)
  426. ax.set_ylabel('HDI')
  427. fig.tight_layout()
  428. if save_figs:
  429. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_inhibitory.png', dpi=200)
  430. def plot_hdi_violin_combined(traj, plot_run_names):
  431. labels = []
  432. inh_hdis = []
  433. exc_hdis = []
  434. no_conn_hdi = 0.
  435. colors = ['black', 'red']
  436. for run_idx, run_name in enumerate(plot_run_names):
  437. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  438. if label != 'no conn':
  439. labels.append(label)
  440. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  441. inh_hdis.append(sorted(inh_head_direction_indices))
  442. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  443. exc_hdis.append(sorted(exc_head_direction_indices))
  444. else:
  445. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  446. no_conn_hdi = np.mean(exc_head_direction_indices)
  447. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  448. inh_viol_plt = ax.violinplot(inh_hdis, showmeans=True, showextrema=False)
  449. # viol_plt['cmeans'].set_color('black')
  450. #
  451. # for pc in viol_plt['bodies']:
  452. # pc.set_facecolor('blue')
  453. # pc.set_edgecolor('black')
  454. # pc.set_alpha(0.7)
  455. for b in inh_viol_plt['bodies']:
  456. m = np.mean(b.get_paths()[0].vertices[:, 0])
  457. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  458. b.set_color('b')
  459. exc_viol_plt = ax.violinplot(exc_hdis, showmeans=True, showextrema=False)
  460. for b in exc_viol_plt['bodies']:
  461. m = np.mean(b.get_paths()[0].vertices[:, 0])
  462. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  463. b.set_color('r')
  464. ax.axhline(no_conn_hdi, color='black', linestyle='--')
  465. ax.annotate('no conn', xy=(0.45, 0.48), xycoords='axes fraction')
  466. ax.set_xticks(np.arange(1, len(labels) + 1))
  467. ax.set_xticklabels(labels)
  468. ax.set_ylabel('HDI')
  469. fig.tight_layout()
  470. if save_figs:
  471. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined.svg', dpi=200)
  472. def plot_hdi_violin_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  473. labels = []
  474. inh_hdis = []
  475. exc_hdis = []
  476. no_conn_hdi = 0.
  477. in_polar_plot_hdi = []
  478. ex_polar_plot_hdi = []
  479. colors = ['black', 'red']
  480. for run_idx, run_name in enumerate(plot_run_names):
  481. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  482. if label != 'no conn':
  483. labels.append(label)
  484. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  485. inh_hdis.append(sorted(inh_head_direction_indices))
  486. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  487. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  488. exc_hdis.append(sorted(exc_head_direction_indices))
  489. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  490. else:
  491. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  492. no_conn_hdi = np.mean(exc_head_direction_indices)
  493. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  494. fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.5))
  495. inh_ell_viol_plt = ax.violinplot(inh_hdis[0], showmeans=True, showextrema=False)
  496. for b in inh_ell_viol_plt['bodies']:
  497. m = np.mean(b.get_paths()[0].vertices[:, 0])
  498. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  499. b.set_color('b')
  500. mean_line = inh_ell_viol_plt['cmeans']
  501. mean_line.set_color('b')
  502. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  503. exc_ell_viol_plt = ax.violinplot(exc_hdis[0], showmeans=True, showextrema=False)
  504. for b in exc_ell_viol_plt['bodies']:
  505. m = np.mean(b.get_paths()[0].vertices[:, 0])
  506. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], m, np.inf)
  507. b.set_color('r')
  508. mean_line = exc_ell_viol_plt['cmeans']
  509. mean_line.set_color('r')
  510. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], m, np.inf)
  511. inh_cir_viol_plt = ax.violinplot(inh_hdis[1], showmeans=True, showextrema=False)
  512. for b in inh_cir_viol_plt['bodies']:
  513. m = np.mean(b.get_paths()[0].vertices[:, 0])
  514. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  515. b.set_color('b')
  516. mean_line = inh_cir_viol_plt['cmeans']
  517. mean_line.set_color('b')
  518. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  519. exc_cir_viol_plt = ax.violinplot(exc_hdis[1], showmeans=True, showextrema=False)
  520. for b in exc_cir_viol_plt['bodies']:
  521. m = np.mean(b.get_paths()[0].vertices[:, 0])
  522. b.get_paths()[0].vertices[:, 0] = np.clip(b.get_paths()[0].vertices[:, 0], -np.inf, m)
  523. b.set_color('r')
  524. mean_line = exc_cir_viol_plt['cmeans']
  525. mean_line.set_color('r')
  526. mean_line.get_paths()[0].vertices[:, 0] = np.clip(mean_line.get_paths()[0].vertices[:, 0], -np.inf, m)
  527. ax.axhline(no_conn_hdi, 0.5, 1., color='black', linestyle='--')
  528. ax.axvline(1.0, color='k')
  529. ax.annotate('no conn', xy=(0.75, 0.415), xycoords='axes fraction')
  530. ax.set_xlim(0.5, 1.5)
  531. ax.set_ylim(0.0, 1.0)
  532. ax.set_xticks([0.75, 1.25])
  533. ax.set_xticklabels(['circular', 'ellipsoid'])
  534. ax.set_ylabel('HDI')
  535. fig.tight_layout()
  536. if save_figs:
  537. plt.savefig(FIGURE_SAVE_PATH + 'hdi_violin_combined_and_overlayed.svg', dpi=200)
  538. return ex_polar_plot_hdi, in_polar_plot_hdi
  539. def plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names, ex_polar_plot_id, in_polar_plot_id):
  540. labels = []
  541. inh_hdis = []
  542. exc_hdis = []
  543. no_conn_hdi = 0.
  544. in_polar_plot_hdi = []
  545. ex_polar_plot_hdi = []
  546. colors = ['black', 'red']
  547. for run_idx, run_name in enumerate(plot_run_names):
  548. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  549. if label != 'no conn':
  550. labels.append(adjust_label(label))
  551. inh_head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  552. inh_hdis.append(sorted(inh_head_direction_indices))
  553. in_polar_plot_hdi.append(inh_head_direction_indices[in_polar_plot_id])
  554. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  555. exc_hdis.append(sorted(exc_head_direction_indices))
  556. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  557. else:
  558. exc_head_direction_indices = traj.results.runs[run_name].head_direction_indices
  559. no_conn_hdi = np.mean(exc_head_direction_indices)
  560. ex_polar_plot_hdi.append(exc_head_direction_indices[ex_polar_plot_id])
  561. fig, axes = plt.subplots(2, 1, figsize=(5.5, 4.5))
  562. plt.subplots_adjust(wspace=0, hspace=0)
  563. bins = np.linspace(0.0, 1.0, 21, endpoint=True)
  564. for i in range(2):
  565. # i = i + 1 # grid spec indexes from 0
  566. # ax = fig.add_subplot(gs[i])
  567. ax = axes[i]
  568. ax.hist(exc_hdis[i], color='r', edgecolor='r', alpha=0.3, bins=bins, density=True)
  569. ax.hist(inh_hdis[i], color='b', edgecolor='b', alpha=0.3, bins=bins, density=True)
  570. ax.axvline(np.mean(exc_hdis[i]), color='r')
  571. ax.axvline(np.mean(inh_hdis[i]), color='b')
  572. ax.axvline(no_conn_hdi, color='dimgrey', linestyle='--', linewidth=2.5)
  573. ax.set_ylabel(labels[i], rotation='vertical')
  574. # plt.axis('on')
  575. if i == 0:
  576. ax.set_xticklabels([])
  577. ax.spines['top'].set_visible(False)
  578. else:
  579. ax.set_xlabel('head direction index')
  580. ax.spines['right'].set_visible(False)
  581. plt.gcf().subplots_adjust(bottom=0.15)
  582. plt.gcf().subplots_adjust(left=0.15)
  583. plt.annotate('density', (-0.2,1.15), xycoords='axes fraction', rotation=90, fontsize=18)
  584. # plt.annotate('probability density', (-0.2,1.5), xycoords='axes fraction', rotation=90, fontsize=18)
  585. # fig.tight_layout()
  586. if save_figs:
  587. plt.savefig(FIGURE_SAVE_PATH + 'hdi_historgram_combined_and_overlayed.png', dpi=200)
  588. return ex_polar_plot_hdi, in_polar_plot_hdi
  589. def plot_hdi_histogram_inhibitory(traj, plot_run_names, in_polar_plot_id):
  590. labels = []
  591. hdis = []
  592. colors = ['black', 'red']
  593. for run_idx, run_name in enumerate(plot_run_names):
  594. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  595. if label != 'no conn':
  596. labels.append(label)
  597. head_direction_indices = traj.results.runs[run_name].inh_head_direction_indices
  598. print('inh {}: {}'.format(label, head_direction_indices[in_polar_plot_id]))
  599. hdis.append(head_direction_indices)
  600. fig, ax = plt.subplots(1, 1, figsize=(6, 3))
  601. ax.hist(hdis, color=colors, label=labels, bins=30)
  602. for hdi, color in zip(hdis, colors):
  603. mean_hdi = np.mean(hdi)
  604. ax.axvline(mean_hdi, 0, 1, color=color, linestyle='--')
  605. ax.set_xlabel("HDI")
  606. ax.legend()
  607. fig.tight_layout()
  608. if save_figs:
  609. plt.savefig(FIGURE_SAVE_PATH + 'hdi_histogram_inhibitory.png', dpi=200)
  610. def filter_run_names_by_par_dict(traj, par_dict):
  611. run_name_list = []
  612. for run_idx, run_name in enumerate(traj.f_get_run_names()):
  613. traj.f_set_crun(run_name)
  614. paramters_equal = True
  615. for key, val in par_dict.items():
  616. if (traj.par[key] != val):
  617. paramters_equal = False
  618. if paramters_equal:
  619. run_name_list.append(run_name)
  620. traj.f_restore_default()
  621. return run_name_list
  622. def plot_exc_and_inh_hdi_over_corr_len(traj, plot_run_names):
  623. corr_len_expl = traj.f_get('correlation_length').f_get_range()
  624. seed_expl = traj.f_get('seed').f_get_range()
  625. label_expl = [traj.derived_parameters.runs[run_name].morphology.morph_label for run_name in traj.f_get_run_names()]
  626. label_range = set(label_expl)
  627. exc_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  628. exc_hdi_frame.index.names = ["corr_len", "seed", "label"]
  629. inh_hdi_frame = pd.Series(index=[corr_len_expl, seed_expl, label_expl])
  630. inh_hdi_frame.index.names = ["corr_len", "seed", "label"]
  631. for run_name, corr_len, seed, label in zip(plot_run_names, corr_len_expl, seed_expl, label_expl):
  632. print(label)
  633. ex_tunings = traj.results.runs[run_name].ex_tunings
  634. head_direction_indices = traj.results[run_name].head_direction_indices
  635. #TODO: Actual correlation lengths
  636. # actual_corr_len = get_correlation_length(ex_tunings.reshape((30,30)), 450, 30)
  637. exc_hdi_frame[corr_len, seed, label] = np.mean(head_direction_indices)
  638. inh_head_direction_indices = traj.results[run_name].inh_head_direction_indices
  639. inh_hdi_frame[corr_len, seed, label] = np.mean(inh_head_direction_indices)
  640. # TODO: Standart deviation also for the population
  641. exc_hdi_n_and_seed_mean = exc_hdi_frame.groupby(level=[0, 2]).mean()
  642. exc_hdi_n_and_seed_std_dev = exc_hdi_frame.groupby(level=[0, 2]).std()
  643. inh_hdi_n_and_seed_mean = inh_hdi_frame.groupby(level=[0, 2]).mean()
  644. inh_hdi_n_and_seed_std_dev = inh_hdi_frame.groupby(level=[0, 2]).std()
  645. exc_style_dict = {
  646. 'no conn': ['grey', 'dashed', '', 0],
  647. 'ellipsoid': ['red', 'solid', '^', 8.],
  648. 'circular': ['lightsalmon', 'solid', '^', 8.]
  649. }
  650. inh_style_dict = {
  651. 'no conn': ['grey', 'dashed', '', 0],
  652. 'ellipsoid': ['blue', 'solid', 'o', 8.],
  653. 'circular': ['lightblue', 'solid', 'o', 8.]
  654. }
  655. fig, ax = plt.subplots(1, 1)
  656. for label in label_range:
  657. if label == 'no conn':
  658. ax.axhline(exc_hdi_n_and_seed_mean[1, label], color='grey', linestyle='--')
  659. ax.annotate('input', xy=(1.01, 0.44), xycoords='axes fraction')
  660. continue
  661. exc_hdi_mean = exc_hdi_n_and_seed_mean[:, label]
  662. exc_hdi_std = exc_hdi_n_and_seed_std_dev[:, label]
  663. inh_hdi_mean = inh_hdi_n_and_seed_mean[:, label]
  664. inh_hdi_std = inh_hdi_n_and_seed_std_dev[:, label]
  665. corr_len_range = exc_hdi_mean.keys().to_numpy()
  666. exc_col, exc_lin, exc_mar, exc_mar_size = exc_style_dict[label]
  667. inh_col, inh_lin, inh_mar, inh_mar_size = inh_style_dict[label]
  668. ax.plot(corr_len_range, exc_hdi_mean, label='exc., ' + label, marker=exc_mar, color=exc_col, linestyle=exc_lin, markersize=exc_mar_size, alpha=0.5)
  669. plt.fill_between(corr_len_range, exc_hdi_mean - exc_hdi_std,
  670. exc_hdi_mean + exc_hdi_std, alpha=0.3, color=exc_col)
  671. ax.plot(corr_len_range, inh_hdi_mean, label='inh., ' + label, marker=inh_mar, color=inh_col, linestyle=inh_lin, markersize=inh_mar_size, alpha=0.5)
  672. plt.fill_between(corr_len_range, inh_hdi_mean - inh_hdi_std,
  673. inh_hdi_mean + inh_hdi_std, alpha=0.3, color=inh_col)
  674. ax.set_xlabel('Correlation length')
  675. ax.set_ylabel('Head Direction Index')
  676. ax.axvline(206.9, color='k', linewidth=0.5)
  677. ax.set_ylim(0.0,1.0)
  678. ax.set_xlim(0.0,800.)
  679. tablelegend(ax, ncol=2, bbox_to_anchor=(1, 1),
  680. row_labels=['exc.', 'inh.'],
  681. col_labels=['ellipsoid', 'circular'],
  682. title_label='')
  683. # plt.legend()
  684. if save_figs:
  685. plt.savefig(FIGURE_SAVE_PATH + 'exc_and_inh_hdi_over_corr_len_scaled.png', dpi=200)
  686. def plot_in_degree_map(traj, plot_run_names):
  687. n_ex = int(np.sqrt(traj.N_E))
  688. max_degree = 0
  689. for run_name in plot_run_names:
  690. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  691. exc_degree = np.sum(ie_adjacency, axis=0)
  692. run_max_degree = np.max(exc_degree)
  693. if run_max_degree > max_degree:
  694. max_degree = run_max_degree
  695. fig, axes = plt.subplots(1, 2, figsize=(9., 4.5))
  696. for ax, run_name in zip(axes, plot_run_names[:-1]):
  697. traj.f_set_crun(run_name)
  698. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  699. X, Y = get_position_mesh(traj.results.runs[run_name].ex_positions)
  700. number_of_excitatory_neurons_per_row = int(np.sqrt(traj.N_E))
  701. ie_adjacency = traj.results.runs[run_name].ie_adjacency
  702. exc_degree = np.sum(ie_adjacency, axis=0)
  703. c = ax.pcolor(X, Y, np.reshape(exc_degree, (number_of_excitatory_neurons_per_row,
  704. number_of_excitatory_neurons_per_row)), vmin=0, vmax=max_degree,
  705. cmap='hot')
  706. ax.set_title(label)
  707. fig.colorbar(c, ax=ax, label="in/out-degree")
  708. fig.suptitle('in/out-degree', fontsize=16)
  709. traj.f_restore_default()
  710. if save_figs:
  711. plt.savefig(FIGURE_SAVE_PATH + 'in_degree_map.png', dpi=200)
  712. def plot_spatial_hdi_map(traj, plot_run_names):
  713. max_val = 0
  714. for run_name in plot_run_names:
  715. hdis = traj.results.runs[run_name].head_direction_indices
  716. run_max_val = np.max(hdis)
  717. if run_max_val > max_val:
  718. max_val = run_max_val
  719. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  720. for ax, run_name in zip(axes, plot_run_names):
  721. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  722. positions = traj.results.runs[run_name].ex_positions
  723. head_direction_indices = traj.results[run_name].head_direction_indices
  724. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  725. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  726. ax.set_title(label)
  727. fig.colorbar(c, ax=ax, label="head direction index")
  728. fig.suptitle('spatial HDI map', fontsize=16)
  729. if save_figs:
  730. plt.savefig(FIGURE_SAVE_PATH + 'spatial_hdi_map.png', dpi=200)
  731. def plot_exc_spatial_hdi_map(traj, plot_run_names):
  732. max_val = 0
  733. for run_name in plot_run_names:
  734. hdis = traj.results.runs[run_name].head_direction_indices
  735. run_max_val = np.max(hdis)
  736. if run_max_val > max_val:
  737. max_val = run_max_val
  738. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  739. for ax, run_name in zip(axes, plot_run_names):
  740. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  741. positions = traj.results.runs[run_name].ex_positions
  742. head_direction_indices = traj.results[run_name].head_direction_indices
  743. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  744. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  745. ax.set_title(label)
  746. fig.colorbar(c, ax=ax, label="head direction index")
  747. fig.suptitle('spatial exc. HDI map', fontsize=16)
  748. if save_figs:
  749. plt.savefig(FIGURE_SAVE_PATH + 'spatial_exc_hdi_map.png', dpi=200)
  750. def plot_inh_spatial_hdi_map(traj, plot_run_names):
  751. max_val = 0
  752. for run_name in plot_run_names:
  753. hdis = traj.results.runs[run_name].inh_head_direction_indices
  754. run_max_val = np.max(hdis)
  755. if run_max_val > max_val:
  756. max_val = run_max_val
  757. fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
  758. for ax, run_name in zip(axes, plot_run_names):
  759. label = traj.derived_parameters.runs[run_name].morphology.morph_label
  760. ax_cloud = traj.results.runs[run_name].inhibitory_axonal_cloud_array
  761. positions = [[x, y] for x, y, phi in ax_cloud]
  762. head_direction_indices = traj.results[run_name].inh_head_direction_indices
  763. print('Mean {}-HDI = {}'.format(label, np.mean(head_direction_indices)))
  764. c = plot_hdi_in_space(ax, positions, head_direction_indices, max_val)
  765. ax.set_title(label)
  766. fig.colorbar(c, ax=ax, label="head direction index")
  767. fig.suptitle('spatial inh. HDI map', fontsize=16)
  768. if save_figs:
  769. plt.savefig(FIGURE_SAVE_PATH + 'spatial_inh_hdi_map.png', dpi=200)
  770. if __name__ == "__main__":
  771. traj = Trajectory(TRAJ_NAME, add_time=False, dynamic_imports=Brian2MonitorResult)
  772. NO_LOADING = 0
  773. FULL_LOAD = 2
  774. traj.f_load(filename=DATA_FOLDER + TRAJ_NAME + ".hdf5", load_parameters=FULL_LOAD, load_results=NO_LOADING)
  775. traj.v_auto_load = True
  776. save_figs = True
  777. plt.rcParams.update({'font.size': 16})
  778. plot_corr_len = get_closest_correlation_length(traj, 200.0)
  779. par_dict = {'seed': 1, 'correlation_length': plot_corr_len}
  780. plot_run_names = filter_run_names_by_par_dict(traj, par_dict)
  781. print(plot_run_names)
  782. direction_idx = 6
  783. dir_indices = [0, 3, 6, 9]
  784. plot_axonal_clouds(traj, plot_run_names)
  785. #
  786. ex_polar_plot_id = plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names)
  787. #
  788. in_polar_plot_id, in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names)
  789. #
  790. ex_polar_plot_hdi, in_polar_plot_hdi = plot_hdi_histogram_combined_and_overlayed(traj, plot_run_names,
  791. ex_polar_plot_id, in_polar_plot_id)
  792. plot_excitatory_condensed_polar_plot(traj, plot_run_names, ex_polar_plot_id, ex_polar_plot_hdi)
  793. plot_inhibitory_condensed_polar_plot(traj, plot_run_names, in_polar_plot_id, in_polar_plot_hdi)
  794. # plot_exc_spatial_hdi_map(traj, plot_run_names)
  795. #
  796. # plot_inh_spatial_hdi_map(traj, plot_run_names)
  797. # plot_axonal_clouds(traj, plot_run_names)
  798. #
  799. # ex_polar_plot_id = plot_firing_rate_map_excitatory(traj, direction_idx, plot_run_names)
  800. #
  801. # in_polar_plot_id, in_max_rate = plot_firing_rate_map_inhibitory(traj, direction_idx, plot_run_names)
  802. #
  803. # plot_orientation_maps_diff_scales_with_ellipse(traj)
  804. #
  805. # plot_hdi_histogram_inhibitory(traj, plot_run_names)
  806. #
  807. # plot_hdi_histogram_excitatory(traj, plot_run_names)
  808. #
  809. # plot_hdi_over_corr_len(traj, traj.f_get_run_names())
  810. #
  811. # plot_excitatory_condensed_polar_plot(traj, plot_run_names, ex_polar_plot_id)
  812. #
  813. # plot_inhibitory_condensed_polar_plot(traj, plot_run_names, in_polar_plot_id, in_max_rate)
  814. #
  815. # plot_hdi_violin_combined(traj, plot_run_names)
  816. #
  817. # plot_hdi_violin_combined_and_overlayed(traj, plot_run_names)
  818. plot_exc_and_inh_hdi_over_corr_len(traj, traj.f_get_run_names())
  819. if not save_figs:
  820. plt.show()
  821. traj.f_restore_default()