bokeh_charts.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import pandas as pd
  2. from bokeh.plotting import figure
  3. from bokeh.palettes import d3
  4. from bokeh.models import Plot, Line, Legend, ColumnDataSource
  5. import numpy as np
  6. def get_bounds(arr):
  7. arr_min, arr_max = np.nanmin(arr), np.nanmax(arr)
  8. margin = 0.1 * (arr_max - arr_min)
  9. return arr_min - margin, arr_max + margin
  10. def add_lineplot(
  11. data: pd.DataFrame, x: str, y: str, hue: str, bokeh_plot: Plot,
  12. legend_location: str = "center right", legend_nrow: int = 1,
  13. legend_click_policy: str = "hide", legend_orientation: str = "vertical",
  14. white_filled_circle_marker=False, circle_marker_size=8
  15. ):
  16. group = data.groupby(hue)
  17. if len(group) < 3:
  18. palette = d3["Category10"][3][:len(group)]
  19. elif 3 <= len(group) <= 10:
  20. palette = d3["Category10"][len(group)]
  21. else:
  22. palette = d3["Category20"][len(group)]
  23. lines = {}
  24. for color, (hue, hue_df) in zip(palette, group):
  25. line = bokeh_plot.line(x=hue_df[x], y=hue_df[y], line_color=color, legend_label=hue)
  26. lines[hue] = line
  27. if white_filled_circle_marker:
  28. bokeh_plot.circle(x=hue_df[x], y=hue_df[y], fill_color="white", size=circle_marker_size)
  29. bokeh_plot.xaxis.axis_label = x
  30. bokeh_plot.yaxis.axis_label = y
  31. bokeh_plot.xaxis.bounds = get_bounds(data[x])
  32. bokeh_plot.yaxis.bounds = get_bounds(data[y])
  33. bokeh_plot.legend.click_policy = legend_click_policy
  34. bokeh_plot.legend.location = legend_location
  35. bokeh_plot.legend.orientation = legend_orientation
  36. return bokeh_plot