15 KB

  ###############################################################################
  ## plot the baseline and driven response properties as a function ##
  ## receptive field pos ##
  import os
  import numpy as np
  import pandas as pd
  import scipy.stats as stats
  import matplotlib.pyplot as plt
  import matplotlib.image as mpimg
  import matplotlib.gridspec as gridspec
  from matplotlib.ticker import FormatStrFormatter
  from .figure_style import *
  label_size = 8
  fig2_help="Plots the properties of the baseline response and the stimulus driven response as a function of the receptive field position along the fish's rostro-caudal axis. Depends on fishsketch.png (expected in the folder ../figures/) and the data files "
  def get_feature_values(df, features):
  """Get the values of different
  Parameters
  ----------
  df : [type]
  [description]
  features : [type]
  [description]
  Returns
  -------
  [type]
  [description]
  """
  datasets = df.dataset_id.unique()
  df = df.dropna()
  feature_dict = {}
  for f in features:
  feature_dict[f] = []
  for d in datasets:
  contrasts = df.contrast[(df.dataset_id == d) & (df.pop_size == 1)].unique()
  for c in contrasts:
  trials = df[(df.dataset_id == d) & (df.pop_size == 1) & (df.contrast == c)]
  for f in feature_dict.keys():
  feature_dict[f].append(np.mean(trials[f].values))
  for f in feature_dict.keys():
  feature_dict[f] = np.array(feature_dict[f])
  return feature_dict, datasets
  def plotLinregress(ax, x, y, bonferroni_factor, color="tab:red", label="", feature=""):
  xRange = np.asarray([x.min(), x.max()])
  slope, intercept, *params = stats.linregress(x, y)
  p_star = params[1] * bonferroni_factor
  if p_star > 1:
  p_star = 1
  l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p$^*$", params[1] if bonferroni_factor == 1 else p_star)
  ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-")
  ax.legend(loc='upper left', frameon=False, ncol=1,fontsize=6, bbox_to_anchor=(0., 1.3))
  print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y)))
  return slope, intercept
  def plotMeanProperty(ax, bins, x, y):
  my = np.zeros(len(bins)-1)
  sy = np.zeros(len(bins)-1)
  for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
  yy = y[(x >= start) & (x < end)]
  if len(yy) > 0:
  my[i] = np.mean(yy)
  sy[i] = np.std(yy)
  else:
  my[i] = np.nan
  sy[i] = np.nan
  centers = bins[:-1]+(bins[1]-bins[0])/2
  ax.plot(centers, my, '--', lw=0.5, color='black')
  ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1)
  return my
  def remove_bounds(x, y, label="firing rate"):
  mask = np.ones(len(x), dtype=bool)
  ordered_x = np.sort(x)
  min_index = np.where(x == ordered_x[0])[0]
  max_index = np.where(x == ordered_x[-1])[0]
  mask[min_index] = False
  mask[max_index] = False
  reduced_x = x[mask]
  reduced_y = y[mask]
  _, _, *params = stats.linregress(reduced_x, reduced_y)
  print(f"{label}: removed min and max position, r: {params[0]:.2f}, p: {params[1]:.2f}")
  def plot_baseline_properties(axes, df, markersize, correctionFactor):
  xPositions = df.receptor_pos_relative.values
  posBins = np.linspace(0, 1, 10)
  counts, posBins, _ = axes[0].hist(xPositions, np.linspace(0, 1, 10), color="tab:blue", histtype='stepfilled', alpha=0.75)
  posCenters = posBins[:-1]+(posBins[1]-posBins[0])/2
  twinAx = axes[0].twinx()
  cumCounts = np.cumsum(counts)
  cumCounts /= cumCounts.max()
  twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.75)
  twinAx.set_yticks([])
  twinAx.set_ylim([0, 1.5])
  twinAx.patch.set_visible(False)
  despine(twinAx, ["top","right", "bottom", "left"], True)
  ax = axes[0]
  ax.set_xticks(np.arange(0, 1.1, 0.2))
  ax.set_xlim([0, 1.])
  ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  ax.set_xticklabels([])
  ax.set_ylabel("# of cells", fontsize=label_size)
  ax.set_ylim([0, 60])
  ax.set_yticks([0, 20, 40])
  ax.set_yticks([0, 10, 20, 30, 40], minor=True)
  ax.yaxis.set_label_coords(-0.15, 0.5)
  # firing rate
  y_values = df.firing_rate.values
  remove_bounds(xPositions.copy(), y_values.copy())
  ax = axes[1]
  plotMeanProperty(ax, posBins, xPositions, y_values)
  plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="firing rate")
  ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  ax.set_ylabel('firing rate [Hz]', fontsize=label_size)
  maxFR = np.ceil(df.firing_rate.max()/200)*200
  ax.set_ylim(0, maxFR)
  ax.set_yticks(np.arange(0, maxFR+1, 200))
  ax.set_yticks(np.arange(0, maxFR+1, 100), minor=True)
  #axes[1].spines['left'].set_bounds(0, maxFR)
  ax.set_xlim([0, 1.0])
  ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  ax.set_xticklabels([])
  ax.yaxis.set_label_coords(-0.15, 0.5)
  # CV of the interspike interval
  y_values =
  ax = axes[2]
  plotMeanProperty(ax, posBins, xPositions, y_values)
  plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="CV")
  ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  ax.set_ylabel(r'CV$_{ISI}$', fontsize=label_size)
  ax.set_ylim(0, 1.5)
  ax.set_yticks(np.arange(0, 1.51, 0.5))
  ax.set_yticks(np.arange(0, 1.51, 0.25), minor=True)
  ax.set_xlim([0.0, 1.0])
  ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  ax.set_xticklabels([])
  ax.yaxis.set_label_coords(-0.15, 0.5)
  # burst fraction
  y_values = df.burst_fraction.values
  ax = axes[3]
  remove_bounds(xPositions.copy(), y_values.copy(), label="burst fraction")
  plotMeanProperty(ax, posBins, xPositions, y_values)
  plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="burst fraction")
  ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  ax.set_ylabel('burst fraction', fontsize=label_size)
  ax.set_ylim(0, 1.0)
  ax.set_yticks(np.arange(0, 1.01, 0.5))
  ax.set_yticks(np.arange(0, 1.01, 0.25), minor=True)
  ax.set_xlim([0, 1.0])
  ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  ax.set_xticklabels([])
  ax.yaxis.set_label_coords(-0.15, 0.5)
  # Vector strength
  y_values = df.vector_strength.values
  ax = axes[4]
  plotMeanProperty(ax, posBins, xPositions, y_values
  155. plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="vector strength")
  156. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  157. ax.set_ylabel('vector strength', fontsize=label_size)
  158. ax.yaxis.set_label_coords(-0.15, 0.5)
  159. ax.set_ylim(0.6, 1.0)
  160. ax.set_yticks(np.arange(0.6, 1.01, 0.2))
  161. ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True)
  162. ax.set_xlim([0, 1.0])
  163. ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  164. ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  165. ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
  166. ax.xaxis.set_label_coords(0.5, -0.35)
  167. def plot_driven_properties(axes, df, markersize, correctionFactor):
  168. xPositions = df.receptor_pos_relative.values
  169. posBins = np.linspace(0, 1, 10)
  170. # response modulation
  171. y_values = df.response_modulation.values
  172. ax = axes[0]
  173. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  174. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="modulation")
  175. plotMeanProperty(ax, posBins, xPositions, y_values)
  176. ax.set_xticks(np.arange(0, 1.1, 0.2))
  177. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  178. ax.set_xticklabels([])
  179. ax.set_ylim([0, 300])
  180. ax.set_yticks(np.arange(0, 301, 100))
  181. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  182. ax.set_ylabel("modulation [Hz]", fontsize=label_size)
  183. ax.yaxis.set_label_coords(-0.15, 0.5)
  184. # response variability
  185. y_values = df.response_variability.values
  186. ax = axes[1]
  187. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  188. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="variability")
  189. plotMeanProperty(ax, posBins, xPositions, y_values)
  190. ax.set_xticks(np.arange(0, 1.1, 0.2))
  191. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  192. ax.set_xticklabels([])
  193. ax.set_ylim([0, 300])
  194. ax.set_yticks(np.arange(0, 301, 100))
  195. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  196. ax.set_ylabel("variability [Hz]", fontsize=label_size)
  197. ax.yaxis.set_label_coords(-0.15, 0.5)
  198. # lower and upper cutoff
  199. ax = axes[2]
  200. ax.scatter(xPositions, df.gain_cutoff_lower.values, color="tab:blue", s=markersize)
  201. plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:blue", label="lower", feature="lower cutoff")
  202. plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values)
  203. ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:red", s=markersize)
  204. plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", label="upper", feature="upper cutoff")
  205. plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_upper.values)
  206. ax.set_xticks(np.arange(0, 1.1, 0.2))
  207. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  208. ax.set_xticklabels([])
  209. ax.set_ylim([0, 325])
  210. ax.set_yticks(np.arange(0, 301, 100))
  211. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  212. ax.set_ylabel("cutoff [Hz]", fontsize=label_size)
  213. ax.yaxis.set_label_coords(-0.15, 0.5)
  214. # maximum gain
  215. y_values = df.gain_maximum.values / 1000 # convert ot kHz/mV
  216. ax = axes[3]
  217. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  218. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="gain")
  219. plotMeanProperty(ax, posBins, xPositions, y_values)
  220. ax.set_xticks(np.arange(0, 1.1, 0.2))
  221. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  222. ax.set_xticklabels([])
  223. ax.set_ylim([0, 4])
  224. ax.set_yticks(np.arange(0, 4.1, 2))
  225. ax.set_yticks(np.arange(0, 4.1, 1), minor=True)
  226. ax.set_ylabel("gain [kHz/mV]", fontsize=label_size)
  227. ax.yaxis.set_label_coords(-0.15, 0.5)
  228. ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
  229. # mututal information
  230. y_values = df.mi.values
  231. ax = axes[4]
  232. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  233. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="mutual info.")
  234. plotMeanProperty(ax, posBins, xPositions, y_values)
  235. ax.set_xticks(np.arange(0, 1.1, 0.2))
  236. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  237. ax.set_ylim([0, 600])
  238. ax.set_yticks(np.arange(0, 601, 200))
  239. ax.set_yticks(np.arange(0, 601, 100), minor=True)
  240. ax.set_ylabel("mutual info. [bit/s]", fontsize=label_size)
  241. ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
  242. ax.yaxis.set_label_coords(-0.15, 0.5)
  243. ax.yaxis.set_major_formatter(FormatStrFormatter('%i'))
  244. ax.xaxis.set_label_coords(0.5, -0.35)
  245. def layout_figure():
  246. baseline_subfig_labels = ["A", "B", "C", "D", "E"]
  247. driven_subfig_labels = ["F", "G", "H", "I", "J"]
  248. subplots_per_col = max([len(baseline_subfig_labels), len(driven_subfig_labels)])
  249. driven_axes = list(range(len(baseline_subfig_labels)))
  250. baseline_axes = list(range(len(driven_subfig_labels)))
  251. fig = plt.figure(figsize=(5.1, 5.75))
  252. subfigs = fig.subfigures(1, 2, wspace=0.05)
  253. gr = gridspec.GridSpec(subplots_per_col, 1, hspace=0.5, left=0.225)
  254. for i in range(subplots_per_col):
  255. baseline_label = baseline_subfig_labels[i]
  256. driven_label = driven_subfig_labels[i]
  257. baseline_axes[i] = subfigs[0].add_subplot(gr[i, 0])
  258. driven_axes[i] = subfigs[1].add_subplot(gr[i, 0])
  259. baseline_axes[i].text(-0.25, 1.05, baseline_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  260. transform=baseline_axes[i].transAxes, ha="center", va="bottom")
  261. driven_axes[i].text(-0.25, 1.05, driven_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  262. transform=driven_axes[i].transAxes, ha="center", va="bottom")
  263. despine(baseline_axes[i], ["top", "right"], False)
  264. despine(driven_axes[i], ["top", "right"], False)
  265. pic_ax = subfigs[0].add_axes((0.23, 0.88, 0.75, 0.15))
  266. if os.path.exists("figures/fishsketch.png"):
  267. img = mpimg.imread("figures/fishsketch.png")
  268. pic_ax.imshow(img)
  269. despine(pic_ax, ["top", "bottom", "left", "right"], True)
  270. else:
  271. raise FileNotFoundError("Missing fishsketch.png file.")
  272. return fig, baseline_axes, driven_axes
  273. def plot_position_correlations(args):
  274. if not os.path.exists(args.driven_frame) or not os.path.exists(args.baseline_frame):
  275. raise ValueError(f"Results data frame(s) not found! ({args.driven_frame} or {args.baseline_frame})")
  276. driven_df = pd.read_csv(args.driven_frame, sep=";", index_col=0)
  277. baseline_df = pd.read_csv(args.baseline_frame, sep=";", index_col=0)
  278. markersize = 0.5
  279. fig, baseline_axes, driven_axes = layout_figure()
  280. print("Baseline response:")
  281. plot_baseline_properties(baseline_axes, baseline_df, markersize, correctionFactor=4)
  282. print("Driven response:")
  283. plot_driven_properties(driven_axes, driven_df, markersize, correctionFactor=5)
  284. fig.subplots_adjust(left=0.2, right=0.975, top=0.95, bottom=0.075)
  285. if args.nosave:
  287. else:
  288. fig.savefig(args.outfile, dpi=500)
  289. plt.close()
  290. plt.close()
  291. def command_line_parser(subparsers):
  292. default_bf = os.path.join("derived_data", "figure2_baseline_properties.csv")
  293. default_df = os.path.join("derived_data", "figure2_driven_properties.csv")
  294. parser = subparsers.add_parser("property_correlations", help=fig2_help)
  295. parser.add_argument("-df", "--driven_frame", default=default_df,
  296. help=f"Full file name of a CSV table readable with pandas that contains the position data and coding properties ({default_df})")
  297. parser.add_argument("-bf", "--baseline_frame", default=default_bf,
  298. help=f"Full file name of a CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).")
  299. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures", "property_position_relation.pdf"),
  300. help="The filename of the figure")
  301. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  302. parser.set_defaults(func=plot_position_correlations)
  303. return parser