123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- import pandas as pd
- import numpy as np
- import argparse
- import seaborn as sns
- import matplotlib.pyplot as plt
- from mpl_toolkits.axes_grid1 import make_axes_locatable
- import sys
- from pathlib import Path
- sys.path.append(str(Path.cwd().parents[0]))
- sys.path.append(str(Path.cwd().parents[0] / 'scripts'))
- from utils import load_df, colormap, stack_columns
- def plot_pvalues(data, ax=None, x='matrix', order=['weights', 'correlations'],
- border=0.1, alpha=0.5, lineplot=False):
- if ax is None:
- fig, ax = plt.subplots()
- data = stack_columns(data, new_column='matrix',
- prefixes=['pvalue'],
- suffixes=['weights', 'correlations'])
- sns.stripplot(x=x, y="pvalue", hue="matrix", data=data,
- ax=ax, palette=colormap, order=order, alpha=alpha)
- if lineplot:
- sns.lineplot(data=data, x=x, y='pvalue', hue='matrix',
- ax=ax, palette=colormap, alpha=0)
- ax.set_xlabel('')
- ax.set_yscale('linear')
- ax.set_ylim((border, 1.05))
- sns.despine(ax=ax, left=True, right=True, top=True, bottom=True)
- ax.get_legend().remove()
- ax.axhline(0.1, color='0.8', linestyle=':', linewidth=4)
- divider = make_axes_locatable(ax)
- axlog = divider.append_axes("bottom", size=3.5, pad=0, sharex=ax)
- axlog.set_yscale('log')
- minp = data[data.pvalue > 0].pvalue.min()
- axlog.set_ylim((minp/10, border))
- sns.stripplot(x=x, y="pvalue", hue="matrix",
- data=data, ax=axlog, palette=colormap,
- order=order, alpha=alpha)
- if lineplot:
- sns.lineplot(data=data, x=x, y='pvalue', hue='matrix',
- ax=axlog, palette=colormap, alpha=0)
- axlog.set_ylabel('')
- sns.despine(ax=axlog, left=True, right=True, top=True, bottom=True)
- axlog.set_xlabel('')
- axlog.get_legend().remove()
- return axlog
- def plot_pvalue_overview(data):
- fig, axes = plt.subplots(ncols=3, figsize=(12,7),
- gridspec_kw=dict(wspace=.5, width_ratios=[1,.2,.2]))
- alpha = 0.3
- # weights and correlation
- plot_pvalues(data, axes[0], alpha=alpha)
- # ratio
- ax = axes[1]
- sns.stripplot(y="pvalue_ratio", data=data, ax=ax,
- color=colormap['ratio'], alpha=alpha)
- ax.axhline(1, color='0.3', linestyle=':')
- ax.set_yscale('log')
- ax.yaxis.set_ticks_position('right')
- ax.yaxis.set_label_position('right')
- ax.set_ylabel(r'$p_c / p_w$')
- ax.set_xlabel('')
- ax.set_xticklabels(['ratio'])
- sns.despine(ax=ax, left=True, right=True, top=True, bottom=True)
- # rate correlation
- ax = axes[2]
- sns.stripplot(y="rate_correlation", data=data, ax=ax,
- color=colormap['rate_correlation'], alpha=alpha)
- ax.yaxis.set_ticks_position('right')
- ax.yaxis.set_label_position('right')
- ax.set_ylabel('rate vector correlation')
- ax.set_xlabel('')
- ax.set_xticklabels(['rates'])
- sns.despine(ax=ax, left=True, right=True, top=True, bottom=True)
- return fig
- if __name__ == '__main__':
- CLI = argparse.ArgumentParser()
- CLI.add_argument("--input", nargs='?', type=Path)
- CLI.add_argument("--output", nargs='?', type=Path)
- CLI.add_argument("--protocol", nargs='?', type=str)
- args, unknown = CLI.parse_known_args()
- df = load_df(args.input)
- data = df[df.protocol.str.contains(args.protocol)]
- if 'redraw' not in args.protocol:
- data = data[~data.protocol.str.contains('redraw')]
- sns.set(style='ticks', palette='deep', context='talk')
- fig = plot_pvalue_overview(data)
- fig.suptitle(f'{Path(args.output).name.strip(".png")}', fontsize=17)
- <<<<<<< HEAD
- plt.savefig(args.output, dpi=300, bbox_inches=None)
- =======
- plt.savefig(args.output, bbox_inches=None)
- >>>>>>> refs/remotes/origin/synced/master
|