populations_method.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import rlxnix as rlx
  5. import matplotlib.pyplot as plt
  6. import matplotlib.gridspec as gridspec
  7. from .figure_style import *
  8. from ..util import mutual_info, load_white_noise_stim
  9. from IPython import embed
  10. class WhitenoiseDataset(object):
  11. def __init__(self, name) -> None:
  12. self._name = name
  13. self._baserate = 0.0
  14. self._rate_modulation = 0.0
  15. self._cv = 0.0
  16. self._all_spikes = None
  17. self._spike_responses = []
  18. self._isihist = None
  19. self._coherence = None
  20. self._trial_data = None
  21. self._rates = None
  22. self._delay = 0.0
  23. self._time = None
  24. self._mi = None
  25. self._kernel = 0.00125
  26. self._stimfile = None
  27. self._mi_expectation = 0.0
  28. self._mi_spike = 0.0
  29. self._selected_trials = None
  30. def layout_figure():
  31. axes = {}
  32. fig = plt.figure(figsize=(6.9, 6.))
  33. gs = gridspec.GridSpec(ncols=3, nrows=1, width_ratios=[5, 0.05, 5])
  34. cols = ["left", "", "right"]
  35. inner_cols = ["responses", "", "coherence"]
  36. ylabels = ["firing rate [Hz]", "", "coherence"]
  37. xlabels = ["time [ms]", "", "frequency [Hz]"]
  38. for i in range(3):
  39. col = cols[i]
  40. if i == 1:
  41. continue
  42. if col not in axes:
  43. axes[col] = {}
  44. height_ratios = np.ones(6)
  45. height_ratios[-2] = 0.125
  46. sgs = gs[i].subgridspec(nrows=6, ncols=3, width_ratios=[3, 0.25, 2], height_ratios=height_ratios)
  47. for j in range(4):
  48. row = f"row_{j}"
  49. if row not in axes[col]:
  50. axes[col][row] = {}
  51. for k, inner_col in enumerate(inner_cols):
  52. if inner_col == "":
  53. continue
  54. ax = fig.add_subplot(sgs[j, k])
  55. if j == 0 and k == 0:
  56. ax.text(-0.375, 1.05, "A" if i == 0 else "C", transform=ax.transAxes, ha="left",
  57. fontsize=label_size, fontweight="bold")
  58. despine(ax, ["top", "right"], hide_ticks=False)
  59. ax.set_ylabel(ylabels[k])
  60. if j == 3:
  61. ax.set_xlabel(xlabels[k])
  62. axes[col][row][inner_col] = ax
  63. axes[col]["row_4"] = {}
  64. ax = fig.add_subplot(sgs[5, 0])
  65. ax.text(-0.375, 1.05, "B" if col == "left" else "D", transform=ax.transAxes,
  66. fontsize=label_size, fontweight="bold", ha="left")
  67. despine(ax, ["top", "right"], hide_ticks=False)
  68. ax.set_ylabel(ylabels[0])
  69. ax.set_xlabel(xlabels[0])
  70. axes[col]["row_4"]["response"] = ax
  71. ax = fig.add_subplot(sgs[5, 2])
  72. despine(ax, ["top", "right"], hide_ticks=False)
  73. ax.set_ylabel(ylabels[2])
  74. ax.set_xlabel(xlabels[2])
  75. axes[col]["row_4"]["coherence"] = ax
  76. fig.subplots_adjust(left=0.075, right=0.985, top=0.975, bottom=0.06)
  77. return fig, axes
  78. def find_homogeneous_cells(df):
  79. selection = []
  80. cells = df.dataset_id.unique()
  81. for cell in cells:
  82. trials = df[df.dataset_id == cell]
  83. selection.append(trials[0:1])
  84. selection = pd.concat(selection)
  85. similar_cells = selection[(selection.rate_modulation > 40) & (selection.rate_modulation < 60) &
  86. (selection.baserate > 100) & (selection.baserate < 150) &
  87. (selection.cv > 0.3) & (selection.cv < 0.9) &
  88. (selection.inverted == False)]
  89. selection_trials = []
  90. for i in range(len(similar_cells)):
  91. id = similar_cells.iloc[i].dataset_id
  92. selection_trials.append(df[(df.dataset_id == id) & (~df.inverted) &
  93. (df.rate_modulation > 40) & (df.rate_modulation < 60) &
  94. (df.baserate > 100) & (df.baserate < 150) &
  95. (df.cv > 0.3) & (df.cv < 0.9)
  96. ])
  97. selection_trials = pd.concat(selection_trials)
  98. return selection_trials
  99. def find_heterogeneous_cells(df):
  100. selection = []
  101. cells = df.dataset_id.unique()
  102. for cell in cells:
  103. trials = df[df.dataset_id == cell]
  104. if np.any(trials.inverted):
  105. continue
  106. selection.append(trials[0:1])
  107. selection = pd.concat(selection)
  108. cells = []
  109. cells.append(selection[selection.baserate == np.max(selection.baserate)].dataset_id.values[0])
  110. cells.append(selection.iloc[np.argmin(np.abs(selection.baserate - np.percentile(selection.baserate, 25)))].dataset_id)
  111. cells.append(selection.iloc[np.argmin(np.abs(selection.rate_modulation - np.percentile(selection.rate_modulation, 25)))].dataset_id)
  112. cells.append(selection[selection.rate_modulation == np.max(selection.rate_modulation)].dataset_id.values[0])
  113. selection_trials = []
  114. for id in cells:
  115. selection_trials.append(df[(df.dataset_id == id)])
  116. selection_trials = pd.concat(selection_trials)
  117. return selection_trials
  118. def read_dataset(c, trial_info, data_location="raw_data", binwidth= 0.0005, kernel_sigma= 0.00125):
  119. ds = WhitenoiseDataset(c)
  120. d = rlx.Dataset(os.path.join(data_location, c + ".nix"))
  121. b = d.repro_runs("Baseline")[0]
  122. ds._baserate = b.baseline_rate
  123. ds._rate_modulation = np.mean(trial_info.rate_modulation)
  124. ds._cv = b.baseline_cv
  125. ds._isihist = np.histogram(np.diff(b.spikes()), bins=np.arange(0.0, 0.02, binwidth))
  126. trace_name = "spikes-1" if "spikes-1" in d.event_traces else "Spikes-1"
  127. all_spikes = d.event_traces[trace_name].data_array[:]
  128. ds._all_spikes = all_spikes
  129. ds._rates = []
  130. ds._spike_responses = []
  131. ds._kernel = binwidth
  132. for i in range(min(len(trial_info), 10)):
  133. ti = trial_info.iloc[i]
  134. spikes = all_spikes[(all_spikes >= ti.start_time) & (all_spikes < ti.end_time)] - ti.start_time
  135. ds._spike_responses.append(spikes)
  136. ds._stimfile = ti.stimfile
  137. ds._delay = np.mean(trial_info.delay)
  138. time, stimulus = load_white_noise_stim(os.path.join("stimuli", ti.stimfile))
  139. f, c, mis, rates, _ = mutual_info(ds._spike_responses, 0.0, [False for i in range(len(trial_info))],
  140. stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma)
  141. ds._rates = rates
  142. ds._time = time
  143. ds._coherence = (f, c)
  144. ds._mi = mis[0]
  145. ds._mi_spike = ds._mi / ds._baserate
  146. return ds
  147. def plot_responses(ds, axis, highlight=None, xticklabels=False, spikes_color=None):
  148. spikes_axis = axis.twinx()
  149. despine(spikes_axis, ["top", "right"], hide_ticks=False)
  150. spikes_axis.yaxis.set_ticks([])
  151. spikes_color = "silver" if spikes_color is None else spikes_color
  152. event_list = spikes_axis.eventplot(ds._spike_responses[:10], color=spikes_color, linewidths=0.2)
  153. axis.plot(ds._time, np.mean(ds._rates, axis=1), color="white", lw=2)
  154. axis.plot(ds._time, np.mean(ds._rates, axis=1), color="tab:blue", lw=1.0)
  155. axis.set_xlim([0, 0.150])
  156. axis.set_xticks(np.arange(0, 0.151, 0.050))
  157. axis.set_xticks(np.arange(0, 0.151, 0.025), minor=True)
  158. if not xticklabels:
  159. axis.set_xticklabels([])
  160. else:
  161. axis.set_xticklabels(np.arange(0, 151, 50))
  162. axis.set_ylim([0, 800])
  163. axis.set_yticks(np.arange(0, 751, 250))
  164. axis.set_yticks(np.arange(0, 751, 50), minor=True)
  165. spikes_axis.set_xlim([0, 0.150])
  166. axis.set_zorder(spikes_axis.get_zorder() + 1)
  167. axis.set_frame_on(False)
  168. axis.text(1., 1.0, f"rate: {ds._baserate:.1f} $\pm$ {ds._rate_modulation:.1f} Hz",
  169. transform=axis.transAxes, fontsize=legend_fontsize, ha="right", va="top")
  170. if highlight is None:
  171. return
  172. for h in highlight:
  173. el = event_list[h]
  174. el.set_color("tab:orange")
  175. el.set_linewidths(0.75)
  176. def plot_isih(ds, axis, binwidth=0.00025):
  177. axis.bar(ds._isihist[1][:-1] + np.diff(ds._isihist[1])/2, ds._isihist[0], width=binwidth)
  178. def plot_coherence(ds, axis, xticklabels=False):
  179. f, c = ds._coherence
  180. kernel = np.ones(5) / 5
  181. c = np.convolve(c, kernel, mode="same")
  182. axis.plot(f, c, lw=1.0, color="tab:blue")
  183. axis.set_xlim([0, 300])
  184. axis.set_xticks(np.arange(0, 301, 150))
  185. axis.set_xticks(np.arange(0, 301, 50), minor=True)
  186. axis.set_ylim([0, 1.])
  187. axis.set_yticks(np.arange(0.0, 1.01, 0.5))
  188. axis.set_yticks(np.arange(0.0, 1.01, 0.1), minor=True)
  189. ypos = 0.6 if np.mean(c[(f > 250) & (f < 300)]) > 0.7 else 1.0
  190. axis.text(1.0, ypos, f"{ds._mi:.1f} bit/s\n{ds._mi_spike:.1f} bit/spike", fontsize=legend_fontsize,
  191. ha="right", va="top", transform=axis.transAxes)
  192. if not xticklabels:
  193. axis.set_xticklabels([])
  194. def get_population_response(datasets, kernel_sigma=0.00125):
  195. trial_count = 10
  196. spike_times = []
  197. count = 0
  198. total_mi = 0.0
  199. selected_trials = {}
  200. spike_count = 0.0
  201. while count < trial_count:
  202. cell_index = count % len(datasets)
  203. if cell_index not in selected_trials:
  204. selected_trials[cell_index] = []
  205. ds = datasets[cell_index]
  206. num_trials = len(ds._spike_responses)
  207. trial = None
  208. while trial is None or trial in selected_trials[cell_index]:
  209. trial = np.random.randint(0, num_trials)
  210. spike_times.append(ds._spike_responses[trial] - ds._delay)
  211. selected_trials[cell_index].append(trial)
  212. total_mi += ds._mi
  213. spike_count += len(ds._spike_responses[trial])
  214. count += 1
  215. time, stimulus = load_white_noise_stim(os.path.join("stimuli", ds._stimfile))
  216. f, c, mis, rates, _ = mutual_info(spike_times, 0.0, [False for i in range(len(spike_times))],
  217. stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma)
  218. results = WhitenoiseDataset("homogeneous population")
  219. results._spike_responses = spike_times
  220. results._mi = mis[0]
  221. results._delay = 0.0
  222. results._baserate = np.mean(np.mean(rates, axis=1))
  223. results._rate_modulation = np.std(np.mean(rates, axis=1))
  224. results._coherence = (f, c)
  225. results._rates = rates
  226. results._time = time
  227. results._mi_expectation = total_mi / trial_count
  228. results._mi_spike = results._mi / results._baserate
  229. results._selected_trials = selected_trials
  230. return results
  231. def plot_population(selection, axes, num_cells=4, col="left"):
  232. # plot the rasterplot of e.g. 10 trials
  233. # highlight one trial
  234. # plot the average +- std firing rate
  235. datasets = []
  236. data_location = "raw_data"
  237. cells = selection.dataset_id.unique()
  238. for i in range(num_cells):
  239. c = cells[i]
  240. datasets.append(read_dataset(c, selection[selection.dataset_id == c], data_location=data_location))
  241. ds = datasets[i]
  242. population_response = get_population_response(datasets)
  243. for i in range(num_cells):
  244. plot_responses(datasets[i], axes[col][f"row_{i}"]["responses"], highlight=population_response._selected_trials[i], xticklabels=i==3)
  245. # plot_isih(ds, axes[col][f"row_{i}"]["isi"])
  246. plot_coherence(datasets[i], axes[col][f"row_{i}"]["coherence"], xticklabels=i==3)
  247. plot_responses(population_response, axes[col]["row_4"]["response"], spikes_color="tab:orange", xticklabels=True)
  248. plot_coherence(population_response, axes[col]["row_4"]["coherence"], xticklabels=True)
  249. def plot_methods_plot(args):
  250. if not os.path.exists(args.trials):
  251. raise ValueError(f"Whitenoise trials data file not found! ({args.trials})")
  252. fig, axes = layout_figure()
  253. df = pd.read_csv(args.trials, sep=";", index_col=0)
  254. df = df[(df.duration == 10.0) & (df.contrast == 10.0)]
  255. homogeneous_cells = find_homogeneous_cells(df)
  256. plot_population(homogeneous_cells, axes)
  257. heterogeneous_cells = find_heterogeneous_cells(df)
  258. plot_population(heterogeneous_cells, axes, col="right")
  259. if args.nosave:
  260. plt.show()
  261. else:
  262. fig.savefig(args.outfile)
  263. plt.close()
  264. def command_line_parser(subparsers):
  265. whitenoise_trials = os.path.join("derived_data", "whitenoise_trials.csv")
  266. population_methods_parser = subparsers.add_parser("population_coding_method", help="")
  267. population_methods_parser.add_argument("-t", "--trials", default=whitenoise_trials)
  268. population_methods_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "population_methods.pdf"))
  269. population_methods_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  270. population_methods_parser.set_defaults(func=plot_methods_plot)
  271. def main():
  272. embed()
  273. df = pd.read_csv("../../derived_data/heterogeneous_populationcoding.csv", sep=";", index_col=0)
  274. if __name__ == "__main__":
  275. main()