| from __future__ import annotations |
|
|
| import hashlib |
| import os |
| import re |
| import time |
| import typing |
| from base64 import b64encode |
| from urllib.request import parse_http_list |
|
|
| from ._exceptions import ProtocolError |
| from ._models import Cookies, Request, Response |
| from ._utils import to_bytes, to_str, unquote |
|
|
| if typing.TYPE_CHECKING: |
| from hashlib import _Hash |
|
|
|
|
| __all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"] |
|
|
|
|
| class Auth: |
| """ |
| Base class for all authentication schemes. |
| |
| To implement a custom authentication scheme, subclass `Auth` and override |
| the `.auth_flow()` method. |
| |
| If the authentication scheme does I/O such as disk access or network calls, or uses |
| synchronization primitives such as locks, you should override `.sync_auth_flow()` |
| and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized |
| implementations that will be used by `Client` and `AsyncClient` respectively. |
| """ |
|
|
| requires_request_body = False |
| requires_response_body = False |
|
|
| def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| """ |
| Execute the authentication flow. |
| |
| To dispatch a request, `yield` it: |
| |
| ``` |
| yield request |
| ``` |
| |
| The client will `.send()` the response back into the flow generator. You can |
| access it like so: |
| |
| ``` |
| response = yield request |
| ``` |
| |
| A `return` (or reaching the end of the generator) will result in the |
| client returning the last response obtained from the server. |
| |
| You can dispatch as many requests as is necessary. |
| """ |
| yield request |
|
|
| def sync_auth_flow( |
| self, request: Request |
| ) -> typing.Generator[Request, Response, None]: |
| """ |
| Execute the authentication flow synchronously. |
| |
| By default, this defers to `.auth_flow()`. You should override this method |
| when the authentication scheme does I/O and/or uses concurrency primitives. |
| """ |
| if self.requires_request_body: |
| request.read() |
|
|
| flow = self.auth_flow(request) |
| request = next(flow) |
|
|
| while True: |
| response = yield request |
| if self.requires_response_body: |
| response.read() |
|
|
| try: |
| request = flow.send(response) |
| except StopIteration: |
| break |
|
|
| async def async_auth_flow( |
| self, request: Request |
| ) -> typing.AsyncGenerator[Request, Response]: |
| """ |
| Execute the authentication flow asynchronously. |
| |
| By default, this defers to `.auth_flow()`. You should override this method |
| when the authentication scheme does I/O and/or uses concurrency primitives. |
| """ |
| if self.requires_request_body: |
| await request.aread() |
|
|
| flow = self.auth_flow(request) |
| request = next(flow) |
|
|
| while True: |
| response = yield request |
| if self.requires_response_body: |
| await response.aread() |
|
|
| try: |
| request = flow.send(response) |
| except StopIteration: |
| break |
|
|
|
|
| class FunctionAuth(Auth): |
| """ |
| Allows the 'auth' argument to be passed as a simple callable function, |
| that takes the request, and returns a new, modified request. |
| """ |
|
|
| def __init__(self, func: typing.Callable[[Request], Request]) -> None: |
| self._func = func |
|
|
| def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| yield self._func(request) |
|
|
|
|
| class BasicAuth(Auth): |
| """ |
| Allows the 'auth' argument to be passed as a (username, password) pair, |
| and uses HTTP Basic authentication. |
| """ |
|
|
| def __init__(self, username: str | bytes, password: str | bytes) -> None: |
| self._auth_header = self._build_auth_header(username, password) |
|
|
| def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| request.headers["Authorization"] = self._auth_header |
| yield request |
|
|
| def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: |
| userpass = b":".join((to_bytes(username), to_bytes(password))) |
| token = b64encode(userpass).decode() |
| return f"Basic {token}" |
|
|
|
|
| class NetRCAuth(Auth): |
| """ |
| Use a 'netrc' file to lookup basic auth credentials based on the url host. |
| """ |
|
|
| def __init__(self, file: str | None = None) -> None: |
| |
| |
| import netrc |
|
|
| self._netrc_info = netrc.netrc(file) |
|
|
| def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| auth_info = self._netrc_info.authenticators(request.url.host) |
| if auth_info is None or not auth_info[2]: |
| |
| yield request |
| else: |
| |
| request.headers["Authorization"] = self._build_auth_header( |
| username=auth_info[0], password=auth_info[2] |
| ) |
| yield request |
|
|
| def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str: |
| userpass = b":".join((to_bytes(username), to_bytes(password))) |
| token = b64encode(userpass).decode() |
| return f"Basic {token}" |
|
|
|
|
| class DigestAuth(Auth): |
| _ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = { |
| "MD5": hashlib.md5, |
| "MD5-SESS": hashlib.md5, |
| "SHA": hashlib.sha1, |
| "SHA-SESS": hashlib.sha1, |
| "SHA-256": hashlib.sha256, |
| "SHA-256-SESS": hashlib.sha256, |
| "SHA-512": hashlib.sha512, |
| "SHA-512-SESS": hashlib.sha512, |
| } |
|
|
| def __init__(self, username: str | bytes, password: str | bytes) -> None: |
| self._username = to_bytes(username) |
| self._password = to_bytes(password) |
| self._last_challenge: _DigestAuthChallenge | None = None |
| self._nonce_count = 1 |
|
|
| def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: |
| if self._last_challenge: |
| request.headers["Authorization"] = self._build_auth_header( |
| request, self._last_challenge |
| ) |
|
|
| response = yield request |
|
|
| if response.status_code != 401 or "www-authenticate" not in response.headers: |
| |
| |
| return |
|
|
| for auth_header in response.headers.get_list("www-authenticate"): |
| if auth_header.lower().startswith("digest "): |
| break |
| else: |
| |
| |
| return |
|
|
| self._last_challenge = self._parse_challenge(request, response, auth_header) |
| self._nonce_count = 1 |
|
|
| request.headers["Authorization"] = self._build_auth_header( |
| request, self._last_challenge |
| ) |
| if response.cookies: |
| Cookies(response.cookies).set_cookie_header(request=request) |
| yield request |
|
|
| def _parse_challenge( |
| self, request: Request, response: Response, auth_header: str |
| ) -> _DigestAuthChallenge: |
| """ |
| Returns a challenge from a Digest WWW-Authenticate header. |
| These take the form of: |
| `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` |
| """ |
| scheme, _, fields = auth_header.partition(" ") |
|
|
| |
| assert scheme.lower() == "digest" |
|
|
| header_dict: dict[str, str] = {} |
| for field in parse_http_list(fields): |
| key, value = field.strip().split("=", 1) |
| header_dict[key] = unquote(value) |
|
|
| try: |
| realm = header_dict["realm"].encode() |
| nonce = header_dict["nonce"].encode() |
| algorithm = header_dict.get("algorithm", "MD5") |
| opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None |
| qop = header_dict["qop"].encode() if "qop" in header_dict else None |
| return _DigestAuthChallenge( |
| realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop |
| ) |
| except KeyError as exc: |
| message = "Malformed Digest WWW-Authenticate header" |
| raise ProtocolError(message, request=request) from exc |
|
|
| def _build_auth_header( |
| self, request: Request, challenge: _DigestAuthChallenge |
| ) -> str: |
| hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()] |
|
|
| def digest(data: bytes) -> bytes: |
| return hash_func(data).hexdigest().encode() |
|
|
| A1 = b":".join((self._username, challenge.realm, self._password)) |
|
|
| path = request.url.raw_path |
| A2 = b":".join((request.method.encode(), path)) |
| |
| HA2 = digest(A2) |
|
|
| nc_value = b"%08x" % self._nonce_count |
| cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce) |
| self._nonce_count += 1 |
|
|
| HA1 = digest(A1) |
| if challenge.algorithm.lower().endswith("-sess"): |
| HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) |
|
|
| qop = self._resolve_qop(challenge.qop, request=request) |
| if qop is None: |
| |
| digest_data = [HA1, challenge.nonce, HA2] |
| else: |
| |
| digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2] |
|
|
| format_args = { |
| "username": self._username, |
| "realm": challenge.realm, |
| "nonce": challenge.nonce, |
| "uri": path, |
| "response": digest(b":".join(digest_data)), |
| "algorithm": challenge.algorithm.encode(), |
| } |
| if challenge.opaque: |
| format_args["opaque"] = challenge.opaque |
| if qop: |
| format_args["qop"] = b"auth" |
| format_args["nc"] = nc_value |
| format_args["cnonce"] = cnonce |
|
|
| return "Digest " + self._get_header_value(format_args) |
|
|
| def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes: |
| s = str(nonce_count).encode() |
| s += nonce |
| s += time.ctime().encode() |
| s += os.urandom(8) |
|
|
| return hashlib.sha1(s).hexdigest()[:16].encode() |
|
|
| def _get_header_value(self, header_fields: dict[str, bytes]) -> str: |
| NON_QUOTED_FIELDS = ("algorithm", "qop", "nc") |
| QUOTED_TEMPLATE = '{}="{}"' |
| NON_QUOTED_TEMPLATE = "{}={}" |
|
|
| header_value = "" |
| for i, (field, value) in enumerate(header_fields.items()): |
| if i > 0: |
| header_value += ", " |
| template = ( |
| QUOTED_TEMPLATE |
| if field not in NON_QUOTED_FIELDS |
| else NON_QUOTED_TEMPLATE |
| ) |
| header_value += template.format(field, to_str(value)) |
|
|
| return header_value |
|
|
| def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None: |
| if qop is None: |
| return None |
| qops = re.split(b", ?", qop) |
| if b"auth" in qops: |
| return b"auth" |
|
|
| if qops == [b"auth-int"]: |
| raise NotImplementedError("Digest auth-int support is not yet implemented") |
|
|
| message = f'Unexpected qop value "{qop!r}" in digest auth' |
| raise ProtocolError(message, request=request) |
|
|
|
|
| class _DigestAuthChallenge(typing.NamedTuple): |
| realm: bytes |
| nonce: bytes |
| algorithm: str |
| opaque: bytes | None |
| qop: bytes | None |
|
|