util.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. """General use classes and functions"""
  2. import os
  3. import pickle
  4. import numpy as np
  5. from numpy import pi
  6. import pandas as pd
  7. from scipy.stats import linregress
  8. from scipy.optimize import curve_fit
  9. import scipy.ndimage.filters as filters
  10. import matplotlib.pyplot as plt
  11. from matplotlib import colors
  12. from matplotlib.ticker import FuncFormatter
  13. from matplotlib.offsetbox import AnchoredText
  14. RESPDFNAMES = ['mviresp', 'grtresp', 'sponresp', 'bestmviresp', 'bestgrtresp', 'grttunresp']
  15. FITSDFNAMES = ['fits', 'sampfits']
  16. MIDFNAMES = ['mviFMI', 'grtFMI', 'mviRMI', 'grtRMI', 'maxFMI', 'maxRMI', 'iposmi']
  17. POSRUNDFNAMES = ['ipos_st8', 'ipos_opto', 'upos', 'runspeed']
  18. CELLTYPEDFNAMES = ['celltype', 'cellscreenpos']
  19. FIGDFNAMES = ['fig1', 'fig2', 'fig3', 'fig4', 'fig5', 'fig6',
  20. 'fig1S33S1', 'fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt', 'fig5S2']
  21. DFNAMES = RESPDFNAMES + FITSDFNAMES + MIDFNAMES + POSRUNDFNAMES + CELLTYPEDFNAMES + FIGDFNAMES
  22. EXPORTDFNAMES = MIDFNAMES + FIGDFNAMES # to .csv
  23. STRNAMES = ['mvimseustrs', 'grtmseustrs', 'mvigrtmsustrs']
  24. def load(name, subfolder=''):
  25. """Return variable (e.g. a DataFrame) from a pickle"""
  26. path = os.path.dirname(__file__)
  27. fname = os.path.join(path, 'pickles', subfolder, name+'.pickle')
  28. with open(fname, 'rb') as f:
  29. val = pd.read_pickle(f) # backward compatible with pickles generated by pandas < 1.0
  30. return val
  31. def load_all(subfolder=''):
  32. """Load all variables (DFNAMES+STRNAMES) from pickles, return name:val dict"""
  33. name2val = {}
  34. for name in DFNAMES+STRNAMES:
  35. print('Loading', name)
  36. try:
  37. val = load(name, subfolder=subfolder)
  38. name2val[name] = val
  39. except FileNotFoundError as err:
  40. print(err)
  41. return name2val
  42. def save(ns, dfnames=None, strnames=None, subfolder=''):
  43. """Save DFNAMES, STRNAMES and optionally FIGDFNAMES to disk, given a namespace ns
  44. (e.g. locals()). Useful for reducing unnecessary recalculation and/or in case of
  45. lack of database connection"""
  46. if subfolder == '' and ns['EXPTYPE'] != 'pvmvis':
  47. subfolder = ns['EXPTYPE'] # prevent accidentally overwriting the default PV pickles
  48. path = os.path.join(os.path.dirname(__file__), 'pickles', subfolder)
  49. if dfnames == None:
  50. dfnames = DFNAMES
  51. assert type(dfnames) in [list, tuple]
  52. for dfname in dfnames:
  53. fname = os.path.join(path, dfname+'.pickle')
  54. print('Saving', fname)
  55. if dfname not in ns:
  56. print('WARNING: %r not found' % dfname)
  57. else:
  58. ns[dfname].to_pickle(fname)
  59. if strnames == None:
  60. strnames = STRNAMES
  61. assert type(strnames) in [list, tuple]
  62. for strname in strnames:
  63. fname = os.path.join(path, strname+'.pickle')
  64. print('Saving', fname)
  65. with open(fname, 'wb') as f:
  66. pickle.dump(ns[strname], f)
  67. def export2csv(ns, dfnames=None, subfolder=''):
  68. """Given a namespace ns (e.g. locals()), export all figure-specific and modulation index
  69. DataFrames to .csv for Steffen to analyze in R, dfnames: list of strings"""
  70. path = os.path.join(os.path.dirname(__file__), 'csv', subfolder)
  71. if dfnames == None:
  72. dfnames = EXPORTDFNAMES
  73. for dfname in dfnames:
  74. fname = os.path.join(path, dfname+'.csv')
  75. print('Saving', fname)
  76. if dfname not in ns:
  77. print('WARNING: %r not found' % dfname)
  78. else:
  79. df = ns[dfname]
  80. columns = list(df.columns)
  81. for col in columns:
  82. if np.all(pd.isna(df[col])):
  83. print('WARNING: found empty column %r in df %r' % (col, dfname))
  84. if dfname == 'fig1':
  85. # explode various columns so that each trial gets its own row,
  86. # can only explode one column at a time:
  87. newdf = df.explode('trialis')
  88. newdf['rates'] = df.explode('rates')['rates'].values
  89. newdf['rate02s'] = df.explode('rate02s')['rate02s'].values
  90. newdf['rate35s'] = df.explode('rate35s')['rate35s'].values
  91. newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
  92. newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
  93. newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values
  94. df = newdf
  95. elif dfname == 'fig3':
  96. # explode various columns so that each trial gets its own row,
  97. # can only explode one column at a time:
  98. newdf = df.explode('trialis') # includes both non-blank and blank trials
  99. newdf['rates'] = df.explode('rates')['rates'].values
  100. newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
  101. newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
  102. newdf['blankcondrates'] = df.explode('blankcondrates')['blankcondrates'].values
  103. newdf['blankburstratios'] = df.explode('blankburstratios')['blankburstratios'].values
  104. newdf['blankcondburstratios'] = df.explode('blankcondburstratios')[
  105. 'blankcondburstratios'].values
  106. df = newdf
  107. elif dfname in ['fig1S4mvi', 'fig1S4grt', 'fig1S5mvi', 'fig1S5grt']:
  108. # explode various columns so that each trial gets its own row,
  109. # can only explode one column at a time:
  110. newdf = df.explode('trialis') # includes both non-blank and blank trials
  111. newdf['rates'] = df.explode('rates')['rates'].values
  112. newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
  113. df = newdf
  114. elif dfname == 'fig4S1':
  115. newdf['blankrates'] = df.explode('blankrates')['blankrates'].values
  116. df = newdf
  117. elif dfname == 'fig5':
  118. # explode various columns so that each trial gets its own row,
  119. # can only explode one column at a time:
  120. newdf = df.explode('trialis')
  121. newdf['rates'] = df.explode('rates')['rates'].values
  122. newdf['burstratios'] = df.explode('burstratios')['burstratios'].values
  123. df = newdf
  124. elif dfname.startswith('ipos_'):
  125. # exclude export of pupil area trial matrices and timepoints:
  126. columns.remove('area_trialmat')
  127. columns.remove('area_trialts')
  128. # explode various columns so that each trial gets its own row,
  129. # can only explode one column at a time:
  130. newdf = df.explode('trialis')
  131. newdf['area_trialmean'] = df.explode('area_trialmean')['area_trialmean'].values
  132. df = newdf
  133. df.to_csv(fname, columns=columns)
  134. def desat(hexcolor, alpha):
  135. """Manually desaturate hex RGB color. Plotting directly with alpha results in saturated
  136. colors appearing to be plotted over top of desaturated colors, even when plotted in a
  137. lower layer"""
  138. return mixalpha(hexcolor, alpha)
  139. def mixalpha(hexcolor, alpha=1, bg='#ffffff'):
  140. """Mix alpha into hexcolor, assuming background color.
  141. See https://stackoverflow.com/a/21576659/2020363"""
  142. rgb = np.array(colors.hex2color(hexcolor)) # convert to float RGB array
  143. bg = np.array(colors.hex2color(bg))
  144. rgb = alpha*rgb + (1 - alpha)*bg # mix it
  145. return colors.rgb2hex(rgb)
  146. def axes_disable_scientific(axes, axiss=None):
  147. """Disable scientific notation for both axes labels, useful for log-log plots.
  148. See https://stackoverflow.com/a/49306588/3904031"""
  149. if axiss == None:
  150. axiss = [axes.xaxis, axes.yaxis]
  151. for axis in axiss:
  152. ff = FuncFormatter(lambda y, _: '{:.16g}'.format(y))
  153. axis.set_major_formatter(ff)
  154. def ms2msstr(msu):
  155. """Convert ms dictionary to ms string"""
  156. return "%s_s%02d" % (msu['m'], msu['s'])
  157. def msu2msustr(msu):
  158. """Convert msu dictionary to msu string"""
  159. return "%s_s%02d_u%02d" % (msu['m'], msu['s'], msu['u'])
  160. def mseustr2msustr(mseustr):
  161. """Convert an mseu string to an msu string, i.e. drop the experiment ID"""
  162. msustr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 5]])
  163. return msustr
  164. def mseustrs2msustrs(mseustrs):
  165. """Convert a sequence of mseu strings to msu strings, i.e. drop the experiment ID"""
  166. msustrs = []
  167. for mseustr in mseustrs:
  168. msustr = mseustr2msustr(mseustr)
  169. msustrs.append(msustr)
  170. return msustrs
  171. def mseustr2mstr(mseustr):
  172. """Convert an mseu string to a mouse string"""
  173. mstr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2]])
  174. return mstr
  175. def mseustrs2mstrs(mseustrs):
  176. """Convert a sequence of mseu strings to mouse strings"""
  177. mstrs = []
  178. for mseustr in mseustrs:
  179. mstr = mseustr2mstr(mseustr)
  180. mstrs.append(mstr)
  181. return mstrs
  182. def mseustr2msestr(mseustr):
  183. """Convert an mseu string to an mse string"""
  184. msestr = '_'.join(np.array(mseustr.split('_'))[[0, 1, 2, 3, 4]])
  185. return msestr
  186. def mseustrs2msestrs(mseustrs):
  187. """Convert a sequence of mseu strings to mse strings"""
  188. msestrs = []
  189. for mseustr in mseustrs:
  190. msestr = mseustr2msestr(mseustr)
  191. msestrs.append(msestr)
  192. return msestrs
  193. def findmse(mseustrs, msestr):
  194. """Return boolean array of all entries in mseustrs that match the experiment
  195. described by msestr"""
  196. mse = msestr.split('_')
  197. assert len(mse) == 5 # strain, year, number, series, experiment
  198. hits = np.tile(False, len(mseustrs))
  199. for i, mseustr in enumerate(mseustrs):
  200. mseu = mseustr.split('_')
  201. assert len(mseu) == 6
  202. if mseu[:5] == mse:
  203. hits[i] = True
  204. return hits
  205. def fitmodel(ctrlfit, optofit, ctrltest, optotest, model=None):
  206. """Fit a model to ctrl and opto fit signals, test with ctrl and opto test signals"""
  207. if model == 'linear':
  208. mm, b, rr, p, stderr = linregress(ctrlfit, optofit)
  209. rsq = rr * rr
  210. raise RuntimeError('R2 for linear fit needs to be calculated on test data')
  211. elif model == 'threshlin':
  212. p, pcov = curve_fit(threshlin, xdata=ctrlfit, ydata=optofit,
  213. p0=None)
  214. mm, b = p
  215. # for each sample, calc rsq on the test data:
  216. rsq = rsquared(optotest, threshlin(ctrltest, *p))
  217. else:
  218. raise ValueError("Unknown model %r" % model)
  219. return mm, b, rsq
  220. def threshlin(x, m, b):
  221. """Return threshold linear model"""
  222. y = m * x + b
  223. y[y < 0] = 0
  224. return y
  225. def rsquared(targets, predictions):
  226. """Return the r-squared value for the fit"""
  227. residuals = targets - predictions
  228. residual_variance = np.sum(residuals**2)
  229. variance_of_targets = np.sum((targets - np.mean(targets))**2)
  230. if variance_of_targets == 0:
  231. rsq = np.nan
  232. else:
  233. rsq = 1 - (residual_variance / variance_of_targets)
  234. return rsq
  235. def residual_rsquared(targets, residuals):
  236. """Return the r-squared value for the fit, given the target values and residuals.
  237. Minor variation of djd.model.rsquared()"""
  238. ssres = np.sum(residuals**2)
  239. sstot = np.sum((targets - np.mean(targets))**2)
  240. if sstot == 0:
  241. rsq = np.nan
  242. else:
  243. rsq = 1 - (ssres / sstot)
  244. return rsq
  245. def linear_loss(params, x, y):
  246. """Linear loss function, for use with scipy.optimize.least_squares()"""
  247. assert len(params) == 2
  248. m, b = params # unpack
  249. return m*x + b - y
  250. def get_max_snr(mvirespr, mseustr, kind, st8):
  251. """Find SNR for mseustr, kind, st8 combination, take max across opto conditions"""
  252. mvirowis = ((mvirespr['mseu'] == mseustr) &
  253. (mvirespr['kind'] == kind) &
  254. (mvirespr['st8'] == st8))
  255. mvirows = mvirespr[mvirowis]
  256. assert len(mvirows) == 2
  257. fbsnr = mvirows[mvirows['opto'] == False]['snr'].iloc[0] # feedback
  258. supsnr = mvirows[mvirows['opto'] == True]['snr'].iloc[0] # suppression
  259. maxsnr = max(fbsnr, supsnr) # take the max of the two conditions
  260. return maxsnr
  261. def intround(n):
  262. """Round to the nearest integer, return an integer. Works on arrays.
  263. Saves on parentheses, nothing more"""
  264. if np.iterable(n): # it's a sequence, return as an int64 array
  265. return np.int64(np.round(n))
  266. else: # it's a scalar, return as normal Python int
  267. return int(round(n))
  268. def split_tranges(tranges, width, tres):
  269. """Split up tranges into lots of smaller (typically overlapping) tranges, with width and
  270. tres. Usually, tres < width, but this also works for width < tres.
  271. Test with:
  272. print(split_tranges([(0,100)], 1, 10))
  273. print(split_tranges([(0,100)], 10, 1))
  274. print(split_tranges([(0,100)], 10, 10))
  275. print(split_tranges([(0,100)], 10, 8))
  276. print(split_tranges([(0,100)], 3, 10))
  277. print(split_tranges([(0,100)], 10, 3))
  278. print(split_tranges([(0,100)], 3, 8))
  279. print(split_tranges([(0,100)], 8, 3))
  280. """
  281. newtranges = []
  282. for trange in tranges:
  283. t0, t1 = trange
  284. assert width < (t1 - t0)
  285. # calculate left and right edges of subtranges that fall within trange:
  286. # This is tricky: find maximum left edge such that the corresponding maximum right
  287. # edge goes as close as possible to t1 without exceeding it:
  288. tend = (t1-width+tres) // tres*tres # there might be a nicer way, but this works
  289. ledges = np.arange(t0, tend, tres)
  290. redges = ledges + width
  291. subtranges = [ (le, re) for le, re in zip(ledges, redges) ]
  292. newtranges.append(subtranges)
  293. return np.vstack(newtranges)
  294. def wrap_raster(raster, t0, t1, newdt, offsets=[0, 0]):
  295. """Extract event times in raster (list or array of arrays) between t0 and t1 (s),
  296. and wrap into extra rows such that event times never exceed newdt (s)"""
  297. t1floor = t1 - t1 % newdt
  298. t0s = np.arange(t0, t1floor, newdt) # end exclusive
  299. t1s = t0s + newdt
  300. tranges = np.column_stack([t0s, t1s])
  301. wrappedraster = []
  302. for row in raster:
  303. dst = [] # init list to collect events for this row
  304. for trange in tranges:
  305. # search within trange, but take into account desired offsets:
  306. si0, si1 = row.searchsorted(trange + offsets)
  307. # get spike times relative to start of trange:
  308. dst.append(row[si0:si1] - trange[0])
  309. # convert from list to object array to enable fancy indexing:
  310. wrappedraster.extend(dst)
  311. return np.asarray(wrappedraster)
  312. def cf():
  313. """Close all figures"""
  314. plt.close('all')
  315. def saveall(path=None, format='png'):
  316. """Save all open figures to chosen path, pop up dialog box if path is None"""
  317. if path is None: # query with dialog box for a path
  318. from matplotlib import rcParams
  319. startpath = os.path.expanduser(rcParams['savefig.directory']) # get default
  320. path = choose_path(startpath, msg="Choose a folder to save to")
  321. if not path: # dialog box was cancelled
  322. return # don't do anything
  323. rcParams['savefig.directory'] = path # update default
  324. fs = [ plt.figure(i) for i in plt.get_fignums() ]
  325. for f in fs:
  326. fname = f.canvas.get_window_title() + '.' + format
  327. fname = fname.replace(' ', '_')
  328. fullfname = os.path.join(path, fname)
  329. print(fullfname)
  330. f.savefig(fullfname)
  331. def lastcmd():
  332. """Return a string containing the last command entered by the user in the
  333. Ipython shell. Useful for generating plot titles"""
  334. ip = get_ipython()
  335. return ip._last_input_line
  336. def wintitle(titlestr=None, f=None):
  337. """Set title of current MPL window, defaults to last command entered"""
  338. if titlestr is None:
  339. titlestr = lastcmd()
  340. if f is None:
  341. f = plt.gcf()
  342. f.canvas.set_window_title(titlestr)
  343. def simpletraster(raster, dt=5, offsets=[0, 0], s=1, clr='k',
  344. scatter=False, scattermarker='|', scattersize=10,
  345. burstis=None, burstclr='r',
  346. axisbg='w', alpha=1, inchespersec=1.5, inchespertrial=1/25,
  347. ax=None, figsize=None, title=False, xaxis=True, label=None):
  348. """
  349. Create a simple trial raster plot. Each entry in raster is a list of spike times
  350. relative to the start of each trial.
  351. dt : trial duration (s)
  352. offsets : offsets relative to trial start and end (s)
  353. s : tick linewidths
  354. clr : tick color, either a single color or a sequence of colors, one per trial
  355. scatter : whether to use original ax.scatter() command to plot much faster and use much
  356. less memory, but with potentially vertically overlapping ticks. Otherwise,
  357. default to slower ax.eventplot()
  358. burstis : burst indices, as returned by FiringPattern().burst_ratio()
  359. """
  360. ntrials = len(raster)
  361. spiketrialis, c = [], []
  362. # get raster tick color of each trial:
  363. if type(clr) == str: # all trials have the same color
  364. clr = list(colors.to_rgba(clr))
  365. clr[3] = alpha # apply alpha, so that we can control alpha per tick
  366. trialclrs = [clr]*ntrials
  367. else: # each trial has potentially a different color
  368. assert type(clr) in [list, np.ndarray]
  369. assert len(clr) == ntrials
  370. trialclrs = []
  371. for trialclr in clr:
  372. trialclr = list(colors.to_rgba(trialclr))
  373. trialclr[3] = alpha # apply alpha, so that we can control alpha per tick
  374. trialclrs.append(trialclr)
  375. burstclr = colors.to_rgba(burstclr) # keep full saturation for burst spikes
  376. # collect 1-based trial info, one entry per spike:
  377. for triali, rastertrial in enumerate(raster):
  378. nspikes = len(rastertrial)
  379. spiketrialis.append(np.tile(triali+1, nspikes)) # 1-based
  380. trialclr = trialclrs[triali]
  381. spikecolors = np.tile(trialclr, (nspikes, 1))
  382. if burstis is not None:
  383. bis = burstis[triali]
  384. if len(bis) > 0:
  385. spikecolors[bis] = burstclr
  386. c.append(spikecolors)
  387. # convert each list of arrays to a single flat array:
  388. raster = np.hstack(raster)
  389. spiketrialis = np.hstack(spiketrialis)
  390. c = np.concatenate(c)
  391. xmin, xmax = offsets[0], dt + offsets[1]
  392. totaldt = xmax - xmin # total raster duration, including offsets
  393. if ax == None:
  394. if figsize is None:
  395. figwidth = min(1 + totaldt*inchespersec, 12)
  396. figheight = min(1 + ntrials*inchespertrial, 12)
  397. figsize = figwidth, figheight
  398. f, ax = plt.subplots(figsize=figsize)
  399. if scatter:
  400. # scatter doesn't carefully control vertical spacing, allows vertical overlap of ticks:
  401. ax.scatter(raster, spiketrialis, marker=scattermarker, c=c, s=scattersize, label=label)
  402. else:
  403. # eventplot is slower, but does a better job:
  404. raster = raster[:, np.newaxis] # somehow eventplot requires an extra unitary dimension
  405. if len(raster) == 0:
  406. print("No spikes for eventplot %r" % title) # prevent TypeError from eventplot()
  407. else:
  408. ax.eventplot(raster, lineoffsets=spiketrialis, colors=c, linewidth=s, label=label)
  409. ax.set_xlim(xmin, xmax)
  410. # -1 inverts the y axis, +1 ensures last trial is fully visible:
  411. ax.set_ylim(ntrials+1, -1)
  412. ax.set_facecolor(axisbg)
  413. ax.set_xlabel('Time (s)')
  414. ax.set_ylabel('Trial')
  415. if label:
  416. ax.legend(loc="best")
  417. if title:
  418. #a.set_title(title)
  419. wintitle(title)
  420. if xaxis != True:
  421. if xaxis == False:
  422. renderer = f.canvas.get_renderer()
  423. bbox = a.xaxis.get_tightbbox(renderer).transformed(f.dpi_scale_trans.inverted())
  424. xaxis = bbox.height
  425. figheight = figheight - xaxis
  426. ax.get_xaxis().set_visible(False)
  427. ax.spines['bottom'].set_visible(False)
  428. f.set_figheight(figheight)
  429. #f.tight_layout(pad=0.3) # crop figure to contents, doesn't seem to do anything any more
  430. #f.show()
  431. return ax
  432. def raster2psth(raster, bins, binw, tres, kernel='gauss'):
  433. """Convert a spike trial raster to a peri-stimulus time histogram (PSTH).
  434. To calculate the PSTH of a subset of trials, pass a raster containing only that subset.
  435. Parameters
  436. ----------
  437. raster : spike trial raster as a sequence of arrays of spike times (s), one array per trial
  438. bins : 2D array of start and stop PSTH bin edge times (s), one row per bin.
  439. Bins may or may not be overlapping. Typically generated using util.split_tranges()
  440. binw : PSTH bin width (s) that was used to generate bins
  441. tres : temporal resolution (s) that was used to generate bins, only used if kernel=='gauss'
  442. kernel : smoothing kernel : None or 'gauss'
  443. Returns
  444. -------
  445. psth : peri-stimulus time histogram (Hz), normalized by bin width and number of trials
  446. """
  447. # make sure raster has nested iterables, i.e. list of arrays, or array of arrays, etc.,
  448. # even if there's only one array inside raster representing only one trial:
  449. if len(raster) > 0: # not an empty raster
  450. trial0 = raster[0]
  451. if type(trial0) not in (np.ndarray, list):
  452. raise ValueError("Ensure that raster is a sequence of arrays of spike times,\n"
  453. "one per trial. If you're passing only a single extracted trial,\n"
  454. "make sure to pass it within e.g. a list of length 1")
  455. # now it's safe to assume that len(raster) represents the number of included trials,
  456. # and not erroneuosly the number of spikes in a single unnested array of spike times:
  457. ntrials = len(raster)
  458. if ntrials == 0: # empty raster
  459. spikes = np.asarray(raster)
  460. else:
  461. spikes = np.hstack(raster) # flatten across trials
  462. spikes.sort()
  463. spikeis = spikes.searchsorted(bins) # where bin edges fall in spikes
  464. # convert to rate: number of spikes in each bin, normalized by binw:
  465. psth = (spikeis[:, 1] - spikeis[:, 0]) / binw
  466. if kernel is None: # rectangular bins
  467. pass
  468. elif kernel == 'gauss': # apply Gaussian filtering
  469. sigma = binw / 2 # set sigma to half the bin width (sec)
  470. sigmansamples = sigma / tres # sigma as multiple of number of samples (unitless)
  471. psth = filters.gaussian_filter1d(psth, sigma=sigmansamples)
  472. else:
  473. raise ValueError('Unknown kernel %r' % kernel)
  474. # normalize by number of trials:
  475. if ntrials != 0:
  476. psth = psth / ntrials
  477. return psth
  478. def raster2freqcomp(raster, dt, f, mean='scalar'):
  479. """Extract a frequency component from spike raster (one row of spike times per trial).
  480. Adapted from getHarmsResps.m and UnitGetHarm.m
  481. Parameters
  482. ----------
  483. raster : spike raster as a sequence of arrays of spike times (s), one array per trial
  484. dt : trial duration (s)
  485. f : frequency to extract (Hz), f=0 extracts mean firing rate
  486. mean : 'scalar': compute mean of amplitudes of each trial's vector (mean(abs)), i.e. find
  487. frequency component f separately for each trial, then take average amplitude.
  488. 'vector': compute mean across all trials before calculating amplitude (abs(mean)),
  489. equivalent to first calculating PSTH from all rows of raster
  490. Returns
  491. -------
  492. r : peak-to-peak amplitude of frequency component f
  493. theta : angle of frequency component f (rad)
  494. Examples
  495. --------
  496. >>> inphase = np.array([0, 1, 2, 3, 4]) # spike times (s)
  497. >>> outphase = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) # spike times (s)
  498. >>> raster2freqcomp([inphase], 5, 1) # single trial, 'mean' is irrelevant
  499. (2.0, -4.898587196589412e-16)
  500. >>> raster2freqcomp([outphase], 5, 1)
  501. (2.0, 3.1415926535897927)
  502. >>> raster2freqcomp([inphase, outphase], 5, 1, mean='scalar')
  503. (2.0, 1.5707963267948961)
  504. >>> raster2freqcomp([inphase, outphase], 5, 1, mean='vector')
  505. (1.2246467991473544e-16, 1.5707963267948966)
  506. Using f=0 returns mean firing rate:
  507. >>> raster2freqcomp([inphase, outphase], 5, 0, mean='scalar')
  508. (1.0, 0.0)
  509. >>> raster2freqcomp([inphase, outphase], 5, 0, mean='vector')
  510. (1.0, 0.0)
  511. """
  512. ntrials = len(raster)
  513. res, ims = np.zeros(ntrials), np.zeros(ntrials) # init real and imaginary components
  514. for triali, spikes in enumerate(raster): # iterate over trials
  515. if len(spikes) == 0:
  516. continue
  517. spikes = np.asarray(spikes) # in case raster is a list of lists
  518. if spikes.max() > dt:
  519. print('spikes exceeding dt:', spikes[spikes > dt])
  520. # discard rare spikes in raster that for some reason (screen vsyncs?) fall outside
  521. # the expected trial duration:
  522. spikes = spikes[spikes <= dt]
  523. omega = 2 * np.pi * f # angular frequency (rad/s)
  524. res[triali] = (np.cos(omega*spikes)).sum() / dt
  525. ims[triali] = (np.sin(omega*spikes)).sum() / dt
  526. Hs = (res + ims*1j) # array of complex numbers
  527. if f != 0: # not just the degenerate mean firing rate case
  528. Hs = 2 * Hs # convert to peak-to-peak
  529. if mean == 'scalar':
  530. Hamplitudes = np.abs(Hs) # ntrials long
  531. r = np.nanmean(Hamplitudes) # mean of amplitudes
  532. theta = np.nanmean(np.angle(Hs)) # mean of angles
  533. #rstd = np.nanstd(Hamplitudes) # stdev of amplitudes
  534. elif mean == 'vector':
  535. Hmean = np.nanmean(Hs) # single complex number
  536. r = np.abs(Hmean) # corresponds to PSTH amplitude
  537. theta = np.angle(Hmean) # angle of mean vector
  538. #rstd = np.nanstd(Hs) # single scalar, corresponds to PSTH stdev
  539. else:
  540. raise ValueError('Unknown `mean` method %r' % mean)
  541. ## NOTE: another way to calculate theta might be:
  542. #theta = np.arctan2(np.nanmean(np.imag(Hs)), np.nanmean(np.real(Hs)))
  543. return r, theta
  544. def sparseness(x):
  545. """Sparseness measure, from Vinje and Gallant, 2000. This is basically 1 minus the ratio
  546. of the square of the sums over the sum of the squares of the values in signal x"""
  547. if x.sum() == 0:
  548. return 0
  549. n = len(x)
  550. return (1 - (x.sum()/n)**2 / np.sum((x**2)/n)) / (1 - 1/n)
  551. def reliability(signals, average='mean', ignore_nan=True):
  552. """Calculate reliability across trials in signals, one row per trial, by finding the
  553. average Pearson's rho between all pairwise combinations of trial signals
  554. Returns
  555. -------
  556. reliability : float
  557. rhos : ndarray
  558. """
  559. ntrials = len(signals)
  560. if ntrials < 2:
  561. return np.nan # can't calculate reliability with less than 2 trials
  562. if ignore_nan:
  563. rhos = pairwisecorr_nan(signals)
  564. else:
  565. rhos, _ = pairwisecorr(signals)
  566. if average == 'mean':
  567. rel = np.nanmean(rhos)
  568. elif average == 'median':
  569. rel = np.nanmedian(rhos)
  570. else:
  571. raise ValueError('Unknown average %r' % average)
  572. return rel, rhos
  573. def snr_baden2016(signals):
  574. """Return signal-to-noise ratio, aka Berens quality index (QI), from Baden2016, of a set of
  575. signals. Take ratio of the temporal variance of the trial averaged signal (i.e. PSTH),
  576. to the average across trials of the variance in time of each trial. Ranges from 0 to 1.
  577. Ignores NaNs."""
  578. assert signals.ndim == 2
  579. signal = np.nanvar(np.nanmean(signals, axis=0)) # reduce row axis, calc var across time
  580. noise = np.nanmean(np.nanvar(signals, axis=1)) # reduce time axis, calc mean across trials
  581. if signal == 0:
  582. return 0
  583. else:
  584. return signal / noise
  585. def get_psth_peaks_gac(ts, t, psth, thresh, sigma=0.02, alpha=1.0, minpoints=5,
  586. lowp=16, highp=84, checkthresh=True, verbose=True):
  587. """Extract PSTH peaks from spike times ts collapsed across trials, by clustering them
  588. using gradient ascent clustering (GAC, Swindale2014). Then, optionally check each peak
  589. against its amplitude in the PSTH (and its time stamps t), to ensure it passes thresh.
  590. Also extract the left and right edges of each peak, based on where each peak's mass falls
  591. between lowp and highp percentiles.
  592. sigma is the clustering bandwidth used by GAC, in this case in seconds.
  593. Note that very narrow peaks will be missed if the resolution of the PSTH isn't high enough
  594. (TRES=0.0001 is plenty)"""
  595. from spyke.gac import gac # .pyx file
  596. ts2d = np.float32(ts[:, None]) # convert to 2D (one row per spike), contig float32
  597. # get cluster IDs and positions corresponding to spikets, cpos is indexed into using
  598. # cids:
  599. cids, cpos = gac(ts2d, sigma=sigma, alpha=alpha, minpoints=minpoints, returncpos=True,
  600. verbose=verbose)
  601. ucids = np.unique(cids) # unique cluster IDs across all spikets
  602. ucids = ucids[ucids >= 0] # exclude junk cluster -1
  603. #npeaks = len(ucids) # but not all of them will necessarily cross the PSTH threshold
  604. peakis, lis, ris = [], [], []
  605. for ucid, pos in zip(ucids, cpos): # clusters are numbered in order of decreasing size
  606. spikeis, = np.where(cids == ucid)
  607. cts = ts[spikeis] # this cluster's spike times
  608. # search all spikes for argmax, same as using lowp=0 and highp=100:
  609. #li, ri = t.searchsorted([cts[0], cts[-1]])
  610. # search only within the percentiles for argmax:
  611. lt, rt = np.percentile(cts, [lowp, highp])
  612. li, ri = t.searchsorted([lt, rt])
  613. if li == ri:
  614. # start and end indices are identical, cluster probably falls before first or
  615. # after last spike time:
  616. assert li == 0 or li == len(psth)
  617. continue # no peak to be found in psth for this cluster
  618. localpsth = psth[li:ri]
  619. # indices of all local peaks within percentiles in psth:
  620. #allpeakiis, = argrelextrema(localpsth, np.greater)
  621. #if len(allpeakiis) == 0:
  622. # continue # no peaks found for this cluster
  623. # find peakii closest to pos:
  624. #peakii = allpeakiis[abs((t[li + allpeakiis] - pos)).argmin()]
  625. # find biggest peak:
  626. #peakii = allpeakiis[localpsth[allpeakiis].argmax()]
  627. peakii = localpsth.argmax() # find max point
  628. if peakii == 0 or peakii == len(localpsth)-1:
  629. continue # skip "peak" that's really just a start or end point of localpsth
  630. peaki = li + peakii
  631. if checkthresh and psth[peaki] < thresh:
  632. continue # skip peak that doesn't meet thresh
  633. if peaki in peakis:
  634. continue # this peak has already been detected by a preceding, larger, cluster
  635. peakis.append(peaki)
  636. lis.append(li)
  637. ris.append(ri)
  638. if verbose:
  639. print('.', end='') # indicate a peak has been found
  640. return np.asarray(peakis), np.asarray(lis), np.asarray(ris)
  641. def pairwisecorr(signals, weight=False, invalid='ignore'):
  642. """Calculate Pearson correlations between all pairs of rows in 2D signals array.
  643. See np.seterr() for possible values of `invalid`"""
  644. assert signals.ndim == 2
  645. assert len(signals) >= 2 # at least two rows, i.e. at least one pair
  646. N = len(signals)
  647. # potentially allow 0/0 (nan) rhomat entries by ignoring 'invalid' errors
  648. # (not 'divide'):
  649. oldsettings = np.seterr(invalid=invalid)
  650. rhomat = np.corrcoef(signals) # full correlation matrix
  651. np.seterr(**oldsettings) # restore previous numpy error settings
  652. uti = np.triu_indices(N, k=1)
  653. rhos = rhomat[uti] # pull out the upper triangle
  654. if weight:
  655. sums = signals.sum(axis=1)
  656. # weight each pair by the one with the least signal:
  657. weights = np.vstack((sums[uti[0]], sums[uti[1]])).min(axis=0) # all pairs
  658. weights = weights / weights.sum() # normalize, ensure float division
  659. return rhos, weights
  660. else:
  661. return rhos, None
  662. def pairwisecorr_nan(signals):
  663. """Calculate Pearson correlations between all pairs of rows in 2D signals array,
  664. while skipping NaNs. Relies on Pandas DataFrame method"""
  665. assert signals.ndim == 2
  666. assert len(signals) >= 2 # at least two rows, i.e. at least one pair
  667. N = len(signals)
  668. rhomat = np.array(pd.DataFrame(signals.T).corr()) # full correlation matrix
  669. return np.triu(rhomat, k=1) # non-unique entries zeroed
  670. def vector_OSI(oris, rates):
  671. """Vector averaging method for calculating orientation selectivity index (OSI).
  672. See Bonhoeffer1995, Swindale1998 and neuropy.neuron.Tune.pref().
  673. Reasonable use case is to take model tuning curve, calculate its values at a
  674. fine ori resolution (say 1 deg), and use that as input rates here.
  675. Parameters
  676. ----------
  677. oris : orientations (degrees, potentially ranging 0 to 360)
  678. Attention: for orientation data, ori should always range 0 to 180! Only for direction
  679. data (e.g. from drifting gratings) can it go 0 to 360; it will then return the
  680. orientation selectivity of the data. For the direction selectivity, ori has to again
  681. range 0 to 180!
  682. rates : corresponding firing rates
  683. Returns
  684. -------
  685. r : length of net vector average as fraction of total firing
  686. """
  687. orisrad = 2 * oris * np.pi/180 # double the angle, convert from deg to rad
  688. x = (rates*np.cos(orisrad)).sum()
  689. y = (rates*np.sin(orisrad)).sum()
  690. n = rates.sum()
  691. r = np.sqrt(x**2+y**2) / n # fraction of total firing
  692. return r
  693. def percentile_ci(data, alpha=0.05, func=np.percentile, **kwargs):
  694. """Simple percentile method for confidence intervals. No assumptions
  695. about shape of distribution"""
  696. data = np.array(data) # accept lists & tuples
  697. lower, med, upper = func(data, [100*alpha, 50, 100*(1-alpha)], **kwargs)
  698. return med, lower, upper
  699. def sum_of_gaussians(x, dp, rp, rn, r0, sigma):
  700. """ORITUNE sum of two gaussians living on a circle, for orientation tuning
  701. x are the orientations, bet 0 and 360.
  702. dp is the preferred direction (bet 0 and 360)
  703. rp is the response to the preferred direction;
  704. rn is the response to the opposite direction;
  705. r0 is the background response (useful only in some cases)
  706. sigma is the tuning width;
  707. """
  708. angles_p = 180 / pi * np.angle(np.exp(1j*(x-dp) * pi / 180))
  709. angles_n = 180 / pi * np.angle(np.exp(1j*(x-dp+180) * pi / 180))
  710. y = (r0 + rp*np.exp(-angles_p**2 / (2*sigma**2)) + rn*np.exp(-angles_n**2 / (2*sigma**2)))
  711. return y