11from __future__ import annotations
22
3+ import contextlib
34import typing
45
56from .._models import Request , Response
@@ -101,11 +102,9 @@ def __init__(
101102 self ,
102103 ignore_body : bool ,
103104 asgi_generator : typing .AsyncGenerator [_Message , None ],
104- disconnect_request_event : Event ,
105105 ) -> None :
106106 self ._ignore_body = ignore_body
107107 self ._asgi_generator = asgi_generator
108- self ._disconnect_request_event = disconnect_request_event
109108
110109 async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
111110 more_body = True
@@ -118,13 +117,10 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
118117 more_body = message .get ("more_body" , False )
119118 if chunk and not self ._ignore_body :
120119 yield chunk
121- if not more_body :
122- self ._disconnect_request_event .set ()
123120 finally :
124121 await self .aclose ()
125122
126123 async def aclose (self ) -> None :
127- self ._disconnect_request_event .set ()
128124 await self ._asgi_generator .aclose ()
129125
130126
@@ -149,6 +145,9 @@ class ASGITransport(AsyncBaseTransport):
149145 such as testing the content of a client 500 response.
150146 * `root_path` - The root path on which the ASGI application should be mounted.
151147 * `client` - A two-tuple indicating the client IP and port of incoming requests.
148+ * `streaming` - Set to `True` to enable streaming of response content. Default to
149+ `False`, as activating this feature means that the ASGI `app` will run in a
150+ sub-task, which has observable side effects for context variables.
152151 ```
153152 """
154153
@@ -158,18 +157,20 @@ def __init__(
158157 raise_app_exceptions : bool = True ,
159158 root_path : str = "" ,
160159 client : tuple [str , int ] = ("127.0.0.1" , 123 ),
160+ * ,
161+ streaming : bool = False ,
161162 ) -> None :
162163 self .app = app
163164 self .raise_app_exceptions = raise_app_exceptions
164165 self .root_path = root_path
165166 self .client = client
167+ self .streaming = streaming
166168
167169 async def handle_async_request (
168170 self ,
169171 request : Request ,
170172 ) -> Response :
171- disconnect_request_event = create_event ()
172- asgi_generator = self ._stream_asgi_messages (request , disconnect_request_event )
173+ asgi_generator = self ._stream_asgi_messages (request )
173174
174175 async for message in asgi_generator :
175176 if message ["type" ] == "http.response.start" :
@@ -179,15 +180,13 @@ async def handle_async_request(
179180 stream = ASGIResponseStream (
180181 ignore_body = request .method == "HEAD" ,
181182 asgi_generator = asgi_generator ,
182- disconnect_request_event = disconnect_request_event ,
183183 ),
184184 )
185185 else :
186- disconnect_request_event .set ()
187186 return Response (status_code = 500 , headers = [])
188187
189188 async def _stream_asgi_messages (
190- self , request : Request , disconnect_request_event : Event
189+ self , request : Request
191190 ) -> typing .AsyncGenerator [typing .MutableMapping [str , typing .Any ]]:
192191 assert isinstance (request .stream , AsyncByteStream )
193192
@@ -211,9 +210,13 @@ async def _stream_asgi_messages(
211210 request_body_chunks = request .stream .__aiter__ ()
212211 request_complete = False
213212
213+ # Response.
214+ response_complete = create_event ()
215+
214216 # ASGI response messages stream
217+ stream_size = 0 if self .streaming else float ("inf" )
215218 response_message_send_stream , response_message_recv_stream = (
216- create_memory_object_stream (0 )
219+ create_memory_object_stream (stream_size )
217220 )
218221
219222 # ASGI app exception
@@ -225,7 +228,7 @@ async def receive() -> _Message:
225228 nonlocal request_complete
226229
227230 if request_complete :
228- await disconnect_request_event .wait ()
231+ await response_complete .wait ()
229232 return {"type" : "http.disconnect" }
230233
231234 try :
@@ -235,17 +238,29 @@ async def receive() -> _Message:
235238 return {"type" : "http.request" , "body" : b"" , "more_body" : False }
236239 return {"type" : "http.request" , "body" : body , "more_body" : True }
237240
241+ async def send (message : _Message ) -> None :
242+ await response_message_send_stream .send (message )
243+ if message ["type" ] == "http.response.body" and not message .get (
244+ "more_body" , False
245+ ):
246+ response_complete .set ()
247+
238248 async def run_app () -> None :
239249 nonlocal app_exception
240250 try :
241- await self .app (scope , receive , response_message_send_stream . send )
251+ await self .app (scope , receive , send )
242252 except Exception as ex :
243253 app_exception = ex
244254 finally :
245255 await response_message_send_stream .aclose ()
246256
247- async with create_task_group () as task_group :
248- task_group .start_soon (run_app )
257+ async with contextlib .AsyncExitStack () as exit_stack :
258+ exit_stack .callback (response_complete .set )
259+ if self .streaming :
260+ task_group = await exit_stack .enter_async_context (create_task_group ())
261+ task_group .start_soon (run_app )
262+ else :
263+ await run_app ()
249264
250265 async with response_message_recv_stream :
251266 try :
0 commit comments