receptive_field_estimation.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import logging
  3. import numpy as np
  4. import pandas as pd
  5. import rlxnix as rlx
  6. from scipy.signal import csd
  7. from scipy.optimize import curve_fit
  8. from joblib import Parallel, delayed
  9. from ..util import read_file_list, firing_rate
  10. def gauss(x,a,x0,sigma):
  11. return a*np.exp(-(x-x0)**2/(2*sigma**2))
  12. def receptive_field_position(rr, dt=1./20000., dataset_name=""):
  13. logging.info(f"\tReceptive field position...")
  14. if len(rr.stimuli) == 0:
  15. return None
  16. durations = rr.stimulus_durations
  17. fish_length = rr.fish_length
  18. dfs = rr.stimulus_deltafs
  19. x_positions, y_positions, z_positions = rr.stimulus_positions
  20. stim_positions_relative = x_positions / fish_length
  21. power_at_df = []
  22. positions = []
  23. relative_positions = []
  24. for i, stim in enumerate(rr.stimuli):
  25. if "ReceptiveField" not in stim.name:
  26. continue
  27. if y_positions[i] != 0.0:
  28. logging.debug(f"skipping position ({x_positions[i]}, {y_positions[i]})")
  29. continue
  30. spike_times = rr.spikes(i)
  31. rate = firing_rate(spike_times, durations[i], sigma=0.001, dt=dt)
  32. freq, psd = csd(rate, rate, fs=1./dt, nperseg=2**14, noverlap=2**13)
  33. power_at_df.append(np.sum(psd[(freq > dfs[i] - 2) & (freq < dfs[i] + 2)]))
  34. positions.append(x_positions[i])
  35. relative_positions.append(stim_positions_relative[i])
  36. if len(np.unique(relative_positions)) < 3:
  37. logging.info(f"{dataset_name}.{rr.name}: less than 3 positions, skipping")
  38. return None
  39. relative_positions = np.array(relative_positions)
  40. power_at_df = np.array(power_at_df)
  41. rf_position = 0.0
  42. rf_sigma = 0.0
  43. position_uncertainty = 0.0
  44. try:
  45. popt, pcov = curve_fit(gauss, relative_positions, power_at_df,
  46. p0=[1, 0.5, 0.25])
  47. rf_position = popt[1]
  48. rf_sigma = popt[2]
  49. if np.any(np.isinf(pcov)):
  50. logging.info(f"{dataset_name}.{rr.name} fitting was not successful!")
  51. position_uncertainty = 1.0
  52. position_uncertainty = np.sqrt(np.diag(pcov))[1]
  53. except Exception as e:
  54. sorted_positions = np.sort(np.unique(relative_positions))
  55. means = np.zeros_like(sorted_positions)
  56. for i, pos in enumerate(sorted_positions):
  57. means[i] = np.mean(power_at_df[relative_positions == pos])
  58. rf_position = sorted_positions[means == np.max(means)][0]
  59. rf_sigma = 1.0
  60. position_uncertainty = 1.0
  61. logging.info(f"Receptive field: position fit failed for {dataset_name} RePro run {rr.name}, falling back to maximum power position...")
  62. results = {"total_length": fish_length, "trials":len(rr.stimuli),
  63. "receptor_pos_absolute": rf_position * fish_length,
  64. "receptor_pos_relative": rf_position,
  65. "receptive_field_sigma": rf_sigma,
  66. "position_uncertainty": position_uncertainty,
  67. }
  68. return results
  69. def receptive_field_estimation(dataset_name, data_folder):
  70. logging.info("Analyzing cell {dataset_name} ...")
  71. filename = os.path.join(data_folder, dataset_name + ".nix")
  72. if not os.path.exists(filename):
  73. logging.error(f"NIX file for dataset {dataset_name} not found {filename}!")
  74. return None
  75. d = rlx.Dataset(filename)
  76. dt = d.data_traces[0].sampling_interval
  77. results = []
  78. for rf in d.repro_runs("ReceptiveField"):
  79. r = receptive_field_position(rf, dt, dataset_name)
  80. if r is not None:
  81. results.append(r)
  82. if len(results) == 0:
  83. logging.warning(f"No receptive field results for {dataset_name}!")
  84. return None
  85. best_index = 0
  86. min_error = 100000
  87. for i, res in enumerate(results):
  88. if res["receptor_pos_absolute"] == 0.0:
  89. continue
  90. err = res["position_uncertainty"]
  91. if err < min_error:
  92. best_index = i
  93. min_error = err
  94. best_results = results[best_index]
  95. best_results["dataset_id"] = dataset_name
  96. return best_results
  97. def run_receptive_field_analysis(list_folder, data_folder, num_cores=1):
  98. datasets = read_file_list(os.path.join(list_folder, "baseline_datasets.dat"))
  99. processed_list = Parallel(n_jobs=num_cores)(delayed(receptive_field_estimation)(dataset, data_folder) for dataset in datasets[:])
  100. results = [r for r in processed_list if r is not None]
  101. # results = receptive_field_estimation("2018-08-14-ac-invivo-1", data_folder)
  102. df = pd.DataFrame(results)
  103. return df