diff --git a/runners/google-cloud-dataflow-java/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java index 5bff46c9317f..fa7067e5bcd4 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -81,7 +81,6 @@ import com.google.cloud.dataflow.sdk.values.TypedPValue; import com.google.common.base.Preconditions; import com.google.common.base.Strings; -import com.google.common.collect.Lists; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -730,12 +729,8 @@ private void addOutput(String name, PValue value, Coder valueCoder) { } private void addDisplayData(String name, DisplayData displayData) { - List> serializedItems = Lists.newArrayList(); - for (DisplayData.Item item : displayData.items()) { - serializedItems.add(MAPPER.convertValue(item, Map.class)); - } - - addList(getProperties(), name, serializedItems); + List> list = MAPPER.convertValue(displayData, List.class); + addList(getProperties(), name, list); } @Override diff --git a/runners/google-cloud-dataflow-java/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java index 1b32b73ddb83..dd1b3c8d3ee9 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java @@ -818,8 +818,8 @@ public void processElement(ProcessContext c) throws Exception { @Override public void populateDisplayData(DisplayData.Builder builder) { builder - .add("foo", "bar") - .add("foo2", DataflowPipelineTranslatorTest.class) + .add("foo", "bar") + .add("foo2", DataflowPipelineTranslatorTest.class) .withLabel("Test Class") .withLinkUrl("http://www.google.com"); } @@ -833,7 +833,7 @@ public void processElement(ProcessContext c) throws Exception { @Override public void populateDisplayData(DisplayData.Builder builder) { - builder.add("foo3", "barge"); + builder.add("foo3", 1234); } }; @@ -876,11 +876,11 @@ public void populateDisplayData(DisplayData.Builder builder) { ); ImmutableList expectedFn2DisplayData = ImmutableList.of( - ImmutableMap.builder() + ImmutableMap.builder() .put("namespace", fn2.getClass().getName()) .put("key", "foo3") - .put("type", "STRING") - .put("value", "barge") + .put("type", "INTEGER") + .put("value", 1234L) .build() ); diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java index d23fc0b797d4..d9098ba8d954 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayData.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.annotation.JsonGetter; import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonValue; import org.apache.avro.reflect.Nullable; import org.joda.time.Duration; @@ -100,6 +101,7 @@ public static Type inferType(@Nullable Object value) { return Type.tryInferFrom(value); } + @JsonValue public Collection items() { return entries.values(); } @@ -175,6 +177,13 @@ public interface Builder { */ ItemBuilder add(String key, long value); + /** + * Register the given numeric display data if the value is not null. + * + * @see DisplayData.Builder#add(String, long) + */ + ItemBuilder addIfNotNull(String key, @Nullable Long value); + /** * Register the given numeric display data if the value is different than the specified default. * @@ -189,6 +198,13 @@ public interface Builder { */ ItemBuilder add(String key, double value); + /** + * Register the given floating point display data if the value is not null. + * + * @see DisplayData.Builder#add(String, double) + */ + ItemBuilder addIfNotNull(String key, @Nullable Double value); + /** * Register the given floating point display data if the value is different than the specified * default. @@ -204,6 +220,13 @@ public interface Builder { */ ItemBuilder add(String key, boolean value); + /** + * Register the given boolean display data if the value is not null. + * + * @see DisplayData.Builder#add(String, boolean) + */ + ItemBuilder addIfNotNull(String key, @Nullable Boolean value); + /** * Register the given boolean display data if the value is different than the specified default. * @@ -286,6 +309,7 @@ ItemBuilder addIfNotDefault( * transform or component. * * @throws ClassCastException if the value cannot be safely cast to the specified type. + * * @see DisplayData#inferType(Object) */ ItemBuilder add(String key, Type type, Object value); @@ -332,8 +356,8 @@ public static class Item { private final String key; private final String ns; private final Type type; - private final String value; - private final String shortValue; + private final Object value; + private final Object shortValue; private final String label; private final String url; @@ -348,8 +372,8 @@ private Item( String namespace, String key, Type type, - String value, - String shortValue, + Object value, + Object shortValue, String url, String label) { this.ns = namespace; @@ -384,7 +408,7 @@ public Type getType() { * Retrieve the value of the metadata item. */ @JsonGetter("value") - public String getValue() { + public Object getValue() { return value; } @@ -398,7 +422,7 @@ public String getValue() { @JsonGetter("shortValue") @JsonInclude(JsonInclude.Include.NON_NULL) @Nullable - public String getShortValue() { + public Object getShortValue() { return shortValue; } @@ -540,48 +564,65 @@ public enum Type { STRING { @Override FormattedItemValue format(Object value) { - return new FormattedItemValue(value.toString()); + return new FormattedItemValue(checkType(value, String.class, STRING)); } }, INTEGER { @Override FormattedItemValue format(Object value) { - Number number = (Number) value; - return new FormattedItemValue(Long.toString(number.longValue())); + if (value instanceof Integer) { + long l = ((Integer) value).longValue(); + return format(l); + } + + return new FormattedItemValue(checkType(value, Long.class, INTEGER)); } }, FLOAT { @Override FormattedItemValue format(Object value) { - return new FormattedItemValue(Double.toString((Double) value)); + return new FormattedItemValue(checkType(value, Number.class, FLOAT)); } }, BOOLEAN() { @Override FormattedItemValue format(Object value) { - return new FormattedItemValue(Boolean.toString((boolean) value)); + return new FormattedItemValue(checkType(value, Boolean.class, BOOLEAN)); } }, TIMESTAMP() { @Override FormattedItemValue format(Object value) { - return new FormattedItemValue((TIMESTAMP_FORMATTER.print((Instant) value))); + Instant instant = checkType(value, Instant.class, TIMESTAMP); + return new FormattedItemValue((TIMESTAMP_FORMATTER.print(instant))); } }, DURATION { @Override FormattedItemValue format(Object value) { - return new FormattedItemValue(Long.toString(((Duration) value).getMillis())); + Duration duration = checkType(value, Duration.class, DURATION); + return new FormattedItemValue(duration.getMillis()); } }, JAVA_CLASS { @Override FormattedItemValue format(Object value) { - Class clazz = (Class) value; + Class clazz = checkType(value, Class.class, JAVA_CLASS); return new FormattedItemValue(clazz.getName(), clazz.getSimpleName()); } }; + private static T checkType(Object value, Class clazz, DisplayData.Type expectedType) { + if (!clazz.isAssignableFrom(value.getClass())) { + throw new ClassCastException(String.format( + "Value is not valid for DisplayData type %s: %s", expectedType, value)); + } + + @SuppressWarnings("unchecked") // type checked above. + T typedValue = (T) value; + return typedValue; + } + /** * Format the display metadata value into a long string representation, and optionally * a shorter representation for display. @@ -592,7 +633,6 @@ FormattedItemValue format(Object value) { @Nullable private static Type tryInferFrom(@Nullable Object value) { - Type type; if (value instanceof Integer || value instanceof Long) { return INTEGER; } else if (value instanceof Double || value instanceof Float) { @@ -614,23 +654,23 @@ private static Type tryInferFrom(@Nullable Object value) { } static class FormattedItemValue { - private final String shortValue; - private final String longValue; + private final Object shortValue; + private final Object longValue; - private FormattedItemValue(String longValue) { + private FormattedItemValue(Object longValue) { this(longValue, null); } - private FormattedItemValue(String longValue, String shortValue) { + private FormattedItemValue(Object longValue, Object shortValue) { this.longValue = longValue; this.shortValue = shortValue; } - String getLongValue() { + Object getLongValue() { return this.longValue; } - String getShortValue() { + Object getShortValue() { return this.shortValue; } } @@ -700,9 +740,14 @@ public ItemBuilder add(String key, long value) { return addItemIf(true, key, Type.INTEGER, value); } + @Override + public ItemBuilder addIfNotNull(String key, @Nullable Long value) { + return addItemIf(value != null, key, Type.INTEGER, value); + } + @Override public ItemBuilder addIfNotDefault(String key, long value, long defaultValue) { - return addItemIf(value != defaultValue, key, Type.INTEGER, value); + return addItemIf(!Objects.equals(value, defaultValue), key, Type.INTEGER, value); } @Override @@ -710,9 +755,14 @@ public ItemBuilder add(String key, double value) { return addItemIf(true, key, Type.FLOAT, value); } + @Override + public ItemBuilder addIfNotNull(String key, @Nullable Double value) { + return addItemIf(value != null, key, Type.FLOAT, value); + } + @Override public ItemBuilder addIfNotDefault(String key, double value, double defaultValue) { - return addItemIf(value != defaultValue, key, Type.FLOAT, value); + return addItemIf(!Objects.equals(value, defaultValue), key, Type.FLOAT, value); } @Override @@ -720,9 +770,14 @@ public ItemBuilder add(String key, boolean value) { return addItemIf(true, key, Type.BOOLEAN, value); } + @Override + public ItemBuilder addIfNotNull(String key, @Nullable Boolean value) { + return addItemIf(value != null, key, Type.BOOLEAN, value); + } + @Override public ItemBuilder addIfNotDefault(String key, boolean value, boolean defaultValue) { - return addItemIf(value != defaultValue, key, Type.BOOLEAN, value); + return addItemIf(!Objects.equals(value, defaultValue), key, Type.BOOLEAN, value); } @Override diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java index 2832414256ca..7a06ab6456df 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataMatchers.java @@ -301,7 +301,7 @@ protected DisplayData.Type featureValueOf(DisplayData.Item actual) { * value. */ - public static Matcher hasValue(String value) { + public static Matcher hasValue(Object value) { return hasValue(Matchers.is(value)); } @@ -309,12 +309,12 @@ public static Matcher hasValue(String value) { * Creates a matcher that matches if the examined {@link DisplayData.Item} contains a value * matching the specified value matcher. */ - public static Matcher hasValue(Matcher valueMatcher) { - return new FeatureMatcher( + public static Matcher hasValue(Matcher valueMatcher) { + return new FeatureMatcher( valueMatcher, "with value", "value") { @Override - protected String featureValueOf(DisplayData.Item actual) { - return actual.getValue(); + protected T featureValueOf(DisplayData.Item actual) { + return (T) actual.getValue(); } }; } diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java index 9f8d5097efc6..5e102e0834ef 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/transforms/display/DisplayDataTest.java @@ -33,14 +33,21 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Builder; import com.google.cloud.dataflow.sdk.transforms.display.DisplayData.Item; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; import com.google.common.testing.EqualsTester; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import org.hamcrest.CustomTypeSafeMatcher; import org.hamcrest.FeatureMatcher; @@ -68,6 +75,7 @@ public class DisplayDataTest { @Rule public ExpectedException thrown = ExpectedException.none(); private static final DateTimeFormatter ISO_FORMATTER = ISODateTimeFormat.dateTime(); + private static final ObjectMapper MAPPER = new ObjectMapper(); @Test public void testTypicalUsage() { @@ -206,19 +214,29 @@ public void populateDisplayData(DisplayData.Builder builder) { @Test public void testAddIfNotDefault() { - final int defaultValue = 10; - DisplayData data = DisplayData.from(new HasDisplayData() { @Override public void populateDisplayData(Builder builder) { builder - .addIfNotDefault("isDefault", defaultValue, defaultValue) - .addIfNotDefault("notDefault", defaultValue + 1, defaultValue); + .addIfNotDefault("defaultString", "foo", "foo") + .addIfNotDefault("notDefaultString", "foo", "notFoo") + .addIfNotDefault("defaultInteger", 1, 1) + .addIfNotDefault("notDefaultInteger", 1, 2) + .addIfNotDefault("defaultDouble", 123.4, 123.4) + .addIfNotDefault("notDefaultDouble", 123.4, 234.5) + .addIfNotDefault("defaultBoolean", true, true) + .addIfNotDefault("notDefaultBoolean", true, false) + .addIfNotDefault("defaultInstant", new Instant(0), new Instant(0)) + .addIfNotDefault("notDefaultInstant", new Instant(0), Instant.now()) + .addIfNotDefault("defaultDuration", Duration.ZERO, Duration.ZERO) + .addIfNotDefault("notDefaultDuration", Duration.millis(1234), Duration.ZERO) + .addIfNotDefault("defaultClass", DisplayDataTest.class, DisplayDataTest.class) + .addIfNotDefault("notDefaultClass", DisplayDataTest.class, null); } }); - assertThat(data, not(hasDisplayItem(hasKey("isDefault")))); - assertThat(data, hasDisplayItem("notDefault", defaultValue + 1)); + assertThat(data.items(), hasSize(7)); + assertThat(data.items(), everyItem(hasKey(startsWith("notDefault")))); } @Test @@ -227,13 +245,25 @@ public void testAddIfNotNull() { @Override public void populateDisplayData(Builder builder) { builder - .addIfNotNull("isNull", (Class) null) - .addIfNotNull("notNull", DisplayDataTest.class); + .addIfNotNull("nullString", (String) null) + .addIfNotNull("notNullString", "foo") + .addIfNotNull("nullLong", (Long) null) + .addIfNotNull("notNullLong", 1234L) + .addIfNotNull("nullDouble", (Double) null) + .addIfNotNull("notNullDouble", 123.4) + .addIfNotNull("nullBoolean", (Boolean) null) + .addIfNotNull("notNullBoolean", true) + .addIfNotNull("nullInstant", (Instant) null) + .addIfNotNull("notNullInstant", Instant.now()) + .addIfNotNull("nullDuration", (Duration) null) + .addIfNotNull("notNullDuration", Duration.ZERO) + .addIfNotNull("nullClass", (Class) null) + .addIfNotNull("notNullClass", DisplayDataTest.class); } }); - assertThat(data, not(hasDisplayItem(hasKey("isNull")))); - assertThat(data, hasDisplayItem(hasKey("notNull"))); + assertThat(data.items(), hasSize(7)); + assertThat(data.items(), everyItem(hasKey(startsWith("notNull")))); } @Test @@ -563,15 +593,66 @@ public void testExplicitItemType() { @Override public void populateDisplayData(Builder builder) { builder - .add("integer", DisplayData.Type.INTEGER, 1234) + .add("integer", DisplayData.Type.INTEGER, 1234L) .add("string", DisplayData.Type.STRING, "foobar"); } }); - assertThat(data, hasDisplayItem("integer", 1234)); + assertThat(data, hasDisplayItem("integer", 1234L)); assertThat(data, hasDisplayItem("string", "foobar")); } + @Test + public void testFormatIncompatibleTypes() { + Map invalidPairs = ImmutableMap.builder() + .put(DisplayData.Type.STRING, 1234) + .put(DisplayData.Type.INTEGER, "string value") + .put(DisplayData.Type.FLOAT, "string value") + .put(DisplayData.Type.BOOLEAN, "string value") + .put(DisplayData.Type.TIMESTAMP, "string value") + .put(DisplayData.Type.DURATION, "string value") + .put(DisplayData.Type.JAVA_CLASS, "string value") + .build(); + + for (Map.Entry pair : invalidPairs.entrySet()) { + try { + DisplayData.Type type = pair.getKey(); + Object invalidValue = pair.getValue(); + + type.format(invalidValue); + fail(String.format( + "Expected exception not thrown for invalid %s value: %s", type, invalidValue)); + } catch (ClassCastException e) { + // Expected + } + } + } + + @Test + public void testFormatCompatibleTypes() { + Multimap validPairs = ImmutableMultimap + .builder() + .put(DisplayData.Type.INTEGER, 1234) + .put(DisplayData.Type.INTEGER, 1234L) + .put(DisplayData.Type.FLOAT, 123.4f) + .put(DisplayData.Type.FLOAT, 123.4) + .put(DisplayData.Type.FLOAT, 1234) + .put(DisplayData.Type.FLOAT, 1234L) + .build(); + + for (Map.Entry pair : validPairs.entries()) { + DisplayData.Type type = pair.getKey(); + Object value = pair.getValue(); + + try { + type.format(value); + } catch (ClassCastException e) { + fail(String.format("Failed to format %s for DisplayData.%s", + value.getClass().getSimpleName(), type)); + } + } + } + @Test public void testInvalidExplicitItemType() { HasDisplayData component = new HasDisplayData() { @@ -743,6 +824,7 @@ public void populateDisplayData(Builder builder) { }); } + @Test public void testAcceptsNullOptionalValues() { DisplayData.from( new HasDisplayData() { @@ -750,14 +832,95 @@ public void testAcceptsNullOptionalValues() { public void populateDisplayData(Builder builder) { builder.add("key", "value") .withLabel(null) - .withLinkUrl(null) - .withNamespace(null); + .withLinkUrl(null); } }); // Should not throw } + @Test + public void testJsonSerialization() throws IOException { + final String stringValue = "foobar"; + final int intValue = 1234; + final double floatValue = 123.4; + final boolean boolValue = true; + final int durationMillis = 1234; + + HasDisplayData component = new HasDisplayData() { + @Override + public void populateDisplayData(Builder builder) { + builder + .add("string", stringValue) + .add("long", intValue) + .add("double", floatValue) + .add("boolean", boolValue) + .add("instant", new Instant(0)) + .add("duration", Duration.millis(durationMillis)) + .add("class", DisplayDataTest.class) + .withLinkUrl("http://abc") + .withLabel("baz") + ; + } + }; + DisplayData data = DisplayData.from(component); + + JsonNode json = MAPPER.readTree(MAPPER.writeValueAsBytes(data)); + assertThat(json, hasExpectedJson(component, "STRING", "string", quoted(stringValue))); + assertThat(json, hasExpectedJson(component, "INTEGER", "long", intValue)); + assertThat(json, hasExpectedJson(component, "FLOAT", "double", floatValue)); + assertThat(json, hasExpectedJson(component, "BOOLEAN", "boolean", boolValue)); + assertThat(json, hasExpectedJson(component, "DURATION", "duration", durationMillis)); + assertThat(json, hasExpectedJson( + component, "TIMESTAMP", "instant", quoted("1970-01-01T00:00:00.000Z"))); + assertThat(json, hasExpectedJson( + component, "JAVA_CLASS", "class", quoted(DisplayDataTest.class.getName()), + quoted("DisplayDataTest"), "baz", "http://abc")); + } + + private String quoted(Object obj) { + return String.format("\"%s\"", obj); + } + + private Matcher> hasExpectedJson( + HasDisplayData component, String type, String key, Object value) + throws IOException { + return hasExpectedJson(component, type, key, value, null, null, null); + } + + private Matcher> hasExpectedJson( + HasDisplayData component, + String type, + String key, + Object value, + Object shortValue, + String label, + String linkUrl) throws IOException { + Class nsClass = component.getClass(); + + StringBuilder builder = new StringBuilder(); + builder.append("{"); + builder.append(String.format("\"namespace\":\"%s\",", nsClass.getName())); + builder.append(String.format("\"type\":\"%s\",", type)); + builder.append(String.format("\"key\":\"%s\",", key)); + builder.append(String.format("\"value\":%s", value)); + + if (shortValue != null) { + builder.append(String.format(",\"shortValue\":%s", shortValue)); + } + if (label != null) { + builder.append(String.format(",\"label\":\"%s\"", label)); + } + if (linkUrl != null) { + builder.append(String.format(",\"linkUrl\":\"%s\"", linkUrl)); + } + + builder.append("}"); + + JsonNode jsonNode = MAPPER.readTree(builder.toString()); + return hasItem(jsonNode); + } + private static Matcher hasLabel(Matcher labelMatcher) { return new FeatureMatcher( labelMatcher, "display item with label", "label") { @@ -778,12 +941,14 @@ protected String featureValueOf(DisplayData.Item actual) { }; } - private static Matcher hasShortValue(Matcher valueStringMatcher) { - return new FeatureMatcher( + private static Matcher hasShortValue(Matcher valueStringMatcher) { + return new FeatureMatcher( valueStringMatcher, "display item with short value", "short value") { @Override - protected String featureValueOf(DisplayData.Item actual) { - return actual.getShortValue(); + protected T featureValueOf(DisplayData.Item actual) { + @SuppressWarnings("unchecked") + T shortValue = (T) actual.getShortValue(); + return shortValue; } }; }