diff --git a/.builds/archlinux.yaml b/.builds/archlinux.yaml index b09623d..4c122dc 100644 --- a/.builds/archlinux.yaml +++ b/.builds/archlinux.yaml @@ -19,6 +19,10 @@ packages: - python-pytest-cov - python-pytest-httpserver - python-trustme + - python-pytest-asyncio + - python-aiohttp + - python-aiostream + - python-aioresponses sources: - https://github.com/pimutils/vdirsyncer environment: diff --git a/setup.cfg b/setup.cfg index 40e5673..e09d9db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ addopts = --cov=vdirsyncer --cov-report=term-missing --no-cov-on-fail +# filterwarnings=error [flake8] application-import-names = tests,vdirsyncer diff --git a/setup.py b/setup.py index 0305896..ea71894 100644 --- a/setup.py +++ b/setup.py @@ -13,14 +13,14 @@ requirements = [ # https://github.com/mitsuhiko/click/issues/200 "click>=5.0,<9.0", "click-log>=0.3.0, <0.4.0", - # https://github.com/pimutils/vdirsyncer/issues/478 - "click-threading>=0.5", "requests >=2.20.0", # https://github.com/sigmavirus24/requests-toolbelt/pull/28 # And https://github.com/sigmavirus24/requests-toolbelt/issues/54 "requests_toolbelt >=0.4.0", # https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa "atomicwrites>=0.1.7", + "aiohttp>=3.7.1,<4.0.0", + "aiostream>=0.4.3,<0.5.0", ] diff --git a/test-requirements.txt b/test-requirements.txt index 3d1c6cb..bda5f7e 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,3 +3,5 @@ pytest pytest-cov pytest-httpserver trustme +pytest-asyncio +aioresponses diff --git a/tests/conftest.py b/tests/conftest.py index 5423e6d..5c174b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ General-purpose fixtures for vdirsyncer's testsuite. import logging import os +import aiohttp import click_log import pytest from hypothesis import HealthCheck @@ -52,3 +53,18 @@ elif os.environ.get("CI", "false").lower() == "true": settings.load_profile("ci") else: settings.load_profile("dev") + + +@pytest.fixture +async def aio_session(event_loop): + async with aiohttp.ClientSession() as session: + yield session + + +@pytest.fixture +async def aio_connector(event_loop): + conn = aiohttp.TCPConnector(limit_per_host=16) + try: + yield conn + finally: + await conn.close() diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py index d68aacc..88d38e4 100644 --- a/tests/storage/__init__.py +++ b/tests/storage/__init__.py @@ -4,6 +4,7 @@ import uuid from urllib.parse import quote as urlquote from urllib.parse import unquote as urlunquote +import aiostream import pytest from .. import assert_item_equals @@ -49,8 +50,9 @@ class StorageTests: raise NotImplementedError() @pytest.fixture - def s(self, get_storage_args): - return self.storage_class(**get_storage_args()) + async def s(self, get_storage_args): + rv = self.storage_class(**await get_storage_args()) + return rv @pytest.fixture def get_item(self, item_type): @@ -72,176 +74,211 @@ class StorageTests: if not self.supports_metadata: pytest.skip("This storage does not support metadata.") - def test_generic(self, s, get_item): + @pytest.mark.asyncio + async def test_generic(self, s, get_item): items = [get_item() for i in range(1, 10)] hrefs = [] for item in items: - href, etag = s.upload(item) + href, etag = await s.upload(item) if etag is None: - _, etag = s.get(href) + _, etag = await s.get(href) hrefs.append((href, etag)) hrefs.sort() - assert hrefs == sorted(s.list()) + assert hrefs == sorted(await aiostream.stream.list(s.list())) for href, etag in hrefs: assert isinstance(href, (str, bytes)) assert isinstance(etag, (str, bytes)) - assert s.has(href) - item, etag2 = s.get(href) + assert await s.has(href) + item, etag2 = await s.get(href) assert etag == etag2 - def test_empty_get_multi(self, s): - assert list(s.get_multi([])) == [] + @pytest.mark.asyncio + async def test_empty_get_multi(self, s): + assert await aiostream.stream.list(s.get_multi([])) == [] - def test_get_multi_duplicates(self, s, get_item): - href, etag = s.upload(get_item()) + @pytest.mark.asyncio + async def test_get_multi_duplicates(self, s, get_item): + href, etag = await s.upload(get_item()) if etag is None: - _, etag = s.get(href) - ((href2, item, etag2),) = s.get_multi([href] * 2) + _, etag = await s.get(href) + ((href2, item, etag2),) = await aiostream.stream.list(s.get_multi([href] * 2)) assert href2 == href assert etag2 == etag - def test_upload_already_existing(self, s, get_item): + @pytest.mark.asyncio + async def test_upload_already_existing(self, s, get_item): item = get_item() - s.upload(item) + await s.upload(item) with pytest.raises(exceptions.PreconditionFailed): - s.upload(item) + await s.upload(item) - def test_upload(self, s, get_item): + @pytest.mark.asyncio + async def test_upload(self, s, get_item): item = get_item() - href, etag = s.upload(item) - assert_item_equals(s.get(href)[0], item) + href, etag = await s.upload(item) + assert_item_equals((await s.get(href))[0], item) - def test_update(self, s, get_item): + @pytest.mark.asyncio + async def test_update(self, s, get_item): item = get_item() - href, etag = s.upload(item) + href, etag = await s.upload(item) if etag is None: - _, etag = s.get(href) - assert_item_equals(s.get(href)[0], item) + _, etag = await s.get(href) + assert_item_equals((await s.get(href))[0], item) new_item = get_item(uid=item.uid) - new_etag = s.update(href, new_item, etag) + new_etag = await s.update(href, new_item, etag) if new_etag is None: - _, new_etag = s.get(href) + _, new_etag = await s.get(href) # See https://github.com/pimutils/vdirsyncer/issues/48 assert isinstance(new_etag, (bytes, str)) - assert_item_equals(s.get(href)[0], new_item) + assert_item_equals((await s.get(href))[0], new_item) - def test_update_nonexisting(self, s, get_item): + @pytest.mark.asyncio + async def test_update_nonexisting(self, s, get_item): item = get_item() with pytest.raises(exceptions.PreconditionFailed): - s.update("huehue", item, '"123"') + await s.update("huehue", item, '"123"') - def test_wrong_etag(self, s, get_item): + @pytest.mark.asyncio + async def test_wrong_etag(self, s, get_item): item = get_item() - href, etag = s.upload(item) + href, etag = await s.upload(item) with pytest.raises(exceptions.PreconditionFailed): - s.update(href, item, '"lolnope"') + await s.update(href, item, '"lolnope"') with pytest.raises(exceptions.PreconditionFailed): - s.delete(href, '"lolnope"') + await s.delete(href, '"lolnope"') - def test_delete(self, s, get_item): - href, etag = s.upload(get_item()) - s.delete(href, etag) - assert not list(s.list()) + @pytest.mark.asyncio + async def test_delete(self, s, get_item): + href, etag = await s.upload(get_item()) + await s.delete(href, etag) + assert not await aiostream.stream.list(s.list()) - def test_delete_nonexisting(self, s, get_item): + @pytest.mark.asyncio + async def test_delete_nonexisting(self, s, get_item): with pytest.raises(exceptions.PreconditionFailed): - s.delete("1", '"123"') + await s.delete("1", '"123"') - def test_list(self, s, get_item): - assert not list(s.list()) - href, etag = s.upload(get_item()) + @pytest.mark.asyncio + async def test_list(self, s, get_item): + assert not await aiostream.stream.list(s.list()) + href, etag = await s.upload(get_item()) if etag is None: - _, etag = s.get(href) - assert list(s.list()) == [(href, etag)] + _, etag = await s.get(href) + assert await aiostream.stream.list(s.list()) == [(href, etag)] - def test_has(self, s, get_item): - assert not s.has("asd") - href, etag = s.upload(get_item()) - assert s.has(href) - assert not s.has("asd") - s.delete(href, etag) - assert not s.has(href) + @pytest.mark.asyncio + async def test_has(self, s, get_item): + assert not await s.has("asd") + href, etag = await s.upload(get_item()) + assert await s.has(href) + assert not await s.has("asd") + await s.delete(href, etag) + assert not await s.has(href) - def test_update_others_stay_the_same(self, s, get_item): + @pytest.mark.asyncio + async def test_update_others_stay_the_same(self, s, get_item): info = {} for _ in range(4): - href, etag = s.upload(get_item()) + href, etag = await s.upload(get_item()) if etag is None: - _, etag = s.get(href) + _, etag = await s.get(href) info[href] = etag - assert { - href: etag - for href, item, etag in s.get_multi(href for href, etag in info.items()) - } == info + items = await aiostream.stream.list( + s.get_multi(href for href, etag in info.items()) + ) + assert {href: etag for href, item, etag in items} == info - def test_repr(self, s, get_storage_args): + @pytest.mark.asyncio + def test_repr(self, s, get_storage_args): # XXX: unused param assert self.storage_class.__name__ in repr(s) assert s.instance_name is None - def test_discover(self, requires_collections, get_storage_args, get_item): + @pytest.mark.asyncio + async def test_discover( + self, + requires_collections, + get_storage_args, + get_item, + aio_connector, + ): collections = set() for i in range(1, 5): collection = f"test{i}" - s = self.storage_class(**get_storage_args(collection=collection)) - assert not list(s.list()) - s.upload(get_item()) + s = self.storage_class(**await get_storage_args(collection=collection)) + assert not await aiostream.stream.list(s.list()) + await s.upload(get_item()) collections.add(s.collection) - actual = { - c["collection"] - for c in self.storage_class.discover(**get_storage_args(collection=None)) - } + discovered = await aiostream.stream.list( + self.storage_class.discover(**await get_storage_args(collection=None)) + ) + actual = {c["collection"] for c in discovered} assert actual >= collections - def test_create_collection(self, requires_collections, get_storage_args, get_item): + @pytest.mark.asyncio + async def test_create_collection( + self, + requires_collections, + get_storage_args, + get_item, + ): if getattr(self, "dav_server", "") in ("icloud", "fastmail", "davical"): pytest.skip("Manual cleanup would be necessary.") if getattr(self, "dav_server", "") == "radicale": pytest.skip("Radicale does not support collection creation") - args = get_storage_args(collection=None) + args = await get_storage_args(collection=None) args["collection"] = "test" - s = self.storage_class(**self.storage_class.create_collection(**args)) + s = self.storage_class(**await self.storage_class.create_collection(**args)) - href = s.upload(get_item())[0] - assert href in (href for href, etag in s.list()) + href = (await s.upload(get_item()))[0] + assert href in await aiostream.stream.list( + (href async for href, etag in s.list()) + ) - def test_discover_collection_arg(self, requires_collections, get_storage_args): - args = get_storage_args(collection="test2") + @pytest.mark.asyncio + async def test_discover_collection_arg( + self, requires_collections, get_storage_args + ): + args = await get_storage_args(collection="test2") with pytest.raises(TypeError) as excinfo: - list(self.storage_class.discover(**args)) + await aiostream.stream.list(self.storage_class.discover(**args)) assert "collection argument must not be given" in str(excinfo.value) - def test_collection_arg(self, get_storage_args): + @pytest.mark.asyncio + async def test_collection_arg(self, get_storage_args): if self.storage_class.storage_name.startswith("etesync"): pytest.skip("etesync uses UUIDs.") if self.supports_collections: - s = self.storage_class(**get_storage_args(collection="test2")) + s = self.storage_class(**await get_storage_args(collection="test2")) # Can't do stronger assertion because of radicale, which needs a # fileextension to guess the collection type. assert "test2" in s.collection else: with pytest.raises(ValueError): - self.storage_class(collection="ayy", **get_storage_args()) + self.storage_class(collection="ayy", **await get_storage_args()) - def test_case_sensitive_uids(self, s, get_item): + @pytest.mark.asyncio + async def test_case_sensitive_uids(self, s, get_item): if s.storage_name == "filesystem": pytest.skip("Behavior depends on the filesystem.") uid = str(uuid.uuid4()) - s.upload(get_item(uid=uid.upper())) - s.upload(get_item(uid=uid.lower())) - items = [href for href, etag in s.list()] + await s.upload(get_item(uid=uid.upper())) + await s.upload(get_item(uid=uid.lower())) + items = [href async for href, etag in s.list()] assert len(items) == 2 assert len(set(items)) == 2 - def test_specialchars( + @pytest.mark.asyncio + async def test_specialchars( self, monkeypatch, requires_collections, get_storage_args, get_item ): if getattr(self, "dav_server", "") == "radicale": @@ -254,16 +291,16 @@ class StorageTests: uid = "test @ foo ät bar град сатану" collection = "test @ foo ät bar" - s = self.storage_class(**get_storage_args(collection=collection)) + s = self.storage_class(**await get_storage_args(collection=collection)) item = get_item(uid=uid) - href, etag = s.upload(item) - item2, etag2 = s.get(href) + href, etag = await s.upload(item) + item2, etag2 = await s.get(href) if etag is not None: assert etag2 == etag assert_item_equals(item2, item) - ((_, etag3),) = s.list() + ((_, etag3),) = await aiostream.stream.list(s.list()) assert etag2 == etag3 # etesync uses UUIDs for collection names @@ -274,22 +311,23 @@ class StorageTests: if self.storage_class.storage_name.endswith("dav"): assert urlquote(uid, "/@:") in href - def test_metadata(self, requires_metadata, s): + @pytest.mark.asyncio + async def test_metadata(self, requires_metadata, s): if not getattr(self, "dav_server", ""): - assert not s.get_meta("color") - assert not s.get_meta("displayname") + assert not await s.get_meta("color") + assert not await s.get_meta("displayname") try: - s.set_meta("color", None) - assert not s.get_meta("color") - s.set_meta("color", "#ff0000") - assert s.get_meta("color") == "#ff0000" + await s.set_meta("color", None) + assert not await s.get_meta("color") + await s.set_meta("color", "#ff0000") + assert await s.get_meta("color") == "#ff0000" except exceptions.UnsupportedMetadataError: pass for x in ("hello world", "hello wörld"): - s.set_meta("displayname", x) - rv = s.get_meta("displayname") + await s.set_meta("displayname", x) + rv = await s.get_meta("displayname") assert rv == x assert isinstance(rv, str) @@ -306,16 +344,18 @@ class StorageTests: "فلسطين", ], ) - def test_metadata_normalization(self, requires_metadata, s, value): - x = s.get_meta("displayname") + @pytest.mark.asyncio + async def test_metadata_normalization(self, requires_metadata, s, value): + x = await s.get_meta("displayname") assert x == normalize_meta_value(x) if not getattr(self, "dav_server", None): # ownCloud replaces "" with "unnamed" - s.set_meta("displayname", value) - assert s.get_meta("displayname") == normalize_meta_value(value) + await s.set_meta("displayname", value) + assert await s.get_meta("displayname") == normalize_meta_value(value) - def test_recurring_events(self, s, item_type): + @pytest.mark.asyncio + async def test_recurring_events(self, s, item_type): if item_type != "VEVENT": pytest.skip("This storage instance doesn't support iCalendar.") @@ -362,7 +402,7 @@ class StorageTests: ).strip() ) - href, etag = s.upload(item) + href, etag = await s.upload(item) - item2, etag2 = s.get(href) + item2, etag2 = await s.get(href) assert normalize_item(item) == normalize_item(item2) diff --git a/tests/storage/conftest.py b/tests/storage/conftest.py index 74b8d04..f0c9a12 100644 --- a/tests/storage/conftest.py +++ b/tests/storage/conftest.py @@ -3,6 +3,7 @@ import subprocess import time import uuid +import aiostream import pytest import requests @@ -80,31 +81,31 @@ def xandikos_server(): @pytest.fixture -def slow_create_collection(request): +async def slow_create_collection(request, aio_connector): # We need to properly clean up because otherwise we might run into # storage limits. to_delete = [] - def delete_collections(): + async def delete_collections(): for s in to_delete: - s.session.request("DELETE", "") + await s.session.request("DELETE", "") - request.addfinalizer(delete_collections) - - def inner(cls, args, collection): + async def inner(cls, args, collection): assert collection.startswith("test") collection += "-vdirsyncer-ci-" + str(uuid.uuid4()) - args = cls.create_collection(collection, **args) + args = await cls.create_collection(collection, **args) s = cls(**args) - _clear_collection(s) - assert not list(s.list()) + await _clear_collection(s) + assert not await aiostream.stream.list(s.list()) to_delete.append(s) return args - return inner + yield inner + + await delete_collections() -def _clear_collection(s): - for href, etag in s.list(): +async def _clear_collection(s): + async for href, etag in s.list(): s.delete(href, etag) diff --git a/tests/storage/dav/__init__.py b/tests/storage/dav/__init__.py index c714745..d45e9d6 100644 --- a/tests/storage/dav/__init__.py +++ b/tests/storage/dav/__init__.py @@ -1,8 +1,9 @@ import os import uuid +import aiohttp +import aiostream import pytest -import requests.exceptions from .. import get_server_mixin from .. import StorageTests @@ -19,30 +20,33 @@ class DAVStorageTests(ServerMixin, StorageTests): dav_server = dav_server @pytest.mark.skipif(dav_server == "radicale", reason="Radicale is very tolerant.") - def test_dav_broken_item(self, s): + @pytest.mark.asyncio + async def test_dav_broken_item(self, s): item = Item("HAHA:YES") - with pytest.raises((exceptions.Error, requests.exceptions.HTTPError)): - s.upload(item) - assert not list(s.list()) + with pytest.raises((exceptions.Error, aiohttp.ClientResponseError)): + await s.upload(item) + assert not await aiostream.stream.list(s.list()) - def test_dav_empty_get_multi_performance(self, s, monkeypatch): + @pytest.mark.asyncio + async def test_dav_empty_get_multi_performance(self, s, monkeypatch): def breakdown(*a, **kw): raise AssertionError("Expected not to be called.") monkeypatch.setattr("requests.sessions.Session.request", breakdown) try: - assert list(s.get_multi([])) == [] + assert list(await aiostream.stream.list(s.get_multi([]))) == [] finally: # Make sure monkeypatch doesn't interfere with DAV server teardown monkeypatch.undo() - def test_dav_unicode_href(self, s, get_item, monkeypatch): + @pytest.mark.asyncio + async def test_dav_unicode_href(self, s, get_item, monkeypatch): if self.dav_server == "radicale": pytest.skip("Radicale is unable to deal with unicode hrefs") monkeypatch.setattr(s, "_get_href", lambda item: item.ident + s.fileext) item = get_item(uid="град сатану" + str(uuid.uuid4())) - href, etag = s.upload(item) - item2, etag2 = s.get(href) + href, etag = await s.upload(item) + item2, etag2 = await s.get(href) assert_item_equals(item, item2) diff --git a/tests/storage/dav/test_caldav.py b/tests/storage/dav/test_caldav.py index 4a59a9f..64df92c 100644 --- a/tests/storage/dav/test_caldav.py +++ b/tests/storage/dav/test_caldav.py @@ -1,8 +1,10 @@ import datetime from textwrap import dedent +import aiohttp +import aiostream import pytest -import requests.exceptions +from aioresponses import aioresponses from . import dav_server from . import DAVStorageTests @@ -21,15 +23,17 @@ class TestCalDAVStorage(DAVStorageTests): def item_type(self, request): return request.param - @pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") - def test_doesnt_accept_vcard(self, item_type, get_storage_args): - s = self.storage_class(item_types=(item_type,), **get_storage_args()) + @pytest.mark.asyncio + async def test_doesnt_accept_vcard(self, item_type, get_storage_args): + s = self.storage_class(item_types=(item_type,), **await get_storage_args()) try: - s.upload(format_item(VCARD_TEMPLATE)) - except (exceptions.Error, requests.exceptions.HTTPError): + await s.upload(format_item(VCARD_TEMPLATE)) + except (exceptions.Error, aiohttp.ClientResponseError): + # Most storages hard-fail, but xandikos doesn't. pass - assert not list(s.list()) + + assert not await aiostream.stream.list(s.list()) # The `arg` param is not named `item_types` because that would hit # https://bitbucket.org/pytest-dev/pytest/issue/745/ @@ -44,10 +48,11 @@ class TestCalDAVStorage(DAVStorageTests): ], ) @pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") - def test_item_types_performance( + @pytest.mark.asyncio + async def test_item_types_performance( self, get_storage_args, arg, calls_num, monkeypatch ): - s = self.storage_class(item_types=arg, **get_storage_args()) + s = self.storage_class(item_types=arg, **await get_storage_args()) old_parse = s._parse_prop_responses calls = [] @@ -56,17 +61,18 @@ class TestCalDAVStorage(DAVStorageTests): return old_parse(*a, **kw) monkeypatch.setattr(s, "_parse_prop_responses", new_parse) - list(s.list()) + await aiostream.stream.list(s.list()) assert len(calls) == calls_num @pytest.mark.xfail( dav_server == "radicale", reason="Radicale doesn't support timeranges." ) - def test_timerange_correctness(self, get_storage_args): + @pytest.mark.asyncio + async def test_timerange_correctness(self, get_storage_args): start_date = datetime.datetime(2013, 9, 10) end_date = datetime.datetime(2013, 9, 13) s = self.storage_class( - start_date=start_date, end_date=end_date, **get_storage_args() + start_date=start_date, end_date=end_date, **await get_storage_args() ) too_old_item = format_item( @@ -123,50 +129,44 @@ class TestCalDAVStorage(DAVStorageTests): ).strip() ) - s.upload(too_old_item) - s.upload(too_new_item) - expected_href, _ = s.upload(good_item) + await s.upload(too_old_item) + await s.upload(too_new_item) + expected_href, _ = await s.upload(good_item) - ((actual_href, _),) = s.list() + ((actual_href, _),) = await aiostream.stream.list(s.list()) assert actual_href == expected_href - def test_invalid_resource(self, monkeypatch, get_storage_args): - calls = [] - args = get_storage_args(collection=None) + @pytest.mark.asyncio + async def test_invalid_resource(self, monkeypatch, get_storage_args): + args = await get_storage_args(collection=None) - def request(session, method, url, **kwargs): - assert url == args["url"] - calls.append(None) + with aioresponses() as m: + m.add(args["url"], method="PROPFIND", status=200, body="Hello world") - r = requests.Response() - r.status_code = 200 - r._content = b"Hello World." - return r + with pytest.raises(ValueError): + s = self.storage_class(**args) + await aiostream.stream.list(s.list()) - monkeypatch.setattr("requests.sessions.Session.request", request) - - with pytest.raises(ValueError): - s = self.storage_class(**args) - list(s.list()) - assert len(calls) == 1 + assert len(m.requests) == 1 @pytest.mark.skipif(dav_server == "icloud", reason="iCloud only accepts VEVENT") @pytest.mark.skipif( dav_server == "fastmail", reason="Fastmail has non-standard hadling of VTODOs." ) @pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") - def test_item_types_general(self, s): - event = s.upload(format_item(EVENT_TEMPLATE))[0] - task = s.upload(format_item(TASK_TEMPLATE))[0] + @pytest.mark.asyncio + async def test_item_types_general(self, s): + event = (await s.upload(format_item(EVENT_TEMPLATE)))[0] + task = (await s.upload(format_item(TASK_TEMPLATE)))[0] s.item_types = ("VTODO", "VEVENT") - def hrefs(): - return {href for href, etag in s.list()} + async def hrefs(): + return {href async for href, etag in s.list()} - assert hrefs() == {event, task} + assert await hrefs() == {event, task} s.item_types = ("VTODO",) - assert hrefs() == {task} + assert await hrefs() == {task} s.item_types = ("VEVENT",) - assert hrefs() == {event} + assert await hrefs() == {event} s.item_types = () - assert hrefs() == {event, task} + assert await hrefs() == {event, task} diff --git a/tests/storage/etesync/test_main.py b/tests/storage/etesync/test_main.py index dcb1d1f..442cb1a 100644 --- a/tests/storage/etesync/test_main.py +++ b/tests/storage/etesync/test_main.py @@ -58,7 +58,7 @@ class EtesyncTests(StorageTests): ) assert r.status_code == 200 - def inner(collection="test"): + async def inner(collection="test"): rv = { "email": "test@localhost", "db_path": str(tmpdir.join("etesync.db")), diff --git a/tests/storage/servers/baikal/__init__.py b/tests/storage/servers/baikal/__init__.py index 7547e07..ad27b39 100644 --- a/tests/storage/servers/baikal/__init__.py +++ b/tests/storage/servers/baikal/__init__.py @@ -3,13 +3,21 @@ import pytest class ServerMixin: @pytest.fixture - def get_storage_args(self, request, tmpdir, slow_create_collection, baikal_server): - def inner(collection="test"): + def get_storage_args( + self, + request, + tmpdir, + slow_create_collection, + baikal_server, + aio_connector, + ): + async def inner(collection="test"): base_url = "http://127.0.0.1:8002/" args = { "url": base_url, "username": "baikal", "password": "baikal", + "connector": aio_connector, } if self.storage_class.fileext == ".vcf": @@ -18,7 +26,11 @@ class ServerMixin: args["url"] = base_url + "cal.php/" if collection is not None: - args = slow_create_collection(self.storage_class, args, collection) + args = await slow_create_collection( + self.storage_class, + args, + collection, + ) return args return inner diff --git a/tests/storage/servers/davical/__init__.py b/tests/storage/servers/davical/__init__.py index 141b9f0..ce6a002 100644 --- a/tests/storage/servers/davical/__init__.py +++ b/tests/storage/servers/davical/__init__.py @@ -27,7 +27,7 @@ class ServerMixin: @pytest.fixture def get_storage_args(self, davical_args, request): - def inner(collection="test"): + async def inner(collection="test"): if collection is None: return davical_args diff --git a/tests/storage/servers/fastmail/__init__.py b/tests/storage/servers/fastmail/__init__.py index 1c21c6d..a847fdd 100644 --- a/tests/storage/servers/fastmail/__init__.py +++ b/tests/storage/servers/fastmail/__init__.py @@ -11,7 +11,7 @@ class ServerMixin: # See https://github.com/pimutils/vdirsyncer/issues/824 pytest.skip("Fastmail has non-standard VTODO support.") - def inner(collection="test"): + async def inner(collection="test"): args = { "username": os.environ["FASTMAIL_USERNAME"], "password": os.environ["FASTMAIL_PASSWORD"], diff --git a/tests/storage/servers/icloud/__init__.py b/tests/storage/servers/icloud/__init__.py index fa34e27..cb7285c 100644 --- a/tests/storage/servers/icloud/__init__.py +++ b/tests/storage/servers/icloud/__init__.py @@ -11,7 +11,7 @@ class ServerMixin: # See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa pytest.skip("iCloud doesn't support anything else than VEVENT") - def inner(collection="test"): + async def inner(collection="test"): args = { "username": os.environ["ICLOUD_USERNAME"], "password": os.environ["ICLOUD_PASSWORD"], diff --git a/tests/storage/servers/radicale/__init__.py b/tests/storage/servers/radicale/__init__.py index fb97ab2..f59cc81 100644 --- a/tests/storage/servers/radicale/__init__.py +++ b/tests/storage/servers/radicale/__init__.py @@ -9,17 +9,23 @@ class ServerMixin: tmpdir, slow_create_collection, radicale_server, + aio_connector, ): - def inner(collection="test"): + async def inner(collection="test"): url = "http://127.0.0.1:8001/" args = { "url": url, "username": "radicale", "password": "radicale", + "connector": aio_connector, } if collection is not None: - args = slow_create_collection(self.storage_class, args, collection) + args = await slow_create_collection( + self.storage_class, + args, + collection, + ) return args return inner diff --git a/tests/storage/servers/xandikos/__init__.py b/tests/storage/servers/xandikos/__init__.py index ece5de0..36b4eda 100644 --- a/tests/storage/servers/xandikos/__init__.py +++ b/tests/storage/servers/xandikos/__init__.py @@ -9,13 +9,19 @@ class ServerMixin: tmpdir, slow_create_collection, xandikos_server, + aio_connector, ): - def inner(collection="test"): + async def inner(collection="test"): url = "http://127.0.0.1:8000/" - args = {"url": url} + args = {"url": url, "connector": aio_connector} if collection is not None: - args = slow_create_collection(self.storage_class, args, collection) + args = await slow_create_collection( + self.storage_class, + args, + collection, + ) + return args return inner diff --git a/tests/storage/test_filesystem.py b/tests/storage/test_filesystem.py index d14999c..412a4fa 100644 --- a/tests/storage/test_filesystem.py +++ b/tests/storage/test_filesystem.py @@ -1,5 +1,6 @@ import subprocess +import aiostream import pytest from . import StorageTests @@ -12,10 +13,10 @@ class TestFilesystemStorage(StorageTests): @pytest.fixture def get_storage_args(self, tmpdir): - def inner(collection="test"): + async def inner(collection="test"): rv = {"path": str(tmpdir), "fileext": ".txt", "collection": collection} if collection is not None: - rv = self.storage_class.create_collection(**rv) + rv = await self.storage_class.create_collection(**rv) return rv return inner @@ -26,7 +27,8 @@ class TestFilesystemStorage(StorageTests): f.write("stub") self.storage_class(str(tmpdir) + "/hue", ".txt") - def test_broken_data(self, tmpdir): + @pytest.mark.asyncio + async def test_broken_data(self, tmpdir): s = self.storage_class(str(tmpdir), ".txt") class BrokenItem: @@ -35,64 +37,70 @@ class TestFilesystemStorage(StorageTests): ident = uid with pytest.raises(TypeError): - s.upload(BrokenItem) + await s.upload(BrokenItem) assert not tmpdir.listdir() - def test_ident_with_slash(self, tmpdir): + @pytest.mark.asyncio + async def test_ident_with_slash(self, tmpdir): s = self.storage_class(str(tmpdir), ".txt") - s.upload(Item("UID:a/b/c")) + await s.upload(Item("UID:a/b/c")) (item_file,) = tmpdir.listdir() assert "/" not in item_file.basename and item_file.isfile() - def test_ignore_tmp_files(self, tmpdir): + @pytest.mark.asyncio + async def test_ignore_tmp_files(self, tmpdir): """Test that files with .tmp suffix beside .ics files are ignored.""" s = self.storage_class(str(tmpdir), ".ics") - s.upload(Item("UID:xyzxyz")) + await s.upload(Item("UID:xyzxyz")) (item_file,) = tmpdir.listdir() item_file.copy(item_file.new(ext="tmp")) assert len(tmpdir.listdir()) == 2 - assert len(list(s.list())) == 1 + assert len(await aiostream.stream.list(s.list())) == 1 - def test_ignore_tmp_files_empty_fileext(self, tmpdir): + @pytest.mark.asyncio + async def test_ignore_tmp_files_empty_fileext(self, tmpdir): """Test that files with .tmp suffix are ignored with empty fileext.""" s = self.storage_class(str(tmpdir), "") - s.upload(Item("UID:xyzxyz")) + await s.upload(Item("UID:xyzxyz")) (item_file,) = tmpdir.listdir() item_file.copy(item_file.new(ext="tmp")) assert len(tmpdir.listdir()) == 2 # assert False, tmpdir.listdir() # enable to see the created filename - assert len(list(s.list())) == 1 + assert len(await aiostream.stream.list(s.list())) == 1 - def test_ignore_files_typical_backup(self, tmpdir): + @pytest.mark.asyncio + async def test_ignore_files_typical_backup(self, tmpdir): """Test file-name ignorance with typical backup ending ~.""" ignorext = "~" # without dot storage = self.storage_class(str(tmpdir), "", fileignoreext=ignorext) - storage.upload(Item("UID:xyzxyz")) + await storage.upload(Item("UID:xyzxyz")) (item_file,) = tmpdir.listdir() item_file.copy(item_file.new(basename=item_file.basename + ignorext)) assert len(tmpdir.listdir()) == 2 - assert len(list(storage.list())) == 1 + assert len(await aiostream.stream.list(storage.list())) == 1 - def test_too_long_uid(self, tmpdir): + @pytest.mark.asyncio + async def test_too_long_uid(self, tmpdir): storage = self.storage_class(str(tmpdir), ".txt") item = Item("UID:" + "hue" * 600) - href, etag = storage.upload(item) + href, etag = await storage.upload(item) assert item.uid not in href - def test_post_hook_inactive(self, tmpdir, monkeypatch): + @pytest.mark.asyncio + async def test_post_hook_inactive(self, tmpdir, monkeypatch): def check_call_mock(*args, **kwargs): raise AssertionError() monkeypatch.setattr(subprocess, "call", check_call_mock) s = self.storage_class(str(tmpdir), ".txt", post_hook=None) - s.upload(Item("UID:a/b/c")) - - def test_post_hook_active(self, tmpdir, monkeypatch): + await s.upload(Item("UID:a/b/c")) + @pytest.mark.asyncio + async def test_post_hook_active(self, tmpdir, monkeypatch): calls = [] exe = "foo" @@ -104,14 +112,17 @@ class TestFilesystemStorage(StorageTests): monkeypatch.setattr(subprocess, "call", check_call_mock) s = self.storage_class(str(tmpdir), ".txt", post_hook=exe) - s.upload(Item("UID:a/b/c")) + await s.upload(Item("UID:a/b/c")) assert calls - def test_ignore_git_dirs(self, tmpdir): + @pytest.mark.asyncio + async def test_ignore_git_dirs(self, tmpdir): tmpdir.mkdir(".git").mkdir("foo") tmpdir.mkdir("a") tmpdir.mkdir("b") - assert {c["collection"] for c in self.storage_class.discover(str(tmpdir))} == { - "a", - "b", + + expected = {"a", "b"} + actual = { + c["collection"] async for c in self.storage_class.discover(str(tmpdir)) } + assert actual == expected diff --git a/tests/storage/test_http.py b/tests/storage/test_http.py index 0bf69cd..b18beaa 100644 --- a/tests/storage/test_http.py +++ b/tests/storage/test_http.py @@ -1,5 +1,6 @@ import pytest -from requests import Response +from aioresponses import aioresponses +from aioresponses import CallbackResult from tests import normalize_item from vdirsyncer.exceptions import UserError @@ -7,7 +8,8 @@ from vdirsyncer.storage.http import HttpStorage from vdirsyncer.storage.http import prepare_auth -def test_list(monkeypatch): +@pytest.mark.asyncio +async def test_list(aio_connector): collection_url = "http://127.0.0.1/calendar/collection.ics" items = [ @@ -34,50 +36,53 @@ def test_list(monkeypatch): responses = ["\n".join(["BEGIN:VCALENDAR"] + items + ["END:VCALENDAR"])] * 2 - def get(self, method, url, *a, **kw): - assert method == "GET" - assert url == collection_url - r = Response() - r.status_code = 200 + def callback(url, headers, **kwargs): + assert headers["User-Agent"].startswith("vdirsyncer/") assert responses - r._content = responses.pop().encode("utf-8") - r.headers["Content-Type"] = "text/calendar" - r.encoding = "ISO-8859-1" - return r - monkeypatch.setattr("requests.sessions.Session.request", get) + return CallbackResult( + status=200, + body=responses.pop().encode("utf-8"), + headers={"Content-Type": "text/calendar; charset=iso-8859-1"}, + ) - s = HttpStorage(url=collection_url) + with aioresponses() as m: + m.get(collection_url, callback=callback, repeat=True) - found_items = {} + s = HttpStorage(url=collection_url, connector=aio_connector) - for href, etag in s.list(): - item, etag2 = s.get(href) - assert item.uid is not None - assert etag2 == etag - found_items[normalize_item(item)] = href + found_items = {} - expected = { - normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items - } + async for href, etag in s.list(): + item, etag2 = await s.get(href) + assert item.uid is not None + assert etag2 == etag + found_items[normalize_item(item)] = href - assert set(found_items) == expected + expected = { + normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items + } - for href, etag in s.list(): - item, etag2 = s.get(href) - assert item.uid is not None - assert etag2 == etag - assert found_items[normalize_item(item)] == href + assert set(found_items) == expected + + async for href, etag in s.list(): + item, etag2 = await s.get(href) + assert item.uid is not None + assert etag2 == etag + assert found_items[normalize_item(item)] == href -def test_readonly_param(): +def test_readonly_param(aio_connector): + """The ``readonly`` param cannot be ``False``.""" + url = "http://example.com/" with pytest.raises(ValueError): - HttpStorage(url=url, read_only=False) + HttpStorage(url=url, read_only=False, connector=aio_connector) - a = HttpStorage(url=url, read_only=True).read_only - b = HttpStorage(url=url, read_only=None).read_only - assert a is b is True + a = HttpStorage(url=url, read_only=True, connector=aio_connector) + b = HttpStorage(url=url, read_only=None, connector=aio_connector) + + assert a.read_only is b.read_only is True def test_prepare_auth(): @@ -115,9 +120,9 @@ def test_prepare_auth_guess(monkeypatch): assert "requests_toolbelt is too old" in str(excinfo.value).lower() -def test_verify_false_disallowed(): +def test_verify_false_disallowed(aio_connector): with pytest.raises(ValueError) as excinfo: - HttpStorage(url="http://example.com", verify=False) + HttpStorage(url="http://example.com", verify=False, connector=aio_connector) assert "forbidden" in str(excinfo.value).lower() assert "consider setting verify_fingerprint" in str(excinfo.value).lower() diff --git a/tests/storage/test_http_with_singlefile.py b/tests/storage/test_http_with_singlefile.py index acb1c39..1ff73ba 100644 --- a/tests/storage/test_http_with_singlefile.py +++ b/tests/storage/test_http_with_singlefile.py @@ -1,5 +1,7 @@ +import aiostream import pytest -from requests import Response +from aioresponses import aioresponses +from aioresponses import CallbackResult import vdirsyncer.storage.http from . import StorageTests @@ -14,32 +16,33 @@ class CombinedStorage(Storage): _repr_attributes = ("url", "path") storage_name = "http_and_singlefile" - def __init__(self, url, path, **kwargs): + def __init__(self, url, path, *, connector, **kwargs): if kwargs.get("collection", None) is not None: raise ValueError() super().__init__(**kwargs) self.url = url self.path = path - self._reader = vdirsyncer.storage.http.HttpStorage(url=url) + self._reader = vdirsyncer.storage.http.HttpStorage(url=url, connector=connector) self._reader._ignore_uids = False self._writer = SingleFileStorage(path=path) - def list(self, *a, **kw): - return self._reader.list(*a, **kw) + async def list(self, *a, **kw): + async for item in self._reader.list(*a, **kw): + yield item - def get(self, *a, **kw): - self.list() - return self._reader.get(*a, **kw) + async def get(self, *a, **kw): + await aiostream.stream.list(self.list()) + return await self._reader.get(*a, **kw) - def upload(self, *a, **kw): - return self._writer.upload(*a, **kw) + async def upload(self, *a, **kw): + return await self._writer.upload(*a, **kw) - def update(self, *a, **kw): - return self._writer.update(*a, **kw) + async def update(self, *a, **kw): + return await self._writer.update(*a, **kw) - def delete(self, *a, **kw): - return self._writer.delete(*a, **kw) + async def delete(self, *a, **kw): + return await self._writer.delete(*a, **kw) class TestHttpStorage(StorageTests): @@ -51,28 +54,37 @@ class TestHttpStorage(StorageTests): def setup_tmpdir(self, tmpdir, monkeypatch): self.tmpfile = str(tmpdir.ensure("collection.txt")) - def _request(method, url, *args, **kwargs): - assert method == "GET" - assert url == "http://localhost:123/collection.txt" - assert "vdirsyncer" in kwargs["headers"]["User-Agent"] - r = Response() - r.status_code = 200 - try: - with open(self.tmpfile, "rb") as f: - r._content = f.read() - except OSError: - r._content = b"" + def callback(url, headers, **kwargs): + """Read our tmpfile at request time. - r.headers["Content-Type"] = "text/calendar" - r.encoding = "utf-8" - return r + We can't just read this during test setup since the file get written to + during test execution. - monkeypatch.setattr(vdirsyncer.storage.http, "request", _request) + It might make sense to actually run a server serving the local file. + """ + assert headers["User-Agent"].startswith("vdirsyncer/") + + with open(self.tmpfile, "r") as f: + body = f.read() + + return CallbackResult( + status=200, + body=body, + headers={"Content-Type": "text/calendar; charset=utf-8"}, + ) + + with aioresponses() as m: + m.get("http://localhost:123/collection.txt", callback=callback, repeat=True) + yield @pytest.fixture - def get_storage_args(self): - def inner(collection=None): + def get_storage_args(self, aio_connector): + async def inner(collection=None): assert collection is None - return {"url": "http://localhost:123/collection.txt", "path": self.tmpfile} + return { + "url": "http://localhost:123/collection.txt", + "path": self.tmpfile, + "connector": aio_connector, + } return inner diff --git a/tests/storage/test_memory.py b/tests/storage/test_memory.py index 82db79c..41173c1 100644 --- a/tests/storage/test_memory.py +++ b/tests/storage/test_memory.py @@ -11,4 +11,7 @@ class TestMemoryStorage(StorageTests): @pytest.fixture def get_storage_args(self): - return lambda **kw: kw + async def inner(**args): + return args + + return inner diff --git a/tests/storage/test_singlefile.py b/tests/storage/test_singlefile.py index 58a95a5..d5449fd 100644 --- a/tests/storage/test_singlefile.py +++ b/tests/storage/test_singlefile.py @@ -11,10 +11,10 @@ class TestSingleFileStorage(StorageTests): @pytest.fixture def get_storage_args(self, tmpdir): - def inner(collection="test"): + async def inner(collection="test"): rv = {"path": str(tmpdir.join("%s.txt")), "collection": collection} if collection is not None: - rv = self.storage_class.create_collection(**rv) + rv = await self.storage_class.create_collection(**rv) return rv return inner diff --git a/tests/system/cli/test_utils.py b/tests/system/cli/test_utils.py index 20815cf..fe6722e 100644 --- a/tests/system/cli/test_utils.py +++ b/tests/system/cli/test_utils.py @@ -1,3 +1,5 @@ +import pytest + from vdirsyncer import exceptions from vdirsyncer.cli.utils import handle_cli_error from vdirsyncer.cli.utils import storage_instance_from_config @@ -15,11 +17,13 @@ def test_handle_cli_error(capsys): assert "ayy lmao" in err -def test_storage_instance_from_config(monkeypatch): - def lol(**kw): - assert kw == {"foo": "bar", "baz": 1} - return "OK" +@pytest.mark.asyncio +async def test_storage_instance_from_config(monkeypatch, aio_connector): + class Dummy: + def __init__(self, **kw): + assert kw == {"foo": "bar", "baz": 1} - monkeypatch.setitem(storage_names._storages, "lol", lol) + monkeypatch.setitem(storage_names._storages, "lol", Dummy) config = {"type": "lol", "foo": "bar", "baz": 1} - assert storage_instance_from_config(config) == "OK" + storage = await storage_instance_from_config(config, connector=aio_connector) + assert isinstance(storage, Dummy) diff --git a/tests/system/utils/test_main.py b/tests/system/utils/test_main.py index db3e70f..4403582 100644 --- a/tests/system/utils/test_main.py +++ b/tests/system/utils/test_main.py @@ -1,9 +1,9 @@ import logging import sys +import aiohttp import click_log import pytest -import requests from cryptography import x509 from cryptography.hazmat.primitives import hashes @@ -25,74 +25,77 @@ def test_get_storage_init_args(): assert not required -def test_request_ssl(): - with pytest.raises(requests.exceptions.ConnectionError) as excinfo: - http.request("GET", "https://self-signed.badssl.com/") - assert "certificate verify failed" in str(excinfo.value) +@pytest.mark.asyncio +async def test_request_ssl(): + async with aiohttp.ClientSession() as session: + with pytest.raises(aiohttp.ClientConnectorCertificateError) as excinfo: + await http.request( + "GET", + "https://self-signed.badssl.com/", + session=session, + ) + assert "certificate verify failed" in str(excinfo.value) - http.request("GET", "https://self-signed.badssl.com/", verify=False) + await http.request( + "GET", + "https://self-signed.badssl.com/", + verify=False, + session=session, + ) -def _fingerprints_broken(): - from pkg_resources import parse_version as ver - - broken_urllib3 = ver(requests.__version__) <= ver("2.5.1") - return broken_urllib3 - - -def fingerprint_of_cert(cert, hash=hashes.SHA256): +def fingerprint_of_cert(cert, hash=hashes.SHA256) -> str: return x509.load_pem_x509_certificate(cert.bytes()).fingerprint(hash()).hex() -@pytest.mark.skipif( - _fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529" -) -@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256]) -def test_request_ssl_leaf_fingerprint(httpserver, localhost_cert, hash_algorithm): +@pytest.mark.parametrize("hash_algorithm", [hashes.SHA256]) +@pytest.mark.asyncio +async def test_request_ssl_leaf_fingerprint( + httpserver, + localhost_cert, + hash_algorithm, + aio_session, +): fingerprint = fingerprint_of_cert(localhost_cert.cert_chain_pems[0], hash_algorithm) + bogus = "".join(reversed(fingerprint)) # We have to serve something: httpserver.expect_request("/").respond_with_data("OK") - url = f"https://{httpserver.host}:{httpserver.port}/" + url = f"https://127.0.0.1:{httpserver.port}/" - http.request("GET", url, verify=False, verify_fingerprint=fingerprint) - with pytest.raises(requests.exceptions.ConnectionError) as excinfo: - http.request("GET", url, verify_fingerprint=fingerprint) + await http.request("GET", url, verify_fingerprint=fingerprint, session=aio_session) - with pytest.raises(requests.exceptions.ConnectionError) as excinfo: - http.request( - "GET", - url, - verify=False, - verify_fingerprint="".join(reversed(fingerprint)), - ) - assert "Fingerprints did not match" in str(excinfo.value) + with pytest.raises(aiohttp.ServerFingerprintMismatch): + await http.request("GET", url, verify_fingerprint=bogus, session=aio_session) -@pytest.mark.skipif( - _fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529" -) @pytest.mark.xfail(reason="Not implemented") -@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256]) -def test_request_ssl_ca_fingerprint(httpserver, ca, hash_algorithm): +@pytest.mark.parametrize("hash_algorithm", [hashes.SHA256]) +@pytest.mark.asyncio +async def test_request_ssl_ca_fingerprints(httpserver, ca, hash_algorithm, aio_session): fingerprint = fingerprint_of_cert(ca.cert_pem) + bogus = "".join(reversed(fingerprint)) # We have to serve something: httpserver.expect_request("/").respond_with_data("OK") - url = f"https://{httpserver.host}:{httpserver.port}/" + url = f"https://127.0.0.1:{httpserver.port}/" - http.request("GET", url, verify=False, verify_fingerprint=fingerprint) - with pytest.raises(requests.exceptions.ConnectionError) as excinfo: - http.request("GET", url, verify_fingerprint=fingerprint) + await http.request( + "GET", + url, + verify=False, + verify_fingerprint=fingerprint, + session=aio_session, + ) - with pytest.raises(requests.exceptions.ConnectionError) as excinfo: + with pytest.raises(aiohttp.ServerFingerprintMismatch): http.request( "GET", url, verify=False, - verify_fingerprint="".join(reversed(fingerprint)), + verify_fingerprint=bogus, + session=aio_session, ) - assert "Fingerprints did not match" in str(excinfo.value) def test_open_graphical_browser(monkeypatch): diff --git a/tests/unit/cli/test_discover.py b/tests/unit/cli/test_discover.py index 983707c..b075315 100644 --- a/tests/unit/cli/test_discover.py +++ b/tests/unit/cli/test_discover.py @@ -1,3 +1,4 @@ +import aiostream import pytest from vdirsyncer.cli.discover import expand_collections @@ -132,34 +133,40 @@ missing = object() ), ], ) -def test_expand_collections(shortcuts, expected): +@pytest.mark.asyncio +async def test_expand_collections(shortcuts, expected): config_a = {"type": "fooboo", "storage_side": "a"} config_b = {"type": "fooboo", "storage_side": "b"} - def get_discovered_a(): + async def get_discovered_a(): return { "c1": {"type": "fooboo", "custom_arg": "a1", "collection": "c1"}, "c2": {"type": "fooboo", "custom_arg": "a2", "collection": "c2"}, "a3": {"type": "fooboo", "custom_arg": "a3", "collection": "a3"}, } - def get_discovered_b(): + async def get_discovered_b(): return { "c1": {"type": "fooboo", "custom_arg": "b1", "collection": "c1"}, "c2": {"type": "fooboo", "custom_arg": "b2", "collection": "c2"}, "b3": {"type": "fooboo", "custom_arg": "b3", "collection": "b3"}, } + async def handle_not_found(config, collection): + return missing + assert ( sorted( - expand_collections( - shortcuts, - config_a, - config_b, - get_discovered_a, - get_discovered_b, - lambda config, collection: missing, + await aiostream.stream.list( + expand_collections( + shortcuts, + config_a, + config_b, + get_discovered_a, + get_discovered_b, + handle_not_found, + ) ) ) == sorted(expected) diff --git a/tests/unit/sync/test_sync.py b/tests/unit/sync/test_sync.py index 7ee315a..e01535a 100644 --- a/tests/unit/sync/test_sync.py +++ b/tests/unit/sync/test_sync.py @@ -1,5 +1,7 @@ +import asyncio from copy import deepcopy +import aiostream import hypothesis.strategies as st import pytest from hypothesis import assume @@ -21,10 +23,10 @@ from vdirsyncer.sync.status import SqliteStatus from vdirsyncer.vobject import Item -def sync(a, b, status, *args, **kwargs): +async def sync(a, b, status, *args, **kwargs): new_status = SqliteStatus(":memory:") new_status.load_legacy_status(status) - rv = _sync(a, b, new_status, *args, **kwargs) + rv = await _sync(a, b, new_status, *args, **kwargs) status.clear() status.update(new_status.to_legacy_status()) return rv @@ -38,45 +40,49 @@ def items(s): return {x[1].raw for x in s.items.values()} -def test_irrelevant_status(): +@pytest.mark.asyncio +async def test_irrelevant_status(): a = MemoryStorage() b = MemoryStorage() status = {"1": ("1", 1234, "1.ics", 2345)} - sync(a, b, status) + await sync(a, b, status) assert not status assert not items(a) assert not items(b) -def test_missing_status(): +@pytest.mark.asyncio +async def test_missing_status(): a = MemoryStorage() b = MemoryStorage() status = {} item = Item("asdf") - a.upload(item) - b.upload(item) - sync(a, b, status) + await a.upload(item) + await b.upload(item) + await sync(a, b, status) assert len(status) == 1 assert items(a) == items(b) == {item.raw} -def test_missing_status_and_different_items(): +@pytest.mark.asyncio +async def test_missing_status_and_different_items(): a = MemoryStorage() b = MemoryStorage() status = {} item1 = Item("UID:1\nhaha") item2 = Item("UID:1\nhoho") - a.upload(item1) - b.upload(item2) + await a.upload(item1) + await b.upload(item2) with pytest.raises(SyncConflict): - sync(a, b, status) + await sync(a, b, status) assert not status - sync(a, b, status, conflict_resolution="a wins") + await sync(a, b, status, conflict_resolution="a wins") assert items(a) == items(b) == {item1.raw} -def test_read_only_and_prefetch(): +@pytest.mark.asyncio +async def test_read_only_and_prefetch(): a = MemoryStorage() b = MemoryStorage() b.read_only = True @@ -84,147 +90,154 @@ def test_read_only_and_prefetch(): status = {} item1 = Item("UID:1\nhaha") item2 = Item("UID:2\nhoho") - a.upload(item1) - a.upload(item2) + await a.upload(item1) + await a.upload(item2) - sync(a, b, status, force_delete=True) - sync(a, b, status, force_delete=True) + await sync(a, b, status, force_delete=True) + await sync(a, b, status, force_delete=True) assert not items(a) and not items(b) -def test_partial_sync_error(): +@pytest.mark.asyncio +async def test_partial_sync_error(): a = MemoryStorage() b = MemoryStorage() status = {} - a.upload(Item("UID:0")) + await a.upload(Item("UID:0")) b.read_only = True with pytest.raises(PartialSync): - sync(a, b, status, partial_sync="error") + await sync(a, b, status, partial_sync="error") -def test_partial_sync_ignore(): +@pytest.mark.asyncio +async def test_partial_sync_ignore(): a = MemoryStorage() b = MemoryStorage() status = {} item0 = Item("UID:0\nhehe") - a.upload(item0) - b.upload(item0) + await a.upload(item0) + await b.upload(item0) b.read_only = True item1 = Item("UID:1\nhaha") - a.upload(item1) + await a.upload(item1) - sync(a, b, status, partial_sync="ignore") - sync(a, b, status, partial_sync="ignore") + await sync(a, b, status, partial_sync="ignore") + await sync(a, b, status, partial_sync="ignore") assert items(a) == {item0.raw, item1.raw} assert items(b) == {item0.raw} -def test_partial_sync_ignore2(): +@pytest.mark.asyncio +async def test_partial_sync_ignore2(): a = MemoryStorage() b = MemoryStorage() status = {} - href, etag = a.upload(Item("UID:0")) + href, etag = await a.upload(Item("UID:0")) a.read_only = True - sync(a, b, status, partial_sync="ignore", force_delete=True) + await sync(a, b, status, partial_sync="ignore", force_delete=True) assert items(b) == items(a) == {"UID:0"} b.items.clear() - sync(a, b, status, partial_sync="ignore", force_delete=True) - sync(a, b, status, partial_sync="ignore", force_delete=True) + await sync(a, b, status, partial_sync="ignore", force_delete=True) + await sync(a, b, status, partial_sync="ignore", force_delete=True) assert items(a) == {"UID:0"} assert not b.items a.read_only = False - a.update(href, Item("UID:0\nupdated"), etag) + await a.update(href, Item("UID:0\nupdated"), etag) a.read_only = True - sync(a, b, status, partial_sync="ignore", force_delete=True) + await sync(a, b, status, partial_sync="ignore", force_delete=True) assert items(b) == items(a) == {"UID:0\nupdated"} -def test_upload_and_update(): +@pytest.mark.asyncio +async def test_upload_and_update(): a = MemoryStorage(fileext=".a") b = MemoryStorage(fileext=".b") status = {} item = Item("UID:1") # new item 1 in a - a.upload(item) - sync(a, b, status) + await a.upload(item) + await sync(a, b, status) assert items(b) == items(a) == {item.raw} item = Item("UID:1\nASDF:YES") # update of item 1 in b - b.update("1.b", item, b.get("1.b")[1]) - sync(a, b, status) + await b.update("1.b", item, (await b.get("1.b"))[1]) + await sync(a, b, status) assert items(b) == items(a) == {item.raw} item2 = Item("UID:2") # new item 2 in b - b.upload(item2) - sync(a, b, status) + await b.upload(item2) + await sync(a, b, status) assert items(b) == items(a) == {item.raw, item2.raw} item2 = Item("UID:2\nASDF:YES") # update of item 2 in a - a.update("2.a", item2, a.get("2.a")[1]) - sync(a, b, status) + await a.update("2.a", item2, (await a.get("2.a"))[1]) + await sync(a, b, status) assert items(b) == items(a) == {item.raw, item2.raw} -def test_deletion(): +@pytest.mark.asyncio +async def test_deletion(): a = MemoryStorage(fileext=".a") b = MemoryStorage(fileext=".b") status = {} item = Item("UID:1") - a.upload(item) + await a.upload(item) item2 = Item("UID:2") - a.upload(item2) - sync(a, b, status) - b.delete("1.b", b.get("1.b")[1]) - sync(a, b, status) + await a.upload(item2) + await sync(a, b, status) + await b.delete("1.b", (await b.get("1.b"))[1]) + await sync(a, b, status) assert items(a) == items(b) == {item2.raw} - a.upload(item) - sync(a, b, status) + await a.upload(item) + await sync(a, b, status) assert items(a) == items(b) == {item.raw, item2.raw} - a.delete("1.a", a.get("1.a")[1]) - sync(a, b, status) + await a.delete("1.a", (await a.get("1.a"))[1]) + await sync(a, b, status) assert items(a) == items(b) == {item2.raw} -def test_insert_hash(): +@pytest.mark.asyncio +async def test_insert_hash(): a = MemoryStorage() b = MemoryStorage() status = {} item = Item("UID:1") - href, etag = a.upload(item) - sync(a, b, status) + href, etag = await a.upload(item) + await sync(a, b, status) for d in status["1"]: del d["hash"] - a.update(href, Item("UID:1\nHAHA:YES"), etag) - sync(a, b, status) + await a.update(href, Item("UID:1\nHAHA:YES"), etag) + await sync(a, b, status) assert "hash" in status["1"][0] and "hash" in status["1"][1] -def test_already_synced(): +@pytest.mark.asyncio +async def test_already_synced(): a = MemoryStorage(fileext=".a") b = MemoryStorage(fileext=".b") item = Item("UID:1") - a.upload(item) - b.upload(item) + await a.upload(item) + await b.upload(item) status = { "1": ( - {"href": "1.a", "hash": item.hash, "etag": a.get("1.a")[1]}, - {"href": "1.b", "hash": item.hash, "etag": b.get("1.b")[1]}, + {"href": "1.a", "hash": item.hash, "etag": (await a.get("1.a"))[1]}, + {"href": "1.b", "hash": item.hash, "etag": (await b.get("1.b"))[1]}, ) } old_status = deepcopy(status) @@ -233,69 +246,73 @@ def test_already_synced(): ) for _ in (1, 2): - sync(a, b, status) + await sync(a, b, status) assert status == old_status assert items(a) == items(b) == {item.raw} @pytest.mark.parametrize("winning_storage", "ab") -def test_conflict_resolution_both_etags_new(winning_storage): +@pytest.mark.asyncio +async def test_conflict_resolution_both_etags_new(winning_storage): a = MemoryStorage() b = MemoryStorage() item = Item("UID:1") - href_a, etag_a = a.upload(item) - href_b, etag_b = b.upload(item) + href_a, etag_a = await a.upload(item) + href_b, etag_b = await b.upload(item) status = {} - sync(a, b, status) + await sync(a, b, status) assert status item_a = Item("UID:1\nitem a") item_b = Item("UID:1\nitem b") - a.update(href_a, item_a, etag_a) - b.update(href_b, item_b, etag_b) + await a.update(href_a, item_a, etag_a) + await b.update(href_b, item_b, etag_b) with pytest.raises(SyncConflict): - sync(a, b, status) - sync(a, b, status, conflict_resolution=f"{winning_storage} wins") + await sync(a, b, status) + await sync(a, b, status, conflict_resolution=f"{winning_storage} wins") assert ( items(a) == items(b) == {item_a.raw if winning_storage == "a" else item_b.raw} ) -def test_updated_and_deleted(): +@pytest.mark.asyncio +async def test_updated_and_deleted(): a = MemoryStorage() b = MemoryStorage() - href_a, etag_a = a.upload(Item("UID:1")) + href_a, etag_a = await a.upload(Item("UID:1")) status = {} - sync(a, b, status, force_delete=True) + await sync(a, b, status, force_delete=True) - ((href_b, etag_b),) = b.list() - b.delete(href_b, etag_b) + ((href_b, etag_b),) = await aiostream.stream.list(b.list()) + await b.delete(href_b, etag_b) updated = Item("UID:1\nupdated") - a.update(href_a, updated, etag_a) - sync(a, b, status, force_delete=True) + await a.update(href_a, updated, etag_a) + await sync(a, b, status, force_delete=True) assert items(a) == items(b) == {updated.raw} -def test_conflict_resolution_invalid_mode(): +@pytest.mark.asyncio +async def test_conflict_resolution_invalid_mode(): a = MemoryStorage() b = MemoryStorage() item_a = Item("UID:1\nitem a") item_b = Item("UID:1\nitem b") - a.upload(item_a) - b.upload(item_b) + await a.upload(item_a) + await b.upload(item_b) with pytest.raises(ValueError): - sync(a, b, {}, conflict_resolution="yolo") + await sync(a, b, {}, conflict_resolution="yolo") -def test_conflict_resolution_new_etags_without_changes(): +@pytest.mark.asyncio +async def test_conflict_resolution_new_etags_without_changes(): a = MemoryStorage() b = MemoryStorage() item = Item("UID:1") - href_a, etag_a = a.upload(item) - href_b, etag_b = b.upload(item) + href_a, etag_a = await a.upload(item) + href_b, etag_b = await b.upload(item) status = {"1": (href_a, "BOGUS_a", href_b, "BOGUS_b")} - sync(a, b, status) + await sync(a, b, status) ((ident, (status_a, status_b)),) = status.items() assert ident == "1" @@ -305,7 +322,8 @@ def test_conflict_resolution_new_etags_without_changes(): assert status_b["etag"] == etag_b -def test_uses_get_multi(monkeypatch): +@pytest.mark.asyncio +async def test_uses_get_multi(monkeypatch): def breakdown(*a, **kw): raise AssertionError("Expected use of get_multi") @@ -313,11 +331,11 @@ def test_uses_get_multi(monkeypatch): old_get = MemoryStorage.get - def get_multi(self, hrefs): + async def get_multi(self, hrefs): hrefs = list(hrefs) get_multi_calls.append(hrefs) for href in hrefs: - item, etag = old_get(self, href) + item, etag = await old_get(self, href) yield href, item, etag monkeypatch.setattr(MemoryStorage, "get", breakdown) @@ -326,72 +344,77 @@ def test_uses_get_multi(monkeypatch): a = MemoryStorage() b = MemoryStorage() item = Item("UID:1") - expected_href, etag = a.upload(item) + expected_href, etag = await a.upload(item) - sync(a, b, {}) + await sync(a, b, {}) assert get_multi_calls == [[expected_href]] -def test_empty_storage_dataloss(): +@pytest.mark.asyncio +async def test_empty_storage_dataloss(): a = MemoryStorage() b = MemoryStorage() - a.upload(Item("UID:1")) - a.upload(Item("UID:2")) + await a.upload(Item("UID:1")) + await a.upload(Item("UID:2")) status = {} - sync(a, b, status) + await sync(a, b, status) with pytest.raises(StorageEmpty): - sync(MemoryStorage(), b, status) + await sync(MemoryStorage(), b, status) with pytest.raises(StorageEmpty): - sync(a, MemoryStorage(), status) + await sync(a, MemoryStorage(), status) -def test_no_uids(): +@pytest.mark.asyncio +async def test_no_uids(): a = MemoryStorage() b = MemoryStorage() - a.upload(Item("ASDF")) - b.upload(Item("FOOBAR")) + await a.upload(Item("ASDF")) + await b.upload(Item("FOOBAR")) status = {} - sync(a, b, status) + await sync(a, b, status) assert items(a) == items(b) == {"ASDF", "FOOBAR"} -def test_changed_uids(): +@pytest.mark.asyncio +async def test_changed_uids(): a = MemoryStorage() b = MemoryStorage() - href_a, etag_a = a.upload(Item("UID:A-ONE")) - href_b, etag_b = b.upload(Item("UID:B-ONE")) + href_a, etag_a = await a.upload(Item("UID:A-ONE")) + href_b, etag_b = await b.upload(Item("UID:B-ONE")) status = {} - sync(a, b, status) + await sync(a, b, status) - a.update(href_a, Item("UID:A-TWO"), etag_a) - sync(a, b, status) + await a.update(href_a, Item("UID:A-TWO"), etag_a) + await sync(a, b, status) -def test_both_readonly(): +@pytest.mark.asyncio +async def test_both_readonly(): a = MemoryStorage(read_only=True) b = MemoryStorage(read_only=True) assert a.read_only assert b.read_only status = {} with pytest.raises(BothReadOnly): - sync(a, b, status) + await sync(a, b, status) -def test_partial_sync_revert(): +@pytest.mark.asyncio +async def test_partial_sync_revert(): a = MemoryStorage(instance_name="a") b = MemoryStorage(instance_name="b") status = {} - a.upload(Item("UID:1")) - b.upload(Item("UID:2")) + await a.upload(Item("UID:1")) + await b.upload(Item("UID:2")) b.read_only = True - sync(a, b, status, partial_sync="revert") + await sync(a, b, status, partial_sync="revert") assert len(status) == 2 assert items(a) == {"UID:1", "UID:2"} assert items(b) == {"UID:2"} - sync(a, b, status, partial_sync="revert") + await sync(a, b, status, partial_sync="revert") assert len(status) == 1 assert items(a) == {"UID:2"} assert items(b) == {"UID:2"} @@ -399,37 +422,39 @@ def test_partial_sync_revert(): # Check that updates get reverted a.items[next(iter(a.items))] = ("foo", Item("UID:2\nupdated")) assert items(a) == {"UID:2\nupdated"} - sync(a, b, status, partial_sync="revert") + await sync(a, b, status, partial_sync="revert") assert len(status) == 1 assert items(a) == {"UID:2\nupdated"} - sync(a, b, status, partial_sync="revert") + await sync(a, b, status, partial_sync="revert") assert items(a) == {"UID:2"} # Check that deletions get reverted a.items.clear() - sync(a, b, status, partial_sync="revert", force_delete=True) - sync(a, b, status, partial_sync="revert", force_delete=True) + await sync(a, b, status, partial_sync="revert", force_delete=True) + await sync(a, b, status, partial_sync="revert", force_delete=True) assert items(a) == {"UID:2"} @pytest.mark.parametrize("sync_inbetween", (True, False)) -def test_ident_conflict(sync_inbetween): +@pytest.mark.asyncio +async def test_ident_conflict(sync_inbetween): a = MemoryStorage() b = MemoryStorage() status = {} - href_a, etag_a = a.upload(Item("UID:aaa")) - href_b, etag_b = a.upload(Item("UID:bbb")) + href_a, etag_a = await a.upload(Item("UID:aaa")) + href_b, etag_b = await a.upload(Item("UID:bbb")) if sync_inbetween: - sync(a, b, status) + await sync(a, b, status) - a.update(href_a, Item("UID:xxx"), etag_a) - a.update(href_b, Item("UID:xxx"), etag_b) + await a.update(href_a, Item("UID:xxx"), etag_a) + await a.update(href_b, Item("UID:xxx"), etag_b) with pytest.raises(IdentConflict): - sync(a, b, status) + await sync(a, b, status) -def test_moved_href(): +@pytest.mark.asyncio +async def test_moved_href(): """ Concrete application: ppl_ stores contact aliases in filenames, which means item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but @@ -440,8 +465,8 @@ def test_moved_href(): a = MemoryStorage() b = MemoryStorage() status = {} - href, etag = a.upload(Item("UID:haha")) - sync(a, b, status) + href, etag = await a.upload(Item("UID:haha")) + await sync(a, b, status) b.items["lol"] = b.items.pop("haha") @@ -451,7 +476,7 @@ def test_moved_href(): # No actual sync actions a.delete = a.update = a.upload = b.delete = b.update = b.upload = blow_up - sync(a, b, status) + await sync(a, b, status) assert len(status) == 1 assert items(a) == items(b) == {"UID:haha"} assert status["haha"][1]["href"] == "lol" @@ -460,12 +485,13 @@ def test_moved_href(): # Further sync should be a noop. Not even prefetching should occur. b.get_multi = blow_up - sync(a, b, status) + await sync(a, b, status) assert old_status == status assert items(a) == items(b) == {"UID:haha"} -def test_bogus_etag_change(): +@pytest.mark.asyncio +async def test_bogus_etag_change(): """Assert that sync algorithm is resilient against etag changes if content didn\'t change. @@ -475,27 +501,33 @@ def test_bogus_etag_change(): a = MemoryStorage() b = MemoryStorage() status = {} - href_a, etag_a = a.upload(Item("UID:ASDASD")) - sync(a, b, status) - assert len(status) == len(list(a.list())) == len(list(b.list())) == 1 + href_a, etag_a = await a.upload(Item("UID:ASDASD")) + await sync(a, b, status) + assert ( + len(status) + == len(await aiostream.stream.list(a.list())) + == len(await aiostream.stream.list(b.list())) + == 1 + ) - ((href_b, etag_b),) = b.list() - a.update(href_a, Item("UID:ASDASD"), etag_a) - b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b) + ((href_b, etag_b),) = await aiostream.stream.list(b.list()) + await a.update(href_a, Item("UID:ASDASD"), etag_a) + await b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b) b.delete = b.update = b.upload = blow_up - sync(a, b, status) + await sync(a, b, status) assert len(status) == 1 assert items(a) == items(b) == {"UID:ASDASD\nACTUALCHANGE:YES"} -def test_unicode_hrefs(): +@pytest.mark.asyncio +async def test_unicode_hrefs(): a = MemoryStorage() b = MemoryStorage() status = {} - href, etag = a.upload(Item("UID:äää")) - sync(a, b, status) + href, etag = await a.upload(Item("UID:äää")) + await sync(a, b, status) class ActionIntentionallyFailed(Exception): @@ -511,11 +543,12 @@ class SyncMachine(RuleBasedStateMachine): Storage = Bundle("storage") @rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans()) + @pytest.mark.asyncio def newstorage(self, flaky_etags, null_etag_on_upload): s = MemoryStorage() if flaky_etags: - def get(href): + async def get(href): old_etag, item = s.items[href] etag = _random_string() s.items[href] = etag, item @@ -526,8 +559,15 @@ class SyncMachine(RuleBasedStateMachine): if null_etag_on_upload: _old_upload = s.upload _old_update = s.update - s.upload = lambda item: (_old_upload(item)[0], "NULL") - s.update = lambda h, i, e: _old_update(h, i, e) and "NULL" + + async def upload(item): + return ((await _old_upload(item)))[0], "NULL" + + async def update(href, item, etag): + return await _old_update(href, item, etag) and "NULL" + + s.upload = upload + s.update = update return s @@ -547,11 +587,11 @@ class SyncMachine(RuleBasedStateMachine): _old_upload = s.upload _old_update = s.update - def upload(item): - return _old_upload(item)[0], None + async def upload(item): + return (await _old_upload(item))[0], None - def update(href, item, etag): - _old_update(href, item, etag) + async def update(href, item, etag): + return await _old_update(href, item, etag) s.upload = upload s.update = update @@ -590,66 +630,73 @@ class SyncMachine(RuleBasedStateMachine): with_error_callback, partial_sync, ): - assume(a is not b) - old_items_a = items(a) - old_items_b = items(b) + async def inner(): + assume(a is not b) + old_items_a = items(a) + old_items_b = items(b) - a.instance_name = "a" - b.instance_name = "b" + a.instance_name = "a" + b.instance_name = "b" - errors = [] + errors = [] - if with_error_callback: - error_callback = errors.append - else: - error_callback = None + if with_error_callback: + error_callback = errors.append + else: + error_callback = None - try: - # If one storage is read-only, double-sync because changes don't - # get reverted immediately. - for _ in range(2 if a.read_only or b.read_only else 1): - sync( - a, - b, - status, - force_delete=force_delete, - conflict_resolution=conflict_resolution, - error_callback=error_callback, - partial_sync=partial_sync, + try: + # If one storage is read-only, double-sync because changes don't + # get reverted immediately. + for _ in range(2 if a.read_only or b.read_only else 1): + await sync( + a, + b, + status, + force_delete=force_delete, + conflict_resolution=conflict_resolution, + error_callback=error_callback, + partial_sync=partial_sync, + ) + + for e in errors: + raise e + except PartialSync: + assert partial_sync == "error" + except ActionIntentionallyFailed: + pass + except BothReadOnly: + assert a.read_only and b.read_only + assume(False) + except StorageEmpty: + if force_delete: + raise + else: + not_a = not await aiostream.stream.list(a.list()) + not_b = not await aiostream.stream.list(b.list()) + assert not_a or not_b + else: + items_a = items(a) + items_b = items(b) + + assert items_a == items_b or partial_sync == "ignore" + assert items_a == old_items_a or not a.read_only + assert items_b == old_items_b or not b.read_only + + assert ( + set(a.items) | set(b.items) == set(status) + or partial_sync == "ignore" ) - for e in errors: - raise e - except PartialSync: - assert partial_sync == "error" - except ActionIntentionallyFailed: - pass - except BothReadOnly: - assert a.read_only and b.read_only - assume(False) - except StorageEmpty: - if force_delete: - raise - else: - assert not list(a.list()) or not list(b.list()) - else: - items_a = items(a) - items_b = items(b) - - assert items_a == items_b or partial_sync == "ignore" - assert items_a == old_items_a or not a.read_only - assert items_b == old_items_b or not b.read_only - - assert ( - set(a.items) | set(b.items) == set(status) or partial_sync == "ignore" - ) + asyncio.run(inner()) TestSyncMachine = SyncMachine.TestCase @pytest.mark.parametrize("error_callback", [True, False]) -def test_rollback(error_callback): +@pytest.mark.asyncio +async def test_rollback(error_callback): a = MemoryStorage() b = MemoryStorage() status = {} @@ -662,7 +709,7 @@ def test_rollback(error_callback): if error_callback: errors = [] - sync( + await sync( a, b, status=status, @@ -677,16 +724,22 @@ def test_rollback(error_callback): assert status["1"] else: with pytest.raises(ActionIntentionallyFailed): - sync(a, b, status=status, conflict_resolution="a wins") + await sync(a, b, status=status, conflict_resolution="a wins") -def test_duplicate_hrefs(): +@pytest.mark.asyncio +async def test_duplicate_hrefs(): a = MemoryStorage() b = MemoryStorage() - a.list = lambda: [("a", "a")] * 3 + + async def fake_list(): + for item in [("a", "a")] * 3: + yield item + + a.list = fake_list a.items["a"] = ("a", Item("UID:a")) status = {} - sync(a, b, status) + await sync(a, b, status) with pytest.raises(AssertionError): - sync(a, b, status) + await sync(a, b, status) diff --git a/tests/unit/test_metasync.py b/tests/unit/test_metasync.py index c32fbe1..5eb84cd 100644 --- a/tests/unit/test_metasync.py +++ b/tests/unit/test_metasync.py @@ -12,105 +12,122 @@ from vdirsyncer.storage.base import normalize_meta_value from vdirsyncer.storage.memory import MemoryStorage -def test_irrelevant_status(): +@pytest.mark.asyncio +async def test_irrelevant_status(): a = MemoryStorage() b = MemoryStorage() status = {"foo": "bar"} - metasync(a, b, status, keys=()) + await metasync(a, b, status, keys=()) assert not status -def test_basic(monkeypatch): +@pytest.mark.asyncio +async def test_basic(monkeypatch): a = MemoryStorage() b = MemoryStorage() status = {} - a.set_meta("foo", "bar") - metasync(a, b, status, keys=["foo"]) - assert a.get_meta("foo") == b.get_meta("foo") == "bar" + await a.set_meta("foo", "bar") + await metasync(a, b, status, keys=["foo"]) + assert await a.get_meta("foo") == await b.get_meta("foo") == "bar" - a.set_meta("foo", "baz") - metasync(a, b, status, keys=["foo"]) - assert a.get_meta("foo") == b.get_meta("foo") == "baz" + await a.set_meta("foo", "baz") + await metasync(a, b, status, keys=["foo"]) + assert await a.get_meta("foo") == await b.get_meta("foo") == "baz" monkeypatch.setattr(a, "set_meta", blow_up) monkeypatch.setattr(b, "set_meta", blow_up) - metasync(a, b, status, keys=["foo"]) - assert a.get_meta("foo") == b.get_meta("foo") == "baz" + await metasync(a, b, status, keys=["foo"]) + assert await a.get_meta("foo") == await b.get_meta("foo") == "baz" monkeypatch.undo() monkeypatch.undo() - b.set_meta("foo", None) - metasync(a, b, status, keys=["foo"]) - assert not a.get_meta("foo") and not b.get_meta("foo") + await b.set_meta("foo", None) + await metasync(a, b, status, keys=["foo"]) + assert not await a.get_meta("foo") and not await b.get_meta("foo") @pytest.fixture -def conflict_state(request): +@pytest.mark.asyncio +async def conflict_state(request, event_loop): a = MemoryStorage() b = MemoryStorage() status = {} - a.set_meta("foo", "bar") - b.set_meta("foo", "baz") + await a.set_meta("foo", "bar") + await b.set_meta("foo", "baz") def cleanup(): - assert a.get_meta("foo") == "bar" - assert b.get_meta("foo") == "baz" - assert not status + async def do_cleanup(): + assert await a.get_meta("foo") == "bar" + assert await b.get_meta("foo") == "baz" + assert not status + + event_loop.run_until_complete(do_cleanup()) request.addfinalizer(cleanup) return a, b, status -def test_conflict(conflict_state): +@pytest.mark.asyncio +async def test_conflict(conflict_state): a, b, status = conflict_state with pytest.raises(MetaSyncConflict): - metasync(a, b, status, keys=["foo"]) + await metasync(a, b, status, keys=["foo"]) -def test_invalid_conflict_resolution(conflict_state): +@pytest.mark.asyncio +async def test_invalid_conflict_resolution(conflict_state): a, b, status = conflict_state with pytest.raises(UserError) as excinfo: - metasync(a, b, status, keys=["foo"], conflict_resolution="foo") + await metasync(a, b, status, keys=["foo"], conflict_resolution="foo") assert "Invalid conflict resolution setting" in str(excinfo.value) -def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch): +@pytest.mark.asyncio +async def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch): a, b, status = conflict_state warnings = [] monkeypatch.setattr(logger, "warning", warnings.append) with pytest.raises(MetaSyncConflict): - metasync(a, b, status, keys=["foo"], conflict_resolution=lambda *a, **kw: None) + await metasync( + a, + b, + status, + keys=["foo"], + conflict_resolution=lambda *a, **kw: None, + ) assert warnings == ["Custom commands don't work on metasync."] -def test_conflict_same_content(): +@pytest.mark.asyncio +async def test_conflict_same_content(): a = MemoryStorage() b = MemoryStorage() status = {} - a.set_meta("foo", "bar") - b.set_meta("foo", "bar") + await a.set_meta("foo", "bar") + await b.set_meta("foo", "bar") - metasync(a, b, status, keys=["foo"]) - assert a.get_meta("foo") == b.get_meta("foo") == status["foo"] == "bar" + await metasync(a, b, status, keys=["foo"]) + assert await a.get_meta("foo") == await b.get_meta("foo") == status["foo"] == "bar" @pytest.mark.parametrize("wins", "ab") -def test_conflict_x_wins(wins): +@pytest.mark.asyncio +async def test_conflict_x_wins(wins): a = MemoryStorage() b = MemoryStorage() status = {} - a.set_meta("foo", "bar") - b.set_meta("foo", "baz") + await a.set_meta("foo", "bar") + await b.set_meta("foo", "baz") - metasync( + await metasync( a, b, status, @@ -119,8 +136,8 @@ def test_conflict_x_wins(wins): ) assert ( - a.get_meta("foo") - == b.get_meta("foo") + await a.get_meta("foo") + == await b.get_meta("foo") == status["foo"] == ("bar" if wins == "a" else "baz") ) @@ -148,7 +165,8 @@ metadata = st.dictionaries(keys, values) keys={"0"}, conflict_resolution="a wins", ) -def test_fuzzing(a, b, status, keys, conflict_resolution): +@pytest.mark.asyncio +async def test_fuzzing(a, b, status, keys, conflict_resolution): def _get_storage(m, instance_name): s = MemoryStorage(instance_name=instance_name) s.metadata = m @@ -159,13 +177,13 @@ def test_fuzzing(a, b, status, keys, conflict_resolution): winning_storage = a if conflict_resolution == "a wins" else b expected_values = { - key: winning_storage.get_meta(key) for key in keys if key not in status + key: await winning_storage.get_meta(key) for key in keys if key not in status } - metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution) + await metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution) for key in keys: s = status.get(key, "") - assert a.get_meta(key) == b.get_meta(key) == s + assert await a.get_meta(key) == await b.get_meta(key) == s if expected_values.get(key, "") and s: assert s == expected_values[key] diff --git a/tests/unit/test_repair.py b/tests/unit/test_repair.py index 7d35ff9..31f51a8 100644 --- a/tests/unit/test_repair.py +++ b/tests/unit/test_repair.py @@ -1,3 +1,4 @@ +import aiostream import pytest from hypothesis import given from hypothesis import HealthCheck @@ -15,37 +16,42 @@ from vdirsyncer.vobject import Item @given(uid=uid_strategy) # Using the random module for UIDs: @settings(suppress_health_check=HealthCheck.all()) -def test_repair_uids(uid): +@pytest.mark.asyncio +async def test_repair_uids(uid): s = MemoryStorage() s.items = { "one": ("asdf", Item(f"BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD")), "two": ("asdf", Item(f"BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD")), } - uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()] + uid1, uid2 = [(await s.get(href))[0].uid async for href, etag in s.list()] assert uid1 == uid2 - repair_storage(s, repair_unsafe_uid=False) + await repair_storage(s, repair_unsafe_uid=False) - uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()] + uid1, uid2 = [ + (await s.get(href))[0].uid + for href, etag in await aiostream.stream.list(s.list()) + ] assert uid1 != uid2 @given(uid=uid_strategy.filter(lambda x: not href_safe(x))) # Using the random module for UIDs: @settings(suppress_health_check=HealthCheck.all()) -def test_repair_unsafe_uids(uid): +@pytest.mark.asyncio +async def test_repair_unsafe_uids(uid): s = MemoryStorage() item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD") - href, etag = s.upload(item) - assert s.get(href)[0].uid == uid + href, etag = await s.upload(item) + assert (await s.get(href))[0].uid == uid assert not href_safe(uid) - repair_storage(s, repair_unsafe_uid=True) + await repair_storage(s, repair_unsafe_uid=True) - new_href = list(s.list())[0][0] + new_href = (await aiostream.stream.list(s.list()))[0][0] assert href_safe(new_href) - newuid = s.get(new_href)[0].uid + newuid = (await s.get(new_href))[0].uid assert href_safe(newuid) diff --git a/vdirsyncer/cli/__init__.py b/vdirsyncer/cli/__init__.py index a1a8191..2d914d3 100644 --- a/vdirsyncer/cli/__init__.py +++ b/vdirsyncer/cli/__init__.py @@ -1,8 +1,10 @@ +import asyncio import functools import json import logging import sys +import aiohttp import click import click_log @@ -124,17 +126,26 @@ def sync(ctx, collections, force_delete): """ from .tasks import prepare_pair, sync_collection - for pair_name, collections in collections: - for collection, config in prepare_pair( - pair_name=pair_name, - collections=collections, - config=ctx.config, - ): - sync_collection( - collection=collection, - general=config, - force_delete=force_delete, - ) + async def main(collections): + conn = aiohttp.TCPConnector(limit_per_host=16) + + for pair_name, collections in collections: + async for collection, config in prepare_pair( + pair_name=pair_name, + collections=collections, + config=ctx.config, + connector=conn, + ): + await sync_collection( + collection=collection, + general=config, + force_delete=force_delete, + connector=conn, + ) + + await conn.close() + + asyncio.run(main(collections)) @app.command() @@ -149,13 +160,31 @@ def metasync(ctx, collections): """ from .tasks import prepare_pair, metasync_collection - for pair_name, collections in collections: - for collection, config in prepare_pair( - pair_name=pair_name, - collections=collections, - config=ctx.config, - ): - metasync_collection(collection=collection, general=config) + async def main(collections): + conn = aiohttp.TCPConnector(limit_per_host=16) + + for pair_name, collections in collections: + collections = prepare_pair( + pair_name=pair_name, + collections=collections, + config=ctx.config, + connector=conn, + ) + + await asyncio.gather( + *[ + metasync_collection( + collection=collection, + general=config, + connector=conn, + ) + async for collection, config in collections + ] + ) + + await conn.close() + + asyncio.run(main(collections)) @app.command() @@ -178,15 +207,23 @@ def discover(ctx, pairs, list): config = ctx.config - for pair_name in pairs or config.pairs: - pair = config.get_pair(pair_name) + async def main(): + conn = aiohttp.TCPConnector(limit_per_host=16) - discover_collections( - status_path=config.general["status_path"], - pair=pair, - from_cache=False, - list_collections=list, - ) + for pair_name in pairs or config.pairs: + pair = config.get_pair(pair_name) + + await discover_collections( + status_path=config.general["status_path"], + pair=pair, + from_cache=False, + list_collections=list, + connector=conn, + ) + + await conn.close() + + asyncio.run(main()) @app.command() @@ -225,7 +262,18 @@ def repair(ctx, collection, repair_unsafe_uid): "turn off other client's synchronization features." ) click.confirm("Do you want to continue?", abort=True) - repair_collection(ctx.config, collection, repair_unsafe_uid=repair_unsafe_uid) + + async def main(): + conn = aiohttp.TCPConnector(limit_per_host=16) + await repair_collection( + ctx.config, + collection, + repair_unsafe_uid=repair_unsafe_uid, + connector=conn, + ) + await conn.close() + + asyncio.run(main()) @app.command() diff --git a/vdirsyncer/cli/discover.py b/vdirsyncer/cli/discover.py index e960217..877a10c 100644 --- a/vdirsyncer/cli/discover.py +++ b/vdirsyncer/cli/discover.py @@ -1,10 +1,13 @@ +import asyncio import hashlib import json import logging import sys +import aiohttp +import aiostream + from .. import exceptions -from ..utils import cached_property from .utils import handle_collection_not_found from .utils import handle_storage_init_error from .utils import load_status @@ -35,7 +38,14 @@ def _get_collections_cache_key(pair): return m.hexdigest() -def collections_for_pair(status_path, pair, from_cache=True, list_collections=False): +async def collections_for_pair( + status_path, + pair, + from_cache=True, + list_collections=False, + *, + connector: aiohttp.TCPConnector, +): """Determine all configured collections for a given pair. Takes care of shortcut expansion and result caching. @@ -67,16 +77,24 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa logger.info("Discovering collections for pair {}".format(pair.name)) - a_discovered = _DiscoverResult(pair.config_a) - b_discovered = _DiscoverResult(pair.config_b) + a_discovered = _DiscoverResult(pair.config_a, connector=connector) + b_discovered = _DiscoverResult(pair.config_b, connector=connector) if list_collections: - _print_collections(pair.config_a["instance_name"], a_discovered.get_self) - _print_collections(pair.config_b["instance_name"], b_discovered.get_self) + await _print_collections( + pair.config_a["instance_name"], + a_discovered.get_self, + connector=connector, + ) + await _print_collections( + pair.config_b["instance_name"], + b_discovered.get_self, + connector=connector, + ) # We have to use a list here because the special None/null value would get # mangled to string (because JSON objects always have string keys). - rv = list( + rv = await aiostream.stream.list( expand_collections( shortcuts=pair.collections, config_a=pair.config_a, @@ -87,7 +105,7 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa ) ) - _sanity_check_collections(rv) + await _sanity_check_collections(rv, connector=connector) save_status( status_path, @@ -103,10 +121,14 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa return rv -def _sanity_check_collections(collections): +async def _sanity_check_collections(collections, *, connector): + tasks = [] + for _, (a_args, b_args) in collections: - storage_instance_from_config(a_args) - storage_instance_from_config(b_args) + tasks.append(storage_instance_from_config(a_args, connector=connector)) + tasks.append(storage_instance_from_config(b_args, connector=connector)) + + await asyncio.gather(*tasks) def _compress_collections_cache(collections, config_a, config_b): @@ -134,17 +156,28 @@ def _expand_collections_cache(collections, config_a, config_b): class _DiscoverResult: - def __init__(self, config): + def __init__(self, config, *, connector): self._cls, _ = storage_class_from_config(config) - self._config = config - def get_self(self): + if self._cls.__name__ in [ + "CardDAVStorage", + "CalDAVStorage", + "GoogleCalendarStorage", + ]: + assert connector is not None + config["connector"] = connector + + self._config = config + self._discovered = None + + async def get_self(self): + if self._discovered is None: + self._discovered = await self._discover() return self._discovered - @cached_property - def _discovered(self): + async def _discover(self): try: - discovered = list(self._cls.discover(**self._config)) + discovered = await aiostream.stream.list(self._cls.discover(**self._config)) except NotImplementedError: return {} except Exception: @@ -158,7 +191,7 @@ class _DiscoverResult: return rv -def expand_collections( +async def expand_collections( shortcuts, config_a, config_b, @@ -173,9 +206,9 @@ def expand_collections( for shortcut in shortcuts: if shortcut == "from a": - collections = get_a_discovered() + collections = await get_a_discovered() elif shortcut == "from b": - collections = get_b_discovered() + collections = await get_b_discovered() else: collections = [shortcut] @@ -189,17 +222,23 @@ def expand_collections( continue handled_collections.add(collection) - a_args = _collection_from_discovered( - get_a_discovered, collection_a, config_a, _handle_collection_not_found + a_args = await _collection_from_discovered( + get_a_discovered, + collection_a, + config_a, + _handle_collection_not_found, ) - b_args = _collection_from_discovered( - get_b_discovered, collection_b, config_b, _handle_collection_not_found + b_args = await _collection_from_discovered( + get_b_discovered, + collection_b, + config_b, + _handle_collection_not_found, ) yield collection, (a_args, b_args) -def _collection_from_discovered( +async def _collection_from_discovered( get_discovered, collection, config, _handle_collection_not_found ): if collection is None: @@ -208,14 +247,19 @@ def _collection_from_discovered( return args try: - return get_discovered()[collection] + return (await get_discovered())[collection] except KeyError: - return _handle_collection_not_found(config, collection) + return await _handle_collection_not_found(config, collection) -def _print_collections(instance_name, get_discovered): +async def _print_collections( + instance_name: str, + get_discovered, + *, + connector: aiohttp.TCPConnector, +): try: - discovered = get_discovered() + discovered = await get_discovered() except exceptions.UserError: raise except Exception: @@ -238,8 +282,12 @@ def _print_collections(instance_name, get_discovered): args["instance_name"] = instance_name try: - storage = storage_instance_from_config(args, create=False) - displayname = storage.get_meta("displayname") + storage = await storage_instance_from_config( + args, + create=False, + connector=connector, + ) + displayname = await storage.get_meta("displayname") except Exception: displayname = "" diff --git a/vdirsyncer/cli/tasks.py b/vdirsyncer/cli/tasks.py index 3799dc1..f27c0d1 100644 --- a/vdirsyncer/cli/tasks.py +++ b/vdirsyncer/cli/tasks.py @@ -1,5 +1,7 @@ import json +import aiohttp + from .. import exceptions from .. import sync from .config import CollectionConfig @@ -15,11 +17,15 @@ from .utils import manage_sync_status from .utils import save_status -def prepare_pair(pair_name, collections, config): +async def prepare_pair(pair_name, collections, config, *, connector): pair = config.get_pair(pair_name) all_collections = dict( - collections_for_pair(status_path=config.general["status_path"], pair=pair) + await collections_for_pair( + status_path=config.general["status_path"], + pair=pair, + connector=connector, + ) ) for collection_name in collections or all_collections: @@ -37,15 +43,21 @@ def prepare_pair(pair_name, collections, config): yield collection, config.general -def sync_collection(collection, general, force_delete): +async def sync_collection( + collection, + general, + force_delete, + *, + connector: aiohttp.TCPConnector, +): pair = collection.pair status_name = get_status_name(pair.name, collection.name) try: cli_logger.info(f"Syncing {status_name}") - a = storage_instance_from_config(collection.config_a) - b = storage_instance_from_config(collection.config_b) + a = await storage_instance_from_config(collection.config_a, connector=connector) + b = await storage_instance_from_config(collection.config_b, connector=connector) sync_failed = False @@ -57,7 +69,7 @@ def sync_collection(collection, general, force_delete): with manage_sync_status( general["status_path"], pair.name, collection.name ) as status: - sync.sync( + await sync.sync( a, b, status, @@ -76,9 +88,9 @@ def sync_collection(collection, general, force_delete): raise JobFailed() -def discover_collections(pair, **kwargs): - rv = collections_for_pair(pair=pair, **kwargs) - collections = list(c for c, (a, b) in rv) +async def discover_collections(pair, **kwargs): + rv = await collections_for_pair(pair=pair, **kwargs) + collections = [c for c, (a, b) in rv] if collections == [None]: collections = None cli_logger.info( @@ -86,7 +98,13 @@ def discover_collections(pair, **kwargs): ) -def repair_collection(config, collection, repair_unsafe_uid): +async def repair_collection( + config, + collection, + repair_unsafe_uid, + *, + connector: aiohttp.TCPConnector, +): from ..repair import repair_storage storage_name, collection = collection, None @@ -99,7 +117,7 @@ def repair_collection(config, collection, repair_unsafe_uid): if collection is not None: cli_logger.info("Discovering collections (skipping cache).") cls, config = storage_class_from_config(config) - for config in cls.discover(**config): + async for config in cls.discover(**config): if config["collection"] == collection: break else: @@ -110,14 +128,14 @@ def repair_collection(config, collection, repair_unsafe_uid): ) config["type"] = storage_type - storage = storage_instance_from_config(config) + storage = await storage_instance_from_config(config, connector=connector) cli_logger.info(f"Repairing {storage_name}/{collection}") cli_logger.warning("Make sure no other program is talking to the server.") - repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid) + await repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid) -def metasync_collection(collection, general): +async def metasync_collection(collection, general, *, connector: aiohttp.TCPConnector): from ..metasync import metasync pair = collection.pair @@ -133,10 +151,10 @@ def metasync_collection(collection, general): or {} ) - a = storage_instance_from_config(collection.config_a) - b = storage_instance_from_config(collection.config_b) + a = await storage_instance_from_config(collection.config_a, connector=connector) + b = await storage_instance_from_config(collection.config_b, connector=connector) - metasync( + await metasync( a, b, status, diff --git a/vdirsyncer/cli/utils.py b/vdirsyncer/cli/utils.py index 3f7dd50..010871e 100644 --- a/vdirsyncer/cli/utils.py +++ b/vdirsyncer/cli/utils.py @@ -5,6 +5,7 @@ import json import os import sys +import aiohttp import click from atomicwrites import atomic_write @@ -252,22 +253,37 @@ def storage_class_from_config(config): return cls, config -def storage_instance_from_config(config, create=True): +async def storage_instance_from_config( + config, + create=True, + *, + connector: aiohttp.TCPConnector, +): """ :param config: A configuration dictionary to pass as kwargs to the class corresponding to config['type'] """ + from vdirsyncer.storage.dav import DAVStorage + from vdirsyncer.storage.http import HttpStorage cls, new_config = storage_class_from_config(config) + if issubclass(cls, DAVStorage) or issubclass(cls, HttpStorage): + assert connector is not None # FIXME: hack? + new_config["connector"] = connector + try: return cls(**new_config) except exceptions.CollectionNotFound as e: if create: - config = handle_collection_not_found( + config = await handle_collection_not_found( config, config.get("collection", None), e=str(e) ) - return storage_instance_from_config(config, create=False) + return await storage_instance_from_config( + config, + create=False, + connector=connector, + ) else: raise except Exception: @@ -319,7 +335,7 @@ def assert_permissions(path, wanted): os.chmod(path, wanted) -def handle_collection_not_found(config, collection, e=None): +async def handle_collection_not_found(config, collection, e=None): storage_name = config.get("instance_name", None) cli_logger.warning( @@ -333,7 +349,7 @@ def handle_collection_not_found(config, collection, e=None): cls, config = storage_class_from_config(config) config["collection"] = collection try: - args = cls.create_collection(**config) + args = await cls.create_collection(**config) args["type"] = storage_type return args except NotImplementedError as e: diff --git a/vdirsyncer/http.py b/vdirsyncer/http.py index f6c6a37..7d78dcd 100644 --- a/vdirsyncer/http.py +++ b/vdirsyncer/http.py @@ -1,6 +1,6 @@ import logging -import requests +import aiohttp from . import __version__ from . import DOCS_HOME @@ -99,23 +99,8 @@ def prepare_client_cert(cert): return cert -def _install_fingerprint_adapter(session, fingerprint): - prefix = "https://" - try: - from requests_toolbelt.adapters.fingerprint import FingerprintAdapter - except ImportError: - raise RuntimeError( - "`verify_fingerprint` can only be used with " - "requests-toolbelt versions >= 0.4.0" - ) - - if not isinstance(session.adapters[prefix], FingerprintAdapter): - fingerprint_adapter = FingerprintAdapter(fingerprint) - session.mount(prefix, fingerprint_adapter) - - -def request( - method, url, session=None, latin1_fallback=True, verify_fingerprint=None, **kwargs +async def request( + method, url, session, latin1_fallback=True, verify_fingerprint=None, **kwargs ): """ Wrapper method for requests, to ease logging and mocking. Parameters should @@ -132,16 +117,20 @@ def request( https://github.com/kennethreitz/requests/issues/2042 """ - if session is None: - session = requests.Session() - if verify_fingerprint is not None: - _install_fingerprint_adapter(session, verify_fingerprint) + ssl = aiohttp.Fingerprint(bytes.fromhex(verify_fingerprint.replace(":", ""))) + kwargs.pop("verify", None) + elif kwargs.pop("verify", None) is False: + ssl = False + else: + ssl = None # TODO XXX: Check all possible values for this session.hooks = {"response": _fix_redirects} func = session.request + # TODO: rewrite using + # https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing logger.debug("=" * 20) logger.debug(f"{method} {url}") logger.debug(kwargs.get("headers", {})) @@ -150,7 +139,14 @@ def request( assert isinstance(kwargs.get("data", b""), bytes) - r = func(method, url, **kwargs) + kwargs.pop("cert", None) # TODO XXX FIXME! + + # Hacks to translate API + if auth := kwargs.pop("auth", None): + kwargs["auth"] = aiohttp.BasicAuth(*auth) + + r = func(method, url, ssl=ssl, **kwargs) + r = await r # See https://github.com/kennethreitz/requests/issues/2042 content_type = r.headers.get("Content-Type", "") @@ -162,13 +158,13 @@ def request( logger.debug("Removing latin1 fallback") r.encoding = None - logger.debug(r.status_code) + logger.debug(r.status) logger.debug(r.headers) logger.debug(r.content) - if r.status_code == 412: + if r.status == 412: raise exceptions.PreconditionFailed(r.reason) - if r.status_code in (404, 410): + if r.status in (404, 410): raise exceptions.NotFoundError(r.reason) r.raise_for_status() diff --git a/vdirsyncer/metasync.py b/vdirsyncer/metasync.py index 82ec092..551e72c 100644 --- a/vdirsyncer/metasync.py +++ b/vdirsyncer/metasync.py @@ -14,24 +14,24 @@ class MetaSyncConflict(MetaSyncError): key = None -def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): - def _a_to_b(): +async def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): + async def _a_to_b(): logger.info(f"Copying {key} to {storage_b}") - storage_b.set_meta(key, a) + await storage_b.set_meta(key, a) status[key] = a - def _b_to_a(): + async def _b_to_a(): logger.info(f"Copying {key} to {storage_a}") - storage_a.set_meta(key, b) + await storage_a.set_meta(key, b) status[key] = b - def _resolve_conflict(): + async def _resolve_conflict(): if a == b: status[key] = a elif conflict_resolution == "a wins": - _a_to_b() + await _a_to_b() elif conflict_resolution == "b wins": - _b_to_a() + await _b_to_a() else: if callable(conflict_resolution): logger.warning("Custom commands don't work on metasync.") @@ -40,8 +40,8 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): raise MetaSyncConflict(key) for key in keys: - a = storage_a.get_meta(key) - b = storage_b.get_meta(key) + a = await storage_a.get_meta(key) + b = await storage_b.get_meta(key) s = normalize_meta_value(status.get(key)) logger.debug(f"Key: {key}") logger.debug(f"A: {a}") @@ -49,11 +49,11 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): logger.debug(f"S: {s}") if a != s and b != s: - _resolve_conflict() + await _resolve_conflict() elif a != s and b == s: - _a_to_b() + await _a_to_b() elif a == s and b != s: - _b_to_a() + await _b_to_a() else: assert a == b diff --git a/vdirsyncer/repair.py b/vdirsyncer/repair.py index ac76f7d..104b2ad 100644 --- a/vdirsyncer/repair.py +++ b/vdirsyncer/repair.py @@ -1,6 +1,8 @@ import logging from os.path import basename +import aiostream + from .utils import generate_href from .utils import href_safe @@ -11,11 +13,11 @@ class IrreparableItem(Exception): pass -def repair_storage(storage, repair_unsafe_uid): +async def repair_storage(storage, repair_unsafe_uid): seen_uids = set() - all_hrefs = list(storage.list()) + all_hrefs = await aiostream.stream.list(storage.list()) for i, (href, _) in enumerate(all_hrefs): - item, etag = storage.get(href) + item, etag = await storage.get(href) logger.info("[{}/{}] Processing {}".format(i, len(all_hrefs), href)) try: @@ -32,10 +34,10 @@ def repair_storage(storage, repair_unsafe_uid): seen_uids.add(new_item.uid) if new_item.raw != item.raw: if new_item.uid != item.uid: - storage.upload(new_item) - storage.delete(href, etag) + await storage.upload(new_item) + await storage.delete(href, etag) else: - storage.update(href, new_item, etag) + await storage.update(href, new_item, etag) def repair_item(href, item, seen_uids, repair_unsafe_uid): diff --git a/vdirsyncer/storage/base.py b/vdirsyncer/storage/base.py index 5231bea..6154355 100644 --- a/vdirsyncer/storage/base.py +++ b/vdirsyncer/storage/base.py @@ -7,10 +7,10 @@ from ..utils import uniq def mutating_storage_method(f): @functools.wraps(f) - def inner(self, *args, **kwargs): + async def inner(self, *args, **kwargs): if self.read_only: raise exceptions.ReadOnlyError("This storage is read-only.") - return f(self, *args, **kwargs) + return await f(self, *args, **kwargs) return inner @@ -77,7 +77,7 @@ class Storage(metaclass=StorageMeta): self.collection = collection @classmethod - def discover(cls, **kwargs): + async def discover(cls, **kwargs): """Discover collections given a basepath or -URL to many collections. :param **kwargs: Keyword arguments to additionally pass to the storage @@ -92,10 +92,12 @@ class Storage(metaclass=StorageMeta): from the last segment of a URL or filesystem path. """ + if False: + yield # Needs to be an async generator raise NotImplementedError() @classmethod - def create_collection(cls, collection, **kwargs): + async def create_collection(cls, collection, **kwargs): """ Create the specified collection and return the new arguments. @@ -118,13 +120,13 @@ class Storage(metaclass=StorageMeta): {x: getattr(self, x) for x in self._repr_attributes}, ) - def list(self): + async def list(self): """ :returns: list of (href, etag) """ raise NotImplementedError() - def get(self, href): + async def get(self, href): """Fetch a single item. :param href: href to fetch @@ -134,7 +136,7 @@ class Storage(metaclass=StorageMeta): """ raise NotImplementedError() - def get_multi(self, hrefs): + async def get_multi(self, hrefs): """Fetch multiple items. Duplicate hrefs must be ignored. Functionally similar to :py:meth:`get`, but might bring performance @@ -146,22 +148,22 @@ class Storage(metaclass=StorageMeta): :returns: iterable of (href, item, etag) """ for href in uniq(hrefs): - item, etag = self.get(href) + item, etag = await self.get(href) yield href, item, etag - def has(self, href): + async def has(self, href): """Check if an item exists by its href. :returns: True or False """ try: - self.get(href) + await self.get(href) except exceptions.PreconditionFailed: return False else: return True - def upload(self, item): + async def upload(self, item): """Upload a new item. In cases where the new etag cannot be atomically determined (i.e. in @@ -176,7 +178,7 @@ class Storage(metaclass=StorageMeta): """ raise NotImplementedError() - def update(self, href, item, etag): + async def update(self, href, item, etag): """Update an item. The etag may be none in some cases, see `upload`. @@ -189,7 +191,7 @@ class Storage(metaclass=StorageMeta): """ raise NotImplementedError() - def delete(self, href, etag): + async def delete(self, href, etag): """Delete an item by href. :raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has @@ -197,8 +199,8 @@ class Storage(metaclass=StorageMeta): """ raise NotImplementedError() - @contextlib.contextmanager - def at_once(self): + @contextlib.asynccontextmanager + async def at_once(self): """A contextmanager that buffers all writes. Essentially, this:: @@ -217,7 +219,7 @@ class Storage(metaclass=StorageMeta): """ yield - def get_meta(self, key): + async def get_meta(self, key): """Get metadata value for collection/storage. See the vdir specification for the keys that *have* to be accepted. @@ -228,7 +230,7 @@ class Storage(metaclass=StorageMeta): raise NotImplementedError("This storage does not support metadata.") - def set_meta(self, key, value): + async def set_meta(self, key, value): """Get metadata value for collection/storage. :param key: The metadata key. diff --git a/vdirsyncer/storage/dav.py b/vdirsyncer/storage/dav.py index b9e4055..8dc4319 100644 --- a/vdirsyncer/storage/dav.py +++ b/vdirsyncer/storage/dav.py @@ -5,8 +5,8 @@ import xml.etree.ElementTree as etree from inspect import getfullargspec from inspect import signature -import requests -from requests.exceptions import HTTPError +import aiohttp +import aiostream from .. import exceptions from .. import http @@ -18,6 +18,7 @@ from ..http import USERAGENT from ..vobject import Item from .base import normalize_meta_value from .base import Storage +from vdirsyncer.exceptions import Error dav_logger = logging.getLogger(__name__) @@ -44,10 +45,10 @@ def _contains_quoted_reserved_chars(x): return False -def _assert_multistatus_success(r): +async def _assert_multistatus_success(r): # Xandikos returns a multistatus on PUT. try: - root = _parse_xml(r.content) + root = _parse_xml(await r.content.read()) except InvalidXMLResponse: return for status in root.findall(".//{DAV:}status"): @@ -57,7 +58,7 @@ def _assert_multistatus_success(r): except (ValueError, IndexError): continue if st < 200 or st >= 400: - raise HTTPError(f"Server error: {st}") + raise Error(f"Server error: {st}") def _normalize_href(base, href): @@ -169,14 +170,14 @@ class Discover: _, collection = url.rstrip("/").rsplit("/", 1) return urlparse.unquote(collection) - def find_principal(self): + async def find_principal(self): try: - return self._find_principal_impl("") - except (HTTPError, exceptions.Error): + return await self._find_principal_impl("") + except (aiohttp.ClientResponseError, exceptions.Error): dav_logger.debug("Trying out well-known URI") - return self._find_principal_impl(self._well_known_uri) + return await self._find_principal_impl(self._well_known_uri) - def _find_principal_impl(self, url): + async def _find_principal_impl(self, url): headers = self.session.get_default_headers() headers["Depth"] = "0" body = b""" @@ -187,9 +188,14 @@ class Discover: """ - response = self.session.request("PROPFIND", url, headers=headers, data=body) + response = await self.session.request( + "PROPFIND", + url, + headers=headers, + data=body, + ) - root = _parse_xml(response.content) + root = _parse_xml(await response.content.read()) rv = root.find(".//{DAV:}current-user-principal/{DAV:}href") if rv is None: # This is for servers that don't support current-user-principal @@ -201,34 +207,37 @@ class Discover: ) ) return response.url - return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/" + return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/" - def find_home(self): - url = self.find_principal() + async def find_home(self): + url = await self.find_principal() headers = self.session.get_default_headers() headers["Depth"] = "0" - response = self.session.request( + response = await self.session.request( "PROPFIND", url, headers=headers, data=self._homeset_xml ) - root = etree.fromstring(response.content) + root = etree.fromstring(await response.content.read()) # Better don't do string formatting here, because of XML namespaces rv = root.find(".//" + self._homeset_tag + "/{DAV:}href") if rv is None: raise InvalidXMLResponse("Couldn't find home-set.") - return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/" + return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/" - def find_collections(self): + async def find_collections(self): rv = None try: - rv = list(self._find_collections_impl("")) - except (HTTPError, exceptions.Error): + rv = await aiostream.stream.list(self._find_collections_impl("")) + except (aiohttp.ClientResponseError, exceptions.Error): pass if rv: return rv + dav_logger.debug("Given URL is not a homeset URL") - return self._find_collections_impl(self.find_home()) + return await aiostream.stream.list( + self._find_collections_impl(await self.find_home()) + ) def _check_collection_resource_type(self, response): if self._resourcetype is None: @@ -245,13 +254,13 @@ class Discover: return False return True - def _find_collections_impl(self, url): + async def _find_collections_impl(self, url): headers = self.session.get_default_headers() headers["Depth"] = "1" - r = self.session.request( + r = await self.session.request( "PROPFIND", url, headers=headers, data=self._collection_xml ) - root = _parse_xml(r.content) + root = _parse_xml(await r.content.read()) done = set() for response in root.findall("{DAV:}response"): if not self._check_collection_resource_type(response): @@ -260,33 +269,33 @@ class Discover: href = response.find("{DAV:}href") if href is None: raise InvalidXMLResponse("Missing href tag for collection " "props.") - href = urlparse.urljoin(r.url, href.text) + href = urlparse.urljoin(str(r.url), href.text) if href not in done: done.add(href) yield {"href": href} - def discover(self): - for c in self.find_collections(): + async def discover(self): + for c in await self.find_collections(): url = c["href"] collection = self._get_collection_from_url(url) storage_args = dict(self.kwargs) storage_args.update({"url": url, "collection": collection}) yield storage_args - def create(self, collection): + async def create(self, collection): if collection is None: collection = self._get_collection_from_url(self.kwargs["url"]) - for c in self.discover(): + async for c in self.discover(): if c["collection"] == collection: return c - home = self.find_home() + home = await self.find_home() url = urlparse.urljoin(home, urlparse.quote(collection, "/@")) try: - url = self._create_collection_impl(url) - except HTTPError as e: + url = await self._create_collection_impl(url) + except (aiohttp.ClientResponseError, Error) as e: raise NotImplementedError(e) else: rv = dict(self.kwargs) @@ -294,7 +303,7 @@ class Discover: rv["url"] = url return rv - def _create_collection_impl(self, url): + async def _create_collection_impl(self, url): data = """ @@ -312,13 +321,13 @@ class Discover: "utf-8" ) - response = self.session.request( + response = await self.session.request( "MKCOL", url, data=data, headers=self.session.get_default_headers(), ) - return response.url + return str(response.url) class CalDiscover(Discover): @@ -350,14 +359,18 @@ class CardDiscover(Discover): class DAVSession: - """ - A helper class to connect to DAV servers. - """ + """A helper class to connect to DAV servers.""" + + connector: aiohttp.BaseConnector @classmethod def init_and_remaining_args(cls, **kwargs): + def is_arg(k): + """Return true if ``k`` is an argument of ``cls.__init__``.""" + return k in argspec.args or k in argspec.kwonlyargs + argspec = getfullargspec(cls.__init__) - self_args, remainder = utils.split_dict(kwargs, argspec.args.__contains__) + self_args, remainder = utils.split_dict(kwargs, is_arg) return cls(**self_args), remainder @@ -371,6 +384,8 @@ class DAVSession: useragent=USERAGENT, verify_fingerprint=None, auth_cert=None, + *, + connector: aiohttp.BaseConnector, ): self._settings = { "cert": prepare_client_cert(auth_cert), @@ -380,21 +395,28 @@ class DAVSession: self.useragent = useragent self.url = url.rstrip("/") + "/" - - self._session = requests.session() + self.connector = connector @utils.cached_property def parsed_url(self): return urlparse.urlparse(self.url) - def request(self, method, path, **kwargs): + async def request(self, method, path, **kwargs): url = self.url if path: - url = urlparse.urljoin(self.url, path) + url = urlparse.urljoin(str(self.url), path) more = dict(self._settings) more.update(kwargs) - return http.request(method, url, session=self._session, **more) + + # XXX: This is a temporary hack to pin-point bad refactoring. + assert self.connector is not None + async with aiohttp.ClientSession( + connector=self.connector, + connector_owner=False, + # TODO use `raise_for_status=true`, though this needs traces first, + ) as session: + return await http.request(method, url, session=session, **more) def get_default_headers(self): return { @@ -417,33 +439,41 @@ class DAVStorage(Storage): # The DAVSession class to use session_class = DAVSession + connector: aiohttp.TCPConnector + _repr_attributes = ("username", "url") _property_table = { "displayname": ("displayname", "DAV:"), } - def __init__(self, **kwargs): + def __init__(self, *, connector, **kwargs): # defined for _repr_attributes self.username = kwargs.get("username") self.url = kwargs.get("url") + self.connector = connector - self.session, kwargs = self.session_class.init_and_remaining_args(**kwargs) + self.session, kwargs = self.session_class.init_and_remaining_args( + connector=connector, + **kwargs, + ) super().__init__(**kwargs) __init__.__signature__ = signature(session_class.__init__) @classmethod - def discover(cls, **kwargs): + async def discover(cls, **kwargs): session, _ = cls.session_class.init_and_remaining_args(**kwargs) d = cls.discovery_class(session, kwargs) - return d.discover() + + async for collection in d.discover(): + yield collection @classmethod - def create_collection(cls, collection, **kwargs): + async def create_collection(cls, collection, **kwargs): session, _ = cls.session_class.init_and_remaining_args(**kwargs) d = cls.discovery_class(session, kwargs) - return d.create(collection) + return await d.create(collection) def _normalize_href(self, *args, **kwargs): return _normalize_href(self.session.url, *args, **kwargs) @@ -455,57 +485,65 @@ class DAVStorage(Storage): def _is_item_mimetype(self, mimetype): return _fuzzy_matches_mimetype(self.item_mimetype, mimetype) - def get(self, href): - ((actual_href, item, etag),) = self.get_multi([href]) + async def get(self, href): + ((actual_href, item, etag),) = await aiostream.stream.list( + self.get_multi([href]) + ) assert href == actual_href return item, etag - def get_multi(self, hrefs): + async def get_multi(self, hrefs): hrefs = set(hrefs) href_xml = [] for href in hrefs: if href != self._normalize_href(href): raise exceptions.NotFoundError(href) href_xml.append(f"{href}") - if not href_xml: - return () + if href_xml: + data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode( + "utf-8" + ) + response = await self.session.request( + "REPORT", "", data=data, headers=self.session.get_default_headers() + ) + root = _parse_xml( + await response.content.read() + ) # etree only can handle bytes + rv = [] + hrefs_left = set(hrefs) + for href, etag, prop in self._parse_prop_responses(root): + raw = prop.find(self.get_multi_data_query) + if raw is None: + dav_logger.warning( + "Skipping {}, the item content is missing.".format(href) + ) + continue - data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode("utf-8") - response = self.session.request( - "REPORT", "", data=data, headers=self.session.get_default_headers() - ) - root = _parse_xml(response.content) # etree only can handle bytes - rv = [] - hrefs_left = set(hrefs) - for href, etag, prop in self._parse_prop_responses(root): - raw = prop.find(self.get_multi_data_query) - if raw is None: - dav_logger.warning( - "Skipping {}, the item content is missing.".format(href) - ) - continue + raw = raw.text or "" - raw = raw.text or "" + if isinstance(raw, bytes): + raw = raw.decode(response.encoding) + if isinstance(etag, bytes): + etag = etag.decode(response.encoding) - if isinstance(raw, bytes): - raw = raw.decode(response.encoding) - if isinstance(etag, bytes): - etag = etag.decode(response.encoding) - - try: - hrefs_left.remove(href) - except KeyError: - if href in hrefs: - dav_logger.warning("Server sent item twice: {}".format(href)) + try: + hrefs_left.remove(href) + except KeyError: + if href in hrefs: + dav_logger.warning("Server sent item twice: {}".format(href)) + else: + dav_logger.warning( + "Server sent unsolicited item: {}".format(href) + ) else: - dav_logger.warning("Server sent unsolicited item: {}".format(href)) - else: - rv.append((href, Item(raw), etag)) - for href in hrefs_left: - raise exceptions.NotFoundError(href) - return rv + rv.append((href, Item(raw), etag)) + for href in hrefs_left: + raise exceptions.NotFoundError(href) - def _put(self, href, item, etag): + for href, item, etag in rv: + yield href, item, etag + + async def _put(self, href, item, etag): headers = self.session.get_default_headers() headers["Content-Type"] = self.item_mimetype if etag is None: @@ -513,11 +551,11 @@ class DAVStorage(Storage): else: headers["If-Match"] = etag - response = self.session.request( + response = await self.session.request( "PUT", href, data=item.raw.encode("utf-8"), headers=headers ) - _assert_multistatus_success(response) + await _assert_multistatus_success(response) # The server may not return an etag under certain conditions: # @@ -534,25 +572,28 @@ class DAVStorage(Storage): # In such cases we return a constant etag. The next synchronization # will then detect an etag change and will download the new item. etag = response.headers.get("etag", None) - href = self._normalize_href(response.url) + href = self._normalize_href(str(response.url)) return href, etag - def update(self, href, item, etag): + async def update(self, href, item, etag): if etag is None: raise ValueError("etag must be given and must not be None.") - href, etag = self._put(self._normalize_href(href), item, etag) + href, etag = await self._put(self._normalize_href(href), item, etag) return etag - def upload(self, item): + async def upload(self, item): href = self._get_href(item) - return self._put(href, item, None) + rv = await self._put(href, item, None) + return rv - def delete(self, href, etag): + async def delete(self, href, etag): href = self._normalize_href(href) headers = self.session.get_default_headers() - headers.update({"If-Match": etag}) + if etag: # baikal doesn't give us an etag. + dav_logger.warning("Deleting an item with no etag.") + headers.update({"If-Match": etag}) - self.session.request("DELETE", href, headers=headers) + await self.session.request("DELETE", href, headers=headers) def _parse_prop_responses(self, root, handled_hrefs=None): if handled_hrefs is None: @@ -604,7 +645,7 @@ class DAVStorage(Storage): handled_hrefs.add(href) yield href, etag, props - def list(self): + async def list(self): headers = self.session.get_default_headers() headers["Depth"] = "1" @@ -620,14 +661,19 @@ class DAVStorage(Storage): # We use a PROPFIND request instead of addressbook-query due to issues # with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83 - response = self.session.request("PROPFIND", "", data=data, headers=headers) - root = _parse_xml(response.content) + response = await self.session.request( + "PROPFIND", + "", + data=data, + headers=headers, + ) + root = _parse_xml(await response.content.read()) rv = self._parse_prop_responses(root) for href, etag, _prop in rv: yield href, etag - def get_meta(self, key): + async def get_meta(self, key): try: tagname, namespace = self._property_table[key] except KeyError: @@ -649,9 +695,14 @@ class DAVStorage(Storage): headers = self.session.get_default_headers() headers["Depth"] = "0" - response = self.session.request("PROPFIND", "", data=data, headers=headers) + response = await self.session.request( + "PROPFIND", + "", + data=data, + headers=headers, + ) - root = _parse_xml(response.content) + root = _parse_xml(await response.content.read()) for prop in root.findall(".//" + xpath): text = normalize_meta_value(getattr(prop, "text", None)) @@ -659,7 +710,7 @@ class DAVStorage(Storage): return text return "" - def set_meta(self, key, value): + async def set_meta(self, key, value): try: tagname, namespace = self._property_table[key] except KeyError: @@ -683,8 +734,11 @@ class DAVStorage(Storage): "utf-8" ) - self.session.request( - "PROPPATCH", "", data=data, headers=self.session.get_default_headers() + await self.session.request( + "PROPPATCH", + "", + data=data, + headers=self.session.get_default_headers(), ) # XXX: Response content is currently ignored. Though exceptions are @@ -776,7 +830,7 @@ class CalDAVStorage(DAVStorage): ("VTODO", "VEVENT"), start, end ) - def list(self): + async def list(self): caldavfilters = list( self._get_list_filters(self.item_types, self.start_date, self.end_date) ) @@ -788,7 +842,8 @@ class CalDAVStorage(DAVStorage): # instead? # # See https://github.com/dmfs/tasks/issues/118 for backstory. - yield from DAVStorage.list(self) + async for href, etag in DAVStorage.list(self): + yield href, etag data = """ str: return p -def split_dict(d, f): +def split_dict(d: dict, f: callable): """Puts key into first dict if f(key), otherwise in second dict""" - a, b = split_sequence(d.items(), lambda item: f(item[0])) - return dict(a), dict(b) - - -def split_sequence(s, f): - """Puts item into first list if f(item), else in second list""" - a = [] - b = [] - for item in s: - if f(item): - a.append(item) + a = {} + b = {} + for k, v in d.items(): + if f(k): + a[k] = v else: - b.append(item) - + b[k] = v return a, b