util.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. import numpy as np
  2. import pandas as pd
  3. import quantities as pq
  4. from matplotlib import pyplot as plt
  5. from scipy.interpolate import interp1d
  6. from statsmodels.stats.multitest import fdrcorrection
  7. from scipy.stats import norm, sem
  8. from sklearn.mixture import GaussianMixture
  9. from neo import SpikeTrain
  10. from elephant.statistics import instantaneous_rate
  11. from elephant.kernels import GaussianKernel
  12. from parameters import *
  13. ## Data handling
  14. def df2keys(df):
  15. """Return MSE keys for all entries in a DataFrame."""
  16. df = df.reset_index()
  17. keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
  18. return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
  19. def key2idx(key):
  20. """Return DataFrame index tuple for the given key."""
  21. return tuple([val for k, val in key.items()])
  22. def print_Ns(df):
  23. """
  24. Print number of mice, recording sessions, experiments, and neurons
  25. contained in a dataset.
  26. Parameters
  27. ----------
  28. df : pandas.DataFrame
  29. Dataframe with 'm', 's', 'e', and 'u' identifiers in the columns or index.
  30. """
  31. df = df.reset_index()
  32. print(f"No. mice: {len(df.groupby('m'))}")
  33. print(f"No. series: {len(df.groupby(['m', 's']))}")
  34. print(f"No. experiments: {len(df.groupby(['m', 's', 'e']))}")
  35. if 'u' in df.columns:
  36. print(f"No. units: {len(df.groupby(['m', 's', 'e', 'u']))}")
  37. # Print number of units from each experiment
  38. for idx, experiment in df.groupby(['m', 's', 'e']):
  39. print(f"\n{idx[0]} s{idx[1]:02d} e{idx[2]:02d}")
  40. print(f" {len(experiment)} units")
  41. def load_data(data, conditions, ext=''):
  42. """
  43. Load data of a specified type, concatenating all the requested conditions.
  44. Parameters
  45. ----------
  46. conditions : list
  47. List of string specifying the experimental conditions to load.
  48. ext : str
  49. String specifying the filename extension to load, can specify the
  50. spike type (e.g. 'burst') or data sub-sample (e.g. 'sizematched').
  51. """
  52. dfs = []
  53. for condition in conditions:
  54. filename = '{}_{}'.format(data, condition)
  55. if ext:
  56. filename = filename + '_{}'.format(ext)
  57. filename = filename + '.pkl'
  58. df = pd.read_pickle(DATAPATH + filename)
  59. if 'condition' not in df.columns:
  60. df['condition'] = condition
  61. dfs.append(df)
  62. return pd.concat(dfs, axis='index')
  63. ## Plotting utils
  64. def set_plotsize(w, h=None, ax=None):
  65. """
  66. Set the size of a matplotlib axes object in cm.
  67. Parameters
  68. ----------
  69. w, h : float
  70. Desired width and height of plot, if height is None, the axis will be
  71. square.
  72. ax : matplotlib.axes
  73. Axes to resize, if None the output of plt.gca() will be re-sized.
  74. Notes
  75. -----
  76. - Use after subplots_adjust (if adjustment is needed)
  77. - Matplotlib axis size is determined by the figure size and the subplot
  78. margins (r, l; given as a fraction of the figure size), i.e.
  79. w_ax = w_fig * (r - l)
  80. """
  81. if h is None: # assume square
  82. h = w
  83. w /= 2.54 # convert cm to inches
  84. h /= 2.54
  85. if not ax: # get current axes
  86. ax = plt.gca()
  87. # get margins
  88. l = ax.figure.subplotpars.left
  89. r = ax.figure.subplotpars.right
  90. t = ax.figure.subplotpars.top
  91. b = ax.figure.subplotpars.bottom
  92. # set fig dimensions to produce desired ax dimensions
  93. figw = float(w)/(r-l)
  94. figh = float(h)/(t-b)
  95. ax.figure.set_size_inches(figw, figh)
  96. def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
  97. """
  98. Clip the axis lines to end at the minimum and maximum tick values.
  99. Parameters
  100. ----------
  101. ax : matplotlib.axes
  102. Axes to resize, if None the output of plt.gca() will be re-sized.
  103. spines : list
  104. Axes to keep and clip, axes not included in this list will be removed.
  105. Valid values include 'left', 'bottom', 'right', 'top'.
  106. ext : dict
  107. For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
  108. the axis line will be extended beyond the last tick by the value
  109. specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
  110. that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
  111. the top tick.
  112. """
  113. if ax is None:
  114. ax = plt.gca()
  115. spines2ax = {
  116. 'left': ax.yaxis,
  117. 'top': ax.xaxis,
  118. 'right': ax.yaxis,
  119. 'bottom': ax.xaxis
  120. }
  121. all_spines = ['left', 'bottom', 'right', 'top']
  122. for spine in spines:
  123. low = min(spines2ax[spine].get_majorticklocs())
  124. high = max(spines2ax[spine].get_majorticklocs())
  125. if spine in ext.keys():
  126. low += ext[spine][0]
  127. high += ext[spine][1]
  128. ax.spines[spine].set_bounds(low, high)
  129. for spine in [spine for spine in all_spines if spine not in spines]:
  130. ax.spines[spine].set_visible(False)
  131. def p2stars(p):
  132. if p <= 0.0001:
  133. return '***'
  134. elif p <= 0.001:
  135. return '**'
  136. elif p<= 0.05:
  137. return '*'
  138. else:
  139. return ''
  140. def violin_plot(dists, colors, ax=None, logscale=False):
  141. if type(colors) is list:
  142. assert len(colors) == len(dists)
  143. if ax is None:
  144. fig, ax = plt.subplots()
  145. violins = ax.violinplot(dists, showmedians=True, showextrema=False)
  146. for violin, color in zip(violins['bodies'], colors):
  147. violin.set_facecolor('none')
  148. violin.set_edgecolor(color)
  149. violin.set_alpha(1)
  150. violin.set_linewidth(2)
  151. violins['cmedians'].set_color('black')
  152. for pos, dist in enumerate(dists):
  153. median = np.median(dist)
  154. text = f'{median:.2f}'
  155. if logscale:
  156. text = f'{10 ** median:.2f}'
  157. ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
  158. ax.set_xticks(np.arange(len(dists)) + 1)
  159. ax.tick_params(bottom=False)
  160. return ax
  161. def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
  162. """
  163. Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
  164. Parameters
  165. ----------
  166. df : pandas.DataFrame
  167. Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
  168. cmap : str or Matplotlib colormap object
  169. max_rate : str
  170. Is the max event rate expected to at 'high' or 'low' pupil sizes?
  171. example : dict
  172. If not none, MSEU key passed will be used to highlight example neuron.
  173. """
  174. fig = plt.figure()
  175. # Find out which pupil size bin has the highest firing rate
  176. df['tuning_max'] = df['area_means'].apply(np.argmax)
  177. # Min-max normalize firing rates for each neuron
  178. df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
  179. # Get heatmap for neurons with significant rate differences across pupil size bins
  180. df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
  181. heatmap_sig = np.row_stack(df_sig['tuning_norm'])
  182. n_sig = len(df_sig) # number of neurons with significant differences
  183. # Make axis with size proportional to the fraction of significant neurons
  184. n_units = len(df) # total number of neurons
  185. ax1_height = n_sig / n_units
  186. ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
  187. # Plot heatmap for signifcant neurons
  188. mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
  189. # Scatter plot marking pupil size bin with maximum
  190. ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
  191. # Format axes
  192. ax1.set_xticks([])
  193. yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig) # ticks every 20 neurons, and final count
  194. ax1.set_yticks(yticks - 0.5)
  195. ax1.set_yticklabels(yticks)
  196. ax1.set_ylabel('Neurons')
  197. # Make a colorbar
  198. cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
  199. cbar.ax.set_xticklabels(['min', 'max'])
  200. cbar.ax.set_xlabel('Spikes', labelpad=-5)
  201. # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
  202. # (highest event rate at one of the pupil size extremes)
  203. if max_rate == 'high':
  204. n_mon = len(df_sig.query('tuning_max >= 9'))
  205. print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
  206. ax1.axvline(8.5, lw=2, ls='--', color='white')
  207. ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
  208. elif max_rate == 'low':
  209. n_mon = len(df_sig.query('tuning_max <= 1'))
  210. print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
  211. ax1.axvline(0.5, lw=2, ls='--', color='white')
  212. ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
  213. # Get heatmap for neurons without significant rate differences acrss pupil sizes
  214. df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
  215. heatmap_ns = np.row_stack(df_ns['tuning_norm'])
  216. # Make axis with size proportional to the fraction of non-significant neurons
  217. n_ns = len(df_ns)
  218. ax2_height = n_ns / n_units
  219. ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
  220. # Plot heatmap for neurons without significant rate differences acrss pupil sizes
  221. mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
  222. ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
  223. # Format axes
  224. ax2.xaxis.set_ticks_position('bottom')
  225. ax2.set_xticks([-0.5, 4.5, 9.5])
  226. ax2.set_xticklabels([0, 0.5, 1]) # x-axis ticks mark percentiles of pupil size range
  227. ax2.set_xlim(right=9.5)
  228. ax2.set_xlabel('Pupil size (norm.)')
  229. yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns) # ticks every 20 neurons, and final count
  230. ax2.set_yticks(yticks - 0.5)
  231. ax2.set_yticklabels(yticks)
  232. # Highligh example neuron by circling scatter dot
  233. if example is not None:
  234. try: # first check if example neuron is among significant neurons
  235. is_ex = df_sig.index == tuple([v for k, v in example.items()])
  236. assert is_ex.any()
  237. ex_max = df_sig['tuning_max'][is_ex]
  238. ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
  239. except:
  240. is_ex = df_ns.index == tuple([v for k, v in example.items()])
  241. ex_max = df_ns['tuning_max'][is_ex]
  242. ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
  243. return fig
  244. def cumulative_histogram(data, bins, color='C0', ax=None):
  245. """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
  246. if ax is None:
  247. fig, ax = plt.subplots()
  248. weights = np.ones_like(data) / len(data)
  249. counts, _ = np.histogram(data, bins=bins, weights=weights)
  250. ax.plot(bins[:-1], np.cumsum(counts), color=color)
  251. return ax
  252. def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
  253. weights = np.ones_like(x)
  254. if density:
  255. weights = weights / len(x)
  256. counts, _ = np.histogram(x, bins=bins, weights=weights)
  257. if ax is None:
  258. fig, ax = plt.subplots()
  259. xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
  260. ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
  261. ax.plot(xs, ys, lw=2, **kwargs)
  262. ax.set_xticks(bins + 0.5)
  263. ax.set_yticks([0, 0.5, 1])
  264. return ax, counts
  265. def phase_coupling_scatter(df, ax=None):
  266. """Phase-frequency scatter plot for phase coupling."""
  267. if ax is None:
  268. fig, ax = plt.subplots()
  269. for event in ['tonicspk', 'burst']:
  270. df_sig = df.query(f'{event}_sig == True')
  271. ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
  272. n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
  273. ax.set_xticks(FREQUENCYTICKS)
  274. ax.set_xticklabels(FREQUENCYTICKLABELS)
  275. ax.set_xlim(left=-3.075)
  276. ax.set_xlabel("Inverse timescale (s$^{-1}$)")
  277. ax.set_yticks(PHASETICKS)
  278. ax.set_yticklabels(PHASETICKLABELS)
  279. ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
  280. ax.set_ylabel("Preferred phase")
  281. return ax
  282. def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
  283. """Plot a circular histogram."""
  284. if ax is None:
  285. fig, ax = plt.subplots(subplot_kw={'polar':True})
  286. weights = np.ones_like(angles)
  287. if density:
  288. weights /= len(angles)
  289. counts, bins = np.histogram(angles, bins=bins, weights=weights)
  290. xs = bins + (np.pi / (len(bins) - 1))
  291. ys = np.append(counts, counts[0])
  292. ax.plot(xs, ys, **kwargs)
  293. ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
  294. ax.set_xticklabels(['0', '\u03C0/2', '\u03C0', '3\u03C0/2'])
  295. ax.tick_params(axis='x', pad=-5)
  296. return ax, counts
  297. def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
  298. """
  299. Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
  300. """
  301. if ax is None:
  302. fig, ax = plt.subplots()
  303. for event in ['burst', 'tonicspk']:
  304. df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
  305. strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
  306. ys = np.array([agg(s) for s in strengths])
  307. yerr = np.array([err(s) for s in strengths])
  308. if not logscale:
  309. ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
  310. ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
  311. ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
  312. else:
  313. ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
  314. ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
  315. ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
  316. ax.set_xticks(FREQUENCYTICKS)
  317. ax.set_xticklabels(FREQUENCYTICKLABELS)
  318. ax.set_xlim(left=-3.1)
  319. ax.set_xlabel('Timescale (Hz)')
  320. ax.set_ylabel('Coupling strength')
  321. return ax
  322. def compare_coupling_strength_preferences(df, conds, spk_type):
  323. """
  324. Make a pie chart showing the proportion of neurons that either lose their
  325. coupling, retain the IMF to which they most strongly couple, or shift the
  326. IMF to which they most strongly couple across two conditions. Also
  327. highlight the proportion of neurons in the latter two categories that
  328. increase their coupling in the second condition compared to the first.
  329. Parameters
  330. ----------
  331. df : pandas.DataFrame
  332. Merged dataframe combining the output of two phase coupling analyses
  333. performed on the same set of neurons/ IMFs across two conditions.
  334. Columns must have suffixes indicating the condition they come from.
  335. conds : list
  336. List of strings indicating suffixes for each of the two conditions.
  337. spk_type : str
  338. String indicating for which spike type coupling is to be compared.
  339. """
  340. # Start by restricting to neuron-IMF pairs with significant coupling in the first condition
  341. df_sig = df.query(f'({spk_type}_sig_{conds[0]} == True)')
  342. # Define some helper functions to be applied to each neuron
  343. def _pref_timescale(unit):
  344. # Find out if IMF with stongest coupling is the same across conditions
  345. pref_cond0 = np.argmax(unit[f'{spk_type}_strength_{conds[0]}'])
  346. pref_cond1 = np.argmax(unit[f'{spk_type}_strength_{conds[1]}'])
  347. return pref_cond0 != pref_cond1
  348. def _max_strength(unit):
  349. # Get the difference in strength for the strongest IMF across conditions
  350. strength_cond0 = unit[f'{spk_type}_strength_{conds[0]}'].max()
  351. strength_cond1 = unit[f'{spk_type}_strength_{conds[1]}'].max()
  352. return strength_cond1 - strength_cond0
  353. # Group data by neuron
  354. units = df_sig.groupby(['m', 's', 'e', 'u'])
  355. # Does the neuron have any significant coupling?
  356. sig = (units.apply(lambda x: x[f'{spk_type}_sig_{conds[1]}'].sum() > 0))
  357. # Does the strongest timescale shift across conditions?
  358. shift = units.apply(_pref_timescale)
  359. # What is the difference in max strength between conditions?
  360. strength = units.apply(_max_strength)
  361. # Get slices of the pie
  362. no_coupling = (~sig).sum()
  363. imf_same_down = (sig & ~shift & (strength < 0)).sum() # max strength lower in second condition
  364. imf_same_up = (sig & ~shift & (strength >= 0)).sum() # max strength higher in second condition
  365. imf_shift_down = (sig & shift & (strength < 0)).sum() # change in IMF with strongest coupling
  366. imf_shift_up = (sig & shift & (strength >= 0)).sum()
  367. total = np.sum([no_coupling, imf_shift_down, imf_shift_up, imf_same_up, imf_same_down])
  368. assert total == len(units) # check categories are mutually exclusive
  369. # Print proportions in each category
  370. print(f"No coupling: {(no_coupling / total * 100):.1f}%")
  371. print(f"Timescale same: {(np.sum([imf_same_up, imf_same_down]) / total * 100):.1f}%")
  372. print(f"Timescale shift: {(np.sum([imf_shift_up, imf_shift_down]) / total * 100):.1f}%")
  373. print(f"Increased coupling: {(np.sum([imf_shift_up, imf_same_up]) / total * 100):.1f}%\n")
  374. # Make pie chart
  375. fig, ax = plt.subplots()
  376. slices, text = ax.pie([no_coupling, imf_shift_down, imf_shift_up, imf_same_up, imf_same_down],
  377. explode=[0., 0., 0.1, 0.1, 0.], # pop-out slices representing max strength increase
  378. colors=['gray', COLORS[spk_type], COLORS[spk_type], COLORS[spk_type], COLORS[spk_type]])
  379. slices[1].set_alpha(0.5) # faded color for slices representing strongest IMF shift
  380. slices[2].set_alpha(0.5)
  381. plt.legend(slices[::2], ['No coupling', 'Timescale shift', 'Timescale retained'], loc=(1, 0.75), frameon=False)
  382. ax.xaxis.set_visible(False)
  383. return ax
  384. ## Util
  385. def zero_runs(a):
  386. """
  387. Return an array with shape (m, 2), where m is the number of "runs" of zeros
  388. in a. The first column is the index of the first 0 in each run, the second
  389. is the index of the first nonzero element after the run.
  390. """
  391. # Create an array that is 1 where a is 0, and pad each end with an extra 0.
  392. iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
  393. absdiff = np.abs(np.diff(iszero))
  394. # Runs start and end where absdiff is 1.
  395. ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
  396. return ranges
  397. def merge_ranges(ranges, dt=1):
  398. """
  399. Given a set of ranges [start, stop], return new set of ranges where all
  400. overlapping ranges are merged.
  401. """
  402. tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
  403. tc = np.ones_like(tpts) # time course of ranges
  404. for t0, t1 in ranges: # for each range
  405. i0, i1 = tpts.searchsorted([t0, t1])
  406. tc[i0:i1] = 0 # set values in range to 0
  407. new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
  408. if new_ranges[-1, -1] == len(tpts): # fix end-point
  409. new_ranges[-1, -1] = len(tpts) - 1
  410. return tpts[new_ranges]
  411. def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
  412. """
  413. Get start and stop indices of stretches of (relatively) continuous data.
  414. Parameters
  415. ----------
  416. data : ndarray
  417. 1D boolean array
  418. max0len : int
  419. maximum length (in data pts) of False stretches to ignore
  420. min1len : int
  421. minimum length (in data pts) of True runs to keep
  422. min1prop : int
  423. minimum proprtion of True data in the run necessary for it
  424. to be considered
  425. Returns
  426. -------
  427. out : ndarray
  428. (m, 2) array of start and stop indices, where m is the number runs of
  429. continuous True values
  430. """
  431. # get ranges of True values
  432. one_ranges = zero_runs(~data)
  433. if len(one_ranges) == 0:
  434. return np.array([[]])
  435. # merge ranges that are separated by < min0len of False
  436. one_ranges[:, 1] += (max0len - 1)
  437. one_ranges = merge_ranges(one_ranges)
  438. # return indices to normal
  439. one_ranges[:, 1] -= (max0len - 1)
  440. one_ranges[-1, -1] += 1
  441. # remove ranges that are too short
  442. lengths = np.diff(one_ranges, axis=1).ravel()
  443. one_ranges = one_ranges[lengths >= min1len]
  444. # remove ranges that don't have sufficient proportion True
  445. prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
  446. return one_ranges[prop >= min1prop]
  447. def switch_ranges(ranges, dt=1, minval=0, maxval=None):
  448. """
  449. Given a set of (start, stop) pairs, return a new set of pairs for values
  450. outside the given ranges.
  451. Parameters
  452. ----------
  453. ranges : ndarray
  454. N x 2 array containing start and stop values in the first and second
  455. columns respectively
  456. dt : float
  457. minval, maxval : int
  458. the minimum and maximum possible values, if maxval is None it is assumed
  459. that the maximum possible value is the maximum value in the input ranges
  460. Returns
  461. -------
  462. out : ndarray
  463. M x 2 array containing start and stop values of all ranges outside of
  464. the input ranges
  465. """
  466. if ranges.shape[1] == 0:
  467. return np.array([[minval, maxval]])
  468. assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
  469. maxval = ranges.max() if maxval is None else maxval
  470. # get new ranges
  471. new_ranges = np.zeros_like(ranges)
  472. new_ranges[:,0] = ranges[:,0] - dt # new stop values
  473. new_ranges[:,1] = ranges[:,1] + dt # new start values
  474. # fix boundaries
  475. new_ranges = new_ranges.ravel()
  476. if new_ranges[0] >= (minval + dt): # first new stop within allowed range
  477. new_ranges = np.concatenate((np.array([minval]), new_ranges))
  478. else:
  479. new_ranges = new_ranges[1:]
  480. if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
  481. new_ranges = np.concatenate((new_ranges, np.array([maxval])))
  482. else:
  483. new_ranges = new_ranges[:-1]
  484. return new_ranges.reshape((int(len(new_ranges) / 2), 2))
  485. def shuffle_bins(x, binwidth=1):
  486. """
  487. Randomly shuffle bins of an array.
  488. """
  489. # bin start indices
  490. bins_i0 = np.arange(0, len(x), binwidth)
  491. # shuffled bins
  492. np.random.shuffle(bins_i0)
  493. # concatenate shuffled bins
  494. shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
  495. return shf
  496. def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
  497. if series['%s_bouts' % bouts].shape[1] < 1:
  498. return np.array([])
  499. header, _ = data.split('_')
  500. data_in_bouts = []
  501. for t0, t1 in series['%s_bouts' % bouts]:
  502. t0 -= dt0
  503. t1 += dt1
  504. if trange == 'start':
  505. t1 = t0 + dt
  506. elif trange == 'end':
  507. t0 = t1 - dt
  508. elif trange == 'middle':
  509. t0 = t0 + dt
  510. t1 = t1 - dt
  511. if t1 <= t0:
  512. continue
  513. if t0 < series['%s_tpts' % header].min():
  514. continue
  515. if t1 > series['%s_tpts' % header].max():
  516. continue
  517. i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
  518. data_in_bout = series[data][i0:i1].copy()
  519. if norm:
  520. data_in_bout = data_in_bout / series[data].max()
  521. data_in_bouts.append(data_in_bout)
  522. if concatenate:
  523. data_in_bouts = np.concatenate(data_in_bouts)
  524. return data_in_bouts
  525. def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
  526. if opto:
  527. opto = np.isin(series['trial_id'], series['opto_trials'])
  528. elif not opto:
  529. opto = ~np.isin(series['trial_id'], series['opto_trials'])
  530. if stim_id < 0:
  531. stim = np.ones_like(series['stim_id']).astype('bool')
  532. else:
  533. stim = series['stim_id'] == stim_id
  534. series['trial_on_times'] = series['trial_on_times'][opto & stim]
  535. series['trial_off_times'] = series['trial_off_times'][opto & stim]
  536. return series
  537. def sort_data(data, sort_vals, bins=10):
  538. if type(bins) == int:
  539. nbins = bins
  540. bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
  541. else:
  542. nbins = len(bins) - 1
  543. bin_edges = bins
  544. digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
  545. return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
  546. def apply_sort_data(series, data_col, sort_col, bins=10):
  547. return sort_data(series[data_col], series[sort_col], bins)
  548. ## Statistics
  549. def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
  550. # Get bins base on percentiles to eliminate effect of outliers
  551. min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
  552. #min_area, max_area = pupil_area.min(), pupil_area.max()
  553. area_bins = np.linspace(min_area, max_area, nbins + 1)
  554. # Bin pupil area
  555. binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
  556. # Bin rates according to pupil area
  557. binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
  558. if sort:
  559. sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
  560. binned_rates = binned_rates[sorted_inds]
  561. binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
  562. return area_bins, binned_area, binned_rates
  563. def rescale(x, method='min_max'):
  564. # Set re-scaling method
  565. if method == 'z_score':
  566. return (x - x.mean()) / x.std()
  567. elif method == 'min_max':
  568. return (x - x.min()) / (x.max() - x.min())
  569. def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
  570. if ts2 is None:
  571. ts2 = ts1.copy()
  572. auto = True
  573. else:
  574. auto = False
  575. tau_max = (tau_max // dtau) * dtau
  576. tbins = np.arange(-tau_max, tau_max + dtau, dtau)
  577. ccg = np.zeros(len(tbins) - 1)
  578. for t0 in ts1:
  579. dts = ts2 - t0
  580. if auto:
  581. dts = dts[dts != 0]
  582. ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
  583. ccg /= len(ts1)
  584. if not return_tpts:
  585. return ccg
  586. else:
  587. tpts = tbins[:-1] + dtau / 2
  588. return tpts, ccg
  589. def angle_subtract(a1, a2, period=(2 * np.pi)):
  590. """
  591. Compute the pair-wise difference between two sets of angles on a given period.
  592. """
  593. return (a1 - a2) % period
  594. def circmean(alpha, w=None, axis=None):
  595. """
  596. Compute mean resultant vector of circular data.
  597. Parameters
  598. ----------
  599. alpha : ndarray
  600. array of angles
  601. w : ndarray
  602. array of weights, must be same shape as alpha
  603. axis : int, None
  604. axis across which to compute mean
  605. Returns
  606. -------
  607. mrl : ndarray
  608. mean resultant vector length
  609. theta : ndarray
  610. mean resultant vector angle
  611. """
  612. # weights default to ones
  613. if w is None:
  614. w = np.ones_like(alpha)
  615. w[np.isnan(alpha)] = 0
  616. # compute weighted mean
  617. mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
  618. mrl = np.abs(mean_vector) # length
  619. theta = np.angle(mean_vector) # angle
  620. return mrl, theta
  621. def circmean_angle(alpha, **kwargs):
  622. return circmean(alpha, **kwargs)[1]
  623. def circhist(angles, n_bins=8, proportion=True, wrap=False):
  624. bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
  625. weights = np.ones(len(angles))
  626. if proportion:
  627. weights /= len(angles)
  628. counts, bins = np.histogram(angles, bins=bins, weights=weights)
  629. if wrap:
  630. counts = np.concatenate([counts, [counts[0]]])
  631. bins = np.concatenate([bins, [bins[0]]])
  632. return counts, bins
  633. def unbiased_variance(data):
  634. if len(data) <= 1:
  635. return np.nan
  636. else:
  637. return np.var(data) * len(data) / (len(data) - 1)
  638. def se_median(sample, n_resamp=1000):
  639. """Standard error of the median."""
  640. medians = np.full(n_resamp, np.nan)
  641. for i in range(n_resamp):
  642. resample = np.random.choice(sample, len(sample), replace=True)
  643. medians[i] = np.median(resample)
  644. return np.std(medians)
  645. def coupling_summary(df):
  646. """
  647. Print the number of units with significant phase coupling for either spike
  648. type to at least one IMF, the number of units with significant coupling for
  649. each spike type, and the mean number of IMFs to which a single unit has
  650. significant coupling for each spike type.
  651. Parameters
  652. ----------
  653. df : pandas.DataFrame
  654. Dataframe output from the phase_tuning.py analysis script.
  655. """
  656. # For either spike type
  657. units = df.groupby(['m', 's', 'e', 'u'])
  658. # Check each unit for significance
  659. n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
  660. prop_sig = n_sig / len(units)
  661. print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
  662. # For each spike type
  663. for spk_type in ['tonicspk', 'burst']:
  664. # Remove units where coupling for the spike type was not assessed
  665. units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
  666. n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
  667. prop_sig = n_sig / len(units)
  668. print(f"Neurons with significant {spk_type} coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
  669. n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
  670. print(f" Mean number of CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
  671. def kl_divergence(p, q):
  672. return np.sum(np.where(p != 0, p * np.log(p / q), 0))
  673. def match_distributions(x1, x2, x1_bins, x2_bins):
  674. """
  675. For two time series, x2 & x2, return indices of sub-sampled time
  676. periods such that the distribution of x2 is matched across
  677. bins of x1.
  678. """
  679. x1_nbins = len(x1_bins) - 1
  680. x2_nbins = len(x2_bins) - 1
  681. # bin x1
  682. x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
  683. # get continuous periods where x1 visits each bin
  684. x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
  685. # get mean of x2 for each x1 bin visit
  686. x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
  687. # find minimum common distribution across x1 bins
  688. x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
  689. x2_mcd = x2_counts.min(axis=0)
  690. # bin x2 means
  691. x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
  692. x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
  693. x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
  694. # loop over x2 bins
  695. matched_ranges = [[], [], [], []]
  696. for x2_bin in np.arange(x2_nbins):
  697. # find the x1 bin matching the MCD
  698. seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
  699. assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
  700. # for each bin visit, find the closest matching mean in the other x1 bins
  701. target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
  702. target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
  703. for target_mean, target_range in zip(target_means, target_ranges):
  704. matched_ranges[seed_x1_bin].append(target_range)
  705. for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
  706. matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
  707. matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
  708. # delete the matching period
  709. x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
  710. x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
  711. return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
  712. ## Signal processing
  713. def normalized_xcorr(a, b, dt=None, ts=None):
  714. """
  715. Compute Pearson r between two arrays at various lags
  716. Parameters
  717. ----------
  718. a, b : ndarray
  719. The arrays to correlate.
  720. dt : float
  721. The time step between samples in the arrays.
  722. ts : list
  723. If not None, only the xcorr between the specified lags will be
  724. returned.
  725. Return
  726. ------
  727. xcorr, lags : ndarray
  728. The cross correlation and corresponding lags between a and b.
  729. Positive lags indicate that a is delayed relative to b.
  730. """
  731. assert len(a) == len(b)
  732. n = len(a)
  733. a_norm = (a - a.mean()) / a.std()
  734. b_norm = (b - b.mean()) / b.std()
  735. xcorr = np.correlate(a_norm, b_norm, 'full') / n
  736. lags = np.arange(-n + 1, n)
  737. if dt is not None:
  738. lags = lags * dt
  739. if ts is not None:
  740. assert len(ts) == 2
  741. i0, i1 = lags.searchsorted(ts)
  742. xcorr = xcorr[i0:i1]
  743. lags = lags[i0:i1]
  744. return xcorr, lags
  745. def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
  746. """
  747. Use linear interpolation to re-sample 1D data.
  748. """
  749. # get interpolation function
  750. func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
  751. # get new y-values
  752. y_interpolated = func(x_new)
  753. return y_interpolated
  754. def interpolate_and_normalize(y, x_old, x_new):
  755. """
  756. Perform linear interpolation and min-max normalization.
  757. """
  758. y_new = interpolate(y, x_old, x_new)
  759. return (y_new - y_new.min()) / (y_new.max() - y_new.min())
  760. def match_signal_length(a, b, a_tpts, b_tpts):
  761. """
  762. Given two signals, truncate to match the length of the shortest.
  763. """
  764. t1 = min(a_tpts.max(), b_tpts.max())
  765. a1 = a[:a_tpts.searchsorted(t1)]
  766. a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
  767. b1 = b[:b_tpts.searchsorted(t1)]
  768. b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
  769. return a1, b1, a1_tpts, b1_tpts
  770. def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
  771. if type(t0) == str:
  772. t0, t1 = series[t0].min(), series[t0].max()
  773. elif t0 is None: # get overlapping time range for all columns
  774. t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
  775. elif t1 is None:
  776. t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
  777. # Set time base
  778. tbins = np.arange(t0, t1, dt)
  779. tpts = tbins[:-1] + (dt / 2)
  780. for col in columns:
  781. header = col.split('_')[0]
  782. times = series[f'{header}_times']
  783. counts, _ = np.histogram(times, bins=tbins)
  784. series[f'{header}_counts'] = counts
  785. series[f'{header}_tpts'] = tpts
  786. return series
  787. def resample_timeseries(y, tpts, dt=0.25):
  788. tpts_new = np.arange(tpts.min(), tpts.max(), dt)
  789. return tpts_new, interpolate(y, tpts, tpts_new)
  790. def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
  791. if type(t0) == str:
  792. t0, t1 = series[t0].min(), series[t0].max()
  793. elif t0 is None: # get overlapping time range for all columns
  794. t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
  795. t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
  796. # Set new time base
  797. tbins = np.arange(t0, t1, dt)
  798. tpts_new = tbins[:-1] + (dt / 2)
  799. # Interpolate and re-sample each column
  800. for col in columns:
  801. header = col.split('_')[0]
  802. data = series[col]
  803. tpts = series[f'{header}_tpts']
  804. series[col] = interpolate(data, tpts, tpts_new)
  805. series[f'{header}_tpts'] = tpts_new
  806. return series
  807. ## Neural activity
  808. def get_mean_rates(df):
  809. df['mean_rate'] = df.apply(
  810. lambda x:
  811. len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
  812. axis='columns'
  813. )
  814. return df
  815. def get_mean_rate_threshold(df, alpha=0.025):
  816. if 'mean_rate' not in df.columns:
  817. df = get_mean_rates(df)
  818. rates = np.log10(df['mean_rate'])
  819. gmm = GaussianMixture(n_components=2)
  820. gmm.fit(rates[..., np.newaxis])
  821. (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
  822. threshold = mu + norm.ppf(alpha) * np.sqrt(var)
  823. return threshold
  824. def filter_units(df, threshold):
  825. if 'mean_rate' not in df.columns:
  826. df = get_mean_rates(df)
  827. return df.query(f'mean_rate >= {threshold}')
  828. def apply_get_raster(series, events, spike_type='spk', pre=0, post=1):
  829. events = series['%s_times' % events]
  830. spks = series['%s_times' % spike_type]
  831. raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
  832. return raster
  833. def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
  834. rate_kernel = GaussianKernel(bw*pq.s)
  835. tpts = np.arange(pre, post, dt)
  836. psth = np.full((len(events), len(tpts)), np.nan)
  837. for i, t0 in enumerate(events):
  838. rel_ts = spikes - t0
  839. rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
  840. try:
  841. rate = instantaneous_rate(
  842. SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
  843. sampling_period=dt*pq.s,
  844. kernel=rate_kernel
  845. )
  846. except:
  847. continue
  848. psth[i] = rate.squeeze()
  849. if baseline:
  850. b0, b1 = tpts.searchsorted(baseline)
  851. baseline_rate = psth[:, b0:b1].mean(axis=1)
  852. psth = (psth.T - baseline_rate).T
  853. return psth, tpts
  854. def apply_get_psth(series, events, spike_type, **kwargs):
  855. events = series['{}_times'.format(events)]
  856. spikes = series['{}_times'.format(spike_type)]
  857. psth, tpts = get_psth(events, spikes, **kwargs)
  858. return psth
  859. def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
  860. dt = np.round(np.diff(tpts).mean(), 3) # round to nearest ms
  861. i_pre, i_post = int(pre / dt), int(post / dt)
  862. responses = np.full((len(events), i_pre + i_post), np.nan)
  863. for j, t0 in enumerate(events):
  864. i = tpts.searchsorted(t0)
  865. i0, i1 = i - i_pre, i + i_post
  866. if i0 < 0:
  867. continue
  868. if i1 > len(data):
  869. break
  870. responses[j] = data[i0:i1]
  871. tpts = np.linspace(pre, post, responses.shape[1])
  872. if baseline:
  873. b0, b1 = tpts.searchsorted(baseline)
  874. baseline_resp = responses[:, b0:b1].mean(axis=1)
  875. responses = (responses.T - baseline_resp).T
  876. return responses, tpts
  877. def apply_get_responses(series, events, data, **kwargs):
  878. events = series[f'{events}_times']
  879. tpts = series[f'{data.split("_")[0]}_tpts']
  880. data = series[f'{data}']
  881. responses, tpts = get_responses(events, data, tpts, **kwargs)
  882. return responses
  883. def rvr(responses):
  884. """
  885. Compute the 'response variability ratio': the ratio of the variance of the
  886. mean response versus the mean across-trial variance (for each timepoint).
  887. Parameters
  888. ----------
  889. responses : ndarray
  890. 2D array of responses where each row is a trial.
  891. """
  892. return np.nanmean(responses, axis=0).var() / np.nanvar(responses, axis=0).mean()