diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java new file mode 100644 index 000000000..4b07f04c7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java @@ -0,0 +1,399 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameUtil; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; + +/** + * {@link DuplexConnection#receive()} is a single stream on which the following type of frames + * arrive: + * + * + * + *

The only way to differentiate these two frames is determining whether the stream Id is odd or + * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the + * client. + */ +class ClientServerInputMultiplexer implements CoreSubscriber, Closeable { + + private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + private static final InitializingInterceptorRegistry emptyInterceptorRegistry = + new InitializingInterceptorRegistry(); + + private final InternalDuplexConnection setupReceiver; + private final InternalDuplexConnection serverReceiver; + private final InternalDuplexConnection clientReceiver; + private final DuplexConnection setupConnection; + private final DuplexConnection serverConnection; + private final DuplexConnection clientConnection; + private final DuplexConnection source; + private final boolean isClient; + + private Subscription s; + private boolean setupReceived; + + private Throwable t; + + private volatile int state; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ClientServerInputMultiplexer.class, "state"); + + public ClientServerInputMultiplexer(DuplexConnection source) { + this(source, emptyInterceptorRegistry, false); + } + + public ClientServerInputMultiplexer( + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { + this.source = source; + this.isClient = isClient; + source = registry.initConnection(Type.SOURCE, source); + + if (!isClient) { + setupReceiver = new InternalDuplexConnection(this, source); + setupConnection = registry.initConnection(Type.SETUP, setupReceiver); + } else { + setupReceiver = null; + setupConnection = null; + } + serverReceiver = new InternalDuplexConnection(this, source); + clientReceiver = new InternalDuplexConnection(this, source); + serverConnection = registry.initConnection(Type.SERVER, serverReceiver); + clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); + } + + public DuplexConnection asClientServerConnection() { + return source; + } + + public DuplexConnection asServerConnection() { + return serverConnection; + } + + public DuplexConnection asClientConnection() { + return clientConnection; + } + + public DuplexConnection asSetupConnection() { + return setupConnection; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + if (isClient) { + s.request(Long.MAX_VALUE); + } else { + // request first SetupFrame + s.request(1); + } + } + } + + @Override + public void onNext(ByteBuf frame) { + int streamId = FrameHeaderCodec.streamId(frame); + final Type type; + if (streamId == 0) { + switch (FrameHeaderCodec.frameType(frame)) { + case SETUP: + case RESUME: + case RESUME_OK: + type = Type.SETUP; + setupReceived = true; + break; + case LEASE: + case KEEPALIVE: + case ERROR: + type = isClient ? Type.CLIENT : Type.SERVER; + break; + default: + type = isClient ? Type.SERVER : Type.CLIENT; + } + } else if ((streamId & 0b1) == 0) { + type = Type.SERVER; + } else { + type = Type.CLIENT; + } + if (!isClient && type != Type.SETUP && !setupReceived) { + final IllegalStateException error = + new IllegalStateException("SETUP or LEASE frame must be received before any others."); + this.s.cancel(); + onError(error); + } + + switch (type) { + case SETUP: + final InternalDuplexConnection setupReceiver = this.setupReceiver; + setupReceiver.onNext(frame); + setupReceiver.onComplete(); + break; + case CLIENT: + clientReceiver.onNext(frame); + break; + case SERVER: + serverReceiver.onNext(frame); + break; + } + } + + @Override + public void onComplete() { + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (!isClient) { + if (!setupReceived) { + setupReceiver.onComplete(); + } + + if (previousState == 1) { + return; + } + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onComplete(); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + this.t = t; + + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (!isClient) { + if (!setupReceived) { + setupReceiver.onError(t); + } + + if (previousState == 1) { + return; + } + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onError(t); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onError(t); + } + } + + boolean notifyRequested() { + final int currentState = incrementAndGetCheckingState(); + if (currentState == Integer.MIN_VALUE) { + return false; + } + + if (isClient) { + if (currentState == 2) { + source.receive().subscribe(this); + } + } else { + if (currentState == 1) { + source.receive().subscribe(this); + } else if (currentState == 3) { + // means setup was consumed and we got request from client and server multiplexers + s.request(Long.MAX_VALUE); + } + } + + return true; + } + + int incrementAndGetCheckingState() { + int prev, next; + for (; ; ) { + prev = this.state; + + if (prev == Integer.MIN_VALUE) { + return prev; + } + + next = prev + 1; + if (STATE.compareAndSet(this, prev, next)) { + return next; + } + } + } + + private static class InternalDuplexConnection extends Flux + implements Subscription, DuplexConnection { + private final ClientServerInputMultiplexer clientServerInputMultiplexer; + private final DuplexConnection source; + private final boolean debugEnabled; + + private volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(InternalDuplexConnection.class, "state"); + + CoreSubscriber actual; + + public InternalDuplexConnection( + ClientServerInputMultiplexer clientServerInputMultiplexer, DuplexConnection source) { + this.clientServerInputMultiplexer = clientServerInputMultiplexer; + this.source = source; + this.debugEnabled = LOGGER.isDebugEnabled(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (this.state == 0 && STATE.compareAndSet(this, 0, 1)) { + this.actual = actual; + actual.onSubscribe(this); + } else { + Operators.error( + actual, + new IllegalStateException("InternalDuplexConnection allows only single subscription")); + } + } + + @Override + public void request(long n) { + if (this.state == 1 && STATE.compareAndSet(this, 1, 2)) { + final ClientServerInputMultiplexer multiplexer = clientServerInputMultiplexer; + if (!multiplexer.notifyRequested()) { + final Throwable t = multiplexer.t; + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + } + } + + @Override + public void cancel() { + // no ops + } + + void onNext(ByteBuf frame) { + this.actual.onNext(frame); + } + + void onComplete() { + this.actual.onComplete(); + } + + void onError(Throwable t) { + this.actual.onError(t); + } + + @Override + public Mono send(Publisher frame) { + if (debugEnabled) { + return Flux.from(frame) + .doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))) + .as(source::send); + } + + return source.send(frame); + } + + @Override + public Mono sendOne(ByteBuf frame) { + if (debugEnabled) { + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); + } + + return source.sendOne(frame); + } + + @Override + public Flux receive() { + if (debugEnabled) { + return this.doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } else { + return this; + } + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + public boolean isSubscribed() { + return this.state != 0; + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public double availability() { + return source.availability(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index 5664eace3..c7caba946 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -29,7 +29,6 @@ import io.rsocket.SocketAcceptor; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.lease.LeaseStats; import io.rsocket.lease.Leases; diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index f7fbbe9bd..5a411e464 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -33,7 +33,6 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.lease.Leases; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.lease.ResponderLeaseHandler; diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java index 337d17c64..eb86c6734 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -25,7 +25,6 @@ import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.SetupFrameCodec; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.resume.*; import java.time.Duration; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 48ae62906..179a7a757 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -43,7 +43,11 @@ *

The only way to differentiate these two frames is determining whether the stream Id is odd or * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the * client. + * + * @deprecated since 1.1.0-M1 in favor of package-private {@link + * io.rsocket.core.ClientServerInputMultiplexer} */ +@Deprecated public class ClientServerInputMultiplexer implements Closeable { private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); private static final InitializingInterceptorRegistry emptyInterceptorRegistry = diff --git a/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java new file mode 100644 index 000000000..d065f3d71 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.presentation.StandardRepresentation; + +public final class ByteBufRepresentation extends StandardRepresentation { + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 256) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..141ed4385 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,186 @@ +package io.rsocket.test; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import java.time.Duration; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + Assertions.assertThat(tracker) + .allSatisfy( + buf -> + Assertions.assertThat(buf) + .matches( + bb -> { + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!awaitZeroRefCntDuration.isZero()) { + long end = + awaitZeroRefCntDuration.plusNanos(System.nanoTime()).toNanos(); + while (bb.refCnt() != 0) { + if (System.nanoTime() >= end) { + break; + } + parkNanos(100); + } + } + return bb.refCnt() == 0; + }, + "buffer should be released")); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index d71c2ee21..e322ad292 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -16,10 +16,13 @@ package io.rsocket.test; +import static java.util.concurrent.locks.LockSupport.parkNanos; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.util.ByteBufPayload; -import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -28,6 +31,9 @@ public class TestRSocket implements RSocket { private final String data; private final String metadata; + private final AtomicLong observedInteractions = new AtomicLong(); + private final AtomicLong activeInteractions = new AtomicLong(); + public TestRSocket(String data, String metadata) { this.data = data; this.metadata = metadata; @@ -35,30 +41,69 @@ public TestRSocket(String data, String metadata) { @Override public Mono requestResponse(Payload payload) { + activeInteractions.getAndIncrement(); payload.release(); - return Mono.just(ByteBufPayload.create(data, metadata)); + observedInteractions.getAndIncrement(); + return Mono.just(ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestStream(Payload payload) { + activeInteractions.getAndIncrement(); payload.release(); - return Flux.range(1, 10_000).flatMap(l -> requestResponse(EmptyPayload.INSTANCE)); + observedInteractions.getAndIncrement(); + return Flux.range(1, 10_000) + .map(l -> ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono metadataPush(Payload payload) { + activeInteractions.getAndIncrement(); payload.release(); - return Mono.empty(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono fireAndForget(Payload payload) { + activeInteractions.getAndIncrement(); payload.release(); - return Mono.empty(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads); + activeInteractions.getAndIncrement(); + observedInteractions.getAndIncrement(); + return Flux.from(payloads).doFinally(__ -> activeInteractions.getAndDecrement()); + } + + public boolean awaitAllInteractionTermination(Duration duration) { + long end = duration.plusNanos(System.nanoTime()).toNanos(); + long activeNow; + while ((activeNow = activeInteractions.get()) > 0) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return activeNow == 0; + } + + public boolean awaitUntilObserved(int interactions, Duration duration) { + long end = duration.plusNanos(System.nanoTime()).toNanos(); + long observed; + while ((observed = observedInteractions.get()) < interactions) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return observed >= interactions; } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index d30d64100..870038691 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -16,21 +16,30 @@ package io.rsocket.test; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import java.io.BufferedReader; import java.io.InputStreamReader; import java.time.Duration; import java.util.concurrent.CancellationException; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; -import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; @@ -39,15 +48,25 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.junit.platform.commons.logging.Logger; +import org.junit.platform.commons.logging.LoggerFactory; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; import reactor.core.Disposable; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; public interface TransportTest { + Logger logger = LoggerFactory.getLogger(TransportTest.class); + String MOCK_DATA = "test-data"; String MOCK_METADATA = "metadata"; String LARGE_DATA = read("words.shakespeare.txt.gz"); @@ -74,7 +93,9 @@ default void setUp() { @AfterEach default void close() { + getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); getTransportPair().dispose(); + getTransportPair().byteBufAllocator.assertHasNoLeaks(); Hooks.resetOnOperatorDebug(); } @@ -94,7 +115,7 @@ default Payload createTestPayload(int metadataPresent) { } String metadata = metadata1; - return DefaultPayload.create(MOCK_DATA, metadata); + return ByteBufPayload.create(MOCK_DATA, metadata); } @DisplayName("makes 10 fireAndForget requests") @@ -103,9 +124,10 @@ default void fireAndForget10() { Flux.range(1, 10) .flatMap(i -> getClient().fireAndForget(createTestPayload(i))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 10 fireAndForget with Large Payload in Requests") @@ -114,9 +136,10 @@ default void largePayloadFireAndForget10() { Flux.range(1, 10) .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD)) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } default RSocket getClient() { @@ -131,22 +154,24 @@ default RSocket getClient() { @Test default void metadataPush10() { Flux.range(1, 10) - .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", "test-metadata"))) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 10 metadataPush with Large Metadata in requests") @Test default void largePayloadMetadataPush10() { Flux.range(1, 10) - .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA))) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 1 requestChannel request with 0 payloads") @@ -155,7 +180,6 @@ default void requestChannel0() { getClient() .requestChannel(Flux.empty()) .as(StepVerifier::create) - .expectNextCount(0) .expectErrorSatisfies( t -> Assertions.assertThat(t) @@ -169,6 +193,7 @@ default void requestChannel0() { default void requestChannel1() { getClient() .requestChannel(Mono.just(createTestPayload(0))) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(1) .expectComplete() @@ -182,6 +207,7 @@ default void requestChannel200_000() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(200_000) .expectComplete() @@ -195,6 +221,7 @@ default void largePayloadRequestChannel200() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(200) .expectComplete() @@ -209,6 +236,7 @@ default void requestChannel20_000() { getClient() .requestChannel(payloads) .doOnNext(this::assertChannelPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(20_000) .expectComplete() @@ -222,6 +250,7 @@ default void requestChannel2_000_000() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(2_000_000) .expectComplete() @@ -237,6 +266,7 @@ default void requestChannel3() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(publisher -> StepVerifier.create(publisher, 3)) .expectNextCount(3) .expectComplete() @@ -249,16 +279,17 @@ default void requestChannel3() { @Test default void requestChannel512() { Flux payloads = Flux.range(0, 512).map(this::createTestPayload); + final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(13)); Flux.range(0, 1024) - .flatMap( - v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(Schedulers.elastic()), 12) + .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) .blockLast(); } default void check(Flux payloads) { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(512) .as("expected 512 items") @@ -272,6 +303,7 @@ default void requestResponse1() { getClient() .requestResponse(createTestPayload(1)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(1) .expectComplete() @@ -284,6 +316,7 @@ default void requestResponse10() { Flux.range(1, 10) .flatMap( i -> getClient().requestResponse(createTestPayload(i)).doOnNext(v -> assertPayload(v))) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(10) .expectComplete() @@ -294,7 +327,8 @@ default void requestResponse10() { @Test default void requestResponse100() { Flux.range(1, 100) - .flatMap(i -> getClient().requestResponse(createTestPayload(i)).map(Payload::getDataUtf8)) + .flatMap(i -> getClient().requestResponse(createTestPayload(i))) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(100) .expectComplete() @@ -305,7 +339,8 @@ default void requestResponse100() { @Test default void largePayloadRequestResponse100() { Flux.range(1, 100) - .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8)) + .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD)) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(100) .expectComplete() @@ -316,7 +351,8 @@ default void largePayloadRequestResponse100() { @Test default void requestResponse10_000() { Flux.range(1, 10_000) - .flatMap(i -> getClient().requestResponse(createTestPayload(i)).map(Payload::getDataUtf8)) + .flatMap(i -> getClient().requestResponse(createTestPayload(i))) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(10_000) .expectComplete() @@ -329,6 +365,7 @@ default void requestStream10_000() { getClient() .requestStream(createTestPayload(3)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(10_000) .expectComplete() @@ -341,6 +378,7 @@ default void requestStream5() { getClient() .requestStream(createTestPayload(3)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .take(5) .as(StepVerifier::create) .expectNextCount(5) @@ -354,6 +392,7 @@ default void requestStreamDelayedRequestN() { getClient() .requestStream(createTestPayload(3)) .take(10) + .doOnNext(Payload::release) .as(StepVerifier::create) .thenRequest(5) .expectNextCount(5) @@ -381,24 +420,81 @@ final class TransportPair implements Disposable { private static final String data = "hello world"; private static final String metadata = "metadata"; + private final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT, Duration.ofMinutes(1)); + + private final TestRSocket responder; + private final RSocket client; private final S server; public TransportPair( Supplier addressSupplier, - BiFunction clientTransportSupplier, - Function> serverTransportSupplier) { + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier) { T address = addressSupplier.get(); + final boolean runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + final boolean runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + + ByteBufAllocator allocatorToSupply; + if (ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.ADVANCED + || ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.PARANOID) { + logger.info(() -> "Using LeakTrackingByteBufAllocator"); + allocatorToSupply = byteBufAllocator; + } else { + allocatorToSupply = ByteBufAllocator.DEFAULT; + } + responder = new TestRSocket(TransportPair.data, metadata); server = - RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .bind(serverTransportSupplier.apply(address)) + RSocketServer.create((setup, sendingSocket) -> Mono.just(responder)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors( + registry -> { + if (runServerWithAsyncInterceptors) { + logger.info( + () -> + "Perform Integration Test with Async Interceptors Enabled For Server"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection)) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + }) + .bind(serverTransportSupplier.apply(address, allocatorToSupply)) .block(); client = - RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMillis(Integer.MAX_VALUE), Duration.ofMillis(Integer.MAX_VALUE)) + .interceptors( + registry -> { + if (runClientWithAsyncInterceptors) { + logger.info( + () -> + "Perform Integration Test with Async Interceptors Enabled For Client"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection)) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + }) + .connect(clientTransportSupplier.apply(address, server, allocatorToSupply)) .doOnError(Throwable::printStackTrace) .block(); } @@ -406,6 +502,7 @@ public TransportPair( @Override public void dispose() { server.dispose(); + client.dispose(); } RSocket getClient() { @@ -419,5 +516,118 @@ public String expectedPayloadData() { public String expectedPayloadMetadata() { return metadata; } + + private static class AsyncDuplexConnection implements DuplexConnection { + + private final DuplexConnection duplexConnection; + + public AsyncDuplexConnection(DuplexConnection duplexConnection) { + this.duplexConnection = duplexConnection; + } + + @Override + public Mono send(Publisher frames) { + return duplexConnection.send(frames); + } + + @Override + public Flux receive() { + return duplexConnection + .receive() + .subscribeOn(Schedulers.parallel()) + .doOnNext(ByteBuf::retain) + .publishOn(Schedulers.parallel(), Integer.MAX_VALUE) + .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::safeRelease) + .transform( + Operators.lift( + (__, actual) -> new ByteBufReleaserOperator(actual))); + } + + @Override + public ByteBufAllocator alloc() { + return duplexConnection.alloc(); + } + + @Override + public Mono onClose() { + return duplexConnection.onClose(); + } + + @Override + public void dispose() { + duplexConnection.dispose(); + } + } + + private static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + + Subscription s; + + public ByteBufReleaserOperator(CoreSubscriber actual) { + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + actual.onNext(buf); + buf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java new file mode 100644 index 000000000..87a1d4dbf --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java @@ -0,0 +1,6 @@ +package io.rsocket.test; + +@FunctionalInterface +public interface TriFunction { + R apply(T1 t1, T2 t2, T3 t3); +} diff --git a/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation new file mode 100644 index 000000000..0c33b5ff7 --- /dev/null +++ b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -0,0 +1,16 @@ +# +# Copyright 2015-2018 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +io.rsocket.test.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java index ffc9ccb3a..e9c137255 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java @@ -25,8 +25,8 @@ final class LocalTransportTest implements TransportTest { private final TransportPair transportPair = new TransportPair<>( () -> "test-" + UUID.randomUUID(), - (address, server) -> LocalClientTransport.create(address), - LocalServerTransport::create); + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address)); @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java index 95bebd6aa..85481924a 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -1,5 +1,6 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -17,20 +18,22 @@ public class TcpSecureTransportTest implements TransportTest { private final TransportPair transportPair = new TransportPair<>( () -> new InetSocketAddress("localhost", 0), - (address, server) -> + (address, server, allocator) -> TcpClientTransport.create( TcpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) .remoteAddress(server::address) .secure( ssl -> ssl.sslContext( SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE)))), - address -> { + (address, allocator) -> { try { SelfSignedCertificate ssc = new SelfSignedCertificate(); TcpServer server = TcpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) .bindAddress(() -> address) .secure( ssl -> diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java index 182be1d91..c474f9b0b 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java @@ -16,19 +16,30 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.rsocket.test.TransportTest; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; final class TcpTransportTest implements TransportTest { private final TransportPair transportPair = new TransportPair<>( () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> TcpClientTransport.create(server.address()), - TcpServerTransport::create); + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> + TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator))); @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index 15f9ae3df..9777c8bfa 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -16,6 +16,7 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -34,9 +35,10 @@ final class WebsocketSecureTransportTest implements TransportTest { private final TransportPair transportPair = new TransportPair<>( () -> new InetSocketAddress("localhost", 0), - (address, server) -> + (address, server, allocator) -> WebsocketClientTransport.create( HttpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) .remoteAddress(server::address) .secure( ssl -> @@ -46,11 +48,12 @@ final class WebsocketSecureTransportTest implements TransportTest { String.format( "https://%s:%d/", server.address().getHostName(), server.address().getPort())), - address -> { + (address, allocator) -> { try { SelfSignedCertificate ssc = new SelfSignedCertificate(); HttpServer server = HttpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) .bindAddress(() -> address) .secure( ssl -> diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java index 10d27daeb..93d7bdb2f 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java @@ -16,19 +16,33 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.rsocket.test.TransportTest; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; final class WebsocketTransportTest implements TransportTest { private final TransportPair transportPair = new TransportPair<>( () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> WebsocketClientTransport.create(server.address()), - address -> WebsocketServerTransport.create(address.getHostName(), address.getPort())); + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> + WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator))); @Override public Duration getTimeout() {