property_correlations.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. ###############################################################################
  2. ## plot the baseline and driven response properties as a function ##
  3. ## receptive field pos ##
  4. import os
  5. import numpy as np
  6. import pandas as pd
  7. import scipy.stats as stats
  8. import matplotlib.pyplot as plt
  9. import matplotlib.image as mpimg
  10. import matplotlib.gridspec as gridspec
  11. from matplotlib.ticker import FormatStrFormatter
  12. from .figure_style import *
  13. label_size = 8
  14. fig2_help="Plots the properties of the baseline response and the stimulus driven response as a function of the receptive field position along the fish's rostro-caudal axis. Depends on fishsketch.png (expected in the folder ../figures/) and the data files "
  15. def get_feature_values(df, features):
  16. """Get the values of different
  17. Parameters
  18. ----------
  19. df : [type]
  20. [description]
  21. features : [type]
  22. [description]
  23. Returns
  24. -------
  25. [type]
  26. [description]
  27. """
  28. datasets = df.dataset_id.unique()
  29. df = df.dropna()
  30. feature_dict = {}
  31. for f in features:
  32. feature_dict[f] = []
  33. for d in datasets:
  34. contrasts = df.contrast[(df.dataset_id == d) & (df.pop_size == 1)].unique()
  35. for c in contrasts:
  36. trials = df[(df.dataset_id == d) & (df.pop_size == 1) & (df.contrast == c)]
  37. for f in feature_dict.keys():
  38. feature_dict[f].append(np.mean(trials[f].values))
  39. for f in feature_dict.keys():
  40. feature_dict[f] = np.array(feature_dict[f])
  41. return feature_dict, datasets
  42. def plotLinregress(ax, x, y, bonferroni_factor, color="tab:red", label="", feature=""):
  43. xRange = np.asarray([x.min(), x.max()])
  44. slope, intercept, *params = stats.linregress(x, y)
  45. p_star = params[1] * bonferroni_factor
  46. if p_star > 1:
  47. p_star = 1
  48. l = r'%s r: %.2f, %s: %.2f' % ("%s, " % label if len(label) > 0 else "", params[0], "p" if bonferroni_factor == 1 else r"p$^*$", params[1] if bonferroni_factor == 1 else p_star)
  49. ax.plot(xRange, intercept + slope * xRange, color=color, label=l, ls="-")
  50. ax.legend(loc='upper left', frameon=False, ncol=1,fontsize=6, bbox_to_anchor=(0., 1.3))
  51. print("%s: minimum: %.2f, maximum: %.2f, mean: %.2f, std: %.2f" % (feature, np.min(y), np.max(y), np.mean(y), np.std(y)))
  52. return slope, intercept
  53. def plotMeanProperty(ax, bins, x, y):
  54. my = np.zeros(len(bins)-1)
  55. sy = np.zeros(len(bins)-1)
  56. for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
  57. yy = y[(x >= start) & (x < end)]
  58. if len(yy) > 0:
  59. my[i] = np.mean(yy)
  60. sy[i] = np.std(yy)
  61. else:
  62. my[i] = np.nan
  63. sy[i] = np.nan
  64. centers = bins[:-1]+(bins[1]-bins[0])/2
  65. ax.plot(centers, my, '--', lw=0.5, color='black')
  66. ax.fill_between(centers, my+sy, my-sy, color='black', alpha=0.1)
  67. return my
  68. def remove_bounds(x, y, label="firing rate"):
  69. mask = np.ones(len(x), dtype=bool)
  70. ordered_x = np.sort(x)
  71. min_index = np.where(x == ordered_x[0])[0]
  72. max_index = np.where(x == ordered_x[-1])[0]
  73. mask[min_index] = False
  74. mask[max_index] = False
  75. reduced_x = x[mask]
  76. reduced_y = y[mask]
  77. _, _, *params = stats.linregress(reduced_x, reduced_y)
  78. print(f"{label}: removed min and max position, r: {params[0]:.2f}, p: {params[1]:.2f}")
  79. def plot_baseline_properties(axes, df, markersize, correctionFactor):
  80. xPositions = df.receptor_pos_relative.values
  81. posBins = np.linspace(0, 1, 10)
  82. counts, posBins, _ = axes[0].hist(xPositions, np.linspace(0, 1, 10), color="tab:blue", histtype='stepfilled', alpha=0.75)
  83. posCenters = posBins[:-1]+(posBins[1]-posBins[0])/2
  84. twinAx = axes[0].twinx()
  85. cumCounts = np.cumsum(counts)
  86. cumCounts /= cumCounts.max()
  87. twinAx.plot(posCenters, cumCounts, color="tab:red", alpha=0.75)
  88. twinAx.set_yticks([])
  89. twinAx.set_ylim([0, 1.5])
  90. twinAx.patch.set_visible(False)
  91. despine(twinAx, ["top","right", "bottom", "left"], True)
  92. ax = axes[0]
  93. ax.set_xticks(np.arange(0, 1.1, 0.2))
  94. ax.set_xlim([0, 1.])
  95. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  96. ax.set_xticklabels([])
  97. ax.set_ylabel("# of cells", fontsize=label_size)
  98. ax.set_ylim([0, 60])
  99. ax.set_yticks([0, 20, 40])
  100. ax.set_yticks([0, 10, 20, 30, 40], minor=True)
  101. ax.yaxis.set_label_coords(-0.15, 0.5)
  102. # firing rate
  103. y_values = df.firing_rate.values
  104. remove_bounds(xPositions.copy(), y_values.copy())
  105. ax = axes[1]
  106. plotMeanProperty(ax, posBins, xPositions, y_values)
  107. plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="firing rate")
  108. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  109. ax.set_ylabel('firing rate [Hz]', fontsize=label_size)
  110. maxFR = np.ceil(df.firing_rate.max()/200)*200
  111. ax.set_ylim(0, maxFR)
  112. ax.set_yticks(np.arange(0, maxFR+1, 200))
  113. ax.set_yticks(np.arange(0, maxFR+1, 100), minor=True)
  114. #axes[1].spines['left'].set_bounds(0, maxFR)
  115. ax.set_xlim([0, 1.0])
  116. ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  117. ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  118. ax.set_xticklabels([])
  119. ax.yaxis.set_label_coords(-0.15, 0.5)
  120. # CV of the interspike interval
  121. y_values = df.cv.values
  122. ax = axes[2]
  123. plotMeanProperty(ax, posBins, xPositions, y_values)
  124. plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="CV")
  125. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  126. ax.set_ylabel(r'CV$_{ISI}$', fontsize=label_size)
  127. ax.set_ylim(0, 1.5)
  128. ax.set_yticks(np.arange(0, 1.51, 0.5))
  129. ax.set_yticks(np.arange(0, 1.51, 0.25), minor=True)
  130. ax.set_xlim([0.0, 1.0])
  131. ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  132. ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  133. ax.set_xticklabels([])
  134. ax.yaxis.set_label_coords(-0.15, 0.5)
  135. # burst fraction
  136. y_values = df.burst_fraction.values
  137. ax = axes[3]
  138. remove_bounds(xPositions.copy(), y_values.copy(), label="burst fraction")
  139. plotMeanProperty(ax, posBins, xPositions, y_values)
  140. plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="burst fraction")
  141. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  142. ax.set_ylabel('burst fraction', fontsize=label_size)
  143. ax.set_ylim(0, 1.0)
  144. ax.set_yticks(np.arange(0, 1.01, 0.5))
  145. ax.set_yticks(np.arange(0, 1.01, 0.25), minor=True)
  146. ax.set_xlim([0, 1.0])
  147. ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  148. ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  149. ax.set_xticklabels([])
  150. ax.yaxis.set_label_coords(-0.15, 0.5)
  151. # Vector strength
  152. y_values = df.vector_strength.values
  153. ax = axes[4]
  154. plotMeanProperty(ax, posBins, xPositions, y_values)
  155. plotLinregress(ax, xPositions, y_values, bonferroni_factor=correctionFactor, feature="vector strength")
  156. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  157. ax.set_ylabel('vector strength', fontsize=label_size)
  158. ax.yaxis.set_label_coords(-0.15, 0.5)
  159. ax.set_ylim(0.6, 1.0)
  160. ax.set_yticks(np.arange(0.6, 1.01, 0.2))
  161. ax.set_yticks(np.arange(0.6, 1.01, 0.1), minor=True)
  162. ax.set_xlim([0, 1.0])
  163. ax.set_xticks(np.arange(0.0, 1.01, 0.2))
  164. ax.set_xticks(np.arange(0.0, 1.0, 0.1), minor=True)
  165. ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
  166. ax.xaxis.set_label_coords(0.5, -0.35)
  167. def plot_driven_properties(axes, df, markersize, correctionFactor):
  168. xPositions = df.receptor_pos_relative.values
  169. posBins = np.linspace(0, 1, 10)
  170. # response modulation
  171. y_values = df.response_modulation.values
  172. ax = axes[0]
  173. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  174. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="modulation")
  175. plotMeanProperty(ax, posBins, xPositions, y_values)
  176. ax.set_xticks(np.arange(0, 1.1, 0.2))
  177. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  178. ax.set_xticklabels([])
  179. ax.set_ylim([0, 300])
  180. ax.set_yticks(np.arange(0, 301, 100))
  181. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  182. ax.set_ylabel("modulation [Hz]", fontsize=label_size)
  183. ax.yaxis.set_label_coords(-0.15, 0.5)
  184. # response variability
  185. y_values = df.response_variability.values
  186. ax = axes[1]
  187. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  188. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="variability")
  189. plotMeanProperty(ax, posBins, xPositions, y_values)
  190. ax.set_xticks(np.arange(0, 1.1, 0.2))
  191. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  192. ax.set_xticklabels([])
  193. ax.set_ylim([0, 300])
  194. ax.set_yticks(np.arange(0, 301, 100))
  195. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  196. ax.set_ylabel("variability [Hz]", fontsize=label_size)
  197. ax.yaxis.set_label_coords(-0.15, 0.5)
  198. # lower and upper cutoff
  199. ax = axes[2]
  200. ax.scatter(xPositions, df.gain_cutoff_lower.values, color="tab:blue", s=markersize)
  201. plotLinregress(ax, xPositions, df.gain_cutoff_lower.values, correctionFactor, color="tab:blue", label="lower", feature="lower cutoff")
  202. plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_lower.values)
  203. ax.scatter(xPositions, df.gain_cutoff_upper.values, color="tab:red", s=markersize)
  204. plotLinregress(ax, xPositions, df.gain_cutoff_upper.values, correctionFactor, color="tab:red", label="upper", feature="upper cutoff")
  205. plotMeanProperty(ax, posBins, xPositions, df.gain_cutoff_upper.values)
  206. ax.set_xticks(np.arange(0, 1.1, 0.2))
  207. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  208. ax.set_xticklabels([])
  209. ax.set_ylim([0, 325])
  210. ax.set_yticks(np.arange(0, 301, 100))
  211. ax.set_yticks(np.arange(0, 300, 50), minor=True)
  212. ax.set_ylabel("cutoff [Hz]", fontsize=label_size)
  213. ax.yaxis.set_label_coords(-0.15, 0.5)
  214. # maximum gain
  215. y_values = df.gain_maximum.values / 1000 # convert ot kHz/mV
  216. ax = axes[3]
  217. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  218. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="gain")
  219. plotMeanProperty(ax, posBins, xPositions, y_values)
  220. ax.set_xticks(np.arange(0, 1.1, 0.2))
  221. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  222. ax.set_xticklabels([])
  223. ax.set_ylim([0, 4])
  224. ax.set_yticks(np.arange(0, 4.1, 2))
  225. ax.set_yticks(np.arange(0, 4.1, 1), minor=True)
  226. ax.set_ylabel("gain [kHz/mV]", fontsize=label_size)
  227. ax.yaxis.set_label_coords(-0.15, 0.5)
  228. ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
  229. # mututal information
  230. y_values = df.mi.values
  231. ax = axes[4]
  232. ax.scatter(xPositions, y_values, color="tab:blue", s=markersize)
  233. plotLinregress(ax, xPositions, y_values, correctionFactor, color="tab:red", feature="mutual info.")
  234. plotMeanProperty(ax, posBins, xPositions, y_values)
  235. ax.set_xticks(np.arange(0, 1.1, 0.2))
  236. ax.set_xticks(np.arange(0, 1.01, 0.1), minor=True)
  237. ax.set_ylim([0, 600])
  238. ax.set_yticks(np.arange(0, 601, 200))
  239. ax.set_yticks(np.arange(0, 601, 100), minor=True)
  240. ax.set_ylabel("mutual info. [bit/s]", fontsize=label_size)
  241. ax.set_xlabel("receptor position [rel.]", fontsize=label_size+1)
  242. ax.yaxis.set_label_coords(-0.15, 0.5)
  243. ax.yaxis.set_major_formatter(FormatStrFormatter('%i'))
  244. ax.xaxis.set_label_coords(0.5, -0.35)
  245. def layout_figure():
  246. baseline_subfig_labels = ["A", "B", "C", "D", "E"]
  247. driven_subfig_labels = ["F", "G", "H", "I", "J"]
  248. subplots_per_col = max([len(baseline_subfig_labels), len(driven_subfig_labels)])
  249. driven_axes = list(range(len(baseline_subfig_labels)))
  250. baseline_axes = list(range(len(driven_subfig_labels)))
  251. fig = plt.figure(figsize=(5.1, 5.75))
  252. subfigs = fig.subfigures(1, 2, wspace=0.05)
  253. gr = gridspec.GridSpec(subplots_per_col, 1, hspace=0.5, left=0.225)
  254. for i in range(subplots_per_col):
  255. baseline_label = baseline_subfig_labels[i]
  256. driven_label = driven_subfig_labels[i]
  257. baseline_axes[i] = subfigs[0].add_subplot(gr[i, 0])
  258. driven_axes[i] = subfigs[1].add_subplot(gr[i, 0])
  259. baseline_axes[i].text(-0.25, 1.05, baseline_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  260. transform=baseline_axes[i].transAxes, ha="center", va="bottom")
  261. driven_axes[i].text(-0.25, 1.05, driven_label, fontsize=subfig_labelsize, fontweight=subfig_labelweight,
  262. transform=driven_axes[i].transAxes, ha="center", va="bottom")
  263. despine(baseline_axes[i], ["top", "right"], False)
  264. despine(driven_axes[i], ["top", "right"], False)
  265. pic_ax = subfigs[0].add_axes((0.23, 0.88, 0.75, 0.15))
  266. if os.path.exists("figures/fishsketch.png"):
  267. img = mpimg.imread("figures/fishsketch.png")
  268. pic_ax.imshow(img)
  269. despine(pic_ax, ["top", "bottom", "left", "right"], True)
  270. else:
  271. raise FileNotFoundError("Missing fishsketch.png file.")
  272. return fig, baseline_axes, driven_axes
  273. def plot_position_correlations(args):
  274. if not os.path.exists(args.driven_frame) or not os.path.exists(args.baseline_frame):
  275. raise ValueError(f"Results data frame(s) not found! ({args.driven_frame} or {args.baseline_frame})")
  276. driven_df = pd.read_csv(args.driven_frame, sep=";", index_col=0)
  277. baseline_df = pd.read_csv(args.baseline_frame, sep=";", index_col=0)
  278. markersize = 0.5
  279. fig, baseline_axes, driven_axes = layout_figure()
  280. print("Baseline response:")
  281. plot_baseline_properties(baseline_axes, baseline_df, markersize, correctionFactor=4)
  282. print("Driven response:")
  283. plot_driven_properties(driven_axes, driven_df, markersize, correctionFactor=5)
  284. fig.subplots_adjust(left=0.2, right=0.975, top=0.95, bottom=0.075)
  285. if args.nosave:
  286. plt.show()
  287. else:
  288. fig.savefig(args.outfile, dpi=500)
  289. plt.close()
  290. plt.close()
  291. def command_line_parser(subparsers):
  292. default_bf = os.path.join("derived_data", "figure2_baseline_properties.csv")
  293. default_df = os.path.join("derived_data", "figure2_driven_properties.csv")
  294. parser = subparsers.add_parser("property_correlations", help=fig2_help)
  295. parser.add_argument("-df", "--driven_frame", default=default_df,
  296. help=f"Full file name of a CSV table readable with pandas that contains the position data and coding properties ({default_df})")
  297. parser.add_argument("-bf", "--baseline_frame", default=default_bf,
  298. help=f"Full file name of a CSV table readable with pandas that holds the baseline properties and positions (defaults to {default_bf}).")
  299. parser.add_argument("-o", "--outfile", type=str, default=os.path.join("figures", "property_position_relation.pdf"),
  300. help="The filename of the figure")
  301. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  302. parser.set_defaults(func=plot_position_correlations)
  303. return parser