123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- import os
- import numpy as np
- import pandas as pd
- from tqdm import tqdm
- import argparse
- from scipy.interpolate import interp1d
- from util import zero_runs
- from hht import HHT
- from parameters import *
- from util import (load_data, filter_units, switch_ranges, merge_ranges, angle_subtract,
- circmean, get_trials, shuffle_bins)
- def resample(data, old_tpts, new_tpts, axis=0, fill_value='extrapolate'):
- """
- Use linear interpolation to re-sample data.
- """
- # interpolate time-course with linear splines
- func = interp1d(old_tpts, data, axis=axis, fill_value=fill_value)
- # get new time-course
- interpolated_data = func(new_tpts)
- return interpolated_data
- def phase2rank(alpha):
- """
- Convert angles to circular ranks.
- """
- n = len(alpha)
- ranks = np.full(n, np.nan)
- ranks[alpha.argsort()] = np.arange(n)
- return 2 * np.pi * ranks / n
- def rank2phase(rank, alpha):
- """
- Convert a circular rank back to phase in the original distribution.
- """
- n = len(alpha)
- # convert circular rank to linear rank
- linrank = n * rank / 2 / np.pi
- return np.sort(alpha)[np.round(linrank).astype('int')]
- def inds2train(inds, length):
- """
- Convert event indices into binary time-series.
- Parameters
- ----------
- inds : 1D array
- event indices
- length : int
- total length of segment in which events occur
- Returns
- -------
- event_train : ndarray
- binary time-series
- """
- # initialize output array
- event_train = np.zeros(length, dtype='uint8')
- event_train[inds] = 1
- return event_train
- def times2train(evts, tpts):
- """
- Convert event times into binary time-series.
- Parameters
- ----------
- evts : 1D array
- event times
- tpts : 1D array
- time base in which events occur
- Returns
- -------
- event_train : ndarray
- binary time-series array
- """
- # clip events that fall out of time base
- evts_in_tpts = evts[(evts > tpts.min()) & (evts < tpts.max())]
- # convert times to indices
- evis = tpts.searchsorted(evts_in_tpts)
- # get the event train
- ev_train = inds2train(evis, len(tpts))
- return ev_train
- def modified_mrl2(alpha, w=None, axis=0):
- """
- A bias-free measure of the squared mean resultant length [1].
- Parameters
- ----------
- alpha : ndarray
- array of angles
- w : ndarray
- array of weights, must be same shape as alpha
- axis : int, None
- axis across which to compute mean
- Returns
- -------
- out : ndarray
- bias-corrected squared mean resultant length
- Notes
- -----
- - taking the square-root of this measure does *not* provide a bias-free
- measure of the mean resultant length, see [1].
- References
- ----------
- [1] Kutil, R. (2012). Biased and unbiased estimation of the circular mean
- resultant length and its variance. Statistics, 46(4), 549-561.
- """
- mrl, _ = circmean(alpha, w=w, axis=axis)
- n = alpha.shape[axis]
- return (n / (n - 1)) * (mrl ** 2 - (1 / n))
- def phase_tuning(phase, spk_train, shuffle_binwidth=1000, n_shuffles=1000):
- ranks = phase2rank(phase) - np.pi
- assert len(phase) == len(spk_train)
- # Get phase ranks where spikes occur
- spk_ranks = ranks[spk_train == 1]
- # Compute modified mean vector length
- r = modified_mrl2(spk_ranks)
- # Compute mean rank
- _, mean_rank = circmean(spk_ranks)
- # Convert rank back to phase
- theta = rank2phase(mean_rank + np.pi, phase)
- # Compute tuning strength for shuffled spike trains
- r_shf = np.full(n_shuffles, np.nan)
- for shf_i in range(n_shuffles):
- # Shuffle time bins of spike train
- spk_train_shf = shuffle_bins(spk_train, shuffle_binwidth)
- # Take phase of shuffled train
- spk_ranks_shf = ranks[spk_train_shf == 1]
- # Compute tuning strength
- r_shf[shf_i] = modified_mrl2(spk_ranks_shf)
- p = (r_shf > r).sum() / n_shuffles
- return r, theta, p
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('e_name')
- parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
- parser.add_argument('-t', '--tranges', default='')
- args = parser.parse_args()
- if args.tranges:
- assert args.tranges in ['run', 'sit', 'desync', 'sizematched', 'nosaccade', 'noopto']
- df_pupil = load_data('pupil', [args.e_name])
- df_pupil.set_index(['m', 's', 'e'], inplace=True)
- df_spikes = load_data('spikes', [args.e_name])
- df_spikes.set_index(['m', 's', 'e'], inplace=True)
- df_spikes = filter_units(df_spikes, MINRATE)
- ## TODO: find a better way to integrate saccades
- if 'saccade' in args.spk_types:
- df_spikes['saccade_times'] = [df_pupil.loc[idx]['saccade_times'] for idx, unit in df_spikes.iterrows()]
- # Load data for requested time ranges
- if args.tranges in ['nosaccade']:
- df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
- elif args.tranges in ['desync', 'sizematched']:
- df_hht = load_data('hht', [args.e_name]).set_index(['m', 's', 'e'])
- elif args.tranges in ['run', 'sit']:
- df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
- elif args.tranges in ['noopto']:
- df_trials = load_data('trials', [args.e_name])
- df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
- df_trials = df_trials.apply(get_trials, stim_id=-1, axis='columns').set_index(['m', 's', 'e'])
- seriess = []
- for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):
- pupil_area = row['pupil_area']
- pupil_tpts = row['pupil_tpts']
- # Get IMFs
- pupil_fs = 1 / np.diff(pupil_tpts).mean()
- hht = HHT(pupil_area, pupil_fs)
- hht.emd()
- # Get phases and frequencies
- hht.hsa()
- hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
- imf_phases = hht.phase.T
- imf_freqs = hht.characteristic_frequency
- imf_power = hht.power_ratio
- # Get time-ranges
- if args.tranges in ['run', 'sit']:
- try:
- tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
- except KeyError:
- print("No run data found for ", idx)
- continue
- if args.tranges == 'run':
- dt0 = BEHAVEXCLUSIONS['run'][1]
- dt1 = BEHAVEXCLUSIONS['sit'][0]
- elif args.tranges == 'sit':
- dt0 = BEHAVEXCLUSIONS['sit'][1]
- dt1 = BEHAVEXCLUSIONS['run'][0]
- ext = np.ones_like(tranges) * np.array([dt0, dt1])
- tranges = tranges + ext
- tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
- tranges = [tranges for imf in range(hht.n_imfs)]
- elif args.tranges in ['desync', 'sizematched']:
- try:
- tranges = df_hht.loc[idx, '%s_bouts' % args.tranges]
- except KeyError:
- print("No HHT data found for ", idx)
- continue
- elif args.tranges in ['nosaccade']:
- saccade_times = df_pupil.loc[idx, 'saccade_times']
- saccade_tranges = np.column_stack([saccade_times, saccade_times])
- saccade_tranges += np.array(BEHAVEXCLUSIONS['saccade'])
- saccade_tranges = merge_ranges(saccade_tranges, dt=(1 / pupil_fs))
- tranges = switch_ranges(
- saccade_tranges,
- dt=(1 / pupil_fs),
- minval=pupil_tpts.min(),
- maxval=pupil_tpts.max()
- )
- tranges = [tranges for imf in range(hht.n_imfs)]
- elif args.tranges in ['opto', 'noopto']:
- try:
- trial_ids = df_trials.loc[idx, 'trial_id']
- opto_trials = df_trials.loc[idx, 'opto_trials']
- trial_on_time = df_trials.loc[idx, 'trial_on_times']
- trial_off_time = df_trials.loc[idx, 'trial_off_times']
- except KeyError:
- print("No trial data found for ", idx)
- continue
- t0s = trial_on_time
- t1s = trial_off_time
- tranges = np.column_stack([t0s, t1s])
- tranges = [tranges for imf in range(hht.n_imfs)]
- elif args.tranges in ['half1', 'half2']:
- t0, t1 = pupil_tpts.min(), pupil_tpts.max()
- half_length = (t1 - t0) / 2
- if args.tranges == 'half1':
- tranges = [np.array([[t0, t0 + half_length]]) for imf in range(hht.n_imfs)]
- if args.tranges == 'half2':
- tranges = [np.array([[t0 + half_length, t1]]) for imf in range(hht.n_imfs)]
- elif args.tranges in ['split1', 'split2']:
- imf_cycles = [pupil_tpts[np.where(np.diff(phase) < -np.pi)[0]] for phase in hht.phase.T]
- imf_cycles = [np.concatenate([pupil_tpts[:1], cycles, pupil_tpts[-1:]]) for cycles in imf_cycles]
- cycle_tranges = [np.column_stack([cycles[:-1], cycles[1:]]) for cycles in imf_cycles]
- if args.tranges == 'split1':
- tranges = [cycles[0::2] for cycles in cycle_tranges]
- else:
- tranges = [cycles[1::2] for cycles in cycle_tranges]
- # Get units for this experiment
- try:
- df_units = df_spikes.loc[idx]
- except KeyError:
- print("Spikes missing for {}".format(idx))
- continue
- for _, unit in df_units.iterrows():
- unit_tpts = np.arange(*unit['spk_tinfo'])
- t0, t1 = row['pupil_tpts'].min(), row['pupil_tpts'].max()
- i0, i1 = unit_tpts.searchsorted([t0, t1])
- unit_tpts = unit_tpts[i0:i1]
- imf_phases_resamp = resample(imf_phases, pupil_tpts, unit_tpts, axis=1)
- spk_trains = {}
- for spk_type in args.spk_types:
- spk_trains[spk_type] = times2train(unit['{}_times'.format(spk_type)], unit_tpts)
- for imf_i, phase in enumerate(imf_phases_resamp):
- data = {
- 'm': idx[0],
- 's': idx[1],
- 'e': idx[2],
- 'u': unit['u'],
- 'imf': imf_i + 1,
- 'freq': imf_freqs[imf_i],
- 'power': imf_power[imf_i]
- }
- # Get time ranges to analyze
- if args.tranges:
- tranges_imf = tranges[imf_i]
- else:
- tranges_imf = np.array([[t0, t1]])
- iranges_imf = unit_tpts.searchsorted(tranges_imf)
- unit_fs = 1 / np.diff(unit_tpts).mean()
- binwidth = np.floor(unit_fs * SHUFFLE_BINWIDTH).astype('int')
- for spk_type, spk_train in spk_trains.items():
- # Take only data in ranges
- if len(iranges_imf) > 0:
- phase_clipped = np.concatenate([phase[i0:i1] for i0, i1 in iranges_imf])
- train_clipped = np.concatenate([spk_train[i0:i1] for i0, i1 in iranges_imf])
- else:
- train_clipped = np.array([])
- # Check that there are enough spikes to do analysis
- nspikes = train_clipped.sum()
- if nspikes < NSPIKES:
- r = theta = p = np.nan
- else:
- r, theta, p = phase_tuning(
- phase_clipped, train_clipped,
- shuffle_binwidth=binwidth, n_shuffles=NSHUFFLES
- )
- data['_'.join([spk_type, 'n'])] = nspikes
- data['_'.join([spk_type, 'strength'])] = r
- data['_'.join([spk_type, 'phase'])] = theta
- data['_'.join([spk_type, 'p'])] = p
- seriess.append(pd.Series(data=data))
- df_tuning = pd.DataFrame(seriess)
- if args.tranges:
- filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, args.tranges)
- elif args.spk_types == ['saccade']:
- filename = 'phasetuning_{}_{}.pkl'.format(args.e_name, 'saccades')
- else:
- filename = 'phasetuning_{}.pkl'.format(args.e_name)
- df_tuning.to_pickle(DATAPATH + filename)
|