import os import logging import nixio as nix import numpy as np import pandas as pd from pydantic import NoneIsAllowedError import rlxnix as rlx from joblib import Parallel, delayed from .baseline_response_properties import burst_fraction from ..util import detect_eod_times, firing_rate, spike_triggered_average, get_temporal_shift, read_file_list from IPython import embed stimulus_name = "gwn300Hz10s0.3.dat" def noise_features(filestim_runs): features = [] stim = None stim_cutoff = None stim_times = None stim_ampls = None for i, f in enumerate(filestim_runs): f.stimulus_folder = "stimuli" if len(f.stimuli) < 2: continue if stimulus_name not in f.stimulus_filename: continue if f.stimulus_filename != stim: stim = f.stimulus_filename stim_ampls, stim_times = f.load_stimulus(stimulus_index=1) if "/" in stim: stim_name = stim.split("/")[-1] elif "\\" in stim: stim_name = stim.split("\\")[-1] else: stim_name = stim if "Hz" in stim_name: stim_cutoff = float(stim_name.split("Hz")[0].split("gwn")[-1]) if stim_ampls is None or stim_times is None: continue stimulus_spikes = None for j, fs in enumerate(f.stimuli): if fs.duration < 1.0: continue spike_times = f.spikes(stimulus_index=j) if len(spike_times) < 5: logging.warning(f"Not enough spikes in stimulus {fs}") continue feature_dict = {"stimulus_index": None, "start_time":None, "end_time": None, "stimfile": None, "contrast": None, "cutoff": None, "duration": 0.0, "firing_rate": None, "rate_modulation": None, "inverted": False, "delay": 0.0} feature_dict["stimulus_index"] = f"{i}_{j}" feature_dict["start_time"] = fs.start_time feature_dict["end_time"] = fs.stop_time feature_dict["stimfile"] = stim_name feature_dict["contrast"] = f.contrast[0] feature_dict["duration"] = fs.duration feature_dict["cutoff"] = stim_cutoff if stimulus_spikes is None: stimulus_spikes = spike_times else: stimulus_spikes = np.append(stimulus_spikes, spike_times) rate = firing_rate(spike_times, fs.duration, sigma=0.005) feature_dict["rate_modulation"] = np.std(rate) feature_dict["firing_rate"] = np.mean(rate) features.append(feature_dict) sta_time, sta = spike_triggered_average(stimulus_spikes, stim_times, stim_ampls) delay, inverted = get_temporal_shift(sta_time, sta) feature_dict["inverted"] = inverted feature_dict["delay"] = delay for feat in features: feat["inverted"] = inverted feat["delay"] = delay return features def get_baseline_features(dataset, baseline_df): baserate, cv, burstiness = None, None, None baseline_runs = dataset.repro_runs("BaselineActivity") if dataset.name in baseline_df.dataset_id: # for datasets for which we do have the receptive fields baserate = baseline_df.firing_rate[baseline_df.dataset_id == dataset.name] cv = baseline_df.cv[baseline_df.dataset_id == dataset.name] burstiness = baseline_df.burst_fraction[baseline_df.dataset_id == dataset.name] elif len(baseline_runs) > 0: # for datasets without receptive field measurements baserate = baseline_runs[0].baseline_rate cv = baseline_runs[0].baseline_cv eod_frequency = None try: eod_frequency = baseline_runs[0].eod_frequency except: pass if eod_frequency is None: print("Detecting eod times manually...") eod, time = baseline_runs[0].eod() eod -= np.mean(eod) eod_times, _ = detect_eod_times(time, eod, .5 * np.max(eod)) eod_frequency = len(eod_times) / baseline_runs[0].duration if eod_frequency is None: embed() burstiness = burst_fraction(baseline_runs[0].spikes(), 1./eod_frequency) else: # should only occur for some old datasets logging.info(f"No baseline repro in dataset{dataset.name}. Trying to fix this...") min_time = 9999 for r in dataset.repro_runs(): if r.start_time < min_time: min_time = r.start_time # data time of the first repro that is not baseline if min_time > 10: spike_trace = "spikes-1" if "spikes-1" in dataset.event_traces else "Spikes-1" if spike_trace not in dataset.event_traces: return None, None, None spike_event_trace = dataset.event_traces[spike_trace] spike_times = spike_event_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:] baserate = len(spike_times) / min_time isis = np.diff(spike_times) cv = np.std(isis) / np.mean(isis) eod_trace = dataset.data_traces["EOD"] eod = eod_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:] time = np.array(eod_trace.data_array.dimensions[0].axis(len(eod))) eod_times, _ = detect_eod_times(time, eod, 0.5 * np.max(eod)) eod_frequency = len(eod_times) / min_time burstiness = burst_fraction(spike_times, 1./eod_frequency) return baserate, cv, burstiness def get_features(dataset_name, data_folder, baseline_df): print(dataset_name) features = [] baseline_feats = {"dataset_id": dataset_name, "baserate":None, "cv":None, "burstiness":None} dataset = rlx.Dataset(os.path.join(data_folder, dataset_name + ".nix")) filestimulus_runs = dataset.repro_runs("FileStimulus") if len(filestimulus_runs) == 0: logging.error(f"Dataset {dataset_name} has no FileStimulus recordings. Skipping dataset!") return None baserate, cv, burstiness = get_baseline_features(dataset, baseline_df) if baserate is None: logging.warning(f"Dataset {dataset.name} has no BaselineActivity recording. Skipping dataset!") baseline_feats["dataset_id"] = dataset_name baseline_feats["baserate"] = baserate baseline_feats["cv"] = cv baseline_feats["burstiness"] = burstiness features = noise_features(filestimulus_runs) for feat in features: feat.update(baseline_feats) return features def run_driven_response_analysis(file_list_folder: str, data_folder: str, results_folder: str, num_cores: int = 1): datasets = read_file_list(os.path.join(file_list_folder, "noise_datasets.dat")) baseline_properties = pd.read_csv(os.path.join(results_folder, "baseline_properties.csv"), sep=";", index_col=0) processed_list = Parallel(n_jobs=num_cores)(delayed(get_features)(dataset, data_folder, baseline_properties) for dataset in datasets) results = [] for pr in processed_list: if pr is not None: results.extend(pr) df = pd.DataFrame(results) return df