@@ -61,10 +61,26 @@ def __init__(
6161 self ._sent_connection_init = False
6262 self ._used_all_stream_ids = False
6363 self ._connection_error = False
64- self ._events : typing .Dict [int , h2 .events .Event ] = {}
64+
65+ # Mapping from stream ID to response stream events.
66+ self ._events : typing .Dict [
67+ int ,
68+ typing .Union [
69+ h2 .events .ResponseReceived ,
70+ h2 .events .DataReceived ,
71+ h2 .events .StreamEnded ,
72+ h2 .events .StreamReset ,
73+ ],
74+ ] = {}
75+
76+ # Connection terminated events are stored as state since
77+ # we need to handle them for all streams.
78+ self ._connection_terminated : typing .Optional [
79+ h2 .events .ConnectionTerminated
80+ ] = None
81+
6582 self ._read_exception : typing .Optional [Exception ] = None
6683 self ._write_exception : typing .Optional [Exception ] = None
67- self ._connection_error_event : typing .Optional [h2 .events .Event ] = None
6884
6985 async def handle_async_request (self , request : Request ) -> Response :
7086 if not self .can_handle_request (request .url .origin ):
@@ -111,6 +127,7 @@ async def handle_async_request(self, request: Request) -> Response:
111127 self ._events [stream_id ] = []
112128 except h2 .exceptions .NoAvailableStreamIDError : # pragma: nocover
113129 self ._used_all_stream_ids = True
130+ self ._request_count -= 1
114131 raise ConnectionNotAvailable ()
115132
116133 try :
@@ -152,8 +169,8 @@ async def handle_async_request(self, request: Request) -> Response:
152169 #
153170 # In this case we'll have stored the event, and should raise
154171 # it as a RemoteProtocolError.
155- if self ._connection_error_event :
156- raise RemoteProtocolError (self ._connection_error_event )
172+ if self ._connection_terminated : # pragma: nocover
173+ raise RemoteProtocolError (self ._connection_terminated )
157174 # If h2 raises a protocol error in some other state then we
158175 # must somehow have made a protocol violation.
159176 raise LocalProtocolError (exc ) # pragma: nocover
@@ -292,12 +309,14 @@ async def _receive_response_body(
292309 self ._h2_state .acknowledge_received_data (amount , stream_id )
293310 await self ._write_outgoing_data (request )
294311 yield event .data
295- elif isinstance (event , ( h2 .events .StreamEnded , h2 . events . StreamReset ) ):
312+ elif isinstance (event , h2 .events .StreamEnded ):
296313 break
297314
298315 async def _receive_stream_event (
299316 self , request : Request , stream_id : int
300- ) -> h2 .events .Event :
317+ ) -> typing .Union [
318+ h2 .events .ResponseReceived , h2 .events .DataReceived , h2 .events .StreamEnded
319+ ]:
301320 """
302321 Return the next available event for a given stream ID.
303322
@@ -306,8 +325,7 @@ async def _receive_stream_event(
306325 while not self ._events .get (stream_id ):
307326 await self ._receive_events (request , stream_id )
308327 event = self ._events [stream_id ].pop (0 )
309- # The StreamReset event applies to a single stream.
310- if hasattr (event , "error_code" ):
328+ if isinstance (event , h2 .events .StreamReset ):
311329 raise RemoteProtocolError (event )
312330 return event
313331
@@ -319,8 +337,12 @@ async def _receive_events(
319337 for a given stream ID.
320338 """
321339 async with self ._read_lock :
322- if self ._connection_error_event is not None : # pragma: nocover
323- raise RemoteProtocolError (self ._connection_error_event )
340+ if self ._connection_terminated is not None :
341+ last_stream_id = self ._connection_terminated .last_stream_id
342+ if stream_id and last_stream_id and stream_id > last_stream_id :
343+ self ._request_count -= 1
344+ raise ConnectionNotAvailable ()
345+ raise RemoteProtocolError (self ._connection_terminated )
324346
325347 # This conditional is a bit icky. We don't want to block reading if we've
326348 # actually got an event to return for a given stream. We need to do that
@@ -338,16 +360,20 @@ async def _receive_events(
338360 await self ._receive_remote_settings_change (event )
339361 trace .return_value = event
340362
341- event_stream_id = getattr (event , "stream_id" , 0 )
342-
343- # The ConnectionTerminatedEvent applies to the entire connection,
344- # and should be saved so it can be raised on all streams.
345- if hasattr (event , "error_code" ) and event_stream_id == 0 :
346- self ._connection_error_event = event
347- raise RemoteProtocolError (event )
348-
349- if event_stream_id in self ._events :
350- self ._events [event_stream_id ].append (event )
363+ elif isinstance (
364+ event ,
365+ (
366+ h2 .events .ResponseReceived ,
367+ h2 .events .DataReceived ,
368+ h2 .events .StreamEnded ,
369+ h2 .events .StreamReset ,
370+ ),
371+ ):
372+ if event .stream_id in self ._events :
373+ self ._events [event .stream_id ].append (event )
374+
375+ elif isinstance (event , h2 .events .ConnectionTerminated ):
376+ self ._connection_terminated = event
351377
352378 await self ._write_outgoing_data (request )
353379
@@ -372,7 +398,10 @@ async def _response_closed(self, stream_id: int) -> None:
372398 await self ._max_streams_semaphore .release ()
373399 del self ._events [stream_id ]
374400 async with self ._state_lock :
375- if self ._state == HTTPConnectionState .ACTIVE and not self ._events :
401+ if self ._connection_terminated and not self ._events :
402+ await self .aclose ()
403+
404+ elif self ._state == HTTPConnectionState .ACTIVE and not self ._events :
376405 self ._state = HTTPConnectionState .IDLE
377406 if self ._keepalive_expiry is not None :
378407 now = time .monotonic ()
0 commit comments