123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780 |
- """General use classes and functions"""
- import os
- import pickle
- import numpy as np
- from numpy import pi
- import pandas as pd
- from scipy.stats import linregress
- from scipy.optimize import curve_fit
- import scipy.ndimage.filters as filters
- import matplotlib.pyplot as plt
- from matplotlib import colors
- from matplotlib.ticker import FuncFormatter
- from matplotlib.offsetbox import AnchoredText
- RESPDFNAMES = ['mviresp', 'grtresp', 'sponresp', 'bestmviresp', 'bestgrtresp', 'grttunresp']
- FITSDFNAMES = ['fits', 'sampfits']
- MIDFNAMES = ['mviFMI', 'grtFMI', 'mviRMI', 'grtRMI', 'maxFMI', 'maxRMI', 'iposmi']
- POSRUNDFNAMES = ['ipos_st8', 'ipos_opto', 'upos', 'runspeed']
- CELLTYPEDFNAMES = ['celltype', 'cellscreenpos']
- FIGDFNAMES = ['fig1', 'fig2', 'fig3', 'fig4', 'fig5', 'fig6',
- 'fig1S33S1', 'fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt', 'fig5S2']
- DFNAMES = RESPDFNAMES + FITSDFNAMES + MIDFNAMES + POSRUNDFNAMES + CELLTYPEDFNAMES + FIGDFNAMES
- EXPORTDFNAMES = MIDFNAMES + FIGDFNAMES # to .csv
- STRNAMES = ['mvimseustrs', 'grtmseustrs', 'mvigrtmsustrs']
- def load(name, subfolder=''):
- """Return variable (e.g. a DataFrame) from a pickle"""
- path = os.path.dirname(__file__)
- fname = os.path.join(path, 'pickles', subfolder, name+'.pickle')
- with open(fname, 'rb') as f:
- val = pd.read_pickle(f) # backward compatible with pickles generated by pandas < 1.0
- return val
- def load_all(subfolder=''):
- """Load all variables (DFNAMES+STRNAMES) from pickles, return name:val dict"""
- name2val = {}
- for name in DFNAMES+STRNAMES:
- print('Loading', name)
- try:
- val = load(name, subfolder=subfolder)
- name2val[name] = val
- except FileNotFoundError as err:
- print(err)
- return name2val
- def save(ns, dfnames=None, strnames=None, subfolder=''):
- """Save DFNAMES, STRNAMES and optionally FIGDFNAMES to disk, given a namespace ns
- (e.g. locals()). Useful for reducing unnecessary recalculation and/or in case of
- lack of database connection"""
- if subfolder == '' and ns['EXPTYPE'] != 'pvmvis':
- subfolder = ns['EXPTYPE'] # prevent accidentally overwriting the default PV pickles
- path = os.path.join(os.path.dirname(__file__), 'pickles', subfolder)
- if dfnames == None:
- dfnames = DFNAMES
- assert type(dfnames) in [list, tuple]
- for dfname in dfnames:
- fname = os.path.join(path, dfname+'.pickle')
- print('Saving', fname)
- if dfname not in ns:
- print('WARNING: %r not found' % dfname)
- else:
- ns[dfname].to_pickle(fname)
- if strnames == None:
- strnames = STRNAMES
- assert type(strnames) in [list, tuple]
- for strname in strnames:
- fname = os.path.join(path, strname+'.pickle')
- print('Saving', fname)
- with open(fname, 'wb') as f:
- pickle.dump(ns[strname], f)
- def export2csv(ns, dfnames=None, subfolder=''):
- """Given a namespace ns (e.g. locals()), export all figure-specific and modulation index
- DataFrames to .csv for Steffen to analyze in R, dfnames: list of strings"""
- path = os.path.join(os.path.dirname(__file__), 'csv', subfolder)
- if dfnames == None:
- dfnames = EXPORTDFNAMES
- for dfname in dfnames:
- fname = os.path.join(path, dfname+'.csv')
- print('Saving', fname)
- if dfname not in ns:
- print('WARNING: %r not found' % dfname)
- else:
- df = ns[dfname]
- columns = list(df.columns)
- for col in columns:
- if np.all(pd.isna(df[col])):
- print('WARNING: found empty column %r in df %r' % (col, dfname))
- if dfname == 'fig1':
- # explode various columns so that each trial gets its own row,
- # can only explode one column at a time:
- newdf = df.explode('trialis')
- newdf['rates'] = df.explode('rates')['rates'].values
- newdf['rate02s'] = df.explode('rate02s')['rate02s'].values
- newdf['rate35s'] = df.explode('rate35s')['rate35s'].values
- newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
- newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
- newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values
- df = newdf
- elif dfname == 'fig3':
- # explode various columns so that each trial gets its own row,
- # can only explode one column at a time:
- newdf = df.explode('trialis') # includes both non-blank and blank trials
- newdf['rates'] = df.explode('rates')['rates'].values
- newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
- newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
- newdf['blankcondrates'] = df.explode('blankcondrates')['blankcondrates'].values
- newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values
- newdf['blankcondburstratios'] = df.explode('blankcondburstratios')[
- 'blankcondburstratios'].values
- df = newdf
- elif dfname in ['fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt']:
- # explode various columns so that each trial gets its own row,
- # can only explode one column at a time:
- newdf = df.explode('trialis') # includes both non-blank and blank trials
- newdf['rates'] = df.explode('rates')['rates'].values
- newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
- df = newdf
- elif dfname == 'fig4S1':
- newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
- df = newdf
- elif dfname == 'fig5':
- # explode various columns so that each trial gets its own row,
- # can only explode one column at a time:
- newdf = df.explode('trialis')
- newdf['rates'] = df.explode('rates')['rates'].values
- newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
- df = newdf
- elif dfname.startswith('ipos_'):
- # exclude export of pupil area trial matrices and timepoints:
- columns.remove('area_trialmat')
- columns.remove('area_trialts')
- # explode various columns so that each trial gets its own row,
- # can only explode one column at a time:
- newdf = df.explode('trialis')
- newdf['area_trialmean'] = df.explode('area_trialmean')['area_trialmean'].values
- df = newdf
- df.to_csv(fname, columns=columns)
- def desat(hexcolor, alpha):
- """Manually desaturate hex RGB color. Plotting directly with alpha results in saturated
- colors appearing to be plotted over top of desaturated colors, even when plotted in a
- lower layer"""
- return mixalpha(hexcolor, alpha)
- def mixalpha(hexcolor, alpha=1, bg='#ffffff'):
- """Mix alpha into hexcolor, assuming background color.
- See https://stackoverflow.com/a/21576659/2020363"""
- rgb = np.array(colors.hex2color(hexcolor)) # convert to float RGB array
- bg = np.array(colors.hex2color(bg))
- rgb = alpha*rgb + (1 - alpha)*bg # mix it
- return colors.rgb2hex(rgb)
- def axes_disable_scientific(axes, axiss=None):
- """Disable scientific notation for both axes labels, useful for log-log plots.
- See https://stackoverflow.com/a/49306588/3904031"""
- if axiss == None:
- axiss = [axes.xaxis, axes.yaxis]
- for axis in axiss:
- ff = FuncFormatter(lambda y, _: '{:.16g}'.format(y))
- axis.set_major_formatter(ff)
- def ms2msstr(msu):
- """Convert ms dictionary to ms string"""
- return "%s_s%02d" % (msu['m'], msu['s'])
- def msu2msustr(msu):
- """Convert msu dictionary to msu string"""
- return "%s_s%02d_u%02d" % (msu['m'], msu['s'], msu['u'])
- def mseustr2msustr(mseustr):
- """Convert an mseu string to an msu string, i.e. drop the experiment ID"""
- msustr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 5]])
- return msustr
- def mseustrs2msustrs(mseustrs):
- """Convert a sequence of mseu strings to msu strings, i.e. drop the experiment ID"""
- msustrs = []
- for mseustr in mseustrs:
- msustr = mseustr2msustr(mseustr)
- msustrs.append(msustr)
- return msustrs
- def mseustr2mstr(mseustr):
- """Convert an mseu string to a mouse string"""
- mstr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2]])
- return mstr
- def mseustrs2mstrs(mseustrs):
- """Convert a sequence of mseu strings to mouse strings"""
- mstrs = []
- for mseustr in mseustrs:
- mstr = mseustr2mstr(mseustr)
- mstrs.append(mstr)
- return mstrs
- def mseustr2msestr(mseustr):
- """Convert an mseu string to an mse string"""
- msestr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 4]])
- return msestr
- def mseustrs2msestrs(mseustrs):
- """Convert a sequence of mseu strings to mse strings"""
- msestrs = []
- for mseustr in mseustrs:
- msestr = mseustr2msestr(mseustr)
- msestrs.append(msestr)
- return msestrs
- def findmse(mseustrs, msestr):
- """Return boolean array of all entries in mseustrs that match the experiment
- described by msestr"""
- mse = msestr.split('_')
- assert len(mse) == 5 # strain, year, number, series, experiment
- hits = np.tile(False, len(mseustrs))
- for i, mseustr in enumerate(mseustrs):
- mseu = mseustr.split('_')
- assert len(mseu) == 6
- if mseu[:5] == mse:
- hits[i] = True
- return hits
- def fitmodel(ctrlfit, optofit, ctrltest, optotest, model=None):
- """Fit a model to ctrl and opto fit signals, test with ctrl and opto test signals"""
- if model == 'linear':
- mm, b, rr, p, stderr = linregress(ctrlfit, optofit)
- rsq = rr * rr
- raise RuntimeError('R2 for linear fit needs to be calculated on test data')
- elif model == 'threshlin':
- p, pcov = curve_fit(threshlin, xdata=ctrlfit, ydata=optofit,
- p0=None)
- mm, b = p
- # for each sample, calc rsq on the test data:
- rsq = rsquared(optotest, threshlin(ctrltest, *p))
- else:
- raise ValueError("Unknown model %r" % model)
- return mm, b, rsq
- def threshlin(x, m, b):
- """Return threshold linear model"""
- y = m * x + b
- y[y < 0] = 0
- return y
- def rsquared(targets, predictions):
- """Return the r-squared value for the fit"""
- residuals = targets - predictions
- residual_variance = np.sum(residuals**2)
- variance_of_targets = np.sum((targets - np.mean(targets))**2)
- if variance_of_targets == 0:
- rsq = np.nan
- else:
- rsq = 1 - (residual_variance / variance_of_targets)
- return rsq
- def residual_rsquared(targets, residuals):
- """Return the r-squared value for the fit, given the target values and residuals.
- Minor variation of djd.model.rsquared()"""
- ssres = np.sum(residuals**2)
- sstot = np.sum((targets - np.mean(targets))**2)
- if sstot == 0:
- rsq = np.nan
- else:
- rsq = 1 - (ssres / sstot)
- return rsq
- def linear_loss(params, x, y):
- """Linear loss function, for use with scipy.optimize.least_squares()"""
- assert len(params) == 2
- m, b = params # unpack
- return m*x + b - y
- def get_max_snr(mvirespr, mseustr, kind, st8):
- """Find SNR for mseustr, kind, st8 combination, take max across opto conditions"""
- mvirowis = ((mvirespr['mseu'] == mseustr) &
- (mvirespr['kind'] == kind) &
- (mvirespr['st8'] == st8))
- mvirows = mvirespr[mvirowis]
- assert len(mvirows) == 2
- fbsnr = mvirows[mvirows['opto'] == False]['snr'].iloc[0] # feedback
- supsnr = mvirows[mvirows['opto'] == True]['snr'].iloc[0] # suppression
- maxsnr = max(fbsnr, supsnr) # take the max of the two conditions
- return maxsnr
- def intround(n):
- """Round to the nearest integer, return an integer. Works on arrays.
- Saves on parentheses, nothing more"""
- if np.iterable(n): # it's a sequence, return as an int64 array
- return np.int64(np.round(n))
- else: # it's a scalar, return as normal Python int
- return int(round(n))
- def split_tranges(tranges, width, tres):
- """Split up tranges into lots of smaller (typically overlapping) tranges, with width and
- tres. Usually, tres < width, but this also works for width < tres.
- Test with:
- print(split_tranges([(0,100)], 1, 10))
- print(split_tranges([(0,100)], 10, 1))
- print(split_tranges([(0,100)], 10, 10))
- print(split_tranges([(0,100)], 10, 8))
- print(split_tranges([(0,100)], 3, 10))
- print(split_tranges([(0,100)], 10, 3))
- print(split_tranges([(0,100)], 3, 8))
- print(split_tranges([(0,100)], 8, 3))
- """
- newtranges = []
- for trange in tranges:
- t0, t1 = trange
- assert width < (t1 - t0)
- # calculate left and right edges of subtranges that fall within trange:
- # This is tricky: find maximum left edge such that the corresponding maximum right
- # edge goes as close as possible to t1 without exceeding it:
- tend = (t1-width+tres) // tres*tres # there might be a nicer way, but this works
- ledges = np.arange(t0, tend, tres)
- redges = ledges + width
- subtranges = [ (le, re) for le, re in zip(ledges, redges) ]
- newtranges.append(subtranges)
- return np.vstack(newtranges)
- def wrap_raster(raster, t0, t1, newdt, offsets=[0, 0]):
- """Extract event times in raster (list or array of arrays) between t0 and t1 (s),
- and wrap into extra rows such that event times never exceed newdt (s)"""
- t1floor = t1 - t1 % newdt
- t0s = np.arange(t0, t1floor, newdt) # end exclusive
- t1s = t0s + newdt
- tranges = np.column_stack([t0s, t1s])
- wrappedraster = []
- for row in raster:
- dst = [] # init list to collect events for this row
- for trange in tranges:
- # search within trange, but take into account desired offsets:
- si0, si1 = row.searchsorted(trange + offsets)
- # get spike times relative to start of trange:
- dst.append(row[si0:si1] - trange[0])
- # convert from list to object array to enable fancy indexing:
- wrappedraster.extend(dst)
- return np.asarray(wrappedraster)
- def cf():
- """Close all figures"""
- plt.close('all')
- def saveall(path=None, format='png'):
- """Save all open figures to chosen path, pop up dialog box if path is None"""
- if path is None: # query with dialog box for a path
- from matplotlib import rcParams
- startpath = os.path.expanduser(rcParams['savefig.directory']) # get default
- path = choose_path(startpath, msg="Choose a folder to save to")
- if not path: # dialog box was cancelled
- return # don't do anything
- rcParams['savefig.directory'] = path # update default
- fs = [ plt.figure(i) for i in plt.get_fignums() ]
- for f in fs:
- fname = f.canvas.get_window_title() + '.' + format
- fname = fname.replace(' ', '_')
- fullfname = os.path.join(path, fname)
- print(fullfname)
- f.savefig(fullfname)
- def lastcmd():
- """Return a string containing the last command entered by the user in the
- Ipython shell. Useful for generating plot titles"""
- ip = get_ipython()
- return ip._last_input_line
- def wintitle(titlestr=None, f=None):
- """Set title of current MPL window, defaults to last command entered"""
- if titlestr is None:
- titlestr = lastcmd()
- if f is None:
- f = plt.gcf()
- f.canvas.set_window_title(titlestr)
- def simpletraster(raster, dt=5, offsets=[0, 0], s=1, clr='k',
- scatter=False, scattermarker='|', scattersize=10,
- burstis=None, burstclr='r',
- axisbg='w', alpha=1, inchespersec=1.5, inchespertrial=1/25,
- ax=None, figsize=None, title=False, xaxis=True, label=None):
- """
- Create a simple trial raster plot. Each entry in raster is a list of spike times
- relative to the start of each trial.
- dt : trial duration (s)
- offsets : offsets relative to trial start and end (s)
- s : tick linewidths
- clr : tick color, either a single color or a sequence of colors, one per trial
- scatter : whether to use original ax.scatter() command to plot much faster and use much
- less memory, but with potentially vertically overlapping ticks. Otherwise,
- default to slower ax.eventplot()
- burstis : burst indices, as returned by FiringPattern().burst_ratio()
- """
- ntrials = len(raster)
- spiketrialis, c = [], []
- # get raster tick color of each trial:
- if type(clr) == str: # all trials have the same color
- clr = list(colors.to_rgba(clr))
- clr[3] = alpha # apply alpha, so that we can control alpha per tick
- trialclrs = [clr]*ntrials
- else: # each trial has potentially a different color
- assert type(clr) in [list, np.ndarray]
- assert len(clr) == ntrials
- trialclrs = []
- for trialclr in clr:
- trialclr = list(colors.to_rgba(trialclr))
- trialclr[3] = alpha # apply alpha, so that we can control alpha per tick
- trialclrs.append(trialclr)
- burstclr = colors.to_rgba(burstclr) # keep full saturation for burst spikes
- # collect 1-based trial info, one entry per spike:
- for triali, rastertrial in enumerate(raster):
- nspikes = len(rastertrial)
- spiketrialis.append(np.tile(triali+1, nspikes)) # 1-based
- trialclr = trialclrs[triali]
- spikecolors = np.tile(trialclr, (nspikes, 1))
- if burstis is not None:
- bis = burstis[triali]
- if len(bis) > 0:
- spikecolors[bis] = burstclr
- c.append(spikecolors)
- # convert each list of arrays to a single flat array:
- raster = np.hstack(raster)
- spiketrialis = np.hstack(spiketrialis)
- c = np.concatenate(c)
- xmin, xmax = offsets[0], dt + offsets[1]
- totaldt = xmax - xmin # total raster duration, including offsets
- if ax == None:
- if figsize is None:
- figwidth = min(1 + totaldt*inchespersec, 12)
- figheight = min(1 + ntrials*inchespertrial, 12)
- figsize = figwidth, figheight
- f, ax = plt.subplots(figsize=figsize)
- if scatter:
- # scatter doesn't carefully control vertical spacing, allows vertical overlap of ticks:
- ax.scatter(raster, spiketrialis, marker=scattermarker, c=c, s=scattersize, label=label)
- else:
- # eventplot is slower, but does a better job:
- raster = raster[:, np.newaxis] # somehow eventplot requires an extra unitary dimension
- if len(raster) == 0:
- print("No spikes for eventplot %r" % title) # prevent TypeError from eventplot()
- else:
- ax.eventplot(raster, lineoffsets=spiketrialis, colors=c, linewidth=s, label=label)
- ax.set_xlim(xmin, xmax)
- # -1 inverts the y axis, +1 ensures last trial is fully visible:
- ax.set_ylim(ntrials+1, -1)
- ax.set_facecolor(axisbg)
- ax.set_xlabel('Time (s)')
- ax.set_ylabel('Trial')
- if label:
- ax.legend(loc="best")
- if title:
- #a.set_title(title)
- wintitle(title)
- if xaxis != True:
- if xaxis == False:
- renderer = f.canvas.get_renderer()
- bbox = a.xaxis.get_tightbbox(renderer).transformed(f.dpi_scale_trans.inverted())
- xaxis = bbox.height
- figheight = figheight - xaxis
- ax.get_xaxis().set_visible(False)
- ax.spines['bottom'].set_visible(False)
- f.set_figheight(figheight)
- #f.tight_layout(pad=0.3) # crop figure to contents, doesn't seem to do anything any more
- #f.show()
- return ax
- def raster2psth(raster, bins, binw, tres, kernel='gauss'):
- """Convert a spike trial raster to a peri-stimulus time histogram (PSTH).
- To calculate the PSTH of a subset of trials, pass a raster containing only that subset.
- Parameters
- ----------
- raster : spike trial raster as a sequence of arrays of spike times (s), one array per trial
- bins : 2D array of start and stop PSTH bin edge times (s), one row per bin.
- Bins may or may not be overlapping. Typically generated using util.split_tranges()
- binw : PSTH bin width (s) that was used to generate bins
- tres : temporal resolution (s) that was used to generate bins, only used if kernel=='gauss'
- kernel : smoothing kernel : None or 'gauss'
- Returns
- -------
- psth : peri-stimulus time histogram (Hz), normalized by bin width and number of trials
- """
- # make sure raster has nested iterables, i.e. list of arrays, or array of arrays, etc.,
- # even if there's only one array inside raster representing only one trial:
- if len(raster) > 0: # not an empty raster
- trial0 = raster[0]
- if type(trial0) not in (np.ndarray, list):
- raise ValueError("Ensure that raster is a sequence of arrays of spike times,\n"
- "one per trial. If you're passing only a single extracted trial,\n"
- "make sure to pass it within e.g. a list of length 1")
- # now it's safe to assume that len(raster) represents the number of included trials,
- # and not erroneuosly the number of spikes in a single unnested array of spike times:
- ntrials = len(raster)
- if ntrials == 0: # empty raster
- spikes = np.asarray(raster)
- else:
- spikes = np.hstack(raster) # flatten across trials
- spikes.sort()
- spikeis = spikes.searchsorted(bins) # where bin edges fall in spikes
- # convert to rate: number of spikes in each bin, normalized by binw:
- psth = (spikeis[:, 1] - spikeis[:, 0]) / binw
- if kernel is None: # rectangular bins
- pass
- elif kernel == 'gauss': # apply Gaussian filtering
- sigma = binw / 2 # set sigma to half the bin width (sec)
- sigmansamples = sigma / tres # sigma as multiple of number of samples (unitless)
- psth = filters.gaussian_filter1d(psth, sigma=sigmansamples)
- else:
- raise ValueError('Unknown kernel %r' % kernel)
- # normalize by number of trials:
- if ntrials != 0:
- psth = psth / ntrials
- return psth
- def raster2freqcomp(raster, dt, f, mean='scalar'):
- """Extract a frequency component from spike raster (one row of spike times per trial).
- Adapted from getHarmsResps.m and UnitGetHarm.m
- Parameters
- ----------
- raster : spike raster as a sequence of arrays of spike times (s), one array per trial
- dt : trial duration (s)
- f : frequency to extract (Hz), f=0 extracts mean firing rate
- mean : 'scalar': compute mean of amplitudes of each trial's vector (mean(abs)), i.e. find
- frequency component f separately for each trial, then take average amplitude.
- 'vector': compute mean across all trials before calculating amplitude (abs(mean)),
- equivalent to first calculating PSTH from all rows of raster
- Returns
- -------
- r : peak-to-peak amplitude of frequency component f
- theta : angle of frequency component f (rad)
- Examples
- --------
- >>> inphase = np.array([0, 1, 2, 3, 4]) # spike times (s)
- >>> outphase = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) # spike times (s)
- >>> raster2freqcomp([inphase], 5, 1) # single trial, 'mean' is irrelevant
- (2.0, -4.898587196589412e-16)
- >>> raster2freqcomp([outphase], 5, 1)
- (2.0, 3.1415926535897927)
- >>> raster2freqcomp([inphase, outphase], 5, 1, mean='scalar')
- (2.0, 1.5707963267948961)
- >>> raster2freqcomp([inphase, outphase], 5, 1, mean='vector')
- (1.2246467991473544e-16, 1.5707963267948966)
- Using f=0 returns mean firing rate:
- >>> raster2freqcomp([inphase, outphase], 5, 0, mean='scalar')
- (1.0, 0.0)
- >>> raster2freqcomp([inphase, outphase], 5, 0, mean='vector')
- (1.0, 0.0)
- """
- ntrials = len(raster)
- res, ims = np.zeros(ntrials), np.zeros(ntrials) # init real and imaginary components
- for triali, spikes in enumerate(raster): # iterate over trials
- if len(spikes) == 0:
- continue
- spikes = np.asarray(spikes) # in case raster is a list of lists
- if spikes.max() > dt:
- print('spikes exceeding dt:', spikes[spikes > dt])
- # discard rare spikes in raster that for some reason (screen vsyncs?) fall outside
- # the expected trial duration:
- spikes = spikes[spikes <= dt]
- omega = 2 * np.pi * f # angular frequency (rad/s)
- res[triali] = (np.cos(omega*spikes)).sum() / dt
- ims[triali] = (np.sin(omega*spikes)).sum() / dt
- Hs = (res + ims*1j) # array of complex numbers
- if f != 0: # not just the degenerate mean firing rate case
- Hs = 2 * Hs # convert to peak-to-peak
- if mean == 'scalar':
- Hamplitudes = np.abs(Hs) # ntrials long
- r = np.nanmean(Hamplitudes) # mean of amplitudes
- theta = np.nanmean(np.angle(Hs)) # mean of angles
- #rstd = np.nanstd(Hamplitudes) # stdev of amplitudes
- elif mean == 'vector':
- Hmean = np.nanmean(Hs) # single complex number
- r = np.abs(Hmean) # corresponds to PSTH amplitude
- theta = np.angle(Hmean) # angle of mean vector
- #rstd = np.nanstd(Hs) # single scalar, corresponds to PSTH stdev
- else:
- raise ValueError('Unknown `mean` method %r' % mean)
- ## NOTE: another way to calculate theta might be:
- #theta = np.arctan2(np.nanmean(np.imag(Hs)), np.nanmean(np.real(Hs)))
- return r, theta
- def sparseness(x):
- """Sparseness measure, from Vinje and Gallant, 2000. This is basically 1 minus the ratio
- of the square of the sums over the sum of the squares of the values in signal x"""
- if x.sum() == 0:
- return 0
- n = len(x)
- return (1 - (x.sum()/n)**2 / np.sum((x**2)/n)) / (1 - 1/n)
- def reliability(signals, average='mean', ignore_nan=True):
- """Calculate reliability across trials in signals, one row per trial, by finding the
- average Pearson's rho between all pairwise combinations of trial signals
- Returns
- -------
- reliability : float
- rhos : ndarray
- """
- ntrials = len(signals)
- if ntrials < 2:
- return np.nan # can't calculate reliability with less than 2 trials
- if ignore_nan:
- rhos = pairwisecorr_nan(signals)
- else:
- rhos, _ = pairwisecorr(signals)
- if average == 'mean':
- rel = np.nanmean(rhos)
- elif average == 'median':
- rel = np.nanmedian(rhos)
- else:
- raise ValueError('Unknown average %r' % average)
- return rel, rhos
- def snr_baden2016(signals):
- """Return signal-to-noise ratio, aka Berens quality index (QI), from Baden2016, of a set of
- signals. Take ratio of the temporal variance of the trial averaged signal (i.e. PSTH),
- to the average across trials of the variance in time of each trial. Ranges from 0 to 1.
- Ignores NaNs."""
- assert signals.ndim == 2
- signal = np.nanvar(np.nanmean(signals, axis=0)) # reduce row axis, calc var across time
- noise = np.nanmean(np.nanvar(signals, axis=1)) # reduce time axis, calc mean across trials
- if signal == 0:
- return 0
- else:
- return signal / noise
- def get_psth_peaks_gac(ts, t, psth, thresh, sigma=0.02, alpha=1.0, minpoints=5,
- lowp=16, highp=84, checkthresh=True, verbose=True):
- """Extract PSTH peaks from spike times ts collapsed across trials, by clustering them
- using gradient ascent clustering (GAC, Swindale2014). Then, optionally check each peak
- against its amplitude in the PSTH (and its time stamps t), to ensure it passes thresh.
- Also extract the left and right edges of each peak, based on where each peak's mass falls
- between lowp and highp percentiles.
- sigma is the clustering bandwidth used by GAC, in this case in seconds.
- Note that very narrow peaks will be missed if the resolution of the PSTH isn't high enough
- (TRES=0.0001 is plenty)"""
- from spyke.gac import gac # .pyx file
- ts2d = np.float32(ts[:, None]) # convert to 2D (one row per spike), contig float32
- # get cluster IDs and positions corresponding to spikets, cpos is indexed into using
- # cids:
- cids, cpos = gac(ts2d, sigma=sigma, alpha=alpha, minpoints=minpoints, returncpos=True,
- verbose=verbose)
- ucids = np.unique(cids) # unique cluster IDs across all spikets
- ucids = ucids[ucids >= 0] # exclude junk cluster -1
- #npeaks = len(ucids) # but not all of them will necessarily cross the PSTH threshold
- peakis, lis, ris = [], [], []
- for ucid, pos in zip(ucids, cpos): # clusters are numbered in order of decreasing size
- spikeis, = np.where(cids == ucid)
- cts = ts[spikeis] # this cluster's spike times
- # search all spikes for argmax, same as using lowp=0 and highp=100:
- #li, ri = t.searchsorted([cts[0], cts[-1]])
- # search only within the percentiles for argmax:
- lt, rt = np.percentile(cts, [lowp, highp])
- li, ri = t.searchsorted([lt, rt])
- if li == ri:
- # start and end indices are identical, cluster probably falls before first or
- # after last spike time:
- assert li == 0 or li == len(psth)
- continue # no peak to be found in psth for this cluster
- localpsth = psth[li:ri]
- # indices of all local peaks within percentiles in psth:
- #allpeakiis, = argrelextrema(localpsth, np.greater)
- #if len(allpeakiis) == 0:
- # continue # no peaks found for this cluster
- # find peakii closest to pos:
- #peakii = allpeakiis[abs((t[li + allpeakiis] - pos)).argmin()]
- # find biggest peak:
- #peakii = allpeakiis[localpsth[allpeakiis].argmax()]
- peakii = localpsth.argmax() # find max point
- if peakii == 0 or peakii == len(localpsth)-1:
- continue # skip "peak" that's really just a start or end point of localpsth
- peaki = li + peakii
- if checkthresh and psth[peaki] < thresh:
- continue # skip peak that doesn't meet thresh
- if peaki in peakis:
- continue # this peak has already been detected by a preceding, larger, cluster
- peakis.append(peaki)
- lis.append(li)
- ris.append(ri)
- if verbose:
- print('.', end='') # indicate a peak has been found
- return np.asarray(peakis), np.asarray(lis), np.asarray(ris)
- def pairwisecorr(signals, weight=False, invalid='ignore'):
- """Calculate Pearson correlations between all pairs of rows in 2D signals array.
- See np.seterr() for possible values of `invalid`"""
- assert signals.ndim == 2
- assert len(signals) >= 2 # at least two rows, i.e. at least one pair
- N = len(signals)
- # potentially allow 0/0 (nan) rhomat entries by ignoring 'invalid' errors
- # (not 'divide'):
- oldsettings = np.seterr(invalid=invalid)
- rhomat = np.corrcoef(signals) # full correlation matrix
- np.seterr(**oldsettings) # restore previous numpy error settings
- uti = np.triu_indices(N, k=1)
- rhos = rhomat[uti] # pull out the upper triangle
- if weight:
- sums = signals.sum(axis=1)
- # weight each pair by the one with the least signal:
- weights = np.vstack((sums[uti[0]], sums[uti[1]])).min(axis=0) # all pairs
- weights = weights / weights.sum() # normalize, ensure float division
- return rhos, weights
- else:
- return rhos, None
- def pairwisecorr_nan(signals):
- """Calculate Pearson correlations between all pairs of rows in 2D signals array,
- while skipping NaNs. Relies on Pandas DataFrame method"""
- assert signals.ndim == 2
- assert len(signals) >= 2 # at least two rows, i.e. at least one pair
- N = len(signals)
- rhomat = np.array(pd.DataFrame(signals.T).corr()) # full correlation matrix
- return np.triu(rhomat, k=1) # non-unique entries zeroed
- def vector_OSI(oris, rates):
- """Vector averaging method for calculating orientation selectivity index (OSI).
- See Bonhoeffer1995, Swindale1998 and neuropy.neuron.Tune.pref().
- Reasonable use case is to take model tuning curve, calculate its values at a
- fine ori resolution (say 1 deg), and use that as input rates here.
- Parameters
- ----------
- oris : orientations (degrees, potentially ranging 0 to 360)
- Attention: for orientation data, ori should always range 0 to 180! Only for direction
- data (e.g. from drifting gratings) can it go 0 to 360; it will then return the
- orientation selectivity of the data. For the direction selectivity, ori has to again
- range 0 to 180!
- rates : corresponding firing rates
- Returns
- -------
- r : length of net vector average as fraction of total firing
- """
- orisrad = 2 * oris * np.pi/180 # double the angle, convert from deg to rad
- x = (rates*np.cos(orisrad)).sum()
- y = (rates*np.sin(orisrad)).sum()
- n = rates.sum()
- r = np.sqrt(x**2+y**2) / n # fraction of total firing
- return r
- def percentile_ci(data, alpha=0.05, func=np.percentile, **kwargs):
- """Simple percentile method for confidence intervals. No assumptions
- about shape of distribution"""
- data = np.array(data) # accept lists & tuples
- lower, med, upper = func(data, [100*alpha, 50, 100*(1-alpha)], **kwargs)
- return med, lower, upper
- def sum_of_gaussians(x, dp, rp, rn, r0, sigma):
- """ORITUNE sum of two gaussians living on a circle, for orientation tuning
- x are the orientations, bet 0 and 360.
- dp is the preferred direction (bet 0 and 360)
- rp is the response to the preferred direction;
- rn is the response to the opposite direction;
- r0 is the background response (useful only in some cases)
- sigma is the tuning width;
- """
- angles_p = 180 / pi * np.angle(np.exp(1j*(x-dp) * pi / 180))
- angles_n = 180 / pi * np.angle(np.exp(1j*(x-dp+180) * pi / 180))
- y = (r0 + rp*np.exp(-angles_p**2 / (2*sigma**2)) + rn*np.exp(-angles_n**2 / (2*sigma**2)))
- return y
|