mirror of
https://github.com/samsonjs/vdirsyncer.git
synced 2026-03-25 08:55:50 +00:00
Implement digest auth
This commit is contained in:
parent
8550475548
commit
611b8667a3
1 changed files with 91 additions and 17 deletions
|
|
@ -1,9 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from ssl import create_default_context
|
||||
|
||||
import aiohttp
|
||||
import requests.auth
|
||||
from requests.utils import parse_dict_header
|
||||
|
||||
from . import DOCS_HOME
|
||||
from . import __version__
|
||||
|
|
@ -36,25 +41,77 @@ _detect_faulty_requests()
|
|||
del _detect_faulty_requests
|
||||
|
||||
|
||||
class AuthMethod(ABC):
|
||||
def __init__(self, username, password):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
@abstractmethod
|
||||
def handle_401(self, response):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_auth_header(self, method, url):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BasicAuthMethod(AuthMethod):
|
||||
def handle_401(self, _response):
|
||||
pass
|
||||
|
||||
def get_auth_header(self, _method, _url):
|
||||
auth_str = f"{self.username}:{self.password}"
|
||||
return "Basic " + b64encode(auth_str.encode('utf-8')).decode("utf-8")
|
||||
|
||||
|
||||
class DigestAuthMethod(AuthMethod):
|
||||
# make class var to 'cache' the state, which is more efficient because otherwise
|
||||
# each request would first require another 'initialization' request.
|
||||
_auth_helpers = {}
|
||||
|
||||
def __init__(self, username, password):
|
||||
super().__init__(username, password)
|
||||
|
||||
self._auth_helper = self._auth_helpers.get(
|
||||
(username, password),
|
||||
requests.auth.HTTPDigestAuth(username, password)
|
||||
)
|
||||
self._auth_helpers[(username, password)] = self._auth_helper
|
||||
|
||||
@property
|
||||
def auth_helper_vars(self):
|
||||
return self._auth_helper._thread_local
|
||||
|
||||
def handle_401(self, response):
|
||||
s_auth = response.headers.get("www-authenticate", "")
|
||||
|
||||
if "digest" in s_auth.lower():
|
||||
# Original source:
|
||||
# https://github.com/psf/requests/blob/f12ccbef6d6b95564da8d22e280d28c39d53f0e9/src/requests/auth.py#L262-L263
|
||||
pat = re.compile(r"digest ", flags=re.IGNORECASE)
|
||||
self.auth_helper_vars.chal = parse_dict_header(pat.sub("", s_auth, count=1))
|
||||
|
||||
def get_auth_header(self, method, url):
|
||||
self._auth_helper.init_per_thread_state()
|
||||
|
||||
if not self.auth_helper_vars.chal:
|
||||
# Need to do init request first
|
||||
return ''
|
||||
|
||||
return self._auth_helper.build_digest_header(method, url)
|
||||
|
||||
|
||||
def prepare_auth(auth, username, password):
|
||||
if username and password:
|
||||
if auth == "basic" or auth is None:
|
||||
return aiohttp.BasicAuth(username, password)
|
||||
return BasicAuthMethod(username, password)
|
||||
elif auth == "digest":
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
return HTTPDigestAuth(username, password)
|
||||
return DigestAuthMethod(username, password)
|
||||
elif auth == "guess":
|
||||
try:
|
||||
from requests_toolbelt.auth.guess import GuessAuth
|
||||
except ImportError:
|
||||
raise exceptions.UserError(
|
||||
"Your version of requests_toolbelt is too "
|
||||
"old for `guess` authentication. At least "
|
||||
"version 0.4.0 is required."
|
||||
)
|
||||
else:
|
||||
return GuessAuth(username, password)
|
||||
raise exceptions.UserError(f"'Guess' authentication is not supported in this version of vdirsyncer. \n"
|
||||
f"Please explicitly specify either 'basic' or 'digest' auth instead. \n"
|
||||
f"See the following issue for more information: "
|
||||
f"https://github.com/pimutils/vdirsyncer/issues/1015")
|
||||
else:
|
||||
raise exceptions.UserError(f"Unknown authentication method: {auth}")
|
||||
elif auth:
|
||||
|
|
@ -97,14 +154,17 @@ async def request(
|
|||
method,
|
||||
url,
|
||||
session,
|
||||
auth,
|
||||
latin1_fallback=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Wrapper method for requests, to ease logging and mocking.
|
||||
"""Wrapper method for requests, to ease logging and mocking as well as to
|
||||
support auth methods currently unsupported by aiohttp.
|
||||
|
||||
Parameters should be the same as for ``aiohttp.request``, as well as:
|
||||
Parameters should be the same as for ``aiohttp.request``, except:
|
||||
|
||||
:param session: A requests session object to use.
|
||||
:param auth: The HTTP ``AuthMethod`` to use for authentication.
|
||||
:param verify_fingerprint: Optional. SHA256 of the expected server certificate.
|
||||
:param latin1_fallback: RFC-2616 specifies the default Content-Type of
|
||||
text/* to be latin1, which is not always correct, but exactly what
|
||||
|
|
@ -134,7 +194,21 @@ async def request(
|
|||
ssl_context.load_cert_chain(*cert)
|
||||
kwargs["ssl"] = ssl_context
|
||||
|
||||
response = await session.request(method, url, **kwargs)
|
||||
headers = kwargs.pop("headers", {})
|
||||
num_401 = 0
|
||||
while num_401 < 2:
|
||||
headers["Authorization"] = auth.get_auth_header(method, url)
|
||||
response = await session.request(method, url, headers=headers, **kwargs)
|
||||
|
||||
if response.ok:
|
||||
break
|
||||
|
||||
if response.status == 401:
|
||||
num_401 += 1
|
||||
auth.handle_401(response)
|
||||
else:
|
||||
# some other error, will be handled later on
|
||||
break
|
||||
|
||||
# See https://github.com/kennethreitz/requests/issues/2042
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
|
|
|||
Loading…
Reference in a new issue