supp_figure4.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. ###############################################################################
  2. ## correlations among the baseline parameters and the mutual formation ##
  3. import os
  4. import pandas as pd
  5. import numpy as np
  6. import scipy as sp
  7. import matplotlib.pyplot as plt
  8. import matplotlib as mplt
  9. from matplotlib import cm
  10. from sklearn import linear_model
  11. from sklearn.metrics import r2_score
  12. from .property_correlations import get_feature_values
  13. def plot_correlation_heatmap(rs, ps, labels, axis=None, textsize=6, labelsize=6):
  14. """plots a heatmap of correlation coefficients and the respective p-values
  15. Args:
  16. rs (np.ndarray): matrix of correlation coefficients
  17. ps (np.ndarray): matrix of corresponding p-values
  18. labels (list): list of labels
  19. axis (matplotlib axis): matplotlib axis object, default=None
  20. """
  21. if axis is None:
  22. fig = plt.figure()
  23. axis = fig.add_subplot(111)
  24. cmap = cm.get_cmap("RdYlBu", 512)
  25. newcolors = cmap(np.linspace(0, 1, 512))
  26. black = np.array([0, 0, 0, 1])
  27. newcolors[-1, :] = black
  28. newcmp = mplt.colors.ListedColormap(newcolors)
  29. im = axis.imshow(rs, cmap=newcmp, vmin=-1.0, vmax=1.0)
  30. cb = plt.gcf().colorbar(im, ax=axis, orientation='vertical', label='correlation coefficient')
  31. cb.ax.tick_params(labelsize=labelsize)
  32. axis.set_xticks(np.arange(len(labels)))
  33. axis.set_yticks(np.arange(len(labels)))
  34. axis.set_xticks(np.arange(0.5, len(labels)), minor=True)
  35. axis.set_yticks(np.arange(0.5, len(labels)), minor=True)
  36. axis.tick_params(axis='both', which='minor', colors='white')
  37. axis.grid(color="white", axis="both", which="minor", lw=0.6)
  38. axis.set_xticklabels(labels, fontsize=labelsize)
  39. axis.set_yticklabels(labels, fontsize=labelsize)
  40. plt.setp(axis.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
  41. plt.setp(axis.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
  42. bonferroni = (len(labels)**2 - len(labels)) / 2
  43. for i in range(len(labels)):
  44. for j in range(len(labels)):
  45. if i == j:
  46. continue
  47. c = "k" if rs[i, j] > -.5 and rs[i,j] < 0.65 else "w"
  48. alpha = ps[i,j] * bonferroni
  49. p = "n.s." if alpha >= 0.05 else (r"p<%.1i%%" % (5 if alpha > .01 else 1))
  50. axis.text(j, i, "r: %.2f\n%s"% (rs[i, j], p), ha="center",
  51. va="center", color=c, fontsize=textsize)
  52. return cb
  53. def correlate(data_frame, columns):
  54. """calculate pairwise pearsonr correlations for all combinations of columns from the passed DataFrame
  55. Args:
  56. data_frame (pandas DataFrame): the data frame containing the data
  57. columns (list): list of column names
  58. Returns:
  59. numpy.ndarray: matrix containing the correlation coefficients of shape(len(columns), len(columns))
  60. numpy.ndarray: matrix containing the p-values of shape(len(columns), len(columns))
  61. """
  62. # do a z-transform
  63. temp = np.zeros((len(data_frame), len(columns)))
  64. for i, c in enumerate(columns):
  65. temp[:, i] = (data_frame[c].values - np.mean(data_frame[c].values))/np.std(data_frame[c].values)
  66. ps = np.zeros((len(columns), len(columns)))
  67. rs = np.ones(ps.shape)
  68. rscores = np.zeros(len(columns))
  69. reg = linear_model.LinearRegression()
  70. for i,_ in enumerate(columns):
  71. pattern = np.ones(len(columns), dtype=bool)
  72. pattern[i] = 0
  73. reg.fit(temp[:, pattern], temp[:, i])
  74. rs[i, pattern] = reg.coef_
  75. y_pred = reg.predict(list(temp[:, pattern]))
  76. rscores[i] = r2_score(temp[:, i], y_pred)
  77. ps = np.zeros((len(columns), len(columns)))
  78. rs = np.ones(ps.shape)
  79. for i in range(len(columns)):
  80. for j in range(len(columns)):
  81. r, p = sp.stats.pearsonr(data_frame[columns[i]].values, data_frame[columns[j]].values)
  82. ps[i,j] = p
  83. rs[i,j] = r
  84. return rs, ps
  85. def layout_figure():
  86. pass
  87. fig, axis = plt.subplots(1, 1, figsize=(3.5, 0.8 * 3.5))
  88. fig.subplots_adjust(left=0.1, top=0.95, right=0.925, bottom=0.175)
  89. return fig, axis
  90. def baseline_correlations(args):
  91. """
  92. plots the mutual information estimated from the stimulus response coherence yielded for population sizes of 1 and
  93. plots pairwise correlation coefficients for various parameters.
  94. Args:
  95. hom (pandas DataFrame): results from homogeneous populations
  96. """
  97. df = pd.read_csv(args.inputfile, sep=";", index_col=0)
  98. feats = ["population_rate", "cv", "rate_modulation", "lower_cutoff", "upper_cutoff", "mi"]
  99. labels = ["firing rate", r"$CV_{ISI}$", "rate mod.", r"$\omega_{lower}$",
  100. r"$\omega_{upper}$", "mutual info."]
  101. features,_ = get_feature_values(df, feats)
  102. selection = pd.DataFrame(features)
  103. fig, axis = layout_figure()
  104. rs, ps = correlate(selection, list(features.keys()))
  105. cb = plot_correlation_heatmap(rs, ps, labels, axis=axis)
  106. cb_pos = list(cb.ax.get_position().bounds)
  107. cb_pos[0] = cb_pos[0] + 0.005
  108. cb.ax.set_position(cb_pos)
  109. if args.nosave:
  110. plt.show()
  111. else:
  112. fig.savefig(args.outfile)
  113. plt.close()
  114. def command_line_parser(subparsers):
  115. parser = subparsers.add_parser("supfig4", help="Supplementary figure 4: Plots correlations of various baseline features.")
  116. parser.add_argument("-i", "--inputfile", default=os.path.join("derived_data", "homogeneous_populationcoding.csv"))
  117. parser.add_argument("-o", "--outfile", default=os.path.join("figures", "coding_correlations.pdf"))
  118. parser.add_argument("-n", "--nosave", action='store_true', help="no saving of the figure, just showing")
  119. parser.set_defaults(func=baseline_correlations)