util.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import os
  2. import numpy as np
  3. import matplotlib.mlab as mlab
  4. from scipy.interpolate import interp1d
  5. def read_file_list(filename):
  6. """Reads the files containing the dataset names.
  7. Parameters
  8. ----------
  9. filename : str
  10. The file name
  11. Returns
  12. -------
  13. list
  14. The dataset names
  15. Raises
  16. ------
  17. FileNotFoundError
  18. If the file could not be found
  19. """
  20. if os.path.exists(filename):
  21. datasets = []
  22. with open(filename, "r") as f:
  23. for l in f:
  24. if len(l.strip()) == 0:
  25. continue
  26. if "invivo" not in l:
  27. dataset = l.strip() + "-invivo-1"
  28. else:
  29. dataset = l.strip()
  30. datasets.append(dataset)
  31. return datasets
  32. else:
  33. raise FileNotFoundError(f"File {filename} was not found!")
  34. def threshold_crossings(data, time, threshold=0.0):
  35. """Detects rising threshold crossings in the data. Returns the time and the indices of the crossings.
  36. Parameters
  37. ----------
  38. data : np.array
  39. Vector of data values
  40. time : np.array
  41. Vector of respective recording times
  42. threshold : float, optional
  43. The threshold value, by default 0.0
  44. Returns
  45. -------
  46. np.array
  47. The crossing times
  48. np.array
  49. The respective indices in data and time
  50. """
  51. indices = np.where((data > threshold) & (np.roll(data, 1) <= threshold))[0]
  52. times = time[indices]
  53. return times, indices
  54. def line(p1, p2):
  55. a = (p1[1] - p2[1])
  56. b = (p2[0] - p1[0])
  57. c = (p1[0]*p2[1] - p2[0]*p1[1])
  58. return a, b, -c
  59. def intersection(l1, l2):
  60. d = l1[0] * l2[1] - l1[1] * l2[0]
  61. dx = l1[2] * l2[1] - l1[1] * l2[2]
  62. dy = l1[0] * l2[2] - l1[2] * l2[0]
  63. if d != 0:
  64. x = dx / d
  65. y = dy / d
  66. return x, y
  67. else:
  68. return False
  69. def detect_eod_times(time, eod, threshold=0.0, running_average=0):
  70. """Detects the EOD threshold crossings. Applies some optional smoothing before threshold detection.
  71. Uses linear interpolation to estimate the real times of the threshold crossings. Resulting EOD times,
  72. will have a higher temporal precision than given by the sampling of the data.
  73. Parameters
  74. ----------
  75. time : np.array
  76. vector containing the recording timestamps in seconds
  77. eod : np.array
  78. vector with the recorded EOD values in mV
  79. threshold : float, optional
  80. The EOD threshold in mV, by default 0.0
  81. running_average : int, optional
  82. the width of a running average used to smooth the recorded EOD, by default 0, no smoothing
  83. Returns
  84. -------
  85. np.array
  86. vector of EOD threshold crossings
  87. """
  88. if running_average > 0:
  89. eod = np.convolve(eod, np.ones(int(running_average))/int(running_average), mode="same")
  90. _, idx = threshold_crossings(eod, time, threshold)
  91. threshold_line = line([time[0], threshold], [time[-1], threshold])
  92. max_index = min(len(eod), len(time)) - 1
  93. valid_indices = idx[idx < max_index]
  94. times = np.zeros(valid_indices.shape)
  95. for j, index in enumerate(valid_indices):
  96. l1 = line([time[index - 1], eod[index - 1]], [time[index + 1], eod[index + 1]])
  97. intersec = intersection(l1, threshold_line)
  98. if isinstance(intersec, bool):
  99. times[j] = times[j-1]
  100. continue
  101. times[j] = intersection(l1, threshold_line)[0]
  102. return times, idx
  103. def gaussKernel(sigma, dt):
  104. """ Creates a Gaussian kernel with a given standard deviation and an integral of 1.
  105. Parameters
  106. ----------
  107. sigma : float
  108. The standard deviation of the kernel in seconds
  109. dt : float
  110. The temporal resolution of the kernel, given in seconds.
  111. Returns:
  112. np.array
  113. The kernel in the range -4 to +4 sigma
  114. """
  115. x = np.arange(-4. * sigma, 4. * sigma, dt)
  116. y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
  117. return y
  118. def spike_triggered_average(spikes, stimulus_time, stimulus_amplitudes, delay=0.03, samplingrate=20000):
  119. sta_time = np.arange(-delay, delay, 1/samplingrate)
  120. sta = np.zeros(sta_time.shape)
  121. count = 0
  122. for sp in spikes:
  123. if sp < delay or sp > stimulus_time[-1] - delay:
  124. continue
  125. start_index = int(np.round((sp - delay) * samplingrate))
  126. end_index = int(np.round((sp + delay) * samplingrate))
  127. sta += stimulus_amplitudes[start_index:end_index]
  128. count += 1
  129. sta /= count
  130. return sta_time, sta
  131. def get_temporal_shift(sta_time, sta):
  132. min_time = sta_time[np.argmin(sta)]
  133. max_time = sta_time[np.argmax(sta)]
  134. min_sta = sta[np.argmin(sta)]
  135. max_sta = sta[np.argmax(sta)]
  136. shift = 0.0
  137. inverted = False
  138. if np.abs(min_sta) > np.abs(max_sta) and min_time > max_time:
  139. shift = min_time
  140. inverted = True
  141. else:
  142. shift = max_time
  143. return np.abs(shift), inverted
  144. def load_stim(filename):
  145. """
  146. Loads a data file saved by relacs. Returns a tuple of dictionaries
  147. containing the data and the header information
  148. :param filename: Filename of the data file.
  149. :type filename: string
  150. :returns: a tuple of dictionaries containing the head information and the data.
  151. :rtype: tuple
  152. """
  153. with open(filename, 'r') as fid:
  154. L = [l.lstrip().rstrip() for l in fid.readlines()]
  155. ret = []
  156. dat = {}
  157. X = []
  158. keyon = False
  159. currkey = None
  160. for l in L:
  161. # if empty line and we have data recorded
  162. if (not l or l.startswith('#')) and len(X) > 0:
  163. keyon = False
  164. currkey = None
  165. dat['data'] = np.array(X)
  166. ret.append(dat)
  167. X = []
  168. dat = {}
  169. if '---' in l:
  170. continue
  171. if l.startswith('#'):
  172. if ":" in l:
  173. tmp = [e.rstrip().lstrip() for e in l[1:].split(':')]
  174. if currkey is None:
  175. dat[tmp[0]] = tmp[1]
  176. else:
  177. dat[currkey][tmp[0]] = tmp[1]
  178. elif "=" in l:
  179. tmp = [e.rstrip().lstrip() for e in l[1:].split('=')]
  180. if currkey is None:
  181. dat[tmp[0]] = tmp[1]
  182. else:
  183. dat[currkey][tmp[0]] = tmp[1]
  184. elif l[1:].lower().startswith('key'):
  185. dat['key'] = []
  186. keyon = True
  187. elif keyon:
  188. dat['key'].append(tuple([e.lstrip().rstrip() for e in l[1:].split()]))
  189. else:
  190. currkey = l[1:].rstrip().lstrip()
  191. dat[currkey] = {}
  192. elif l: # if l != ''
  193. keyon = False
  194. currkey = None
  195. X.append([float(e) for e in l.split()])
  196. if len(X) > 0:
  197. dat['data'] = np.array(X)
  198. else:
  199. dat['data'] = []
  200. ret.append(dat)
  201. return tuple(ret)
  202. def load_white_noise_stim(stim_name, stim_duration=10, sampling_rate=20000, folder="stimuli"):
  203. if stim_duration == 0.0:
  204. print("Stimulus duration must be larger than zero!")
  205. return None, None
  206. if not os.path.exists(folder):
  207. folder = os.path.expanduser(os.path.join("~", "data", "stimuli"))
  208. stim_file = stim_name.split('/')[-1]
  209. full_file = os.path.sep.join([folder, stim_file])
  210. if not os.path.exists(full_file):
  211. print("Stimulus file does not exist")
  212. return None, None
  213. s = load_stim(full_file)
  214. inter = interp1d(s[0]['data'][:, 0], s[0]['data'][:, 1])
  215. x_org = s[0]['data'][:, 0]
  216. x_new = np.linspace(0.0, stim_duration, int(stim_duration * sampling_rate))
  217. x_new[x_new > x_org[-1]] = x_org[-1]
  218. stimulus = inter(x_new)
  219. return x_new, stimulus
  220. def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
  221. """Convert spike times to a firing rate estimated by kernel convolution with a Gaussian kernel.
  222. Args:
  223. spikes (np.array): the spike times
  224. duration (float): the trial duration
  225. sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005.
  226. dt (float, optional): desired temporal resolution of the firing rate. Defaults to 1./20000..
  227. Returns:
  228. np.array: the firing rate
  229. """
  230. binary = np.zeros(int(np.round(duration/dt)))
  231. indices = np.asarray(np.round(spikes / dt), dtype=np.int)
  232. binary[indices[indices < len(binary)]] = 1
  233. kernel = gaussKernel(sigma, dt)
  234. rate = np.convolve(kernel, binary, mode="same")
  235. return rate
  236. def coherence_rate(coherence, frequency, start_freq, stop_freq):
  237. """Estimates the coherence rate as a lower bound estimation of the mutual information between stimulus and response.
  238. Parameters
  239. ----------
  240. coherence : _type_
  241. _description_
  242. frequency : _type_
  243. _description_
  244. start_freq : _type_
  245. _description_
  246. stop_freq : _type_
  247. _description_
  248. Returns
  249. -------
  250. _type_
  251. _description_
  252. """
  253. deltaf = np.mean(np.diff(frequency))
  254. spectrum = -np.log2((np.ones(coherence.shape) - coherence)) * deltaf
  255. info = np.sum(spectrum[(frequency >= start_freq) & (frequency < stop_freq)])
  256. return info
  257. def mutual_info(spike_responses, artificial_delay, inversion_needed, stimulus, freq_bin_edges, kernel_sigma=0.00125, delay_type="equal", stepsize=1./20000, trial_duration=10.):
  258. """Estimates a the lower bound mutual information carried by the responses about the stimulus. The MI is estimated using the stimulus response coherence. See Borst & Theunissen, 2001.
  259. Parameters
  260. ----------
  261. spike_responses : list of np.ndarrays of float
  262. The spike times
  263. artificial_delay : float
  264. The artificial delay that is added to the spike times. If delay is less or equal 0.0 no delay will be added
  265. inversion_needed : bool
  266. Whether or not the firing rate is inverted, may happen due to the receptor location, stimulus geometry
  267. stimulus : np.ndarray
  268. The stimulus waveform.
  269. freq_bin_edges : list of float
  270. The edges of the frequency bands to be analysed.
  271. kernel_sigma : float, optional
  272. Standard deviation of the Gaussian kernel used for firing rate estimation, by default 0.00125
  273. delay_type : str, optional
  274. Type of destribution of the delays, by default "equal", alternatively "gaussian"
  275. stepsize : float, optional
  276. Temporal stepsize of the stimulus trace, by default 1./20000
  277. trial_duration : float, optional
  278. The duration of a recorded trial, by default 10. The stimulus waveform may be longer than the responses are actually recorded.
  279. Returns
  280. -------
  281. np.array:
  282. array containing the frequency axis of the coherence function
  283. np.array:
  284. the average coherence function
  285. np.array:
  286. mutual information values, shape is len of freq_bin_edges + 1
  287. np.array:
  288. avg_rate; the average firing rate
  289. np.array:
  290. rate_error; std of firing rates as function of time
  291. float:
  292. true_delay; avg of the real delays
  293. np.array:
  294. variability; the time-averaged rate
  295. """
  296. population_rates = None
  297. rng = np.random.default_rng()
  298. delays = np.zeros(len(spike_responses))
  299. for i, (sr, invert) in enumerate(zip(spike_responses, inversion_needed)):
  300. if isinstance(sr, list):
  301. sr = np.array(sr)
  302. dt = 0.0
  303. if artificial_delay > 0.0:
  304. if "equal" in delay_type:
  305. dt = rng.random() * artificial_delay * 2
  306. elif "gaussian" in delay_type:
  307. dt = rng.standard_random() * artificial_delay
  308. delays[i] = dt
  309. sr += dt
  310. rate = firing_rate(sr, trial_duration, kernel_sigma, dt=stepsize)
  311. if invert:
  312. avg = np.mean(rate)
  313. rate -= avg
  314. rate *= -1
  315. rate += avg
  316. if population_rates is None:
  317. population_rates = np.zeros((len(rate), len(spike_responses)))
  318. population_rates[:, i] = rate
  319. c, f = mlab.cohere(np.mean(population_rates, axis=1), stimulus, NFFT=2**14, noverlap=2**13,
  320. Fs=1./stepsize, detrend=mlab.detrend_mean, window=mlab.window_hanning)
  321. mis = np.zeros(len(freq_bin_edges))
  322. mis[0] = coherence_rate(c, f, 0, freq_bin_edges[-1])
  323. for i in range(1, len(freq_bin_edges)):
  324. mis[i] = coherence_rate(c, f, freq_bin_edges[i-1], freq_bin_edges[i])
  325. return f, c, mis, population_rates, np.mean(delays)
  326. class LIF(object):
  327. def __init__(self, stepsize=0.0001, offset=1.6, tau_m=0.010, tau_a=0.02, da=0.0, D=3.5):
  328. self.stepsize = stepsize # simulation stepsize [s]
  329. self.offset = offset # offset curent [nA]
  330. self.tau_m = tau_m # membrane time_constant [s]
  331. self.tau_a = tau_a # adaptation time_constant [s]
  332. self.da = da # increment in adaptation current [nA]
  333. self.D = D # noise intensity
  334. self.v_threshold = 1.0 # spiking threshold
  335. self.v_reset = 0.0 # reset voltage after spiking
  336. self.i_a = 0.0 # current adaptation current
  337. self.v = self.v_reset # current membrane voltage
  338. self.t = 0.0 # current time [s]
  339. self.membrane_voltage = []
  340. self.spike_times = []
  341. def _reset(self):
  342. self.i_a = 0.0
  343. self.v = self.v_reset
  344. self.t = 0.0
  345. self.membrane_voltage = []
  346. self.spike_times = []
  347. def _lif(self, stimulus, noise):
  348. """
  349. euler solution of the membrane equation with adaptation current and noise
  350. """
  351. self.i_a -= self.i_a - self.stepsize/self.tau_a * (self.i_a)
  352. self.v += self.stepsize * ( -self.v + stimulus + noise + self.offset - self.i_a)/self.tau_m;
  353. self.membrane_voltage.append(self.v)
  354. def _next(self, stimulus):
  355. """
  356. working horse which delegates to the euler and gets the spike times
  357. """
  358. noise = self.D * (float(np.random.randn() % 10000) - 5000.0)/10000
  359. self._lif(stimulus, noise)
  360. self.t += self.stepsize
  361. if self.v > self.v_threshold and len(self.membrane_voltage) > 1:
  362. self.v = self.v_reset
  363. self.membrane_voltage[len(self.membrane_voltage)-1] = 2.0
  364. self.spike_times.append(self.t)
  365. self.i_a += self.da
  366. def run_const_stim(self, steps, stimulus):
  367. """
  368. lif simulation with constant stimulus.
  369. """
  370. self._reset()
  371. for i in range(steps):
  372. self._next(stimulus);
  373. time = np.arange(len(self.membrane_voltage))*self.stepsize
  374. return time, np.array(self.membrane_voltage), np.array(self.spike_times)
  375. def run_stimulus(self, stimulus):
  376. """
  377. lif simulation with a predefined stimulus trace.
  378. """
  379. self._reset()
  380. for s in stimulus:
  381. self._next(s);
  382. time = np.arange(len(self.membrane_voltage)) * self.stepsize
  383. return time, np.array(self.membrane_voltage), np.array(self.spike_times)
  384. def __str__(self):
  385. out = '\n'.join(["stepsize: \t" + str(self.stepsize),
  386. "offset:\t\t" + str(self.offset),
  387. "tau_m:\t\t" + str(self.tau_m),
  388. "tau_a:\t\t" + str(self.tau_a),
  389. "da:\t\t" + str(self.da),
  390. "D:\t\t" + str(self.D),
  391. "v_threshold:\t" + str(self.v_threshold),
  392. "v_reset:\t" + str(self.v_reset)])
  393. return out
  394. def __repr__(self):
  395. return self.__str__()