12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import mpl_scatter_density
- import matplotlib
- import matplotlib.pyplot as plt
- import datashader as ds
- import datashader.transfer_functions as tf
- def plot_occupancy_projections(df):
- """Plots X, Y, Z columns of 'df' in a three-subplot figure. Made for large datasets."""
- cmap = 'Blues'
- norm = matplotlib.colors.PowerNorm(gamma=.3)
- fig = plt.figure(dpi=400)
- ax = fig.add_subplot(2, 2, 1, projection='scatter_density')
- img = ax.scatter_density(df.X, df.Z, cmap=cmap, norm=norm)
- ax.set(
- xlim=[-80, 80],
- ylim=[-50, 50],
- xlabel='X (cm)',
- ylabel='Z (cm)',
- )
- ax = fig.add_subplot(2, 2, 3, projection='scatter_density')
- img = ax.scatter_density(df.X, df.Y, cmap=cmap, norm=norm)
- ax.set(
- xlim=[-80, 80],
- ylim=[0., 50],
- xlabel='X (cm)',
- ylabel='Y (cm)'
- )
- ax = fig.add_subplot(2, 2, 4, projection='scatter_density')
- img = ax.scatter_density(df.Z, df.Y, cmap=cmap, norm=norm)
- ax.set(
- xlim=[-50, 50],
- ylim=[0., 50],
- xlabel='Z (cm)',
- ylabel='Y (cm)',
- )
- plt.suptitle('Distribution of Position Data', y=1.05)
- fig.tight_layout()
-
- return fig
- def imshow(data, x, y, z=None, width=1000, height=None, cmap=['lightblue', 'darkblue'], how='eq_hist',
- ax=None, **mpl_kwargs):
- """Returns an MxNx4 image array, colored by datashader. Useful for plotting."""
- w, h = data[x].max() - data[x].min(), data[y].max() - data[y].min()
- height = height if height else int(width * h / w) # Square pixels by default
- cvs = ds.Canvas(plot_height=height, plot_width=width)
- agg = ds.mean(z) if z else None
- agg = cvs.points(source=data, x=x, y=y, agg=agg)
-
- cmap = getattr(ds.colors, cmap) if isinstance(cmap, str) else cmap
- img = tf.shade(agg, cmap=cmap, how=how)
- img = img.values.view('uint8').reshape(img.values.shape[0], -1, 4)
-
- extent = (data[x].min(), data[x].max(), data[y].min(), data[y].max())
- if ax:
- out = ax.imshow(img, origin='lower', extent=extent)
- ax.set(xlabel=x, ylabel=y)
- else:
- out = plt.imshow(img, origin='lower', extent=extent)
- plt.gca().set(xlabel=x, ylabel=y)
- return out
|