Implement digest auth

This commit is contained in:
Mike A. 2024-09-10 23:34:26 +02:00 committed by Hugo
parent 8550475548
commit 611b8667a3

View file

@ -1,9 +1,14 @@
from __future__ import annotations
import logging
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from ssl import create_default_context
import aiohttp
import requests.auth
from requests.utils import parse_dict_header
from . import DOCS_HOME
from . import __version__
@ -36,25 +41,77 @@ _detect_faulty_requests()
del _detect_faulty_requests
class AuthMethod(ABC):
def __init__(self, username, password):
self.username = username
self.password = password
@abstractmethod
def handle_401(self, response):
raise NotImplementedError
@abstractmethod
def get_auth_header(self, method, url):
raise NotImplementedError
class BasicAuthMethod(AuthMethod):
def handle_401(self, _response):
pass
def get_auth_header(self, _method, _url):
auth_str = f"{self.username}:{self.password}"
return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8")
class DigestAuthMethod(AuthMethod):
# make class var to 'cache' the state, which is more efficient because otherwise
# each request would first require another 'initialization' request.
_auth_helpers = {}
def __init__(self, username, password):
super().__init__(username, password)
self._auth_helper = self._auth_helpers.get(
(username, password),
requests.auth.HTTPDigestAuth(username, password)
)
self._auth_helpers[(username, password)] = self._auth_helper
@property
def auth_helper_vars(self):
return self._auth_helper._thread_local
def handle_401(self, response):
s_auth = response.headers.get("www-authenticate", "")
if "digest" in s_auth.lower():
# Original source:
# https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263
pat = re.compile(r"digest ", flags=re.IGNORECASE)
self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1))
def get_auth_header(self, method, url):
self._auth_helper.init_per_thread_state()
if not self.auth_helper_vars.chal:
# Need to do init request first
return ''
return self._auth_helper.build_digest_header(method, url)
def prepare_auth(auth, username, password):
if username and password:
if auth == "basic" or auth is None:
return aiohttp.BasicAuth(username, password)
return BasicAuthMethod(username, password)
elif auth == "digest":
from requests.auth import HTTPDigestAuth
return HTTPDigestAuth(username, password)
return DigestAuthMethod(username, password)
elif auth == "guess":
try:
from requests_toolbelt.auth.guess import GuessAuth
except ImportError:
raise exceptions.UserError(
"Your version of requests_toolbelt is too "
"old for `guess` authentication. At least "
"version 0.4.0 is required."
)
else:
return GuessAuth(username, password)
raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n"
f"Please explicitly specify either 'basic' or 'digest' auth instead. \n"
f"See the following issue for more information: "
f"https://github.com/pimutils/vdirsyncer/issues/1015")
else:
raise exceptions.UserError(f"Unknown authentication method: {auth}")
elif auth:
@ -97,14 +154,17 @@ async def request(
method,
url,
session,
auth,
latin1_fallback=True,
**kwargs,
):
"""Wrapper method for requests, to ease logging and mocking.
"""Wrapper method for requests, to ease logging and mocking as well as to
support auth methods currently unsupported by aiohttp.
Parameters should be the same as for ``aiohttp.request``, as well as:
Parameters should be the same as for ``aiohttp.request``, except:
:param session: A requests session object to use.
:param auth: The HTTP ``AuthMethod`` to use for authentication.
:param verify_fingerprint: Optional. SHA256 of the expected server certificate.
:param latin1_fallback: RFC-2616 specifies the default Content-Type of
text/* to be latin1, which is not always correct, but exactly what
@ -134,7 +194,21 @@ async def request(
ssl_context.load_cert_chain(*cert)
kwargs["ssl"] = ssl_context
response = await session.request(method, url, **kwargs)
headers = kwargs.pop("headers", {})
num_401 = 0
while num_401 < 2:
headers["Authorization"] = auth.get_auth_header(method, url)
response = await session.request(method, url, headers=headers, **kwargs)
if response.ok:
break
if response.status == 401:
num_401 += 1
auth.handle_401(response)
else:
# some other error, will be handled later on
break
# See https://github.com/kennethreitz/requests/issues/2042
content_type = response.headers.get("Content-Type", "")