12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001 |
- import numpy as np
- import pandas as pd
- import quantities as pq
- from matplotlib import pyplot as plt
- from scipy.interpolate import interp1d
- from statsmodels.stats.multitest import fdrcorrection
- from scipy.stats import norm, sem
- from sklearn.mixture import GaussianMixture
- from neo import SpikeTrain
- from elephant.statistics import instantaneous_rate
- from elephant.kernels import GaussianKernel
- from parameters import *
- ## Data handling
- def df2keys(df):
- """Return MSE keys for all entries in a DataFrame."""
- df = df.reset_index()
- keys = [key for key in df.columns if key in ['m', 's', 'e', 'u']]
- return [{key: val for key, val in zip(keys, vals)} for vals in df[keys].values]
- def key2idx(key):
- """Return DataFrame index tuple for the given key."""
- return tuple([val for k, val in key.items()])
- def print_Ns(df):
- """
- Print number of mice, recording sessions, experiments, and neurons
- contained in a dataset.
- Parameters
- ----------
- df : pandas.DataFrame
- Dataframe with 'm', 's', 'e', and 'u' identifiers in the columns or index.
- """
- df = df.reset_index()
- print(f"No. mice: {len(df.groupby('m'))}")
- print(f"No. series: {len(df.groupby(['m', 's']))}")
- print(f"No. experiments: {len(df.groupby(['m', 's', 'e']))}")
- if 'u' in df.columns:
- print(f"No. units: {len(df.groupby(['m', 's', 'e', 'u']))}")
- # Print number of units from each experiment
- for idx, experiment in df.groupby(['m', 's', 'e']):
- print(f"\n{idx[0]} s{idx[1]:02d} e{idx[2]:02d}")
- print(f" {len(experiment)} units")
- def load_data(data, conditions, ext=''):
- """
- Load data of a specified type, concatenating all the requested conditions.
- Parameters
- ----------
- conditions : list
- List of string specifying the experimental conditions to load.
- ext : str
- String specifying the filename extension to load, can specify the
- spike type (e.g. 'burst') or data sub-sample (e.g. 'sizematched').
- """
- dfs = []
- for condition in conditions:
- filename = '{}_{}'.format(data, condition)
- if ext:
- filename = filename + '_{}'.format(ext)
- filename = filename + '.pkl'
- df = pd.read_pickle(DATAPATH + filename)
- if 'condition' not in df.columns:
- df['condition'] = condition
- dfs.append(df)
- return pd.concat(dfs, axis='index')
- ## Plotting utils
- def set_plotsize(w, h=None, ax=None):
- """
- Set the size of a matplotlib axes object in cm.
- Parameters
- ----------
- w, h : float
- Desired width and height of plot, if height is None, the axis will be
- square.
- ax : matplotlib.axes
- Axes to resize, if None the output of plt.gca() will be re-sized.
- Notes
- -----
- - Use after subplots_adjust (if adjustment is needed)
- - Matplotlib axis size is determined by the figure size and the subplot
- margins (r, l; given as a fraction of the figure size), i.e.
- w_ax = w_fig * (r - l)
- """
- if h is None: # assume square
- h = w
- w /= 2.54 # convert cm to inches
- h /= 2.54
- if not ax: # get current axes
- ax = plt.gca()
- # get margins
- l = ax.figure.subplotpars.left
- r = ax.figure.subplotpars.right
- t = ax.figure.subplotpars.top
- b = ax.figure.subplotpars.bottom
- # set fig dimensions to produce desired ax dimensions
- figw = float(w)/(r-l)
- figh = float(h)/(t-b)
- ax.figure.set_size_inches(figw, figh)
- def clip_axes_to_ticks(ax=None, spines=['left', 'bottom'], ext={}):
- """
- Clip the axis lines to end at the minimum and maximum tick values.
- Parameters
- ----------
- ax : matplotlib.axes
- Axes to resize, if None the output of plt.gca() will be re-sized.
- spines : list
- Axes to keep and clip, axes not included in this list will be removed.
- Valid values include 'left', 'bottom', 'right', 'top'.
- ext : dict
- For each axis in ext.keys() ('left', 'bottom', 'right', 'top'),
- the axis line will be extended beyond the last tick by the value
- specified, e.g. {'left':[0.1, 0.2]} will results in an axis line
- that extends 0.1 units beyond the bottom tick and 0.2 unit beyond
- the top tick.
- """
- if ax is None:
- ax = plt.gca()
- spines2ax = {
- 'left': ax.yaxis,
- 'top': ax.xaxis,
- 'right': ax.yaxis,
- 'bottom': ax.xaxis
- }
- all_spines = ['left', 'bottom', 'right', 'top']
- for spine in spines:
- low = min(spines2ax[spine].get_majorticklocs())
- high = max(spines2ax[spine].get_majorticklocs())
- if spine in ext.keys():
- low += ext[spine][0]
- high += ext[spine][1]
- ax.spines[spine].set_bounds(low, high)
- for spine in [spine for spine in all_spines if spine not in spines]:
- ax.spines[spine].set_visible(False)
- def p2stars(p):
- if p <= 0.0001:
- return '***'
- elif p <= 0.001:
- return '**'
- elif p<= 0.05:
- return '*'
- else:
- return ''
- def violin_plot(dists, colors, ax=None, logscale=False):
- if type(colors) is list:
- assert len(colors) == len(dists)
- if ax is None:
- fig, ax = plt.subplots()
- violins = ax.violinplot(dists, showmedians=True, showextrema=False)
- for violin, color in zip(violins['bodies'], colors):
- violin.set_facecolor('none')
- violin.set_edgecolor(color)
- violin.set_alpha(1)
- violin.set_linewidth(2)
- violins['cmedians'].set_color('black')
- for pos, dist in enumerate(dists):
- median = np.median(dist)
- text = f'{median:.2f}'
- if logscale:
- text = f'{10 ** median:.2f}'
- ax.text(pos + 1.4, median, text, va='center', ha='center', rotation=-90, fontsize=LABELFONTSIZE)
- ax.set_xticks(np.arange(len(dists)) + 1)
- ax.tick_params(bottom=False)
- return ax
- def pupil_area_rate_heatmap(df, cmap='gray', max_rate='high', example=None):
- """
- Plot a heatmap of event rates (tonic spikes or bursts) where each row is a neuron and each column is a pupil size bin.
- Parameters
- ----------
- df : pandas.DataFrame
- Dataframe with neurons in the rows and mean firing rates for each pupil size bin in a column called 'area_means'.
-
- cmap : str or Matplotlib colormap object
- max_rate : str
- Is the max event rate expected to at 'high' or 'low' pupil sizes?
- example : dict
- If not none, MSEU key passed will be used to highlight example neuron.
- """
- fig = plt.figure()
-
- # Find out which pupil size bin has the highest firing rate
- df['tuning_max'] = df['area_means'].apply(np.argmax)
- # Min-max normalize firing rates for each neuron
- df['tuning_norm'] = df['area_means'].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
- # Get heatmap for neurons with significant rate differences across pupil size bins
- df_sig = df.query('area_p <= 0.05').sort_values('tuning_max')
- heatmap_sig = np.row_stack(df_sig['tuning_norm'])
- n_sig = len(df_sig) # number of neurons with significant differences
- # Make axis with size proportional to the fraction of significant neurons
- n_units = len(df) # total number of neurons
- ax1_height = n_sig / n_units
- ax1 = fig.add_axes([0.1, 0.9 - (0.76 * ax1_height), 0.8, (0.76 * ax1_height)])
- # Plot heatmap for signifcant neurons
- mat = ax1.matshow(heatmap_sig, cmap=cmap, aspect='auto')
- # Scatter plot marking pupil size bin with maximum
- ax1.scatter(df_sig['tuning_max'], np.arange(len(df_sig)) - 0.5, s=0.5, color='black', zorder=3)
- # Format axes
- ax1.set_xticks([])
- yticks = np.insert(np.arange(0, n_sig, 20), -1, n_sig) # ticks every 20 neurons, and final count
- ax1.set_yticks(yticks - 0.5)
- ax1.set_yticklabels(yticks)
- ax1.set_ylabel('Neurons')
- # Make a colorbar
- cbar = fig.colorbar(mat, ax=ax1, ticks=[0, 1], location='top', shrink=0.75)
- cbar.ax.set_xticklabels(['min', 'max'])
- cbar.ax.set_xlabel('Spikes', labelpad=-5)
- # Print dotted line and label for neurons considered to have 'monotonic' modulation profiles
- # (highest event rate at one of the pupil size extremes)
- if max_rate == 'high':
- n_mon = len(df_sig.query('tuning_max >= 9'))
- print("Monotonic increasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
- ax1.axvline(8.5, lw=2, ls='--', color='white')
- ax1.set_title(r'Monotonic$\rightarrow$', fontsize=LABELFONTSIZE, loc='right', pad=0)
- elif max_rate == 'low':
- n_mon = len(df_sig.query('tuning_max <= 1'))
- print("Monotonic decreasing: %d/%d (%.1f)" % (n_mon, n_sig, n_mon / n_sig * 100))
- ax1.axvline(0.5, lw=2, ls='--', color='white')
- ax1.set_title(r'$\leftarrow$Monotonic', fontsize=LABELFONTSIZE, loc='left', pad=0)
- # Get heatmap for neurons without significant rate differences acrss pupil sizes
- df_ns = df.query('area_p > 0.05').sort_values('tuning_max')
- heatmap_ns = np.row_stack(df_ns['tuning_norm'])
-
- # Make axis with size proportional to the fraction of non-significant neurons
- n_ns = len(df_ns)
- ax2_height = n_ns / n_units
- ax2 = fig.add_axes([0.1, 0.1, 0.8, (0.76 * ax2_height)])
- # Plot heatmap for neurons without significant rate differences acrss pupil sizes
- mat = ax2.matshow(heatmap_ns, cmap='Greys', vmax=2, aspect='auto')
- ax2.scatter(df_ns['tuning_max'], np.arange(n_ns) - 0.5, s=0.5, color='black', zorder=3)
- # Format axes
- ax2.xaxis.set_ticks_position('bottom')
- ax2.set_xticks([-0.5, 4.5, 9.5])
- ax2.set_xticklabels([0, 0.5, 1]) # x-axis ticks mark percentiles of pupil size range
- ax2.set_xlim(right=9.5)
- ax2.set_xlabel('Pupil size (norm.)')
- yticks = np.insert(np.arange(0, n_ns, 20), -1, n_ns) # ticks every 20 neurons, and final count
- ax2.set_yticks(yticks - 0.5)
- ax2.set_yticklabels(yticks)
- # Highligh example neuron by circling scatter dot
- if example is not None:
- try: # first check if example neuron is among significant neurons
- is_ex = df_sig.index == tuple([v for k, v in example.items()])
- assert is_ex.any()
- ex_max = df_sig['tuning_max'][is_ex]
- ax1.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
- except:
- is_ex = df_ns.index == tuple([v for k, v in example.items()])
- ex_max = df_ns['tuning_max'][is_ex]
- ax2.scatter(ex_max, np.where(is_ex)[0], lw=2, s=1, fc='none', ec='magenta', zorder=4)
- return fig
- def cumulative_histogram(data, bins, color='C0', ax=None):
- """Convenience function, cleaner looking plot than plt.hist(..., cumulative=True)."""
- if ax is None:
- fig, ax = plt.subplots()
- weights = np.ones_like(data) / len(data)
- counts, _ = np.histogram(data, bins=bins, weights=weights)
- ax.plot(bins[:-1], np.cumsum(counts), color=color)
- return ax
- def cumulative_hist(x, bins, density=True, ax=None, **kwargs):
- weights = np.ones_like(x)
- if density:
- weights = weights / len(x)
- counts, _ = np.histogram(x, bins=bins, weights=weights)
- if ax is None:
- fig, ax = plt.subplots()
- xs = np.insert(bins, np.arange(len(bins) - 1), bins[:-1])
- ys = np.insert(np.insert(np.cumsum(counts), np.arange(len(counts)), np.cumsum(counts)), 0, 0)
- ax.plot(xs, ys, lw=2, **kwargs)
- ax.set_xticks(bins + 0.5)
- ax.set_yticks([0, 0.5, 1])
- return ax, counts
- def phase_coupling_scatter(df, ax=None):
- """Phase-frequency scatter plot for phase coupling."""
- if ax is None:
- fig, ax = plt.subplots()
- for event in ['tonicspk', 'burst']:
- df_sig = df.query(f'{event}_sig == True')
- ax.scatter(np.log10(df_sig['freq']), df_sig[f'{event}_phase'], ec=COLORS[event], fc='none', lw=0.5, s=3)
- n_sig = max([len(df.query(f'{event}_sig == True').groupby(['m', 's', 'e', 'u'])) for event in ['tonicspk', 'burst']])
-
- ax.set_xticks(FREQUENCYTICKS)
- ax.set_xticklabels(FREQUENCYTICKLABELS)
- ax.set_xlim(left=-3.075)
- ax.set_xlabel("Inverse timescale (s$^{-1}$)")
-
- ax.set_yticks(PHASETICKS)
- ax.set_yticklabels(PHASETICKLABELS)
- ax.set_ylim([-np.pi - 0.15, np.pi + 0.15])
- ax.set_ylabel("Preferred phase")
- return ax
- def plot_circhist(angles, ax=None, bins=np.linspace(0, 2 * np.pi, 17), density=True, **kwargs):
- """Plot a circular histogram."""
- if ax is None:
- fig, ax = plt.subplots(subplot_kw={'polar':True})
- weights = np.ones_like(angles)
- if density:
- weights /= len(angles)
- counts, bins = np.histogram(angles, bins=bins, weights=weights)
- xs = bins + (np.pi / (len(bins) - 1))
- ys = np.append(counts, counts[0])
- ax.plot(xs, ys, **kwargs)
- ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
- ax.set_xticklabels(['0', '\u03C0/2', '\u03C0', '3\u03C0/2'])
- ax.tick_params(axis='x', pad=-5)
- return ax, counts
- def coupling_strength_line_plot(df, agg=np.mean, err=sem, logscale=True, ax=None, **kwargs):
- """
- Line plot showing average burst and tonic spike coupling strengths and SE across timescale bins.
- """
- if ax is None:
- fig, ax = plt.subplots()
- for event in ['burst', 'tonicspk']:
- df_sig = df.query(f'({event}_sig == True) & ({event}_strength > 0)').copy()
- strengths = sort_data(df_sig[f'{event}_strength'], df_sig['freq'], bins=FREQUENCYBINS)
- ys = np.array([agg(s) for s in strengths])
- yerr = np.array([err(s) for s in strengths])
- if not logscale:
- ax.plot(FREQUENCYXPOS, ys, color=COLORS[event], **kwargs)
- ax.plot(FREQUENCYXPOS, ys + yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.plot(FREQUENCYXPOS, ys - yerr, color=COLORS[event], lw=0.5, ls='--', **kwargs)
- else:
- ax.plot(FREQUENCYXPOS, np.log10(ys), color=COLORS[event], **kwargs)
- ax.plot(FREQUENCYXPOS, np.log10(ys + yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.plot(FREQUENCYXPOS, np.log10(ys - yerr), color=COLORS[event], lw=0.5, ls='--', **kwargs)
- ax.set_xticks(FREQUENCYTICKS)
- ax.set_xticklabels(FREQUENCYTICKLABELS)
- ax.set_xlim(left=-3.1)
- ax.set_xlabel('Timescale (Hz)')
- ax.set_ylabel('Coupling strength')
- return ax
- def compare_coupling_strength_preferences(df, conds, spk_type):
- """
- Make a pie chart showing the proportion of neurons that either lose their
- coupling, retain the IMF to which they most strongly couple, or shift the
- IMF to which they most strongly couple across two conditions. Also
- highlight the proportion of neurons in the latter two categories that
- increase their coupling in the second condition compared to the first.
- Parameters
- ----------
- df : pandas.DataFrame
- Merged dataframe combining the output of two phase coupling analyses
- performed on the same set of neurons/ IMFs across two conditions.
- Columns must have suffixes indicating the condition they come from.
- conds : list
- List of strings indicating suffixes for each of the two conditions.
- spk_type : str
- String indicating for which spike type coupling is to be compared.
- """
- # Start by restricting to neuron-IMF pairs with significant coupling in the first condition
- df_sig = df.query(f'({spk_type}_sig_{conds[0]} == True)')
- # Define some helper functions to be applied to each neuron
- def _pref_timescale(unit):
- # Find out if IMF with stongest coupling is the same across conditions
- pref_cond0 = np.argmax(unit[f'{spk_type}_strength_{conds[0]}'])
- pref_cond1 = np.argmax(unit[f'{spk_type}_strength_{conds[1]}'])
- return pref_cond0 != pref_cond1
- def _max_strength(unit):
- # Get the difference in strength for the strongest IMF across conditions
- strength_cond0 = unit[f'{spk_type}_strength_{conds[0]}'].max()
- strength_cond1 = unit[f'{spk_type}_strength_{conds[1]}'].max()
- return strength_cond1 - strength_cond0
- # Group data by neuron
- units = df_sig.groupby(['m', 's', 'e', 'u'])
- # Does the neuron have any significant coupling?
- sig = (units.apply(lambda x: x[f'{spk_type}_sig_{conds[1]}'].sum() > 0))
- # Does the strongest timescale shift across conditions?
- shift = units.apply(_pref_timescale)
- # What is the difference in max strength between conditions?
- strength = units.apply(_max_strength)
- # Get slices of the pie
- no_coupling = (~sig).sum()
- imf_same_down = (sig & ~shift & (strength < 0)).sum() # max strength lower in second condition
- imf_same_up = (sig & ~shift & (strength >= 0)).sum() # max strength higher in second condition
- imf_shift_down = (sig & shift & (strength < 0)).sum() # change in IMF with strongest coupling
- imf_shift_up = (sig & shift & (strength >= 0)).sum()
- total = np.sum([no_coupling, imf_shift_down, imf_shift_up, imf_same_up, imf_same_down])
- assert total == len(units) # check categories are mutually exclusive
- # Print proportions in each category
- print(f"No coupling: {(no_coupling / total * 100):.1f}%")
- print(f"Timescale same: {(np.sum([imf_same_up, imf_same_down]) / total * 100):.1f}%")
- print(f"Timescale shift: {(np.sum([imf_shift_up, imf_shift_down]) / total * 100):.1f}%")
- print(f"Increased coupling: {(np.sum([imf_shift_up, imf_same_up]) / total * 100):.1f}%\n")
- # Make pie chart
- fig, ax = plt.subplots()
- slices, text = ax.pie([no_coupling, imf_shift_down, imf_shift_up, imf_same_up, imf_same_down],
- explode=[0., 0., 0.1, 0.1, 0.], # pop-out slices representing max strength increase
- colors=['gray', COLORS[spk_type], COLORS[spk_type], COLORS[spk_type], COLORS[spk_type]])
- slices[1].set_alpha(0.5) # faded color for slices representing strongest IMF shift
- slices[2].set_alpha(0.5)
- plt.legend(slices[::2], ['No coupling', 'Timescale shift', 'Timescale retained'], loc=(1, 0.75), frameon=False)
- ax.xaxis.set_visible(False)
- return ax
- ## Util
- def zero_runs(a):
- """
- Return an array with shape (m, 2), where m is the number of "runs" of zeros
- in a. The first column is the index of the first 0 in each run, the second
- is the index of the first nonzero element after the run.
- """
- # Create an array that is 1 where a is 0, and pad each end with an extra 0.
- iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
- absdiff = np.abs(np.diff(iszero))
- # Runs start and end where absdiff is 1.
- ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
- return ranges
- def merge_ranges(ranges, dt=1):
- """
- Given a set of ranges [start, stop], return new set of ranges where all
- overlapping ranges are merged.
- """
- tpts = np.arange(ranges.min(), ranges.max(), dt) # array of time points
- tc = np.ones_like(tpts) # time course of ranges
- for t0, t1 in ranges: # for each range
- i0, i1 = tpts.searchsorted([t0, t1])
- tc[i0:i1] = 0 # set values in range to 0
- new_ranges = zero_runs(tc) # get indices of continuous stretches of zero
- if new_ranges[-1, -1] == len(tpts): # fix end-point
- new_ranges[-1, -1] = len(tpts) - 1
- return tpts[new_ranges]
- def continuous_runs(data, max0len=1, min1len=1, min1prop=0):
- """
- Get start and stop indices of stretches of (relatively) continuous data.
- Parameters
- ----------
- data : ndarray
- 1D boolean array
- max0len : int
- maximum length (in data pts) of False stretches to ignore
- min1len : int
- minimum length (in data pts) of True runs to keep
- min1prop : int
- minimum proprtion of True data in the run necessary for it
- to be considered
- Returns
- -------
- out : ndarray
- (m, 2) array of start and stop indices, where m is the number runs of
- continuous True values
- """
- # get ranges of True values
- one_ranges = zero_runs(~data)
- if len(one_ranges) == 0:
- return np.array([[]])
- # merge ranges that are separated by < min0len of False
- one_ranges[:, 1] += (max0len - 1)
- one_ranges = merge_ranges(one_ranges)
- # return indices to normal
- one_ranges[:, 1] -= (max0len - 1)
- one_ranges[-1, -1] += 1
- # remove ranges that are too short
- lengths = np.diff(one_ranges, axis=1).ravel()
- one_ranges = one_ranges[lengths >= min1len]
- # remove ranges that don't have sufficient proportion True
- prop = np.array([data[i0:i1].sum() / (i1 - i0) for (i0, i1) in one_ranges])
- return one_ranges[prop >= min1prop]
- def switch_ranges(ranges, dt=1, minval=0, maxval=None):
- """
- Given a set of (start, stop) pairs, return a new set of pairs for values
- outside the given ranges.
- Parameters
- ----------
- ranges : ndarray
- N x 2 array containing start and stop values in the first and second
- columns respectively
- dt : float
- minval, maxval : int
- the minimum and maximum possible values, if maxval is None it is assumed
- that the maximum possible value is the maximum value in the input ranges
- Returns
- -------
- out : ndarray
- M x 2 array containing start and stop values of all ranges outside of
- the input ranges
- """
- if ranges.shape[1] == 0:
- return np.array([[minval, maxval]])
- assert (ranges.ndim == 2) & (ranges.shape[1] == 2), "A two-column array is expected"
- maxval = ranges.max() if maxval is None else maxval
- # get new ranges
- new_ranges = np.zeros_like(ranges)
- new_ranges[:,0] = ranges[:,0] - dt # new stop values
- new_ranges[:,1] = ranges[:,1] + dt # new start values
- # fix boundaries
- new_ranges = new_ranges.ravel()
- if new_ranges[0] >= (minval + dt): # first new stop within allowed range
- new_ranges = np.concatenate((np.array([minval]), new_ranges))
- else:
- new_ranges = new_ranges[1:]
- if new_ranges[-1] <= (maxval - dt): # first new start within allowed range
- new_ranges = np.concatenate((new_ranges, np.array([maxval])))
- else:
- new_ranges = new_ranges[:-1]
- return new_ranges.reshape((int(len(new_ranges) / 2), 2))
- def shuffle_bins(x, binwidth=1):
- """
- Randomly shuffle bins of an array.
- """
- # bin start indices
- bins_i0 = np.arange(0, len(x), binwidth)
- # shuffled bins
- np.random.shuffle(bins_i0)
- # concatenate shuffled bins
- shf = np.concatenate([x[i0:(i0 + binwidth)] for i0 in bins_i0])
- return shf
- def take_data_in_bouts(series, data, bouts, trange=None, dt=2, dt0=0, dt1=0, concatenate=True, norm=False):
- if series['%s_bouts' % bouts].shape[1] < 1:
- return np.array([])
- header, _ = data.split('_')
- data_in_bouts = []
- for t0, t1 in series['%s_bouts' % bouts]:
- t0 -= dt0
- t1 += dt1
- if trange == 'start':
- t1 = t0 + dt
- elif trange == 'end':
- t0 = t1 - dt
- elif trange == 'middle':
- t0 = t0 + dt
- t1 = t1 - dt
- if t1 <= t0:
- continue
- if t0 < series['%s_tpts' % header].min():
- continue
- if t1 > series['%s_tpts' % header].max():
- continue
- i0, i1 = series['%s_tpts' % header].searchsorted([t0, t1])
- data_in_bout = series[data][i0:i1].copy()
- if norm:
- data_in_bout = data_in_bout / series[data].max()
- data_in_bouts.append(data_in_bout)
- if concatenate:
- data_in_bouts = np.concatenate(data_in_bouts)
- return data_in_bouts
- def get_trials(series, stim_id=0, opto=False, multi_stim='warn'):
- if opto:
- opto = np.isin(series['trial_id'], series['opto_trials'])
- elif not opto:
- opto = ~np.isin(series['trial_id'], series['opto_trials'])
- if stim_id < 0:
- stim = np.ones_like(series['stim_id']).astype('bool')
- else:
- stim = series['stim_id'] == stim_id
- series['trial_on_times'] = series['trial_on_times'][opto & stim]
- series['trial_off_times'] = series['trial_off_times'][opto & stim]
- return series
- def sort_data(data, sort_vals, bins=10):
- if type(bins) == int:
- nbins = bins
- bin_edges = np.linspace(sort_vals.min(), sort_vals.max(), nbins + 1)
- else:
- nbins = len(bins) - 1
- bin_edges = bins
- digitized_vals = np.digitize(sort_vals, bins=bin_edges).clip(1, nbins)
- return [data[digitized_vals == val] for val in np.arange(nbins) + 1]
- def apply_sort_data(series, data_col, sort_col, bins=10):
- return sort_data(series[data_col], series[sort_col], bins)
- ## Statistics
- def get_binned_rates(spk_rates, pupil_area, sort=False, nbins=10):
- # Get bins base on percentiles to eliminate effect of outliers
- min_area, max_area = np.percentile(pupil_area, [2.5, 97.5])
- #min_area, max_area = pupil_area.min(), pupil_area.max()
- area_bins = np.linspace(min_area, max_area, nbins + 1)
- # Bin pupil area
- binned_area = np.digitize(pupil_area, bins=area_bins).clip(1, nbins) - 1
- # Bin rates according to pupil area
- binned_rates = np.array([spk_rates[binned_area == bin_i] for bin_i in np.arange(nbins)], dtype=object)
- if sort:
- sorted_inds = np.argsort([rates.mean() if len(rates) > 0 else 0 for rates in binned_rates])
- binned_rates = binned_rates[sorted_inds]
- binned_area = np.squeeze([np.where(sorted_inds == area_bin)[0] for area_bin in binned_area])
- return area_bins, binned_area, binned_rates
- def rescale(x, method='min_max'):
- # Set re-scaling method
- if method == 'z_score':
- return (x - x.mean()) / x.std()
- elif method == 'min_max':
- return (x - x.min()) / (x.max() - x.min())
- def correlogram(ts1, ts2=None, tau_max=1, dtau=0.01, return_tpts=False):
- if ts2 is None:
- ts2 = ts1.copy()
- auto = True
- else:
- auto = False
- tau_max = (tau_max // dtau) * dtau
- tbins = np.arange(-tau_max, tau_max + dtau, dtau)
- ccg = np.zeros(len(tbins) - 1)
- for t0 in ts1:
- dts = ts2 - t0
- if auto:
- dts = dts[dts != 0]
- ccg += np.histogram(dts[np.abs(dts) <= tau_max], bins=tbins)[0]
- ccg /= len(ts1)
- if not return_tpts:
- return ccg
- else:
- tpts = tbins[:-1] + dtau / 2
- return tpts, ccg
- def angle_subtract(a1, a2, period=(2 * np.pi)):
- """
- Compute the pair-wise difference between two sets of angles on a given period.
- """
- return (a1 - a2) % period
- def circmean(alpha, w=None, axis=None):
- """
- Compute mean resultant vector of circular data.
- Parameters
- ----------
- alpha : ndarray
- array of angles
- w : ndarray
- array of weights, must be same shape as alpha
- axis : int, None
- axis across which to compute mean
- Returns
- -------
- mrl : ndarray
- mean resultant vector length
- theta : ndarray
- mean resultant vector angle
- """
- # weights default to ones
- if w is None:
- w = np.ones_like(alpha)
- w[np.isnan(alpha)] = 0
- # compute weighted mean
- mean_vector = np.nansum(w * np.exp(1j * alpha), axis=axis) / w.sum(axis=axis)
- mrl = np.abs(mean_vector) # length
- theta = np.angle(mean_vector) # angle
- return mrl, theta
- def circmean_angle(alpha, **kwargs):
- return circmean(alpha, **kwargs)[1]
- def circhist(angles, n_bins=8, proportion=True, wrap=False):
- bins = np.linspace(-np.pi, np.pi, n_bins + 1, endpoint=True)
- weights = np.ones(len(angles))
- if proportion:
- weights /= len(angles)
- counts, bins = np.histogram(angles, bins=bins, weights=weights)
- if wrap:
- counts = np.concatenate([counts, [counts[0]]])
- bins = np.concatenate([bins, [bins[0]]])
- return counts, bins
- def unbiased_variance(data):
- if len(data) <= 1:
- return np.nan
- else:
- return np.var(data) * len(data) / (len(data) - 1)
- def se_median(sample, n_resamp=1000):
- """Standard error of the median."""
- medians = np.full(n_resamp, np.nan)
- for i in range(n_resamp):
- resample = np.random.choice(sample, len(sample), replace=True)
- medians[i] = np.median(resample)
- return np.std(medians)
- def coupling_summary(df):
- """
- Print the number of units with significant phase coupling for either spike
- type to at least one IMF, the number of units with significant coupling for
- each spike type, and the mean number of IMFs to which a single unit has
- significant coupling for each spike type.
-
- Parameters
- ----------
- df : pandas.DataFrame
- Dataframe output from the phase_tuning.py analysis script.
- """
- # For either spike type
- units = df.groupby(['m', 's', 'e', 'u'])
- # Check each unit for significance
- n_sig = units.apply(lambda x: any(x[f'tonicspk_sig']) or any(x[f'burst_sig'])).sum()
- prop_sig = n_sig / len(units)
- print(f"Neurons with significant coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
- # For each spike type
- for spk_type in ['tonicspk', 'burst']:
- # Remove units where coupling for the spike type was not assessed
- units = df.dropna(subset=f'{spk_type}_p').groupby(['m', 's', 'e', 'u'])
- n_sig = units.apply(lambda x: any(x[f'{spk_type}_sig'])).sum()
- prop_sig = n_sig / len(units)
- print(f"Neurons with significant {spk_type} coupling: {prop_sig:.3f} ({n_sig}/{len(units)})")
- n_cpds = units.apply(lambda x: x[f'{spk_type}_sig'].sum())
- print(f" Mean number of CPDs per neuron: {n_cpds.mean():.2f}, {n_cpds.std():.2f}")
- def kl_divergence(p, q):
- return np.sum(np.where(p != 0, p * np.log(p / q), 0))
- def match_distributions(x1, x2, x1_bins, x2_bins):
- """
- For two time series, x2 & x2, return indices of sub-sampled time
- periods such that the distribution of x2 is matched across
- bins of x1.
- """
- x1_nbins = len(x1_bins) - 1
- x2_nbins = len(x2_bins) - 1
- # bin x1
- x1_binned = np.digitize(x1, x1_bins).clip(1, x1_nbins) - 1
- # get continuous periods where x1 visits each bin
- x1_ranges = [zero_runs(~np.equal(x1_binned, x1_bin)) for x1_bin in np.arange(x1_nbins)]
- # get mean of x2 for each x1 bin visit
- x2_means = [np.array([np.mean(x2[i0:i1]) for i0, i1 in x1_bin]) for x1_bin in x1_ranges]
- # find minimum common distribution across x1 bins
- x2_counts = np.row_stack([np.histogram(means, bins=x2_bins)[0] for means in x2_means])
- x2_mcd = x2_counts.min(axis=0)
- # bin x2 means
- x2_means_binned = [np.digitize(means, bins=x2_bins).clip(1, x2_nbins) - 1 for means in x2_means]
- x2_means_in_bins = [[means[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for means, binned_means in zip(x2_means, x2_means_binned)]
- x1_ranges_in_bins = [[ranges[binned_means == x2_bin] for x2_bin in np.arange(x2_nbins)] for ranges, binned_means in zip(x1_ranges, x2_means_binned)]
- # loop over x2 bins
- matched_ranges = [[], [], [], []]
- for x2_bin in np.arange(x2_nbins):
- # find the x1 bin matching the MCD
- seed_x1_bin = np.where(x2_counts[:, x2_bin] == x2_mcd[x2_bin])[0][0]
- assert len(x2_means_in_bins[seed_x1_bin][x2_bin]) == x2_mcd[x2_bin]
- # for each bin visit, find the closest matching mean in the other x1 bins
- target_means = x2_means_in_bins[seed_x1_bin][x2_bin]
- target_ranges = x1_ranges_in_bins[seed_x1_bin][x2_bin]
- for target_mean, target_range in zip(target_means, target_ranges):
- matched_ranges[seed_x1_bin].append(target_range)
- for x1_bin in np.delete(np.arange(x1_nbins), seed_x1_bin):
- matching_ind = np.abs(x2_means_in_bins[x1_bin][x2_bin] - target_mean).argmin()
- matched_ranges[x1_bin].append(x1_ranges_in_bins[x1_bin][x2_bin][matching_ind])
- # delete the matching period
- x2_means_in_bins[x1_bin][x2_bin] = np.delete(x2_means_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
- x1_ranges_in_bins[x1_bin][x2_bin] = np.delete(x1_ranges_in_bins[x1_bin][x2_bin], matching_ind, axis=0)
- return [np.row_stack(ranges) if len(ranges) > 0 else np.array([]) for ranges in matched_ranges]
- ## Signal processing
- def normalized_xcorr(a, b, dt=None, ts=None):
- """
- Compute Pearson r between two arrays at various lags
- Parameters
- ----------
- a, b : ndarray
- The arrays to correlate.
- dt : float
- The time step between samples in the arrays.
- ts : list
- If not None, only the xcorr between the specified lags will be
- returned.
- Return
- ------
- xcorr, lags : ndarray
- The cross correlation and corresponding lags between a and b.
- Positive lags indicate that a is delayed relative to b.
- """
- assert len(a) == len(b)
- n = len(a)
- a_norm = (a - a.mean()) / a.std()
- b_norm = (b - b.mean()) / b.std()
- xcorr = np.correlate(a_norm, b_norm, 'full') / n
- lags = np.arange(-n + 1, n)
- if dt is not None:
- lags = lags * dt
- if ts is not None:
- assert len(ts) == 2
- i0, i1 = lags.searchsorted(ts)
- xcorr = xcorr[i0:i1]
- lags = lags[i0:i1]
- return xcorr, lags
- def interpolate(y, x_old, x_new, axis=0, fill_value='extrapolate'):
- """
- Use linear interpolation to re-sample 1D data.
- """
- # get interpolation function
- func = interp1d(x_old, y, axis=axis, fill_value=fill_value)
- # get new y-values
- y_interpolated = func(x_new)
- return y_interpolated
- def interpolate_and_normalize(y, x_old, x_new):
- """
- Perform linear interpolation and min-max normalization.
- """
- y_new = interpolate(y, x_old, x_new)
- return (y_new - y_new.min()) / (y_new.max() - y_new.min())
- def match_signal_length(a, b, a_tpts, b_tpts):
- """
- Given two signals, truncate to match the length of the shortest.
- """
- t1 = min(a_tpts.max(), b_tpts.max())
- a1 = a[:a_tpts.searchsorted(t1)]
- a1_tpts = a_tpts[:a_tpts.searchsorted(t1)]
- b1 = b[:b_tpts.searchsorted(t1)]
- b1_tpts = b_tpts[:b_tpts.searchsorted(t1)]
- return a1, b1, a1_tpts, b1_tpts
- def times_to_counts(series, columns, t0=None, t1=None, dt=0.25):
- if type(t0) == str:
- t0, t1 = series[t0].min(), series[t0].max()
- elif t0 is None: # get overlapping time range for all columns
- t0 = max([series[f'{col.split("_")[0]}_times'].min() for col in columns])
- elif t1 is None:
- t1 = min([series[f'{col.split("_")[0]}_times'].max() for col in columns])
- # Set time base
- tbins = np.arange(t0, t1, dt)
- tpts = tbins[:-1] + (dt / 2)
- for col in columns:
- header = col.split('_')[0]
- times = series[f'{header}_times']
- counts, _ = np.histogram(times, bins=tbins)
- series[f'{header}_counts'] = counts
- series[f'{header}_tpts'] = tpts
- return series
- def resample_timeseries(y, tpts, dt=0.25):
- tpts_new = np.arange(tpts.min(), tpts.max(), dt)
- return tpts_new, interpolate(y, tpts, tpts_new)
- def _resample_data(series, columns, t0=None, t1=None, dt=0.25):
- if type(t0) == str:
- t0, t1 = series[t0].min(), series[t0].max()
- elif t0 is None: # get overlapping time range for all columns
- t0 = max([series[f'{col.split("_")[0]}_tpts'].min() for col in columns])
- t1 = min([series[f'{col.split("_")[0]}_tpts'].max() for col in columns])
- # Set new time base
- tbins = np.arange(t0, t1, dt)
- tpts_new = tbins[:-1] + (dt / 2)
- # Interpolate and re-sample each column
- for col in columns:
- header = col.split('_')[0]
- data = series[col]
- tpts = series[f'{header}_tpts']
- series[col] = interpolate(data, tpts, tpts_new)
- series[f'{header}_tpts'] = tpts_new
- return series
- ## Neural activity
- def get_mean_rates(df):
- df['mean_rate'] = df.apply(
- lambda x:
- len(x['spk_times']) / (x['spk_tinfo'][1] - x['spk_tinfo'][0]),
- axis='columns'
- )
- return df
- def get_mean_rate_threshold(df, alpha=0.025):
- if 'mean_rate' not in df.columns:
- df = get_mean_rates(df)
- rates = np.log10(df['mean_rate'])
- gmm = GaussianMixture(n_components=2)
- gmm.fit(rates[..., np.newaxis])
- (mu, var) = (gmm.means_.max(), gmm.covariances_.squeeze()[gmm.means_.argmax()])
- threshold = mu + norm.ppf(alpha) * np.sqrt(var)
- return threshold
- def filter_units(df, threshold):
- if 'mean_rate' not in df.columns:
- df = get_mean_rates(df)
- return df.query(f'mean_rate >= {threshold}')
- def apply_get_raster(series, events, spike_type='spk', pre=0, post=1):
- events = series['%s_times' % events]
- spks = series['%s_times' % spike_type]
- raster = np.array([spks[(spks > t0 - pre) & (spks < t0 + post)] - t0 for t0 in events], dtype='object')
- return raster
- def get_psth(events, spikes, pre=0, post=1, dt=0.001, bw=0.01, baseline=[]):
- rate_kernel = GaussianKernel(bw*pq.s)
- tpts = np.arange(pre, post, dt)
- psth = np.full((len(events), len(tpts)), np.nan)
- for i, t0 in enumerate(events):
- rel_ts = spikes - t0
- rel_ts = rel_ts[(rel_ts >= pre) & (rel_ts <= post)]
- try:
- rate = instantaneous_rate(
- SpikeTrain(rel_ts, t_start=pre, t_stop=post, units='s'),
- sampling_period=dt*pq.s,
- kernel=rate_kernel
- )
- except:
- continue
- psth[i] = rate.squeeze()
- if baseline:
- b0, b1 = tpts.searchsorted(baseline)
- baseline_rate = psth[:, b0:b1].mean(axis=1)
- psth = (psth.T - baseline_rate).T
- return psth, tpts
- def apply_get_psth(series, events, spike_type, **kwargs):
- events = series['{}_times'.format(events)]
- spikes = series['{}_times'.format(spike_type)]
- psth, tpts = get_psth(events, spikes, **kwargs)
- return psth
- def get_responses(events, data, tpts, pre=0, post=1, baseline=[]):
- dt = np.round(np.diff(tpts).mean(), 3) # round to nearest ms
- i_pre, i_post = int(pre / dt), int(post / dt)
- responses = np.full((len(events), i_pre + i_post), np.nan)
- for j, t0 in enumerate(events):
- i = tpts.searchsorted(t0)
- i0, i1 = i - i_pre, i + i_post
- if i0 < 0:
- continue
- if i1 > len(data):
- break
- responses[j] = data[i0:i1]
- tpts = np.linspace(pre, post, responses.shape[1])
- if baseline:
- b0, b1 = tpts.searchsorted(baseline)
- baseline_resp = responses[:, b0:b1].mean(axis=1)
- responses = (responses.T - baseline_resp).T
- return responses, tpts
- def apply_get_responses(series, events, data, **kwargs):
- events = series[f'{events}_times']
- tpts = series[f'{data.split("_")[0]}_tpts']
- data = series[f'{data}']
- responses, tpts = get_responses(events, data, tpts, **kwargs)
- return responses
- def rvr(responses):
- """
- Compute the 'response variability ratio': the ratio of the variance of the
- mean response versus the mean across-trial variance (for each timepoint).
- Parameters
- ----------
- responses : ndarray
- 2D array of responses where each row is a trial.
- """
- return np.nanmean(responses, axis=0).var() / np.nanvar(responses, axis=0).mean()
|