import os import numpy as np import pandas as pd import rlxnix as rlx import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from .figure_style import * from ..util import mutual_info, load_white_noise_stim from IPython import embed class WhitenoiseDataset(object): def __init__(self, name) -> None: self._name = name self._baserate = 0.0 self._rate_modulation = 0.0 self._cv = 0.0 self._all_spikes = None self._spike_responses = [] self._isihist = None self._coherence = None self._trial_data = None self._rates = None self._delay = 0.0 self._time = None self._mi = None self._kernel = 0.00125 self._stimfile = None self._mi_expectation = 0.0 self._mi_spike = 0.0 self._selected_trials = None def layout_figure(): axes = {} fig = plt.figure(figsize=(6.9, 6.)) gs = gridspec.GridSpec(ncols=3, nrows=1, width_ratios=[5, 0.05, 5]) cols = ["left", "", "right"] inner_cols = ["responses", "", "coherence"] ylabels = ["firing rate [Hz]", "", "coherence"] xlabels = ["time [ms]", "", "frequency [Hz]"] for i in range(3): col = cols[i] if i == 1: continue if col not in axes: axes[col] = {} height_ratios = np.ones(6) height_ratios[-2] = 0.125 sgs = gs[i].subgridspec(nrows=6, ncols=3, width_ratios=[3, 0.25, 2], height_ratios=height_ratios) for j in range(4): row = f"row_{j}" if row not in axes[col]: axes[col][row] = {} for k, inner_col in enumerate(inner_cols): if inner_col == "": continue ax = fig.add_subplot(sgs[j, k]) if j == 0 and k == 0: ax.text(-0.375, 1.05, "A" if i == 0 else "C", transform=ax.transAxes, ha="left", fontsize=label_size, fontweight="bold") despine(ax, ["top", "right"], hide_ticks=False) ax.set_ylabel(ylabels[k]) if j == 3: ax.set_xlabel(xlabels[k]) axes[col][row][inner_col] = ax axes[col]["row_4"] = {} ax = fig.add_subplot(sgs[5, 0]) ax.text(-0.375, 1.05, "B" if col == "left" else "D", transform=ax.transAxes, fontsize=label_size, fontweight="bold", ha="left") despine(ax, ["top", "right"], hide_ticks=False) ax.set_ylabel(ylabels[0]) ax.set_xlabel(xlabels[0]) axes[col]["row_4"]["response"] = ax ax = fig.add_subplot(sgs[5, 2]) despine(ax, ["top", "right"], hide_ticks=False) ax.set_ylabel(ylabels[2]) ax.set_xlabel(xlabels[2]) axes[col]["row_4"]["coherence"] = ax fig.subplots_adjust(left=0.075, right=0.985, top=0.975, bottom=0.06) return fig, axes def find_homogeneous_cells(df): selection = [] cells = df.dataset_id.unique() for cell in cells: trials = df[df.dataset_id == cell] selection.append(trials[0:1]) selection = pd.concat(selection) similar_cells = selection[(selection.rate_modulation > 40) & (selection.rate_modulation < 60) & (selection.baserate > 100) & (selection.baserate < 150) & (selection.cv > 0.3) & (selection.cv < 0.9) & (selection.inverted == False)] selection_trials = [] for i in range(len(similar_cells)): id = similar_cells.iloc[i].dataset_id selection_trials.append(df[(df.dataset_id == id) & (~df.inverted) & (df.rate_modulation > 40) & (df.rate_modulation < 60) & (df.baserate > 100) & (df.baserate < 150) & (df.cv > 0.3) & (df.cv < 0.9) ]) selection_trials = pd.concat(selection_trials) return selection_trials def find_heterogeneous_cells(df): selection = [] cells = df.dataset_id.unique() for cell in cells: trials = df[df.dataset_id == cell] if np.any(trials.inverted): continue selection.append(trials[0:1]) selection = pd.concat(selection) cells = [] cells.append(selection[selection.baserate == np.max(selection.baserate)].dataset_id.values[0]) cells.append(selection.iloc[np.argmin(np.abs(selection.baserate - np.percentile(selection.baserate, 25)))].dataset_id) cells.append(selection.iloc[np.argmin(np.abs(selection.rate_modulation - np.percentile(selection.rate_modulation, 25)))].dataset_id) cells.append(selection[selection.rate_modulation == np.max(selection.rate_modulation)].dataset_id.values[0]) selection_trials = [] for id in cells: selection_trials.append(df[(df.dataset_id == id)]) selection_trials = pd.concat(selection_trials) return selection_trials def read_dataset(c, trial_info, data_location="raw_data", binwidth= 0.0005, kernel_sigma= 0.00125): ds = WhitenoiseDataset(c) d = rlx.Dataset(os.path.join(data_location, c + ".nix")) b = d.repro_runs("Baseline")[0] ds._baserate = b.baseline_rate ds._rate_modulation = np.mean(trial_info.rate_modulation) ds._cv = b.baseline_cv ds._isihist = np.histogram(np.diff(b.spikes()), bins=np.arange(0.0, 0.02, binwidth)) trace_name = "spikes-1" if "spikes-1" in d.event_traces else "Spikes-1" all_spikes = d.event_traces[trace_name].data_array[:] ds._all_spikes = all_spikes ds._rates = [] ds._spike_responses = [] ds._kernel = binwidth for i in range(min(len(trial_info), 10)): ti = trial_info.iloc[i] spikes = all_spikes[(all_spikes >= ti.start_time) & (all_spikes < ti.end_time)] - ti.start_time ds._spike_responses.append(spikes) ds._stimfile = ti.stimfile ds._delay = np.mean(trial_info.delay) time, stimulus = load_white_noise_stim(os.path.join("stimuli", ti.stimfile)) f, c, mis, rates, _ = mutual_info(ds._spike_responses, 0.0, [False for i in range(len(trial_info))], stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma) ds._rates = rates ds._time = time ds._coherence = (f, c) ds._mi = mis[0] ds._mi_spike = ds._mi / ds._baserate return ds def plot_responses(ds, axis, highlight=None, xticklabels=False, spikes_color=None): spikes_axis = axis.twinx() despine(spikes_axis, ["top", "right"], hide_ticks=False) spikes_axis.yaxis.set_ticks([]) spikes_color = "silver" if spikes_color is None else spikes_color event_list = spikes_axis.eventplot(ds._spike_responses[:10], color=spikes_color, linewidths=0.2) axis.plot(ds._time, np.mean(ds._rates, axis=1), color="white", lw=2) axis.plot(ds._time, np.mean(ds._rates, axis=1), color="tab:blue", lw=1.0) axis.set_xlim([0, 0.150]) axis.set_xticks(np.arange(0, 0.151, 0.050)) axis.set_xticks(np.arange(0, 0.151, 0.025), minor=True) if not xticklabels: axis.set_xticklabels([]) else: axis.set_xticklabels(np.arange(0, 151, 50)) axis.set_ylim([0, 800]) axis.set_yticks(np.arange(0, 751, 250)) axis.set_yticks(np.arange(0, 751, 50), minor=True) spikes_axis.set_xlim([0, 0.150]) axis.set_zorder(spikes_axis.get_zorder() + 1) axis.set_frame_on(False) axis.text(1., 1.0, f"rate: {ds._baserate:.1f} $\pm$ {ds._rate_modulation:.1f} Hz", transform=axis.transAxes, fontsize=legend_fontsize, ha="right", va="top") if highlight is None: return for h in highlight: el = event_list[h] el.set_color("tab:orange") el.set_linewidths(0.75) def plot_isih(ds, axis, binwidth=0.00025): axis.bar(ds._isihist[1][:-1] + np.diff(ds._isihist[1])/2, ds._isihist[0], width=binwidth) def plot_coherence(ds, axis, xticklabels=False): f, c = ds._coherence kernel = np.ones(5) / 5 c = np.convolve(c, kernel, mode="same") axis.plot(f, c, lw=1.0, color="tab:blue") axis.set_xlim([0, 300]) axis.set_xticks(np.arange(0, 301, 150)) axis.set_xticks(np.arange(0, 301, 50), minor=True) axis.set_ylim([0, 1.]) axis.set_yticks(np.arange(0.0, 1.01, 0.5)) axis.set_yticks(np.arange(0.0, 1.01, 0.1), minor=True) ypos = 0.6 if np.mean(c[(f > 250) & (f < 300)]) > 0.7 else 1.0 axis.text(1.0, ypos, f"{ds._mi:.1f} bit/s\n{ds._mi_spike:.1f} bit/spike", fontsize=legend_fontsize, ha="right", va="top", transform=axis.transAxes) if not xticklabels: axis.set_xticklabels([]) def get_population_response(datasets, kernel_sigma=0.00125): trial_count = 10 spike_times = [] count = 0 total_mi = 0.0 selected_trials = {} spike_count = 0.0 while count < trial_count: cell_index = count % len(datasets) if cell_index not in selected_trials: selected_trials[cell_index] = [] ds = datasets[cell_index] num_trials = len(ds._spike_responses) trial = None while trial is None or trial in selected_trials[cell_index]: trial = np.random.randint(0, num_trials) spike_times.append(ds._spike_responses[trial] - ds._delay) selected_trials[cell_index].append(trial) total_mi += ds._mi spike_count += len(ds._spike_responses[trial]) count += 1 time, stimulus = load_white_noise_stim(os.path.join("stimuli", ds._stimfile)) f, c, mis, rates, _ = mutual_info(spike_times, 0.0, [False for i in range(len(spike_times))], stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma) results = WhitenoiseDataset("homogeneous population") results._spike_responses = spike_times results._mi = mis[0] results._delay = 0.0 results._baserate = np.mean(np.mean(rates, axis=1)) results._rate_modulation = np.std(np.mean(rates, axis=1)) results._coherence = (f, c) results._rates = rates results._time = time results._mi_expectation = total_mi / trial_count results._mi_spike = results._mi / results._baserate results._selected_trials = selected_trials return results def plot_population(selection, axes, num_cells=4, col="left"): # plot the rasterplot of e.g. 10 trials # highlight one trial # plot the average +- std firing rate datasets = [] data_location = "raw_data" cells = selection.dataset_id.unique() for i in range(num_cells): c = cells[i] datasets.append(read_dataset(c, selection[selection.dataset_id == c], data_location=data_location)) ds = datasets[i] population_response = get_population_response(datasets) for i in range(num_cells): plot_responses(datasets[i], axes[col][f"row_{i}"]["responses"], highlight=population_response._selected_trials[i], xticklabels=i==3) # plot_isih(ds, axes[col][f"row_{i}"]["isi"]) plot_coherence(datasets[i], axes[col][f"row_{i}"]["coherence"], xticklabels=i==3) plot_responses(population_response, axes[col]["row_4"]["response"], spikes_color="tab:orange", xticklabels=True) plot_coherence(population_response, axes[col]["row_4"]["coherence"], xticklabels=True) def plot_methods_plot(args): if not os.path.exists(args.trials): raise ValueError(f"Whitenoise trials data file not found! ({args.trials})") fig, axes = layout_figure() df = pd.read_csv(args.trials, sep=";", index_col=0) df = df[(df.duration == 10.0) & (df.contrast == 10.0)] homogeneous_cells = find_homogeneous_cells(df) plot_population(homogeneous_cells, axes) heterogeneous_cells = find_heterogeneous_cells(df) plot_population(heterogeneous_cells, axes, col="right") if args.nosave: plt.show() else: fig.savefig(args.outfile) plt.close() def command_line_parser(subparsers): whitenoise_trials = os.path.join("derived_data", "whitenoise_trials.csv") population_methods_parser = subparsers.add_parser("population_coding_method", help="") population_methods_parser.add_argument("-t", "--trials", default=whitenoise_trials) population_methods_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "population_methods.pdf")) population_methods_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing") population_methods_parser.set_defaults(func=plot_methods_plot) def main(): embed() df = pd.read_csv("../../derived_data/heterogeneous_populationcoding.csv", sep=";", index_col=0) if __name__ == "__main__": main()