homogeneous_populations.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import os
  2. import math
  3. import random
  4. import logging
  5. import numpy as np
  6. import pandas as pd
  7. import rlxnix as rlx
  8. import nixio as nix
  9. import itertools as it
  10. from joblib import Parallel, delayed
  11. from simplejson import load
  12. from code.util import load_white_noise_stim, mutual_info
  13. from IPython import embed
  14. stimulus_name = "gwn300Hz10s0.3.dat"
  15. def mutual_info_homogenous_population(all_spikes, stimulus_indices, stimulus, kernel, repeats, dataset_name, contrast, cv, burstiness):
  16. """working horse of the analysis...
  17. Parameters
  18. ----------
  19. all_spikes : list
  20. list of np.arrays holding the spike times
  21. stimulus_indices : list
  22. list of stimulus ids
  23. stimulus : np.array
  24. the stimulus waveform
  25. kernel : float
  26. The sigma of the Gaussian kernel used to create firing rates
  27. repeats : int
  28. The maximal number of repeats for each population size
  29. dataset_name : str
  30. the name (id) of the analyzed dataset
  31. contrast : float
  32. the stimulus contrast (i.e. intensity)
  33. cv : float
  34. The coefficient of variation of the interspike intervals of the baseline response
  35. burstiness : float
  36. The burstiness of the baseline response
  37. Returns
  38. -------
  39. list of dictionaries
  40. the analysis results
  41. """
  42. results = []
  43. populations_sizes = np.arange(1, len(all_spikes) + 1)
  44. for ps in populations_sizes:
  45. possible_combinations = None
  46. shuf = None
  47. combination_count = math.factorial(len(all_spikes)) / (math.factorial(ps) * math.factorial(len(all_spikes) - ps))
  48. if combination_count < 10000:
  49. possible_combinations = list(it.combinations(np.arange(len(all_spikes)), ps))
  50. shuf = np.arange(len(possible_combinations))
  51. np.random.shuffle(shuf)
  52. combination_count = len(possible_combinations)
  53. coherences = []
  54. count = 0
  55. while count < repeats and count < combination_count:
  56. print("\tpopulation_size: %i, repeat: %i" % (ps, count), end="\r", flush=True)
  57. if possible_combinations is not None:
  58. combination = possible_combinations[shuf[count]]
  59. else:
  60. combination = random.sample(range(len(all_spikes)), ps)
  61. results_dict = {"dataset_id": dataset_name, "contrast": contrast, "pop_size": ps, "kernel": kernel, "snr": 0.0, "mi":0.0, "mi_100": 0.0, "mi_200": 0.0, "mi_300": 0.0, "result_file":"n.a."}
  62. spikes = []
  63. for i in combination:
  64. spikes.append(all_spikes[i])
  65. f, coh, mis, rates, _ = mutual_info(spikes, 0.0, np.zeros(len(spikes)), stimulus, freq_bin_edges=[0, 100, 200, 300], kernel_sigma=kernel)
  66. coherences.append(coh)
  67. smoothing_kernel = np.ones(19)/19
  68. sc = np.convolve(coh, smoothing_kernel, mode="same")
  69. max_coh = np.max(sc[f <= 300])
  70. peak_f = f[np.where(sc == max_coh)]
  71. upper_cutoff = f[(sc <= max_coh/np.sqrt(2)) & (f > peak_f)][0]
  72. lower_cutoff = f[(sc <= max_coh/np.sqrt(2)) & (f < peak_f)][-1] if len(f[(sc <= max_coh/np.sqrt(2)) & (f < peak_f)]) > 0 else 0.0
  73. avg_response = np.mean(rates, axis=1) if ps > 1 else rates
  74. avg_rate = np.mean(avg_response)
  75. rate_error = np.std(rates, axis=1) if ps > 1 else None
  76. variability = np.mean(rate_error, axis=0) if rate_error is not None else None
  77. snr = np.std(avg_rate) / variability if variability is not None else None
  78. results_dict["population_rate"] = avg_rate
  79. results_dict["rate_modulation"] = np.std(avg_response)
  80. results_dict["peak_coh"] = max_coh
  81. results_dict["upper_cutoff"] = upper_cutoff
  82. results_dict["lower_cutoff"] = lower_cutoff
  83. results_dict["peak_freq"] = peak_f
  84. results_dict["variability"] = variability if variability is not None else 0.0
  85. results_dict["snr"] = snr if snr is not None else 0.0
  86. results_dict["mi"] = mis[0]
  87. results_dict["mi_100"] = mis[1]
  88. results_dict["mi_200"] = mis[2]
  89. results_dict["mi_300"] = mis[3]
  90. try:
  91. results_dict["stimulus_indices"] = stimulus_indices[np.array(combination)]
  92. except:
  93. embed()
  94. exit()
  95. results_dict["cv"] = cv
  96. results_dict["burstiness"] = burstiness
  97. results.append(results_dict)
  98. count += 1
  99. print("\n", end="")
  100. return results
  101. def get_spikes(trials, dataset):
  102. """Get all the spike trains that belong to one stimulus condition.
  103. Paramters
  104. ----------
  105. trials : pd.DataFrame
  106. the trial information
  107. dataset : rlx.Dataset
  108. the dataset instance that grants access to the the spikes event trace
  109. Returns
  110. -------
  111. list
  112. The spike times for each trial as np.ndarray, relative to trial onset in seconds.
  113. np.ndarray
  114. The trial indexes
  115. """
  116. trace_name = "spikes-1" if "spikes-1" in dataset.event_traces else "Spikes-1"
  117. all_spikes = dataset.event_traces[trace_name].data_array[:]
  118. spikes = []
  119. for s, e in zip(trials.start_time.values, trials.end_time.values):
  120. spikes.append(all_spikes[(all_spikes >= s) & (all_spikes < e)] - s)
  121. trial_ids = trials.stimulus_index.unique()
  122. return spikes, trial_ids
  123. def homogeneous_population_dataset(dataset_name, trial_information, data_location, stim_location, kernel_size=0.001, repeats=10):
  124. """
  125. Parameters
  126. ----------
  127. dataset_name : str
  128. name of the dataset
  129. trial_information : pd.DataFrame
  130. DataFrame containing the information about the whitenoise trials recorded in that datasets/cell
  131. data_location : str
  132. location of the raw data.
  133. stim_location : str
  134. location of the stimulus file.
  135. kernel_size : float, optional
  136. sigma of the Gaussian kernel used to create the firing rate, by default 0.001
  137. repeats : int, optional
  138. maximal number of different populations created from the trials, by default 10
  139. Returns
  140. -------
  141. list of dictionaries
  142. the dicts contain the analysis results for each repetition at each population size.
  143. """
  144. dataset = rlx.Dataset(os.path.join(data_location, dataset_name + ".nix"))
  145. _, stimulus = load_white_noise_stim(os.path.join(stim_location, stimulus_name))
  146. contrasts = trial_information.contrast.unique()
  147. results = []
  148. for c in contrasts:
  149. contrast_trials = trial_information[trial_information.contrast == c]
  150. cv = contrast_trials.cv.unique()[0]
  151. burstiness = contrast_trials.burstiness.unique()[0]
  152. trial_spikes, trial_indices = get_spikes(contrast_trials, dataset)
  153. print(f"\tcontrast {c}, stimulus: {stimulus_name} number of trials: {len(trial_spikes)}")
  154. results.extend(mutual_info_homogenous_population(trial_spikes, trial_indices, stimulus, kernel_size,
  155. repeats,dataset_name, c, cv, burstiness))
  156. return results
  157. def process_cell(dataset_name, whitenoise_trials, data_location, stim_location, kernel_size=0.001, repeats=10):
  158. """entry point for the homogeneous population analysis
  159. Args:
  160. dataset (str): the name of the dataset.
  161. whitenoise_trials (pd.DataFrame): the data frame containing the cell properties as analyzed with "WhiteNoiseOverview.py".
  162. stim_location (str): path to the stimulus files.
  163. kernel_size (float, optional): std of the Gaussian kernel used to calculate the firing rates. Defaults to 0.001.
  164. repeats (int, optional): Number of populations drawn for each population size. Defaults to 10.
  165. Returns:
  166. list of dictionaries: the dictionaries contain the analysis results for each repetition and population size.
  167. """
  168. logging.info(f"Homogeneous populations: processing {dataset_name} ...")
  169. trials = whitenoise_trials[(whitenoise_trials.dataset_id == dataset_name)]
  170. results_dicts = homogeneous_population_dataset(dataset_name, trials, data_location, stim_location, kernel_size, repeats)
  171. return results_dicts
  172. def run_homogeneous_population_analysis(whitenoise_trial_file, data_location, stim_location, kernel_size=0.001, repeats=10, num_cores=1):
  173. """Analyses the coding performance in homogeneous populations constructed from trials of the same cell.
  174. Parameters
  175. ----------
  176. whitenoise_trial_file : pd.DataFrame
  177. The data frame containing information about all trials recorded with the same whitenoise stimulus.
  178. data_location : str
  179. _description_
  180. stim_location : str
  181. _description_
  182. kernel_size : float, optional
  183. _description_, by default 0.001
  184. repeats : int, optional
  185. _description_, by default 10
  186. num_cores : int, optional
  187. _description_, by default 1
  188. Returns
  189. -------
  190. _type_
  191. _description_
  192. """
  193. whitenoise_trials = pd.read_csv(whitenoise_trial_file, sep=";", index_col=0)
  194. whitenoise_trials = whitenoise_trials[(whitenoise_trials.stimfile == stimulus_name) & (whitenoise_trials.duration == 10)]
  195. datasets = whitenoise_trials.dataset_id.unique()
  196. processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(dataset, whitenoise_trials, data_location, stim_location, kernel_size, repeats) for dataset in datasets)
  197. results = []
  198. for pr in processed_list:
  199. if pr is not None:
  200. results.extend(pr)
  201. df = pd.DataFrame(results)
  202. return df