123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import os
- import numpy as np
- import pandas as pd
- import rlxnix as rlx
- import matplotlib.pyplot as plt
- import matplotlib.gridspec as gridspec
- from .figure_style import *
- from ..util import mutual_info, load_white_noise_stim
- from IPython import embed
- class WhitenoiseDataset(object):
- def __init__(self, name) -> None:
- self._name = name
- self._baserate = 0.0
- self._rate_modulation = 0.0
- self._cv = 0.0
- self._all_spikes = None
- self._spike_responses = []
- self._isihist = None
- self._coherence = None
- self._trial_data = None
- self._rates = None
- self._delay = 0.0
- self._time = None
- self._mi = None
- self._kernel = 0.00125
- self._stimfile = None
- self._mi_expectation = 0.0
- self._mi_spike = 0.0
- self._selected_trials = None
- def layout_figure():
- axes = {}
- fig = plt.figure(figsize=(6.9, 6.))
- gs = gridspec.GridSpec(ncols=3, nrows=1, width_ratios=[5, 0.05, 5])
- cols = ["left", "", "right"]
- inner_cols = ["responses", "", "coherence"]
- ylabels = ["firing rate [Hz]", "", "coherence"]
- xlabels = ["time [ms]", "", "frequency [Hz]"]
- for i in range(3):
- col = cols[i]
- if i == 1:
- continue
- if col not in axes:
- axes[col] = {}
- height_ratios = np.ones(6)
- height_ratios[-2] = 0.125
- sgs = gs[i].subgridspec(nrows=6, ncols=3, width_ratios=[3, 0.25, 2], height_ratios=height_ratios)
- for j in range(4):
- row = f"row_{j}"
- if row not in axes[col]:
- axes[col][row] = {}
- for k, inner_col in enumerate(inner_cols):
- if inner_col == "":
- continue
- ax = fig.add_subplot(sgs[j, k])
- if j == 0 and k == 0:
- ax.text(-0.375, 1.05, "A" if i == 0 else "C", transform=ax.transAxes, ha="left",
- fontsize=label_size, fontweight="bold")
- despine(ax, ["top", "right"], hide_ticks=False)
- ax.set_ylabel(ylabels[k])
- if j == 3:
- ax.set_xlabel(xlabels[k])
- axes[col][row][inner_col] = ax
- axes[col]["row_4"] = {}
- ax = fig.add_subplot(sgs[5, 0])
- ax.text(-0.375, 1.05, "B" if col == "left" else "D", transform=ax.transAxes,
- fontsize=label_size, fontweight="bold", ha="left")
- despine(ax, ["top", "right"], hide_ticks=False)
- ax.set_ylabel(ylabels[0])
- ax.set_xlabel(xlabels[0])
- axes[col]["row_4"]["response"] = ax
- ax = fig.add_subplot(sgs[5, 2])
- despine(ax, ["top", "right"], hide_ticks=False)
- ax.set_ylabel(ylabels[2])
- ax.set_xlabel(xlabels[2])
- axes[col]["row_4"]["coherence"] = ax
- fig.subplots_adjust(left=0.075, right=0.985, top=0.975, bottom=0.06)
- return fig, axes
- def find_homogeneous_cells(df):
- selection = []
- cells = df.dataset_id.unique()
- for cell in cells:
- trials = df[df.dataset_id == cell]
- selection.append(trials[0:1])
- selection = pd.concat(selection)
- similar_cells = selection[(selection.rate_modulation > 40) & (selection.rate_modulation < 60) &
- (selection.baserate > 100) & (selection.baserate < 150) &
- (selection.cv > 0.3) & (selection.cv < 0.9) &
- (selection.inverted == False)]
- selection_trials = []
- for i in range(len(similar_cells)):
- id = similar_cells.iloc[i].dataset_id
- selection_trials.append(df[(df.dataset_id == id) & (~df.inverted) &
- (df.rate_modulation > 40) & (df.rate_modulation < 60) &
- (df.baserate > 100) & (df.baserate < 150) &
- (df.cv > 0.3) & (df.cv < 0.9)
- ])
- selection_trials = pd.concat(selection_trials)
- return selection_trials
- def find_heterogeneous_cells(df):
- selection = []
- cells = df.dataset_id.unique()
- for cell in cells:
- trials = df[df.dataset_id == cell]
- if np.any(trials.inverted):
- continue
- selection.append(trials[0:1])
- selection = pd.concat(selection)
- cells = []
- cells.append(selection[selection.baserate == np.max(selection.baserate)].dataset_id.values[0])
- cells.append(selection.iloc[np.argmin(np.abs(selection.baserate - np.percentile(selection.baserate, 25)))].dataset_id)
- cells.append(selection.iloc[np.argmin(np.abs(selection.rate_modulation - np.percentile(selection.rate_modulation, 25)))].dataset_id)
- cells.append(selection[selection.rate_modulation == np.max(selection.rate_modulation)].dataset_id.values[0])
- selection_trials = []
- for id in cells:
- selection_trials.append(df[(df.dataset_id == id)])
- selection_trials = pd.concat(selection_trials)
- return selection_trials
- def read_dataset(c, trial_info, data_location="raw_data", binwidth= 0.0005, kernel_sigma= 0.00125):
- ds = WhitenoiseDataset(c)
- d = rlx.Dataset(os.path.join(data_location, c + ".nix"))
- b = d.repro_runs("Baseline")[0]
- ds._baserate = b.baseline_rate
- ds._rate_modulation = np.mean(trial_info.rate_modulation)
- ds._cv = b.baseline_cv
- ds._isihist = np.histogram(np.diff(b.spikes()), bins=np.arange(0.0, 0.02, binwidth))
- trace_name = "spikes-1" if "spikes-1" in d.event_traces else "Spikes-1"
- all_spikes = d.event_traces[trace_name].data_array[:]
- ds._all_spikes = all_spikes
- ds._rates = []
- ds._spike_responses = []
- ds._kernel = binwidth
-
- for i in range(min(len(trial_info), 10)):
- ti = trial_info.iloc[i]
- spikes = all_spikes[(all_spikes >= ti.start_time) & (all_spikes < ti.end_time)] - ti.start_time
- ds._spike_responses.append(spikes)
- ds._stimfile = ti.stimfile
- ds._delay = np.mean(trial_info.delay)
- time, stimulus = load_white_noise_stim(os.path.join("stimuli", ti.stimfile))
- f, c, mis, rates, _ = mutual_info(ds._spike_responses, 0.0, [False for i in range(len(trial_info))],
- stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma)
- ds._rates = rates
- ds._time = time
- ds._coherence = (f, c)
- ds._mi = mis[0]
- ds._mi_spike = ds._mi / ds._baserate
- return ds
- def plot_responses(ds, axis, highlight=None, xticklabels=False, spikes_color=None):
- spikes_axis = axis.twinx()
- despine(spikes_axis, ["top", "right"], hide_ticks=False)
- spikes_axis.yaxis.set_ticks([])
- spikes_color = "silver" if spikes_color is None else spikes_color
- event_list = spikes_axis.eventplot(ds._spike_responses[:10], color=spikes_color, linewidths=0.2)
- axis.plot(ds._time, np.mean(ds._rates, axis=1), color="white", lw=2)
- axis.plot(ds._time, np.mean(ds._rates, axis=1), color="tab:blue", lw=1.0)
- axis.set_xlim([0, 0.150])
- axis.set_xticks(np.arange(0, 0.151, 0.050))
- axis.set_xticks(np.arange(0, 0.151, 0.025), minor=True)
- if not xticklabels:
- axis.set_xticklabels([])
- else:
- axis.set_xticklabels(np.arange(0, 151, 50))
- axis.set_ylim([0, 800])
- axis.set_yticks(np.arange(0, 751, 250))
- axis.set_yticks(np.arange(0, 751, 50), minor=True)
- spikes_axis.set_xlim([0, 0.150])
- axis.set_zorder(spikes_axis.get_zorder() + 1)
- axis.set_frame_on(False)
- axis.text(1., 1.0, f"rate: {ds._baserate:.1f} $\pm$ {ds._rate_modulation:.1f} Hz",
- transform=axis.transAxes, fontsize=legend_fontsize, ha="right", va="top")
- if highlight is None:
- return
- for h in highlight:
- el = event_list[h]
- el.set_color("tab:orange")
- el.set_linewidths(0.75)
- def plot_isih(ds, axis, binwidth=0.00025):
- axis.bar(ds._isihist[1][:-1] + np.diff(ds._isihist[1])/2, ds._isihist[0], width=binwidth)
- def plot_coherence(ds, axis, xticklabels=False):
- f, c = ds._coherence
- kernel = np.ones(5) / 5
- c = np.convolve(c, kernel, mode="same")
- axis.plot(f, c, lw=1.0, color="tab:blue")
- axis.set_xlim([0, 300])
- axis.set_xticks(np.arange(0, 301, 150))
- axis.set_xticks(np.arange(0, 301, 50), minor=True)
- axis.set_ylim([0, 1.])
- axis.set_yticks(np.arange(0.0, 1.01, 0.5))
- axis.set_yticks(np.arange(0.0, 1.01, 0.1), minor=True)
- ypos = 0.6 if np.mean(c[(f > 250) & (f < 300)]) > 0.7 else 1.0
- axis.text(1.0, ypos, f"{ds._mi:.1f} bit/s\n{ds._mi_spike:.1f} bit/spike", fontsize=legend_fontsize,
- ha="right", va="top", transform=axis.transAxes)
- if not xticklabels:
- axis.set_xticklabels([])
- def get_population_response(datasets, kernel_sigma=0.00125):
- trial_count = 10
- spike_times = []
- count = 0
- total_mi = 0.0
- selected_trials = {}
- spike_count = 0.0
- while count < trial_count:
- cell_index = count % len(datasets)
- if cell_index not in selected_trials:
- selected_trials[cell_index] = []
- ds = datasets[cell_index]
- num_trials = len(ds._spike_responses)
- trial = None
- while trial is None or trial in selected_trials[cell_index]:
- trial = np.random.randint(0, num_trials)
- spike_times.append(ds._spike_responses[trial] - ds._delay)
- selected_trials[cell_index].append(trial)
- total_mi += ds._mi
- spike_count += len(ds._spike_responses[trial])
- count += 1
- time, stimulus = load_white_noise_stim(os.path.join("stimuli", ds._stimfile))
- f, c, mis, rates, _ = mutual_info(spike_times, 0.0, [False for i in range(len(spike_times))],
- stimulus, freq_bin_edges=[300], kernel_sigma=kernel_sigma)
- results = WhitenoiseDataset("homogeneous population")
- results._spike_responses = spike_times
- results._mi = mis[0]
- results._delay = 0.0
- results._baserate = np.mean(np.mean(rates, axis=1))
- results._rate_modulation = np.std(np.mean(rates, axis=1))
- results._coherence = (f, c)
- results._rates = rates
- results._time = time
- results._mi_expectation = total_mi / trial_count
- results._mi_spike = results._mi / results._baserate
- results._selected_trials = selected_trials
- return results
- def plot_population(selection, axes, num_cells=4, col="left"):
- # plot the rasterplot of e.g. 10 trials
- # highlight one trial
- # plot the average +- std firing rate
- datasets = []
- data_location = "raw_data"
- cells = selection.dataset_id.unique()
- for i in range(num_cells):
- c = cells[i]
- datasets.append(read_dataset(c, selection[selection.dataset_id == c], data_location=data_location))
- ds = datasets[i]
- population_response = get_population_response(datasets)
- for i in range(num_cells):
- plot_responses(datasets[i], axes[col][f"row_{i}"]["responses"], highlight=population_response._selected_trials[i], xticklabels=i==3)
- # plot_isih(ds, axes[col][f"row_{i}"]["isi"])
- plot_coherence(datasets[i], axes[col][f"row_{i}"]["coherence"], xticklabels=i==3)
- plot_responses(population_response, axes[col]["row_4"]["response"], spikes_color="tab:orange", xticklabels=True)
- plot_coherence(population_response, axes[col]["row_4"]["coherence"], xticklabels=True)
- def plot_methods_plot(args):
- if not os.path.exists(args.trials):
- raise ValueError(f"Whitenoise trials data file not found! ({args.trials})")
- fig, axes = layout_figure()
- df = pd.read_csv(args.trials, sep=";", index_col=0)
- df = df[(df.duration == 10.0) & (df.contrast == 10.0)]
- homogeneous_cells = find_homogeneous_cells(df)
- plot_population(homogeneous_cells, axes)
- heterogeneous_cells = find_heterogeneous_cells(df)
- plot_population(heterogeneous_cells, axes, col="right")
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile)
- plt.close()
- def command_line_parser(subparsers):
- whitenoise_trials = os.path.join("derived_data", "whitenoise_trials.csv")
- population_methods_parser = subparsers.add_parser("population_coding_method", help="")
- population_methods_parser.add_argument("-t", "--trials", default=whitenoise_trials)
- population_methods_parser.add_argument("-o", "--outfile", default=os.path.join("figures", "population_methods.pdf"))
- population_methods_parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- population_methods_parser.set_defaults(func=plot_methods_plot)
- def main():
- embed()
- df = pd.read_csv("../../derived_data/heterogeneous_populationcoding.csv", sep=";", index_col=0)
- if __name__ == "__main__":
- main()
|