123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- import os
- import numpy as np
- import matplotlib.mlab as mlab
- from scipy.interpolate import interp1d
- def read_file_list(filename):
- """Reads the files containing the dataset names.
- Parameters
- ----------
- filename : str
- The file name
- Returns
- -------
- list
- The dataset names
- Raises
- ------
- FileNotFoundError
- If the file could not be found
- """
- if os.path.exists(filename):
- datasets = []
- with open(filename, "r") as f:
- for l in f:
- if len(l.strip()) == 0:
- continue
- if "invivo" not in l:
- dataset = l.strip() + "-invivo-1"
- else:
- dataset = l.strip()
- datasets.append(dataset)
- return datasets
- else:
- raise FileNotFoundError(f"File {filename} was not found!")
- def threshold_crossings(data, time, threshold=0.0):
- """Detects rising threshold crossings in the data. Returns the time and the indices of the crossings.
- Parameters
- ----------
- data : np.array
- Vector of data values
- time : np.array
- Vector of respective recording times
- threshold : float, optional
- The threshold value, by default 0.0
- Returns
- -------
- np.array
- The crossing times
- np.array
- The respective indices in data and time
- """
- indices = np.where((data > threshold) & (np.roll(data, 1) <= threshold))[0]
- times = time[indices]
- return times, indices
- def line(p1, p2):
- a = (p1[1] - p2[1])
- b = (p2[0] - p1[0])
- c = (p1[0]*p2[1] - p2[0]*p1[1])
- return a, b, -c
- def intersection(l1, l2):
- d = l1[0] * l2[1] - l1[1] * l2[0]
- dx = l1[2] * l2[1] - l1[1] * l2[2]
- dy = l1[0] * l2[2] - l1[2] * l2[0]
- if d != 0:
- x = dx / d
- y = dy / d
- return x, y
- else:
- return False
- def detect_eod_times(time, eod, threshold=0.0, running_average=0):
- """Detects the EOD threshold crossings. Applies some optional smoothing before threshold detection.
- Uses linear interpolation to estimate the real times of the threshold crossings. Resulting EOD times,
- will have a higher temporal precision than given by the sampling of the data.
- Parameters
- ----------
- time : np.array
- vector containing the recording timestamps in seconds
- eod : np.array
- vector with the recorded EOD values in mV
- threshold : float, optional
- The EOD threshold in mV, by default 0.0
- running_average : int, optional
- the width of a running average used to smooth the recorded EOD, by default 0, no smoothing
- Returns
- -------
- np.array
- vector of EOD threshold crossings
- """
- if running_average > 0:
- eod = np.convolve(eod, np.ones(int(running_average))/int(running_average), mode="same")
- _, idx = threshold_crossings(eod, time, threshold)
- threshold_line = line([time[0], threshold], [time[-1], threshold])
- max_index = min(len(eod), len(time)) - 1
- valid_indices = idx[idx < max_index]
- times = np.zeros(valid_indices.shape)
- for j, index in enumerate(valid_indices):
- l1 = line([time[index - 1], eod[index - 1]], [time[index + 1], eod[index + 1]])
- intersec = intersection(l1, threshold_line)
- if isinstance(intersec, bool):
- times[j] = times[j-1]
- continue
- times[j] = intersection(l1, threshold_line)[0]
- return times, idx
- def gaussKernel(sigma, dt):
- """ Creates a Gaussian kernel with a given standard deviation and an integral of 1.
- Parameters
- ----------
- sigma : float
- The standard deviation of the kernel in seconds
- dt : float
- The temporal resolution of the kernel, given in seconds.
- Returns:
- np.array
- The kernel in the range -4 to +4 sigma
- """
- x = np.arange(-4. * sigma, 4. * sigma, dt)
- y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
- return y
- def spike_triggered_average(spikes, stimulus_time, stimulus_amplitudes, delay=0.03, samplingrate=20000):
- sta_time = np.arange(-delay, delay, 1/samplingrate)
- sta = np.zeros(sta_time.shape)
- count = 0
- for sp in spikes:
- if sp < delay or sp > stimulus_time[-1] - delay:
- continue
- start_index = int(np.round((sp - delay) * samplingrate))
- end_index = int(np.round((sp + delay) * samplingrate))
- sta += stimulus_amplitudes[start_index:end_index]
- count += 1
- sta /= count
- return sta_time, sta
- def get_temporal_shift(sta_time, sta):
- min_time = sta_time[np.argmin(sta)]
- max_time = sta_time[np.argmax(sta)]
- min_sta = sta[np.argmin(sta)]
- max_sta = sta[np.argmax(sta)]
- shift = 0.0
- inverted = False
- if np.abs(min_sta) > np.abs(max_sta) and min_time > max_time:
- shift = min_time
- inverted = True
- else:
- shift = max_time
- return np.abs(shift), inverted
- def load_stim(filename):
- """
- Loads a data file saved by relacs. Returns a tuple of dictionaries
- containing the data and the header information
- :param filename: Filename of the data file.
- :type filename: string
- :returns: a tuple of dictionaries containing the head information and the data.
- :rtype: tuple
- """
- with open(filename, 'r') as fid:
- L = [l.lstrip().rstrip() for l in fid.readlines()]
- ret = []
- dat = {}
- X = []
- keyon = False
- currkey = None
- for l in L:
- # if empty line and we have data recorded
- if (not l or l.startswith('#')) and len(X) > 0:
- keyon = False
- currkey = None
- dat['data'] = np.array(X)
- ret.append(dat)
- X = []
- dat = {}
- if '---' in l:
- continue
- if l.startswith('#'):
- if ":" in l:
- tmp = [e.rstrip().lstrip() for e in l[1:].split(':')]
- if currkey is None:
- dat[tmp[0]] = tmp[1]
- else:
- dat[currkey][tmp[0]] = tmp[1]
- elif "=" in l:
- tmp = [e.rstrip().lstrip() for e in l[1:].split('=')]
- if currkey is None:
- dat[tmp[0]] = tmp[1]
- else:
- dat[currkey][tmp[0]] = tmp[1]
- elif l[1:].lower().startswith('key'):
- dat['key'] = []
- keyon = True
- elif keyon:
- dat['key'].append(tuple([e.lstrip().rstrip() for e in l[1:].split()]))
- else:
- currkey = l[1:].rstrip().lstrip()
- dat[currkey] = {}
- elif l: # if l != ''
- keyon = False
- currkey = None
- X.append([float(e) for e in l.split()])
- if len(X) > 0:
- dat['data'] = np.array(X)
- else:
- dat['data'] = []
- ret.append(dat)
- return tuple(ret)
- def load_white_noise_stim(stim_name, stim_duration=10, sampling_rate=20000, folder="stimuli"):
- if stim_duration == 0.0:
- print("Stimulus duration must be larger than zero!")
- return None, None
- if not os.path.exists(folder):
- folder = os.path.expanduser(os.path.join("~", "data", "stimuli"))
- stim_file = stim_name.split('/')[-1]
- full_file = os.path.sep.join([folder, stim_file])
- if not os.path.exists(full_file):
- print("Stimulus file does not exist")
- return None, None
- s = load_stim(full_file)
- inter = interp1d(s[0]['data'][:, 0], s[0]['data'][:, 1])
- x_org = s[0]['data'][:, 0]
- x_new = np.linspace(0.0, stim_duration, int(stim_duration * sampling_rate))
- x_new[x_new > x_org[-1]] = x_org[-1]
- stimulus = inter(x_new)
- return x_new, stimulus
- def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
- """Convert spike times to a firing rate estimated by kernel convolution with a Gaussian kernel.
- Args:
- spikes (np.array): the spike times
- duration (float): the trial duration
- sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005.
- dt (float, optional): desired temporal resolution of the firing rate. Defaults to 1./20000..
- Returns:
- np.array: the firing rate
- """
- binary = np.zeros(int(np.round(duration/dt)))
- indices = np.asarray(np.round(spikes / dt), dtype=np.int)
- binary[indices[indices < len(binary)]] = 1
- kernel = gaussKernel(sigma, dt)
- rate = np.convolve(kernel, binary, mode="same")
- return rate
- def coherence_rate(coherence, frequency, start_freq, stop_freq):
- deltaf = np.mean(np.diff(frequency))
- spectrum = -np.log2((np.ones(coherence.shape) - coherence)) * deltaf
-
- info = np.sum(spectrum[(frequency >= start_freq) & (frequency < stop_freq)])
- return info
- def mutual_info(spike_responses, delay, inversion_needed, stimulus, freq_bin_edges, kernel_sigma=0.00125, delay_type="equal", stepsize=1./20000, trial_duration=10.):
- """[summary]
- Args:
- spike_responses (list of list of spike times): [description]
- delay (float): [description]
- inversion_needed ([type]): [description]
- stimulus ([type]): [description]
- freq_bin_edges ([type]): [description]
- kernel_sigma (float, optional): [description]. Defaults to 0.00125.
- Returns:
- np.array: array containing the frequency axis of the coherence function
- np.array: the average coherence function
- np.array: mutual information values, shape is len of freq_bin_edges + 1
- np.array: avg_rate; the average firing rate
- np.array: rate_error; std of firing rates as function of time
- float: true_delay; avg of the real delays
- np.array: variability; the time-averaged rate
- """
- population_rates = None
- rng = np.random.default_rng()
- delays = np.zeros(len(spike_responses))
- for i, (sr, invert) in enumerate(zip(spike_responses, inversion_needed)):
- if isinstance(sr, list):
- sr = np.array(sr)
- dt = 0.0
- if "equal" in delay_type:
- dt = rng.random() * delay * 2
- elif "gaussian" in delay_type:
- dt = rng.standard_random() * delay
- delays[i] = dt
- sr += dt
- rate = firing_rate(sr, trial_duration, kernel_sigma, dt=stepsize)
- if invert:
- avg = np.mean(rate)
- rate -= avg
- rate *= -1
- rate += avg
- if population_rates is None:
- population_rates = np.zeros((len(rate), len(spike_responses)))
- population_rates[:, i] = rate
- c, f = mlab.cohere(np.mean(population_rates, axis=1), stimulus, NFFT=2**14, noverlap=2**13,
- Fs=1./stepsize, detrend=mlab.detrend_mean, window=mlab.window_hanning)
- mis = np.zeros(len(freq_bin_edges))
- mis[0] = coherence_rate(c, f, 0, freq_bin_edges[-1])
- for i in range(1, len(freq_bin_edges)):
- mis[i] = coherence_rate(c, f, freq_bin_edges[i-1], freq_bin_edges[i])
- return f, c, mis, population_rates, np.mean(delays)
- class LIF(object):
- def __init__(self, stepsize=0.0001, offset=1.6, tau_m=0.010, tau_a=0.02, da=0.0, D=3.5):
- self.stepsize = stepsize # simulation stepsize [s]
- self.offset = offset # offset curent [nA]
- self.tau_m = tau_m # membrane time_constant [s]
- self.tau_a = tau_a # adaptation time_constant [s]
- self.da = da # increment in adaptation current [nA]
- self.D = D # noise intensity
- self.v_threshold = 1.0 # spiking threshold
- self.v_reset = 0.0 # reset voltage after spiking
- self.i_a = 0.0 # current adaptation current
- self.v = self.v_reset # current membrane voltage
- self.t = 0.0 # current time [s]
- self.membrane_voltage = []
- self.spike_times = []
- def _reset(self):
- self.i_a = 0.0
- self.v = self.v_reset
- self.t = 0.0
- self.membrane_voltage = []
- self.spike_times = []
- def _lif(self, stimulus, noise):
- """
- euler solution of the membrane equation with adaptation current and noise
- """
- self.i_a -= self.i_a - self.stepsize/self.tau_a * (self.i_a)
- self.v += self.stepsize * ( -self.v + stimulus + noise + self.offset - self.i_a)/self.tau_m;
- self.membrane_voltage.append(self.v)
- def _next(self, stimulus):
- """
- working horse which delegates to the euler and gets the spike times
- """
- noise = self.D * (float(np.random.randn() % 10000) - 5000.0)/10000
- self._lif(stimulus, noise)
- self.t += self.stepsize
- if self.v > self.v_threshold and len(self.membrane_voltage) > 1:
- self.v = self.v_reset
- self.membrane_voltage[len(self.membrane_voltage)-1] = 2.0
- self.spike_times.append(self.t)
- self.i_a += self.da
- def run_const_stim(self, steps, stimulus):
- """
- lif simulation with constant stimulus.
- """
- self._reset()
- for i in range(steps):
- self._next(stimulus);
- time = np.arange(len(self.membrane_voltage))*self.stepsize
- return time, np.array(self.membrane_voltage), np.array(self.spike_times)
- def run_stimulus(self, stimulus):
- """
- lif simulation with a predefined stimulus trace.
- """
- self._reset()
- for s in stimulus:
- self._next(s);
- time = np.arange(len(self.membrane_voltage)) * self.stepsize
- return time, np.array(self.membrane_voltage), np.array(self.spike_times)
- def __str__(self):
- out = '\n'.join(["stepsize: \t" + str(self.stepsize),
- "offset:\t\t" + str(self.offset),
- "tau_m:\t\t" + str(self.tau_m),
- "tau_a:\t\t" + str(self.tau_a),
- "da:\t\t" + str(self.da),
- "D:\t\t" + str(self.D),
- "v_threshold:\t" + str(self.v_threshold),
- "v_reset:\t" + str(self.v_reset)])
- return out
- def __repr__(self):
- return self.__str__()
-
- if __name__ == "__main__":
- # stim = "gwn300Hz10s0.3.dat"
- # time, stimulus = load_white_noise_stim(stim, 10, 40000)
- # np.savez("noise_stimulus.npz", time=time, stim=stim)
-
- import matplotlib.pyplot as plt
- from response_features import despine
- fig = plt.figure(figsize=(1.5, 1.5))
- ax = fig.add_subplot(111)
- k = gaussKernel(0.125, 0.0001)
- ax.plot(k)
- ax.set_yticklabels([])
- ax.set_xlim([0, 10000])
- ax.set_xticks([0, 2500, 5000, 7500, 10000])
- ax.set_xticklabels([r"-2$\pi$", r"-$\pi$" , "0", r"$\pi$", r"2$\pi$"])
- despine(ax, ["top", "right"], False)
- fig.subplots_adjust(bottom=0.2)
- fig.savefig("gaussian.pdf")
- plt.close()
|