diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java index 511f839b823e..f4ab8bb73e64 100644 --- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java @@ -156,16 +156,20 @@ public interface OutputChunkConsumer { } /** - * An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values - * using the specified {@link Coder}. + * An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T} + * values using the specified {@link Coder}. * *

Note that this adapter follows the Beam Fn API specification for forcing values that decode * consuming zero bytes to consuming exactly one byte. * *

Note that access to the underlying {@link InputStream} is lazy and will only be invoked on - * first access to {@link #next()} or {@link #hasNext()}. + * first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}. + * + *

Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being + * returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will + * seemingly make zero progress yet will actually advance through the empty pages. */ - public static class DataStreamDecoder implements Iterator { + public static class DataStreamDecoder implements PrefetchableIterator { private enum State { READ_REQUIRED, @@ -173,19 +177,44 @@ private enum State { EOF } - private final Iterator inputByteStrings; + private final PrefetchableIterator inputByteStrings; private final Inbound inbound; private final Coder coder; private State currentState; private T next; - public DataStreamDecoder(Coder coder, Iterator inputStream) { + public DataStreamDecoder(Coder coder, PrefetchableIterator inputStream) { this.currentState = State.READ_REQUIRED; this.coder = coder; this.inputByteStrings = inputStream; this.inbound = new Inbound(); } + @Override + public boolean isReady() { + switch (currentState) { + case EOF: + return true; + case READ_REQUIRED: + try { + return inbound.isReady(); + } catch (IOException e) { + throw new RuntimeException(e); + } + case HAS_NEXT: + return true; + default: + throw new IllegalStateException(String.format("Unknown state %s", currentState)); + } + } + + @Override + public void prefetch() { + if (!isReady()) { + inputByteStrings.prefetch(); + } + } + @Override public boolean hasNext() { switch (currentState) { @@ -232,8 +261,8 @@ public void remove() { private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput(); /** - * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first - * {@link Iterator} on first access of this input stream. + * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link + * Iterator} on first access of this input stream. * *

Closing this input stream has no effect. */ @@ -245,6 +274,22 @@ public Inbound() { this.currentStream = EMPTY_STREAM; } + public boolean isReady() throws IOException { + // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString + // minus the number of bytes that have been read so far and can be reliably used to tell + // us whether we are at the end of the stream. + while (currentStream.available() == 0) { + if (!inputByteStrings.isReady()) { + return false; + } + if (!inputByteStrings.hasNext()) { + return true; + } + currentStream = inputByteStrings.next().newInput(); + } + return true; + } + public boolean isEof() throws IOException { // Note that ByteString#newInput is guaranteed to return the length of the entire ByteString // minus the number of bytes that have been read so far and can be reliably used to tell diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java index 9dd5ee496d72..a8b48e844b62 100644 --- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java +++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/DataStreamsTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; import java.io.IOException; @@ -106,7 +107,7 @@ public void testNonEmptyInputStream() throws Exception { } @Test - public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception { + public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception { CountingOutputStream countingOutputStream = new CountingOutputStream(ByteStreams.nullOutputStream()); GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream); @@ -115,6 +116,55 @@ public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception { testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE); } + @Test + public void testPrefetch() throws Exception { + List encodings = new ArrayList<>(); + { + ByteString.Output encoding = ByteString.newOutput(); + StringUtf8Coder.of().encode("A", encoding); + StringUtf8Coder.of().encode("BC", encoding); + encodings.add(encoding.toByteString()); + } + encodings.add(ByteString.EMPTY); + { + ByteString.Output encoding = ByteString.newOutput(); + StringUtf8Coder.of().encode("DEF", encoding); + StringUtf8Coder.of().encode("GHIJ", encoding); + encodings.add(encoding.toByteString()); + } + + PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext iterator = + new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator()); + PrefetchableIterator decoder = + new DataStreamDecoder<>(StringUtf8Coder.of(), iterator); + assertFalse(decoder.isReady()); + decoder.prefetch(); + assertTrue(decoder.isReady()); + assertEquals(1, iterator.getNumPrefetchCalls()); + + decoder.next(); + // Now we will have moved off of the empty byte array that we start with so prefetch will + // do nothing since we are ready + assertTrue(decoder.isReady()); + decoder.prefetch(); + assertEquals(1, iterator.getNumPrefetchCalls()); + + decoder.next(); + // Now we are at the end of the first ByteString so we expect a prefetch to pass through + assertFalse(decoder.isReady()); + decoder.prefetch(); + assertEquals(2, iterator.getNumPrefetchCalls()); + // We also expect the decoder to not be ready since the next byte string is empty which + // would require us to move to the next page. This typically wouldn't happen in practice + // though because we expect non empty pages. + assertFalse(decoder.isReady()); + + // Prefetching will allow us to move to the third ByteString + decoder.prefetch(); + assertEquals(3, iterator.getNumPrefetchCalls()); + assertTrue(decoder.isReady()); + } + private void testDecoderWith(Coder coder, T... expected) throws IOException { ByteString.Output output = ByteString.newOutput(); for (T value : expected) { @@ -131,7 +181,9 @@ private void testDecoderWith(Coder coder, T... expected) throws IOExcepti } private void testDecoderWith(Coder coder, T[] expected, List encoded) { - Iterator decoder = new DataStreamDecoder<>(coder, encoded.iterator()); + Iterator decoder = + new DataStreamDecoder<>( + coder, PrefetchableIterators.maybePrefetchable(encoded.iterator())); Object[] actual = Iterators.toArray(decoder, Object.class); assertArrayEquals(expected, actual); diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java index 6131634ac523..9ada1759bad1 100644 --- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java +++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/stream/PrefetchableIteratorsTest.java @@ -120,10 +120,14 @@ public void testConcat() { "F"); } - private static class NeverReady implements PrefetchableIterator { - PrefetchableIterator delegate = PrefetchableIterators.fromArray("A", "B"); + public static class NeverReady implements PrefetchableIterator { + private final Iterator delegate; int prefetchCalled; + public NeverReady(Iterator delegate) { + this.delegate = delegate; + } + @Override public boolean isReady() { return false; @@ -140,74 +144,117 @@ public boolean hasNext() { } @Override - public String next() { + public T next() { return delegate.next(); } + + public int getNumPrefetchCalls() { + return prefetchCalled; + } } - private static class ReadyAfterPrefetch extends NeverReady { + public static class ReadyAfterPrefetch extends NeverReady { + + public ReadyAfterPrefetch(Iterator delegate) { + super(delegate); + } + @Override public boolean isReady() { return prefetchCalled > 0; } } + public static class ReadyAfterPrefetchUntilNext extends ReadyAfterPrefetch { + boolean advancedSincePrefetch; + + public ReadyAfterPrefetchUntilNext(Iterator delegate) { + super(delegate); + } + + @Override + public boolean isReady() { + return !advancedSincePrefetch && super.isReady(); + } + + @Override + public void prefetch() { + advancedSincePrefetch = false; + super.prefetch(); + } + + @Override + public T next() { + advancedSincePrefetch = true; + return super.next(); + } + + @Override + public boolean hasNext() { + advancedSincePrefetch = true; + return super.hasNext(); + } + } + @Test public void testConcatIsReadyAdvancesToNextIteratorWhenAble() { - NeverReady readyAfterPrefetch1 = new NeverReady(); - ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch(); - ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch(); + NeverReady readyAfterPrefetch1 = + new NeverReady<>(PrefetchableIterators.fromArray("A", "B")); + ReadyAfterPrefetch readyAfterPrefetch2 = + new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B")); + ReadyAfterPrefetch readyAfterPrefetch3 = + new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B")); PrefetchableIterator iterator = PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3); // Expect no prefetches yet - assertEquals(0, readyAfterPrefetch1.prefetchCalled); - assertEquals(0, readyAfterPrefetch2.prefetchCalled); - assertEquals(0, readyAfterPrefetch3.prefetchCalled); + assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls()); // We expect to attempt to prefetch for the first time. iterator.prefetch(); - assertEquals(1, readyAfterPrefetch1.prefetchCalled); - assertEquals(0, readyAfterPrefetch2.prefetchCalled); - assertEquals(0, readyAfterPrefetch3.prefetchCalled); + assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); // We expect to attempt to prefetch again since we aren't ready. iterator.prefetch(); - assertEquals(2, readyAfterPrefetch1.prefetchCalled); - assertEquals(0, readyAfterPrefetch2.prefetchCalled); - assertEquals(0, readyAfterPrefetch3.prefetchCalled); + assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); // The current iterator is done but is never ready so we can't advance to the next one and // expect another prefetch to go to the current iterator. iterator.prefetch(); - assertEquals(3, readyAfterPrefetch1.prefetchCalled); - assertEquals(0, readyAfterPrefetch2.prefetchCalled); - assertEquals(0, readyAfterPrefetch3.prefetchCalled); + assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); // Now that we know the last iterator is done and have advanced to the next one we expect // prefetch to go through iterator.prefetch(); - assertEquals(3, readyAfterPrefetch1.prefetchCalled); - assertEquals(1, readyAfterPrefetch2.prefetchCalled); - assertEquals(0, readyAfterPrefetch3.prefetchCalled); + assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); // The last iterator is done so we should be able to prefetch the next one before advancing iterator.prefetch(); - assertEquals(3, readyAfterPrefetch1.prefetchCalled); - assertEquals(1, readyAfterPrefetch2.prefetchCalled); - assertEquals(1, readyAfterPrefetch3.prefetchCalled); + assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); // The current iterator is ready so no additional prefetch is necessary iterator.prefetch(); - assertEquals(3, readyAfterPrefetch1.prefetchCalled); - assertEquals(1, readyAfterPrefetch2.prefetchCalled); - assertEquals(1, readyAfterPrefetch3.prefetchCalled); + assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls()); + assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls()); + assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls()); iterator.next(); } diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle index 3c859aea8294..6337cd47f308 100644 --- a/sdks/java/harness/build.gradle +++ b/sdks/java/harness/build.gradle @@ -72,6 +72,7 @@ dependencies { testCompile library.java.mockito_core testCompile project(path: ":sdks:java:core", configuration: "shadowTest") testCompile project(":runners:core-construction-java") + testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime") shadowTestRuntimeClasspath library.java.slf4j_jdk14 jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest") jmhRuntime library.java.slf4j_jdk14 diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java index 5ddf0ae5270c..777036ab4c82 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java @@ -26,6 +26,8 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterables; import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -49,7 +51,7 @@ public class BagUserState { private final BeamFnStateClient beamFnStateClient; private final StateRequest request; private final Coder valueCoder; - private Iterable oldValues; + private PrefetchableIterable oldValues; private ArrayList newValues; private boolean isClosed; @@ -80,19 +82,19 @@ public BagUserState( this.newValues = new ArrayList<>(); } - public Iterable get() { + public PrefetchableIterable get() { checkState( !isClosed, "Bag user state is no longer usable because it is closed for %s", request.getStateKey()); if (oldValues == null) { // If we were cleared we should disregard old values. - return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()); + return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size()); } else if (newValues.isEmpty()) { // If we have no new values then just return the old values. return oldValues; } - return Iterables.concat( + return PrefetchableIterables.concat( oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size())); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index 2f2789ed4566..5a931c5910a2 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -38,7 +38,6 @@ import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateBinder; import org.apache.beam.sdk.state.StateContext; @@ -264,7 +263,7 @@ public T read() { @Override public ValueState readLater() { - // TODO(BEAM-12802): Support prefetching. + impl.get().iterator().prefetch(); return this; } }; @@ -310,7 +309,7 @@ public Iterable read() { @Override public BagState readLater() { - // TODO(BEAM-12802): Support prefetching. + impl.get().iterator().prefetch(); return this; } @@ -391,6 +390,7 @@ public AccumT mergeAccumulators(Iterable accumulators) { @Override public CombiningState readLater() { + impl.get().iterator().prefetch(); return this; } @@ -412,7 +412,18 @@ public void add(ElementT value) { @Override public ReadableState isEmpty() { - return ReadableStates.immediate(!impl.get().iterator().hasNext()); + return new ReadableState() { + @Override + public @Nullable Boolean read() { + return !impl.get().iterator().hasNext(); + } + + @Override + public ReadableState readLater() { + impl.get().iterator().prefetch(); + return this; + } + }; } @Override diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java index cfc76cf1a726..7828f93ba027 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java @@ -21,6 +21,9 @@ import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; +import java.util.Objects; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterator; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.checkerframework.checker.nullness.qual.Nullable; @@ -28,29 +31,41 @@ * Converts an iterator to an iterable lazily loading values from the underlying iterator and * caching them to support reiteration. */ -@SuppressWarnings({ - "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) -}) -class LazyCachingIteratorToIterable implements Iterable { +class LazyCachingIteratorToIterable implements PrefetchableIterable { private final List cachedElements; - private final Iterator iterator; + private final PrefetchableIterator iterator; - public LazyCachingIteratorToIterable(Iterator iterator) { + public LazyCachingIteratorToIterable(PrefetchableIterator iterator) { this.cachedElements = new ArrayList<>(); this.iterator = iterator; } @Override - public Iterator iterator() { + public PrefetchableIterator iterator() { return new CachingIterator(); } /** An {@link Iterator} which adds and fetched values into the cached elements list. */ - private class CachingIterator implements Iterator { + private class CachingIterator implements PrefetchableIterator { private int position = 0; private CachingIterator() {} + @Override + public boolean isReady() { + if (position < cachedElements.size()) { + return true; + } + return iterator.isReady(); + } + + @Override + public void prefetch() { + if (!isReady()) { + iterator.prefetch(); + } + } + @Override public boolean hasNext() { // The order of the short circuit is important below. @@ -76,7 +91,7 @@ public T next() { @Override public int hashCode() { - return iterator.hasNext() ? iterator.next().hashCode() : -1789023489; + return iterator.hasNext() ? Objects.hashCode(iterator.next()) : -1789023489; } @Override diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 22be3060d42a..1026ba590466 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -17,20 +17,21 @@ */ package org.apache.beam.fn.harness.state; -import java.util.Collections; import java.util.Iterator; import java.util.NoSuchElementException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.fn.stream.DataStreams; +import org.apache.beam.sdk.fn.stream.DataStreams.DataStreamDecoder; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterator; +import org.apache.beam.sdk.fn.stream.PrefetchableIterators; import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; /** * Adapters which convert a a logical series of chunks using continuation tokens over the Beam Fn @@ -54,7 +55,7 @@ private StateFetchingIterators() {} * only) chunk of a state stream. This state request will be populated with a continuation * token to request further chunks of the stream if required. */ - public static Iterator readAllStartingFrom( + public static PrefetchableIterator readAllStartingFrom( BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) { return new LazyBlockingStateFetchingIterator(beamFnStateClient, stateRequestForFirstChunk); } @@ -74,94 +75,142 @@ public static Iterator readAllStartingFrom( * token to request further chunks of the stream if required. * @param valueCoder A coder for decoding the state stream. */ - public static Iterable readAllAndDecodeStartingFrom( + public static PrefetchableIterable readAllAndDecodeStartingFrom( BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk, Coder valueCoder) { - FirstPageAndRemainder firstPageAndRemainder = - new FirstPageAndRemainder(beamFnStateClient, stateRequestForFirstChunk); - return Iterables.concat( - new LazyCachingIteratorToIterable( - new DataStreams.DataStreamDecoder<>( - valueCoder, new LazySingletonIterator<>(firstPageAndRemainder::firstPage))), - () -> new DataStreams.DataStreamDecoder<>(valueCoder, firstPageAndRemainder.remainder())); - } - - /** A iterable that contains a single element, provided by a Supplier which is invoked lazily. */ - static class LazySingletonIterator implements Iterator { - - private final Supplier supplier; - private boolean hasNext; - - private LazySingletonIterator(Supplier supplier) { - this.supplier = supplier; - hasNext = true; - } - - @Override - public boolean hasNext() { - return hasNext; - } - - @Override - public T next() { - hasNext = false; - return supplier.get(); - } + return new FirstPageAndRemainder<>(beamFnStateClient, stateRequestForFirstChunk, valueCoder); } /** - * An helper class that (lazily) gives the first page of a paginated state request separately from + * A helper class that (lazily) gives the first page of a paginated state request separately from * all the remaining pages. */ - static class FirstPageAndRemainder { + @VisibleForTesting + static class FirstPageAndRemainder implements PrefetchableIterable { private final BeamFnStateClient beamFnStateClient; private final StateRequest stateRequestForFirstChunk; - private ByteString firstPage = null; + private final Coder valueCoder; + private LazyCachingIteratorToIterable firstPage; + private CompletableFuture firstPageResponseFuture; private ByteString continuationToken; - private FirstPageAndRemainder( - BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) { + FirstPageAndRemainder( + BeamFnStateClient beamFnStateClient, + StateRequest stateRequestForFirstChunk, + Coder valueCoder) { this.beamFnStateClient = beamFnStateClient; this.stateRequestForFirstChunk = stateRequestForFirstChunk; + this.valueCoder = valueCoder; } - public ByteString firstPage() { - if (firstPage == null) { - CompletableFuture stateResponseFuture = new CompletableFuture<>(); - beamFnStateClient.handle( - stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()), - stateResponseFuture); - StateResponse stateResponse; - try { - stateResponse = stateResponseFuture.get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException(e); - } catch (ExecutionException e) { - if (e.getCause() == null) { - throw new IllegalStateException(e); + @Override + public PrefetchableIterator iterator() { + return new PrefetchableIterator() { + PrefetchableIterator delegate; + + private void ensureDelegateExists() { + if (delegate == null) { + // Fetch the first page if necessary + prefetchFirstPage(); + if (firstPage == null) { + StateResponse stateResponse; + try { + stateResponse = firstPageResponseFuture.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(e); + } catch (ExecutionException e) { + if (e.getCause() == null) { + throw new IllegalStateException(e); + } + Throwables.throwIfUnchecked(e.getCause()); + throw new IllegalStateException(e.getCause()); + } + continuationToken = stateResponse.getGet().getContinuationToken(); + firstPage = + new LazyCachingIteratorToIterable<>( + new DataStreamDecoder<>( + valueCoder, + PrefetchableIterators.fromArray(stateResponse.getGet().getData()))); + } + + if (ByteString.EMPTY.equals((continuationToken))) { + delegate = firstPage.iterator(); + } else { + delegate = + PrefetchableIterators.concat( + firstPage.iterator(), + new DataStreamDecoder<>( + valueCoder, + new LazyBlockingStateFetchingIterator( + beamFnStateClient, + stateRequestForFirstChunk + .toBuilder() + .setGet( + StateGetRequest.newBuilder() + .setContinuationToken(continuationToken)) + .build()))); + } } - Throwables.throwIfUnchecked(e.getCause()); - throw new IllegalStateException(e.getCause()); } - continuationToken = stateResponse.getGet().getContinuationToken(); - firstPage = stateResponse.getGet().getData(); - } - return firstPage; + + @Override + public boolean isReady() { + if (delegate == null) { + if (firstPageResponseFuture != null) { + return firstPageResponseFuture.isDone(); + } + return false; + } + return delegate.isReady(); + } + + @Override + public void prefetch() { + if (firstPageResponseFuture == null) { + prefetchFirstPage(); + } else if (delegate != null && !delegate.isReady()) { + delegate.prefetch(); + } + } + + @Override + public boolean hasNext() { + if (delegate == null) { + // Ensure that we prefetch the second page after the first has been accessed. + // Prefetching subsequent pages after the first will be handled by the + // LazyBlockingStateFetchingIterator + ensureDelegateExists(); + boolean rval = delegate.hasNext(); + delegate.prefetch(); + return rval; + } + return delegate.hasNext(); + } + + @Override + public T next() { + if (delegate == null) { + // Ensure that we prefetch the second page after the first has been accessed. + // Prefetching subsequent pages after the first will be handled by the + // LazyBlockingStateFetchingIterator + ensureDelegateExists(); + T rval = delegate.next(); + delegate.prefetch(); + return rval; + } + return delegate.next(); + } + }; } - public Iterator remainder() { - firstPage(); - if (ByteString.EMPTY.equals(continuationToken)) { - return Collections.emptyIterator(); - } else { - return new LazyBlockingStateFetchingIterator( - beamFnStateClient, - stateRequestForFirstChunk - .toBuilder() - .setGet(StateGetRequest.newBuilder().setContinuationToken(continuationToken)) - .build()); + private void prefetchFirstPage() { + if (firstPageResponseFuture == null) { + firstPageResponseFuture = new CompletableFuture<>(); + beamFnStateClient.handle( + stateRequestForFirstChunk.toBuilder().setGet(stateRequestForFirstChunk.getGet()), + firstPageResponseFuture); } } } @@ -169,10 +218,11 @@ public Iterator remainder() { /** * An {@link Iterator} which fetches {@link ByteString} chunks using the State API. * - *

This iterator will only request a chunk on first access. Subsiquently it eagerly pre-fetches - * one future chunks at a time. + *

This iterator will only request a chunk on first access. Subsequently it eagerly pre-fetches + * one future chunk at a time. */ - static class LazyBlockingStateFetchingIterator implements Iterator { + @VisibleForTesting + static class LazyBlockingStateFetchingIterator implements PrefetchableIterator { private enum State { READ_REQUIRED, @@ -195,8 +245,17 @@ private enum State { this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken(); } - private void prefetch() { - if (prefetchedResponse == null && currentState == State.READ_REQUIRED) { + @Override + public boolean isReady() { + if (prefetchedResponse == null) { + return currentState != State.READ_REQUIRED; + } + return prefetchedResponse.isDone(); + } + + @Override + public void prefetch() { + if (currentState == State.READ_REQUIRED && prefetchedResponse == null) { prefetchedResponse = new CompletableFuture<>(); beamFnStateClient.handle( stateRequestForFirstChunk diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java index 7597128dcfad..0914b017e703 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java @@ -25,8 +25,11 @@ import java.util.Iterator; import java.util.NoSuchElementException; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterator; +import org.apache.beam.sdk.fn.stream.PrefetchableIterators; +import org.apache.beam.sdk.fn.stream.PrefetchableIteratorsTest.ReadyAfterPrefetch; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -40,7 +43,8 @@ public class LazyCachingIteratorToIterableTest { @Test public void testEmptyIterator() { - Iterable iterable = new LazyCachingIteratorToIterable<>(Iterators.forArray()); + Iterable iterable = + new LazyCachingIteratorToIterable<>(PrefetchableIterators.emptyIterator()); assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class)); // iterate multiple times assertArrayEquals(new Object[0], Iterables.toArray(iterable, Object.class)); @@ -52,7 +56,7 @@ public void testEmptyIterator() { @Test public void testInterleavedIteration() { Iterable iterable = - new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C")); + new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C")); Iterator iterator1 = iterable.iterator(); assertTrue(iterator1.hasNext()); @@ -77,14 +81,45 @@ public void testInterleavedIteration() { @Test public void testEqualsAndHashCode() { - Iterable iterA = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C")); - Iterable iterB = new LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C")); - Iterable iterC = new LazyCachingIteratorToIterable<>(Iterators.forArray()); - Iterable iterD = new LazyCachingIteratorToIterable<>(Iterators.forArray()); + Iterable iterA = + new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C")); + Iterable iterB = + new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray("A", "B", "C")); + Iterable iterC = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray()); + Iterable iterD = new LazyCachingIteratorToIterable<>(PrefetchableIterators.fromArray()); assertEquals(iterA, iterB); assertEquals(iterC, iterD); assertNotEquals(iterA, iterC); assertEquals(iterA.hashCode(), iterB.hashCode()); assertEquals(iterC.hashCode(), iterD.hashCode()); } + + @Test + public void testPrefetch() { + ReadyAfterPrefetch underlying = + new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B", "C")); + PrefetchableIterable iterable = new LazyCachingIteratorToIterable<>(underlying); + PrefetchableIterator iterator1 = iterable.iterator(); + PrefetchableIterator iterator2 = iterable.iterator(); + + // Check that the lazy iterable doesn't do any prefetch/access on instantiation + assertFalse(underlying.isReady()); + assertFalse(iterator1.isReady()); + assertFalse(iterator2.isReady()); + + // Check that if both iterators prefetch there is only one prefetch for the underlying iterator + // iterator. + iterator1.prefetch(); + assertEquals(1, underlying.getNumPrefetchCalls()); + iterator2.prefetch(); + assertEquals(1, underlying.getNumPrefetchCalls()); + + // Check that if that one iterator has advanced, the second doesn't perform any prefetch since + // the element is now cached. + iterator1.next(); + iterator1.next(); + iterator2.next(); + iterator2.prefetch(); + assertEquals(1, underlying.getNumPrefetchCalls()); + } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java index fc729cc5a763..384d2df69219 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java @@ -19,12 +19,16 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import java.util.ArrayList; import java.util.Arrays; -import java.util.Iterator; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import org.apache.beam.fn.harness.state.StateFetchingIterators.FirstPageAndRemainder; import org.apache.beam.fn.harness.state.StateFetchingIterators.LazyBlockingStateFetchingIterator; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; @@ -32,16 +36,56 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterator; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; import org.junit.Test; +import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Tests for {@link StateFetchingIterators}. */ +@RunWith(Enclosed.class) public class StateFetchingIteratorsTest { + + private static BeamFnStateClient fakeStateClient( + AtomicInteger callCount, ByteString... expected) { + return (requestBuilder, response) -> { + callCount.incrementAndGet(); + if (expected.length == 0) { + response.complete( + StateResponse.newBuilder() + .setId(requestBuilder.getId()) + .setGet(StateGetResponse.newBuilder()) + .build()); + return; + } + + ByteString continuationToken = requestBuilder.getGet().getContinuationToken(); + + int requestedPosition = 0; // Default position is 0 + if (!ByteString.EMPTY.equals(continuationToken)) { + requestedPosition = Integer.parseInt(continuationToken.toStringUtf8()); + } + + // Compute the new continuation token + ByteString newContinuationToken = ByteString.EMPTY; + if (requestedPosition != expected.length - 1) { + newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1)); + } + response.complete( + StateResponse.newBuilder() + .setId(requestBuilder.getId()) + .setGet( + StateGetResponse.newBuilder() + .setData(expected[requestedPosition]) + .setContinuationToken(newContinuationToken)) + .build()); + }; + } + /** Tests for {@link StateFetchingIterators.LazyBlockingStateFetchingIterator}. */ @RunWith(JUnit4.class) public static class LazyBlockingStateFetchingIteratorTest { @@ -77,49 +121,55 @@ public void testMultiWithEmptyByteStrings() throws Exception { ByteString.EMPTY); } - private BeamFnStateClient fakeStateClient(AtomicInteger callCount, ByteString... expected) { - return (requestBuilder, response) -> { - callCount.incrementAndGet(); - if (expected.length == 0) { - response.complete( - StateResponse.newBuilder() - .setId(requestBuilder.getId()) - .setGet(StateGetResponse.newBuilder()) - .build()); - return; - } - - ByteString continuationToken = requestBuilder.getGet().getContinuationToken(); - - int requestedPosition = 0; // Default position is 0 - if (!ByteString.EMPTY.equals(continuationToken)) { - requestedPosition = Integer.parseInt(continuationToken.toStringUtf8()); - } - - // Compute the new continuation token - ByteString newContinuationToken = ByteString.EMPTY; - if (requestedPosition != expected.length - 1) { - newContinuationToken = ByteString.copyFromUtf8(Integer.toString(requestedPosition + 1)); - } - response.complete( - StateResponse.newBuilder() - .setId(requestBuilder.getId()) - .setGet( - StateGetResponse.newBuilder() - .setData(expected[requestedPosition]) - .setContinuationToken(newContinuationToken)) - .build()); - }; + @Test + public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception { + AtomicInteger callCount = new AtomicInteger(); + BeamFnStateClient fakeStateClient = + new BeamFnStateClient() { + @Override + public void handle( + StateRequest.Builder requestBuilder, CompletableFuture response) { + callCount.incrementAndGet(); + } + }; + PrefetchableIterator byteStrings = + new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance()); + assertEquals(0, callCount.get()); + byteStrings.prefetch(); + assertEquals(1, callCount.get()); // first prefetch + byteStrings.prefetch(); + assertEquals(1, callCount.get()); // subsequent is ignored } private void testFetch(ByteString... expected) { AtomicInteger callCount = new AtomicInteger(); BeamFnStateClient fakeStateClient = fakeStateClient(callCount, expected); - Iterator byteStrings = + PrefetchableIterator byteStrings = new LazyBlockingStateFetchingIterator(fakeStateClient, StateRequest.getDefaultInstance()); assertEquals(0, callCount.get()); // Ensure it's fully lazy. - assertArrayEquals(expected, Iterators.toArray(byteStrings, Object.class)); + assertFalse(byteStrings.isReady()); + + // Prefetch every second element in the iterator capturing the results + List results = new ArrayList<>(); + for (int i = 0; i < expected.length; ++i) { + if (i % 2 == 0) { + // Ensure that prefetch performs the call + byteStrings.prefetch(); + assertEquals(i + 1, callCount.get()); + assertTrue(byteStrings.isReady()); + } + assertTrue(byteStrings.hasNext()); + results.add(byteStrings.next()); + } + assertFalse(byteStrings.hasNext()); + assertTrue(byteStrings.isReady()); + + assertEquals(Arrays.asList(expected), results); } + } + + @RunWith(JUnit4.class) + public static class FirstPageAndRemainderTest { @Test public void testEmptyValues() throws Exception { @@ -133,7 +183,7 @@ public void testOneValue() throws Exception { @Test public void testManyValues() throws Exception { - testFetchValues(VarIntCoder.of(), 11, 37, 389, 5077); + testFetchValues(VarIntCoder.of(), 1, 22, 333, 4444, 55555, 666666); } private void testFetchValues(Coder coder, T... expected) { @@ -153,35 +203,42 @@ private void testFetchValues(Coder coder, T... expected) { AtomicInteger callCount = new AtomicInteger(); BeamFnStateClient fakeStateClient = fakeStateClient(callCount, Iterables.toArray(byteStrings, ByteString.class)); - Iterable values = - StateFetchingIterators.readAllAndDecodeStartingFrom( - fakeStateClient, StateRequest.getDefaultInstance(), coder); + PrefetchableIterable values = + new FirstPageAndRemainder<>(fakeStateClient, StateRequest.getDefaultInstance(), coder); // Ensure it's fully lazy. assertEquals(0, callCount.get()); - Iterator valuesIter = values.iterator(); + PrefetchableIterator valuesIter = values.iterator(); + assertFalse(valuesIter.isReady()); assertEquals(0, callCount.get()); - // No more is read than necessary. - if (valuesIter.hasNext()) { - valuesIter.next(); - } + // Ensure that the first page result is cached across multiple iterators and subsequent + // iterators are ready and prefetch does nothing + valuesIter.prefetch(); + assertTrue(valuesIter.isReady()); assertEquals(1, callCount.get()); - // The first page is cached. - Iterator valuesIter2 = values.iterator(); - assertEquals(1, callCount.get()); - if (valuesIter2.hasNext()) { - valuesIter2.next(); - } + PrefetchableIterator valuesIter2 = values.iterator(); + assertTrue(valuesIter2.isReady()); + valuesIter2.prefetch(); assertEquals(1, callCount.get()); - if (valuesIter.hasNext()) { - valuesIter.next(); - // Subsequent pages are pre-fetched, so after accessing the second page, - // the third should be requested. - assertEquals(3, callCount.get()); + // Prefetch every second element in the iterator capturing the results + List results = new ArrayList<>(); + for (int i = 0; i < expected.length; ++i) { + if (i % 2 == 1) { + // Ensure that prefetch performs the call + valuesIter2.prefetch(); + assertTrue(valuesIter2.isReady()); + // Note that this is i+2 because we expect to prefetch the page after the current one + // We also have to bound it to the max number of pages + assertEquals(Math.min(i + 2, expected.length), callCount.get()); + } + assertTrue(valuesIter2.hasNext()); + results.add(valuesIter2.next()); } + assertFalse(valuesIter2.hasNext()); + assertTrue(valuesIter2.isReady()); // The contents agree. assertArrayEquals(expected, Iterables.toArray(values, Object.class));