eval_sim_data.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. # -*- coding: utf-8 -*-
  2. from .sim_data import SimInfo
  3. import numpy as np
  4. from scipy.fftpack import fftn
  5. from scipy.signal import convolve2d
  6. import os
  7. #import plot3d
  8. import pylab as plt
  9. import random
  10. from .connection import Connection
  11. import pickle
  12. from .idl import dist
  13. from objsimpy.distance import cyclic_distance
  14. from scipy.ndimage import measurements
  15. from matplotlib.colors import Normalize
  16. from .colormaps import cmap_color_circle
  17. from objsimpy.stim.stim_movie import StimulusSet
  18. from objsimpy.stim.stim_sequence import InputTrace
  19. from itertools import product
  20. import scipy.ndimage
  21. import logging
  22. logger = logging.getLogger(__name__)
  23. def pickled_results(func):
  24. def new_func(filename, *args, **kwargs):
  25. reevaluate = False
  26. if 'reevaluate' in kwargs:
  27. reevaluate = kwargs['reevaluate']
  28. del kwargs['reevaluate']
  29. if not os.path.isfile(filename):
  30. reevaluate = True
  31. if reevaluate:
  32. logger.info("calculating ...")
  33. result = func(*args, **kwargs)
  34. logger.info("pickling result to: {}".format(filename))
  35. f = open(filename, "wb")
  36. pickle.dump(result, f)
  37. f.close()
  38. return result
  39. else:
  40. logger.info("loading pickled result from: {}".format(filename))
  41. f = open(filename, "rb")
  42. result = pickle.load(f)
  43. f.close
  44. return result
  45. return new_func
  46. class Psth():
  47. def __init__(self, psth=None, time_start=0):
  48. self.psth = psth
  49. self.time_start = time_start
  50. self.time_stop = psth.shape[2] + time_start - 1
  51. self.time_steps = psth.shape[2]
  52. self.n_neurons = psth.shape[0]
  53. self.n_stimuli = psth.shape[1]
  54. #print "n_neurons = ", self.n_neurons
  55. def generate_spikes_and_input(psth, n_stimulus_repetitions=None):
  56. psth_remaining = psth.psth.copy()
  57. time = 0
  58. neuron_numbers = []
  59. spike_times = []
  60. inputs = np.array([])
  61. if n_stimulus_repetitions is None or n_stimulus_repetitions < psth_remaining.max():
  62. n_stimulus_repetitions = psth_remaining.max()
  63. #print "corrected n_stimulus_repetitions to ", n_stimulus_repetitions
  64. no_stim_time = 10
  65. stimuli = range(psth.n_stimuli) * n_stimulus_repetitions
  66. #print "stimuli = ", stimuli
  67. while len(stimuli) > 0:
  68. inputs = np.hstack((inputs, np.array([-1]*no_stim_time)))
  69. time += no_stim_time
  70. # pick a stimulus
  71. stimulus = random.choice(stimuli)
  72. stimuli.remove(stimulus)
  73. #print stimuli
  74. repetitions_left = sum(np.array(stimuli) == stimulus)
  75. if psth.time_start < 0:
  76. inputs = np.hstack((inputs, np.array([-1]*(-1*psth.time_start))))
  77. inputs = np.hstack((inputs, np.array([stimulus] * (1 + psth.time_stop))))
  78. else:
  79. inputs = np.hstack((inputs, ([stimulus] * (1 + psth.time_stop - psth.time_start))))
  80. # generate spikes
  81. for t in range(psth.time_steps):
  82. for neuron in range(psth.n_neurons):
  83. spike_probability = psth_remaining[neuron, stimulus, t] / repetitions_left
  84. if random.random() < spike_probability:
  85. neuron_numbers.append(neuron)
  86. spike_times.append(time)
  87. psth_remaining[neuron, stimulus, t] -= 1
  88. time += 1
  89. return ((spike_times, neuron_numbers), InputTrace(nparray=inputs), n_stimulus_repetitions)
  90. def generate_dummy_psth(psth_start=-40, psth_stop=100, n_stimuli=2, n_neurons=2):
  91. n_time_steps = psth_stop - psth_start + 1
  92. psth = np.zeros(n_time_steps * n_neurons * n_stimuli).reshape(n_neurons, n_stimuli, n_time_steps)
  93. for neuron in range(n_neurons):
  94. for stimulus in range(n_stimuli):
  95. psth[neuron, stimulus, :] = exp_decrease(psth_start, psth_stop, 4.9*(stimulus+1.) + 6.4 * np.sin(0.9*np.pi*(neuron+1)/n_neurons), 5+neuron*stimulus)
  96. psth = psth.round()
  97. return Psth(psth, psth_start)
  98. def exp_decrease(start, stop, initial=1.0, time_factor=1.0):
  99. if start < 0:
  100. values = np.hstack((np.zeros(-start + 1), initial * np.exp(-time_factor * np.linspace(0, 1, stop))))
  101. else:
  102. values = np.hstack((np.zeros(1), initial * np.exp(-time_factor * np.linspace(0, 1, stop))))
  103. return values
  104. def calc_psth(spikes, input, psth_start=0, psth_stop=10):
  105. stim_onsets = input.get_stim_onset()
  106. n_neurons = max(spikes[1]) + 1
  107. n_stimuli = max(input.nparray) + 1
  108. n_time_steps = psth_stop - psth_start + 1
  109. psth_array_size = n_neurons * n_stimuli * n_time_steps
  110. #print "psth_array_size=", psth_array_size
  111. #print "n_neurons = ", n_neurons
  112. #print "n_stimuli = ", n_stimuli
  113. #print "n_time_steps = ", n_time_steps
  114. psth = np.zeros(psth_array_size).reshape(n_neurons, n_stimuli, n_time_steps)
  115. spike_times = spikes[0]
  116. neuron_numbers = spikes[1]
  117. spike_indices = np.arange(len(spike_times))
  118. stim_frequencies = {}
  119. for onset in stim_onsets:
  120. onset_time = onset[0]
  121. stimulus_number = onset[1]
  122. # count stimulus repetitions
  123. if stimulus_number in stim_frequencies:
  124. stim_frequencies[stimulus_number] += 1
  125. else:
  126. stim_frequencies[stimulus_number] = 1
  127. # find all spikes between (onset_time + psth_start) and (onset_time + psth_stop)
  128. relevant_spikes_mask = ((onset_time + psth_start) <= spike_times) * (spike_times <= (onset_time + psth_stop))
  129. relevant_spike_indices = spike_indices[relevant_spikes_mask]
  130. for index in relevant_spike_indices:
  131. psth[neuron_numbers[index], stimulus_number, spike_times[index] - onset_time - psth_start] += 1
  132. #print stim_frequencies
  133. # devide by number of stimulus repetitions
  134. for stim in stim_frequencies:
  135. psth[:,stim,:] /= stim_frequencies[stim]
  136. return Psth(psth=psth)
  137. def calc_total_psth(spikes, input, psth_start=0, psth_stop=10, n_neurons=None):
  138. n_time_steps = psth_stop - psth_start + 1
  139. total_psth = np.zeros(n_time_steps)
  140. stim_onsets = input.get_stim_onset()
  141. n_stimulus_events = len(stim_onsets)
  142. stim_times_array = np.arange(n_time_steps) + psth_start
  143. time_psth_time_map = {}
  144. for onset in stim_onsets:
  145. onset_time = onset[0]
  146. #print "onset = ", onset
  147. for t in stim_times_array:
  148. absolute_time = t + onset_time
  149. time_psth_time_map.setdefault(absolute_time, []).append(t)
  150. #print time_psth_time_map
  151. spike_times = spikes[0]
  152. neuron_numbers = spikes[1]
  153. for spike_time in spike_times:
  154. if spike_time in time_psth_time_map:
  155. total_psth[time_psth_time_map[spike_time]] += 1
  156. #print "total psth = ", total_psth
  157. # devide by number of stimulus repetitions
  158. total_psth /= n_stimulus_events
  159. if n_neurons is None:
  160. n_neurons = max(neuron_numbers) + 1
  161. total_psth /= n_neurons
  162. return total_psth
  163. def calc_response_strength(spikes, input, psth_start=0, psth_stop=10):
  164. stim_onsets = input.get_stim_onset()
  165. n_neurons = max(spikes[1]) + 1
  166. n_stimuli = max(input.nparray) + 1
  167. n_time_steps = psth_stop - psth_start + 1
  168. response_array_size = n_neurons * n_stimuli
  169. logger.info("response_array_size = {}".format(response_array_size))
  170. logger.info("n_neurons = {}".format(n_neurons))
  171. logger.info("n_stimuli = {}".format(n_stimuli))
  172. logger.info("n_time_steps = {}".format(n_time_steps))
  173. stim_responses = np.zeros(response_array_size).reshape(n_neurons, n_stimuli)
  174. spike_times = spikes[0]
  175. neuron_numbers = spikes[1]
  176. spike_indices = np.arange(len(spike_times))
  177. stim_frequencies = {}
  178. for onset in stim_onsets:
  179. onset_time = onset[0]
  180. stimulus_number = onset[1]
  181. # count stimulus repetitions
  182. stim_frequencies[stimulus_number] = 1 + stim_frequencies.get(stimulus_number, 0)
  183. # find all spikes between (onset_time + psth_start) and (onset_time + psth_stop)
  184. relevant_spikes_mask = ((onset_time + psth_start) <= spike_times) * (spike_times <= (onset_time + psth_stop))
  185. relevant_spike_indices = spike_indices[relevant_spikes_mask]
  186. for index in relevant_spike_indices:
  187. stim_responses[neuron_numbers[index], stimulus_number] += 1
  188. # devide by number of stimulus repetitions
  189. for stim in stim_frequencies:
  190. stim_responses[:,stim] /= stim_frequencies[stim]
  191. return stim_responses
  192. def calc_response_strength2(spikes, input, psth_start=0, psth_stop=10, n_neurons=None):
  193. stim_onsets = input.get_stim_onset()
  194. #stim_onsets = stim_onsets[:-2] # remove this later!!!!!!!!!!
  195. if n_neurons is None:
  196. n_neurons = max(spikes[1]) + 1
  197. n_stimuli = max(input.nparray) + 1
  198. n_time_steps = psth_stop - psth_start + 1
  199. response_array_size = n_neurons * n_stimuli
  200. #print "response_array_size=", response_array_size
  201. #print "n_neurons = ", n_neurons
  202. #print "n_stimuli = ", n_stimuli
  203. #print "n_time_steps = ", n_time_steps
  204. stim_responses = np.zeros(response_array_size).reshape(n_neurons, n_stimuli)
  205. spike_times = spikes[0]
  206. neuron_numbers = spikes[1]
  207. spike_indices = np.arange(len(spike_times))
  208. stim_frequencies = {}
  209. time_stimulus_map = {}
  210. stim_times_array = np.arange(n_time_steps) + psth_start
  211. logger.debug("stim_onsets={}".format(stim_onsets))
  212. for onset in stim_onsets:
  213. onset_time = onset[0]
  214. stimulus_number = onset[1]
  215. # count stimulus repetitions
  216. stim_frequencies[stimulus_number] = 1 + stim_frequencies.get(stimulus_number, 0)
  217. stim_times = onset_time + stim_times_array
  218. for t in stim_times:
  219. time_stimulus_map.setdefault(t, []).append(stimulus_number)
  220. for spike_index in spike_indices:
  221. if spike_times[spike_index] in time_stimulus_map:
  222. for stim in time_stimulus_map[spike_times[spike_index]]:
  223. stim_responses[neuron_numbers[spike_index], stim] += 1
  224. # devide by number of stimulus repetitions
  225. for stim in stim_frequencies:
  226. logger.debug("stim = {}, freq = {}".format(stim, stim_frequencies[stim]))
  227. stim_responses[:,stim] /= stim_frequencies[stim]
  228. return stim_responses
  229. calc_response_strength2_PICKLED = pickled_results(calc_response_strength2)
  230. def calc_preferred_stimuli_from_response(response_strength, target_nx=100, target_ny=100, mark_low_responses=None):
  231. """Calculates preferred stimuli from spike responses
  232. parameters:
  233. @param response_strength: 2d array with dimensions (number_of_neurons, number_of_stimuli)
  234. @param target_nx: x dimension of target layer
  235. @param target_nx: y dimension of target layer
  236. @param mark_low_responses: if this parameter is not None, neurons with low response will be markt with the given value
  237. @return: preferred_stimuli, selectivity
  238. """
  239. # add jitter, otherwise at same response_strength lower stimulus number would always
  240. # be counted as preferred (see doc of argmax)
  241. jitter_amplitude = 0.0001
  242. jitter = jitter_amplitude * np.random.random(response_strength.shape)
  243. preferred_stimuli = (response_strength + jitter).argmax(axis=1).reshape(target_ny, target_nx)
  244. # calculate selectivity
  245. max_responses = response_strength.max(axis=1)
  246. min_responses = response_strength.min(axis=1)
  247. selectivity = (max_responses - min_responses) / (max_responses + min_responses)
  248. selectivity = selectivity.reshape(target_ny, target_nx)
  249. # mark low responses
  250. if mark_low_responses is not None:
  251. total_mean_response = response_strength.mean()
  252. low_response = response_strength.max(axis=1) < total_mean_response
  253. preferred_stimuli[low_response.reshape(target_nx, target_ny)] = mark_low_responses
  254. selectivity[low_response]=0
  255. return preferred_stimuli, selectivity
  256. calc_preferred_stimuli_from_response_PICKLED = pickled_results(calc_preferred_stimuli_from_response)
  257. def calc_pref_stim_maps_from_response(response_strength, stim_nx, stim_ny, target_nx=100, target_ny=100):
  258. """Calculate preferred x and y stimulus maps from response strength
  259. @param response_strength: 2d array with dimensions (number_of_neurons, number_of_stimuli)
  260. @param stim_nx: x dimension of 2d stimulus space
  261. @param stim_nx: y dimension of 2d stimulus space
  262. @param target_nx: x dimension of target layer
  263. @param target_nx: y dimension of target layer
  264. @return: preferred_stimuli_x, preferred_stimuli_y, selectivity
  265. """
  266. low_res_mark = -2
  267. preferred_stimuli, selectivity = calc_preferred_stimuli_from_response(response_strength, target_nx, target_ny, low_res_mark)
  268. low_response = np.where(preferred_stimuli==low_res_mark)
  269. preferred_stimuli_x = preferred_stimuli % stim_nx
  270. preferred_stimuli_y = preferred_stimuli / stim_nx
  271. # mark map neurons with too low response with low_res_mark
  272. preferred_stimuli_x[low_response] = low_res_mark
  273. preferred_stimuli_y[low_response] = low_res_mark
  274. return preferred_stimuli_x, preferred_stimuli_y, selectivity
  275. calc_pref_stim_maps_from_response_PICKLED = pickled_results(calc_pref_stim_maps_from_response)
  276. def calc_preferred_stim_from_weights(weight_file, stim_file, fname_stimcorr_pickle=None, stimcorr_reeval=True):
  277. con = Connection()
  278. con.load_weight_file(weight_file)
  279. incomming_weight_sum = con.get_incomming_weight_sum()
  280. stimset = StimulusSet(stim_file)
  281. if fname_stimcorr_pickle is not None:
  282. response_strength = calc_weight_stim_correlation_PICKLED(fname_stimcorr_pickle, con, stimset.pics, reevaluate=stimcorr_reeval)
  283. else:
  284. response_strength = calc_weight_stim_correlation(con, stimset.pics)
  285. max_response = response_strength.max(axis=1)
  286. min_response = response_strength.min(axis=1)
  287. selectivity = (max_response - min_response) / (max_response + min_response)
  288. # add jitter, otherwise at same response_strength lower stimulus number would always
  289. # be counted as preferred (see doc of argmax)
  290. jitter_amplitude = 0.0001
  291. jitter = jitter_amplitude * np.random.random(response_strength.shape)
  292. target_nx = con.get_target_nx()
  293. target_ny = con.get_target_ny()
  294. preferred_stimuli = (response_strength + jitter).argmax(axis=1).reshape(target_ny, target_nx)
  295. selectivity = selectivity.reshape(target_ny, target_nx)
  296. return preferred_stimuli, selectivity, incomming_weight_sum
  297. calc_preferred_stim_from_weights_PICKLED = pickled_results(calc_preferred_stim_from_weights)
  298. def calc_pref_stim_maps_from_weights(weight_file, stim_file, stim_nx=20, stim_ny=20, fname_stimcorr_pickle=None, stimcorr_reeval=True):
  299. preferred_stimuli, selectivity, inc_sum = calc_preferred_stim_from_weights(weight_file, stim_file, fname_stimcorr_pickle)
  300. pref_x = preferred_stimuli % stim_nx
  301. pref_y = preferred_stimuli / stim_nx
  302. return pref_x, pref_y, selectivity, inc_sum
  303. calc_pref_stim_maps_from_weights_PICKLED = pickled_results(calc_pref_stim_maps_from_weights)
  304. def calc_weight_stim_correlation(connection, stim_pics):
  305. n_target = connection.get_n_target()
  306. n_stimuli = len(stim_pics)
  307. response_strength = np.zeros((n_target, n_stimuli))
  308. for target in range(n_target):
  309. weights = connection.get_incomming_matrix(target)
  310. for stim in range(n_stimuli):
  311. response = weights * stim_pics[stim]
  312. response_strength[target, stim] += response.sum()
  313. return response_strength
  314. calc_weight_stim_correlation_PICKLED = pickled_results(calc_weight_stim_correlation)
  315. def plot_psth(psth, ax=None):
  316. logging.info("plotting psth")
  317. if ax is None:
  318. fig = plt.figure()
  319. ax = fig.add_subplot(1,1,1)
  320. n_neurons = psth.shape[0]
  321. n_stim = psth.shape[1]
  322. n_time_steps = psth.shape[2]
  323. for neuron in range(n_neurons):
  324. for stim in range(n_stim):
  325. ax.plot(psth[neuron,stim,:], 'x-')
  326. def calc_patch_sizes(weight_file, stim_file, stim_nx, stim_ny, neurons_nx, neurons_ny):
  327. logging.info("calculating patch sizes")
  328. con = Connection()
  329. con.load_weight_file(weight_file)
  330. stimset = StimulusSet(stim_file)
  331. pickle_file = weight_file + "_stimcorr.pickle"
  332. response_strength = calc_weight_stim_correlation_PICKLED(pickle_file, con, stimset.pics)
  333. logging.info("response_strength.shape = {}".format(response_strength.shape))
  334. n_neurons, n_stim = response_strength.shape
  335. if stim_nx * stim_ny != n_stim:
  336. raise Exception("stimulus dimensions don't match with response strength array!")
  337. if neurons_nx * neurons_ny != n_neurons:
  338. raise Exception("neuron map dimensions don't match with number of neurons!")
  339. # calculate single para map
  340. response_strength_xy = response_strength.reshape(n_neurons, stim_ny, stim_nx)
  341. x_responses = response_strength_xy.sum(1)
  342. y_responses = response_strength_xy.sum(2)
  343. x_patch_sizes = [calc_patch_size(response) for response in x_responses]
  344. y_patch_sizes = [calc_patch_size(response) for response in y_responses]
  345. return x_patch_sizes, y_patch_sizes
  346. def calc_patch_size(image):
  347. ''' berechnet die patch size anhand der Fourier-Transformation
  348. Implementierung nur sinnvoll für quadratische images
  349. '''
  350. nx, ny = image.shape
  351. if nx != ny:
  352. raise ValueError("image must be quadratic! (image.shape[0] == image.shape[1]")
  353. #print "nx=" + str(nx)
  354. #print "ny=" + str(ny)
  355. mwfree_image = image - np.mean(image)
  356. fft_image = np.abs(fftn(mwfree_image))
  357. #plt.imshow(abs(fft_image), interpolation='nearest')
  358. #plt.show()
  359. max_ind = fft_image.argmax()
  360. #print "max_ind=" + str(max_ind)
  361. d = dist(ny, nx)
  362. #print "d.shape=" + str(d.shape)
  363. max_x = max_ind % nx
  364. max_y = max_ind // nx
  365. max_dist = d[max_y, max_x]
  366. # frequency in cycles per pixel
  367. freq = max_dist / nx
  368. # wave length in pixel
  369. wave_length = 1./freq
  370. return wave_length
  371. def evaluate_patch_dynamics(spike_data, patch_border_threshold=0.3, patch_noise_ratio_threshold=10., dt=100., t_start=None):
  372. """Calculate activity patch meausres: size, speed, amplitude"""
  373. spike_movie, time_scale = spike_data.calc_spike_movie(dt=dt, t_start=t_start, time_scale=True)
  374. n_frames = spike_movie.shape[2]
  375. patch_positions = []
  376. patch_sizes = []
  377. patch_amplitudes = []
  378. background_noise = []
  379. total_activity = []
  380. n_neurons = spike_movie.shape[0] * spike_movie.shape[1]
  381. for i in range(n_frames):
  382. frame = spike_movie[:,:,i]
  383. peak_pos, patch_pos, patch_size, patch_indices = find_biggest_patch(frame, threshold=patch_border_threshold)
  384. patch_spikes = [frame[x, y] for x, y in patch_indices]
  385. non_patch_spikes = np.sum(frame) - np.sum(patch_spikes)
  386. total_frame_activity = np.sum(frame)
  387. total_activity.append(total_frame_activity)
  388. mean_non_patch_spikes = non_patch_spikes / (n_neurons - len(patch_indices))
  389. background_noise_hz = 1000. * mean_non_patch_spikes / dt # in Hz
  390. amplitude = np.mean(patch_spikes)
  391. patch_amplitude_hz = 1000. * amplitude / dt # in Hz
  392. patch_amplitudes.append(patch_amplitude_hz)
  393. if patch_amplitude_hz < patch_noise_ratio_threshold*background_noise_hz:
  394. patch_size = 0
  395. patch_sizes.append(patch_size)
  396. background_noise.append(background_noise_hz)
  397. if patch_size > 0:
  398. patch_positions.append(np.array(patch_pos))
  399. else:
  400. patch_positions.append(np.array([-1, -1]))
  401. sqrt_patch_sizes = np.sqrt(patch_sizes)
  402. valid_patches = np.where(sqrt_patch_sizes>0)
  403. n_valid_patches = len(valid_patches[0])
  404. median_sqrt_patch_size = np.median(sqrt_patch_sizes[valid_patches])
  405. speed_sec = calc_patch_speed(patch_positions,
  406. sqrt_patch_sizes,
  407. dt,
  408. 0.5*median_sqrt_patch_size,
  409. spike_movie.shape)
  410. valid_speed_patches = np.where(sqrt_patch_sizes[:-1]>0)
  411. median_speed = np.median(speed_sec[valid_speed_patches])
  412. return {'patch_size': sqrt_patch_sizes,
  413. 'patch_size_median': median_sqrt_patch_size,
  414. 'n_valid_patches': n_valid_patches,
  415. 'n_frames': n_frames,
  416. 'patch_speed': speed_sec,
  417. 'patch_speed_median': median_speed,
  418. 'patch_amplitude': patch_amplitudes,
  419. 'patch_positions': patch_positions,
  420. 'background_noise': background_noise,
  421. 'time_scale': time_scale,
  422. }
  423. def find_biggest_patch(frame, threshold=0.5, sigma=1.5):
  424. """Find biggest activity patch in image (e.g. single spike movie frame)
  425. @param frame: two dimensional ndarray
  426. @param threshold: relative patch threshold between mean and max activity (values must be between 0 and 1)
  427. all contiguous neurons around the maximum activity neuron
  428. which spike above the threshold belong to the patch
  429. @return: peak position, patch size, and coordinates of all neurons that belong to the patch
  430. """
  431. if threshold <= 0 or threshold >= 1:
  432. raise ValueError("Threshold must be in open interval (0,1)")
  433. rows, cols = frame.shape
  434. frame = scipy.ndimage.filters.gaussian_filter(frame, sigma, order=0, output=None, mode='wrap', truncate=4.0)
  435. peak_pos = np.unravel_index(np.argmax(frame), frame.shape)
  436. peak_value = frame[peak_pos]
  437. mean_value = np.mean(frame)
  438. # shift peak to center
  439. shift_rows = int(0.5*rows - peak_pos[0])
  440. shift_cols = int(0.5*cols - peak_pos[1])
  441. frame = np.roll(frame, shift_rows, axis=0)
  442. frame = np.roll(frame, shift_cols, axis=1)
  443. abs_threshold = mean_value + threshold * (peak_value-mean_value)
  444. # create mask of neurons above threshold
  445. above_threshold_pixels = frame > abs_threshold
  446. if sum(above_threshold_pixels.flat) == 0:
  447. return (0,0), (0,0), 0, []
  448. # and now: a miracle occurs
  449. # find indices of masked array entries, flood fill to detect contiguous regions
  450. # http://stackoverflow.com/questions/9440921/identify-contiguous-regions-in-2d-numpy-array
  451. labels, numL = scipy.ndimage.label(above_threshold_pixels)
  452. label_indices = [(labels == i).nonzero() for i in xrange(1, numL+1)]
  453. # find biggest patch
  454. patch_sizes = [len(li[0]) for li in label_indices]
  455. ind_biggest_patch = np.argmax(patch_sizes)
  456. patch_size = patch_sizes[ind_biggest_patch]
  457. patch_indices = label_indices[ind_biggest_patch]
  458. # calculate patch position based on centered patch,
  459. # then shift back coordinates to original position
  460. patch_pos = [(np.mean(patch_indices[0]) - shift_rows) % rows,
  461. (np.mean(patch_indices[1]) - shift_cols) % cols
  462. ]
  463. # patch was shifted to image center for flood fill to work correctly
  464. # now correct indices for original (unshifted) patch position
  465. patch_indices = zip((patch_indices[0] - shift_rows) % rows,
  466. (patch_indices[1] - shift_cols) % cols,
  467. )
  468. return peak_pos, patch_pos, patch_size, patch_indices
  469. def calc_patch_speed(patch_positions, patch_sizes, delta_t, patch_threshold, layer_dim):
  470. """Calculate patch speed from x and y positions"""
  471. pos_x, pos_y = zip(*patch_positions)
  472. pos_x = np.array(pos_x)
  473. pos_y = np.array(pos_y)
  474. pos_diff_x = cyclic_distance(pos_x[1:], pos_x[0:-1], layer_dim[0])
  475. pos_diff_y = cyclic_distance(pos_y[1:], pos_y[0:-1], layer_dim[1])
  476. pos_diff = np.sqrt(pos_diff_x**2 + pos_diff_y**2)
  477. speed_sec = 1000. * pos_diff / delta_t # in neurons per second
  478. # mark invalid speed values
  479. invalid_speed = -1
  480. speed_sec[pos_x[1:] < 0] = invalid_speed # no valid patch position
  481. speed_sec[pos_x[0:-1] < 0] = invalid_speed # no valid patch position
  482. speed_sec[patch_sizes[0:-1] < patch_threshold] = invalid_speed # patch too small
  483. return speed_sec
  484. def find_pinwheels_pref_x_prototype_matching(movie_file, weight_file, num_stimuli = [20,20], plot_path = None, preferred_stim_pickle_file = None):
  485. if preferred_stim_pickle_file == None:
  486. preferences = calc_pref_stim_maps_from_weights(weight_file, movie_file, stim_nx = num_stimuli[0], stim_ny = num_stimuli[1])
  487. else:
  488. preferences = calc_pref_stim_maps_from_weights_PICKLED(preferred_stim_pickle_file, weight_file, movie_file, stim_nx = num_stimuli[0], stim_ny = num_stimuli[1])
  489. prototypes = set_up_prototypes(num_stimuli[0])
  490. bank = set_up_pw_prototype_bank(prototypes)
  491. matched_img = do_pattern_matching_pref_stim_prototype_bank(preferences[0], bank, num_stimuli = num_stimuli[0])
  492. pinwheels = eval_pattern_matching(matched_img)
  493. if plot_path != None:
  494. preferences[0][pinwheels] = -5
  495. preferences[1][pinwheels] = -5
  496. norm = Normalize(vmin = 0.)
  497. plt.imshow(preferences[0], origin = 'lower', interpolation = 'Nearest', cmap = cmap_color_circle(), norm = norm)
  498. plt.colorbar()
  499. plt.show()
  500. plt.savefig(os.path.join(plot_path, 'pref_x_with_pws.png'))
  501. norm = Normalize(vmin = 0.)
  502. plt.imshow(preferences[1], origin = 'lower', interpolation = 'Nearest', cmap = cmap_color_circle(), norm = norm)
  503. plt.colorbar()
  504. plt.show()
  505. plt.savefig(os.path.join(plot_path, 'pref_y_with_pws.png'))
  506. return pinwheels
  507. def set_up_prototypes(nx):
  508. r = [-3.,-2.,-1.,0,1.,2.,3.]
  509. proto = np.array([[(np.arctan2(i,j)/np.pi-1.)/2.*-nx for i in r] for j in r])
  510. proto2 = np.zeros((len(r),len(r)))
  511. c0 = 0
  512. c1 = 0
  513. for i in r:
  514. c1 = 0
  515. for j in r:
  516. phi = (np.arctan2(i,j)/np.pi -1.)/(-2.)
  517. if phi>=0.125:
  518. phi = phi - 0.125
  519. else:
  520. phi = 0.875 + phi
  521. proto2[c0,c1] = phi*nx
  522. c1 += 1
  523. c0 += 1
  524. return [proto,proto2]
  525. def set_up_pw_prototype_bank(protos):
  526. num_protos = len(protos)
  527. bank = np.zeros((8*num_protos,protos[0].shape[0],protos[0].shape[1]))
  528. for i in range(num_protos):
  529. bank[0+i*8] = protos[i]
  530. bank[1+i*8] = np.transpose(bank[0+i*8])
  531. bank[2+i*8] = np.flipud(bank[0+i*8])
  532. bank[3+i*8] = np.fliplr(bank[0+i*8])
  533. bank[4+i*8] = np.transpose(bank[2+i*8])
  534. bank[5+i*8] = np.transpose(bank[3+i*8])
  535. bank[6+i*8] = np.fliplr(bank[2+i*8])
  536. bank[7+i*8] = np.fliplr(bank[5+i*8])
  537. return bank
  538. def do_pattern_matching_pref_stim_prototype_bank(pref_stim_img, bank, num_stimuli, measure = "float"):
  539. more_img = np.zeros((3*pref_stim_img.shape[0],3*pref_stim_img.shape[1]))
  540. for i in range(3):
  541. for j in range(3):
  542. more_img[i*pref_stim_img.shape[0]:(i+1)*pref_stim_img.shape[0],j*pref_stim_img.shape[1]:(j+1)*pref_stim_img.shape[1]] = pref_stim_img
  543. diff = np.zeros([bank.shape[1], bank.shape[2]])
  544. diff_sum = 0.
  545. con_img = np.zeros((pref_stim_img.shape[0],pref_stim_img.shape[1]))
  546. dhalf = bank.shape[1]/2
  547. for i in range(pref_stim_img.shape[0],2*pref_stim_img.shape[0]):
  548. for j in range(pref_stim_img.shape[1],2*pref_stim_img.shape[1]):
  549. for b in range(bank.shape[0]):
  550. curr_img = more_img[i-dhalf:i+dhalf+bank.shape[1]%2,j-dhalf:j+dhalf+bank.shape[2]%2]
  551. num_count = 0
  552. for num in range(num_stimuli):
  553. if num in curr_img:
  554. num_count += 1
  555. diff = abs(curr_img-bank[b])
  556. diff[np.where(diff>num_stimuli/2)] = num_stimuli - diff[np.where(diff>num_stimuli/2)]
  557. diff_sum = np.sum(diff)
  558. measure = ((diff_sum/(bank[0].size*num_stimuli/2.)-0.5)*4)**2
  559. if (num_count>=0.8*num_stimuli) and measure>=2.:
  560. if measure == 'binary':
  561. con_img[i-pref_stim_img.shape[0],j-pref_stim_img.shape[1]] = 1
  562. else:
  563. con_img[i-pref_stim_img.shape[0],j-pref_stim_img.shape[1]] += measure
  564. return con_img
  565. def eval_pattern_matching(matched_img):
  566. more_img = np.zeros((3*matched_img.shape[0],3*matched_img.shape[1]))
  567. for i in range(3):
  568. for j in range(3):
  569. more_img[i*matched_img.shape[0]:(i+1)*matched_img.shape[0],j*matched_img.shape[1]:(j+1)*matched_img.shape[1]] = matched_img
  570. s = [[1,1,1],
  571. [1,1,1],
  572. [1,1,1]]
  573. labeled, num = measurements.label(more_img, structure = s)
  574. for i in range(1,num+1):
  575. if not more_img[np.where(labeled == i)].any()>0.8*np.amax(more_img):
  576. more_img[np.where(np.where(labeled == i))] = 0
  577. labeled, num = measurements.label(more_img, structure = s)
  578. centers = measurements.center_of_mass(more_img, labeled, range(1,num+1))
  579. if len(centers)>0:
  580. more_pw = np.zeros((2,len(centers)))
  581. more_pw[0] = np.round(centers)[:,0]
  582. more_pw[1] = np.round(centers)[:,1]
  583. more_pw = tuple(more_pw.astype(int))
  584. more_img[:,:] = 0
  585. more_img[more_pw] = 1
  586. pw_matrix = more_img[matched_img.shape[0]:2*matched_img.shape[0],matched_img.shape[1]:2*matched_img.shape[1]]
  587. pinwheels = np.where(pw_matrix==1)
  588. else:
  589. pinwheels = []
  590. return pinwheels
  591. def calc_curl(pref_stim, pref_x, n, m):
  592. more_x = np.zeros((2*pref_x.shape[0],2*pref_x.shape[1]))
  593. for i in range(2):
  594. for j in range(2):
  595. more_x[i*pref_stim.shape[0]:int((i+0.5)*pref_stim.shape[0]),j*pref_stim.shape[1]:int((j+1)*pref_stim.shape[1])] = pref_x[pref_stim.shape[0]/2:]
  596. more_x[int((i+0.5)*pref_stim.shape[0]):(i+1)*pref_stim.shape[0],int((j)*pref_stim.shape[1]):(j+1)*pref_stim.shape[1]] = pref_x[:pref_stim.shape[0]/2]
  597. more_x = np.roll(more_x, pref_stim.shape[1]/2)
  598. more_img = np.zeros((2*pref_stim.shape[0],2*pref_stim.shape[1]))
  599. for i in range(2):
  600. for j in range(2):
  601. more_img[i*pref_stim.shape[0]:int((i+0.5)*pref_stim.shape[0]),j*pref_stim.shape[1]:int((j+1)*pref_stim.shape[1])] = pref_stim[pref_stim.shape[0]/2:]
  602. more_img[int((i+0.5)*pref_stim.shape[0]):(i+1)*pref_stim.shape[0],int((j)*pref_stim.shape[1]):(j+1)*pref_stim.shape[1]] = pref_stim[:pref_stim.shape[0]/2]
  603. more_img = np.roll(more_img, pref_stim.shape[1]/2)
  604. viel_img = more_img.copy()
  605. more_img = np.logical_or(more_img == n,more_img == m)
  606. labeled, num = measurements.label(more_img)
  607. label_indices = [(labeled == i).nonzero() for i in xrange(1, num+1)]
  608. centers = measurements.center_of_mass(more_img, labeled, range(1,num+1))
  609. patch_sizes = [len(label_indices[i][0]) for i in range(num)]
  610. biggest_patches = np.where(abs(patch_sizes-np.amax(patch_sizes))<2)
  611. for patch in biggest_patches[0]:
  612. i = np.array(centers)[patch].astype(int)
  613. #for i in np.array(centers)[biggest_patches].astype(int):
  614. if (i[0] >= (pref_stim.shape[0]/2)) and (i[0] < (pref_stim.shape[0]/2+pref_stim.shape[0])) and (i[1] >= (pref_stim.shape[1]/2)) and (i[1] < (pref_stim.shape[1]/2+pref_stim.shape[0])):
  615. a = np.zeros((viel_img.shape[0],viel_img.shape[1]))
  616. a[label_indices[patch]] = 1
  617. g = np.ones((3,3))
  618. f = convolve2d(a,g,mode='same',boundary='wrap')
  619. a = f
  620. a = np.zeros((viel_img.shape[0],viel_img.shape[1]))
  621. a[np.where(f>=3)] = 1
  622. f = convolve2d(a,g,mode='same',boundary='wrap')
  623. a[np.where(f<=8)] = 0
  624. b = np.diff(a).astype(int)
  625. c = np.diff(a,axis=0).astype(int)
  626. d = np.where(c[:,1:]+b[1:] != 0)
  627. d = np.array(d) + 1
  628. e = (d[0],d[1])
  629. tree = scipy.spatial.KDTree(np.array(e).T)
  630. mat = tree.sparse_distance_matrix(tree,70.).todense()
  631. start = 0
  632. minimum = 0.
  633. ordered = np.zeros((d[0].size,2))
  634. order = np.zeros(d[0].size, dtype=int)
  635. order[0] = start
  636. ordered[0] = np.array(d[:,start])
  637. for j in range(1,d[0].size):
  638. start, minimum = find_next_nn(mat,minimum, start, order)
  639. order[j] = start
  640. ordered[j] = np.array(d[:,start])
  641. integral = 0
  642. for j in range(order.size):
  643. integral += cycl_dist(more_x[int(ordered[j,0]), int(ordered[j,1])],more_x[int(ordered[j-1,0]), int(ordered[j-1,1])])
  644. integral /= float(pref_x.max()+1)
  645. ordered = ordered.T
  646. ordered = (ordered[0] + pref_stim.shape[0]/2, ordered[1] + pref_stim.shape[1]/2)
  647. ordered = (ordered[0].astype(int)%pref_stim.shape[0], ordered[1].astype(int)%pref_stim.shape[1])
  648. center = np.array([i[0],i[1]]).astype(int) - np.array([pref_stim.shape[0]/2, pref_stim.shape[1]/2]).astype(int)
  649. patch_indices = (np.array(label_indices[patch][0]) + pref_stim.shape[0]/2, np.array(label_indices[patch][1]) + pref_stim.shape[1]/2)
  650. patch_indices = (patch_indices[0]%pref_stim.shape[0], patch_indices[1]%pref_stim.shape[1])
  651. #patch_indices = np.sum(label_indices[patch], -np.array([pref_stim.shape[0]/2, pref_stim.shape[1]/2]).astype(int))
  652. center = (np.array([center[0]]),np.array([center[1]]))
  653. return ordered, integral
  654. def find_next_nn(mat,minimum, start, order):
  655. indices = np.where(mat[start]>minimum)
  656. minimum = mat[start][indices][0].min()
  657. ind = np.where(mat[start][0]==minimum)[1].copy()
  658. if ind[0,0] in order:
  659. if ind.size > 1:
  660. if ind[0,1] in order:
  661. if ind.size > 2:
  662. if ind[0,2] in order:
  663. start, minimum = find_next_nn(mat,minimum, start, order)
  664. else:
  665. start = ind[0,2]
  666. minimum = 0.
  667. else:
  668. start, minimum = find_next_nn(mat,minimum, start, order)
  669. else:
  670. start = ind[0,1]
  671. minimum = 0.
  672. else:
  673. start, minimum = find_next_nn(mat,minimum, start, order)
  674. else:
  675. start = ind[0,0]
  676. minimum = 0.
  677. return start, minimum
  678. def cycl_dist(a,b,n_stim=20):
  679. d = a-b
  680. if d>n_stim/2:
  681. d = n_stim - d
  682. if d<-n_stim/2:
  683. d = -n_stim - d
  684. return d
  685. def calc_pinwheels_pwcharge(pref_map_in, num_stimuli):
  686. """ radius of circle for charge calculation chosen to be 6. This is big enough to tolerate
  687. small discontiuities, but small enough to avoid overlap between two neighboring pinwheels """
  688. pref_map = np.array(pref_map_in).copy()
  689. r = 6
  690. a = np.amax(pref_map.shape)
  691. circle1 = np.array([[i,j] for i, j in product(range(a/2,-a/2,-1), range(-a/2,0)) if abs(i*i+j*j - r*r) < r])
  692. circle2 = np.array([[i,j] for i, j in product(range(-a/2,a/2), range(0,a/2)) if abs(i*i+j*j - r*r) < r])
  693. circle = np.concatenate((circle1,circle2))
  694. circ_ind = circle.T.copy()
  695. """ create new extended map with double the size of the original map and do the calculations with it in order to avoid boundary effects """
  696. more_map = np.zeros((2*pref_map.shape[0],2*pref_map.shape[1]))
  697. for i in range(2):
  698. for j in range(2):
  699. more_map[i*pref_map.shape[0]:int((i+0.5)*pref_map.shape[0]),j*pref_map.shape[1]:int((j+1)*pref_map.shape[1])] = pref_map[pref_map.shape[0]/2:].copy()
  700. more_map[int((i+0.5)*pref_map.shape[0]):(i+1)*pref_map.shape[0],int((j)*pref_map.shape[1]):(j+1)*pref_map.shape[1]] = pref_map[:pref_map.shape[0]/2].copy()
  701. more_map = np.roll(more_map, pref_map.shape[1]/2)
  702. """ calculate charge landcape with moving the circle over all possible positions on the extended map """
  703. charge = np.zeros((len(range(r,more_map.shape[0]-r)),len(range(r,more_map.shape[1]-r))))
  704. for i,j in product(range(r,more_map.shape[0]-r),range(r,more_map.shape[1]-r)):
  705. circle_coords = circle.copy()
  706. circle_coords[:,0] += i
  707. circle_coords[:,1] += j
  708. charge[(i-r),(j-r)] = calc_charge(more_map, circle_coords, num_stimuli)
  709. """ disregard areas of low charge values due to random fluctuations """
  710. charge[np.where(np.abs(charge)<0.8)] = 0.0
  711. """ convolve charge landscape with np.ones((3,3)) to create connected areas of high charge
  712. and avoid different areas close to eachother """
  713. kernel = np.ones((3,3))
  714. charge_conv = convolve2d(charge, kernel, mode='same', boundary='wrap')
  715. """ disregard areas that only had one single point of high charge or two with medium charge"""
  716. charge_conv[np.where(np.abs(charge_conv)<2)] = 0.0
  717. """ use label to identify connected areas belonging to one pinwheel """
  718. labeled, num = measurements.label(charge_conv)#, kernel)
  719. label_indices = [(labeled == i).nonzero() for i in xrange(1, num+1)]
  720. patch_sizes = [len(label_indices[i][0]) for i in range(num)]
  721. """ neighborhood condition: there must be at least 5 points in a labeled region.
  722. These 5 points had at least two high charge points in their vicinity. This disregards single missing points in a region
  723. but ensures original points of high chrage did have a short distance to eachother and thereby belong to the same pw.
  724. (This corresponds to some fancy geometrical forms and prevents algorithm from finding pws in salt and pepper pattern)"""
  725. labels = []
  726. for i in range(num):
  727. if patch_sizes[i]>4:
  728. labels.append(i+1)
  729. if len(labels)>0:
  730. """ calculate pw coordinates as centers of mass of the regions in coordinate system of the extended map"""
  731. centers = measurements.center_of_mass(charge_conv, labeled, labels)
  732. more_pw = np.zeros((2,len(centers)))
  733. more_pw[0] = np.round(centers)[:,0].astype(int)
  734. more_pw[1] = np.round(centers)[:,1].astype(int)
  735. more_pw = tuple(more_pw.astype(int))
  736. more_pw_matrix = np.zeros(more_map.shape)
  737. more_pw_matrix[more_pw] = 1
  738. """ transform to original coordinate frame """
  739. d = [pref_map.shape[0]/2-r, pref_map.shape[1]/2-r]
  740. pw_matrix = more_pw_matrix[d[0]:(pref_map.shape[0]+d[0]), d[1]:(pref_map.shape[1]+d[1])].copy()
  741. pinwheels = np.where(pw_matrix==1)
  742. else:
  743. pinwheels = ()
  744. """ return pinwheel coordinates """
  745. return pinwheels
  746. def calc_charge(map_matrix, ordered_surface_coords, num_stimuli):
  747. values = np.zeros(ordered_surface_coords.shape[0])
  748. charge = 0.
  749. for k in range(ordered_surface_coords.shape[0]):
  750. delta = cycl_dist(map_matrix[ordered_surface_coords[k,0],ordered_surface_coords[k,1]], map_matrix[ordered_surface_coords[k-1,0],ordered_surface_coords[k-1,1]])
  751. values[k] = map_matrix[ordered_surface_coords[k,0],ordered_surface_coords[k,1]]
  752. """ continuity condition: only accept values for the integral if there are no major jumps in values (eg in random map) """
  753. if abs(delta) < 4:
  754. charge += delta
  755. """ wheel condition: only accept high pinwheel value if almost all values were on the surface """
  756. if np.unique(values).size < num_stimuli-num_stimuli/10.:
  757. charge = 0.0
  758. """ normalization """
  759. charge /= float(num_stimuli)
  760. return charge
  761. if __name__=="__main__":
  762. pass