diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 5072055ea..8fc9742e7 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -263,7 +263,7 @@ # Alt docstrings # NOTE: Used by SubplotGrid.altx _alt_descrip = """ -Add an axes locked to the same location with a +Add an axis locked to the same location with a distinct {x} axis. This is an alias and arguably more intuitive name for `~ultraplot.axes.CartesianAxes.twin{y}`, which generates @@ -276,7 +276,7 @@ # Twin docstrings # NOTE: Used by SubplotGrid.twinx _twin_descrip = """ -Add an axes locked to the same location with a +Add an axis locked to the same location with a distinct {x} axis. This builds upon `matplotlib.axes.Axes.twin{y}`. """ diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 164b34dbc..8a191be5b 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -12,14 +12,17 @@ import matplotlib.gridspec as mgridspec import matplotlib.transforms as mtransforms import numpy as np +from typing import List +from functools import wraps from . import axes as paxes from .config import rc from .internals import ic # noqa: F401 from .internals import _not_none, docstring, warnings from .utils import _fontsize_to_pt, units +from .internals import warnings -__all__ = ["GridSpec", "SubplotGrid", "SubplotsContainer"] # deprecated +__all__ = ["GridSpec", "SubplotGrid"] # Gridspec vector arguments @@ -105,6 +108,49 @@ def _dummy_method(*args): return _dummy_method +def _apply_to_all(func=None, *, doc_key=None): + def decorator(f): + @wraps(f) + def wrapper(self, *args, **kwargs): + objs = self._apply_command(f.__name__, *args, **kwargs) + return SubplotGrid(objs) + + # Note: we generate the doc string on the fly by + # updating the original docstring in the snippet manager + # and adding "for every axis" in grid to the # first sentence. + # Determine source docstring + if doc_key is not None and doc_key in docstring._snippet_manager: + doc = inspect.cleandoc(docstring._snippet_manager[doc_key]) + elif f.__doc__: + doc = inspect.cleandoc(f.__doc__) + else: + doc = "" + + # Inject "for every axes in the grid" into the first sentence + if doc: + dot = doc.find(".") + if dot != -1: + doc = doc[:dot] + " for every axes in the grid" + doc[dot:] + else: + doc += " for every axes in the grid." + + # Patch "Returns" section if present + doc = re.sub( + r"^(Returns\n-------\n)(.+)(\n\s+)(.+)", + r"\1SubplotGrid\2A grid of the resulting axes.", + doc, + flags=re.MULTILINE, + ) + + wrapper.__doc__ = doc + + return wrapper + + if func is not None: + return decorator(func) + return decorator + + class _SubplotSpec(mgridspec.SubplotSpec): """ A thin `~matplotlib.gridspec.SubplotSpec` subclass with a nice string @@ -1528,38 +1574,6 @@ def __setitem__(self, key, value): raise IndexError("Multi dimensional item assignment is not supported.") return super().__setitem__(key, value) # could be list[:] = [1, 2, 3] - @classmethod - def _add_command(cls, src, name): - """ - Add a `SubplotGrid` method that iterates through axes methods. - """ - - # Create the method - def _grid_command(self, *args, **kwargs): - objs = [] - for ax in self: - obj = getattr(ax, name)(*args, **kwargs) - objs.append(obj) - return SubplotGrid(objs) - - # Clean the docstring - cmd = getattr(src, name) - doc = inspect.cleandoc(cmd.__doc__) # dedents - dot = doc.find(".") - if dot != -1: - doc = doc[:dot] + " for every axes in the grid" + doc[dot:] - doc = re.sub( - r"^(Returns\n-------\n)(.+)(\n\s+)(.+)", - r"\1SubplotGrid\2A grid of the resulting axes.", - doc, - ) - - # Apply the method - _grid_command.__qualname__ = f"SubplotGrid.{name}" - _grid_command.__name__ = name - _grid_command.__doc__ = doc - setattr(cls, name, _grid_command) - def _validate_item(self, items, scalar=False): """ Validate assignments. Accept diverse iterable inputs. @@ -1671,23 +1685,125 @@ def shape(self): # a 2D array-like object it should definitely have a shape attribute. return self.gridspec.get_geometry() + def _apply_command( + self, name, *args, warn_on_skip=True, **kwargs + ) -> List[paxes.Axes]: + """ + Apply a command to all axes that support it. + + Parameters + ---------- + name : str + The method name to call on each axes. + warn_on_skip : bool, optional + Whether to warn if some axes do not support the command. Default True. -# Dynamically add commands to generate twin or inset axes -# TODO: Add commands that plot the input data for every -# axes in the grid along a third dimension. -for _src, _name in ( - (paxes.Axes, "panel"), - (paxes.Axes, "panel_axes"), - (paxes.Axes, "inset"), - (paxes.Axes, "inset_axes"), - (paxes.CartesianAxes, "altx"), - (paxes.CartesianAxes, "alty"), - (paxes.CartesianAxes, "dualx"), - (paxes.CartesianAxes, "dualy"), - (paxes.CartesianAxes, "twinx"), - (paxes.CartesianAxes, "twiny"), -): - SubplotGrid._add_command(_src, _name) - -# Deprecated -SubplotsContainer = warnings._rename_objs("0.8.0", SubplotsContainer=SubplotGrid) + Returns + ------- + list + List of results from axes where the command was applied. + """ + objs = [] + skipped_count = 0 + for ax in self: + if hasattr(ax, name) and callable(getattr(ax, name)): + obj = getattr(ax, name)(*args, **kwargs) + objs.append(obj) + else: + skipped_count += 1 + + if warn_on_skip and skipped_count > 0: + warnings._warn_ultraplot( + f"Skipped {skipped_count} axes that do not support method '{name}'.", + UserWarning, + stacklevel=2, + ) + return objs + + # Note we use a stub @_apply_to_all since the logic + # is the same everywhere. + # Furthermore, the return type is give by the wrapper @_apply_to_all. + @_apply_to_all(doc_key="axes.altx") + def altx(self, *args, **kwargs) -> "SubplotGrid": + """ + Call `altx()` for every axes in the grid. + + Returns + ------- + SubplotGrid + A grid of the resulting axes. + """ + ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.dualx") + def dualx(self, *args, **kwargs) -> "SubplotGrid": + """ + Call `dualx()` for every axes in the grid. + + Returns + ------- + SubplotGrid + A grid of the resulting axes. + """ + ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.twinx") + def twinx(self, *args, **kwargs) -> "SubplotGrid": + """ + Call `twinx()` for every axes in the grid. + + Returns + ------- + SubplotGrid + A grid of the resulting axes. + """ + ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.alty") + def alty(self, *args, **kwargs) -> "SubplotGrid": + """ + Call `alty()` for every axes in the grid. + + Returns + ------- + SubplotGrid + A grid of the resulting axes. + """ + ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.dualy") + def dualy(self, *args, **kwargs) -> "SubplotGrid": + """ + Call `dualy()` for every axes in the grid. + + Returns + ------- + SubplotGrid + A grid of the resulting axes. + """ + ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.twiny") + def twiny( + self, *args, **kwargs + ) -> "SubplotGrid": ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.panel") + def panel( + self, *args, **kwargs + ) -> "SubplotGrid": ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.panel_axes") + def panel_axes( + self, *args, **kwargs + ) -> "SubplotGrid": ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.inset") + def inset( + self, *args, **kwargs + ) -> "SubplotGrid": ... # implementation is provided by @_apply_to_all + + @_apply_to_all(doc_key="axes.inset_axes") + def inset_axes( + self, *args, **kwargs + ) -> "SubplotGrid": ... # implementation is provided by @_apply_to_all diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py new file mode 100644 index 000000000..e3890d7a3 --- /dev/null +++ b/ultraplot/tests/test_gridspec.py @@ -0,0 +1,74 @@ +import ultraplot as uplt +import pytest +from ultraplot.gridspec import SubplotGrid + + +def test_grid_has_dynamic_methods(): + """ + Check that we can apply the methods to a SubplotGrid object. + """ + fig, axs = uplt.subplots(nrows=1, ncols=2) + for method in ("altx", "dualx", "twinx", "panel"): + assert hasattr(axs, method) + assert callable(getattr(axs, method)) + args = [] + if method == "dualx": + # needs function argument + args = ["linear"] + subplotgrid = getattr(axs, method)(*args) + assert isinstance(subplotgrid, SubplotGrid) + assert len(subplotgrid) == 2 + + +def test_altx_calls_all_axes_methods(): + """ + Check the return types of newly added methods such as altx, dualx, and twinx. + """ + fig, axs = uplt.subplots(nrows=1, ncols=2) + result = axs.altx() + assert isinstance(result, SubplotGrid) + assert len(result) == 2 + for ax in result: + assert isinstance(ax, uplt.axes.Axes) + + +def test_missing_command_is_skipped_gracefully(): + """For missing commands, we should raise an error.""" + fig, axs = uplt.subplots(nrows=1, ncols=2) + # Pretend we have a method that doesn't exist on these axes + with pytest.raises(AttributeError): + axs.nonexistent() + + +def test_docstring_injection(): + """ + @_apply_to_all should inject the docstring + """ + fig, axs = uplt.subplots(nrows=1, ncols=2) + doc = axs.altx.__doc__ + assert "for every axes in the grid" in doc + assert "Returns" in doc + + +def test_subplot_repr(): + """ + Panels don't have a subplotspec, so they return "unknown" in their repr, but normal subplots should + """ + fig, ax = uplt.subplots() + panel = ax.panel("r") + assert panel.get_subplotspec().__repr__() == "SubplotSpec(unknown)" + assert ( + ax[0].get_subplotspec().__repr__() + == "SubplotSpec(nrows=1, ncols=1, index=(0, 0))" + ) + + +def test_tight_layout_disabled(): + """ + Some methods are disabled in gridspec, such as tight_layout. + This should raise a RuntimeErrror when called on a SubplotGrid. + """ + fig, ax = uplt.subplots() + gs = ax.get_subplotspec().get_gridspec() + with pytest.raises(RuntimeError): + gs.tight_layout(fig)