diff --git a/src/vws/transports.py b/src/vws/transports.py index 80aa989f..74fadb9f 100644 --- a/src/vws/transports.py +++ b/src/vws/transports.py @@ -149,21 +149,22 @@ def __call__( Returns: A Response populated from the httpx response. """ - if isinstance(request_timeout, tuple): - connect_timeout, read_timeout = request_timeout - httpx_timeout = httpx.Timeout( - connect=connect_timeout, - read=read_timeout, - write=None, - pool=None, - ) - else: - httpx_timeout = httpx.Timeout( - connect=request_timeout, - read=request_timeout, - write=None, - pool=None, - ) + match request_timeout: + case tuple() as timeout: + connect_timeout, read_timeout = timeout + httpx_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=None, + pool=None, + ) + case timeout: + httpx_timeout = httpx.Timeout( + connect=timeout, + read=timeout, + write=None, + pool=None, + ) httpx_response = self._client.request( method=method, @@ -272,21 +273,22 @@ async def __call__( Returns: A Response populated from the httpx response. """ - if isinstance(request_timeout, tuple): - connect_timeout, read_timeout = request_timeout - httpx_timeout = httpx.Timeout( - connect=connect_timeout, - read=read_timeout, - write=None, - pool=None, - ) - else: - httpx_timeout = httpx.Timeout( - connect=request_timeout, - read=request_timeout, - write=None, - pool=None, - ) + match request_timeout: + case tuple() as timeout: + connect_timeout, read_timeout = timeout + httpx_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=None, + pool=None, + ) + case timeout: + httpx_timeout = httpx.Timeout( + connect=timeout, + read=timeout, + write=None, + pool=None, + ) httpx_response = await self._client.request( method=method, diff --git a/tests/test_transports.py b/tests/test_transports.py index 4130859d..752b0ef2 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -61,6 +61,28 @@ def test_tuple_timeout() -> None: assert isinstance(response, Response) assert response.status_code == HTTPStatus.OK + @staticmethod + @respx.mock + def test_int_timeout() -> None: + """``HTTPXTransport`` works with an int timeout.""" + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = HTTPXTransport() + response = transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + request_timeout=30, + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + @staticmethod @respx.mock def test_context_manager() -> None: @@ -137,6 +159,29 @@ async def test_tuple_timeout() -> None: assert isinstance(response, Response) assert response.status_code == HTTPStatus.OK + @staticmethod + @pytest.mark.asyncio + @respx.mock + async def test_int_timeout() -> None: + """``AsyncHTTPXTransport`` works with an int timeout.""" + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = AsyncHTTPXTransport() + response = await transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + request_timeout=30, + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + @staticmethod @pytest.mark.asyncio @respx.mock