123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891 |
- import numpy as np
- import pandas as pd
- import quantities as pq
- from matplotlib import pyplot as plt
- from scipy.interpolate import interp1d
- from statsmodels.stats.multitest import fdrcorrection
- from scipy.stats import norm, sem
- from sklearn.mixture import GaussianMixture
- from neo import SpikeTrain
- from elephant.statistics import instantaneous_rate
- from elephant.kernels import GaussianKernel
- from parameters import *
- ## Data handling
- def df2keys(df):
- """Return MSE keys for all entries in a DataFrame."""
- df = df.reset_index()
- keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
- return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
- def key2idx(key):
- """Return DataFrame index tuple for the given key."""
- return tuple([val for k, val in key.items()])
- def df_metadata(df):
- """Print metadata for a DataFrame."""
- df = df.reset_index()
- print("No. mice: {}".format(len(df.groupby('m'))))
- print("No. series: {}".format(len(df.groupby(['m', 's']))))
- print("No. experiments: {}".format(len(df.groupby(['m', 's', 'e']))))
- if 'u' in df.columns:
- print("No. units: {}".format(len(df.groupby(['m', 's', 'e', 'u']))))
- for idx, df in df.groupby(['m', 's', 'e']):
- print("\n{} s{:02d} e{:02d}".format(idx[0], idx[1], idx[2]))
- print(" {} units".format(len(df)))
- def load_data(data, conditions, **kwargs):
- """
- Load data of a specified type and recording region, pooling across all
- requested conditions.
- """
- dfs = []
- for condition in conditions:
- filename = '{}_{}'.format(data, condition)
- for kw, arg in kwargs.items():
- filename = filename + '_{}'.format(arg)
- filename = filename + '.pkl'
- print("Loading: ", filename)
- df = pd.read_pickle(DATAPATH + filename)
- if 'condition' not in df.columns:
- df['condition'] = condition
- dfs.append(df)
- return pd.concat(dfs, axis='index')
- ## Plotting utils
- def set_plotsize(w, h=None, ax=None):
- """
- Set the size of a matplotlib axes object in cm.
- Parameters
- ----------
- w, h : float
- Desired width and height of plot, if height is None, the axis will be
- square.
- ax : matplotlib.axes
- Axes to resize, if None the output of plt.gca() will be re-sized.
- Notes
- -----
- - Use after subplots_adjust (if adjustment is needed)
- - Matplotlib axis size is determined by the figure size and the subplot
- margins (r, l; given as a fraction of the figure size), i.e.
- w_ax = w_fig * (r - l)
- """
- if h is None: # assume square
- h = w
- w /= 2.54 # convert cm to inches
- h /= 2.54
- if not ax: # get current axes
- ax = plt.gca()
- # get margins
- l = ax.figure.subplotpars.left
- r = ax.figure.subplotpars.right
- t = ax.figure.subplotpars.top
- b = ax.figure.subplotpars.bottom
- # set fig dimensions to produce desired ax dimensions
- figw = float(w)/(r-l)
- figh = float(h)/(t-b)
- ax.figure.set_size_inches(figw, figh)
- def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
- """
- Clip the axis lines to end at the minimum and maximum tick values.
- Parameters
- ----------
- ax : matplotlib.axes
- Axes to resize, if None the output of plt.gca() will be re-sized.
- spines : list
- Axes to keep and clip, axes not included in this list will be removed.
- Valid values include 'left', 'bottom', 'right', 'top'.
- ext : dict
- For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
- the axis line will be extended beyond the last tick by the value
- specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
- that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
- the top tick.
- """
- if ax is None:
- ax = plt.gca()
- spines2ax = {
- 'left': ax.yaxis,
- 'top': ax.xaxis,
- 'right': ax.yaxis,
- 'bottom': ax.xaxis
- }
- all_spines = ['left', 'bottom', 'right', 'top']
- for spine in spines:
- low = min(spines2ax[spine].get_majorticklocs())
- high = max(spines2ax[spine].get_majorticklocs())
- if spine in ext.keys():
- low += ext[spine][0]
- high += ext[spine][1]
- ax.spines[spine].set_bounds(low, high)
- for spine in [spine for spine in all_spines if spine not in spines]:
- ax.spines[spine].set_visible(False)
- def p2stars(p):
- if p <= 0.0001:
- return '***'
- elif p <= 0.001:
- return '**'
- elif p<= 0.05:
- return '*'
- else:
- return ''
- def violin_plot(dists, colors, ax=None, logscale=False):
- if type(colors) is list:
- assert len(colors) == len(dists)
- if ax is None:
- fig, ax = plt.subplots()
- violins = ax.violinplot(dists, showmedians=True, showextrema=False)
- for violin, color in zip(violins['bodies'], colors):
- violin.set_facecolor('none')
- violin.set_edgecolor(color)
- violin.set_alpha(1)
- violin.set_linewidth(2)
- violins['cmedians'].set_color('black')
- for pos, dist in enumerate(dists):
- median = np.median(dist)
- text = f'{median:.2f}'
- if logscale:
- text = f'{10 ** median:.2f}'
- ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
- ax.set_xticks(np.arange(len(dists)) + 1)
- ax.tick_params(bottom=False)
- return ax
- def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
- """
- Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
- Parameters
- ----------
- df : pandas.DataFrame
- Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
-
- cmap : str or Matplotlib colormap object
- max_rate : str
- Is the max event rate expected to at 'high' or 'low' pupil sizes?
- example : dict
- If not none, MSEU key passed will be used to highlight example neuron.
- """
- fig = plt.figure()
-
- # Find out which pupil size bin has the highest firing rate
- df['tuning_max'] = df['area_means'].apply(np.argmax)
- # Min-max normalize firing rates for each neuron
- df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
- # Get heatmap for neurons with significant rate differences across pupil size bins
- df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
- heatmap_sig = np.row_stack(df_sig['tuning_norm'])
- n_sig = len(df_sig) # number of neurons with significant differences
- # Make axis with size proportional to the fraction of significant neurons
- n_units = len(df) # total number of neurons
- ax1_height = n_sig / n_units
- ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
- # Plot heatmap for signifcant neurons
- mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
- # Scatter plot marking pupil size bin with maximum
- ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
- # Format axes
- ax1.set_xticks([])
- yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig) # ticks every 20 neurons, and final count
- ax1.set_yticks(yticks - 0.5)
- ax1.set_yticklabels(yticks)
- ax1.set_ylabel('Neurons')
- # Make a colorbar
- cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
- cbar.ax.set_xticklabels(['min', 'max'])
- cbar.ax.set_xlabel('Spikes', labelpad=-5)
- # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
- # (highest event rate at one of the pupil size extremes)
- if max_rate == 'high':
- n_mon = len(df_sig.query('tuning_max >= 9'))
- print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
- ax1.axvline(8.5, lw=2, ls='--', color='white')
- ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
- elif max_rate == 'low':
- n_mon = len(df_sig.query('tuning_max <= 1'))
- print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
- ax1.axvline(0.5, lw=2, ls='--', color='white')
- ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
- # Get heatmap for neurons without significant rate differences acrss pupil sizes
- df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
- heatmap_ns = np.row_stack(df_ns['tuning_norm'])
-
- # Make axis with size proportional to the fraction of non-significant neurons
- n_ns = len(df_ns)
- ax2_height = n_ns / n_units
- ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
- # Plot heatmap for neurons without significant rate differences acrss pupil sizes
- mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
- ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
- # Format axes
- ax2.xaxis.set_ticks_position('bottom')
- ax2.set_xticks([-0.5, 4.5, 9.5])
- ax2.set_xticklabels([0, 0.5, 1]) # x-axis ticks mark percentiles of pupil size range
- ax2.set_xlim(right=9.5)
- ax2.set_xlabel('Pupil size (norm.)')
- yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns) # ticks every 20 neurons, and final count
- ax2.set_yticks(yticks - 0.5)
- ax2.set_yticklabels(yticks)
- # Highligh example neuron by circling scatter dot
- if example is not None:
- try: # first check if example neuron is among significant neurons
- is_ex = df_sig.index == tuple([v for k, v in example.items()])
- assert is_ex.any()
- ex_max = df_sig['tuning_max'][is_ex]
- ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
- except:
- is_ex = df_ns.index == tuple([v for k, v in example.items()])
- ex_max = df_ns['tuning_max'][is_ex]
- ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
- return fig
- def cumulative_histogram(data, bins, color='C0', ax=None):
- """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
- if ax is None:
- fig, ax = plt.subplots()
- weights = np.ones_like(data) / len(data)
- counts, _ = np.histogram(data, bins=bins, weights=weights)
- ax.plot(bins[:-1], np.cumsum(counts), color=color)
- return ax
- def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
- weights = np.ones_like(x)
- if density:
- weights = weights / len(x)
- counts, _ = np.histogram(x, bins=bins, weights=weights)
- if ax is None:
- fig, ax = plt.subplots()
- xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
- ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
- ax.plot(xs, ys, lw=2, **kwargs)
- ax.set_xticks(bins + 0.5)
- ax.set_yticks([0, 0.5, 1])
- return ax, counts
- def phase_coupling_scatter(df, ax=None):
- """Phase-frequency scatter plot for phase coupling."""
- if ax is None:
- fig, ax = plt.subplots()
- for event in ['tonicspk', 'burst']:
- df_sig = df.query(f'{event}_sig == True')
- ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
- n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
-
- ax.set_xticks(FREQUENCYTICKS)
- ax.set_xticklabels(FREQUENCYTICKLABELS)
- ax.set_xlim(left=-3.075)
- ax.set_xlabel("Inverse timescale (s$^{-1}$)")
-
- ax.set_yticks(PHASETICKS)
- ax.set_yticklabels(PHASETICKLABELS)
- ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
- ax.set_ylabel("Preferred phase")
- return ax
- def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
- """Plot a circular histogram."""
- if ax is None:
- fig, ax = plt.subplots(subplot_kw={'polar':True})
- weights = np.ones_like(angles)
- if density:
- weights /= len(angles)
- counts, bins = np.histogram(angles, bins=bins, weights=weights)
- xs = bins + (np.pi / (len(bins) - 1))
- ys = np.append(counts, counts[0])
- ax.plot(xs, ys, **kwargs)
- ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
- ax.set_xticklabels(['0', '\u03C0/2', '\u03C0', '3\u03C0/2'])
- ax.tick_params(axis='x', pad=-5)
- return ax, counts
- def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
- """
- Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
- """
- if ax is None:
- fig, ax = plt.subplots()
- for event in ['burst', 'tonicspk']:
- df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
- strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
- ys = np.array([agg(s) for s in strengths])
- yerr = np.array([err(s) for s in strengths])
- if not logscale:
- ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
- ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
- else:
- ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
- ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.set_xticks(FREQUENCYTICKS)
- ax.set_xticklabels(FREQUENCYTICKLABELS)
- ax.set_xlim(left=-3.1)
- ax.set_xlabel('Timescale (Hz)')
- ax.set_ylabel('Coupling strength')
- return ax
-
- ## Util
- def zero_runs(a):
- """
- Return an array with shape (m, 2), where m is the number of "runs" of zeros
- in a. The first column is the index of the first 0 in each run, the second
- is the index of the first nonzero element after the run.
- """
- # Create an array that is 1 where a is 0, and pad each end with an extra 0.
- iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
- absdiff = np.abs(np.diff(iszero))
- # Runs start and end where absdiff is 1.
- ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
- return ranges
- def merge_ranges(ranges, dt=1):
- """
- Given a set of ranges [start, stop], return new set of ranges where all
- overlapping ranges are merged.
- """
- tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
- tc = np.ones_like(tpts) # time course of ranges
- for t0, t1 in ranges: # for each range
- i0, i1 = tpts.searchsorted([t0, t1])
- tc[i0:i1] = 0 # set values in range to 0
- new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
- if new_ranges[-1, -1] == len(tpts): # fix end-point
- new_ranges[-1, -1] = len(tpts) - 1
- return tpts[new_ranges]
- def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
- """
- Get start and stop indices of stretches of (relatively) continuous data.
- Parameters
- ----------
- data : ndarray
- 1D boolean array
- max0len : int
- maximum length (in data pts) of False stretches to ignore
- min1len : int
- minimum length (in data pts) of True runs to keep
- min1prop : int
- minimum proprtion of True data in the run necessary for it
- to be considered
- Returns
- -------
- out : ndarray
- (m, 2) array of start and stop indices, where m is the number runs of
- continuous True values
- """
- # get ranges of True values
- one_ranges = zero_runs(~data)
- if len(one_ranges) == 0:
- return np.array([[]])
- # merge ranges that are separated by < min0len of False
- one_ranges[:, 1] += (max0len - 1)
- one_ranges = merge_ranges(one_ranges)
- # return indices to normal
- one_ranges[:, 1] -= (max0len - 1)
- one_ranges[-1, -1] += 1
- # remove ranges that are too short
- lengths = np.diff(one_ranges, axis=1).ravel()
- one_ranges = one_ranges[lengths >= min1len]
- # remove ranges that don't have sufficient proportion True
- prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
- return one_ranges[prop >= min1prop]
- def switch_ranges(ranges, dt=1, minval=0, maxval=None):
- """
- Given a set of (start, stop) pairs, return a new set of pairs for values
- outside the given ranges.
- Parameters
- ----------
- ranges : ndarray
- N x 2 array containing start and stop values in the first and second
- columns respectively
- dt : float
- minval, maxval : int
- the minimum and maximum possible values, if maxval is None it is assumed
- that the maximum possible value is the maximum value in the input ranges
- Returns
- -------
- out : ndarray
- M x 2 array containing start and stop values of all ranges outside of
- the input ranges
- """
- if ranges.shape[1] == 0:
- return np.array([[minval, maxval]])
- assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
- maxval = ranges.max() if maxval is None else maxval
- # get new ranges
- new_ranges = np.zeros_like(ranges)
- new_ranges[:,0] = ranges[:,0] - dt # new stop values
- new_ranges[:,1] = ranges[:,1] + dt # new start values
- # fix boundaries
- new_ranges = new_ranges.ravel()
- if new_ranges[0] >= (minval + dt): # first new stop within allowed range
- new_ranges = np.concatenate((np.array([minval]), new_ranges))
- else:
- new_ranges = new_ranges[1:]
- if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
- new_ranges = np.concatenate((new_ranges, np.array([maxval])))
- else:
- new_ranges = new_ranges[:-1]
- return new_ranges.reshape((int(len(new_ranges) / 2), 2))
- def shuffle_bins(x, binwidth=1):
- """
- Randomly shuffle bins of an array.
- """
- # bin start indices
- bins_i0 = np.arange(0, len(x), binwidth)
- # shuffled bins
- np.random.shuffle(bins_i0)
- # concatenate shuffled bins
- shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
- return shf
- def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
- if series['%s_bouts' % bouts].shape[1] < 1:
- return np.array([])
- header, _ = data.split('_')
- data_in_bouts = []
- for t0, t1 in series['%s_bouts' % bouts]:
- t0 -= dt0
- t1 += dt1
- if trange == 'start':
- t1 = t0 + dt
- elif trange == 'end':
- t0 = t1 - dt
- elif trange == 'middle':
- t0 = t0 + dt
- t1 = t1 - dt
- if t1 <= t0:
- continue
- if t0 < series['%s_tpts' % header].min():
- continue
- if t1 > series['%s_tpts' % header].max():
- continue
- i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
- data_in_bout = series[data][i0:i1].copy()
- if norm:
- data_in_bout = data_in_bout / series[data].max()
- data_in_bouts.append(data_in_bout)
- if concatenate:
- data_in_bouts = np.concatenate(data_in_bouts)
- return data_in_bouts
- def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
- if opto:
- opto = np.isin(series['trial_id'], series['opto_trials'])
- elif not opto:
- opto = ~np.isin(series['trial_id'], series['opto_trials'])
- if stim_id < 0:
- stim = np.ones_like(series['stim_id']).astype('bool')
- else:
- stim = series['stim_id'] == stim_id
- series['trial_on_times'] = series['trial_on_times'][opto & stim]
- series['trial_off_times'] = series['trial_off_times'][opto & stim]
- return series
- def sort_data(data, sort_vals, bins=10):
- if type(bins) == int:
- nbins = bins
- bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
- else:
- nbins = len(bins) - 1
- bin_edges = bins
- digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
- return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
- def apply_sort_data(series, data_col, sort_col, bins=10):
- return sort_data(series[data_col], series[sort_col], bins)
- ## Statistics
- def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
- # Get bins base on percentiles to eliminate effect of outliers
- min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
- #min_area, max_area = pupil_area.min(), pupil_area.max()
- area_bins = np.linspace(min_area, max_area, nbins + 1)
- # Bin pupil area
- binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
- # Bin rates according to pupil area
- binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
- if sort:
- sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
- binned_rates = binned_rates[sorted_inds]
- binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
- return area_bins, binned_area, binned_rates
- def rescale(x, method='min_max'):
- # Set re-scaling method
- if method == 'z_score':
- return (x - x.mean()) / x.std()
- elif method == 'min_max':
- return (x - x.min()) / (x.max() - x.min())
- def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
- if ts2 is None:
- ts2 = ts1.copy()
- auto = True
- else:
- auto = False
- tau_max = (tau_max // dtau) * dtau
- tbins = np.arange(-tau_max, tau_max + dtau, dtau)
- ccg = np.zeros(len(tbins) - 1)
- for t0 in ts1:
- dts = ts2 - t0
- if auto:
- dts = dts[dts != 0]
- ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
- ccg /= len(ts1)
- if not return_tpts:
- return ccg
- else:
- tpts = tbins[:-1] + dtau / 2
- return tpts, ccg
- def angle_subtract(a1, a2, period=(2 * np.pi)):
- return (a1 - a2) % period
- def circmean(alpha, w=None, axis=None):
- """
- Compute mean resultant vector of circular data.
- Parameters
- ----------
- alpha : ndarray
- array of angles
- w : ndarray
- array of weights, must be same shape as alpha
- axis : int, None
- axis across which to compute mean
- Returns
- -------
- mrl : ndarray
- mean resultant vector length
- theta : ndarray
- mean resultant vector angle
- """
- # weights default to ones
- if w is None:
- w = np.ones_like(alpha)
- w[np.isnan(alpha)] = 0
- # compute weighted mean
- mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
- mrl = np.abs(mean_vector) # length
- theta = np.angle(mean_vector) # angle
- return mrl, theta
- def circmean_angle(alpha, **kwargs):
- return circmean(alpha, **kwargs)[1]
- def circhist(angles, n_bins=8, proportion=True, wrap=False):
- bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
- weights = np.ones(len(angles))
- if proportion:
- weights /= len(angles)
- counts, bins = np.histogram(angles, bins=bins, weights=weights)
- if wrap:
- counts = np.concatenate([counts, [counts[0]]])
- bins = np.concatenate([bins, [bins[0]]])
- return counts, bins
- def unbiased_variance(data):
- if len(data) <= 1:
- return np.nan
- else:
- return np.var(data) * len(data) / (len(data) - 1)
- def se_median(sample, n_resamp=1000):
- """Standard error of the median."""
- medians = np.full(n_resamp, np.nan)
- for i in range(n_resamp):
- resample = np.random.choice(sample, len(sample), replace=True)
- medians[i] = np.median(resample)
- return np.std(medians)
- def coupling_summary(df):
- """Print some basic statistics for phase coupling."""
- # Either spike type
- units = df.query('(tonicspk_p == tonicspk_p) or (burst_p == burst_p)').groupby(['m', 's', 'e', 'u'])
- n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
- prop_sig = n_sig / len(units)
- print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
- # For each spike type
- for spk_type in ['tonicspk', 'burst']:
- units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
- n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
- prop_sig = n_sig / len(units)
- print(f"{spk_type.capitalize()} prop. significant: {prop_sig:.3f} ({n_sig}/{len(units)})")
- n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
- print(f"{spk_type.capitalize()} num. CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
- def kl_divergence(p, q):
- return np.sum(np.where(p != 0, p * np.log(p / q), 0))
- def match_distributions(x1, x2, x1_bins, x2_bins):
- """
- For two time series, x2 & x2, return indices of sub-sampled time
- periods such that the distribution of x2 is matched across
- bins of x1.
- """
- x1_nbins = len(x1_bins) - 1
- x2_nbins = len(x2_bins) - 1
- # bin x1
- x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
- # get continuous periods where x1 visits each bin
- x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
- # get mean of x2 for each x1 bin visit
- x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
- # find minimum common distribution across x1 bins
- x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
- x2_mcd = x2_counts.min(axis=0)
- # bin x2 means
- x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
- x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
- x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
- # loop over x2 bins
- matched_ranges = [[], [], [], []]
- for x2_bin in np.arange(x2_nbins):
- # find the x1 bin matching the MCD
- seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
- assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
- # for each bin visit, find the closest matching mean in the other x1 bins
- target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
- target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
- for target_mean, target_range in zip(target_means, target_ranges):
- matched_ranges[seed_x1_bin].append(target_range)
- for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
- matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
- matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
- # delete the matching period
- x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
- x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
- return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
- ## Signal processing
- def normalized_xcorr(a, b, dt=None, ts=None):
- """
- Compute Pearson r between two arrays at various lags
- Parameters
- ----------
- a, b : ndarray
- The arrays to correlate.
- dt : float
- The time step between samples in the arrays.
- ts : list
- If not None, only the xcorr between the specified lags will be
- returned.
- Return
- ------
- xcorr, lags : ndarray
- The cross correlation and corresponding lags between a and b.
- Positive lags indicate that a is delayed relative to b.
- """
- assert len(a) == len(b)
- n = len(a)
- a_norm = (a - a.mean()) / a.std()
- b_norm = (b - b.mean()) / b.std()
- xcorr = np.correlate(a_norm, b_norm, 'full') / n
- lags = np.arange(-n + 1, n)
- if dt is not None:
- lags = lags * dt
- if ts is not None:
- assert len(ts) == 2
- i0, i1 = lags.searchsorted(ts)
- xcorr = xcorr[i0:i1]
- lags = lags[i0:i1]
- return xcorr, lags
- def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
- """
- Use linear interpolation to re-sample 1D data.
- """
- # get interpolation function
- func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
- # get new y-values
- y_interpolated = func(x_new)
- return y_interpolated
- def interpolate_and_normalize(y, x_old, x_new):
- """
- Perform linear interpolation and min-max normalization.
- """
- y_new = interpolate(y, x_old, x_new)
- return (y_new - y_new.min()) / (y_new.max() - y_new.min())
- def match_signal_length(a, b, a_tpts, b_tpts):
- """
- Given two signals, truncate to match the length of the shortest.
- """
- t1 = min(a_tpts.max(), b_tpts.max())
- a1 = a[:a_tpts.searchsorted(t1)]
- a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
- b1 = b[:b_tpts.searchsorted(t1)]
- b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
- return a1, b1, a1_tpts, b1_tpts
- def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
- if type(t0) == str:
- t0, t1 = series[t0].min(), series[t0].max()
- elif t0 is None: # get overlapping time range for all columns
- t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
- elif t1 is None:
- t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
- # Set time base
- tbins = np.arange(t0, t1, dt)
- tpts = tbins[:-1] + (dt / 2)
- for col in columns:
- header = col.split('_')[0]
- times = series[f'{header}_times']
- counts, _ = np.histogram(times, bins=tbins)
- series[f'{header}_counts'] = counts
- series[f'{header}_tpts'] = tpts
- return series
- def resample_timeseries(y, tpts, dt=0.25):
- tpts_new = np.arange(tpts.min(), tpts.max(), dt)
- return tpts_new, interpolate(y, tpts, tpts_new)
- def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
- if type(t0) == str:
- t0, t1 = series[t0].min(), series[t0].max()
- elif t0 is None: # get overlapping time range for all columns
- t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
- t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
- # Set new time base
- tbins = np.arange(t0, t1, dt)
- tpts_new = tbins[:-1] + (dt / 2)
- # Interpolate and re-sample each column
- for col in columns:
- header = col.split('_')[0]
- data = series[col]
- tpts = series[f'{header}_tpts']
- series[col] = interpolate(data, tpts, tpts_new)
- series[f'{header}_tpts'] = tpts_new
- return series
- ## Neural activity
- def get_mean_rates(df):
- df['mean_rate'] = df.apply(
- lambda x:
- len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
- axis='columns'
- )
- return df
- def get_mean_rate_threshold(df, alpha=0.025):
- if 'mean_rate' not in df.columns:
- df = get_mean_rates(df)
- rates = np.log10(df['mean_rate'])
- gmm = GaussianMixture(n_components=2)
- gmm.fit(rates[..., np.newaxis])
- (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
- threshold = mu + norm.ppf(alpha) * np.sqrt(var)
- return threshold
- def filter_units(df, threshold):
- if 'mean_rate' not in df.columns:
- df = get_mean_rates(df)
- return df.query(f'mean_rate >= {threshold}')
- def get_raster(series, events, spike_type='spk', pre=0, post=1):
- events = series['%s_times' % events]
- spks = series['%s_times' % spike_type]
- raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
- return raster
- def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
- rate_kernel = GaussianKernel(bw*pq.s)
- tpts = np.arange(pre, post, dt)
- psth = np.full((len(events), len(tpts)), np.nan)
- for i, t0 in enumerate(events):
- rel_ts = spikes - t0
- rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
- try:
- rate = instantaneous_rate(
- SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
- sampling_period=dt*pq.s,
- kernel=rate_kernel
- )
- except:
- continue
- psth[i] = rate.squeeze()
- if baseline:
- b0, b1 = tpts.searchsorted(baseline)
- baseline_rate = psth[:, b0:b1].mean(axis=1)
- psth = (psth.T - baseline_rate).T
- return psth, tpts
- def apply_get_psth(series, events, spike_type, **kwargs):
- events = series['{}_times'.format(events)]
- spikes = series['{}_times'.format(spike_type)]
- psth, tpts = get_psth(events, spikes, **kwargs)
- return psth
- def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
- dt = np.round(np.diff(tpts).mean(), 3) # round to nearest ms
- i_pre, i_post = int(pre / dt), int(post / dt)
- responses = np.full((len(events), i_pre + i_post), np.nan)
- for j, t0 in enumerate(events):
- i = tpts.searchsorted(t0)
- i0, i1 = i - i_pre, i + i_post
- if i0 < 0:
- continue
- if i1 > len(data):
- break
- responses[j] = data[i0:i1]
- tpts = np.linspace(pre, post, responses.shape[1])
- if baseline:
- b0, b1 = tpts.searchsorted(baseline)
- baseline_resp = responses[:, b0:b1].mean(axis=1)
- responses = (responses.T - baseline_resp).T
- return responses, tpts
- def apply_get_responses(series, events, data, **kwargs):
- events = series[f'{events}_times']
- tpts = series[f'{data.split("_")[0]}_tpts']
- data = series[f'{data}']
- responses, tpts = get_responses(events, data, tpts, **kwargs)
- return responses
|