123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import argparse
- from tqdm import tqdm
- import numpy as np
- import pandas as pd
- from scipy.stats import pearsonr
- from parameters import NIMFCYCLES, NSHUFFLES, DATAPATH
- from util import load_data, interpolate, normalized_xcorr
- from hht import HHT
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('e_name')
- parser.add_argument('-t', '--tranges', default='')
- args = parser.parse_args()
- df_pupil = load_data('pupil', [args.e_name]).set_index(['m', 's', 'e'])
- df_run = load_data('ball', [args.e_name]).set_index(['m', 's', 'e'])
- df_all = pd.read_pickle(DATAPATH + 'run.pkl')
- # TODO: move to parameters
- max_lag = 30 # seconds
- seriess = []
- for idx, row in tqdm(df_pupil.iterrows(), total=len(df_pupil)):
- pupil_area = row['pupil_area']
- pupil_tpts = row['pupil_tpts']
- pupil_dt = np.diff(pupil_tpts).mean()
- pupil_fs = 1 / pupil_dt
- t0, t1 = (pupil_tpts.min(), pupil_tpts.max())
- # Prepare run data
- try:
- run_speed, run_tpts = df_run.loc[idx, ['run_speed', 'run_tpts']]
- except KeyError:
- print("No run data found for ", idx)
- continue
- i0, i1 = run_tpts.searchsorted([t0, t1])
- run_speed = interpolate(run_speed[i0:i1], run_tpts[i0:i1], pupil_tpts)
- # Get IMFs
- hht = HHT(pupil_area, pupil_fs)
- hht.emd()
- hht.hsa()
- hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
- imfs = hht.imfs.T
- imf_freqs = hht.characteristic_frequency
- imf_power = hht.power_ratio
- if args.tranges:
- tranges = df_run.loc[idx, '%s_bouts' % args.tranges]
- if args.tranges == 'run':
- ext = np.ones_like(tranges) * np.array([2, -2])
- elif args.tranges == 'sit':
- ext = np.ones_like(tranges) * np.array([4, -2])
- tranges = tranges + ext
- tranges = np.row_stack([trange for trange in tranges if trange[0] < trange[1]])
- iranges = pupil_tpts.searchsorted(tranges)
- imfs = np.column_stack([imfs[:, i0:i1] for i0, i1 in iranges])
- pupil_area = np.concatenate([pupil_area[i0:i1] for i0, i1 in iranges])
- run_speed = np.concatenate([run_speed[i0:i1] for i0, i1 in iranges])
- for i, imf in enumerate(np.row_stack([pupil_area, imfs])):
- data = {
- 'm': idx[0],
- 's': idx[1],
- 'e': idx[2],
- 'imf': i
- }
- if i == 0:
- data['freq'] = data['power'] = np.nan
- else:
- data['freq'] = imf_freqs[i - 1]
- data['power'] = imf_power[i - 1]
- xcorr, lags = normalized_xcorr(imf, run_speed, dt=pupil_dt, ts=[-1 * max_lag, max_lag])
- data['xcorr'] = xcorr
- data['xcorr_lags'] = lags
- r_null = np.full(NSHUFFLES, np.nan)
- j = 0
- while j < NSHUFFLES:
- tpts, signal = df_all.iloc[np.random.choice(np.arange(len(df_all)))]
- signal = interpolate(signal, tpts, np.arange(tpts.min(), tpts.max(), pupil_dt))
- i_max = min(len(imf), len(signal))
- if len(np.unique(signal[:i_max])) == 1:
- continue
- r_null[j], _ = pearsonr(imf[:i_max], signal[:i_max])
- j += 1
- # Get search window for peak
- if i == 0:
- i0, i1 = 0, len(xcorr)
- else:
- T = 1 / imf_freqs[i - 1]
- i0 , i1 = lags.searchsorted([-T, T])
- # Find peak
- xcorr_peak = lags[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
- data['xcorr_peak'] = xcorr_peak
- # Compare to null distibution
- xcorr_max = xcorr[i0:i1][np.abs(xcorr[i0:i1]).argmax()]
- p = (r_null > xcorr_max).sum() / NSHUFFLES
- data['xcorr_p'] = p
- data['xcorr_sig'] = (p < 0.025) | (p > 0.975)
- seriess.append(pd.Series(data=data))
- df_corr = pd.DataFrame(seriess)
- if not args.tranges:
- filename = 'imfcorr_{}.pkl'.format(args.e_name)
- else:
- filename = 'imfcorr_{}_{}.pkl'.format(args.e_name, args.tranges)
- df_corr.to_pickle(DATAPATH + filename)
|