From 611b8667a3a9839f0ee2b5e5183fbef0e82ac922 Mon Sep 17 00:00:00 2001 From: "Mike A." Date: Tue, 10 Sep 2024 23:34:26 +0200 Subject: [PATCH] Implement digest auth --- vdirsyncer/http.py | 108 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 17 deletions(-) diff --git a/vdirsyncer/http.py b/vdirsyncer/http.py index 69ad845..2f89082 100644 --- a/vdirsyncer/http.py +++ b/vdirsyncer/http.py @@ -1,9 +1,14 @@ from __future__ import annotations import logging +import re +from abc import ABC, abstractmethod +from base64 import b64encode from ssl import create_default_context import aiohttp +import requests.auth +from requests.utils import parse_dict_header from . import DOCS_HOME from . import __version__ @@ -36,25 +41,77 @@ _detect_faulty_requests() del _detect_faulty_requests +class AuthMethod(ABC): + def __init__(self, username, password): + self.username = username + self.password = password + + @abstractmethod + def handle_401(self, response): + raise NotImplementedError + + @abstractmethod + def get_auth_header(self, method, url): + raise NotImplementedError + + +class BasicAuthMethod(AuthMethod): + def handle_401(self, _response): + pass + + def get_auth_header(self, _method, _url): + auth_str = f"{self.username}:{self.password}" + return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8") + + +class DigestAuthMethod(AuthMethod): + # make class var to 'cache' the state, which is more efficient because otherwise + # each request would first require another 'initialization' request. + _auth_helpers = {} + + def __init__(self, username, password): + super().__init__(username, password) + + self._auth_helper = self._auth_helpers.get( + (username, password), + requests.auth.HTTPDigestAuth(username, password) + ) + self._auth_helpers[(username, password)] = self._auth_helper + + @property + def auth_helper_vars(self): + return self._auth_helper._thread_local + + def handle_401(self, response): + s_auth = response.headers.get("www-authenticate", "") + + if "digest" in s_auth.lower(): + # Original source: + # https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263 + pat = re.compile(r"digest ", flags=re.IGNORECASE) + self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1)) + + def get_auth_header(self, method, url): + self._auth_helper.init_per_thread_state() + + if not self.auth_helper_vars.chal: + # Need to do init request first + return '' + + return self._auth_helper.build_digest_header(method, url) + + def prepare_auth(auth, username, password): if username and password: if auth == "basic" or auth is None: - return aiohttp.BasicAuth(username, password) + return BasicAuthMethod(username, password) elif auth == "digest": - from requests.auth import HTTPDigestAuth - - return HTTPDigestAuth(username, password) + return DigestAuthMethod(username, password) elif auth == "guess": - try: - from requests_toolbelt.auth.guess import GuessAuth - except ImportError: - raise exceptions.UserError( - "Your version of requests_toolbelt is too " - "old for `guess` authentication. At least " - "version 0.4.0 is required." - ) - else: - return GuessAuth(username, password) + raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n" + f"Please explicitly specify either 'basic' or 'digest' auth instead. \n" + f"See the following issue for more information: " + f"https://github.com/pimutils/vdirsyncer/issues/1015") else: raise exceptions.UserError(f"Unknown authentication method: {auth}") elif auth: @@ -97,14 +154,17 @@ async def request( method, url, session, + auth, latin1_fallback=True, **kwargs, ): - """Wrapper method for requests, to ease logging and mocking. + """Wrapper method for requests, to ease logging and mocking as well as to + support auth methods currently unsupported by aiohttp. - Parameters should be the same as for ``aiohttp.request``, as well as: + Parameters should be the same as for ``aiohttp.request``, except: :param session: A requests session object to use. + :param auth: The HTTP ``AuthMethod`` to use for authentication. :param verify_fingerprint: Optional. SHA256 of the expected server certificate. :param latin1_fallback: RFC-2616 specifies the default Content-Type of text/* to be latin1, which is not always correct, but exactly what @@ -134,7 +194,21 @@ async def request( ssl_context.load_cert_chain(*cert) kwargs["ssl"] = ssl_context - response = await session.request(method, url, **kwargs) + headers = kwargs.pop("headers", {}) + num_401 = 0 + while num_401 < 2: + headers["Authorization"] = auth.get_auth_header(method, url) + response = await session.request(method, url, headers=headers, **kwargs) + + if response.ok: + break + + if response.status == 401: + num_401 += 1 + auth.handle_401(response) + else: + # some other error, will be handled later on + break # See https://github.com/kennethreitz/requests/issues/2042 content_type = response.headers.get("Content-Type", "")