"""General use classes and functions""" import os import pickle import numpy as np from numpy import pi import pandas as pd from scipy.stats import linregress from scipy.optimize import curve_fit import scipy.ndimage.filters as filters import matplotlib.pyplot as plt from matplotlib import colors from matplotlib.ticker import FuncFormatter from matplotlib.offsetbox import AnchoredText RESPDFNAMES = ['mviresp', 'grtresp', 'sponresp', 'bestmviresp', 'bestgrtresp', 'grttunresp'] FITSDFNAMES = ['fits', 'sampfits'] MIDFNAMES = ['mviFMI', 'grtFMI', 'mviRMI', 'grtRMI', 'maxFMI', 'maxRMI', 'iposmi'] POSRUNDFNAMES = ['ipos_st8', 'ipos_opto', 'upos', 'runspeed'] CELLTYPEDFNAMES = ['celltype', 'cellscreenpos'] FIGDFNAMES = ['fig1', 'fig2', 'fig3', 'fig4', 'fig5', 'fig6', 'fig1S33S1', 'fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt', 'fig5S2'] DFNAMES = RESPDFNAMES + FITSDFNAMES + MIDFNAMES + POSRUNDFNAMES + CELLTYPEDFNAMES + FIGDFNAMES EXPORTDFNAMES = MIDFNAMES + FIGDFNAMES # to .csv STRNAMES = ['mvimseustrs', 'grtmseustrs', 'mvigrtmsustrs'] def load(name, subfolder=''): """Return variable (e.g. a DataFrame) from a pickle""" path = os.path.dirname(__file__) fname = os.path.join(path, 'pickles', subfolder, name+'.pickle') with open(fname, 'rb') as f: val = pd.read_pickle(f) # backward compatible with pickles generated by pandas < 1.0 return val def load_all(subfolder=''): """Load all variables (DFNAMES+STRNAMES) from pickles, return name:val dict""" name2val = {} for name in DFNAMES+STRNAMES: print('Loading', name) try: val = load(name, subfolder=subfolder) name2val[name] = val except FileNotFoundError as err: print(err) return name2val def save(ns, dfnames=None, strnames=None, subfolder=''): """Save DFNAMES, STRNAMES and optionally FIGDFNAMES to disk, given a namespace ns (e.g. locals()). Useful for reducing unnecessary recalculation and/or in case of lack of database connection""" if subfolder == '' and ns['EXPTYPE'] != 'pvmvis': subfolder = ns['EXPTYPE'] # prevent accidentally overwriting the default PV pickles path = os.path.join(os.path.dirname(__file__), 'pickles', subfolder) if dfnames == None: dfnames = DFNAMES assert type(dfnames) in [list, tuple] for dfname in dfnames: fname = os.path.join(path, dfname+'.pickle') print('Saving', fname) if dfname not in ns: print('WARNING: %r not found' % dfname) else: ns[dfname].to_pickle(fname) if strnames == None: strnames = STRNAMES assert type(strnames) in [list, tuple] for strname in strnames: fname = os.path.join(path, strname+'.pickle') print('Saving', fname) with open(fname, 'wb') as f: pickle.dump(ns[strname], f) def export2csv(ns, dfnames=None, subfolder=''): """Given a namespace ns (e.g. locals()), export all figure-specific and modulation index DataFrames to .csv for Steffen to analyze in R, dfnames: list of strings""" path = os.path.join(os.path.dirname(__file__), 'csv', subfolder) if dfnames == None: dfnames = EXPORTDFNAMES for dfname in dfnames: fname = os.path.join(path, dfname+'.csv') print('Saving', fname) if dfname not in ns: print('WARNING: %r not found' % dfname) else: df = ns[dfname] columns = list(df.columns) for col in columns: if np.all(pd.isna(df[col])): print('WARNING: found empty column %r in df %r' % (col, dfname)) if dfname == 'fig1': # explode various columns so that each trial gets its own row, # can only explode one column at a time: newdf = df.explode('trialis') newdf['rates'] = df.explode('rates')['rates'].values newdf['rate02s'] = df.explode('rate02s')['rate02s'].values newdf['rate35s'] = df.explode('rate35s')['rate35s'].values newdf['burstratios'] = df.explode('burstratios')['burstratios'].values newdf['blankrates'] = df.explode('blankrates')['blankrates'].values newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values df = newdf elif dfname == 'fig3': # explode various columns so that each trial gets its own row, # can only explode one column at a time: newdf = df.explode('trialis') # includes both non-blank and blank trials newdf['rates'] = df.explode('rates')['rates'].values newdf['burstratios'] = df.explode('burstratios')['burstratios'].values newdf['blankrates'] = df.explode('blankrates')['blankrates'].values newdf['blankcondrates'] = df.explode('blankcondrates')['blankcondrates'].values newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values newdf['blankcondburstratios'] = df.explode('blankcondburstratios')[ 'blankcondburstratios'].values df = newdf elif dfname in ['fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt']: # explode various columns so that each trial gets its own row, # can only explode one column at a time: newdf = df.explode('trialis') # includes both non-blank and blank trials newdf['rates'] = df.explode('rates')['rates'].values newdf['burstratios'] = df.explode('burstratios')['burstratios'].values df = newdf elif dfname == 'fig4S1': newdf['blankrates'] = df.explode('blankrates')['blankrates'].values df = newdf elif dfname == 'fig5': # explode various columns so that each trial gets its own row, # can only explode one column at a time: newdf = df.explode('trialis') newdf['rates'] = df.explode('rates')['rates'].values newdf['burstratios'] = df.explode('burstratios')['burstratios'].values df = newdf elif dfname.startswith('ipos_'): # exclude export of pupil area trial matrices and timepoints: columns.remove('area_trialmat') columns.remove('area_trialts') # explode various columns so that each trial gets its own row, # can only explode one column at a time: newdf = df.explode('trialis') newdf['area_trialmean'] = df.explode('area_trialmean')['area_trialmean'].values df = newdf df.to_csv(fname, columns=columns) def desat(hexcolor, alpha): """Manually desaturate hex RGB color. Plotting directly with alpha results in saturated colors appearing to be plotted over top of desaturated colors, even when plotted in a lower layer""" return mixalpha(hexcolor, alpha) def mixalpha(hexcolor, alpha=1, bg='#ffffff'): """Mix alpha into hexcolor, assuming background color. See https://stackoverflow.com/a/21576659/2020363""" rgb = np.array(colors.hex2color(hexcolor)) # convert to float RGB array bg = np.array(colors.hex2color(bg)) rgb = alpha*rgb + (1 - alpha)*bg # mix it return colors.rgb2hex(rgb) def axes_disable_scientific(axes, axiss=None): """Disable scientific notation for both axes labels, useful for log-log plots. See https://stackoverflow.com/a/49306588/3904031""" if axiss == None: axiss = [axes.xaxis, axes.yaxis] for axis in axiss: ff = FuncFormatter(lambda y, _: '{:.16g}'.format(y)) axis.set_major_formatter(ff) def ms2msstr(msu): """Convert ms dictionary to ms string""" return "%s_s%02d" % (msu['m'], msu['s']) def msu2msustr(msu): """Convert msu dictionary to msu string""" return "%s_s%02d_u%02d" % (msu['m'], msu['s'], msu['u']) def mseustr2msustr(mseustr): """Convert an mseu string to an msu string, i.e. drop the experiment ID""" msustr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 5]]) return msustr def mseustrs2msustrs(mseustrs): """Convert a sequence of mseu strings to msu strings, i.e. drop the experiment ID""" msustrs = [] for mseustr in mseustrs: msustr = mseustr2msustr(mseustr) msustrs.append(msustr) return msustrs def mseustr2mstr(mseustr): """Convert an mseu string to a mouse string""" mstr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2]]) return mstr def mseustrs2mstrs(mseustrs): """Convert a sequence of mseu strings to mouse strings""" mstrs = [] for mseustr in mseustrs: mstr = mseustr2mstr(mseustr) mstrs.append(mstr) return mstrs def mseustr2msestr(mseustr): """Convert an mseu string to an mse string""" msestr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 4]]) return msestr def mseustrs2msestrs(mseustrs): """Convert a sequence of mseu strings to mse strings""" msestrs = [] for mseustr in mseustrs: msestr = mseustr2msestr(mseustr) msestrs.append(msestr) return msestrs def findmse(mseustrs, msestr): """Return boolean array of all entries in mseustrs that match the experiment described by msestr""" mse = msestr.split('_') assert len(mse) == 5 # strain, year, number, series, experiment hits = np.tile(False, len(mseustrs)) for i, mseustr in enumerate(mseustrs): mseu = mseustr.split('_') assert len(mseu) == 6 if mseu[:5] == mse: hits[i] = True return hits def fitmodel(ctrlfit, optofit, ctrltest, optotest, model=None): """Fit a model to ctrl and opto fit signals, test with ctrl and opto test signals""" if model == 'linear': mm, b, rr, p, stderr = linregress(ctrlfit, optofit) rsq = rr * rr raise RuntimeError('R2 for linear fit needs to be calculated on test data') elif model == 'threshlin': p, pcov = curve_fit(threshlin, xdata=ctrlfit, ydata=optofit, p0=None) mm, b = p # for each sample, calc rsq on the test data: rsq = rsquared(optotest, threshlin(ctrltest, *p)) else: raise ValueError("Unknown model %r" % model) return mm, b, rsq def threshlin(x, m, b): """Return threshold linear model""" y = m * x + b y[y < 0] = 0 return y def rsquared(targets, predictions): """Return the r-squared value for the fit""" residuals = targets - predictions residual_variance = np.sum(residuals**2) variance_of_targets = np.sum((targets - np.mean(targets))**2) if variance_of_targets == 0: rsq = np.nan else: rsq = 1 - (residual_variance / variance_of_targets) return rsq def residual_rsquared(targets, residuals): """Return the r-squared value for the fit, given the target values and residuals. Minor variation of djd.model.rsquared()""" ssres = np.sum(residuals**2) sstot = np.sum((targets - np.mean(targets))**2) if sstot == 0: rsq = np.nan else: rsq = 1 - (ssres / sstot) return rsq def linear_loss(params, x, y): """Linear loss function, for use with scipy.optimize.least_squares()""" assert len(params) == 2 m, b = params # unpack return m*x + b - y def get_max_snr(mvirespr, mseustr, kind, st8): """Find SNR for mseustr, kind, st8 combination, take max across opto conditions""" mvirowis = ((mvirespr['mseu'] == mseustr) & (mvirespr['kind'] == kind) & (mvirespr['st8'] == st8)) mvirows = mvirespr[mvirowis] assert len(mvirows) == 2 fbsnr = mvirows[mvirows['opto'] == False]['snr'].iloc[0] # feedback supsnr = mvirows[mvirows['opto'] == True]['snr'].iloc[0] # suppression maxsnr = max(fbsnr, supsnr) # take the max of the two conditions return maxsnr def intround(n): """Round to the nearest integer, return an integer. Works on arrays. Saves on parentheses, nothing more""" if np.iterable(n): # it's a sequence, return as an int64 array return np.int64(np.round(n)) else: # it's a scalar, return as normal Python int return int(round(n)) def split_tranges(tranges, width, tres): """Split up tranges into lots of smaller (typically overlapping) tranges, with width and tres. Usually, tres < width, but this also works for width < tres. Test with: print(split_tranges([(0,100)], 1, 10)) print(split_tranges([(0,100)], 10, 1)) print(split_tranges([(0,100)], 10, 10)) print(split_tranges([(0,100)], 10, 8)) print(split_tranges([(0,100)], 3, 10)) print(split_tranges([(0,100)], 10, 3)) print(split_tranges([(0,100)], 3, 8)) print(split_tranges([(0,100)], 8, 3)) """ newtranges = [] for trange in tranges: t0, t1 = trange assert width < (t1 - t0) # calculate left and right edges of subtranges that fall within trange: # This is tricky: find maximum left edge such that the corresponding maximum right # edge goes as close as possible to t1 without exceeding it: tend = (t1-width+tres) // tres*tres # there might be a nicer way, but this works ledges = np.arange(t0, tend, tres) redges = ledges + width subtranges = [ (le, re) for le, re in zip(ledges, redges) ] newtranges.append(subtranges) return np.vstack(newtranges) def wrap_raster(raster, t0, t1, newdt, offsets=[0, 0]): """Extract event times in raster (list or array of arrays) between t0 and t1 (s), and wrap into extra rows such that event times never exceed newdt (s)""" t1floor = t1 - t1 % newdt t0s = np.arange(t0, t1floor, newdt) # end exclusive t1s = t0s + newdt tranges = np.column_stack([t0s, t1s]) wrappedraster = [] for row in raster: dst = [] # init list to collect events for this row for trange in tranges: # search within trange, but take into account desired offsets: si0, si1 = row.searchsorted(trange + offsets) # get spike times relative to start of trange: dst.append(row[si0:si1] - trange[0]) # convert from list to object array to enable fancy indexing: wrappedraster.extend(dst) return np.asarray(wrappedraster) def cf(): """Close all figures""" plt.close('all') def saveall(path=None, format='png'): """Save all open figures to chosen path, pop up dialog box if path is None""" if path is None: # query with dialog box for a path from matplotlib import rcParams startpath = os.path.expanduser(rcParams['savefig.directory']) # get default path = choose_path(startpath, msg="Choose a folder to save to") if not path: # dialog box was cancelled return # don't do anything rcParams['savefig.directory'] = path # update default fs = [ plt.figure(i) for i in plt.get_fignums() ] for f in fs: fname = f.canvas.get_window_title() + '.' + format fname = fname.replace(' ', '_') fullfname = os.path.join(path, fname) print(fullfname) f.savefig(fullfname) def lastcmd(): """Return a string containing the last command entered by the user in the Ipython shell. Useful for generating plot titles""" ip = get_ipython() return ip._last_input_line def wintitle(titlestr=None, f=None): """Set title of current MPL window, defaults to last command entered""" if titlestr is None: titlestr = lastcmd() if f is None: f = plt.gcf() f.canvas.set_window_title(titlestr) def simpletraster(raster, dt=5, offsets=[0, 0], s=1, clr='k', scatter=False, scattermarker='|', scattersize=10, burstis=None, burstclr='r', axisbg='w', alpha=1, inchespersec=1.5, inchespertrial=1/25, ax=None, figsize=None, title=False, xaxis=True, label=None): """ Create a simple trial raster plot. Each entry in raster is a list of spike times relative to the start of each trial. dt : trial duration (s) offsets : offsets relative to trial start and end (s) s : tick linewidths clr : tick color, either a single color or a sequence of colors, one per trial scatter : whether to use original ax.scatter() command to plot much faster and use much less memory, but with potentially vertically overlapping ticks. Otherwise, default to slower ax.eventplot() burstis : burst indices, as returned by FiringPattern().burst_ratio() """ ntrials = len(raster) spiketrialis, c = [], [] # get raster tick color of each trial: if type(clr) == str: # all trials have the same color clr = list(colors.to_rgba(clr)) clr[3] = alpha # apply alpha, so that we can control alpha per tick trialclrs = [clr]*ntrials else: # each trial has potentially a different color assert type(clr) in [list, np.ndarray] assert len(clr) == ntrials trialclrs = [] for trialclr in clr: trialclr = list(colors.to_rgba(trialclr)) trialclr[3] = alpha # apply alpha, so that we can control alpha per tick trialclrs.append(trialclr) burstclr = colors.to_rgba(burstclr) # keep full saturation for burst spikes # collect 1-based trial info, one entry per spike: for triali, rastertrial in enumerate(raster): nspikes = len(rastertrial) spiketrialis.append(np.tile(triali+1, nspikes)) # 1-based trialclr = trialclrs[triali] spikecolors = np.tile(trialclr, (nspikes, 1)) if burstis is not None: bis = burstis[triali] if len(bis) > 0: spikecolors[bis] = burstclr c.append(spikecolors) # convert each list of arrays to a single flat array: raster = np.hstack(raster) spiketrialis = np.hstack(spiketrialis) c = np.concatenate(c) xmin, xmax = offsets[0], dt + offsets[1] totaldt = xmax - xmin # total raster duration, including offsets if ax == None: if figsize is None: figwidth = min(1 + totaldt*inchespersec, 12) figheight = min(1 + ntrials*inchespertrial, 12) figsize = figwidth, figheight f, ax = plt.subplots(figsize=figsize) if scatter: # scatter doesn't carefully control vertical spacing, allows vertical overlap of ticks: ax.scatter(raster, spiketrialis, marker=scattermarker, c=c, s=scattersize, label=label) else: # eventplot is slower, but does a better job: raster = raster[:, np.newaxis] # somehow eventplot requires an extra unitary dimension if len(raster) == 0: print("No spikes for eventplot %r" % title) # prevent TypeError from eventplot() else: ax.eventplot(raster, lineoffsets=spiketrialis, colors=c, linewidth=s, label=label) ax.set_xlim(xmin, xmax) # -1 inverts the y axis, +1 ensures last trial is fully visible: ax.set_ylim(ntrials+1, -1) ax.set_facecolor(axisbg) ax.set_xlabel('Time (s)') ax.set_ylabel('Trial') if label: ax.legend(loc="best") if title: #a.set_title(title) wintitle(title) if xaxis != True: if xaxis == False: renderer = f.canvas.get_renderer() bbox = a.xaxis.get_tightbbox(renderer).transformed(f.dpi_scale_trans.inverted()) xaxis = bbox.height figheight = figheight - xaxis ax.get_xaxis().set_visible(False) ax.spines['bottom'].set_visible(False) f.set_figheight(figheight) #f.tight_layout(pad=0.3) # crop figure to contents, doesn't seem to do anything any more #f.show() return ax def raster2psth(raster, bins, binw, tres, kernel='gauss'): """Convert a spike trial raster to a peri-stimulus time histogram (PSTH). To calculate the PSTH of a subset of trials, pass a raster containing only that subset. Parameters ---------- raster : spike trial raster as a sequence of arrays of spike times (s), one array per trial bins : 2D array of start and stop PSTH bin edge times (s), one row per bin. Bins may or may not be overlapping. Typically generated using util.split_tranges() binw : PSTH bin width (s) that was used to generate bins tres : temporal resolution (s) that was used to generate bins, only used if kernel=='gauss' kernel : smoothing kernel : None or 'gauss' Returns ------- psth : peri-stimulus time histogram (Hz), normalized by bin width and number of trials """ # make sure raster has nested iterables, i.e. list of arrays, or array of arrays, etc., # even if there's only one array inside raster representing only one trial: if len(raster) > 0: # not an empty raster trial0 = raster[0] if type(trial0) not in (np.ndarray, list): raise ValueError("Ensure that raster is a sequence of arrays of spike times,\n" "one per trial. If you're passing only a single extracted trial,\n" "make sure to pass it within e.g. a list of length 1") # now it's safe to assume that len(raster) represents the number of included trials, # and not erroneuosly the number of spikes in a single unnested array of spike times: ntrials = len(raster) if ntrials == 0: # empty raster spikes = np.asarray(raster) else: spikes = np.hstack(raster) # flatten across trials spikes.sort() spikeis = spikes.searchsorted(bins) # where bin edges fall in spikes # convert to rate: number of spikes in each bin, normalized by binw: psth = (spikeis[:, 1] - spikeis[:, 0]) / binw if kernel is None: # rectangular bins pass elif kernel == 'gauss': # apply Gaussian filtering sigma = binw / 2 # set sigma to half the bin width (sec) sigmansamples = sigma / tres # sigma as multiple of number of samples (unitless) psth = filters.gaussian_filter1d(psth, sigma=sigmansamples) else: raise ValueError('Unknown kernel %r' % kernel) # normalize by number of trials: if ntrials != 0: psth = psth / ntrials return psth def raster2freqcomp(raster, dt, f, mean='scalar'): """Extract a frequency component from spike raster (one row of spike times per trial). Adapted from getHarmsResps.m and UnitGetHarm.m Parameters ---------- raster : spike raster as a sequence of arrays of spike times (s), one array per trial dt : trial duration (s) f : frequency to extract (Hz), f=0 extracts mean firing rate mean : 'scalar': compute mean of amplitudes of each trial's vector (mean(abs)), i.e. find frequency component f separately for each trial, then take average amplitude. 'vector': compute mean across all trials before calculating amplitude (abs(mean)), equivalent to first calculating PSTH from all rows of raster Returns ------- r : peak-to-peak amplitude of frequency component f theta : angle of frequency component f (rad) Examples -------- >>> inphase = np.array([0, 1, 2, 3, 4]) # spike times (s) >>> outphase = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) # spike times (s) >>> raster2freqcomp([inphase], 5, 1) # single trial, 'mean' is irrelevant (2.0, -4.898587196589412e-16) >>> raster2freqcomp([outphase], 5, 1) (2.0, 3.1415926535897927) >>> raster2freqcomp([inphase, outphase], 5, 1, mean='scalar') (2.0, 1.5707963267948961) >>> raster2freqcomp([inphase, outphase], 5, 1, mean='vector') (1.2246467991473544e-16, 1.5707963267948966) Using f=0 returns mean firing rate: >>> raster2freqcomp([inphase, outphase], 5, 0, mean='scalar') (1.0, 0.0) >>> raster2freqcomp([inphase, outphase], 5, 0, mean='vector') (1.0, 0.0) """ ntrials = len(raster) res, ims = np.zeros(ntrials), np.zeros(ntrials) # init real and imaginary components for triali, spikes in enumerate(raster): # iterate over trials if len(spikes) == 0: continue spikes = np.asarray(spikes) # in case raster is a list of lists if spikes.max() > dt: print('spikes exceeding dt:', spikes[spikes > dt]) # discard rare spikes in raster that for some reason (screen vsyncs?) fall outside # the expected trial duration: spikes = spikes[spikes <= dt] omega = 2 * np.pi * f # angular frequency (rad/s) res[triali] = (np.cos(omega*spikes)).sum() / dt ims[triali] = (np.sin(omega*spikes)).sum() / dt Hs = (res + ims*1j) # array of complex numbers if f != 0: # not just the degenerate mean firing rate case Hs = 2 * Hs # convert to peak-to-peak if mean == 'scalar': Hamplitudes = np.abs(Hs) # ntrials long r = np.nanmean(Hamplitudes) # mean of amplitudes theta = np.nanmean(np.angle(Hs)) # mean of angles #rstd = np.nanstd(Hamplitudes) # stdev of amplitudes elif mean == 'vector': Hmean = np.nanmean(Hs) # single complex number r = np.abs(Hmean) # corresponds to PSTH amplitude theta = np.angle(Hmean) # angle of mean vector #rstd = np.nanstd(Hs) # single scalar, corresponds to PSTH stdev else: raise ValueError('Unknown `mean` method %r' % mean) ## NOTE: another way to calculate theta might be: #theta = np.arctan2(np.nanmean(np.imag(Hs)), np.nanmean(np.real(Hs))) return r, theta def sparseness(x): """Sparseness measure, from Vinje and Gallant, 2000. This is basically 1 minus the ratio of the square of the sums over the sum of the squares of the values in signal x""" if x.sum() == 0: return 0 n = len(x) return (1 - (x.sum()/n)**2 / np.sum((x**2)/n)) / (1 - 1/n) def reliability(signals, average='mean', ignore_nan=True): """Calculate reliability across trials in signals, one row per trial, by finding the average Pearson's rho between all pairwise combinations of trial signals Returns ------- reliability : float rhos : ndarray """ ntrials = len(signals) if ntrials < 2: return np.nan # can't calculate reliability with less than 2 trials if ignore_nan: rhos = pairwisecorr_nan(signals) else: rhos, _ = pairwisecorr(signals) if average == 'mean': rel = np.nanmean(rhos) elif average == 'median': rel = np.nanmedian(rhos) else: raise ValueError('Unknown average %r' % average) return rel, rhos def snr_baden2016(signals): """Return signal-to-noise ratio, aka Berens quality index (QI), from Baden2016, of a set of signals. Take ratio of the temporal variance of the trial averaged signal (i.e. PSTH), to the average across trials of the variance in time of each trial. Ranges from 0 to 1. Ignores NaNs.""" assert signals.ndim == 2 signal = np.nanvar(np.nanmean(signals, axis=0)) # reduce row axis, calc var across time noise = np.nanmean(np.nanvar(signals, axis=1)) # reduce time axis, calc mean across trials if signal == 0: return 0 else: return signal / noise def get_psth_peaks_gac(ts, t, psth, thresh, sigma=0.02, alpha=1.0, minpoints=5, lowp=16, highp=84, checkthresh=True, verbose=True): """Extract PSTH peaks from spike times ts collapsed across trials, by clustering them using gradient ascent clustering (GAC, Swindale2014). Then, optionally check each peak against its amplitude in the PSTH (and its time stamps t), to ensure it passes thresh. Also extract the left and right edges of each peak, based on where each peak's mass falls between lowp and highp percentiles. sigma is the clustering bandwidth used by GAC, in this case in seconds. Note that very narrow peaks will be missed if the resolution of the PSTH isn't high enough (TRES=0.0001 is plenty)""" from spyke.gac import gac # .pyx file ts2d = np.float32(ts[:, None]) # convert to 2D (one row per spike), contig float32 # get cluster IDs and positions corresponding to spikets, cpos is indexed into using # cids: cids, cpos = gac(ts2d, sigma=sigma, alpha=alpha, minpoints=minpoints, returncpos=True, verbose=verbose) ucids = np.unique(cids) # unique cluster IDs across all spikets ucids = ucids[ucids >= 0] # exclude junk cluster -1 #npeaks = len(ucids) # but not all of them will necessarily cross the PSTH threshold peakis, lis, ris = [], [], [] for ucid, pos in zip(ucids, cpos): # clusters are numbered in order of decreasing size spikeis, = np.where(cids == ucid) cts = ts[spikeis] # this cluster's spike times # search all spikes for argmax, same as using lowp=0 and highp=100: #li, ri = t.searchsorted([cts[0], cts[-1]]) # search only within the percentiles for argmax: lt, rt = np.percentile(cts, [lowp, highp]) li, ri = t.searchsorted([lt, rt]) if li == ri: # start and end indices are identical, cluster probably falls before first or # after last spike time: assert li == 0 or li == len(psth) continue # no peak to be found in psth for this cluster localpsth = psth[li:ri] # indices of all local peaks within percentiles in psth: #allpeakiis, = argrelextrema(localpsth, np.greater) #if len(allpeakiis) == 0: # continue # no peaks found for this cluster # find peakii closest to pos: #peakii = allpeakiis[abs((t[li + allpeakiis] - pos)).argmin()] # find biggest peak: #peakii = allpeakiis[localpsth[allpeakiis].argmax()] peakii = localpsth.argmax() # find max point if peakii == 0 or peakii == len(localpsth)-1: continue # skip "peak" that's really just a start or end point of localpsth peaki = li + peakii if checkthresh and psth[peaki] < thresh: continue # skip peak that doesn't meet thresh if peaki in peakis: continue # this peak has already been detected by a preceding, larger, cluster peakis.append(peaki) lis.append(li) ris.append(ri) if verbose: print('.', end='') # indicate a peak has been found return np.asarray(peakis), np.asarray(lis), np.asarray(ris) def pairwisecorr(signals, weight=False, invalid='ignore'): """Calculate Pearson correlations between all pairs of rows in 2D signals array. See np.seterr() for possible values of `invalid`""" assert signals.ndim == 2 assert len(signals) >= 2 # at least two rows, i.e. at least one pair N = len(signals) # potentially allow 0/0 (nan) rhomat entries by ignoring 'invalid' errors # (not 'divide'): oldsettings = np.seterr(invalid=invalid) rhomat = np.corrcoef(signals) # full correlation matrix np.seterr(**oldsettings) # restore previous numpy error settings uti = np.triu_indices(N, k=1) rhos = rhomat[uti] # pull out the upper triangle if weight: sums = signals.sum(axis=1) # weight each pair by the one with the least signal: weights = np.vstack((sums[uti[0]], sums[uti[1]])).min(axis=0) # all pairs weights = weights / weights.sum() # normalize, ensure float division return rhos, weights else: return rhos, None def pairwisecorr_nan(signals): """Calculate Pearson correlations between all pairs of rows in 2D signals array, while skipping NaNs. Relies on Pandas DataFrame method""" assert signals.ndim == 2 assert len(signals) >= 2 # at least two rows, i.e. at least one pair N = len(signals) rhomat = np.array(pd.DataFrame(signals.T).corr()) # full correlation matrix return np.triu(rhomat, k=1) # non-unique entries zeroed def vector_OSI(oris, rates): """Vector averaging method for calculating orientation selectivity index (OSI). See Bonhoeffer1995, Swindale1998 and neuropy.neuron.Tune.pref(). Reasonable use case is to take model tuning curve, calculate its values at a fine ori resolution (say 1 deg), and use that as input rates here. Parameters ---------- oris : orientations (degrees, potentially ranging 0 to 360) Attention: for orientation data, ori should always range 0 to 180! Only for direction data (e.g. from drifting gratings) can it go 0 to 360; it will then return the orientation selectivity of the data. For the direction selectivity, ori has to again range 0 to 180! rates : corresponding firing rates Returns ------- r : length of net vector average as fraction of total firing """ orisrad = 2 * oris * np.pi/180 # double the angle, convert from deg to rad x = (rates*np.cos(orisrad)).sum() y = (rates*np.sin(orisrad)).sum() n = rates.sum() r = np.sqrt(x**2+y**2) / n # fraction of total firing return r def percentile_ci(data, alpha=0.05, func=np.percentile, **kwargs): """Simple percentile method for confidence intervals. No assumptions about shape of distribution""" data = np.array(data) # accept lists & tuples lower, med, upper = func(data, [100*alpha, 50, 100*(1-alpha)], **kwargs) return med, lower, upper def sum_of_gaussians(x, dp, rp, rn, r0, sigma): """ORITUNE sum of two gaussians living on a circle, for orientation tuning x are the orientations, bet 0 and 360. dp is the preferred direction (bet 0 and 360) rp is the response to the preferred direction; rn is the response to the opposite direction; r0 is the background response (useful only in some cases) sigma is the tuning width; """ angles_p = 180 / pi * np.angle(np.exp(1j*(x-dp) * pi / 180)) angles_n = 180 / pi * np.angle(np.exp(1j*(x-dp+180) * pi / 180)) y = (r0 + rp*np.exp(-angles_p**2 / (2*sigma**2)) + rn*np.exp(-angles_n**2 / (2*sigma**2))) return y