123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import sys, os
- sys.path.append(os.path.join(os.getcwd(), '..'))
- sys.path.append(os.path.join(os.getcwd(), '..', '..'))
- from imports import *
- from target import build_tgt_matrix
- import pandas as pd
- def load_session_data(session, load_units=True, load_aeps=True, load_moseq=True):
- all_areas = ['A1', 'PPC', 'HPC']
- animal = session.split('_')[0]
- sessionpath = os.path.join(source, animal, session)
- h5_file = os.path.join(sessionpath, session + '.h5')
- aeps_file = os.path.join(sessionpath, 'AEPs.h5')
- moseq_file = os.path.join(sessionpath, 'moseq.h5')
- report_path = os.path.join(report, 'PSTH', session)
- if not os.path.exists(report_path):
- os.makedirs(report_path)
-
- # load timeline and configuration
- with h5py.File(h5_file, 'r') as f:
- tl = np.array(f['processed']['timeline']) # time, X, Y, speed, etc.
- trials = np.array(f['processed']['trial_idxs']) # t_start_idx, t_end_idx, x_tgt, y_tgt, r_tgt, result
- cfg = json.loads(f['processed'].attrs['parameters'])
- # load units
- unit_names, single_units, spike_times = [], {}, {}
- if load_units:
- with h5py.File(h5_file, 'r') as f:
- unit_names = [x for x in f['units']]
- with h5py.File(h5_file, 'r') as f:
- for unit_name in unit_names:
- spike_times[unit_name] = np.array(f['units'][unit_name][H5NAMES.spike_times['name']])
- single_units[unit_name] = np.array(f['units'][unit_name][H5NAMES.inst_rate['name']])
- #single_units[unit_name] = instantaneous_rate(unit_times, tl[:, 0], k_width=50)
-
- # load AEPs
- areas, aeps, aeps_events, lfp = [], {}, [], {}
- AEP_metrics_lims, AEP_metrics_raw, AEP_metrics_norm = {}, {}, {}
- tgt_matrix = []
-
- if load_aeps:
- with h5py.File(aeps_file, 'r') as f:
- for area in all_areas:
- if not area in f:
- continue
- aeps[area] = np.array(f[area]['aeps'])
- aeps_events = np.array(f['aeps_events'])
- areas = list(aeps.keys())
- # TODO find better way. Remove outliers
- if 'A1' in areas:
- aeps['A1'][aeps['A1'] > 5000] = 5000
- aeps['A1'][aeps['A1'] < -5000] = -5000
- if 'PPC' in areas:
- aeps['PPC'][aeps['PPC'] > 1500] = 1500
- aeps['PPC'][aeps['PPC'] < -1500] = -1500
- if 'HPC' in areas:
- aeps['HPC'][aeps['HPC'] > 1500] = 1500
- aeps['HPC'][aeps['HPC'] < -1500] = -1500
- aeps[areas[0]].shape
- # load LFP
- lfp = {}
- with h5py.File(aeps_file, 'r') as f:
- for area in areas:
- if 'LFP' in f[area]:
- lfp[area] = np.array(f[area]['LFP'])
- # load AEP metrics
- AEP_metrics_lims = dict([(area, {}) for area in areas])
- AEP_metrics_raw = dict([(area, {}) for area in areas])
- AEP_metrics_norm = dict([(area, {}) for area in areas])
- with h5py.File(aeps_file, 'r') as f:
- for area in areas:
- grp = f[area]
- for metric_name in grp['raw']:
- AEP_metrics_raw[area][metric_name] = np.array(grp['raw'][metric_name])
- AEP_metrics_norm[area][metric_name] = np.array(grp['norm'][metric_name])
- AEP_metrics_lims[area][metric_name] = [int(x) for x in grp['raw'][metric_name].attrs['limits'].split(',')]
- # build target matrix
- tgt_matrix = build_tgt_matrix(tl, trials, aeps_events)
- # load moseq
- moseq = []
- if load_moseq:
- with h5py.File(moseq_file, 'r') as f:
- moseq_matrix = np.array(f['moseq'])
- moseq_headers = f['moseq'].attrs['headers']
- moseq_headers = moseq_headers.split(',')
- moseq_headers = [moseq_headers[0]] + [x[1:] for x in moseq_headers[1:]]
- moseq = pd.DataFrame(moseq_matrix, columns=moseq_headers)
-
- return {
- 'tl': tl,
- 'trials': trials,
- 'cfg': cfg,
- 'areas': areas,
- 'aeps': aeps,
- 'aeps_events': aeps_events,
- 'lfp': lfp,
- 'AEP_metrics_lims': AEP_metrics_lims,
- 'AEP_metrics_raw': AEP_metrics_raw,
- 'AEP_metrics_norm': AEP_metrics_norm,
- 'tgt_matrix': tgt_matrix,
- 'single_units': single_units,
- 'spike_times': spike_times,
- 'unit_names': unit_names,
- 'animal': animal,
- 'aeps_file': aeps_file,
- 'moseq_file': moseq_file,
- 'h5_file': h5_file,
- 'report_path': report_path,
- 'moseq': moseq
- }
|