Make tests pass

This commit is contained in:
Mike A. 2024-09-10 23:54:02 +02:00 committed by Hugo
parent 89a01631fa
commit 67e1c0ded5
2 changed files with 14 additions and 9 deletions

View file

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from aiohttp import BasicAuth
from aioresponses import CallbackResult from aioresponses import CallbackResult
from aioresponses import aioresponses from aioresponses import aioresponses
from tests import normalize_item from tests import normalize_item
from vdirsyncer.exceptions import UserError from vdirsyncer.exceptions import UserError
from vdirsyncer.http import BasicAuthMethod, DigestAuthMethod
from vdirsyncer.storage.http import HttpStorage from vdirsyncer.storage.http import HttpStorage
from vdirsyncer.storage.http import prepare_auth from vdirsyncer.storage.http import prepare_auth
@ -91,16 +91,14 @@ def test_readonly_param(aio_connector):
def test_prepare_auth(): def test_prepare_auth():
assert prepare_auth(None, "", "") is None assert prepare_auth(None, "", "") is None
assert prepare_auth(None, "user", "pwd") == BasicAuth("user", "pwd") assert prepare_auth(None, "user", "pwd") == BasicAuthMethod("user", "pwd")
assert prepare_auth("basic", "user", "pwd") == BasicAuth("user", "pwd") assert prepare_auth("basic", "user", "pwd") == BasicAuthMethod("user", "pwd")
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert prepare_auth("basic", "", "pwd") assert prepare_auth("basic", "", "pwd")
assert "you need to specify username and password" in str(excinfo.value).lower() assert "you need to specify username and password" in str(excinfo.value).lower()
from requests.auth import HTTPDigestAuth assert isinstance(prepare_auth("digest", "user", "pwd"), DigestAuthMethod)
assert isinstance(prepare_auth("digest", "user", "pwd"), HTTPDigestAuth)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
prepare_auth("ladida", "user", "pwd") prepare_auth("ladida", "user", "pwd")

View file

@ -31,6 +31,11 @@ class AuthMethod(ABC):
def get_auth_header(self, method, url): def get_auth_header(self, method, url):
raise NotImplementedError raise NotImplementedError
def __eq__(self, other):
if not isinstance(other, AuthMethod):
return False
return self.__class__ == other.__class__ and self.username == other.username and self.password == other.password
class BasicAuthMethod(AuthMethod): class BasicAuthMethod(AuthMethod):
def handle_401(self, _response): def handle_401(self, _response):
@ -131,7 +136,7 @@ async def request(
method, method,
url, url,
session, session,
auth, auth=None,
latin1_fallback=True, latin1_fallback=True,
**kwargs, **kwargs,
): ):
@ -174,10 +179,12 @@ async def request(
headers = kwargs.pop("headers", {}) headers = kwargs.pop("headers", {})
num_401 = 0 num_401 = 0
while num_401 < 2: while num_401 < 2:
headers["Authorization"] = auth.get_auth_header(method, url) if auth:
headers["Authorization"] = auth.get_auth_header(method, url)
response = await session.request(method, url, headers=headers, **kwargs) response = await session.request(method, url, headers=headers, **kwargs)
if response.ok: if response.ok or not auth:
# we don't need to do the 401-loop if we don't do auth in the first place
break break
if response.status == 401: if response.status == 401: