# SPDX-FileCopyrightText: 2020 Lukas Schrangl <>
# SPDX-License-Identifier: BSD-3-Clause

"""Plotting utilities

The :py:mod:`sdt.plot` module contains

- the :py:func:`density_scatter` function, which creates a scatter plot where
  data points are colored according to data point density.
- :py:class:`PanelLabel` for creating sub-panel labels for paper figures as
  well as :py:func:`align_panellabels` for aligning them.

Programming reference

from numbers import Number
from typing import Iterable, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.stats
import numpy as np

    import bokeh
    import bokeh.plotting

    bokeh_available = True
except ImportError:
    bokeh_available = False

[docs]def density_scatter(x, y, ax=None, cmap="viridis", **kwargs): """Make a scatter plot with points colored according to density Use a Gaussian kernel density estimate to calculate the density of data points and color them accordingly. Examples -------- >>> x, y = numpy.random.normal(size=(2, 1000)) # create data >>> density_scatter(x, y) Parameters ---------- x, y : array_like, shape(n, ) Input data ax : None or matplotlib.axes.Axes or bokeh.plotting.Figure, optional Object to use for drawing. If `None`, use matplotlib's current axes (:py:func:`gca`). cmap : str or matplotlib.colors.Colormap, optional Name of colormap or `Colormap` instance to be used for mapping densities to colors. Defaults to "viridis". **kwargs Additional keyword arguments to be passed to `ax`'s `scatter` method. """ if ax is None: ax = plt.gca() if len(x) and len(y): kernel = scipy.stats.gaussian_kde([x, y]) dens = kernel(kernel.dataset) # sort so that highest densities are the last (makes nicer plots) sort_idx = np.argsort(dens) dens = dens[sort_idx] x = x[sort_idx] y = y[sort_idx] if isinstance(ax, plt.Axes): kwargs["c"] = dens kwargs["cmap"] = cmap elif bokeh_available and isinstance(ax, bokeh.plotting.Figure): cmap = cols = cmap((dens - dens.min())/dens.max()) * 255 kwargs["color"] = [bokeh.colors.RGB(*c) for c in cols.astype(int)] else: raise ValueError("Unsupported type for `ax`. Can be `None`, a " "`matplotlib.axes.Axes` instance, or " "a `bokeh.plotting.Figure` instance.") return ax.scatter(x, y, **kwargs)
[docs]class PanelLabel(mpl.text.Annotation): """(Sub-) panel label for figures Scientific figures often consist of more than one panel. This allows for adding labels (such as a, b, c, …) to the panels. This has been tested with figures using `constrained layout`. Examples -------- >>> fig, ax = matplotlib.pyplot.subplots(2, 2, constrained_layout=True) >>> pls = [] >>> for x, a in zip("abcd", ax.flatten()): ... pl = plot.PanelLabel(x) ... a.add_artist(pl) ... pls.append(pl) >>> align_panellabels(pls) See also -------- align_panellabels """ def __init__(self, label: str, horizontalposition: str = "axislabel", verticalposition: str = "top", pad: Union[float, Iterable[float]] = 0., **kwargs): """Parameters ---------- label Label text horizontalposition Where to position horizontally. Can be "axislabel" (align with the y axis label) or "frame" (align with the frame of the plot). verticalposition Where to position vertically. Can be "top" (align with the top of the panel), "axislabel" (align with the top x axis label) or "frame" (align with the frame of the plot). pad Extra space between label and panel **kwargs Additional settings. See :py:class:`matplotlib.text.Annotation` for details. Note that ``horizontalalignment`` is ``"right"`` by default. If there is too much horizontal space, try setting ``horizontalalignment="left"``. """ default = { "fontsize": "x-large", "fontweight": "bold", "verticalalignment": "baseline", "horizontalalignment": "right"} if isinstance(pad, Number): pad = (pad, pad) super().__init__(label, (0.0, 1.0), pad, xycoords=self._get_bbox, textcoords="offset points") self.update(default) self.update(kwargs) self._pos = (horizontalposition, verticalposition) self._align_x_grp = {self} self._align_y_grp = {self} self.set_clip_on(False) def _get_bbox(self, renderer: mpl.backend_bases.RendererBase ) -> mpl.transforms.BboxBase: """Get bounding box for panel label based title and axis label pos To be passed to the :py:class:`mpl.text.Annotation` constructor as ``xycoords`` argument. Parameters ---------- renderer Renderer to use Returns ------- Bounding box of label """ bbox = self.axes.get_window_extent(renderer).frozen() include_xy = [] for plx in self._align_x_grp: if self._pos[0] == "frame": xmin = plx.axes.get_window_extent(renderer).xmin else: xmin = (plx.axes.yaxis.get_tightbbox(renderer) or plx.axes.get_window_extent(renderer)).xmin include_xy.append([xmin, bbox.ymin]) for ply in self._align_y_grp: if self._pos[1] == "frame": ymax = ply.axes.get_window_extent(renderer).ymax elif self._pos[1] == "axislabel": ymax = (ply.axes.xaxis.get_tightbbox(renderer) or ply.axes.get_window_extent(renderer)).ymax elif self._pos[1] == "top": ymax = ply.axes.title.get_tightbbox(renderer).ymax else: t = ply.axes.title tpos = t.get_position() ymax = t.get_transform().transform(tpos)[1] include_xy.append([bbox.xmin, ymax]) bbox.update_from_data_xy(include_xy, ignore=False) return bbox
[docs]def align_panellabels(pls: Iterable[PanelLabel]): """Align panel labels row-wise and column-wise Labels in the same row will be aligned vertically, those in the same column horizontally. For this, all axes to which the panels have been added need to share the same GridSpec. Examples -------- >>> x, y = numpy.random.normal(size=(2, 1000)) # create data >>> density_scatter(x, y) >>> fig, ax = matplotlib.pyplot.subplots(2, 2, constrained_layout=True) >>> pls = [] >>> for x, a in zip("abcd", ax.flatten()): ... pl = plot.PanelLabel(x) ... a.add_artist(pl) ... pls.append(pl) >>> align_panellabels(pls) Parameters ---------- pls Panel labels to align See also -------- PanelLabel """ for pl in pls: pl._align_x_grp = set() pl._align_y_grp = set() ss = pl.axes.get_subplotspec() row0 = ss.rowspan.start col0 = ss.colspan.start # loop through other axes and search ones that share the # appropriate column or row number. # Add to a list associated with each axes of siblings. # This list used in `Axes._get_panellabel_bbox`. for plc in pls: ssc = plc.axes.get_subplotspec() if ssc.colspan.start == col0: pl._align_x_grp.add(plc) if ssc.rowspan.start == row0: pl._align_y_grp.add(plc)