In [1]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '..'))

import numpy as np
import h5py
import json
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import signal
from scipy import stats
from target import build_tgt_matrix
import pandas as pd

%matplotlib inline

In [2]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
 return false;
}



In [31]:
source = '/home/sobolev/nevermind/Andrey/data'
report = '/home/sobolev/nevermind/Andrey/analysis/PSTH'

selected_sessions = [
'009266_hippoSIT_2023-04-17_17-04-17', # ch17, 20 + 55 correction, 5067 events. Showcase for N2 / N3 mod in target
'009266_hippoSIT_2023-04-18_10-10-37', # ch17, 10 + 55 correction, 5682 events
'009266_hippoSIT_2023-04-18_17-03-10', # ch17, 6 + 55 correction, 5494 events: FIXME very weird 1-2nd in target, find out
'009266_hippoSIT_2023-04-19_10-33-51', # ch17, 4 + 55 correction, 6424 events: very weird 1-2nd in target, find out
'009266_hippoSIT_2023-04-20_08-57-39', # ch1, 1 + 55 correction, 6424 events. Showcase for N2 / N3 mod in target
'009266_hippoSIT_2023-04-24_16-56-55', # ch17, 5 + 55* correction, 6165 events, frequency
'009266_hippoSIT_2023-04-26_08-20-17', # ch17, 12 + 55* correction, 6095 events, duration - showcase for N2 
'009266_hippoSIT_2023-05-02_12-22-14', # ch20, 10 + 55 correction, 5976 events, FIXME very weird 1-2nd in target, find out
'009266_hippoSIT_2023-05-04_09-11-06', # ch17, 5 + 55* correction, 4487 events, coma session with baseline AEPs
'009266_hippoSIT_2023-05-04_19-47-15', # ch20, 2 + 55 correction, 5678 events, duration
]

session = selected_sessions[0]

animal = session.split('_')[0]
sessionpath = os.path.join(source, animal, session)
aeps_file = os.path.join(sessionpath, 'AEPs.h5')
h5name = os.path.join(sessionpath, session + '.h5')
report_path = os.path.join(report, session)
if not os.path.exists(report_path):
 os.makedirs(report_path)

In [32]:
with h5py.File(h5name, 'r') as f:
 tl = np.array(f['processed']['timeline']) # time, X, Y, speed, etc.
 trials = np.array(f['processed']['trial_idxs']) # t_start_idx, t_end_idx, x_tgt, y_tgt, r_tgt, result
 cfg = json.loads(f['processed'].attrs['parameters'])
 
with h5py.File(aeps_file, 'r') as f:
 aeps = np.array(f['aeps'])
 aeps_events = np.array(f['aeps_events'])
 
# TODO find better way. Remove outliers
aeps[aeps > 5000] = 5000
aeps[aeps < -5000] = -5000

# load metrics
AEP_metrics_lims = {}
AEP_metrics_raw = {}
AEP_metrics_norm = {}
with h5py.File(aeps_file, 'r') as f:
 for metric_name in f['raw']:
 AEP_metrics_raw[metric_name] = np.array(f['raw'][metric_name])
 AEP_metrics_norm[metric_name] = np.array(f['norm'][metric_name])
 AEP_metrics_lims[metric_name] = [int(x) for x in f['raw'][metric_name].attrs['limits'].split(',')]

tgt_dur = cfg['experiment']['target_duration']
tgt_matrix = build_tgt_matrix(tl, aeps_events, cfg['experiment']['target_duration'])

aeps.shape, tgt_matrix.shape

((5067, 200), (73, 5))

## Is performance dependent on AEP states?

In [33]:
# separate states based on AEP metrics - metric mean before entering
def compute_state_idxs(metric_name, n_pulses=10):
 idxs_aeps_high, idxs_aeps_low = [], [] # indices to tgt_matrix

 metric_mean = AEP_metrics_norm[metric_name].mean()
 for i, tgt_enter in enumerate(tgt_matrix):
 if tgt_enter[2] - n_pulses < 0:
 continue
 metric_inst = AEP_metrics_norm[metric_name][tgt_enter[2] - n_pulses:tgt_enter[2]].mean()
 if metric_inst > metric_mean:
 idxs_aeps_high.append(i)
 else:
 idxs_aeps_low.append(i)
 
 return idxs_aeps_low, idxs_aeps_high

In [34]:
# test if the animal will stay in the island depending on high / low AEP metrics

metric_name = 'P1'
n_pulses = 8
predictions = {}

for metric_name in AEP_metrics_norm.keys():
 idxs_aeps_low, idxs_aeps_high = compute_state_idxs(metric_name, n_pulses)

 actual = tgt_matrix[:, 4]
 predicted = np.zeros(len(tgt_matrix))
 predicted[idxs_aeps_high] = 1
 high_low = (predicted == actual).sum() / len(tgt_matrix)
 
 predicted = np.zeros(len(tgt_matrix))
 predicted[idxs_aeps_low] = 1
 low_high = (predicted == actual).sum() / len(tgt_matrix)
 
 predictions[metric_name] = (high_low, low_high)
 
predictions

{'N1': (0.6164383561643836, 0.3835616438356164),
 'P1': (0.5342465753424658, 0.4657534246575342),
 'P2': (0.6164383561643836, 0.3835616438356164),
 'P3': (0.3835616438356164, 0.6164383561643836)}