123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- import numpy as np
- import matplotlib.pyplot as plt
- from opensimplex import OpenSimplex
- from pypet import Trajectory
- from scipy.signal import correlate2d
- from scipy.optimize import leastsq
- from scripts.interneuron_placement import create_grid_of_excitatory_neurons, \
- create_interneuron_sheet_entropy_max_orientation, get_correct_position_mesh
- from scripts.spatial_maps.orientation_maps.orientation_map import OrientationMap
- from scripts.spatial_maps.orientation_maps.orientation_map_generator_pypet import TRAJ_NAME_ORIENTATION_MAPS
- DATA_FOLDER = "../../data/"
- class UniformPerlinMap:
- def __init__(self,x_dim,y_dim, corr_len, sheet_x=0, sheet_y=0, rnd_seed = 1):
- self.x_dim = x_dim
- self.y_dim = y_dim
- self.corr_len = corr_len
- self.sheet_x = sheet_x
- self.sheet_y = sheet_y
- self.rnd_seed = rnd_seed
- noise = OpenSimplex(seed=self.rnd_seed)
- nrow = self.x_dim
- size = self.sheet_x
- scale = corr_len #TODO: Probably this needs to be a linear interpolation
- x = y = np.linspace(0, size, nrow)
- n = [[noise.noise2d(i/scale, j/scale) for j in y] for i in x]
- m = np.concatenate(n)
- sorted_idx = np.argsort(m)
- max_val = nrow * 2
- idx = len(m) // max_val
- for ii, val in enumerate(range(max_val)):
- m[sorted_idx[ii * idx:(ii + 1) * idx]] = val
- p_map = (m - nrow) / nrow
- self.map = p_map.reshape(nrow, -1)
- self.map *= np.pi
- # def uniform_perlin_map(self):
- # noise = OpenSimplex(seed=self.rnd_seed)
- # nrow = self.x_dim
- # size = self.sheet_x
- # x = y = np.linspace(0, size, nrow)
- # n = [[noise.noise2d(i, j) for j in y] for i in x]
- # m = np.concatenate(n)
- # sorted_idx = np.argsort(m)
- # max_val = nrow * 2
- # idx = len(m) // max_val
- # for ii, val in enumerate(range(max_val)):
- # m[sorted_idx[ii * idx:(ii + 1) * idx]] = val
- # landscape = (m - nrow) / nrow
- #
- # # pl.hist(landscape, bins=np.arange(-1, 1.01, 0.01))
- #
- # # pl.matshow(landscape.reshape(nrow, -1), vmin=-1, vmax=1)
- # # pl.colorbar()
- # #
- # # pl.show()
- def get_meshgrid_of_neuron_positions(self):
- xmin = np.min(0.)
- xmax = np.max(self.sheet_x)
- dx = (xmax - xmin) / (self.x_dim - 1)
- ymin = np.min(0.)
- ymax = np.max(self.sheet_y)
- dy = (ymax - ymin) / (self.y_dim - 1)
- X, Y = np.meshgrid(np.arange(xmin, xmax + 2 * dx, dx) - dx / 2., np.arange(ymin, ymax + 2 * dy, dy) - dy / 2.)
- return X, Y
- def get_tuning(self,x,y):
- id_x = int(x * (self.x_dim - 1) / self.sheet_x)
- id_y = int(y * (self.y_dim - 1) / self.sheet_y)
- return self.map[id_x, id_y]
- def get_tuning_by_id(self,id_x, id_y):
- return self.map[id_x, id_y]
- def plot_map(self, ax=None):
- if ax is None:
- fig, ax = plt.subplots(1, 1)
- X, Y = self.get_meshgrid_of_neuron_positions()
- Z = np.zeros((self.x_dim, self.y_dim))
- # For correctly displaying ticks
- for y_idx in range(Z.shape[1]):
- for x_idx in range(Z.shape[0]):
- o_map_val = self.get_tuning_by_id(x_idx, y_idx)
- Z[x_idx, y_idx] = o_map_val
- if ax is None:
- fig = plt.figure()
- ax = fig.add_subplot(111)
- plt.set_cmap('twilight')
- pcm = ax.pcolormesh(X, Y, Z.T)
- plt.gcf().colorbar(pcm, ax=ax)
- def get_orientation_map(correlation_length, seed, sheet_size, N_E, data_folder=None):
- if data_folder is None:
- data_folder = DATA_FOLDER
- traj = Trajectory(filename=data_folder + TRAJ_NAME_ORIENTATION_MAPS + ".hdf5")
- traj.f_load(index=-1, load_parameters=2, load_results=2)
- available_lengths = sorted(list(set(traj.f_get("corr_len").f_get_range())))
- closest_length = available_lengths[np.argmin(np.abs(np.array(available_lengths)-correlation_length))]
- if closest_length!=correlation_length:
- print("Warning: desired correlation length {:.1f} not available. Taking {:.1f} instead".format(
- correlation_length, closest_length))
- corr_len = closest_length
- seed = seed
- map_by_params = lambda x, y: x == corr_len and y == seed
- idx_iterator = traj.f_find_idx(['corr_len', 'seed'], map_by_params)
- # TODO: Since it has only one entry, maybe iterator can be replaced
- for idx in idx_iterator:
- traj.v_idx = idx
- map_angle_grid = traj.crun.map
- number_of_excitatory_neurons_per_row = int(np.sqrt(N_E))
- map = OrientationMap(number_of_excitatory_neurons_per_row + 1, number_of_excitatory_neurons_per_row + 1,
- corr_len, sheet_size, sheet_size, seed)
- map.angle_grid = map_angle_grid
- return map
- def plot_auto_corr(map, ax=None):
- if ax is None:
- fig, ax = plt.subplots(1, 1)
- X, Y = map.get_meshgrid_of_neuron_positions()
- Z = np.zeros((map.x_dim, map.y_dim))
- # For correctly displaying ticks
- for y_idx in range(Z.shape[1]):
- for x_idx in range(Z.shape[0]):
- o_map_val = map.get_tuning_by_id(x_idx, y_idx)
- Z[x_idx, y_idx] = o_map_val
- if ax is None:
- fig = plt.figure()
- ax = fig.add_subplot(111)
- plt.set_cmap('viridis')
- Z = correlate2d(Z,Z)
- Z = Z / np.max(Z)
- pcm = ax.pcolormesh(X, Y, Z[30:,30:].T)
- plt.gcf().colorbar(pcm, ax=ax)
- return Z
- def get_auto_corr(map):
- Z = np.zeros((map.x_dim, map.y_dim))
- for y_idx in range(Z.shape[1]):
- for x_idx in range(Z.shape[0]):
- o_map_val = map.get_tuning_by_id(x_idx, y_idx)
- Z[x_idx, y_idx] = o_map_val
- Z = correlate2d(Z,Z)
- Z = Z / np.max(Z)
- return Z
- def f(x, u, v, z_data):
- corr_len = x[0]
- modelled_z = np.exp(-(np.sqrt(u**2 + v**2)/corr_len))
- diffs = modelled_z - z_data
- return diffs.flatten()
- def fit_correlation(map, data):
- # Here we need the actual positions and not ticks for plotting
- x_max = map.sheet_x
- d_x = x_max / (map.x_dim - 1)
- y_max = map.sheet_y
- d_y = y_max / (map.y_dim - 1)
- u_range = np.arange(0, x_max + d_x, d_x)
- v_range = np.arange(0, y_max + d_y, d_y)
- u, v = np.meshgrid(u_range, v_range)
- u = u.flatten()
- v = v.flatten()
- data = data.flatten()
- result = leastsq(f, [50.5], args=(u, v, data))
- return result
- def test_plot_fit_func(corr_len, map, ax=None):
- if ax is None:
- fig, ax = plt.subplots(1, 1)
- X, Y = map.get_meshgrid_of_neuron_positions()
- Z = np.exp(-(np.sqrt(X**2 + Y**2)/corr_len))
- plt.set_cmap('viridis')
- pcm = ax.pcolormesh(X, Y, Z.T)
- plt.gcf().colorbar(pcm, ax=ax)
- def plot_example(map):
- fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.5))
- map.plot_map(ax=axes[0])
- axes[0].set_title('map')
- map_auto_corr = plot_auto_corr(map, ax=axes[1])
- axes[1].set_title('auto correlation')
- map_corr_len = fit_correlation(map, map_auto_corr[30:, 30:])
- axes[2].set_title('fit of auto correl.')
- test_plot_fit_func(map_corr_len[0][0], map, ax=axes[2])
- def get_correlation_length(tunings, size, dim):
- tun_auto_corr = correlate2d(tunings, tunings)
- tun_auto_corr = tun_auto_corr / np.max(tun_auto_corr)
- x_max = size
- d_x = x_max / (dim - 1)
- y_max = size
- d_y = y_max / (dim - 1)
- u_range = np.arange(0, x_max + d_x, d_x)
- v_range = np.arange(0, y_max + d_y, d_y)
- u, v = np.meshgrid(u_range, v_range)
- u = u.flatten()
- v = v.flatten()
- data = tun_auto_corr[30:, 30:].flatten()
- tun_corr_len = leastsq(f, [50.5], args=(u, v, data))
- return tun_corr_len[0][0]
- def plot_map_and_axons(map, n_inh):
- ex_positions, ex_tunings = create_grid_of_excitatory_neurons(map.sheet_x, map.sheet_y, map.x_dim, map.get_tuning)
- inhibitory_axonal_clouds, _ = create_interneuron_sheet_entropy_max_orientation(
- ex_positions, ex_tunings, n_inh, inhibitory_axon_long_axis,
- inhibitory_axon_short_axis, size,
- size, trial_orientations=30)
- X, Y = get_correct_position_mesh(ex_positions)
- fig, ax = plt.subplots(1, 1)
- head_dir_preference = np.array(ex_tunings).reshape((dim, dim))
- # TODO: Why was this transposed for plotting? (now changed)
- c = ax.pcolor(X, Y, head_dir_preference, vmin=-np.pi, vmax=np.pi, cmap='hsv')
- for i, p in enumerate(inhibitory_axonal_clouds):
- ell = p.get_ellipse()
- ax.add_artist(ell)
- ax.set_title('Perlin, size: {}'.format(size))
- ax.set_aspect('equal')
- if __name__ == '__main__':
- dim = 30
- size = 450
- inhibitory_axon_long_axis = 100
- inhibitory_axon_short_axis = 25
- corr_len = 200
- # orientation_map = get_orientation_map(corr_len, 1, size, dim*dim)
- # orientation_map.plot_map()
- # plt.gca().set_title('Orientation, size: {}'.format(size))
- perlin_map = UniformPerlinMap(dim, dim, corr_len, size, size, 1)
- plot_map_and_axons(perlin_map, 100)
- size = 900
- dim = 90
- perlin_map = UniformPerlinMap(dim, dim, corr_len, size, size, 1)
- plot_map_and_axons(perlin_map, 400)
- plt.gca().set_title('Perlin, size: {}'.format(size))
- # orientation_map = get_orientation_map(corr_len, 1, size, dim * dim)
- # orientation_map.plot_map()
- # plt.gca().set_title('Orientation, size: {}'.format(size))
- # corr_len_range = np.linspace(1.0, 400.0, 30, endpoint=True)
- # p_map_corr_lengths = []
- # o_map_corr_lengths = []
- # for corr_len in corr_len_range:
- # perlin_map = UniformPerlinMap(dim + 1, dim + 1, corr_len, size, size, 1)
- # p_map_auto_corr = get_auto_corr(perlin_map)
- # p_map_corr_len = fit_correlation(perlin_map, p_map_auto_corr[30:, 30:])
- # print('perlin: ', p_map_corr_len)
- # p_map_corr_lengths.append(p_map_corr_len[0][0])
- #
- # orientation_map = get_orientation_map(corr_len, 1, size, dim * dim)
- #
- # o_map_auto_corr = get_auto_corr(orientation_map)
- # o_map_corr_len = fit_correlation(orientation_map, o_map_auto_corr[30:, 30:])
- # print('orientation ', o_map_corr_len)
- # o_map_corr_lengths.append(o_map_corr_len[0][0])
- #
- # fig, ax = plt.subplots(1, 1)
- #
- # ax.plot(corr_len_range,o_map_corr_lengths, label='orientation')
- # ax.plot(corr_len_range,p_map_corr_lengths, label='perlin')
- # plt.legend()
- plt.show()
|