123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- import os
- import numpy as np
- import pandas as pd
- from tqdm import tqdm
- import argparse
- import quantities as pq
- from neo import SpikeTrain
- from elephant.statistics import instantaneous_rate
- from elephant.kernels import GaussianKernel
- from parameters import DATAPATH, MINRATE, NSPIKES, TRIGGEREDAVERAGES, NSHUFFLES, SHUFFLE_BINWIDTH
- from util import load_data, filter_units, shuffle_bins
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('e_name')
- parser.add_argument('region')
- parser.add_argument('-t', '--triggers', nargs='+', default=['saccade', 'run', 'sit'])
- parser.add_argument('-s', '--spk_types', nargs='+', default=['tonicspk', 'burst'])
- args = parser.parse_args()
- df_run = load_data('ball', [args.e_name], region=args.region)
- df_run['run_times'] = df_run['run_bouts'].apply(lambda x: x[:, 0])
- df_run['sit_times'] = df_run['run_bouts'].apply(lambda x: x[:, 1])
- df_pupil = load_data('pupil', [args.e_name], region=args.region)
- df_triggers = pd.merge(df_run, df_pupil)
- if 'trial_on' in args.triggers:
- df_trials = load_data('trials', [args.e_name], region=args.region)
- df_trials.rename(columns={'trial_on_time':'trial_on_times'}, inplace=True)
- df_triggers = pd.merge(df_triggers, df_trials)
- if 'burst' in args.triggers:
- import sys; sys.exit()
- df_triggers.set_index(['m', 's', 'e'], inplace=True)
- df_spikes = load_data('spikes', [args.e_name], region=args.region).set_index(['m', 's', 'e'])
- df_spikes = filter_units(df_spikes, MINRATE)
- seriess = []
- for idx, unit in tqdm(df_spikes.iterrows(), total=len(df_spikes)):
- data = {
- 'm':idx[0],
- 's':idx[1],
- 'e':idx[2],
- 'u':unit['u']
- }
- for trigger in args.triggers:
- #try:
- trigger_times = df_triggers.loc[idx]['%s_times' % trigger]
- #except KeyError:
- #continue
- pars = TRIGGEREDAVERAGES[trigger]
- fs = pars['dt'] * pq.s
- kernel = GaussianKernel(pars['bw'] * pq.s)
- for spk_type in args.spk_types:
- # Get inst. rate for whole experiment
- spk_times = unit['%s_times' % spk_type]
- if len(spk_times) < NSPIKES:
- continue
- t0 = min([unit['spk_tinfo'][0], spk_times.min()])
- t1 = max([unit['spk_tinfo'][1], spk_times.max()])
- spk_train = SpikeTrain(spk_times, t_start=t0 * pq.s, t_stop=t1 * pq.s, units='s')
- inst_rate = instantaneous_rate(spk_train, sampling_period=fs, kernel=kernel)
- inst_rate = inst_rate.squeeze().magnitude
- # Get responses to each trigger
- spk_tpts = np.linspace(t0, t1, inst_rate.shape[0])
- trigger_times = trigger_times[trigger_times < (spk_tpts.max() - pars['post'])]
- trigger_times = trigger_times[trigger_times > (spk_tpts.min() - pars['pre'])]
- i0s = spk_tpts.searchsorted(trigger_times) + int(pars['pre'] / pars['dt'])
- i1s = spk_tpts.searchsorted(trigger_times) + int(pars['post'] / pars['dt'])
- responses = np.row_stack([inst_rate[i0:i1] for i0, i1 in zip(i0s, i1s)])
- # Baseline normalize responses
- response_tpts = np.linspace(pars['pre'], pars['post'], responses.shape[1])
- b0, b1 = response_tpts.searchsorted(pars['baseline'])
- responses = (responses.T - responses[:, b0:b1].mean(axis=1)).T
- # Take mean
- triggered_average = responses.mean(axis=0)
- # Get triggereg averages from shuffled rates
- triggered_average_shf = np.full((NSHUFFLES, triggered_average.shape[0]), np.nan)
- for shf_i in range(NSHUFFLES):
- shuffle_binwidth = int(SHUFFLE_BINWIDTH / pars['dt'])
- inst_rate_shf = shuffle_bins(inst_rate, shuffle_binwidth)
- responses_shf = np.row_stack([inst_rate_shf[i0:i1] for i0, i1 in zip(i0s, i1s)])
- responses_shf = (responses_shf.T - responses_shf[:, b0:b1].mean(axis=1)).T
- triggered_average_shf[shf_i] = responses_shf.mean(axis=0)
- ci_low, ci_high = np.percentile(triggered_average_shf, [2.5, 97.5], axis=0)
- sig = (triggered_average < ci_low).any() | (triggered_average > ci_high).any()
- data[f'{trigger}_{spk_type}_response'] = triggered_average
- data[f'{trigger}_{spk_type}_tpts'] = response_tpts
- data[f'{trigger}_{spk_type}_sig'] = sig
- seriess.append(pd.Series(data=data))
- df_resp = pd.DataFrame(seriess)
- filename = f'responses_{args.e_name}_{args.region}.pkl'
- df_resp.to_pickle(DATAPATH + filename)
|