diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index aeef2511a..2d55a7d6a 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1289,51 +1289,49 @@ def _add_colorbar( length=ticklen * ticklenratio, width=tickwidth * tickwidthratio, ) # noqa: E501 - if label is not None: - # Note for some reason axis.set_label does not work here. We need to use set_x/ylabel explicitly - match loc: - case "top" | "bottom": - if labelloc in (None, "top", "bottom"): - obj.set_label(label) - elif labelloc in ("left", "right"): - obj.ax.set_ylabel(label) - else: - raise ValueError("Could not determine position") - case "left" | "right": - if labelloc in (None, "left", "right"): - obj.set_label(label) - elif labelloc in ("top", "bottom"): - obj.ax.set_xlabel(label) - else: - raise ValueError("Could not determine position") - case "fill": - if labelloc in ("left", "right"): - obj.ax.set_ylabel(label) - elif labelloc in ("top", "bottom"): - obj.ax.set_xlabel(label) - elif labelloc is None: - obj.set_label(label) - else: - raise ValueError("Could not determine position") - # Default to setting label on long axis - case _: - obj.set_label(label) + + if _is_horizontal_loc(loc): + if labelloc is None or _is_horizontal_label(labelloc): + obj.set_label(label) + elif _is_vertical_label(labelloc): + obj.ax.set_ylabel(label) + else: + raise ValueError("Could not determine position") + + elif _is_vertical_loc(loc): + if labelloc is None or _is_vertical_label(labelloc): + obj.set_label(label) + elif _is_horizontal_label(labelloc): + obj.ax.set_xlabel(label) + else: + raise ValueError("Could not determine position") + + elif loc == "fill": + if labelloc is None: + obj.set_label(label) + elif _is_vertical_label(labelloc): + obj.ax.set_ylabel(label) + elif _is_horizontal_label(labelloc): + obj.ax.set_xlabel(label) + else: + raise ValueError("Could not determine position") + + else: + # Default to setting label on long axis + obj.set_label(label) + + # Set axis properties if labelloc is specified if labelloc is not None: - # Temporarily modify the axis to set the label and its properties - match loc: - case "top" | "bottom": - if labelloc in ("left", "right"): - axis = obj._short_axis() - case "left" | "right": - if labelloc in ("top", "bottom"): - axis = obj._short_axis() - case "fill": - if labelloc in ("top", "bottom"): - axis = obj._long_axis() - elif labelloc in ("left", "right"): - axis = obj._short_axis() - case _: - raise ValueError("Location not understood.") + if _is_horizontal_loc(loc) and _is_vertical_label(labelloc): + axis = obj._short_axis() + elif _is_vertical_loc(loc) and _is_horizontal_label(labelloc): + axis = obj._short_axis() + elif loc == "fill": + if _is_horizontal_label(labelloc): + axis = obj._long_axis() + elif _is_vertical_label(labelloc): + axis = obj._short_axis() + axis.set_label_position(labelloc) labelrotation = _not_none(labelrotation, rc["colorbar.labelrotation"]) if labelrotation == "auto": @@ -3631,3 +3629,23 @@ def _get_pos_from_locator( case "lower left" | "lower right" | "lower center": y = y_pad return (x, y) + + +def _is_horizontal_loc(loc): + """Check if location is horizontally oriented.""" + return any(keyword in loc for keyword in ["top", "bottom", "upper", "lower"]) + + +def _is_vertical_loc(loc): + """Check if location is vertically oriented.""" + return loc in ("left", "right") + + +def _is_horizontal_label(labelloc): + """Check if label location is horizontal.""" + return labelloc in ("top", "bottom") + + +def _is_vertical_label(labelloc): + """Check if label location is vertical.""" + return labelloc in ("left", "right") diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index 191925966..058dd7189 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -5,6 +5,7 @@ import numpy as np import pytest import ultraplot as uplt +from itertools import product @pytest.mark.mpl_image_compare @@ -353,3 +354,133 @@ def test_label_placement_fig_colorbar2(): fig, axs = uplt.subplots(nrows=1, ncols=2) fig.colorbar(cmap, loc="bottom", label="My Label", labelloc="right") return fig + + +@pytest.mark.parametrize( + ("labelloc", "cbarloc"), + product( + ["bottom", "top", "left", "right"], + [ + "top", + "bottom", + "left", + "right", + "upper right", + "upper left", + "lower left", + "lower right", + ], + ), +) +def test_colorbar_label_placement(labelloc, cbarloc): + """ + Ensure that colorbar label placement works correctly. + """ + cmap = uplt.Colormap("plasma_r") + title = "My Label" + fig, ax = uplt.subplots() + + cbar = ax.colorbar(cmap, loc=cbarloc, labelloc=labelloc, title=title) + + x_label = cbar.ax.xaxis.label.get_text() + y_label = cbar.ax.yaxis.label.get_text() + + assert title in (x_label, y_label), ( + f"Expected label '{title}' not found. " + f"xaxis label: '{x_label}', yaxis label: '{y_label}', " + f"labelloc='{labelloc}', cbarloc='{cbarloc}'" + ) + + uplt.close(fig) + + +@pytest.mark.parametrize( + ("cbarloc", "invalid_labelloc"), + product( + ["top", "bottom", "upper left", "lower right"], + ["invalid", "diagonal", "center", "middle", 123, "unknown"], + ), +) +def test_colorbar_invalid_horizontal_label(cbarloc, invalid_labelloc): + """ + Test error conditions and edge cases for colorbar label placement. + """ + cmap = uplt.Colormap("plasma_r") + title = "Test Label" + fig, ax = uplt.subplots() + + # Test ValueError cases - invalid labelloc for different colorbar locations + + # Horizontal colorbar location with invalid labelloc + with pytest.raises(ValueError, match="Could not determine position"): + ax.colorbar(cmap, loc=cbarloc, labelloc=invalid_labelloc, label=title) + uplt.close(fig) + + +@pytest.mark.parametrize( + ("cbarloc", "invalid_labelloc"), + product( + ["left", "right", "ll", "ul", "ur", "lr"], + [ + "invalid", + "diagonal", + "center", + "middle", + 123, + "unknown", + ], + ), +) +def test_colorbar_invalid_vertical_label(cbarloc, invalid_labelloc): + # Vertical colorbar location with invalid labelloc + cmap = uplt.Colormap("plasma_r") + title = "Test Label" + fig, ax = uplt.subplots() + with pytest.raises(ValueError, match="Could not determine position"): + ax.colorbar(cmap, loc=cbarloc, labelloc=invalid_labelloc, label=title) + uplt.close(fig) + + +@pytest.mark.parametrize( + "invalid_labelloc", ["fill", "unknown", "custom", "weird_location", 123] +) +def test_colorbar_invalid_fill_label_placement(invalid_labelloc): + # Fill location with invalid labelloc + cmap = uplt.Colormap("plasma_r") + title = "Test Label" + fig, ax = uplt.subplots() + with pytest.raises(ValueError, match="Could not determine position"): + ax.colorbar(cmap, loc="fill", labelloc=invalid_labelloc, label=title) + + +@pytest.mark.parametrize("unknown_loc", ["unknown", "custom", "weird_location", 123]) +def test_colorbar_wrong_label_placement_should_raise_error(unknown_loc): + # Unknown locs should raise errors + cmap = uplt.Colormap("plasma_r") + title = "Test Label" + fig, ax = uplt.subplots() + with pytest.raises(KeyError): + cbar = ax.colorbar(cmap, loc=unknown_loc, label=title) + + +@pytest.mark.parametrize("loc", ["top", "bottom", "left", "right", "fill"]) +def test_colorbar_label_no_labelloc(loc): + cmap = uplt.Colormap("plasma_r") + title = "Test Label" + fig, ax = uplt.subplots() + # None labelloc should always work without error + cbar = ax.colorbar(cmap, loc=loc, labelloc=None, label=title) + + # Should have the label set somewhere + label_found = ( + cbar.ax.get_title() == title + or ( + hasattr(cbar.ax.xaxis.label, "get_text") + and cbar.ax.xaxis.label.get_text() == title + ) + or ( + hasattr(cbar.ax.yaxis.label, "get_text") + and cbar.ax.yaxis.label.get_text() == title + ) + ) + assert label_found, f"Label not found for loc='{loc}' with labelloc=None" diff --git a/ultraplot/ticker.py b/ultraplot/ticker.py index 9bfaf9084..dd3d96c8c 100644 --- a/ultraplot/ticker.py +++ b/ultraplot/ticker.py @@ -892,6 +892,7 @@ def __call__(self, x, pos=None): adjusted_lon = x - self.lon0 # Normalize to -180 to 180 range adjusted_lon = ((adjusted_lon + 180) % 360) - 180 + print(x) # Use the original formatter with the adjusted longitude return super().__call__(adjusted_lon, pos)