datashader_utils.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import mpl_scatter_density
  2. import matplotlib
  3. import matplotlib.pyplot as plt
  4. import datashader as ds
  5. import datashader.transfer_functions as tf
  6. def plot_occupancy_projections(df):
  7. """Plots X, Y, Z columns of 'df' in a three-subplot figure. Made for large datasets."""
  8. cmap = 'Blues'
  9. norm = matplotlib.colors.PowerNorm(gamma=.3)
  10. fig = plt.figure(dpi=400)
  11. ax = fig.add_subplot(2, 2, 1, projection='scatter_density')
  12. img = ax.scatter_density(df.X, df.Z, cmap=cmap, norm=norm)
  13. ax.set(
  14. xlim=[-80, 80],
  15. ylim=[-50, 50],
  16. xlabel='X (cm)',
  17. ylabel='Z (cm)',
  18. )
  19. ax = fig.add_subplot(2, 2, 3, projection='scatter_density')
  20. img = ax.scatter_density(df.X, df.Y, cmap=cmap, norm=norm)
  21. ax.set(
  22. xlim=[-80, 80],
  23. ylim=[0., 50],
  24. xlabel='X (cm)',
  25. ylabel='Y (cm)'
  26. )
  27. ax = fig.add_subplot(2, 2, 4, projection='scatter_density')
  28. img = ax.scatter_density(df.Z, df.Y, cmap=cmap, norm=norm)
  29. ax.set(
  30. xlim=[-50, 50],
  31. ylim=[0., 50],
  32. xlabel='Z (cm)',
  33. ylabel='Y (cm)',
  34. )
  35. plt.suptitle('Distribution of Position Data', y=1.05)
  36. fig.tight_layout()
  37. return fig
  38. def imshow(data, x, y, z=None, width=1000, height=None, cmap=['lightblue', 'darkblue'], how='eq_hist',
  39. ax=None, **mpl_kwargs):
  40. """Returns an MxNx4 image array, colored by datashader. Useful for plotting."""
  41. w, h = data[x].max() - data[x].min(), data[y].max() - data[y].min()
  42. height = height if height else int(width * h / w) # Square pixels by default
  43. cvs = ds.Canvas(plot_height=height, plot_width=width)
  44. agg = ds.mean(z) if z else None
  45. agg = cvs.points(source=data, x=x, y=y, agg=agg)
  46. cmap = getattr(ds.colors, cmap) if isinstance(cmap, str) else cmap
  47. img = tf.shade(agg, cmap=cmap, how=how)
  48. img = img.values.view('uint8').reshape(img.values.shape[0], -1, 4)
  49. extent = (data[x].min(), data[x].max(), data[y].min(), data[y].max())
  50. if ax:
  51. out = ax.imshow(img, origin='lower', extent=extent)
  52. ax.set(xlabel=x, ylabel=y)
  53. else:
  54. out = plt.imshow(img, origin='lower', extent=extent)
  55. plt.gca().set(xlabel=x, ylabel=y)
  56. return out