driven_response_properties.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import logging
  3. import nixio as nix
  4. import numpy as np
  5. import pandas as pd
  6. from pydantic import NoneIsAllowedError
  7. import rlxnix as rlx
  8. from joblib import Parallel, delayed
  9. from .baseline_response_properties import burst_fraction
  10. from ..util import detect_eod_times, firing_rate, spike_triggered_average, get_temporal_shift, read_file_list
  11. from IPython import embed
  12. stimulus_name = "gwn300Hz10s0.3.dat"
  13. def noise_features(filestim_runs):
  14. features = []
  15. stim = None
  16. stim_cutoff = None
  17. stim_times = None
  18. stim_ampls = None
  19. for i, f in enumerate(filestim_runs):
  20. f.stimulus_folder = "stimuli"
  21. if len(f.stimuli) < 2:
  22. continue
  23. if stimulus_name not in f.stimulus_filename:
  24. continue
  25. if f.stimulus_filename != stim:
  26. stim = f.stimulus_filename
  27. stim_ampls, stim_times = f.load_stimulus(stimulus_index=1)
  28. if "/" in stim:
  29. stim_name = stim.split("/")[-1]
  30. elif "\\" in stim:
  31. stim_name = stim.split("\\")[-1]
  32. else:
  33. stim_name = stim
  34. if "Hz" in stim_name:
  35. stim_cutoff = float(stim_name.split("Hz")[0].split("gwn")[-1])
  36. if stim_ampls is None or stim_times is None:
  37. continue
  38. stimulus_spikes = None
  39. for j, fs in enumerate(f.stimuli):
  40. if fs.duration < 1.0:
  41. continue
  42. spike_times = f.spikes(stimulus_index=j)
  43. if len(spike_times) < 5:
  44. logging.warning(f"Not enough spikes in stimulus {fs}")
  45. continue
  46. feature_dict = {"stimulus_index": None, "start_time":None, "end_time": None,
  47. "stimfile": None, "contrast": None, "cutoff": None, "duration": 0.0,
  48. "firing_rate": None, "rate_modulation": None, "inverted": False,
  49. "delay": 0.0}
  50. feature_dict["stimulus_index"] = f"{i}_{j}"
  51. feature_dict["start_time"] = fs.start_time
  52. feature_dict["end_time"] = fs.stop_time
  53. feature_dict["stimfile"] = stim_name
  54. feature_dict["contrast"] = f.contrast[0]
  55. feature_dict["duration"] = fs.duration
  56. feature_dict["cutoff"] = stim_cutoff
  57. if stimulus_spikes is None:
  58. stimulus_spikes = spike_times
  59. else:
  60. stimulus_spikes = np.append(stimulus_spikes, spike_times)
  61. rate = firing_rate(spike_times, fs.duration, sigma=0.005)
  62. feature_dict["rate_modulation"] = np.std(rate)
  63. feature_dict["firing_rate"] = np.mean(rate)
  64. features.append(feature_dict)
  65. sta_time, sta = spike_triggered_average(stimulus_spikes, stim_times, stim_ampls)
  66. delay, inverted = get_temporal_shift(sta_time, sta)
  67. feature_dict["inverted"] = inverted
  68. feature_dict["delay"] = delay
  69. for feat in features:
  70. feat["inverted"] = inverted
  71. feat["delay"] = delay
  72. return features
  73. def get_baseline_features(dataset, baseline_df):
  74. baserate, cv, burstiness = None, None, None
  75. baseline_runs = dataset.repro_runs("BaselineActivity")
  76. if dataset.name in baseline_df.dataset_id: # for datasets for which we do have the receptive fields
  77. baserate = baseline_df.firing_rate[baseline_df.dataset_id == dataset.name]
  78. cv = baseline_df.cv[baseline_df.dataset_id == dataset.name]
  79. burstiness = baseline_df.burst_fraction[baseline_df.dataset_id == dataset.name]
  80. elif len(baseline_runs) > 0: # for datasets without receptive field measurements
  81. baserate = baseline_runs[0].baseline_rate
  82. cv = baseline_runs[0].baseline_cv
  83. eod_frequency = None
  84. try:
  85. eod_frequency = baseline_runs[0].eod_frequency
  86. except:
  87. pass
  88. if eod_frequency is None:
  89. print("Detecting eod times manually...")
  90. eod, time = baseline_runs[0].eod()
  91. eod -= np.mean(eod)
  92. eod_times, _ = detect_eod_times(time, eod, .5 * np.max(eod))
  93. eod_frequency = len(eod_times) / baseline_runs[0].duration
  94. if eod_frequency is None:
  95. embed()
  96. burstiness = burst_fraction(baseline_runs[0].spikes(), 1./eod_frequency)
  97. else: # should only occur for some old datasets
  98. logging.info(f"No baseline repro in dataset{dataset.name}. Trying to fix this...")
  99. min_time = 9999
  100. for r in dataset.repro_runs():
  101. if r.start_time < min_time:
  102. min_time = r.start_time # data time of the first repro that is not baseline
  103. if min_time > 10:
  104. spike_trace = "spikes-1" if "spikes-1" in dataset.event_traces else "Spikes-1"
  105. if spike_trace not in dataset.event_traces:
  106. return None, None, None
  107. spike_event_trace = dataset.event_traces[spike_trace]
  108. spike_times = spike_event_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:]
  109. baserate = len(spike_times) / min_time
  110. isis = np.diff(spike_times)
  111. cv = np.std(isis) / np.mean(isis)
  112. eod_trace = dataset.data_traces["EOD"]
  113. eod = eod_trace.data_array.get_slice([0.0], [min_time], nix.DataSliceMode.Data)[:]
  114. time = np.array(eod_trace.data_array.dimensions[0].axis(len(eod)))
  115. eod_times, _ = detect_eod_times(time, eod, 0.5 * np.max(eod))
  116. eod_frequency = len(eod_times) / min_time
  117. burstiness = burst_fraction(spike_times, 1./eod_frequency)
  118. return baserate, cv, burstiness
  119. def get_features(dataset_name, data_folder, baseline_df):
  120. print(dataset_name)
  121. features = []
  122. baseline_feats = {"dataset_id": dataset_name, "baserate":None, "cv":None, "burstiness":None}
  123. dataset = rlx.Dataset(os.path.join(data_folder, dataset_name + ".nix"))
  124. filestimulus_runs = dataset.repro_runs("FileStimulus")
  125. if len(filestimulus_runs) == 0:
  126. logging.error(f"Dataset {dataset_name} has no FileStimulus recordings. Skipping dataset!")
  127. return None
  128. baserate, cv, burstiness = get_baseline_features(dataset, baseline_df)
  129. if baserate is None:
  130. logging.warning(f"Dataset {dataset.name} has no BaselineActivity recording. Skipping dataset!")
  131. baseline_feats["dataset_id"] = dataset_name
  132. baseline_feats["baserate"] = baserate
  133. baseline_feats["cv"] = cv
  134. baseline_feats["burstiness"] = burstiness
  135. features = noise_features(filestimulus_runs)
  136. for feat in features:
  137. feat.update(baseline_feats)
  138. return features
  139. def run_driven_response_analysis(file_list_folder: str, data_folder: str, results_folder: str, num_cores: int = 1):
  140. datasets = read_file_list(os.path.join(file_list_folder, "noise_datasets.dat"))
  141. baseline_properties = pd.read_csv(os.path.join(results_folder, "baseline_properties.csv"), sep=";", index_col=0)
  142. processed_list = Parallel(n_jobs=num_cores)(delayed(get_features)(dataset, data_folder, baseline_properties) for dataset in datasets)
  143. results = []
  144. for pr in processed_list:
  145. if pr is not None:
  146. results.extend(pr)
  147. df = pd.DataFrame(results)
  148. return df