diff --git a/tracing-jersey/src/main/java/com/palantir/tracing/jersey/TraceEnrichingFilter.java b/tracing-jersey/src/main/java/com/palantir/tracing/jersey/TraceEnrichingFilter.java index 5ceaa6eb9..5bd63dab1 100644 --- a/tracing-jersey/src/main/java/com/palantir/tracing/jersey/TraceEnrichingFilter.java +++ b/tracing-jersey/src/main/java/com/palantir/tracing/jersey/TraceEnrichingFilter.java @@ -87,6 +87,12 @@ public void filter(ContainerRequestContext requestContext, ContainerResponseCont if (maybeSpan.isPresent()) { Span span = maybeSpan.get(); headers.putSingle(TraceHttpHeaders.TRACE_ID, span.getTraceId()); + } else { + // When the filter is called twice (e.g. an exception is thrown in a streaming call), + // the current trace will be empty. To allow clients to still get the trace ID corresponding to + // the failure, we retrieve it from the requestContext. + Optional.ofNullable(requestContext.getProperty(TRACE_ID_PROPERTY_NAME)) + .ifPresent(s -> headers.putSingle(TraceHttpHeaders.TRACE_ID, s)); } } diff --git a/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java b/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java index f67495611..ffbc2bd94 100644 --- a/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java +++ b/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java @@ -49,6 +49,7 @@ import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import javax.ws.rs.core.StreamingOutput; import javax.ws.rs.core.UriInfo; import org.glassfish.jersey.client.JerseyClientBuilder; import org.junit.After; @@ -156,6 +157,36 @@ public void testTraceState_withoutRequestHeadersGeneratesValidTraceResponseHeade assertThat(spanCaptor.getValue().getOperation(), is("GET /trace")); } + @Test + public void testTraceState_withoutRequestHeadersGeneratesValidTraceResponseHeadersWhenFailing() { + Response response = target.path("/failing-trace").request().get(); + assertThat(response.getHeaderString(TraceHttpHeaders.TRACE_ID), not(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.PARENT_SPAN_ID), is(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.SPAN_ID), is(nullValue())); + verify(observer).consume(spanCaptor.capture()); + assertThat(spanCaptor.getValue().getOperation(), is("GET /failing-trace")); + } + + @Test + public void testTraceState_withoutRequestHeadersGeneratesValidTraceResponseHeadersWhenStreaming() { + Response response = target.path("/streaming-trace").request().get(); + assertThat(response.getHeaderString(TraceHttpHeaders.TRACE_ID), not(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.PARENT_SPAN_ID), is(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.SPAN_ID), is(nullValue())); + verify(observer).consume(spanCaptor.capture()); + assertThat(spanCaptor.getValue().getOperation(), is("GET /streaming-trace")); + } + + @Test + public void testTraceState_withoutRequestHeadersGeneratesValidTraceResponseHeadersWhenFailingToStream() { + Response response = target.path("/failing-streaming-trace").request().get(); + assertThat(response.getHeaderString(TraceHttpHeaders.TRACE_ID), not(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.PARENT_SPAN_ID), is(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.SPAN_ID), is(nullValue())); + verify(observer).consume(spanCaptor.capture()); + assertThat(spanCaptor.getValue().getOperation(), is("GET /failing-streaming-trace")); + } + @Test public void testTraceState_withSamplingHeaderWithoutTraceIdDoesNotUseTraceSampler() { target.path("/trace").request() @@ -216,13 +247,33 @@ public final void run(Configuration config, final Environment env) throws Except public static final class TracingTestResource implements TracingTestService { @Override - public void getTraceOperation() {} + public void getTraceOperation() { + throw new RuntimeException("FAIL"); + } @Override public void postTraceOperation() {} @Override public void getTraceWithPathParam() {} + + @Override + public void getFailingTraceOperation() { + throw new RuntimeException(); + } + + @Override + public StreamingOutput getFailingStreamingTraceOperation() { + return os -> { + throw new RuntimeException(); + }; + } + + @Override + public StreamingOutput getStreamingTraceOperation() { + return os -> { + }; + } } @Path("/") @@ -240,5 +291,19 @@ public interface TracingTestService { @GET @Path("/trace/{param}") void getTraceWithPathParam(); + + @GET + @Path("/failing-trace") + void getFailingTraceOperation(); + + @GET + @Path("/failing-streaming-trace") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + StreamingOutput getFailingStreamingTraceOperation(); + + @GET + @Path("/streaming-trace") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + StreamingOutput getStreamingTraceOperation(); } }