mirror of
https://github.com/samsonjs/vdirsyncer.git
synced 2026-04-27 14:57:41 +00:00
Implement digest auth
This commit is contained in:
parent
8550475548
commit
611b8667a3
1 changed files with 91 additions and 17 deletions
|
|
@ -1,9 +1,14 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from base64 import b64encode
|
||||||
from ssl import create_default_context
|
from ssl import create_default_context
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import requests.auth
|
||||||
|
from requests.utils import parse_dict_header
|
||||||
|
|
||||||
from . import DOCS_HOME
|
from . import DOCS_HOME
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
|
@ -36,25 +41,77 @@ _detect_faulty_requests()
|
||||||
del _detect_faulty_requests
|
del _detect_faulty_requests
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMethod(ABC):
|
||||||
|
def __init__(self, username, password):
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle_401(self, response):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_auth_header(self, method, url):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class BasicAuthMethod(AuthMethod):
|
||||||
|
def handle_401(self, _response):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_auth_header(self, _method, _url):
|
||||||
|
auth_str = f"{self.username}:{self.password}"
|
||||||
|
return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class DigestAuthMethod(AuthMethod):
|
||||||
|
# make class var to 'cache' the state, which is more efficient because otherwise
|
||||||
|
# each request would first require another 'initialization' request.
|
||||||
|
_auth_helpers = {}
|
||||||
|
|
||||||
|
def __init__(self, username, password):
|
||||||
|
super().__init__(username, password)
|
||||||
|
|
||||||
|
self._auth_helper = self._auth_helpers.get(
|
||||||
|
(username, password),
|
||||||
|
requests.auth.HTTPDigestAuth(username, password)
|
||||||
|
)
|
||||||
|
self._auth_helpers[(username, password)] = self._auth_helper
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_helper_vars(self):
|
||||||
|
return self._auth_helper._thread_local
|
||||||
|
|
||||||
|
def handle_401(self, response):
|
||||||
|
s_auth = response.headers.get("www-authenticate", "")
|
||||||
|
|
||||||
|
if "digest" in s_auth.lower():
|
||||||
|
# Original source:
|
||||||
|
# https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263
|
||||||
|
pat = re.compile(r"digest ", flags=re.IGNORECASE)
|
||||||
|
self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1))
|
||||||
|
|
||||||
|
def get_auth_header(self, method, url):
|
||||||
|
self._auth_helper.init_per_thread_state()
|
||||||
|
|
||||||
|
if not self.auth_helper_vars.chal:
|
||||||
|
# Need to do init request first
|
||||||
|
return ''
|
||||||
|
|
||||||
|
return self._auth_helper.build_digest_header(method, url)
|
||||||
|
|
||||||
|
|
||||||
def prepare_auth(auth, username, password):
|
def prepare_auth(auth, username, password):
|
||||||
if username and password:
|
if username and password:
|
||||||
if auth == "basic" or auth is None:
|
if auth == "basic" or auth is None:
|
||||||
return aiohttp.BasicAuth(username, password)
|
return BasicAuthMethod(username, password)
|
||||||
elif auth == "digest":
|
elif auth == "digest":
|
||||||
from requests.auth import HTTPDigestAuth
|
return DigestAuthMethod(username, password)
|
||||||
|
|
||||||
return HTTPDigestAuth(username, password)
|
|
||||||
elif auth == "guess":
|
elif auth == "guess":
|
||||||
try:
|
raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n"
|
||||||
from requests_toolbelt.auth.guess import GuessAuth
|
f"Please explicitly specify either 'basic' or 'digest' auth instead. \n"
|
||||||
except ImportError:
|
f"See the following issue for more information: "
|
||||||
raise exceptions.UserError(
|
f"https://github.com/pimutils/vdirsyncer/issues/1015")
|
||||||
"Your version of requests_toolbelt is too "
|
|
||||||
"old for `guess` authentication. At least "
|
|
||||||
"version 0.4.0 is required."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return GuessAuth(username, password)
|
|
||||||
else:
|
else:
|
||||||
raise exceptions.UserError(f"Unknown authentication method: {auth}")
|
raise exceptions.UserError(f"Unknown authentication method: {auth}")
|
||||||
elif auth:
|
elif auth:
|
||||||
|
|
@ -97,14 +154,17 @@ async def request(
|
||||||
method,
|
method,
|
||||||
url,
|
url,
|
||||||
session,
|
session,
|
||||||
|
auth,
|
||||||
latin1_fallback=True,
|
latin1_fallback=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Wrapper method for requests, to ease logging and mocking.
|
"""Wrapper method for requests, to ease logging and mocking as well as to
|
||||||
|
support auth methods currently unsupported by aiohttp.
|
||||||
|
|
||||||
Parameters should be the same as for ``aiohttp.request``, as well as:
|
Parameters should be the same as for ``aiohttp.request``, except:
|
||||||
|
|
||||||
:param session: A requests session object to use.
|
:param session: A requests session object to use.
|
||||||
|
:param auth: The HTTP ``AuthMethod`` to use for authentication.
|
||||||
:param verify_fingerprint: Optional. SHA256 of the expected server certificate.
|
:param verify_fingerprint: Optional. SHA256 of the expected server certificate.
|
||||||
:param latin1_fallback: RFC-2616 specifies the default Content-Type of
|
:param latin1_fallback: RFC-2616 specifies the default Content-Type of
|
||||||
text/* to be latin1, which is not always correct, but exactly what
|
text/* to be latin1, which is not always correct, but exactly what
|
||||||
|
|
@ -134,7 +194,21 @@ async def request(
|
||||||
ssl_context.load_cert_chain(*cert)
|
ssl_context.load_cert_chain(*cert)
|
||||||
kwargs["ssl"] = ssl_context
|
kwargs["ssl"] = ssl_context
|
||||||
|
|
||||||
response = await session.request(method, url, **kwargs)
|
headers = kwargs.pop("headers", {})
|
||||||
|
num_401 = 0
|
||||||
|
while num_401 < 2:
|
||||||
|
headers["Authorization"] = auth.get_auth_header(method, url)
|
||||||
|
response = await session.request(method, url, headers=headers, **kwargs)
|
||||||
|
|
||||||
|
if response.ok:
|
||||||
|
break
|
||||||
|
|
||||||
|
if response.status == 401:
|
||||||
|
num_401 += 1
|
||||||
|
auth.handle_401(response)
|
||||||
|
else:
|
||||||
|
# some other error, will be handled later on
|
||||||
|
break
|
||||||
|
|
||||||
# See https://github.com/kennethreitz/requests/issues/2042
|
# See https://github.com/kennethreitz/requests/issues/2042
|
||||||
content_type = response.headers.get("Content-Type", "")
|
content_type = response.headers.get("Content-Type", "")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue