diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 2ba144c71..91d1137de 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -9,7 +9,8 @@ import re import sys from numbers import Integral, Number -from typing import Any, Iterable +from typing import Any +from collections.abc import Iterable import matplotlib.artist as martist import matplotlib.axes as maxes @@ -2486,18 +2487,16 @@ def _parse_cycle( resolved_cycle = None case True: resolved_cycle = constructor.Cycle(rc["axes.prop_cycle"]) + case constructor.Cycle(): + resolved_cycle = constructor.Cycle(cycle) case str() if cycle.lower() == "none": resolved_cycle = None - case str() | int(): + case str() | int() | Iterable(): resolved_cycle = constructor.Cycle(cycle, **cycle_kw) - case constructor.Cycle(): - resolved_cycle = constructor.Cycle(cycle) case _: resolved_cycle = None - # Ignore cycle for single-column plotting - resolved_cycle = None if ncycle == 1 else resolved_cycle - + # Ignore cycle for single-column plotting unless cycle is different if resolved_cycle and resolved_cycle != self._active_cycle: self.set_prop_cycle(resolved_cycle) diff --git a/ultraplot/constructor.py b/ultraplot/constructor.py index d5aae8bd3..a673668ea 100644 --- a/ultraplot/constructor.py +++ b/ultraplot/constructor.py @@ -938,6 +938,12 @@ def _build_cycler(self, dicts): mcycler = cycler.cycler(**props) super().__init__(mcycler) + def __eq__(self, other): + for a, b in zip(self, other): + if a != b: + return False + return True + def get_next(self): # Get the next set of properties if self._iterator is None: diff --git a/ultraplot/tests/test_1dplots.py b/ultraplot/tests/test_1dplots.py index 5510695c6..0d4dac394 100644 --- a/ultraplot/tests/test_1dplots.py +++ b/ultraplot/tests/test_1dplots.py @@ -458,3 +458,50 @@ def test_norm_not_modified(): assert norm.vmin == 0 assert norm.vmax == 1 return fig + + +@pytest.mark.mpl_image_compare +def test_line_plot_cyclers(): + # Sample data + M, N = 50, 10 + state = np.random.RandomState(51423) + data1 = (state.rand(M, N) - 0.48).cumsum(axis=1).cumsum(axis=0) + data2 = (state.rand(M, N) - 0.48).cumsum(axis=1).cumsum(axis=0) * 1.5 + data1 += state.rand(M, N) + data2 += state.rand(M, N) + data1 *= 2 + + cmaps = ("Blues", "Reds") + cycle = uplt.Cycle(*cmaps) + + # Use property cycle for columns of 2D input data + fig, ax = uplt.subplots(ncols=3, sharey=True) + + # Intention of subplots + ax[0].set_title("Property cycle") + ax[1].set_title("Joined cycle") + ax[2].set_title("Separate cycles") + + ax[0].plot( + data1 + data2, + cycle="black", # cycle from monochromatic colormap + cycle_kw={"ls": ("-", "--", "-.", ":")}, + ) + + # Plot all dat with both cyclers on + ax[1].plot( + (data1 + data2), + cycle=cycle, + ) + + # Test cyclers separately + cycle = uplt.Cycle(*cmaps) + for idx in range(0, N): + ax[2].plot( + (data1[..., idx] + data2[..., idx]), + cycle=cycle, + cycle_kw={"N": N, "left": 0.3}, + ) + + fig.format(xlabel="xlabel", ylabel="ylabel", suptitle="On-the-fly property cycles") + return fig diff --git a/ultraplot/tests/test_format.py b/ultraplot/tests/test_format.py index 9d8be905d..cd073e2ba 100644 --- a/ultraplot/tests/test_format.py +++ b/ultraplot/tests/test_format.py @@ -340,3 +340,58 @@ def test_label_settings(): ax.format(xlabel="xlabel", ylabel="ylabel") ax.format(labelcolor="red") return fig + + +def test_colormap_parsing(): + """Test colormaps merging""" + reds = uplt.colormaps.get_cmap("reds") + blues = uplt.colormaps.get_cmap("blues") + + # helper function to test specific values in the colormaps + # threshold is used due to rounding errors + def test_range( + a: uplt.Colormap, + b: uplt.Colormap, + threshold=1e-10, + ranges=[0.0, 1.0], + ): + for i in ranges: + if not np.allclose(a(i), b(i)): + raise ValueError(f"Colormaps differ !") + + # Test if the colormaps are the same + test_range(uplt.Colormap("blues"), blues) + test_range(uplt.Colormap("reds"), reds) + # For joint colormaps, the lower value should be the lower of the first cmap and the highest should be the highest of the second cmap + test_range(uplt.Colormap("blues", "reds"), reds, ranges=[1.0]) + # Note: the ranges should not match either of the original colormaps + with pytest.raises(ValueError): + test_range(uplt.Colormap("blues", "reds"), reds) + + +def test_input_parsing_cycle(): + """ + Test the potential inputs to cycle + """ + # The first argument is a string or an iterable of strings + with pytest.raises(ValueError): + cycle = uplt.Cycle(None) + + # Empty should also be handled + cycle = uplt.Cycle() + + # Test singular string + cycle = uplt.Cycle("Blues") + target = uplt.colormaps.get_cmap("blues") + first_color = cycle.get_next()["color"] + first_color = uplt.colors.to_rgba(first_color) + assert np.allclose(first_color, target(0)) + + # Test composition + cycle = uplt.Cycle("Blues", "Reds", N=2) + lower_half = uplt.colormaps.get_cmap("blues") + upper_half = uplt.colormaps.get_cmap("reds") + first_color = uplt.colors.to_rgba(cycle.get_next()["color"]) + last_color = uplt.colors.to_rgba(cycle.get_next()["color"]) + assert np.allclose(first_color, lower_half(0.0)) + assert np.allclose(last_color, upper_half(1.0))