diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 03bea4fdc..a0e30f68b 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1456,6 +1456,11 @@ def _add_legend( titlefontcolor=None, handle_kw=None, handler_map=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): """ @@ -1493,7 +1498,18 @@ def _add_legend( # Generate and prepare the legend axes if loc in ("fill", "left", "right", "top", "bottom"): - lax = self._add_guide_panel(loc, align, width=width, space=space, pad=pad) + lax = self._add_guide_panel( + loc, + align, + width=width, + space=space, + pad=pad, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + ) kwargs.setdefault("borderaxespad", 0) if not frameon: kwargs.setdefault("borderpad", 0) @@ -3560,7 +3576,19 @@ def colorbar(self, mappable, values=None, loc=None, location=None, **kwargs): @docstring._concatenate_inherited # also obfuscates params @docstring._snippet_manager - def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): + def legend( + self, + handles=None, + labels=None, + loc=None, + location=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): """ Add an inset legend or outer legend along the edge of the axes. @@ -3622,7 +3650,18 @@ def legend(self, handles=None, labels=None, loc=None, location=None, **kwargs): if queue: self._register_guide("legend", (handles, labels), (loc, align), **kwargs) else: - return self._add_legend(handles, labels, loc=loc, align=align, **kwargs) + return self._add_legend( + handles, + labels, + loc=loc, + align=align, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, + ) @docstring._concatenate_inherited @docstring._snippet_manager diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 7c2cd454b..6b5b46c48 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,12 +6,13 @@ import inspect import os from numbers import Integral + from packaging import version try: - from typing import List, Optional, Union, Tuple + from typing import List, Optional, Tuple, Union except ImportError: - from typing_extensions import List, Optional, Union, Tuple + from typing_extensions import List, Optional, Tuple, Union import matplotlib.axes as maxes import matplotlib.figure as mfigure @@ -30,7 +31,6 @@ from . import constructor from . import gridspec as pgridspec from .config import rc, rc_matplotlib -from .internals import ic # noqa: F401 from .internals import ( _not_none, _pop_params, @@ -38,10 +38,11 @@ _translate_loc, context, docstring, + ic, # noqa: F401 labels, warnings, ) -from .utils import units, _get_subplot_layout, _Crawler +from .utils import _Crawler, units __all__ = [ "Figure", @@ -1385,12 +1386,12 @@ def _add_axes_panel( # Vertical panels: should use rows parameter, not cols if _not_none(cols, col) is not None and _not_none(rows, row) is None: raise ValueError( - f"For {side!r} colorbars (vertical), use 'rows=' or 'row=' " + f"For {side!r} panels (vertical), use 'rows=' or 'row=' " "to specify span, not 'cols=' or 'col='." ) if span is not None and _not_none(rows, row) is None: warnings._warn_ultraplot( - f"For {side!r} colorbars (vertical), prefer 'rows=' over 'span=' " + f"For {side!r} panels (vertical), prefer 'rows=' over 'span=' " "for clarity. Using 'span' as rows." ) span_override = _not_none(rows, row, span) @@ -1398,7 +1399,7 @@ def _add_axes_panel( # Horizontal panels: should use cols parameter, not rows if _not_none(rows, row) is not None and _not_none(cols, col, span) is None: raise ValueError( - f"For {side!r} colorbars (horizontal), use 'cols=' or 'span=' " + f"For {side!r} panels (horizontal), use 'cols=' or 'span=' " "to specify span, not 'rows=' or 'row='." ) span_override = _not_none(cols, col, span) @@ -2395,6 +2396,7 @@ def colorbar( if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): try: ax_single = next(iter(ax)) + except (TypeError, StopIteration): ax_single = ax else: @@ -2474,8 +2476,31 @@ def legend( ax = kwargs.pop("ax", None) # Axes panel legend if ax is not None: - leg = ax.legend( - handles, labels, space=space, pad=pad, width=width, **kwargs + # Check if span parameters are provided + has_span = _not_none(span, row, col, rows, cols) is not None + + # Extract a single axes from array if span is provided + # Otherwise, pass the array as-is for normal legend behavior + if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + try: + ax_single = next(iter(ax)) + except (TypeError, StopIteration): + ax_single = ax + else: + ax_single = ax + leg = ax_single.legend( + handles, + labels, + loc=loc, + space=space, + pad=pad, + width=width, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, ) # Figure panel legend else: diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index dd23c5c18..48a40a678 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -318,3 +318,168 @@ def test_fill_between_included_in_legend(): labels = [t.get_text() for t in leg.get_texts()] assert "band" in labels uplt.close(fig) + + +def test_legend_span_bottom(): + """Test bottom legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend below row 1, spanning columns 1-2 + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify legend was created + assert leg is not None + + +def test_legend_span_top(): + """Test top legend with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Legend above row 2, spanning columns 2-3 + leg = fig.legend(ax=axs[1, :], cols=(2, 3), loc="top") + + assert leg is not None + + +def test_legend_span_right(): + """Test right legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend right of column 1, spanning rows 1-2 + leg = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right") + + assert leg is not None + + +def test_legend_span_left(): + """Test left legend with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Legend left of column 2, spanning rows 2-3 + leg = fig.legend(ax=axs[:, 1], rows=(2, 3), loc="left") + + assert leg is not None + + +def test_legend_span_validation_left_with_cols_error(): + """Test that LEFT legend raises error with cols parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="left") + + +def test_legend_span_validation_right_with_cols_error(): + """Test that RIGHT legend raises error with cols parameter.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): + fig.legend(ax=axs[0, 0], cols=(1, 2), loc="right") + + +def test_legend_span_validation_top_with_rows_error(): + """Test that TOP legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="top") + + +def test_legend_span_validation_bottom_with_rows_error(): + """Test that BOTTOM legend raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + with pytest.raises( + ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" + ): + fig.legend(ax=axs[0, 0], rows=(1, 2), loc="bottom") + + +def test_legend_span_validation_left_with_span_warns(): + """Test that LEFT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="left.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="left") + assert leg is not None + + +def test_legend_span_validation_right_with_span_warns(): + """Test that RIGHT legend with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + with pytest.warns(match="right.*vertical.*prefer 'rows='"): + leg = fig.legend(ax=axs[0, 0], span=(1, 2), loc="right") + assert leg is not None + + +def test_legend_array_without_span(): + """Test that legend on array without span preserves original behavior.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should create legend for all axes in the array + leg = fig.legend(ax=axs[:], loc="right") + assert leg is not None + + +def test_legend_array_with_span(): + """Test that legend on array with span uses first axis + span extent.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should use first axis position with span extent + leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom") + assert leg is not None + + +def test_legend_row_without_span(): + """Test that legend on row without span spans entire row.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 columns + leg = fig.legend(ax=axs[0, :], loc="bottom") + assert leg is not None + + +def test_legend_column_without_span(): + """Test that legend on column without span spans entire column.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + axs[0, 0].plot([], [], label="test") + + # Should span all 3 rows + leg = fig.legend(ax=axs[:, 0], loc="right") + assert leg is not None + + +def test_legend_multiple_sides_with_span(): + """Test multiple legends on different sides with span control.""" + fig, axs = uplt.subplots(nrows=3, ncols=3) + axs[0, 0].plot([], [], label="test") + + # Create legends on all 4 sides with different spans + leg_bottom = fig.legend(ax=axs[0, 0], span=(1, 2), loc="bottom") + leg_top = fig.legend(ax=axs[1, 0], span=(2, 3), loc="top") + leg_right = fig.legend(ax=axs[0, 0], rows=(1, 2), loc="right") + leg_left = fig.legend(ax=axs[0, 1], rows=(2, 3), loc="left") + + assert leg_bottom is not None + assert leg_top is not None + assert leg_right is not None + assert leg_left is not None