diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index cb8b5d63d..d2a438dfd 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -20,6 +20,7 @@ import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; import java.util.Objects; import java.util.Queue; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscriber; @@ -55,6 +56,8 @@ public final class UnboundedProcessor extends FluxProcessor volatile boolean cancelled; + volatile boolean terminated; + volatile int once; @SuppressWarnings("rawtypes") @@ -124,6 +127,9 @@ void drainRegular(Subscriber a) { } if (checkTerminated(d, empty, a)) { + if (!empty) { + release(t); + } return; } @@ -159,7 +165,9 @@ void drainFused(Subscriber a) { for (; ; ) { if (cancelled) { - this.clear(); + if (terminated) { + this.clear(); + } hasDownstream = false; return; } @@ -189,7 +197,7 @@ void drainFused(Subscriber a) { public void drain() { if (WIP.getAndIncrement(this) != 0) { - if (cancelled) { + if ((!outputFused && cancelled) || terminated) { this.clear(); } return; @@ -350,7 +358,9 @@ public void cancel() { cancelled = true; if (WIP.getAndIncrement(this) == 0) { - this.clear(); + if (!outputFused || terminated) { + this.clear(); + } hasDownstream = false; } } @@ -377,6 +387,7 @@ public boolean isEmpty() { @Override public void clear() { + terminated = true; if (DISCARD_GUARD.getAndIncrement(this) != 0) { return; } @@ -384,17 +395,12 @@ public void clear() { int missed = 1; for (; ; ) { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - release(t); - } + T t; + while ((t = queue.poll()) != null) { + release(t); } - while (!priorityQueue.isEmpty()) { - T t = priorityQueue.poll(); - if (t != null) { - release(t); - } + while ((t = priorityQueue.poll()) != null) { + release(t); } missed = DISCARD_GUARD.addAndGet(this, -missed); @@ -415,7 +421,43 @@ public int requestFusion(int requestedMode) { @Override public void dispose() { - cancel(); + if (cancelled) { + return; + } + + error = new CancellationException("Disposed"); + done = true; + + boolean once = true; + if (WIP.getAndIncrement(this) == 0) { + cancelled = true; + int m = 1; + for (; ; ) { + final CoreSubscriber a = this.actual; + + if (!outputFused || terminated) { + clear(); + } + + if (a != null && once) { + try { + a.onError(error); + } catch (Throwable ignored) { + } + } + + cancelled = true; + once = false; + + int wip = this.wip; + if (wip == m) { + break; + } + m = wip; + } + + hasDownstream = false; + } } @Override diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 7bf975543..552afb70c 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-present the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,115 +16,173 @@ package io.rsocket.internal; -import io.rsocket.Payload; -import io.rsocket.util.ByteBufPayload; -import io.rsocket.util.EmptyPayload; -import java.util.concurrent.CountDownLatch; -import org.junit.Assert; -import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; public class UnboundedProcessorTest { - @Test - public void testOnNextBeforeSubscribe_10() { - testOnNextBeforeSubscribeN(10); - } - - @Test - public void testOnNextBeforeSubscribe_100() { - testOnNextBeforeSubscribeN(100); - } - @Test - public void testOnNextBeforeSubscribe_10_000() { - testOnNextBeforeSubscribeN(10_000); + @BeforeAll + public static void setup() { + Hooks.onErrorDropped(__ -> {}); } - @Test - public void testOnNextBeforeSubscribe_100_000() { - testOnNextBeforeSubscribeN(100_000); - } - - @Test - public void testOnNextBeforeSubscribe_1_000_000() { - testOnNextBeforeSubscribeN(1_000_000); - } - - @Test - public void testOnNextBeforeSubscribe_10_000_000() { - testOnNextBeforeSubscribeN(10_000_000); + public static void teardown() { + Hooks.resetOnErrorDropped(); } + @ParameterizedTest( + name = + "Test that emitting {0} onNext before subscribe and requestN should deliver all the signals once the subscriber is available") + @ValueSource(ints = {10, 100, 10_000, 100_000, 1_000_000, 10_000_000}) public void testOnNextBeforeSubscribeN(int n) { - UnboundedProcessor processor = new UnboundedProcessor<>(); + UnboundedProcessor processor = new UnboundedProcessor<>(); for (int i = 0; i < n; i++) { - processor.onNext(EmptyPayload.INSTANCE); + processor.onNext(Unpooled.EMPTY_BUFFER); } processor.onComplete(); - long count = processor.count().block(); - - Assert.assertEquals(n, count); - } - - @Test - public void testOnNextAfterSubscribe_10() throws Exception { - testOnNextAfterSubscribeN(10); - } - - @Test - public void testOnNextAfterSubscribe_100() throws Exception { - testOnNextAfterSubscribeN(100); + StepVerifier.create(processor.count()).expectNext(Long.valueOf(n)).verifyComplete(); } - @Test - public void testOnNextAfterSubscribe_1000() throws Exception { - testOnNextAfterSubscribeN(1000); - } + @ParameterizedTest( + name = + "Test that emitting {0} onNext after subscribe and requestN should deliver all the signals") + @ValueSource(ints = {10, 100, 10_000}) + public void testOnNextAfterSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor<>(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); - @Test - public void testPrioritizedSending() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + processor.subscribe(assertSubscriber); - for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); } - processor.onNextPrioritized(ByteBufPayload.create("test")); - - Payload closestPayload = processor.next().block(); - - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + assertSubscriber.awaitAndAssertNextValueCount(n); } - @Test - public void testPrioritizedFused() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + @ParameterizedTest( + name = + "Test that prioritized value sending deliver prioritized signals before the others mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void testPrioritizedSending(boolean fusedCase) { + UnboundedProcessor processor = new UnboundedProcessor<>(); for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + processor.onNext(Unpooled.EMPTY_BUFFER); } - processor.onNextPrioritized(ByteBufPayload.create("test")); + processor.onNextPrioritized(Unpooled.copiedBuffer("test", CharsetUtil.UTF_8)); - Payload closestPayload = processor.poll(); - - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + assertThat(fusedCase ? processor.poll() : processor.next().block()) + .isNotNull() + .extracting(bb -> bb.toString(CharsetUtil.UTF_8)) + .isEqualTo("test"); } - public void testOnNextAfterSubscribeN(int n) throws Exception { - CountDownLatch latch = new CountDownLatch(n); - UnboundedProcessor processor = new UnboundedProcessor<>(); - processor.log().doOnNext(integer -> latch.countDown()).subscribe(); - - for (int i = 0; i < n; i++) { - System.out.println("onNexting -> " + i); - processor.onNext(EmptyPayload.INSTANCE); + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensureUnboundedProcessorDisposesQueueProperly(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (int i = 0; i < 100000; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor<>(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + }, + unboundedProcessor::dispose, + Schedulers.elastic()), + assertSubscriber::cancel, + Schedulers.elastic()), + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + }, + Schedulers.elastic()); + + assertSubscriber.values().forEach(ReferenceCountUtil::safeRelease); + + allocator.assertHasNoLeaks(); } + } - processor.drain(); - - latch.await(); + @RepeatedTest( + name = + "Ensures that racing between onNext + dispose | downstream async drain should not cause any issues and leaks", + value = 100000) + @Timeout(60) + public void ensuresAsyncFusionAndDisposureHasNoDeadlock() { + // TODO: enable leaks tracking + // final LeaksTrackingByteBufAllocator allocator = + // LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor<>(); + + // final ByteBuf buffer1 = allocator.buffer(1); + // final ByteBuf buffer2 = allocator.buffer(2); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber<>(Operators.enableOnDiscard(null, ReferenceCountUtil::safeRelease)); + + unboundedProcessor.publishOn(Schedulers.parallel()).subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + // unboundedProcessor.onNext(buffer1); + // unboundedProcessor.onNext(buffer2); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.onNext(Unpooled.EMPTY_BUFFER); + unboundedProcessor.dispose(); + }, + unboundedProcessor::dispose); + + assertSubscriber + .await(Duration.ofSeconds(50)) + .values() + .forEach(ReferenceCountUtil::safeRelease); + + // allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java index 84a589a8d..83d420d90 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BooleanSupplier; @@ -86,6 +87,10 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(AssertSubscriber.class, "wip"); + @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater NEXT_VALUES = AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); @@ -100,10 +105,14 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private final CountDownLatch cdl = new CountDownLatch(1); + volatile boolean done; + volatile Subscription s; volatile long requested; + volatile int wip; + volatile List values = new LinkedList<>(); /** The fusion mode to request. */ @@ -377,7 +386,7 @@ public final AssertSubscriber assertError(Class clazz) { } } if (s > 1) { - throw new AssertionError("Multiple errors: " + s, null); + throw new AssertionError("Multiple errors: " + errors, null); } return this; } @@ -854,6 +863,10 @@ public void cancel() { a = S.getAndSet(this, Operators.cancelledSubscription()); if (a != null && a != Operators.cancelledSubscription()) { a.cancel(); + + if (establishedFusionMode == Fuseable.ASYNC && WIP.getAndIncrement(this) == 0) { + qs.clear(); + } } } } @@ -868,37 +881,77 @@ public final boolean isTerminated() { @Override public void onComplete() { + done = true; completionCount++; + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onError(Throwable t) { + done = true; errors.add(t); + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onNext(T t) { if (establishedFusionMode == Fuseable.ASYNC) { - for (; ; ) { - t = qs.poll(); - if (t == null) { - break; - } - valueCount++; - if (valuesStorage) { - List nextValuesSnapshot; - for (; ; ) { - nextValuesSnapshot = values; - nextValuesSnapshot.add(t); - if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { - break; - } + drain(); + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; } } } - } else { + } + } + + void drain() { + if (this.wip != 0 || WIP.getAndIncrement(this) != 0) { + if (isCancelled()) { + qs.clear(); + } + return; + } + + T t; + int m = 1; + for (; ; ) { + if (isCancelled()) { + qs.clear(); + break; + } + boolean done = this.done; + t = qs.poll(); + if (t == null) { + if (done) { + cdl.countDown(); + return; + } + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + continue; + } valueCount++; if (valuesStorage) { List nextValuesSnapshot;