Skip to content

Commit 1b0f54c

Browse files
committed
Make open stream methods in backends more explicit
1 parent fdefc2a commit 1b0f54c

File tree

2 files changed

+27
-32
lines changed

2 files changed

+27
-32
lines changed

httpx/concurrency/asyncio.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,16 @@ async def open_tcp_stream(
263263
ssl_context: typing.Optional[ssl.SSLContext],
264264
timeout: TimeoutConfig,
265265
) -> SocketStream:
266-
return await self._open_stream(
267-
asyncio.open_connection(hostname, port, ssl=ssl_context), timeout
266+
try:
267+
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
268+
asyncio.open_connection(hostname, port, ssl=ssl_context),
269+
timeout.connect_timeout,
270+
)
271+
except asyncio.TimeoutError:
272+
raise ConnectTimeout()
273+
274+
return SocketStream(
275+
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
268276
)
269277

270278
async def open_uds_stream(
@@ -275,23 +283,13 @@ async def open_uds_stream(
275283
timeout: TimeoutConfig,
276284
) -> SocketStream:
277285
server_hostname = hostname if ssl_context else None
278-
return await self._open_stream(
279-
asyncio.open_unix_connection(
280-
path, ssl=ssl_context, server_hostname=server_hostname
281-
),
282-
timeout,
283-
)
284286

285-
async def _open_stream(
286-
self,
287-
socket_stream: typing.Awaitable[
288-
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]
289-
],
290-
timeout: TimeoutConfig,
291-
) -> SocketStream:
292287
try:
293288
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
294-
socket_stream, timeout.connect_timeout,
289+
asyncio.open_unix_connection(
290+
path, ssl=ssl_context, server_hostname=server_hostname
291+
),
292+
timeout.connect_timeout,
295293
)
296294
except asyncio.TimeoutError:
297295
raise ConnectTimeout()

httpx/concurrency/trio.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -178,33 +178,30 @@ async def open_tcp_stream(
178178
ssl_context: typing.Optional[ssl.SSLContext],
179179
timeout: TimeoutConfig,
180180
) -> SocketStream:
181-
return await self._open_stream(
182-
trio.open_tcp_stream(hostname, port), hostname, ssl_context, timeout
183-
)
181+
connect_timeout = _or_inf(timeout.connect_timeout)
182+
183+
with trio.move_on_after(connect_timeout) as cancel_scope:
184+
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
185+
if ssl_context is not None:
186+
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
187+
await stream.do_handshake()
188+
189+
if cancel_scope.cancelled_caught:
190+
raise ConnectTimeout()
191+
192+
return SocketStream(stream=stream, timeout=timeout)
184193

185194
async def open_uds_stream(
186195
self,
187196
path: str,
188197
hostname: typing.Optional[str],
189198
ssl_context: typing.Optional[ssl.SSLContext],
190199
timeout: TimeoutConfig,
191-
) -> SocketStream:
192-
hostname = hostname if ssl_context else None
193-
return await self._open_stream(
194-
trio.open_unix_socket(path), hostname, ssl_context, timeout
195-
)
196-
197-
async def _open_stream(
198-
self,
199-
socket_stream: typing.Awaitable[trio.SocketStream],
200-
hostname: typing.Optional[str],
201-
ssl_context: typing.Optional[ssl.SSLContext],
202-
timeout: TimeoutConfig,
203200
) -> SocketStream:
204201
connect_timeout = _or_inf(timeout.connect_timeout)
205202

206203
with trio.move_on_after(connect_timeout) as cancel_scope:
207-
stream: trio.SocketStream = await socket_stream
204+
stream: trio.SocketStream = await trio.open_unix_socket(path)
208205
if ssl_context is not None:
209206
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
210207
await stream.do_handshake()

0 commit comments

Comments
 (0)