123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import os
- import logging
- import nixio as nix
- import numpy as np
- import pandas as pd
- from pydantic import NoneIsAllowedError
- import rlxnix as rlx
- from joblib import Parallel, delayed
- from .baseline_response_properties import burst_fraction
- from ..util import detect_eod_times, firing_rate, spike_triggered_average, get_temporal_shift, read_file_list
- from IPython import embed
- stimulus_name = "gwn300Hz10s0.3.dat"
- def noise_features(filestim_runs):
- features = []
- stim = None
- stim_cutoff = None
- stim_times = None
- stim_ampls = None
- for i, f in enumerate(filestim_runs):
- f.stimulus_folder = "stimuli"
- if len(f.stimuli) < 2:
- continue
- if stimulus_name not in f.stimulus_filename:
- continue
- if f.stimulus_filename != stim:
- stim = f.stimulus_filename
- stim_ampls, stim_times = f.load_stimulus(stimulus_index=1)
- if "/" in stim:
- stim_name = stim.split("/")[-1]
- elif "\\" in stim:
- stim_name = stim.split("\\")[-1]
- else:
- stim_name = stim
- if "Hz" in stim_name:
- stim_cutoff = float(stim_name.split("Hz")[0].split("gwn")[-1])
- if stim_ampls is None or stim_times is None:
- continue
- stimulus_spikes = None
- for j, fs in enumerate(f.stimuli):
- if fs.duration < 1.0:
- continue
- spike_times = f.spikes(stimulus_index=j)
- if len(spike_times) < 5:
- logging.warning(f"Not enough spikes in stimulus {fs}")
- continue
- feature_dict = {"stimulus_index": None, "start_time":None, "end_time": None,
- "stimfile": None, "contrast": None, "cutoff": None, "duration": 0.0,
- "firing_rate": None, "rate_modulation": None, "inverted": False,
- "delay": 0.0}
- feature_dict["stimulus_index"] = f"{i}_{j}"
- feature_dict["start_time"] = fs.start_time
- feature_dict["end_time"] = fs.stop_time
- feature_dict["stimfile"] = stim_name
- feature_dict["contrast"] = f.contrast[0]
- feature_dict["duration"] = fs.duration
- feature_dict["cutoff"] = stim_cutoff
- if stimulus_spikes is None:
- stimulus_spikes = spike_times
- else:
- stimulus_spikes = np.append(stimulus_spikes, spike_times)
- rate = firing_rate(spike_times, fs.duration, sigma=0.005)
- feature_dict["rate_modulation"] = np.std(rate)
- feature_dict["firing_rate"] = np.mean(rate)
- features.append(feature_dict)
- sta_time, sta = spike_triggered_average(stimulus_spikes, stim_times, stim_ampls)
- delay, inverted = get_temporal_shift(sta_time, sta)
- feature_dict["inverted"] = inverted
- feature_dict["delay"] = delay
- for feat in features:
- feat["inverted"] = inverted
- feat["delay"] = delay
- return features
- def get_baseline_features(dataset, baseline_df):
- baserate, cv, burstiness = None, None, None
- baseline_runs = dataset.repro_runs("BaselineActivity")
- if dataset.name in baseline_df.dataset_id: # for datasets for which we do have the receptive fields
- baserate = baseline_df.firing_rate[baseline_df.dataset_id == dataset.name]
- cv = baseline_df.cv[baseline_df.dataset_id == dataset.name]
- burstiness = baseline_df.burst_fraction[baseline_df.dataset_id == dataset.name]
- elif len(baseline_runs) > 0: # for datasets without receptive field measurements
- baserate = baseline_runs[0].baseline_rate
- cv = baseline_runs[0].baseline_cv
- eod_frequency = None
- try:
- eod_frequency = baseline_runs[0].eod_frequency
- except:
- pass
- if eod_frequency is None:
- print("Detecting eod times manually...")
- eod, time = baseline_runs[0].eod()
- eod -= np.mean(eod)
- eod_times, _ = detect_eod_times(time, eod, .5 * np.max(eod))
- eod_frequency = len(eod_times) / baseline_runs[0].duration
- if eod_frequency is None:
- embed()
- burstiness = burst_fraction(baseline_runs[0].spikes(), 1./eod_frequency)
- else: # should only occur for some old datasets
- logging.info(f"No baseline repro in dataset{dataset.name}. Trying to fix this...")
- min_time = 9999
- for r in dataset.repro_runs():
- if r.start_time < min_time:
- min_time = r.start_time # data time of the first repro that is not baseline
- if min_time > 10:
- spike_trace = "spikes-1" if "spikes-1" in dataset.event_traces else "Spikes-1"
- if spike_trace not in dataset.event_traces:
- return None, None, None
- spike_event_trace = dataset.event_traces[spike_trace]
- spike_times = spike_event_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:]
- baserate = len(spike_times) / min_time
- isis = np.diff(spike_times)
- cv = np.std(isis) / np.mean(isis)
- eod_trace = dataset.data_traces["EOD"]
- eod = eod_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:]
- time = np.array(eod_trace.data_array.dimensions[0].axis(len(eod)))
- eod_times, _ = detect_eod_times(time, eod, 0.5 * np.max(eod))
- eod_frequency = len(eod_times) / min_time
- burstiness = burst_fraction(spike_times, 1./eod_frequency)
- return baserate, cv, burstiness
- def get_features(dataset_name, data_folder, baseline_df):
- print(dataset_name)
- features = []
- baseline_feats = {"dataset_id": dataset_name, "baserate":None, "cv":None, "burstiness":None}
- dataset = rlx.Dataset(os.path.join(data_folder, dataset_name + ".nix"))
- filestimulus_runs = dataset.repro_runs("FileStimulus")
- if len(filestimulus_runs) == 0:
- logging.error(f"Dataset {dataset_name} has no FileStimulus recordings. Skipping dataset!")
- return None
- baserate, cv, burstiness = get_baseline_features(dataset, baseline_df)
- if baserate is None:
- logging.warning(f"Dataset {dataset.name} has no BaselineActivity recording. Skipping dataset!")
- baseline_feats["dataset_id"] = dataset_name
- baseline_feats["baserate"] = baserate
- baseline_feats["cv"] = cv
- baseline_feats["burstiness"] = burstiness
- features = noise_features(filestimulus_runs)
- for feat in features:
- feat.update(baseline_feats)
- return features
- def run_driven_response_analysis(file_list_folder: str, data_folder: str, results_folder: str, num_cores: int = 1):
- datasets = read_file_list(os.path.join(file_list_folder, "noise_datasets.dat"))
- baseline_properties = pd.read_csv(os.path.join(results_folder, "baseline_properties.csv"), sep=";", index_col=0)
- processed_list = Parallel(n_jobs=num_cores)(delayed(get_features)(dataset, data_folder, baseline_properties) for dataset in datasets)
- results = []
- for pr in processed_list:
- if pr is not None:
- results.extend(pr)
- df = pd.DataFrame(results)
- return df
|