Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,36 +156,65 @@ public interface OutputChunkConsumer<T> {
}

/**
* An adapter which converts an {@link InputStream} to an {@link Iterator} of {@code T} values
* using the specified {@link Coder}.
* An adapter which converts an {@link InputStream} to a {@link PrefetchableIterator} of {@code T}
* values using the specified {@link Coder}.
*
* <p>Note that this adapter follows the Beam Fn API specification for forcing values that decode
* consuming zero bytes to consuming exactly one byte.
*
* <p>Note that access to the underlying {@link InputStream} is lazy and will only be invoked on
* first access to {@link #next()} or {@link #hasNext()}.
* first access to {@link #next}, {@link #hasNext}, {@link #isReady}, and {@link #prefetch}.
*
* <p>Note that {@link #isReady} and {@link #prefetch} rely on non-empty {@link ByteString}s being
* returned via the underlying {@link PrefetchableIterator} otherwise the {@link #prefetch} will
* seemingly make zero progress yet will actually advance through the empty pages.
*/
public static class DataStreamDecoder<T> implements Iterator<T> {
public static class DataStreamDecoder<T> implements PrefetchableIterator<T> {

private enum State {
READ_REQUIRED,
HAS_NEXT,
EOF
}

private final Iterator<ByteString> inputByteStrings;
private final PrefetchableIterator<ByteString> inputByteStrings;
private final Inbound inbound;
private final Coder<T> coder;
private State currentState;
private T next;

public DataStreamDecoder(Coder<T> coder, Iterator<ByteString> inputStream) {
public DataStreamDecoder(Coder<T> coder, PrefetchableIterator<ByteString> inputStream) {
this.currentState = State.READ_REQUIRED;
this.coder = coder;
this.inputByteStrings = inputStream;
this.inbound = new Inbound();
}

@Override
public boolean isReady() {
switch (currentState) {
case EOF:
return true;
case READ_REQUIRED:
try {
return inbound.isReady();
} catch (IOException e) {
throw new RuntimeException(e);
}
case HAS_NEXT:
return true;
default:
throw new IllegalStateException(String.format("Unknown state %s", currentState));
}
}

@Override
public void prefetch() {
if (!isReady()) {
inputByteStrings.prefetch();
}
}

@Override
public boolean hasNext() {
switch (currentState) {
Expand Down Expand Up @@ -232,8 +261,8 @@ public void remove() {
private static final InputStream EMPTY_STREAM = ByteString.EMPTY.newInput();

/**
* An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the first
* {@link Iterator} on first access of this input stream.
* An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the {@link
* Iterator} on first access of this input stream.
*
* <p>Closing this input stream has no effect.
*/
Expand All @@ -245,6 +274,22 @@ public Inbound() {
this.currentStream = EMPTY_STREAM;
}

public boolean isReady() throws IOException {
// Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
// minus the number of bytes that have been read so far and can be reliably used to tell
// us whether we are at the end of the stream.
while (currentStream.available() == 0) {
if (!inputByteStrings.isReady()) {
return false;
}
if (!inputByteStrings.hasNext()) {
return true;
}
currentStream = inputByteStrings.next().newInput();
}
return true;
}

public boolean isEof() throws IOException {
// Note that ByteString#newInput is guaranteed to return the length of the entire ByteString
// minus the number of bytes that have been read so far and can be reliably used to tell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;

import java.io.IOException;
Expand Down Expand Up @@ -106,7 +107,7 @@ public void testNonEmptyInputStream() throws Exception {
}

@Test
public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
public void testNonEmptyInputStreamWithZeroLengthEncoding() throws Exception {
CountingOutputStream countingOutputStream =
new CountingOutputStream(ByteStreams.nullOutputStream());
GlobalWindow.Coder.INSTANCE.encode(GlobalWindow.INSTANCE, countingOutputStream);
Expand All @@ -115,6 +116,55 @@ public void testNonEmptyInputStreamWithZeroLengthCoder() throws Exception {
testDecoderWith(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE, GlobalWindow.INSTANCE);
}

@Test
public void testPrefetch() throws Exception {
List<ByteString> encodings = new ArrayList<>();
{
ByteString.Output encoding = ByteString.newOutput();
StringUtf8Coder.of().encode("A", encoding);
StringUtf8Coder.of().encode("BC", encoding);
encodings.add(encoding.toByteString());
}
encodings.add(ByteString.EMPTY);
{
ByteString.Output encoding = ByteString.newOutput();
StringUtf8Coder.of().encode("DEF", encoding);
StringUtf8Coder.of().encode("GHIJ", encoding);
encodings.add(encoding.toByteString());
}

PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<ByteString> iterator =
new PrefetchableIteratorsTest.ReadyAfterPrefetchUntilNext<>(encodings.iterator());
PrefetchableIterator<String> decoder =
new DataStreamDecoder<>(StringUtf8Coder.of(), iterator);
assertFalse(decoder.isReady());
decoder.prefetch();
assertTrue(decoder.isReady());
assertEquals(1, iterator.getNumPrefetchCalls());

decoder.next();
// Now we will have moved off of the empty byte array that we start with so prefetch will
// do nothing since we are ready
assertTrue(decoder.isReady());
decoder.prefetch();
assertEquals(1, iterator.getNumPrefetchCalls());

decoder.next();
// Now we are at the end of the first ByteString so we expect a prefetch to pass through
assertFalse(decoder.isReady());
decoder.prefetch();
assertEquals(2, iterator.getNumPrefetchCalls());
// We also expect the decoder to not be ready since the next byte string is empty which
// would require us to move to the next page. This typically wouldn't happen in practice
// though because we expect non empty pages.
assertFalse(decoder.isReady());

// Prefetching will allow us to move to the third ByteString
decoder.prefetch();
assertEquals(3, iterator.getNumPrefetchCalls());
assertTrue(decoder.isReady());
}

private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOException {
ByteString.Output output = ByteString.newOutput();
for (T value : expected) {
Expand All @@ -131,7 +181,9 @@ private <T> void testDecoderWith(Coder<T> coder, T... expected) throws IOExcepti
}

private <T> void testDecoderWith(Coder<T> coder, T[] expected, List<ByteString> encoded) {
Iterator<T> decoder = new DataStreamDecoder<>(coder, encoded.iterator());
Iterator<T> decoder =
new DataStreamDecoder<>(
coder, PrefetchableIterators.maybePrefetchable(encoded.iterator()));

Object[] actual = Iterators.toArray(decoder, Object.class);
assertArrayEquals(expected, actual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,14 @@ public void testConcat() {
"F");
}

private static class NeverReady implements PrefetchableIterator<String> {
PrefetchableIterator<String> delegate = PrefetchableIterators.fromArray("A", "B");
public static class NeverReady<T> implements PrefetchableIterator<T> {
private final Iterator<T> delegate;
int prefetchCalled;

public NeverReady(Iterator<T> delegate) {
this.delegate = delegate;
}

@Override
public boolean isReady() {
return false;
Expand All @@ -140,74 +144,117 @@ public boolean hasNext() {
}

@Override
public String next() {
public T next() {
return delegate.next();
}

public int getNumPrefetchCalls() {
return prefetchCalled;
}
}

private static class ReadyAfterPrefetch extends NeverReady {
public static class ReadyAfterPrefetch<T> extends NeverReady<T> {

public ReadyAfterPrefetch(Iterator<T> delegate) {
super(delegate);
}

@Override
public boolean isReady() {
return prefetchCalled > 0;
}
}

public static class ReadyAfterPrefetchUntilNext<T> extends ReadyAfterPrefetch<T> {
boolean advancedSincePrefetch;

public ReadyAfterPrefetchUntilNext(Iterator<T> delegate) {
super(delegate);
}

@Override
public boolean isReady() {
return !advancedSincePrefetch && super.isReady();
}

@Override
public void prefetch() {
advancedSincePrefetch = false;
super.prefetch();
}

@Override
public T next() {
advancedSincePrefetch = true;
return super.next();
}

@Override
public boolean hasNext() {
advancedSincePrefetch = true;
return super.hasNext();
}
}

@Test
public void testConcatIsReadyAdvancesToNextIteratorWhenAble() {
NeverReady readyAfterPrefetch1 = new NeverReady();
ReadyAfterPrefetch readyAfterPrefetch2 = new ReadyAfterPrefetch();
ReadyAfterPrefetch readyAfterPrefetch3 = new ReadyAfterPrefetch();
NeverReady<String> readyAfterPrefetch1 =
new NeverReady<>(PrefetchableIterators.fromArray("A", "B"));
ReadyAfterPrefetch<String> readyAfterPrefetch2 =
new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));
ReadyAfterPrefetch<String> readyAfterPrefetch3 =
new ReadyAfterPrefetch<>(PrefetchableIterators.fromArray("A", "B"));

PrefetchableIterator<String> iterator =
PrefetchableIterators.concat(readyAfterPrefetch1, readyAfterPrefetch2, readyAfterPrefetch3);

// Expect no prefetches yet
assertEquals(0, readyAfterPrefetch1.prefetchCalled);
assertEquals(0, readyAfterPrefetch2.prefetchCalled);
assertEquals(0, readyAfterPrefetch3.prefetchCalled);
assertEquals(0, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());

// We expect to attempt to prefetch for the first time.
iterator.prefetch();
assertEquals(1, readyAfterPrefetch1.prefetchCalled);
assertEquals(0, readyAfterPrefetch2.prefetchCalled);
assertEquals(0, readyAfterPrefetch3.prefetchCalled);
assertEquals(1, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();

// We expect to attempt to prefetch again since we aren't ready.
iterator.prefetch();
assertEquals(2, readyAfterPrefetch1.prefetchCalled);
assertEquals(0, readyAfterPrefetch2.prefetchCalled);
assertEquals(0, readyAfterPrefetch3.prefetchCalled);
assertEquals(2, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();

// The current iterator is done but is never ready so we can't advance to the next one and
// expect another prefetch to go to the current iterator.
iterator.prefetch();
assertEquals(3, readyAfterPrefetch1.prefetchCalled);
assertEquals(0, readyAfterPrefetch2.prefetchCalled);
assertEquals(0, readyAfterPrefetch3.prefetchCalled);
assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();

// Now that we know the last iterator is done and have advanced to the next one we expect
// prefetch to go through
iterator.prefetch();
assertEquals(3, readyAfterPrefetch1.prefetchCalled);
assertEquals(1, readyAfterPrefetch2.prefetchCalled);
assertEquals(0, readyAfterPrefetch3.prefetchCalled);
assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(0, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();

// The last iterator is done so we should be able to prefetch the next one before advancing
iterator.prefetch();
assertEquals(3, readyAfterPrefetch1.prefetchCalled);
assertEquals(1, readyAfterPrefetch2.prefetchCalled);
assertEquals(1, readyAfterPrefetch3.prefetchCalled);
assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();

// The current iterator is ready so no additional prefetch is necessary
iterator.prefetch();
assertEquals(3, readyAfterPrefetch1.prefetchCalled);
assertEquals(1, readyAfterPrefetch2.prefetchCalled);
assertEquals(1, readyAfterPrefetch3.prefetchCalled);
assertEquals(3, readyAfterPrefetch1.getNumPrefetchCalls());
assertEquals(1, readyAfterPrefetch2.getNumPrefetchCalls());
assertEquals(1, readyAfterPrefetch3.getNumPrefetchCalls());
iterator.next();
}

Expand Down
1 change: 1 addition & 0 deletions sdks/java/harness/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies {
testCompile library.java.mockito_core
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
testCompile project(":runners:core-construction-java")
testCompile project(path: ":sdks:java:fn-execution", configuration: "testRuntime")
shadowTestRuntimeClasspath library.java.slf4j_jdk14
jmhCompile project(path: ":sdks:java:harness", configuration: "shadowTest")
jmhRuntime library.java.slf4j_jdk14
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;

Expand All @@ -49,7 +51,7 @@ public class BagUserState<T> {
private final BeamFnStateClient beamFnStateClient;
private final StateRequest request;
private final Coder<T> valueCoder;
private Iterable<T> oldValues;
private PrefetchableIterable<T> oldValues;
private ArrayList<T> newValues;
private boolean isClosed;

Expand Down Expand Up @@ -80,19 +82,19 @@ public BagUserState(
this.newValues = new ArrayList<>();
}

public Iterable<T> get() {
public PrefetchableIterable<T> get() {
checkState(
!isClosed,
"Bag user state is no longer usable because it is closed for %s",
request.getStateKey());
if (oldValues == null) {
// If we were cleared we should disregard old values.
return Iterables.limit(Collections.unmodifiableList(newValues), newValues.size());
return PrefetchableIterables.limit(Collections.unmodifiableList(newValues), newValues.size());
} else if (newValues.isEmpty()) {
// If we have no new values then just return the old values.
return oldValues;
}
return Iterables.concat(
return PrefetchableIterables.concat(
oldValues, Iterables.limit(Collections.unmodifiableList(newValues), newValues.size()));
}

Expand Down
Loading