Skip to content

Commit 8d31f0b

Browse files
Graceful handling for HTTP/2 GoAway frames. (#733)
* Refactor HTTP/2 stream events * Refactor HTTP/2 stream events * Graceful handling of HTTP/2 GoAway frames * Add to CHANGELOG * Remove unneccessary getattr * Conditional fix * Conditional fix
1 parent aacdbb9 commit 8d31f0b

File tree

7 files changed

+357
-48
lines changed

7 files changed

+357
-48
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
77
## unreleased
88

99
- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699)
10+
- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730)
1011
- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717)
1112

1213
## 0.17.2 (May 23th, 2023)

httpcore/_async/http2.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,26 @@ def __init__(
6161
self._sent_connection_init = False
6262
self._used_all_stream_ids = False
6363
self._connection_error = False
64-
self._events: typing.Dict[int, h2.events.Event] = {}
64+
65+
# Mapping from stream ID to response stream events.
66+
self._events: typing.Dict[
67+
int,
68+
typing.Union[
69+
h2.events.ResponseReceived,
70+
h2.events.DataReceived,
71+
h2.events.StreamEnded,
72+
h2.events.StreamReset,
73+
],
74+
] = {}
75+
76+
# Connection terminated events are stored as state since
77+
# we need to handle them for all streams.
78+
self._connection_terminated: typing.Optional[
79+
h2.events.ConnectionTerminated
80+
] = None
81+
6582
self._read_exception: typing.Optional[Exception] = None
6683
self._write_exception: typing.Optional[Exception] = None
67-
self._connection_error_event: typing.Optional[h2.events.Event] = None
6884

6985
async def handle_async_request(self, request: Request) -> Response:
7086
if not self.can_handle_request(request.url.origin):
@@ -111,6 +127,7 @@ async def handle_async_request(self, request: Request) -> Response:
111127
self._events[stream_id] = []
112128
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
113129
self._used_all_stream_ids = True
130+
self._request_count -= 1
114131
raise ConnectionNotAvailable()
115132

116133
try:
@@ -152,8 +169,8 @@ async def handle_async_request(self, request: Request) -> Response:
152169
#
153170
# In this case we'll have stored the event, and should raise
154171
# it as a RemoteProtocolError.
155-
if self._connection_error_event:
156-
raise RemoteProtocolError(self._connection_error_event)
172+
if self._connection_terminated: # pragma: nocover
173+
raise RemoteProtocolError(self._connection_terminated)
157174
# If h2 raises a protocol error in some other state then we
158175
# must somehow have made a protocol violation.
159176
raise LocalProtocolError(exc) # pragma: nocover
@@ -292,12 +309,14 @@ async def _receive_response_body(
292309
self._h2_state.acknowledge_received_data(amount, stream_id)
293310
await self._write_outgoing_data(request)
294311
yield event.data
295-
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
312+
elif isinstance(event, h2.events.StreamEnded):
296313
break
297314

298315
async def _receive_stream_event(
299316
self, request: Request, stream_id: int
300-
) -> h2.events.Event:
317+
) -> typing.Union[
318+
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
319+
]:
301320
"""
302321
Return the next available event for a given stream ID.
303322
@@ -306,8 +325,7 @@ async def _receive_stream_event(
306325
while not self._events.get(stream_id):
307326
await self._receive_events(request, stream_id)
308327
event = self._events[stream_id].pop(0)
309-
# The StreamReset event applies to a single stream.
310-
if hasattr(event, "error_code"):
328+
if isinstance(event, h2.events.StreamReset):
311329
raise RemoteProtocolError(event)
312330
return event
313331

@@ -319,8 +337,12 @@ async def _receive_events(
319337
for a given stream ID.
320338
"""
321339
async with self._read_lock:
322-
if self._connection_error_event is not None: # pragma: nocover
323-
raise RemoteProtocolError(self._connection_error_event)
340+
if self._connection_terminated is not None:
341+
last_stream_id = self._connection_terminated.last_stream_id
342+
if stream_id and last_stream_id and stream_id > last_stream_id:
343+
self._request_count -= 1
344+
raise ConnectionNotAvailable()
345+
raise RemoteProtocolError(self._connection_terminated)
324346

325347
# This conditional is a bit icky. We don't want to block reading if we've
326348
# actually got an event to return for a given stream. We need to do that
@@ -338,16 +360,20 @@ async def _receive_events(
338360
await self._receive_remote_settings_change(event)
339361
trace.return_value = event
340362

341-
event_stream_id = getattr(event, "stream_id", 0)
342-
343-
# The ConnectionTerminatedEvent applies to the entire connection,
344-
# and should be saved so it can be raised on all streams.
345-
if hasattr(event, "error_code") and event_stream_id == 0:
346-
self._connection_error_event = event
347-
raise RemoteProtocolError(event)
348-
349-
if event_stream_id in self._events:
350-
self._events[event_stream_id].append(event)
363+
elif isinstance(
364+
event,
365+
(
366+
h2.events.ResponseReceived,
367+
h2.events.DataReceived,
368+
h2.events.StreamEnded,
369+
h2.events.StreamReset,
370+
),
371+
):
372+
if event.stream_id in self._events:
373+
self._events[event.stream_id].append(event)
374+
375+
elif isinstance(event, h2.events.ConnectionTerminated):
376+
self._connection_terminated = event
351377

352378
await self._write_outgoing_data(request)
353379

@@ -372,7 +398,10 @@ async def _response_closed(self, stream_id: int) -> None:
372398
await self._max_streams_semaphore.release()
373399
del self._events[stream_id]
374400
async with self._state_lock:
375-
if self._state == HTTPConnectionState.ACTIVE and not self._events:
401+
if self._connection_terminated and not self._events:
402+
await self.aclose()
403+
404+
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
376405
self._state = HTTPConnectionState.IDLE
377406
if self._keepalive_expiry is not None:
378407
now = time.monotonic()

httpcore/_sync/http2.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,26 @@ def __init__(
6161
self._sent_connection_init = False
6262
self._used_all_stream_ids = False
6363
self._connection_error = False
64-
self._events: typing.Dict[int, h2.events.Event] = {}
64+
65+
# Mapping from stream ID to response stream events.
66+
self._events: typing.Dict[
67+
int,
68+
typing.Union[
69+
h2.events.ResponseReceived,
70+
h2.events.DataReceived,
71+
h2.events.StreamEnded,
72+
h2.events.StreamReset,
73+
],
74+
] = {}
75+
76+
# Connection terminated events are stored as state since
77+
# we need to handle them for all streams.
78+
self._connection_terminated: typing.Optional[
79+
h2.events.ConnectionTerminated
80+
] = None
81+
6582
self._read_exception: typing.Optional[Exception] = None
6683
self._write_exception: typing.Optional[Exception] = None
67-
self._connection_error_event: typing.Optional[h2.events.Event] = None
6884

6985
def handle_request(self, request: Request) -> Response:
7086
if not self.can_handle_request(request.url.origin):
@@ -111,6 +127,7 @@ def handle_request(self, request: Request) -> Response:
111127
self._events[stream_id] = []
112128
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
113129
self._used_all_stream_ids = True
130+
self._request_count -= 1
114131
raise ConnectionNotAvailable()
115132

116133
try:
@@ -152,8 +169,8 @@ def handle_request(self, request: Request) -> Response:
152169
#
153170
# In this case we'll have stored the event, and should raise
154171
# it as a RemoteProtocolError.
155-
if self._connection_error_event:
156-
raise RemoteProtocolError(self._connection_error_event)
172+
if self._connection_terminated: # pragma: nocover
173+
raise RemoteProtocolError(self._connection_terminated)
157174
# If h2 raises a protocol error in some other state then we
158175
# must somehow have made a protocol violation.
159176
raise LocalProtocolError(exc) # pragma: nocover
@@ -292,12 +309,14 @@ def _receive_response_body(
292309
self._h2_state.acknowledge_received_data(amount, stream_id)
293310
self._write_outgoing_data(request)
294311
yield event.data
295-
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
312+
elif isinstance(event, h2.events.StreamEnded):
296313
break
297314

298315
def _receive_stream_event(
299316
self, request: Request, stream_id: int
300-
) -> h2.events.Event:
317+
) -> typing.Union[
318+
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
319+
]:
301320
"""
302321
Return the next available event for a given stream ID.
303322
@@ -306,8 +325,7 @@ def _receive_stream_event(
306325
while not self._events.get(stream_id):
307326
self._receive_events(request, stream_id)
308327
event = self._events[stream_id].pop(0)
309-
# The StreamReset event applies to a single stream.
310-
if hasattr(event, "error_code"):
328+
if isinstance(event, h2.events.StreamReset):
311329
raise RemoteProtocolError(event)
312330
return event
313331

@@ -319,8 +337,12 @@ def _receive_events(
319337
for a given stream ID.
320338
"""
321339
with self._read_lock:
322-
if self._connection_error_event is not None: # pragma: nocover
323-
raise RemoteProtocolError(self._connection_error_event)
340+
if self._connection_terminated is not None:
341+
last_stream_id = self._connection_terminated.last_stream_id
342+
if stream_id and last_stream_id and stream_id > last_stream_id:
343+
self._request_count -= 1
344+
raise ConnectionNotAvailable()
345+
raise RemoteProtocolError(self._connection_terminated)
324346

325347
# This conditional is a bit icky. We don't want to block reading if we've
326348
# actually got an event to return for a given stream. We need to do that
@@ -338,16 +360,20 @@ def _receive_events(
338360
self._receive_remote_settings_change(event)
339361
trace.return_value = event
340362

341-
event_stream_id = getattr(event, "stream_id", 0)
342-
343-
# The ConnectionTerminatedEvent applies to the entire connection,
344-
# and should be saved so it can be raised on all streams.
345-
if hasattr(event, "error_code") and event_stream_id == 0:
346-
self._connection_error_event = event
347-
raise RemoteProtocolError(event)
348-
349-
if event_stream_id in self._events:
350-
self._events[event_stream_id].append(event)
363+
elif isinstance(
364+
event,
365+
(
366+
h2.events.ResponseReceived,
367+
h2.events.DataReceived,
368+
h2.events.StreamEnded,
369+
h2.events.StreamReset,
370+
),
371+
):
372+
if event.stream_id in self._events:
373+
self._events[event.stream_id].append(event)
374+
375+
elif isinstance(event, h2.events.ConnectionTerminated):
376+
self._connection_terminated = event
351377

352378
self._write_outgoing_data(request)
353379

@@ -372,7 +398,10 @@ def _response_closed(self, stream_id: int) -> None:
372398
self._max_streams_semaphore.release()
373399
del self._events[stream_id]
374400
with self._state_lock:
375-
if self._state == HTTPConnectionState.ACTIVE and not self._events:
401+
if self._connection_terminated and not self._events:
402+
self.close()
403+
404+
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
376405
self._state = HTTPConnectionState.IDLE
377406
if self._keepalive_expiry is not None:
378407
now = time.monotonic()

0 commit comments

Comments
 (0)