Skip to content

Commit 49b936d

Browse files
committed
add boolean option for ASGITransport streaming
1 parent a0b2cc7 commit 49b936d

File tree

2 files changed

+114
-50
lines changed

2 files changed

+114
-50
lines changed

httpx/_transports/asgi.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import typing
45

56
from .._models import Request, Response
@@ -101,11 +102,9 @@ def __init__(
101102
self,
102103
ignore_body: bool,
103104
asgi_generator: typing.AsyncGenerator[_Message, None],
104-
disconnect_request_event: Event,
105105
) -> None:
106106
self._ignore_body = ignore_body
107107
self._asgi_generator = asgi_generator
108-
self._disconnect_request_event = disconnect_request_event
109108

110109
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
111110
more_body = True
@@ -118,13 +117,10 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
118117
more_body = message.get("more_body", False)
119118
if chunk and not self._ignore_body:
120119
yield chunk
121-
if not more_body:
122-
self._disconnect_request_event.set()
123120
finally:
124121
await self.aclose()
125122

126123
async def aclose(self) -> None:
127-
self._disconnect_request_event.set()
128124
await self._asgi_generator.aclose()
129125

130126

@@ -149,6 +145,9 @@ class ASGITransport(AsyncBaseTransport):
149145
such as testing the content of a client 500 response.
150146
* `root_path` - The root path on which the ASGI application should be mounted.
151147
* `client` - A two-tuple indicating the client IP and port of incoming requests.
148+
* `streaming` - Set to `True` to enable streaming of response content. Default to
149+
`False`, as activating this feature means that the ASGI `app` will run in a
150+
sub-task, which has observable side effects for context variables.
152151
```
153152
"""
154153

@@ -158,18 +157,20 @@ def __init__(
158157
raise_app_exceptions: bool = True,
159158
root_path: str = "",
160159
client: tuple[str, int] = ("127.0.0.1", 123),
160+
*,
161+
streaming: bool = False,
161162
) -> None:
162163
self.app = app
163164
self.raise_app_exceptions = raise_app_exceptions
164165
self.root_path = root_path
165166
self.client = client
167+
self.streaming = streaming
166168

167169
async def handle_async_request(
168170
self,
169171
request: Request,
170172
) -> Response:
171-
disconnect_request_event = create_event()
172-
asgi_generator = self._stream_asgi_messages(request, disconnect_request_event)
173+
asgi_generator = self._stream_asgi_messages(request)
173174

174175
async for message in asgi_generator:
175176
if message["type"] == "http.response.start":
@@ -179,15 +180,13 @@ async def handle_async_request(
179180
stream=ASGIResponseStream(
180181
ignore_body=request.method == "HEAD",
181182
asgi_generator=asgi_generator,
182-
disconnect_request_event=disconnect_request_event,
183183
),
184184
)
185185
else:
186-
disconnect_request_event.set()
187186
return Response(status_code=500, headers=[])
188187

189188
async def _stream_asgi_messages(
190-
self, request: Request, disconnect_request_event: Event
189+
self, request: Request
191190
) -> typing.AsyncGenerator[typing.MutableMapping[str, typing.Any]]:
192191
assert isinstance(request.stream, AsyncByteStream)
193192

@@ -211,9 +210,13 @@ async def _stream_asgi_messages(
211210
request_body_chunks = request.stream.__aiter__()
212211
request_complete = False
213212

213+
# Response.
214+
response_complete = create_event()
215+
214216
# ASGI response messages stream
217+
stream_size = 0 if self.streaming else float("inf")
215218
response_message_send_stream, response_message_recv_stream = (
216-
create_memory_object_stream(0)
219+
create_memory_object_stream(stream_size)
217220
)
218221

219222
# ASGI app exception
@@ -225,7 +228,7 @@ async def receive() -> _Message:
225228
nonlocal request_complete
226229

227230
if request_complete:
228-
await disconnect_request_event.wait()
231+
await response_complete.wait()
229232
return {"type": "http.disconnect"}
230233

231234
try:
@@ -235,17 +238,29 @@ async def receive() -> _Message:
235238
return {"type": "http.request", "body": b"", "more_body": False}
236239
return {"type": "http.request", "body": body, "more_body": True}
237240

241+
async def send(message: _Message) -> None:
242+
await response_message_send_stream.send(message)
243+
if message["type"] == "http.response.body" and not message.get(
244+
"more_body", False
245+
):
246+
response_complete.set()
247+
238248
async def run_app() -> None:
239249
nonlocal app_exception
240250
try:
241-
await self.app(scope, receive, response_message_send_stream.send)
251+
await self.app(scope, receive, send)
242252
except Exception as ex:
243253
app_exception = ex
244254
finally:
245255
await response_message_send_stream.aclose()
246256

247-
async with create_task_group() as task_group:
248-
task_group.start_soon(run_app)
257+
async with contextlib.AsyncExitStack() as exit_stack:
258+
exit_stack.callback(response_complete.set)
259+
if self.streaming:
260+
task_group = await exit_stack.enter_async_context(create_task_group())
261+
task_group.start_soon(run_app)
262+
else:
263+
await run_app()
249264

250265
async with response_message_recv_stream:
251266
try:

0 commit comments

Comments
 (0)