util.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import os
  2. import numpy as np
  3. import matplotlib.mlab as mlab
  4. from scipy.interpolate import interp1d
  5. def read_file_list(filename):
  6. """Reads the files containing the dataset names.
  7. Parameters
  8. ----------
  9. filename : str
  10. The file name
  11. Returns
  12. -------
  13. list
  14. The dataset names
  15. Raises
  16. ------
  17. FileNotFoundError
  18. If the file could not be found
  19. """
  20. if os.path.exists(filename):
  21. datasets = []
  22. with open(filename, "r") as f:
  23. for l in f:
  24. if len(l.strip()) == 0:
  25. continue
  26. if "invivo" not in l:
  27. dataset = l.strip() + "-invivo-1"
  28. else:
  29. dataset = l.strip()
  30. datasets.append(dataset)
  31. return datasets
  32. else:
  33. raise FileNotFoundError(f"File {filename} was not found!")
  34. def threshold_crossings(data, time, threshold=0.0):
  35. """Detects rising threshold crossings in the data. Returns the time and the indices of the crossings.
  36. Parameters
  37. ----------
  38. data : np.array
  39. Vector of data values
  40. time : np.array
  41. Vector of respective recording times
  42. threshold : float, optional
  43. The threshold value, by default 0.0
  44. Returns
  45. -------
  46. np.array
  47. The crossing times
  48. np.array
  49. The respective indices in data and time
  50. """
  51. indices = np.where((data > threshold) & (np.roll(data, 1) <= threshold))[0]
  52. times = time[indices]
  53. return times, indices
  54. def line(p1, p2):
  55. a = (p1[1] - p2[1])
  56. b = (p2[0] - p1[0])
  57. c = (p1[0]*p2[1] - p2[0]*p1[1])
  58. return a, b, -c
  59. def intersection(l1, l2):
  60. d = l1[0] * l2[1] - l1[1] * l2[0]
  61. dx = l1[2] * l2[1] - l1[1] * l2[2]
  62. dy = l1[0] * l2[2] - l1[2] * l2[0]
  63. if d != 0:
  64. x = dx / d
  65. y = dy / d
  66. return x, y
  67. else:
  68. return False
  69. def detect_eod_times(time, eod, threshold=0.0, running_average=0):
  70. """Detects the EOD threshold crossings. Applies some optional smoothing before threshold detection.
  71. Uses linear interpolation to estimate the real times of the threshold crossings. Resulting EOD times,
  72. will have a higher temporal precision than given by the sampling of the data.
  73. Parameters
  74. ----------
  75. time : np.array
  76. vector containing the recording timestamps in seconds
  77. eod : np.array
  78. vector with the recorded EOD values in mV
  79. threshold : float, optional
  80. The EOD threshold in mV, by default 0.0
  81. running_average : int, optional
  82. the width of a running average used to smooth the recorded EOD, by default 0, no smoothing
  83. Returns
  84. -------
  85. np.array
  86. vector of EOD threshold crossings
  87. """
  88. if running_average > 0:
  89. eod = np.convolve(eod, np.ones(int(running_average))/int(running_average), mode="same")
  90. _, idx = threshold_crossings(eod, time, threshold)
  91. threshold_line = line([time[0], threshold], [time[-1], threshold])
  92. max_index = min(len(eod), len(time)) - 1
  93. valid_indices = idx[idx < max_index]
  94. times = np.zeros(valid_indices.shape)
  95. for j, index in enumerate(valid_indices):
  96. l1 = line([time[index - 1], eod[index - 1]], [time[index + 1], eod[index + 1]])
  97. intersec = intersection(l1, threshold_line)
  98. if isinstance(intersec, bool):
  99. times[j] = times[j-1]
  100. continue
  101. times[j] = intersection(l1, threshold_line)[0]
  102. return times, idx
  103. def gaussKernel(sigma, dt):
  104. """ Creates a Gaussian kernel with a given standard deviation and an integral of 1.
  105. Parameters
  106. ----------
  107. sigma : float
  108. The standard deviation of the kernel in seconds
  109. dt : float
  110. The temporal resolution of the kernel, given in seconds.
  111. Returns:
  112. np.array
  113. The kernel in the range -4 to +4 sigma
  114. """
  115. x = np.arange(-4. * sigma, 4. * sigma, dt)
  116. y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
  117. return y
  118. def spike_triggered_average(spikes, stimulus_time, stimulus_amplitudes, delay=0.03, samplingrate=20000):
  119. sta_time = np.arange(-delay, delay, 1/samplingrate)
  120. sta = np.zeros(sta_time.shape)
  121. count = 0
  122. for sp in spikes:
  123. if sp < delay or sp > stimulus_time[-1] - delay:
  124. continue
  125. start_index = int(np.round((sp - delay) * samplingrate))
  126. end_index = int(np.round((sp + delay) * samplingrate))
  127. sta += stimulus_amplitudes[start_index:end_index]
  128. count += 1
  129. sta /= count
  130. return sta_time, sta
  131. def get_temporal_shift(sta_time, sta):
  132. min_time = sta_time[np.argmin(sta)]
  133. max_time = sta_time[np.argmax(sta)]
  134. min_sta = sta[np.argmin(sta)]
  135. max_sta = sta[np.argmax(sta)]
  136. shift = 0.0
  137. inverted = False
  138. if np.abs(min_sta) > np.abs(max_sta) and min_time > max_time:
  139. shift = min_time
  140. inverted = True
  141. else:
  142. shift = max_time
  143. return np.abs(shift), inverted
  144. def load_stim(filename):
  145. """
  146. Loads a data file saved by relacs. Returns a tuple of dictionaries
  147. containing the data and the header information
  148. :param filename: Filename of the data file.
  149. :type filename: string
  150. :returns: a tuple of dictionaries containing the head information and the data.
  151. :rtype: tuple
  152. """
  153. with open(filename, 'r') as fid:
  154. L = [l.lstrip().rstrip() for l in fid.readlines()]
  155. ret = []
  156. dat = {}
  157. X = []
  158. keyon = False
  159. currkey = None
  160. for l in L:
  161. # if empty line and we have data recorded
  162. if (not l or l.startswith('#')) and len(X) > 0:
  163. keyon = False
  164. currkey = None
  165. dat['data'] = np.array(X)
  166. ret.append(dat)
  167. X = []
  168. dat = {}
  169. if '---' in l:
  170. continue
  171. if l.startswith('#'):
  172. if ":" in l:
  173. tmp = [e.rstrip().lstrip() for e in l[1:].split(':')]
  174. if currkey is None:
  175. dat[tmp[0]] = tmp[1]
  176. else:
  177. dat[currkey][tmp[0]] = tmp[1]
  178. elif "=" in l:
  179. tmp = [e.rstrip().lstrip() for e in l[1:].split('=')]
  180. if currkey is None:
  181. dat[tmp[0]] = tmp[1]
  182. else:
  183. dat[currkey][tmp[0]] = tmp[1]
  184. elif l[1:].lower().startswith('key'):
  185. dat['key'] = []
  186. keyon = True
  187. elif keyon:
  188. dat['key'].append(tuple([e.lstrip().rstrip() for e in l[1:].split()]))
  189. else:
  190. currkey = l[1:].rstrip().lstrip()
  191. dat[currkey] = {}
  192. elif l: # if l != ''
  193. keyon = False
  194. currkey = None
  195. X.append([float(e) for e in l.split()])
  196. if len(X) > 0:
  197. dat['data'] = np.array(X)
  198. else:
  199. dat['data'] = []
  200. ret.append(dat)
  201. return tuple(ret)
  202. def load_white_noise_stim(stim_name, stim_duration=10, sampling_rate=20000, folder="stimuli"):
  203. if stim_duration == 0.0:
  204. print("Stimulus duration must be larger than zero!")
  205. return None, None
  206. if not os.path.exists(folder):
  207. folder = os.path.expanduser(os.path.join("~", "data", "stimuli"))
  208. stim_file = stim_name.split('/')[-1]
  209. full_file = os.path.sep.join([folder, stim_file])
  210. if not os.path.exists(full_file):
  211. print("Stimulus file does not exist")
  212. return None, None
  213. s = load_stim(full_file)
  214. inter = interp1d(s[0]['data'][:, 0], s[0]['data'][:, 1])
  215. x_org = s[0]['data'][:, 0]
  216. x_new = np.linspace(0.0, stim_duration, int(stim_duration * sampling_rate))
  217. x_new[x_new > x_org[-1]] = x_org[-1]
  218. stimulus = inter(x_new)
  219. return x_new, stimulus
  220. def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
  221. """Convert spike times to a firing rate estimated by kernel convolution with a Gaussian kernel.
  222. Args:
  223. spikes (np.array): the spike times
  224. duration (float): the trial duration
  225. sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005.
  226. dt (float, optional): desired temporal resolution of the firing rate. Defaults to 1./20000..
  227. Returns:
  228. np.array: the firing rate
  229. """
  230. binary = np.zeros(int(np.round(duration/dt)))
  231. indices = np.asarray(np.round(spikes / dt), dtype=np.int)
  232. binary[indices[indices < len(binary)]] = 1
  233. kernel = gaussKernel(sigma, dt)
  234. rate = np.convolve(kernel, binary, mode="same")
  235. return rate
  236. def coherence_rate(coherence, frequency, start_freq, stop_freq):
  237. deltaf = np.mean(np.diff(frequency))
  238. spectrum = -np.log2((np.ones(coherence.shape) - coherence)) * deltaf
  239. info = np.sum(spectrum[(frequency >= start_freq) & (frequency < stop_freq)])
  240. return info
  241. def mutual_info(spike_responses, delay, inversion_needed, stimulus, freq_bin_edges, kernel_sigma=0.00125, delay_type="equal", stepsize=1./20000, trial_duration=10.):
  242. """[summary]
  243. Args:
  244. spike_responses (list of list of spike times): [description]
  245. delay (float): [description]
  246. inversion_needed ([type]): [description]
  247. stimulus ([type]): [description]
  248. freq_bin_edges ([type]): [description]
  249. kernel_sigma (float, optional): [description]. Defaults to 0.00125.
  250. Returns:
  251. np.array: array containing the frequency axis of the coherence function
  252. np.array: the average coherence function
  253. np.array: mutual information values, shape is len of freq_bin_edges + 1
  254. np.array: avg_rate; the average firing rate
  255. np.array: rate_error; std of firing rates as function of time
  256. float: true_delay; avg of the real delays
  257. np.array: variability; the time-averaged rate
  258. """
  259. population_rates = None
  260. rng = np.random.default_rng()
  261. delays = np.zeros(len(spike_responses))
  262. for i, (sr, invert) in enumerate(zip(spike_responses, inversion_needed)):
  263. if isinstance(sr, list):
  264. sr = np.array(sr)
  265. dt = 0.0
  266. if "equal" in delay_type:
  267. dt = rng.random() * delay * 2
  268. elif "gaussian" in delay_type:
  269. dt = rng.standard_random() * delay
  270. delays[i] = dt
  271. sr += dt
  272. rate = firing_rate(sr, trial_duration, kernel_sigma, dt=stepsize)
  273. if invert:
  274. avg = np.mean(rate)
  275. rate -= avg
  276. rate *= -1
  277. rate += avg
  278. if population_rates is None:
  279. population_rates = np.zeros((len(rate), len(spike_responses)))
  280. population_rates[:, i] = rate
  281. c, f = mlab.cohere(np.mean(population_rates, axis=1), stimulus, NFFT=2**14, noverlap=2**13,
  282. Fs=1./stepsize, detrend=mlab.detrend_mean, window=mlab.window_hanning)
  283. mis = np.zeros(len(freq_bin_edges))
  284. mis[0] = coherence_rate(c, f, 0, freq_bin_edges[-1])
  285. for i in range(1, len(freq_bin_edges)):
  286. mis[i] = coherence_rate(c, f, freq_bin_edges[i-1], freq_bin_edges[i])
  287. return f, c, mis, population_rates, np.mean(delays)
  288. class LIF(object):
  289. def __init__(self, stepsize=0.0001, offset=1.6, tau_m=0.010, tau_a=0.02, da=0.0, D=3.5):
  290. self.stepsize = stepsize # simulation stepsize [s]
  291. self.offset = offset # offset curent [nA]
  292. self.tau_m = tau_m # membrane time_constant [s]
  293. self.tau_a = tau_a # adaptation time_constant [s]
  294. self.da = da # increment in adaptation current [nA]
  295. self.D = D # noise intensity
  296. self.v_threshold = 1.0 # spiking threshold
  297. self.v_reset = 0.0 # reset voltage after spiking
  298. self.i_a = 0.0 # current adaptation current
  299. self.v = self.v_reset # current membrane voltage
  300. self.t = 0.0 # current time [s]
  301. self.membrane_voltage = []
  302. self.spike_times = []
  303. def _reset(self):
  304. self.i_a = 0.0
  305. self.v = self.v_reset
  306. self.t = 0.0
  307. self.membrane_voltage = []
  308. self.spike_times = []
  309. def _lif(self, stimulus, noise):
  310. """
  311. euler solution of the membrane equation with adaptation current and noise
  312. """
  313. self.i_a -= self.i_a - self.stepsize/self.tau_a * (self.i_a)
  314. self.v += self.stepsize * ( -self.v + stimulus + noise + self.offset - self.i_a)/self.tau_m;
  315. self.membrane_voltage.append(self.v)
  316. def _next(self, stimulus):
  317. """
  318. working horse which delegates to the euler and gets the spike times
  319. """
  320. noise = self.D * (float(np.random.randn() % 10000) - 5000.0)/10000
  321. self._lif(stimulus, noise)
  322. self.t += self.stepsize
  323. if self.v > self.v_threshold and len(self.membrane_voltage) > 1:
  324. self.v = self.v_reset
  325. self.membrane_voltage[len(self.membrane_voltage)-1] = 2.0
  326. self.spike_times.append(self.t)
  327. self.i_a += self.da
  328. def run_const_stim(self, steps, stimulus):
  329. """
  330. lif simulation with constant stimulus.
  331. """
  332. self._reset()
  333. for i in range(steps):
  334. self._next(stimulus);
  335. time = np.arange(len(self.membrane_voltage))*self.stepsize
  336. return time, np.array(self.membrane_voltage), np.array(self.spike_times)
  337. def run_stimulus(self, stimulus):
  338. """
  339. lif simulation with a predefined stimulus trace.
  340. """
  341. self._reset()
  342. for s in stimulus:
  343. self._next(s);
  344. time = np.arange(len(self.membrane_voltage)) * self.stepsize
  345. return time, np.array(self.membrane_voltage), np.array(self.spike_times)
  346. def __str__(self):
  347. out = '\n'.join(["stepsize: \t" + str(self.stepsize),
  348. "offset:\t\t" + str(self.offset),
  349. "tau_m:\t\t" + str(self.tau_m),
  350. "tau_a:\t\t" + str(self.tau_a),
  351. "da:\t\t" + str(self.da),
  352. "D:\t\t" + str(self.D),
  353. "v_threshold:\t" + str(self.v_threshold),
  354. "v_reset:\t" + str(self.v_reset)])
  355. return out
  356. def __repr__(self):
  357. return self.__str__()
  358. if __name__ == "__main__":
  359. # stim = "gwn300Hz10s0.3.dat"
  360. # time, stimulus = load_white_noise_stim(stim, 10, 40000)
  361. # np.savez("noise_stimulus.npz", time=time, stim=stim)
  362. import matplotlib.pyplot as plt
  363. from response_features import despine
  364. fig = plt.figure(figsize=(1.5, 1.5))
  365. ax = fig.add_subplot(111)
  366. k = gaussKernel(0.125, 0.0001)
  367. ax.plot(k)
  368. ax.set_yticklabels([])
  369. ax.set_xlim([0, 10000])
  370. ax.set_xticks([0, 2500, 5000, 7500, 10000])
  371. ax.set_xticklabels([r"-2$\pi$", r"-$\pi$" , "0", r"$\pi$", r"2$\pi$"])
  372. despine(ax, ["top", "right"], False)
  373. fig.subplots_adjust(bottom=0.2)
  374. fig.savefig("gaussian.pdf")
  375. plt.close()