heterogeneous_populations.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import rlxnix as rlx
  5. from joblib import Parallel, delayed
  6. from ..util import load_white_noise_stim, mutual_info
  7. spike_buffer = {}
  8. stimulus_name = "gwn300Hz10s0.3.dat"
  9. def population_spikes(df, data_location, population_size, subtract_delay=True):
  10. """Assembles the spike times for a 'population' of neurons.
  11. Parameters
  12. ----------
  13. df : pd.DataFrame
  14. The DataFrame containing the information about where to find the recorded white noise responses.
  15. population_size : int
  16. The size of the population, i.e. the number of randomly selected trials.
  17. subtract_delay : bool, optional
  18. Defines whether the neuronal delay should be subtracted from the spike times, by default True
  19. Returns
  20. -------
  21. list:
  22. list of spike times of the randomly selected trials.
  23. list:
  24. list containing whether or not the response needs to be inverted.
  25. list:
  26. the respective cell ids
  27. list:
  28. the respective trial ids.
  29. """
  30. cell_indices = np.arange(len(df.dataset_id.unique()))
  31. np.random.shuffle(cell_indices)
  32. cells = df.dataset_id.unique()[cell_indices[:population_size]]
  33. spikes = []
  34. inversion_needed = []
  35. cell_ids = []
  36. trial_ids = []
  37. for c in cells:
  38. if c in spike_buffer.keys():
  39. # print(f"Found spikes of cell {c} in buffer!")
  40. all_spikes = spike_buffer[c]
  41. else:
  42. # print(f"Reading spikes of cell {c} from file!")
  43. dataset = rlx.Dataset(os.path.join(data_location, c + ".nix"))
  44. trace_name = "spikes-1" if "spikes-1" in dataset.event_traces else "Spikes-1"
  45. all_spikes = dataset.event_traces[trace_name].data_array[:]
  46. spike_buffer[c] = all_spikes
  47. cell_trials = df[df.dataset_id == c]
  48. trial_index = np.random.randint(min(cell_trials.index), max(cell_trials.index)+1)
  49. trial = cell_trials[cell_trials.index == trial_index]
  50. trial_spikes = all_spikes[(all_spikes >= trial.start_time.values[0]) & (all_spikes < trial.end_time.values[0])] - trial.start_time.values[0]
  51. if subtract_delay:
  52. trial_spikes -= trial.delay.values[0] # compensate for the delay between stimulus and response
  53. spikes.append(trial_spikes[trial_spikes > 0])
  54. inversion_needed.append(trial.inverted.values[0])
  55. cell_ids.append(c)
  56. trial_ids.append(trial.stimulus_index.values[0])
  57. return spikes, inversion_needed, cell_ids, trial_ids
  58. def mutual_info_heterogeneous_population(df, data_location, stim_location, population_size=10, delay=0.01, repeats=10,
  59. kernel_sigma=0.00125, saving=False, result_location=".", delay_type="equal"):
  60. """_summary_
  61. Parameters
  62. ----------
  63. df : _type_
  64. _description_
  65. data_location : _type_
  66. _description_
  67. stim_location : _type_
  68. _description_
  69. population_size : int, optional
  70. _description_, by default 10
  71. delay : float, optional
  72. _description_, by default 0.01
  73. repeats : int, optional
  74. _description_, by default 10
  75. kernel_sigma : float, optional
  76. _description_, by default 0.00125
  77. saving : bool, optional
  78. _description_, by default False
  79. result_location : str, optional
  80. _description_, by default "."
  81. delay_type : str, optional
  82. _description_, by default "equal"
  83. Returns
  84. -------
  85. _type_
  86. _description_
  87. """
  88. coherences = []
  89. frequency = []
  90. _, stimulus = load_white_noise_stim(os.path.join(stim_location, stimulus_name))
  91. results = []
  92. for i in range(repeats):
  93. r = {"pop_size": population_size, "delay": delay, "true_delay": 0.0, "kernel_sigma": kernel_sigma, "num_inversions": 0, "population_rate": 0.0, "rate_modulation": 0.0,
  94. "snr": 0.0, "mi":0.0, "mi_100": 0.0, "mi_200": 0.0, "mi_300": 0.0, "cell_ids":[], "mtag_ids":[], "trial_ids":[], "result_file":"" }
  95. print("population size: %i, delay: %.5fs, kernel: %.5f, repeat: %i" % (population_size, delay, kernel_sigma, i), end="\r" )
  96. subtract_delay = delay > 0.0
  97. pop_spikes, inversion_needed, cell_ids, trial_ids = population_spikes(df, data_location, population_size, subtract_delay)
  98. f, c, mis, rates, true_delay = mutual_info(pop_spikes, delay, inversion_needed, stimulus, freq_bin_edges=[0, 100, 200, 300], kernel_sigma=kernel_sigma, delay_type=delay_type)
  99. coherences.append(c)
  100. frequency = f
  101. smoothing_kernel = np.ones(19)/19
  102. sc = np.convolve(c, smoothing_kernel, mode="same")
  103. max_coh = np.max(sc[f <= 300])
  104. peak_f = f[np.where(sc == max_coh)]
  105. upper_cutoff = f[(sc <= max_coh/np.sqrt(2)) & (f > peak_f)][0]
  106. lower_cutoff = f[(sc <= max_coh/np.sqrt(2)) & (f < peak_f)][-1] if len(f[(sc <= max_coh/np.sqrt(2)) & (f < peak_f)]) > 0 else 0.0
  107. avg_response = np.mean(rates, axis=1) if population_size > 1 else rates
  108. avg_rate = np.mean(avg_response)
  109. rate_error = np.std(rates, axis=1) if population_size > 1 else None
  110. variability = np.mean(rate_error, axis=0) if rate_error is not None else None
  111. snr = np.std(avg_rate) / variability if variability is not None else None
  112. outfile_name = "pop_size%i_repeat%i_delay%.5f.npz" % (population_size, i, delay)
  113. outfile_full = os.path.join(result_location, outfile_name)
  114. if saving:
  115. np.savez_compressed(outfile_full, coherences=coherences, frequencies=frequency, avg_rate=avg_rate, rate_std=rate_error)
  116. r["true_delay"] = true_delay
  117. r["num_inversions"] = np.sum(inversion_needed)
  118. r["population_rate"] = avg_rate
  119. r["rate_modulation"] = np.std(avg_response)
  120. r["peak_coh"] = max_coh
  121. r["upper_cutoff"] = upper_cutoff
  122. r["lower_cutoff"] = lower_cutoff
  123. r["peak_freq"] = peak_f
  124. r["snr"] = snr if snr is not None else 0.0
  125. r["variability"] = variability if variability is not None else 0.0
  126. r["mi"] = mis[0]
  127. r["mi_100"] = mis[1]
  128. r["mi_200"] = mis[2]
  129. r["mi_300"] = mis[3]
  130. r["cell_ids"] = cell_ids
  131. r["trial_ids"] = trial_ids
  132. r["result_file"] = outfile_name if not saving else ""
  133. results.append(r)
  134. return results
  135. def run_heterogeneous_population_analysis(whitenoise_trial_file, data_location, stim_location,
  136. population_sizes=[2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
  137. delays=[0.0, 0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008, 0.009, 0.01, 0.0125, 0.015],
  138. kernels=[0.000125, 0.00025, 0.0005, 0.001, 0.002, 0.003, 0.005, 0.01, 0.015, 0.02, 0.025],
  139. repeats=50, saving=False, num_cores=1):
  140. """Runs the population analysis for heterogeneous populations.
  141. Parameters
  142. ----------
  143. whitenoise_trial_file : str
  144. The dataframe containing the information of the white noise recordings.
  145. data_location : str
  146. The path to the raw data.
  147. stim_location : str
  148. The path to the folder containing the stimuli.
  149. population_sizes : list of ints, optional
  150. List of population sizes, by default [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
  151. delays : list, optional
  152. list of artificial delays added to the spike times. -1 indicates that the original delay should not be changed, by default [0.0, 0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008, 0.009, 0.01, 0.0125, 0.015]
  153. kernels : list, optional
  154. list of Gaussian kernel standard deviations used for firing rate estimation, by default [0.000125, 0.00025, 0.0005, 0.001, 0.002, 0.003, 0.005, 0.01, 0.015, 0.02, 0.025]
  155. repeats : int, optional
  156. Maximum number of repetitions for each combination of population sizes, delays, kernels. For each repetition a new set of responses will be used. By default 50
  157. saving : bool, optional
  158. Flag that indicates whether or not coherence spectra, firing rates etc should be saved, by default False
  159. num_cores : int, optional
  160. number of parallel jobs spawned to run the analyses, by default 1
  161. Returns
  162. -------
  163. pd.DataFrame
  164. The DataFrame containing the analysis results.
  165. """
  166. result_dicts = []
  167. whitenoise_trials = pd.read_csv(whitenoise_trial_file, sep=";", index_col=0)
  168. whitenoise_trials = whitenoise_trials[(whitenoise_trials.stimfile == stimulus_name) & (whitenoise_trials.duration == 10)]
  169. conditions = []
  170. for k in kernels:
  171. for ps in population_sizes:
  172. for d in delays:
  173. conditions.append((k, ps, d))
  174. processed_list = Parallel(n_jobs=num_cores)(delayed(mutual_info_heterogeneous_population)(whitenoise_trials, data_location, stim_location,
  175. population_size=condition[1],
  176. delay=condition[2],
  177. repeats=repeats,
  178. kernel_sigma=condition[0],
  179. saving=saving,
  180. delay_type="gaussian") for condition in conditions)
  181. for pr in processed_list:
  182. if pr is not None:
  183. result_dicts.extend(pr)
  184. df = pd.DataFrame(result_dicts)
  185. return df