import argparse from tqdm import tqdm import numpy as np import pandas as pd from scipy.stats import pearsonr from parameters import NIMFCYCLES, NSHUFFLES, DATAPATH from util import load_data, interpolate, normalized_xcorr from hht import HHT if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('e_name') parser.add_argument('-t', '--tranges', default='') args = parser.parse_args() df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e']) df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e']) df_all = pd.read_pickle(DATAPATH + 'run.pkl') # TODO: move to parameters max_lag = 30 # seconds seriess = [] for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)): pupil_area = row['pupil_area'] pupil_tpts = row['pupil_tpts'] pupil_dt = np.diff(pupil_tpts).mean() pupil_fs = 1 / pupil_dt t0, t1 = (pupil_tpts.min(), pupil_tpts.max()) # Prepare run data try: run_speed, run_tpts = df_run.loc[idx, ['run_speed', 'run_tpts']] except KeyError: print("No run data found for ", idx) continue i0, i1 = run_tpts.searchsorted([t0, t1]) run_speed = interpolate(run_speed[i0:i1], run_tpts[i0:i1], pupil_tpts) # Get IMFs hht = HHT(pupil_area, pupil_fs) hht.emd() hht.hsa() hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True) imfs = hht.imfs.T imf_freqs = hht.characteristic_frequency imf_power = hht.power_ratio if args.tranges: tranges = df_run.loc[idx, '%s_bouts' % args.tranges] if args.tranges == 'run': ext = np.ones_like(tranges) * np.array([2, -2]) elif args.tranges == 'sit': ext = np.ones_like(tranges) * np.array([4, -2]) tranges = tranges + ext tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]]) iranges = pupil_tpts.searchsorted(tranges) imfs = np.column_stack([imfs[:, i0:i1] for i0, i1 in iranges]) pupil_area = np.concatenate([pupil_area[i0:i1] for i0, i1 in iranges]) run_speed = np.concatenate([run_speed[i0:i1] for i0, i1 in iranges]) for i, imf in enumerate(np.row_stack([pupil_area, imfs])): data = { 'm': idx[0], 's': idx[1], 'e': idx[2], 'imf': i } if i == 0: data['freq'] = data['power'] = np.nan else: data['freq'] = imf_freqs[i - 1] data['power'] = imf_power[i - 1] xcorr, lags = normalized_xcorr(imf, run_speed, dt=pupil_dt, ts=[-1 * max_lag, max_lag]) data['xcorr'] = xcorr data['xcorr_lags'] = lags r_null = np.full(NSHUFFLES, np.nan) j = 0 while j < NSHUFFLES: tpts, signal = df_all.iloc[np.random.choice(np.arange(len(df_all)))] signal = interpolate(signal, tpts, np.arange(tpts.min(), tpts.max(), pupil_dt)) i_max = min(len(imf), len(signal)) if len(np.unique(signal[:i_max])) == 1: continue r_null[j], _ = pearsonr(imf[:i_max], signal[:i_max]) j += 1 # Get search window for peak if i == 0: i0, i1 = 0, len(xcorr) else: T = 1 / imf_freqs[i - 1] i0 , i1 = lags.searchsorted([-T, T]) # Find peak xcorr_peak = lags[i0:i1][np.abs(xcorr[i0:i1]).argmax()] data['xcorr_peak'] = xcorr_peak # Compare to null distibution xcorr_max = xcorr[i0:i1][np.abs(xcorr[i0:i1]).argmax()] p = (r_null > xcorr_max).sum() / NSHUFFLES data['xcorr_p'] = p data['xcorr_sig'] = (p < 0.025) | (p > 0.975) seriess.append(pd.Series(data=data)) df_corr = pd.DataFrame(seriess) if not args.tranges: filename = 'imfcorr_{}.pkl'.format(args.e_name) else: filename = 'imfcorr_{}_{}.pkl'.format(args.e_name, args.tranges) df_corr.to_pickle(DATAPATH + filename)