123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- ###############################################################################
- ## plot the baseline and driven response properties as a function ##
- ## receptive field pos ##
- import os
- import numpy as np
- import pandas as pd
- import scipy.stats as stats
- import matplotlib.pyplot as plt
- import matplotlib.image as mpimg
- import matplotlib.gridspec as gridspec
- from matplotlib.ticker import FormatStrFormatter
- from .figure_style import *
- label_size = 8
- fig2_help="Plots the properties of the baseline response and the stimulus driven response as a function of the receptive field position along the fish's rostro-caudal axis. Depends on fishsketch.png (expected in the folder ../figures/) and the data files "
- def get_feature_values(df, features):
- """Get the values of different
- Parameters
- ----------
- df : [type]
- [description]
- features : [type]
- [description]
- Returns
- -------
- [type]
- [description]
- """
- datasets = df.dataset_id.unique()
- df = df.dropna()
- feature_dict = {}
- for f in features:
- feature_dict[f] = []
- for d in datasets:
- contrasts = df.contrast[(df.dataset_id == d) & (df.pop_size == 1)].unique()
- for c in contrasts:
- trials = df[(df.dataset_id == d) & (df.pop_size == 1) & (df.contrast == c)]
- for f in feature_dict.keys():
- feature_dict[f].append(np.mean(trials[f].values))
- for f in feature_dict.keys():
- feature_dict[f] = np.array(feature_dict[f])
- return feature_dict, datasets
- def plotLinregress(ax, x, y, bonferroni_factor, color="tab:red", label="", feature=""):
- xRange = np.asarray([x.min(), x.max()])
- slope, intercept, *params = stats.linregress(x, y)
- p_star = params[1] * bonferroni_factor
- if p_star > 1:
- p_star = 1
- l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p$^*$", params[1] if bonferroni_factor == 1 else p_star)
- ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-")
- ax.legend(loc='upper left', frameon=False, ncol=1,fontsize=6, bbox_to_anchor=(0., 1.3))
- print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y)))
- return slope, intercept
- def plotMeanProperty(ax, bins, x, y):
- my = np.zeros(len(bins)-1)
- sy = np.zeros(len(bins)-1)
- for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
- yy = y[(x >= start) & (x < end)]
- if len(yy) > 0:
- my[i] = np.mean(yy)
- sy[i] = np.std(yy)
- else:
- my[i] = np.nan
- sy[i] = np.nan
- centers = bins[:-1]+(bins[1]-bins[0])/2
- ax.plot(centers, my, '--', lw=0.5, color='black')
- ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1)
- return my
- def remove_bounds(x, y, label="firing rate"):
- mask = np.ones(len(x), dtype=bool)
- ordered_x = np.sort(x)
- min_index = np.where(x == ordered_x[0])[0]
- max_index = np.where(x == ordered_x[-1])[0]
- mask[min_index] = False
- mask[max_index] = False
- reduced_x = x[mask]
- reduced_y = y[mask]
- _, _, *params = stats.linregress(reduced_x, reduced_y)
- print(f"{label}: removed min and max position, r: {params[0]:.2f}, p: {params[1]:.2f}")
- def plot_baseline_properties(axes, df, markersize, correctionFactor):
- xPositions = df.receptor_pos_relative.values
- posBins = np.linspace(0, 1, 10)
- counts, posBins, _ = axes[0].hist(xPositions, np.linspace(0, 1, 10), color="tab:blue", histtype='stepfilled', alpha=0.75)
- posCenters = posBins[:-1]+(posBins[1]-posBins[0])/2
- twinAx = axes[0].twinx()
- cumCounts = np.cumsum(counts)
- cumCounts /= cumCounts.max()
- twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.75)
- twinAx.set_yticks([])
- twinAx.set_ylim([0, 1.5])
- twinAx.patch.set_visible(False)
- despine(twinAx, ["top","right", "bottom", "left"], True)
- ax = axes[0]
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xlim([0, 1.])
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.set_ylabel("# of cells", fontsize=label_size)
- ax.set_ylim([0, 60])
- ax.set_yticks([0, 20, 40])
- ax.set_yticks([0, 10, 20, 30, 40], minor=True)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # firing rate
- y_values = df.firing_rate.values
- remove_bounds(xPositions.copy(), y_values.copy())
- ax = axes[1]
- plotMeanProperty(ax, posBins, xPositions, y_values)
- plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="firing rate")
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- ax.set_ylabel('firing rate [Hz]', fontsize=label_size)
- maxFR = np.ceil(df.firing_rate.max()/200)*200
- ax.set_ylim(0, maxFR)
- ax.set_yticks(np.arange(0, maxFR+1, 200))
- ax.set_yticks(np.arange(0, maxFR+1, 100), minor=True)
- #axes[1].spines['left'].set_bounds(0, maxFR)
- ax.set_xlim([0, 1.0])
- ax.set_xticks(np.arange(0.0, 1.01, 0.2))
- ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # CV of the interspike interval
- y_values = df.cv.values
- ax = axes[2]
- plotMeanProperty(ax, posBins, xPositions, y_values)
- plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="CV")
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- ax.set_ylabel(r'CV$_{ISI}$', fontsize=label_size)
- ax.set_ylim(0, 1.5)
- ax.set_yticks(np.arange(0, 1.51, 0.5))
- ax.set_yticks(np.arange(0, 1.51, 0.25), minor=True)
- ax.set_xlim([0.0, 1.0])
- ax.set_xticks(np.arange(0.0, 1.01, 0.2))
- ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # burst fraction
- y_values = df.burst_fraction.values
- ax = axes[3]
- remove_bounds(xPositions.copy(), y_values.copy(), label="burst fraction")
- plotMeanProperty(ax, posBins, xPositions, y_values)
- plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="burst fraction")
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- ax.set_ylabel('burst fraction', fontsize=label_size)
- ax.set_ylim(0, 1.0)
- ax.set_yticks(np.arange(0, 1.01, 0.5))
- ax.set_yticks(np.arange(0, 1.01, 0.25), minor=True)
- ax.set_xlim([0, 1.0])
- ax.set_xticks(np.arange(0.0, 1.01, 0.2))
- ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # Vector strength
- y_values = df.vector_strength.values
- ax = axes[4]
- plotMeanProperty(ax, posBins, xPositions, y_values)
- plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="vector strength")
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- ax.set_ylabel('vector strength', fontsize=label_size)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- ax.set_ylim(0.6, 1.0)
- ax.set_yticks(np.arange(0.6, 1.01, 0.2))
- ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True)
- ax.set_xlim([0, 1.0])
- ax.set_xticks(np.arange(0.0, 1.01, 0.2))
- ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
- ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
- ax.xaxis.set_label_coords(0.5, -0.35)
- def plot_driven_properties(axes, df, markersize, correctionFactor):
- xPositions = df.receptor_pos_relative.values
- posBins = np.linspace(0, 1, 10)
- # response modulation
- y_values = df.response_modulation.values
- ax = axes[0]
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="modulation")
- plotMeanProperty(ax, posBins, xPositions, y_values)
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.set_ylim([0, 300])
- ax.set_yticks(np.arange(0, 301, 100))
- ax.set_yticks(np.arange(0, 300, 50), minor=True)
- ax.set_ylabel("modulation [Hz]", fontsize=label_size)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # response variability
- y_values = df.response_variability.values
- ax = axes[1]
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="variability")
- plotMeanProperty(ax, posBins, xPositions, y_values)
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.set_ylim([0, 300])
- ax.set_yticks(np.arange(0, 301, 100))
- ax.set_yticks(np.arange(0, 300, 50), minor=True)
- ax.set_ylabel("variability [Hz]", fontsize=label_size)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # lower and upper cutoff
- ax = axes[2]
- ax.scatter(xPositions, df.gain_cutoff_lower.values, color="tab:blue", s=markersize)
- plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:blue", label="lower", feature="lower cutoff")
- plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values)
- ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:red", s=markersize)
- plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", label="upper", feature="upper cutoff")
- plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_upper.values)
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.set_ylim([0, 325])
- ax.set_yticks(np.arange(0, 301, 100))
- ax.set_yticks(np.arange(0, 300, 50), minor=True)
- ax.set_ylabel("cutoff [Hz]", fontsize=label_size)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- # maximum gain
- y_values = df.gain_maximum.values / 1000 # convert ot kHz/mV
- ax = axes[3]
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="gain")
- plotMeanProperty(ax, posBins, xPositions, y_values)
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_xticklabels([])
- ax.set_ylim([0, 4])
- ax.set_yticks(np.arange(0, 4.1, 2))
- ax.set_yticks(np.arange(0, 4.1, 1), minor=True)
- ax.set_ylabel("gain [kHz/mV]", fontsize=label_size)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
- # mututal information
- y_values = df.mi.values
- ax = axes[4]
- ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
- plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="mutual info.")
- plotMeanProperty(ax, posBins, xPositions, y_values)
- ax.set_xticks(np.arange(0, 1.1, 0.2))
- ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
- ax.set_ylim([0, 600])
- ax.set_yticks(np.arange(0, 601, 200))
- ax.set_yticks(np.arange(0, 601, 100), minor=True)
- ax.set_ylabel("mutual info. [bit/s]", fontsize=label_size)
- ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
- ax.yaxis.set_label_coords(-0.15, 0.5)
- ax.yaxis.set_major_formatter(FormatStrFormatter('%i'))
- ax.xaxis.set_label_coords(0.5, -0.35)
- def layout_figure():
- baseline_subfig_labels = ["A", "B", "C", "D", "E"]
- driven_subfig_labels = ["F", "G", "H", "I", "J"]
- subplots_per_col = max([len(baseline_subfig_labels), len(driven_subfig_labels)])
- driven_axes = list(range(len(baseline_subfig_labels)))
- baseline_axes = list(range(len(driven_subfig_labels)))
- fig = plt.figure(figsize=(5.1, 5.75))
- subfigs = fig.subfigures(1, 2, wspace=0.05)
- gr = gridspec.GridSpec(subplots_per_col, 1, hspace=0.5, left=0.225)
- for i in range(subplots_per_col):
- baseline_label = baseline_subfig_labels[i]
- driven_label = driven_subfig_labels[i]
- baseline_axes[i] = subfigs[0].add_subplot(gr[i, 0])
- driven_axes[i] = subfigs[1].add_subplot(gr[i, 0])
- baseline_axes[i].text(-0.25, 1.05, baseline_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=baseline_axes[i].transAxes, ha="center", va="bottom")
- driven_axes[i].text(-0.25, 1.05, driven_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
- transform=driven_axes[i].transAxes, ha="center", va="bottom")
- despine(baseline_axes[i], ["top", "right"], False)
- despine(driven_axes[i], ["top", "right"], False)
- pic_ax = subfigs[0].add_axes((0.23, 0.88, 0.75, 0.15))
- if os.path.exists("figures/fishsketch.png"):
- img = mpimg.imread("figures/fishsketch.png")
- pic_ax.imshow(img)
- despine(pic_ax, ["top", "bottom", "left", "right"], True)
- else:
- raise FileNotFoundError("Missing fishsketch.png file.")
- return fig, baseline_axes, driven_axes
- def plot_position_correlations(args):
- if not os.path.exists(args.driven_frame) or not os.path.exists(args.baseline_frame):
- raise ValueError(f"Results data frame(s) not found! ({args.driven_frame} or {args.baseline_frame})")
- driven_df = pd.read_csv(args.driven_frame, sep=";", index_col=0)
- baseline_df = pd.read_csv(args.baseline_frame, sep=";", index_col=0)
- markersize = 0.5
- fig, baseline_axes, driven_axes = layout_figure()
- print("Baseline response:")
- plot_baseline_properties(baseline_axes, baseline_df, markersize, correctionFactor=4)
- print("Driven response:")
- plot_driven_properties(driven_axes, driven_df, markersize, correctionFactor=5)
- fig.subplots_adjust(left=0.2, right=0.975, top=0.95, bottom=0.075)
- if args.nosave:
- plt.show()
- else:
- fig.savefig(args.outfile, dpi=500)
- plt.close()
- plt.close()
- def command_line_parser(subparsers):
- default_bf = os.path.join("derived_data", "figure2_baseline_properties.csv")
- default_df = os.path.join("derived_data", "figure2_driven_properties.csv")
- parser = subparsers.add_parser("property_correlations", help=fig2_help)
- parser.add_argument("-df", "--driven_frame", default=default_df,
- help=f"Full file name of a CSV table readable with pandas that contains the position data and coding properties ({default_df})")
- parser.add_argument("-bf", "--baseline_frame", default=default_bf,
- help=f"Full file name of a CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).")
- parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures", "property_position_relation.pdf"),
- help="The filename of the figure")
- parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
- parser.set_defaults(func=plot_position_correlations)
- return parser
|