Initial async support

Add asyncio to the storage backends and most of the codebase. A lot of
it merely uses asyncio APIs, but still doesn't actually run several
things concurrently internally. Further improvements will be added on
top of these changes

Thanks to  Thomas Grainger (@graingert) for a few useful pointers
related to asyncio.
This commit is contained in:
Hugo Osvaldo Barrera 2021-06-17 23:38:18 +02:00
parent 7c9170c677
commit 1a1f6f0788
44 changed files with 1383 additions and 935 deletions

View file

@ -19,6 +19,10 @@ packages:
- python-pytest-cov - python-pytest-cov
- python-pytest-httpserver - python-pytest-httpserver
- python-trustme - python-trustme
- python-pytest-asyncio
- python-aiohttp
- python-aiostream
- python-aioresponses
sources: sources:
- https://github.com/pimutils/vdirsyncer - https://github.com/pimutils/vdirsyncer
environment: environment:

View file

@ -8,6 +8,7 @@ addopts =
--cov=vdirsyncer --cov=vdirsyncer
--cov-report=term-missing --cov-report=term-missing
--no-cov-on-fail --no-cov-on-fail
# filterwarnings=error
[flake8] [flake8]
application-import-names = tests,vdirsyncer application-import-names = tests,vdirsyncer

View file

@ -13,14 +13,14 @@ requirements = [
# https://github.com/mitsuhiko/click/issues/200 # https://github.com/mitsuhiko/click/issues/200
"click>=5.0,<9.0", "click>=5.0,<9.0",
"click-log>=0.3.0, <0.4.0", "click-log>=0.3.0, <0.4.0",
# https://github.com/pimutils/vdirsyncer/issues/478
"click-threading>=0.5",
"requests >=2.20.0", "requests >=2.20.0",
# https://github.com/sigmavirus24/requests-toolbelt/pull/28 # https://github.com/sigmavirus24/requests-toolbelt/pull/28
# And https://github.com/sigmavirus24/requests-toolbelt/issues/54 # And https://github.com/sigmavirus24/requests-toolbelt/issues/54
"requests_toolbelt >=0.4.0", "requests_toolbelt >=0.4.0",
# https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa # https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa
"atomicwrites>=0.1.7", "atomicwrites>=0.1.7",
"aiohttp>=3.7.1,<4.0.0",
"aiostream>=0.4.3,<0.5.0",
] ]

View file

@ -3,3 +3,5 @@ pytest
pytest-cov pytest-cov
pytest-httpserver pytest-httpserver
trustme trustme
pytest-asyncio
aioresponses

View file

@ -4,6 +4,7 @@ General-purpose fixtures for vdirsyncer's testsuite.
import logging import logging
import os import os
import aiohttp
import click_log import click_log
import pytest import pytest
from hypothesis import HealthCheck from hypothesis import HealthCheck
@ -52,3 +53,18 @@ elif os.environ.get("CI", "false").lower() == "true":
settings.load_profile("ci") settings.load_profile("ci")
else: else:
settings.load_profile("dev") settings.load_profile("dev")
@pytest.fixture
async def aio_session(event_loop):
async with aiohttp.ClientSession() as session:
yield session
@pytest.fixture
async def aio_connector(event_loop):
conn = aiohttp.TCPConnector(limit_per_host=16)
try:
yield conn
finally:
await conn.close()

View file

@ -4,6 +4,7 @@ import uuid
from urllib.parse import quote as urlquote from urllib.parse import quote as urlquote
from urllib.parse import unquote as urlunquote from urllib.parse import unquote as urlunquote
import aiostream
import pytest import pytest
from .. import assert_item_equals from .. import assert_item_equals
@ -49,8 +50,9 @@ class StorageTests:
raise NotImplementedError() raise NotImplementedError()
@pytest.fixture @pytest.fixture
def s(self, get_storage_args): async def s(self, get_storage_args):
return self.storage_class(**get_storage_args()) rv = self.storage_class(**await get_storage_args())
return rv
@pytest.fixture @pytest.fixture
def get_item(self, item_type): def get_item(self, item_type):
@ -72,176 +74,211 @@ class StorageTests:
if not self.supports_metadata: if not self.supports_metadata:
pytest.skip("This storage does not support metadata.") pytest.skip("This storage does not support metadata.")
def test_generic(self, s, get_item): @pytest.mark.asyncio
async def test_generic(self, s, get_item):
items = [get_item() for i in range(1, 10)] items = [get_item() for i in range(1, 10)]
hrefs = [] hrefs = []
for item in items: for item in items:
href, etag = s.upload(item) href, etag = await s.upload(item)
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = await s.get(href)
hrefs.append((href, etag)) hrefs.append((href, etag))
hrefs.sort() hrefs.sort()
assert hrefs == sorted(s.list()) assert hrefs == sorted(await aiostream.stream.list(s.list()))
for href, etag in hrefs: for href, etag in hrefs:
assert isinstance(href, (str, bytes)) assert isinstance(href, (str, bytes))
assert isinstance(etag, (str, bytes)) assert isinstance(etag, (str, bytes))
assert s.has(href) assert await s.has(href)
item, etag2 = s.get(href) item, etag2 = await s.get(href)
assert etag == etag2 assert etag == etag2
def test_empty_get_multi(self, s): @pytest.mark.asyncio
assert list(s.get_multi([])) == [] async def test_empty_get_multi(self, s):
assert await aiostream.stream.list(s.get_multi([])) == []
def test_get_multi_duplicates(self, s, get_item): @pytest.mark.asyncio
href, etag = s.upload(get_item()) async def test_get_multi_duplicates(self, s, get_item):
href, etag = await s.upload(get_item())
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = await s.get(href)
((href2, item, etag2),) = s.get_multi([href] * 2) ((href2, item, etag2),) = await aiostream.stream.list(s.get_multi([href] * 2))
assert href2 == href assert href2 == href
assert etag2 == etag assert etag2 == etag
def test_upload_already_existing(self, s, get_item): @pytest.mark.asyncio
async def test_upload_already_existing(self, s, get_item):
item = get_item() item = get_item()
s.upload(item) await s.upload(item)
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.upload(item) await s.upload(item)
def test_upload(self, s, get_item): @pytest.mark.asyncio
async def test_upload(self, s, get_item):
item = get_item() item = get_item()
href, etag = s.upload(item) href, etag = await s.upload(item)
assert_item_equals(s.get(href)[0], item) assert_item_equals((await s.get(href))[0], item)
def test_update(self, s, get_item): @pytest.mark.asyncio
async def test_update(self, s, get_item):
item = get_item() item = get_item()
href, etag = s.upload(item) href, etag = await s.upload(item)
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = await s.get(href)
assert_item_equals(s.get(href)[0], item) assert_item_equals((await s.get(href))[0], item)
new_item = get_item(uid=item.uid) new_item = get_item(uid=item.uid)
new_etag = s.update(href, new_item, etag) new_etag = await s.update(href, new_item, etag)
if new_etag is None: if new_etag is None:
_, new_etag = s.get(href) _, new_etag = await s.get(href)
# See https://github.com/pimutils/vdirsyncer/issues/48 # See https://github.com/pimutils/vdirsyncer/issues/48
assert isinstance(new_etag, (bytes, str)) assert isinstance(new_etag, (bytes, str))
assert_item_equals(s.get(href)[0], new_item) assert_item_equals((await s.get(href))[0], new_item)
def test_update_nonexisting(self, s, get_item): @pytest.mark.asyncio
async def test_update_nonexisting(self, s, get_item):
item = get_item() item = get_item()
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.update("huehue", item, '"123"') await s.update("huehue", item, '"123"')
def test_wrong_etag(self, s, get_item): @pytest.mark.asyncio
async def test_wrong_etag(self, s, get_item):
item = get_item() item = get_item()
href, etag = s.upload(item) href, etag = await s.upload(item)
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.update(href, item, '"lolnope"') await s.update(href, item, '"lolnope"')
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.delete(href, '"lolnope"') await s.delete(href, '"lolnope"')
def test_delete(self, s, get_item): @pytest.mark.asyncio
href, etag = s.upload(get_item()) async def test_delete(self, s, get_item):
s.delete(href, etag) href, etag = await s.upload(get_item())
assert not list(s.list()) await s.delete(href, etag)
assert not await aiostream.stream.list(s.list())
def test_delete_nonexisting(self, s, get_item): @pytest.mark.asyncio
async def test_delete_nonexisting(self, s, get_item):
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.delete("1", '"123"') await s.delete("1", '"123"')
def test_list(self, s, get_item): @pytest.mark.asyncio
assert not list(s.list()) async def test_list(self, s, get_item):
href, etag = s.upload(get_item()) assert not await aiostream.stream.list(s.list())
href, etag = await s.upload(get_item())
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = await s.get(href)
assert list(s.list()) == [(href, etag)] assert await aiostream.stream.list(s.list()) == [(href, etag)]
def test_has(self, s, get_item): @pytest.mark.asyncio
assert not s.has("asd") async def test_has(self, s, get_item):
href, etag = s.upload(get_item()) assert not await s.has("asd")
assert s.has(href) href, etag = await s.upload(get_item())
assert not s.has("asd") assert await s.has(href)
s.delete(href, etag) assert not await s.has("asd")
assert not s.has(href) await s.delete(href, etag)
assert not await s.has(href)
def test_update_others_stay_the_same(self, s, get_item): @pytest.mark.asyncio
async def test_update_others_stay_the_same(self, s, get_item):
info = {} info = {}
for _ in range(4): for _ in range(4):
href, etag = s.upload(get_item()) href, etag = await s.upload(get_item())
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = await s.get(href)
info[href] = etag info[href] = etag
assert { items = await aiostream.stream.list(
href: etag s.get_multi(href for href, etag in info.items())
for href, item, etag in s.get_multi(href for href, etag in info.items()) )
} == info assert {href: etag for href, item, etag in items} == info
def test_repr(self, s, get_storage_args): @pytest.mark.asyncio
def test_repr(self, s, get_storage_args): # XXX: unused param
assert self.storage_class.__name__ in repr(s) assert self.storage_class.__name__ in repr(s)
assert s.instance_name is None assert s.instance_name is None
def test_discover(self, requires_collections, get_storage_args, get_item): @pytest.mark.asyncio
async def test_discover(
self,
requires_collections,
get_storage_args,
get_item,
aio_connector,
):
collections = set() collections = set()
for i in range(1, 5): for i in range(1, 5):
collection = f"test{i}" collection = f"test{i}"
s = self.storage_class(**get_storage_args(collection=collection)) s = self.storage_class(**await get_storage_args(collection=collection))
assert not list(s.list()) assert not await aiostream.stream.list(s.list())
s.upload(get_item()) await s.upload(get_item())
collections.add(s.collection) collections.add(s.collection)
actual = { discovered = await aiostream.stream.list(
c["collection"] self.storage_class.discover(**await get_storage_args(collection=None))
for c in self.storage_class.discover(**get_storage_args(collection=None)) )
} actual = {c["collection"] for c in discovered}
assert actual >= collections assert actual >= collections
def test_create_collection(self, requires_collections, get_storage_args, get_item): @pytest.mark.asyncio
async def test_create_collection(
self,
requires_collections,
get_storage_args,
get_item,
):
if getattr(self, "dav_server", "") in ("icloud", "fastmail", "davical"): if getattr(self, "dav_server", "") in ("icloud", "fastmail", "davical"):
pytest.skip("Manual cleanup would be necessary.") pytest.skip("Manual cleanup would be necessary.")
if getattr(self, "dav_server", "") == "radicale": if getattr(self, "dav_server", "") == "radicale":
pytest.skip("Radicale does not support collection creation") pytest.skip("Radicale does not support collection creation")
args = get_storage_args(collection=None) args = await get_storage_args(collection=None)
args["collection"] = "test" args["collection"] = "test"
s = self.storage_class(**self.storage_class.create_collection(**args)) s = self.storage_class(**await self.storage_class.create_collection(**args))
href = s.upload(get_item())[0] href = (await s.upload(get_item()))[0]
assert href in (href for href, etag in s.list()) assert href in await aiostream.stream.list(
(href async for href, etag in s.list())
)
def test_discover_collection_arg(self, requires_collections, get_storage_args): @pytest.mark.asyncio
args = get_storage_args(collection="test2") async def test_discover_collection_arg(
self, requires_collections, get_storage_args
):
args = await get_storage_args(collection="test2")
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
list(self.storage_class.discover(**args)) await aiostream.stream.list(self.storage_class.discover(**args))
assert "collection argument must not be given" in str(excinfo.value) assert "collection argument must not be given" in str(excinfo.value)
def test_collection_arg(self, get_storage_args): @pytest.mark.asyncio
async def test_collection_arg(self, get_storage_args):
if self.storage_class.storage_name.startswith("etesync"): if self.storage_class.storage_name.startswith("etesync"):
pytest.skip("etesync uses UUIDs.") pytest.skip("etesync uses UUIDs.")
if self.supports_collections: if self.supports_collections:
s = self.storage_class(**get_storage_args(collection="test2")) s = self.storage_class(**await get_storage_args(collection="test2"))
# Can't do stronger assertion because of radicale, which needs a # Can't do stronger assertion because of radicale, which needs a
# fileextension to guess the collection type. # fileextension to guess the collection type.
assert "test2" in s.collection assert "test2" in s.collection
else: else:
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.storage_class(collection="ayy", **get_storage_args()) self.storage_class(collection="ayy", **await get_storage_args())
def test_case_sensitive_uids(self, s, get_item): @pytest.mark.asyncio
async def test_case_sensitive_uids(self, s, get_item):
if s.storage_name == "filesystem": if s.storage_name == "filesystem":
pytest.skip("Behavior depends on the filesystem.") pytest.skip("Behavior depends on the filesystem.")
uid = str(uuid.uuid4()) uid = str(uuid.uuid4())
s.upload(get_item(uid=uid.upper())) await s.upload(get_item(uid=uid.upper()))
s.upload(get_item(uid=uid.lower())) await s.upload(get_item(uid=uid.lower()))
items = [href for href, etag in s.list()] items = [href async for href, etag in s.list()]
assert len(items) == 2 assert len(items) == 2
assert len(set(items)) == 2 assert len(set(items)) == 2
def test_specialchars( @pytest.mark.asyncio
async def test_specialchars(
self, monkeypatch, requires_collections, get_storage_args, get_item self, monkeypatch, requires_collections, get_storage_args, get_item
): ):
if getattr(self, "dav_server", "") == "radicale": if getattr(self, "dav_server", "") == "radicale":
@ -254,16 +291,16 @@ class StorageTests:
uid = "test @ foo ät bar град сатану" uid = "test @ foo ät bar град сатану"
collection = "test @ foo ät bar" collection = "test @ foo ät bar"
s = self.storage_class(**get_storage_args(collection=collection)) s = self.storage_class(**await get_storage_args(collection=collection))
item = get_item(uid=uid) item = get_item(uid=uid)
href, etag = s.upload(item) href, etag = await s.upload(item)
item2, etag2 = s.get(href) item2, etag2 = await s.get(href)
if etag is not None: if etag is not None:
assert etag2 == etag assert etag2 == etag
assert_item_equals(item2, item) assert_item_equals(item2, item)
((_, etag3),) = s.list() ((_, etag3),) = await aiostream.stream.list(s.list())
assert etag2 == etag3 assert etag2 == etag3
# etesync uses UUIDs for collection names # etesync uses UUIDs for collection names
@ -274,22 +311,23 @@ class StorageTests:
if self.storage_class.storage_name.endswith("dav"): if self.storage_class.storage_name.endswith("dav"):
assert urlquote(uid, "/@:") in href assert urlquote(uid, "/@:") in href
def test_metadata(self, requires_metadata, s): @pytest.mark.asyncio
async def test_metadata(self, requires_metadata, s):
if not getattr(self, "dav_server", ""): if not getattr(self, "dav_server", ""):
assert not s.get_meta("color") assert not await s.get_meta("color")
assert not s.get_meta("displayname") assert not await s.get_meta("displayname")
try: try:
s.set_meta("color", None) await s.set_meta("color", None)
assert not s.get_meta("color") assert not await s.get_meta("color")
s.set_meta("color", "#ff0000") await s.set_meta("color", "#ff0000")
assert s.get_meta("color") == "#ff0000" assert await s.get_meta("color") == "#ff0000"
except exceptions.UnsupportedMetadataError: except exceptions.UnsupportedMetadataError:
pass pass
for x in ("hello world", "hello wörld"): for x in ("hello world", "hello wörld"):
s.set_meta("displayname", x) await s.set_meta("displayname", x)
rv = s.get_meta("displayname") rv = await s.get_meta("displayname")
assert rv == x assert rv == x
assert isinstance(rv, str) assert isinstance(rv, str)
@ -306,16 +344,18 @@ class StorageTests:
"فلسطين", "فلسطين",
], ],
) )
def test_metadata_normalization(self, requires_metadata, s, value): @pytest.mark.asyncio
x = s.get_meta("displayname") async def test_metadata_normalization(self, requires_metadata, s, value):
x = await s.get_meta("displayname")
assert x == normalize_meta_value(x) assert x == normalize_meta_value(x)
if not getattr(self, "dav_server", None): if not getattr(self, "dav_server", None):
# ownCloud replaces "" with "unnamed" # ownCloud replaces "" with "unnamed"
s.set_meta("displayname", value) await s.set_meta("displayname", value)
assert s.get_meta("displayname") == normalize_meta_value(value) assert await s.get_meta("displayname") == normalize_meta_value(value)
def test_recurring_events(self, s, item_type): @pytest.mark.asyncio
async def test_recurring_events(self, s, item_type):
if item_type != "VEVENT": if item_type != "VEVENT":
pytest.skip("This storage instance doesn't support iCalendar.") pytest.skip("This storage instance doesn't support iCalendar.")
@ -362,7 +402,7 @@ class StorageTests:
).strip() ).strip()
) )
href, etag = s.upload(item) href, etag = await s.upload(item)
item2, etag2 = s.get(href) item2, etag2 = await s.get(href)
assert normalize_item(item) == normalize_item(item2) assert normalize_item(item) == normalize_item(item2)

View file

@ -3,6 +3,7 @@ import subprocess
import time import time
import uuid import uuid
import aiostream
import pytest import pytest
import requests import requests
@ -80,31 +81,31 @@ def xandikos_server():
@pytest.fixture @pytest.fixture
def slow_create_collection(request): async def slow_create_collection(request, aio_connector):
# We need to properly clean up because otherwise we might run into # We need to properly clean up because otherwise we might run into
# storage limits. # storage limits.
to_delete = [] to_delete = []
def delete_collections(): async def delete_collections():
for s in to_delete: for s in to_delete:
s.session.request("DELETE", "") await s.session.request("DELETE", "")
request.addfinalizer(delete_collections) async def inner(cls, args, collection):
def inner(cls, args, collection):
assert collection.startswith("test") assert collection.startswith("test")
collection += "-vdirsyncer-ci-" + str(uuid.uuid4()) collection += "-vdirsyncer-ci-" + str(uuid.uuid4())
args = cls.create_collection(collection, **args) args = await cls.create_collection(collection, **args)
s = cls(**args) s = cls(**args)
_clear_collection(s) await _clear_collection(s)
assert not list(s.list()) assert not await aiostream.stream.list(s.list())
to_delete.append(s) to_delete.append(s)
return args return args
return inner yield inner
await delete_collections()
def _clear_collection(s): async def _clear_collection(s):
for href, etag in s.list(): async for href, etag in s.list():
s.delete(href, etag) s.delete(href, etag)

View file

@ -1,8 +1,9 @@
import os import os
import uuid import uuid
import aiohttp
import aiostream
import pytest import pytest
import requests.exceptions
from .. import get_server_mixin from .. import get_server_mixin
from .. import StorageTests from .. import StorageTests
@ -19,30 +20,33 @@ class DAVStorageTests(ServerMixin, StorageTests):
dav_server = dav_server dav_server = dav_server
@pytest.mark.skipif(dav_server == "radicale", reason="Radicale is very tolerant.") @pytest.mark.skipif(dav_server == "radicale", reason="Radicale is very tolerant.")
def test_dav_broken_item(self, s): @pytest.mark.asyncio
async def test_dav_broken_item(self, s):
item = Item("HAHA:YES") item = Item("HAHA:YES")
with pytest.raises((exceptions.Error, requests.exceptions.HTTPError)): with pytest.raises((exceptions.Error, aiohttp.ClientResponseError)):
s.upload(item) await s.upload(item)
assert not list(s.list()) assert not await aiostream.stream.list(s.list())
def test_dav_empty_get_multi_performance(self, s, monkeypatch): @pytest.mark.asyncio
async def test_dav_empty_get_multi_performance(self, s, monkeypatch):
def breakdown(*a, **kw): def breakdown(*a, **kw):
raise AssertionError("Expected not to be called.") raise AssertionError("Expected not to be called.")
monkeypatch.setattr("requests.sessions.Session.request", breakdown) monkeypatch.setattr("requests.sessions.Session.request", breakdown)
try: try:
assert list(s.get_multi([])) == [] assert list(await aiostream.stream.list(s.get_multi([]))) == []
finally: finally:
# Make sure monkeypatch doesn't interfere with DAV server teardown # Make sure monkeypatch doesn't interfere with DAV server teardown
monkeypatch.undo() monkeypatch.undo()
def test_dav_unicode_href(self, s, get_item, monkeypatch): @pytest.mark.asyncio
async def test_dav_unicode_href(self, s, get_item, monkeypatch):
if self.dav_server == "radicale": if self.dav_server == "radicale":
pytest.skip("Radicale is unable to deal with unicode hrefs") pytest.skip("Radicale is unable to deal with unicode hrefs")
monkeypatch.setattr(s, "_get_href", lambda item: item.ident + s.fileext) monkeypatch.setattr(s, "_get_href", lambda item: item.ident + s.fileext)
item = get_item(uid="град сатану" + str(uuid.uuid4())) item = get_item(uid="град сатану" + str(uuid.uuid4()))
href, etag = s.upload(item) href, etag = await s.upload(item)
item2, etag2 = s.get(href) item2, etag2 = await s.get(href)
assert_item_equals(item, item2) assert_item_equals(item, item2)

View file

@ -1,8 +1,10 @@
import datetime import datetime
from textwrap import dedent from textwrap import dedent
import aiohttp
import aiostream
import pytest import pytest
import requests.exceptions from aioresponses import aioresponses
from . import dav_server from . import dav_server
from . import DAVStorageTests from . import DAVStorageTests
@ -21,15 +23,17 @@ class TestCalDAVStorage(DAVStorageTests):
def item_type(self, request): def item_type(self, request):
return request.param return request.param
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") @pytest.mark.asyncio
def test_doesnt_accept_vcard(self, item_type, get_storage_args): async def test_doesnt_accept_vcard(self, item_type, get_storage_args):
s = self.storage_class(item_types=(item_type,), **get_storage_args()) s = self.storage_class(item_types=(item_type,), **await get_storage_args())
try: try:
s.upload(format_item(VCARD_TEMPLATE)) await s.upload(format_item(VCARD_TEMPLATE))
except (exceptions.Error, requests.exceptions.HTTPError): except (exceptions.Error, aiohttp.ClientResponseError):
# Most storages hard-fail, but xandikos doesn't.
pass pass
assert not list(s.list())
assert not await aiostream.stream.list(s.list())
# The `arg` param is not named `item_types` because that would hit # The `arg` param is not named `item_types` because that would hit
# https://bitbucket.org/pytest-dev/pytest/issue/745/ # https://bitbucket.org/pytest-dev/pytest/issue/745/
@ -44,10 +48,11 @@ class TestCalDAVStorage(DAVStorageTests):
], ],
) )
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") @pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.")
def test_item_types_performance( @pytest.mark.asyncio
async def test_item_types_performance(
self, get_storage_args, arg, calls_num, monkeypatch self, get_storage_args, arg, calls_num, monkeypatch
): ):
s = self.storage_class(item_types=arg, **get_storage_args()) s = self.storage_class(item_types=arg, **await get_storage_args())
old_parse = s._parse_prop_responses old_parse = s._parse_prop_responses
calls = [] calls = []
@ -56,17 +61,18 @@ class TestCalDAVStorage(DAVStorageTests):
return old_parse(*a, **kw) return old_parse(*a, **kw)
monkeypatch.setattr(s, "_parse_prop_responses", new_parse) monkeypatch.setattr(s, "_parse_prop_responses", new_parse)
list(s.list()) await aiostream.stream.list(s.list())
assert len(calls) == calls_num assert len(calls) == calls_num
@pytest.mark.xfail( @pytest.mark.xfail(
dav_server == "radicale", reason="Radicale doesn't support timeranges." dav_server == "radicale", reason="Radicale doesn't support timeranges."
) )
def test_timerange_correctness(self, get_storage_args): @pytest.mark.asyncio
async def test_timerange_correctness(self, get_storage_args):
start_date = datetime.datetime(2013, 9, 10) start_date = datetime.datetime(2013, 9, 10)
end_date = datetime.datetime(2013, 9, 13) end_date = datetime.datetime(2013, 9, 13)
s = self.storage_class( s = self.storage_class(
start_date=start_date, end_date=end_date, **get_storage_args() start_date=start_date, end_date=end_date, **await get_storage_args()
) )
too_old_item = format_item( too_old_item = format_item(
@ -123,50 +129,44 @@ class TestCalDAVStorage(DAVStorageTests):
).strip() ).strip()
) )
s.upload(too_old_item) await s.upload(too_old_item)
s.upload(too_new_item) await s.upload(too_new_item)
expected_href, _ = s.upload(good_item) expected_href, _ = await s.upload(good_item)
((actual_href, _),) = s.list() ((actual_href, _),) = await aiostream.stream.list(s.list())
assert actual_href == expected_href assert actual_href == expected_href
def test_invalid_resource(self, monkeypatch, get_storage_args): @pytest.mark.asyncio
calls = [] async def test_invalid_resource(self, monkeypatch, get_storage_args):
args = get_storage_args(collection=None) args = await get_storage_args(collection=None)
def request(session, method, url, **kwargs): with aioresponses() as m:
assert url == args["url"] m.add(args["url"], method="PROPFIND", status=200, body="Hello world")
calls.append(None)
r = requests.Response() with pytest.raises(ValueError):
r.status_code = 200 s = self.storage_class(**args)
r._content = b"Hello World." await aiostream.stream.list(s.list())
return r
monkeypatch.setattr("requests.sessions.Session.request", request) assert len(m.requests) == 1
with pytest.raises(ValueError):
s = self.storage_class(**args)
list(s.list())
assert len(calls) == 1
@pytest.mark.skipif(dav_server == "icloud", reason="iCloud only accepts VEVENT") @pytest.mark.skipif(dav_server == "icloud", reason="iCloud only accepts VEVENT")
@pytest.mark.skipif( @pytest.mark.skipif(
dav_server == "fastmail", reason="Fastmail has non-standard hadling of VTODOs." dav_server == "fastmail", reason="Fastmail has non-standard hadling of VTODOs."
) )
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.") @pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.")
def test_item_types_general(self, s): @pytest.mark.asyncio
event = s.upload(format_item(EVENT_TEMPLATE))[0] async def test_item_types_general(self, s):
task = s.upload(format_item(TASK_TEMPLATE))[0] event = (await s.upload(format_item(EVENT_TEMPLATE)))[0]
task = (await s.upload(format_item(TASK_TEMPLATE)))[0]
s.item_types = ("VTODO", "VEVENT") s.item_types = ("VTODO", "VEVENT")
def hrefs(): async def hrefs():
return {href for href, etag in s.list()} return {href async for href, etag in s.list()}
assert hrefs() == {event, task} assert await hrefs() == {event, task}
s.item_types = ("VTODO",) s.item_types = ("VTODO",)
assert hrefs() == {task} assert await hrefs() == {task}
s.item_types = ("VEVENT",) s.item_types = ("VEVENT",)
assert hrefs() == {event} assert await hrefs() == {event}
s.item_types = () s.item_types = ()
assert hrefs() == {event, task} assert await hrefs() == {event, task}

View file

@ -58,7 +58,7 @@ class EtesyncTests(StorageTests):
) )
assert r.status_code == 200 assert r.status_code == 200
def inner(collection="test"): async def inner(collection="test"):
rv = { rv = {
"email": "test@localhost", "email": "test@localhost",
"db_path": str(tmpdir.join("etesync.db")), "db_path": str(tmpdir.join("etesync.db")),

View file

@ -3,13 +3,21 @@ import pytest
class ServerMixin: class ServerMixin:
@pytest.fixture @pytest.fixture
def get_storage_args(self, request, tmpdir, slow_create_collection, baikal_server): def get_storage_args(
def inner(collection="test"): self,
request,
tmpdir,
slow_create_collection,
baikal_server,
aio_connector,
):
async def inner(collection="test"):
base_url = "http://127.0.0.1:8002/" base_url = "http://127.0.0.1:8002/"
args = { args = {
"url": base_url, "url": base_url,
"username": "baikal", "username": "baikal",
"password": "baikal", "password": "baikal",
"connector": aio_connector,
} }
if self.storage_class.fileext == ".vcf": if self.storage_class.fileext == ".vcf":
@ -18,7 +26,11 @@ class ServerMixin:
args["url"] = base_url + "cal.php/" args["url"] = base_url + "cal.php/"
if collection is not None: if collection is not None:
args = slow_create_collection(self.storage_class, args, collection) args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args return args
return inner return inner

View file

@ -27,7 +27,7 @@ class ServerMixin:
@pytest.fixture @pytest.fixture
def get_storage_args(self, davical_args, request): def get_storage_args(self, davical_args, request):
def inner(collection="test"): async def inner(collection="test"):
if collection is None: if collection is None:
return davical_args return davical_args

View file

@ -11,7 +11,7 @@ class ServerMixin:
# See https://github.com/pimutils/vdirsyncer/issues/824 # See https://github.com/pimutils/vdirsyncer/issues/824
pytest.skip("Fastmail has non-standard VTODO support.") pytest.skip("Fastmail has non-standard VTODO support.")
def inner(collection="test"): async def inner(collection="test"):
args = { args = {
"username": os.environ["FASTMAIL_USERNAME"], "username": os.environ["FASTMAIL_USERNAME"],
"password": os.environ["FASTMAIL_PASSWORD"], "password": os.environ["FASTMAIL_PASSWORD"],

View file

@ -11,7 +11,7 @@ class ServerMixin:
# See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa # See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa
pytest.skip("iCloud doesn't support anything else than VEVENT") pytest.skip("iCloud doesn't support anything else than VEVENT")
def inner(collection="test"): async def inner(collection="test"):
args = { args = {
"username": os.environ["ICLOUD_USERNAME"], "username": os.environ["ICLOUD_USERNAME"],
"password": os.environ["ICLOUD_PASSWORD"], "password": os.environ["ICLOUD_PASSWORD"],

View file

@ -9,17 +9,23 @@ class ServerMixin:
tmpdir, tmpdir,
slow_create_collection, slow_create_collection,
radicale_server, radicale_server,
aio_connector,
): ):
def inner(collection="test"): async def inner(collection="test"):
url = "http://127.0.0.1:8001/" url = "http://127.0.0.1:8001/"
args = { args = {
"url": url, "url": url,
"username": "radicale", "username": "radicale",
"password": "radicale", "password": "radicale",
"connector": aio_connector,
} }
if collection is not None: if collection is not None:
args = slow_create_collection(self.storage_class, args, collection) args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args return args
return inner return inner

View file

@ -9,13 +9,19 @@ class ServerMixin:
tmpdir, tmpdir,
slow_create_collection, slow_create_collection,
xandikos_server, xandikos_server,
aio_connector,
): ):
def inner(collection="test"): async def inner(collection="test"):
url = "http://127.0.0.1:8000/" url = "http://127.0.0.1:8000/"
args = {"url": url} args = {"url": url, "connector": aio_connector}
if collection is not None: if collection is not None:
args = slow_create_collection(self.storage_class, args, collection) args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args return args
return inner return inner

View file

@ -1,5 +1,6 @@
import subprocess import subprocess
import aiostream
import pytest import pytest
from . import StorageTests from . import StorageTests
@ -12,10 +13,10 @@ class TestFilesystemStorage(StorageTests):
@pytest.fixture @pytest.fixture
def get_storage_args(self, tmpdir): def get_storage_args(self, tmpdir):
def inner(collection="test"): async def inner(collection="test"):
rv = {"path": str(tmpdir), "fileext": ".txt", "collection": collection} rv = {"path": str(tmpdir), "fileext": ".txt", "collection": collection}
if collection is not None: if collection is not None:
rv = self.storage_class.create_collection(**rv) rv = await self.storage_class.create_collection(**rv)
return rv return rv
return inner return inner
@ -26,7 +27,8 @@ class TestFilesystemStorage(StorageTests):
f.write("stub") f.write("stub")
self.storage_class(str(tmpdir) + "/hue", ".txt") self.storage_class(str(tmpdir) + "/hue", ".txt")
def test_broken_data(self, tmpdir): @pytest.mark.asyncio
async def test_broken_data(self, tmpdir):
s = self.storage_class(str(tmpdir), ".txt") s = self.storage_class(str(tmpdir), ".txt")
class BrokenItem: class BrokenItem:
@ -35,64 +37,70 @@ class TestFilesystemStorage(StorageTests):
ident = uid ident = uid
with pytest.raises(TypeError): with pytest.raises(TypeError):
s.upload(BrokenItem) await s.upload(BrokenItem)
assert not tmpdir.listdir() assert not tmpdir.listdir()
def test_ident_with_slash(self, tmpdir): @pytest.mark.asyncio
async def test_ident_with_slash(self, tmpdir):
s = self.storage_class(str(tmpdir), ".txt") s = self.storage_class(str(tmpdir), ".txt")
s.upload(Item("UID:a/b/c")) await s.upload(Item("UID:a/b/c"))
(item_file,) = tmpdir.listdir() (item_file,) = tmpdir.listdir()
assert "/" not in item_file.basename and item_file.isfile() assert "/" not in item_file.basename and item_file.isfile()
def test_ignore_tmp_files(self, tmpdir): @pytest.mark.asyncio
async def test_ignore_tmp_files(self, tmpdir):
"""Test that files with .tmp suffix beside .ics files are ignored.""" """Test that files with .tmp suffix beside .ics files are ignored."""
s = self.storage_class(str(tmpdir), ".ics") s = self.storage_class(str(tmpdir), ".ics")
s.upload(Item("UID:xyzxyz")) await s.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir() (item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(ext="tmp")) item_file.copy(item_file.new(ext="tmp"))
assert len(tmpdir.listdir()) == 2 assert len(tmpdir.listdir()) == 2
assert len(list(s.list())) == 1 assert len(await aiostream.stream.list(s.list())) == 1
def test_ignore_tmp_files_empty_fileext(self, tmpdir): @pytest.mark.asyncio
async def test_ignore_tmp_files_empty_fileext(self, tmpdir):
"""Test that files with .tmp suffix are ignored with empty fileext.""" """Test that files with .tmp suffix are ignored with empty fileext."""
s = self.storage_class(str(tmpdir), "") s = self.storage_class(str(tmpdir), "")
s.upload(Item("UID:xyzxyz")) await s.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir() (item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(ext="tmp")) item_file.copy(item_file.new(ext="tmp"))
assert len(tmpdir.listdir()) == 2 assert len(tmpdir.listdir()) == 2
# assert False, tmpdir.listdir() # enable to see the created filename # assert False, tmpdir.listdir() # enable to see the created filename
assert len(list(s.list())) == 1 assert len(await aiostream.stream.list(s.list())) == 1
def test_ignore_files_typical_backup(self, tmpdir): @pytest.mark.asyncio
async def test_ignore_files_typical_backup(self, tmpdir):
"""Test file-name ignorance with typical backup ending ~.""" """Test file-name ignorance with typical backup ending ~."""
ignorext = "~" # without dot ignorext = "~" # without dot
storage = self.storage_class(str(tmpdir), "", fileignoreext=ignorext) storage = self.storage_class(str(tmpdir), "", fileignoreext=ignorext)
storage.upload(Item("UID:xyzxyz")) await storage.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir() (item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(basename=item_file.basename + ignorext)) item_file.copy(item_file.new(basename=item_file.basename + ignorext))
assert len(tmpdir.listdir()) == 2 assert len(tmpdir.listdir()) == 2
assert len(list(storage.list())) == 1 assert len(await aiostream.stream.list(storage.list())) == 1
def test_too_long_uid(self, tmpdir): @pytest.mark.asyncio
async def test_too_long_uid(self, tmpdir):
storage = self.storage_class(str(tmpdir), ".txt") storage = self.storage_class(str(tmpdir), ".txt")
item = Item("UID:" + "hue" * 600) item = Item("UID:" + "hue" * 600)
href, etag = storage.upload(item) href, etag = await storage.upload(item)
assert item.uid not in href assert item.uid not in href
def test_post_hook_inactive(self, tmpdir, monkeypatch): @pytest.mark.asyncio
async def test_post_hook_inactive(self, tmpdir, monkeypatch):
def check_call_mock(*args, **kwargs): def check_call_mock(*args, **kwargs):
raise AssertionError() raise AssertionError()
monkeypatch.setattr(subprocess, "call", check_call_mock) monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), ".txt", post_hook=None) s = self.storage_class(str(tmpdir), ".txt", post_hook=None)
s.upload(Item("UID:a/b/c")) await s.upload(Item("UID:a/b/c"))
def test_post_hook_active(self, tmpdir, monkeypatch):
@pytest.mark.asyncio
async def test_post_hook_active(self, tmpdir, monkeypatch):
calls = [] calls = []
exe = "foo" exe = "foo"
@ -104,14 +112,17 @@ class TestFilesystemStorage(StorageTests):
monkeypatch.setattr(subprocess, "call", check_call_mock) monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), ".txt", post_hook=exe) s = self.storage_class(str(tmpdir), ".txt", post_hook=exe)
s.upload(Item("UID:a/b/c")) await s.upload(Item("UID:a/b/c"))
assert calls assert calls
def test_ignore_git_dirs(self, tmpdir): @pytest.mark.asyncio
async def test_ignore_git_dirs(self, tmpdir):
tmpdir.mkdir(".git").mkdir("foo") tmpdir.mkdir(".git").mkdir("foo")
tmpdir.mkdir("a") tmpdir.mkdir("a")
tmpdir.mkdir("b") tmpdir.mkdir("b")
assert {c["collection"] for c in self.storage_class.discover(str(tmpdir))} == {
"a", expected = {"a", "b"}
"b", actual = {
c["collection"] async for c in self.storage_class.discover(str(tmpdir))
} }
assert actual == expected

View file

@ -1,5 +1,6 @@
import pytest import pytest
from requests import Response from aioresponses import aioresponses
from aioresponses import CallbackResult
from tests import normalize_item from tests import normalize_item
from vdirsyncer.exceptions import UserError from vdirsyncer.exceptions import UserError
@ -7,7 +8,8 @@ from vdirsyncer.storage.http import HttpStorage
from vdirsyncer.storage.http import prepare_auth from vdirsyncer.storage.http import prepare_auth
def test_list(monkeypatch): @pytest.mark.asyncio
async def test_list(aio_connector):
collection_url = "http://127.0.0.1/calendar/collection.ics" collection_url = "http://127.0.0.1/calendar/collection.ics"
items = [ items = [
@ -34,50 +36,53 @@ def test_list(monkeypatch):
responses = ["\n".join(["BEGIN:VCALENDAR"] + items + ["END:VCALENDAR"])] * 2 responses = ["\n".join(["BEGIN:VCALENDAR"] + items + ["END:VCALENDAR"])] * 2
def get(self, method, url, *a, **kw): def callback(url, headers, **kwargs):
assert method == "GET" assert headers["User-Agent"].startswith("vdirsyncer/")
assert url == collection_url
r = Response()
r.status_code = 200
assert responses assert responses
r._content = responses.pop().encode("utf-8")
r.headers["Content-Type"] = "text/calendar"
r.encoding = "ISO-8859-1"
return r
monkeypatch.setattr("requests.sessions.Session.request", get) return CallbackResult(
status=200,
body=responses.pop().encode("utf-8"),
headers={"Content-Type": "text/calendar; charset=iso-8859-1"},
)
s = HttpStorage(url=collection_url) with aioresponses() as m:
m.get(collection_url, callback=callback, repeat=True)
found_items = {} s = HttpStorage(url=collection_url, connector=aio_connector)
for href, etag in s.list(): found_items = {}
item, etag2 = s.get(href)
assert item.uid is not None
assert etag2 == etag
found_items[normalize_item(item)] = href
expected = { async for href, etag in s.list():
normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items item, etag2 = await s.get(href)
} assert item.uid is not None
assert etag2 == etag
found_items[normalize_item(item)] = href
assert set(found_items) == expected expected = {
normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items
}
for href, etag in s.list(): assert set(found_items) == expected
item, etag2 = s.get(href)
assert item.uid is not None async for href, etag in s.list():
assert etag2 == etag item, etag2 = await s.get(href)
assert found_items[normalize_item(item)] == href assert item.uid is not None
assert etag2 == etag
assert found_items[normalize_item(item)] == href
def test_readonly_param(): def test_readonly_param(aio_connector):
"""The ``readonly`` param cannot be ``False``."""
url = "http://example.com/" url = "http://example.com/"
with pytest.raises(ValueError): with pytest.raises(ValueError):
HttpStorage(url=url, read_only=False) HttpStorage(url=url, read_only=False, connector=aio_connector)
a = HttpStorage(url=url, read_only=True).read_only a = HttpStorage(url=url, read_only=True, connector=aio_connector)
b = HttpStorage(url=url, read_only=None).read_only b = HttpStorage(url=url, read_only=None, connector=aio_connector)
assert a is b is True
assert a.read_only is b.read_only is True
def test_prepare_auth(): def test_prepare_auth():
@ -115,9 +120,9 @@ def test_prepare_auth_guess(monkeypatch):
assert "requests_toolbelt is too old" in str(excinfo.value).lower() assert "requests_toolbelt is too old" in str(excinfo.value).lower()
def test_verify_false_disallowed(): def test_verify_false_disallowed(aio_connector):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
HttpStorage(url="http://example.com", verify=False) HttpStorage(url="http://example.com", verify=False, connector=aio_connector)
assert "forbidden" in str(excinfo.value).lower() assert "forbidden" in str(excinfo.value).lower()
assert "consider setting verify_fingerprint" in str(excinfo.value).lower() assert "consider setting verify_fingerprint" in str(excinfo.value).lower()

View file

@ -1,5 +1,7 @@
import aiostream
import pytest import pytest
from requests import Response from aioresponses import aioresponses
from aioresponses import CallbackResult
import vdirsyncer.storage.http import vdirsyncer.storage.http
from . import StorageTests from . import StorageTests
@ -14,32 +16,33 @@ class CombinedStorage(Storage):
_repr_attributes = ("url", "path") _repr_attributes = ("url", "path")
storage_name = "http_and_singlefile" storage_name = "http_and_singlefile"
def __init__(self, url, path, **kwargs): def __init__(self, url, path, *, connector, **kwargs):
if kwargs.get("collection", None) is not None: if kwargs.get("collection", None) is not None:
raise ValueError() raise ValueError()
super().__init__(**kwargs) super().__init__(**kwargs)
self.url = url self.url = url
self.path = path self.path = path
self._reader = vdirsyncer.storage.http.HttpStorage(url=url) self._reader = vdirsyncer.storage.http.HttpStorage(url=url, connector=connector)
self._reader._ignore_uids = False self._reader._ignore_uids = False
self._writer = SingleFileStorage(path=path) self._writer = SingleFileStorage(path=path)
def list(self, *a, **kw): async def list(self, *a, **kw):
return self._reader.list(*a, **kw) async for item in self._reader.list(*a, **kw):
yield item
def get(self, *a, **kw): async def get(self, *a, **kw):
self.list() await aiostream.stream.list(self.list())
return self._reader.get(*a, **kw) return await self._reader.get(*a, **kw)
def upload(self, *a, **kw): async def upload(self, *a, **kw):
return self._writer.upload(*a, **kw) return await self._writer.upload(*a, **kw)
def update(self, *a, **kw): async def update(self, *a, **kw):
return self._writer.update(*a, **kw) return await self._writer.update(*a, **kw)
def delete(self, *a, **kw): async def delete(self, *a, **kw):
return self._writer.delete(*a, **kw) return await self._writer.delete(*a, **kw)
class TestHttpStorage(StorageTests): class TestHttpStorage(StorageTests):
@ -51,28 +54,37 @@ class TestHttpStorage(StorageTests):
def setup_tmpdir(self, tmpdir, monkeypatch): def setup_tmpdir(self, tmpdir, monkeypatch):
self.tmpfile = str(tmpdir.ensure("collection.txt")) self.tmpfile = str(tmpdir.ensure("collection.txt"))
def _request(method, url, *args, **kwargs): def callback(url, headers, **kwargs):
assert method == "GET" """Read our tmpfile at request time.
assert url == "http://localhost:123/collection.txt"
assert "vdirsyncer" in kwargs["headers"]["User-Agent"]
r = Response()
r.status_code = 200
try:
with open(self.tmpfile, "rb") as f:
r._content = f.read()
except OSError:
r._content = b""
r.headers["Content-Type"] = "text/calendar" We can't just read this during test setup since the file get written to
r.encoding = "utf-8" during test execution.
return r
monkeypatch.setattr(vdirsyncer.storage.http, "request", _request) It might make sense to actually run a server serving the local file.
"""
assert headers["User-Agent"].startswith("vdirsyncer/")
with open(self.tmpfile, "r") as f:
body = f.read()
return CallbackResult(
status=200,
body=body,
headers={"Content-Type": "text/calendar; charset=utf-8"},
)
with aioresponses() as m:
m.get("http://localhost:123/collection.txt", callback=callback, repeat=True)
yield
@pytest.fixture @pytest.fixture
def get_storage_args(self): def get_storage_args(self, aio_connector):
def inner(collection=None): async def inner(collection=None):
assert collection is None assert collection is None
return {"url": "http://localhost:123/collection.txt", "path": self.tmpfile} return {
"url": "http://localhost:123/collection.txt",
"path": self.tmpfile,
"connector": aio_connector,
}
return inner return inner

View file

@ -11,4 +11,7 @@ class TestMemoryStorage(StorageTests):
@pytest.fixture @pytest.fixture
def get_storage_args(self): def get_storage_args(self):
return lambda **kw: kw async def inner(**args):
return args
return inner

View file

@ -11,10 +11,10 @@ class TestSingleFileStorage(StorageTests):
@pytest.fixture @pytest.fixture
def get_storage_args(self, tmpdir): def get_storage_args(self, tmpdir):
def inner(collection="test"): async def inner(collection="test"):
rv = {"path": str(tmpdir.join("%s.txt")), "collection": collection} rv = {"path": str(tmpdir.join("%s.txt")), "collection": collection}
if collection is not None: if collection is not None:
rv = self.storage_class.create_collection(**rv) rv = await self.storage_class.create_collection(**rv)
return rv return rv
return inner return inner

View file

@ -1,3 +1,5 @@
import pytest
from vdirsyncer import exceptions from vdirsyncer import exceptions
from vdirsyncer.cli.utils import handle_cli_error from vdirsyncer.cli.utils import handle_cli_error
from vdirsyncer.cli.utils import storage_instance_from_config from vdirsyncer.cli.utils import storage_instance_from_config
@ -15,11 +17,13 @@ def test_handle_cli_error(capsys):
assert "ayy lmao" in err assert "ayy lmao" in err
def test_storage_instance_from_config(monkeypatch): @pytest.mark.asyncio
def lol(**kw): async def test_storage_instance_from_config(monkeypatch, aio_connector):
assert kw == {"foo": "bar", "baz": 1} class Dummy:
return "OK" def __init__(self, **kw):
assert kw == {"foo": "bar", "baz": 1}
monkeypatch.setitem(storage_names._storages, "lol", lol) monkeypatch.setitem(storage_names._storages, "lol", Dummy)
config = {"type": "lol", "foo": "bar", "baz": 1} config = {"type": "lol", "foo": "bar", "baz": 1}
assert storage_instance_from_config(config) == "OK" storage = await storage_instance_from_config(config, connector=aio_connector)
assert isinstance(storage, Dummy)

View file

@ -1,9 +1,9 @@
import logging import logging
import sys import sys
import aiohttp
import click_log import click_log
import pytest import pytest
import requests
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -25,74 +25,77 @@ def test_get_storage_init_args():
assert not required assert not required
def test_request_ssl(): @pytest.mark.asyncio
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: async def test_request_ssl():
http.request("GET", "https://self-signed.badssl.com/") async with aiohttp.ClientSession() as session:
assert "certificate verify failed" in str(excinfo.value) with pytest.raises(aiohttp.ClientConnectorCertificateError) as excinfo:
await http.request(
"GET",
"https://self-signed.badssl.com/",
session=session,
)
assert "certificate verify failed" in str(excinfo.value)
http.request("GET", "https://self-signed.badssl.com/", verify=False) await http.request(
"GET",
"https://self-signed.badssl.com/",
verify=False,
session=session,
)
def _fingerprints_broken(): def fingerprint_of_cert(cert, hash=hashes.SHA256) -> str:
from pkg_resources import parse_version as ver
broken_urllib3 = ver(requests.__version__) <= ver("2.5.1")
return broken_urllib3
def fingerprint_of_cert(cert, hash=hashes.SHA256):
return x509.load_pem_x509_certificate(cert.bytes()).fingerprint(hash()).hex() return x509.load_pem_x509_certificate(cert.bytes()).fingerprint(hash()).hex()
@pytest.mark.skipif( @pytest.mark.parametrize("hash_algorithm", [hashes.SHA256])
_fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529" @pytest.mark.asyncio
) async def test_request_ssl_leaf_fingerprint(
@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256]) httpserver,
def test_request_ssl_leaf_fingerprint(httpserver, localhost_cert, hash_algorithm): localhost_cert,
hash_algorithm,
aio_session,
):
fingerprint = fingerprint_of_cert(localhost_cert.cert_chain_pems[0], hash_algorithm) fingerprint = fingerprint_of_cert(localhost_cert.cert_chain_pems[0], hash_algorithm)
bogus = "".join(reversed(fingerprint))
# We have to serve something: # We have to serve something:
httpserver.expect_request("/").respond_with_data("OK") httpserver.expect_request("/").respond_with_data("OK")
url = f"https://{httpserver.host}:{httpserver.port}/" url = f"https://127.0.0.1:{httpserver.port}/"
http.request("GET", url, verify=False, verify_fingerprint=fingerprint) await http.request("GET", url, verify_fingerprint=fingerprint, session=aio_session)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request("GET", url, verify_fingerprint=fingerprint)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: with pytest.raises(aiohttp.ServerFingerprintMismatch):
http.request( await http.request("GET", url, verify_fingerprint=bogus, session=aio_session)
"GET",
url,
verify=False,
verify_fingerprint="".join(reversed(fingerprint)),
)
assert "Fingerprints did not match" in str(excinfo.value)
@pytest.mark.skipif(
_fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529"
)
@pytest.mark.xfail(reason="Not implemented") @pytest.mark.xfail(reason="Not implemented")
@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256]) @pytest.mark.parametrize("hash_algorithm", [hashes.SHA256])
def test_request_ssl_ca_fingerprint(httpserver, ca, hash_algorithm): @pytest.mark.asyncio
async def test_request_ssl_ca_fingerprints(httpserver, ca, hash_algorithm, aio_session):
fingerprint = fingerprint_of_cert(ca.cert_pem) fingerprint = fingerprint_of_cert(ca.cert_pem)
bogus = "".join(reversed(fingerprint))
# We have to serve something: # We have to serve something:
httpserver.expect_request("/").respond_with_data("OK") httpserver.expect_request("/").respond_with_data("OK")
url = f"https://{httpserver.host}:{httpserver.port}/" url = f"https://127.0.0.1:{httpserver.port}/"
http.request("GET", url, verify=False, verify_fingerprint=fingerprint) await http.request(
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: "GET",
http.request("GET", url, verify_fingerprint=fingerprint) url,
verify=False,
verify_fingerprint=fingerprint,
session=aio_session,
)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: with pytest.raises(aiohttp.ServerFingerprintMismatch):
http.request( http.request(
"GET", "GET",
url, url,
verify=False, verify=False,
verify_fingerprint="".join(reversed(fingerprint)), verify_fingerprint=bogus,
session=aio_session,
) )
assert "Fingerprints did not match" in str(excinfo.value)
def test_open_graphical_browser(monkeypatch): def test_open_graphical_browser(monkeypatch):

View file

@ -1,3 +1,4 @@
import aiostream
import pytest import pytest
from vdirsyncer.cli.discover import expand_collections from vdirsyncer.cli.discover import expand_collections
@ -132,34 +133,40 @@ missing = object()
), ),
], ],
) )
def test_expand_collections(shortcuts, expected): @pytest.mark.asyncio
async def test_expand_collections(shortcuts, expected):
config_a = {"type": "fooboo", "storage_side": "a"} config_a = {"type": "fooboo", "storage_side": "a"}
config_b = {"type": "fooboo", "storage_side": "b"} config_b = {"type": "fooboo", "storage_side": "b"}
def get_discovered_a(): async def get_discovered_a():
return { return {
"c1": {"type": "fooboo", "custom_arg": "a1", "collection": "c1"}, "c1": {"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
"c2": {"type": "fooboo", "custom_arg": "a2", "collection": "c2"}, "c2": {"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
"a3": {"type": "fooboo", "custom_arg": "a3", "collection": "a3"}, "a3": {"type": "fooboo", "custom_arg": "a3", "collection": "a3"},
} }
def get_discovered_b(): async def get_discovered_b():
return { return {
"c1": {"type": "fooboo", "custom_arg": "b1", "collection": "c1"}, "c1": {"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
"c2": {"type": "fooboo", "custom_arg": "b2", "collection": "c2"}, "c2": {"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
"b3": {"type": "fooboo", "custom_arg": "b3", "collection": "b3"}, "b3": {"type": "fooboo", "custom_arg": "b3", "collection": "b3"},
} }
async def handle_not_found(config, collection):
return missing
assert ( assert (
sorted( sorted(
expand_collections( await aiostream.stream.list(
shortcuts, expand_collections(
config_a, shortcuts,
config_b, config_a,
get_discovered_a, config_b,
get_discovered_b, get_discovered_a,
lambda config, collection: missing, get_discovered_b,
handle_not_found,
)
) )
) )
== sorted(expected) == sorted(expected)

View file

@ -1,5 +1,7 @@
import asyncio
from copy import deepcopy from copy import deepcopy
import aiostream
import hypothesis.strategies as st import hypothesis.strategies as st
import pytest import pytest
from hypothesis import assume from hypothesis import assume
@ -21,10 +23,10 @@ from vdirsyncer.sync.status import SqliteStatus
from vdirsyncer.vobject import Item from vdirsyncer.vobject import Item
def sync(a, b, status, *args, **kwargs): async def sync(a, b, status, *args, **kwargs):
new_status = SqliteStatus(":memory:") new_status = SqliteStatus(":memory:")
new_status.load_legacy_status(status) new_status.load_legacy_status(status)
rv = _sync(a, b, new_status, *args, **kwargs) rv = await _sync(a, b, new_status, *args, **kwargs)
status.clear() status.clear()
status.update(new_status.to_legacy_status()) status.update(new_status.to_legacy_status())
return rv return rv
@ -38,45 +40,49 @@ def items(s):
return {x[1].raw for x in s.items.values()} return {x[1].raw for x in s.items.values()}
def test_irrelevant_status(): @pytest.mark.asyncio
async def test_irrelevant_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {"1": ("1", 1234, "1.ics", 2345)} status = {"1": ("1", 1234, "1.ics", 2345)}
sync(a, b, status) await sync(a, b, status)
assert not status assert not status
assert not items(a) assert not items(a)
assert not items(b) assert not items(b)
def test_missing_status(): @pytest.mark.asyncio
async def test_missing_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item = Item("asdf") item = Item("asdf")
a.upload(item) await a.upload(item)
b.upload(item) await b.upload(item)
sync(a, b, status) await sync(a, b, status)
assert len(status) == 1 assert len(status) == 1
assert items(a) == items(b) == {item.raw} assert items(a) == items(b) == {item.raw}
def test_missing_status_and_different_items(): @pytest.mark.asyncio
async def test_missing_status_and_different_items():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item1 = Item("UID:1\nhaha") item1 = Item("UID:1\nhaha")
item2 = Item("UID:1\nhoho") item2 = Item("UID:1\nhoho")
a.upload(item1) await a.upload(item1)
b.upload(item2) await b.upload(item2)
with pytest.raises(SyncConflict): with pytest.raises(SyncConflict):
sync(a, b, status) await sync(a, b, status)
assert not status assert not status
sync(a, b, status, conflict_resolution="a wins") await sync(a, b, status, conflict_resolution="a wins")
assert items(a) == items(b) == {item1.raw} assert items(a) == items(b) == {item1.raw}
def test_read_only_and_prefetch(): @pytest.mark.asyncio
async def test_read_only_and_prefetch():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
b.read_only = True b.read_only = True
@ -84,147 +90,154 @@ def test_read_only_and_prefetch():
status = {} status = {}
item1 = Item("UID:1\nhaha") item1 = Item("UID:1\nhaha")
item2 = Item("UID:2\nhoho") item2 = Item("UID:2\nhoho")
a.upload(item1) await a.upload(item1)
a.upload(item2) await a.upload(item2)
sync(a, b, status, force_delete=True) await sync(a, b, status, force_delete=True)
sync(a, b, status, force_delete=True) await sync(a, b, status, force_delete=True)
assert not items(a) and not items(b) assert not items(a) and not items(b)
def test_partial_sync_error(): @pytest.mark.asyncio
async def test_partial_sync_error():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.upload(Item("UID:0")) await a.upload(Item("UID:0"))
b.read_only = True b.read_only = True
with pytest.raises(PartialSync): with pytest.raises(PartialSync):
sync(a, b, status, partial_sync="error") await sync(a, b, status, partial_sync="error")
def test_partial_sync_ignore(): @pytest.mark.asyncio
async def test_partial_sync_ignore():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item0 = Item("UID:0\nhehe") item0 = Item("UID:0\nhehe")
a.upload(item0) await a.upload(item0)
b.upload(item0) await b.upload(item0)
b.read_only = True b.read_only = True
item1 = Item("UID:1\nhaha") item1 = Item("UID:1\nhaha")
a.upload(item1) await a.upload(item1)
sync(a, b, status, partial_sync="ignore") await sync(a, b, status, partial_sync="ignore")
sync(a, b, status, partial_sync="ignore") await sync(a, b, status, partial_sync="ignore")
assert items(a) == {item0.raw, item1.raw} assert items(a) == {item0.raw, item1.raw}
assert items(b) == {item0.raw} assert items(b) == {item0.raw}
def test_partial_sync_ignore2(): @pytest.mark.asyncio
async def test_partial_sync_ignore2():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item("UID:0")) href, etag = await a.upload(Item("UID:0"))
a.read_only = True a.read_only = True
sync(a, b, status, partial_sync="ignore", force_delete=True) await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {"UID:0"} assert items(b) == items(a) == {"UID:0"}
b.items.clear() b.items.clear()
sync(a, b, status, partial_sync="ignore", force_delete=True) await sync(a, b, status, partial_sync="ignore", force_delete=True)
sync(a, b, status, partial_sync="ignore", force_delete=True) await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(a) == {"UID:0"} assert items(a) == {"UID:0"}
assert not b.items assert not b.items
a.read_only = False a.read_only = False
a.update(href, Item("UID:0\nupdated"), etag) await a.update(href, Item("UID:0\nupdated"), etag)
a.read_only = True a.read_only = True
sync(a, b, status, partial_sync="ignore", force_delete=True) await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {"UID:0\nupdated"} assert items(b) == items(a) == {"UID:0\nupdated"}
def test_upload_and_update(): @pytest.mark.asyncio
async def test_upload_and_update():
a = MemoryStorage(fileext=".a") a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b") b = MemoryStorage(fileext=".b")
status = {} status = {}
item = Item("UID:1") # new item 1 in a item = Item("UID:1") # new item 1 in a
a.upload(item) await a.upload(item)
sync(a, b, status) await sync(a, b, status)
assert items(b) == items(a) == {item.raw} assert items(b) == items(a) == {item.raw}
item = Item("UID:1\nASDF:YES") # update of item 1 in b item = Item("UID:1\nASDF:YES") # update of item 1 in b
b.update("1.b", item, b.get("1.b")[1]) await b.update("1.b", item, (await b.get("1.b"))[1])
sync(a, b, status) await sync(a, b, status)
assert items(b) == items(a) == {item.raw} assert items(b) == items(a) == {item.raw}
item2 = Item("UID:2") # new item 2 in b item2 = Item("UID:2") # new item 2 in b
b.upload(item2) await b.upload(item2)
sync(a, b, status) await sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw} assert items(b) == items(a) == {item.raw, item2.raw}
item2 = Item("UID:2\nASDF:YES") # update of item 2 in a item2 = Item("UID:2\nASDF:YES") # update of item 2 in a
a.update("2.a", item2, a.get("2.a")[1]) await a.update("2.a", item2, (await a.get("2.a"))[1])
sync(a, b, status) await sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw} assert items(b) == items(a) == {item.raw, item2.raw}
def test_deletion(): @pytest.mark.asyncio
async def test_deletion():
a = MemoryStorage(fileext=".a") a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b") b = MemoryStorage(fileext=".b")
status = {} status = {}
item = Item("UID:1") item = Item("UID:1")
a.upload(item) await a.upload(item)
item2 = Item("UID:2") item2 = Item("UID:2")
a.upload(item2) await a.upload(item2)
sync(a, b, status) await sync(a, b, status)
b.delete("1.b", b.get("1.b")[1]) await b.delete("1.b", (await b.get("1.b"))[1])
sync(a, b, status) await sync(a, b, status)
assert items(a) == items(b) == {item2.raw} assert items(a) == items(b) == {item2.raw}
a.upload(item) await a.upload(item)
sync(a, b, status) await sync(a, b, status)
assert items(a) == items(b) == {item.raw, item2.raw} assert items(a) == items(b) == {item.raw, item2.raw}
a.delete("1.a", a.get("1.a")[1]) await a.delete("1.a", (await a.get("1.a"))[1])
sync(a, b, status) await sync(a, b, status)
assert items(a) == items(b) == {item2.raw} assert items(a) == items(b) == {item2.raw}
def test_insert_hash(): @pytest.mark.asyncio
async def test_insert_hash():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item = Item("UID:1") item = Item("UID:1")
href, etag = a.upload(item) href, etag = await a.upload(item)
sync(a, b, status) await sync(a, b, status)
for d in status["1"]: for d in status["1"]:
del d["hash"] del d["hash"]
a.update(href, Item("UID:1\nHAHA:YES"), etag) await a.update(href, Item("UID:1\nHAHA:YES"), etag)
sync(a, b, status) await sync(a, b, status)
assert "hash" in status["1"][0] and "hash" in status["1"][1] assert "hash" in status["1"][0] and "hash" in status["1"][1]
def test_already_synced(): @pytest.mark.asyncio
async def test_already_synced():
a = MemoryStorage(fileext=".a") a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b") b = MemoryStorage(fileext=".b")
item = Item("UID:1") item = Item("UID:1")
a.upload(item) await a.upload(item)
b.upload(item) await b.upload(item)
status = { status = {
"1": ( "1": (
{"href": "1.a", "hash": item.hash, "etag": a.get("1.a")[1]}, {"href": "1.a", "hash": item.hash, "etag": (await a.get("1.a"))[1]},
{"href": "1.b", "hash": item.hash, "etag": b.get("1.b")[1]}, {"href": "1.b", "hash": item.hash, "etag": (await b.get("1.b"))[1]},
) )
} }
old_status = deepcopy(status) old_status = deepcopy(status)
@ -233,69 +246,73 @@ def test_already_synced():
) )
for _ in (1, 2): for _ in (1, 2):
sync(a, b, status) await sync(a, b, status)
assert status == old_status assert status == old_status
assert items(a) == items(b) == {item.raw} assert items(a) == items(b) == {item.raw}
@pytest.mark.parametrize("winning_storage", "ab") @pytest.mark.parametrize("winning_storage", "ab")
def test_conflict_resolution_both_etags_new(winning_storage): @pytest.mark.asyncio
async def test_conflict_resolution_both_etags_new(winning_storage):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item("UID:1") item = Item("UID:1")
href_a, etag_a = a.upload(item) href_a, etag_a = await a.upload(item)
href_b, etag_b = b.upload(item) href_b, etag_b = await b.upload(item)
status = {} status = {}
sync(a, b, status) await sync(a, b, status)
assert status assert status
item_a = Item("UID:1\nitem a") item_a = Item("UID:1\nitem a")
item_b = Item("UID:1\nitem b") item_b = Item("UID:1\nitem b")
a.update(href_a, item_a, etag_a) await a.update(href_a, item_a, etag_a)
b.update(href_b, item_b, etag_b) await b.update(href_b, item_b, etag_b)
with pytest.raises(SyncConflict): with pytest.raises(SyncConflict):
sync(a, b, status) await sync(a, b, status)
sync(a, b, status, conflict_resolution=f"{winning_storage} wins") await sync(a, b, status, conflict_resolution=f"{winning_storage} wins")
assert ( assert (
items(a) == items(b) == {item_a.raw if winning_storage == "a" else item_b.raw} items(a) == items(b) == {item_a.raw if winning_storage == "a" else item_b.raw}
) )
def test_updated_and_deleted(): @pytest.mark.asyncio
async def test_updated_and_deleted():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
href_a, etag_a = a.upload(Item("UID:1")) href_a, etag_a = await a.upload(Item("UID:1"))
status = {} status = {}
sync(a, b, status, force_delete=True) await sync(a, b, status, force_delete=True)
((href_b, etag_b),) = b.list() ((href_b, etag_b),) = await aiostream.stream.list(b.list())
b.delete(href_b, etag_b) await b.delete(href_b, etag_b)
updated = Item("UID:1\nupdated") updated = Item("UID:1\nupdated")
a.update(href_a, updated, etag_a) await a.update(href_a, updated, etag_a)
sync(a, b, status, force_delete=True) await sync(a, b, status, force_delete=True)
assert items(a) == items(b) == {updated.raw} assert items(a) == items(b) == {updated.raw}
def test_conflict_resolution_invalid_mode(): @pytest.mark.asyncio
async def test_conflict_resolution_invalid_mode():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item_a = Item("UID:1\nitem a") item_a = Item("UID:1\nitem a")
item_b = Item("UID:1\nitem b") item_b = Item("UID:1\nitem b")
a.upload(item_a) await a.upload(item_a)
b.upload(item_b) await b.upload(item_b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
sync(a, b, {}, conflict_resolution="yolo") await sync(a, b, {}, conflict_resolution="yolo")
def test_conflict_resolution_new_etags_without_changes(): @pytest.mark.asyncio
async def test_conflict_resolution_new_etags_without_changes():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item("UID:1") item = Item("UID:1")
href_a, etag_a = a.upload(item) href_a, etag_a = await a.upload(item)
href_b, etag_b = b.upload(item) href_b, etag_b = await b.upload(item)
status = {"1": (href_a, "BOGUS_a", href_b, "BOGUS_b")} status = {"1": (href_a, "BOGUS_a", href_b, "BOGUS_b")}
sync(a, b, status) await sync(a, b, status)
((ident, (status_a, status_b)),) = status.items() ((ident, (status_a, status_b)),) = status.items()
assert ident == "1" assert ident == "1"
@ -305,7 +322,8 @@ def test_conflict_resolution_new_etags_without_changes():
assert status_b["etag"] == etag_b assert status_b["etag"] == etag_b
def test_uses_get_multi(monkeypatch): @pytest.mark.asyncio
async def test_uses_get_multi(monkeypatch):
def breakdown(*a, **kw): def breakdown(*a, **kw):
raise AssertionError("Expected use of get_multi") raise AssertionError("Expected use of get_multi")
@ -313,11 +331,11 @@ def test_uses_get_multi(monkeypatch):
old_get = MemoryStorage.get old_get = MemoryStorage.get
def get_multi(self, hrefs): async def get_multi(self, hrefs):
hrefs = list(hrefs) hrefs = list(hrefs)
get_multi_calls.append(hrefs) get_multi_calls.append(hrefs)
for href in hrefs: for href in hrefs:
item, etag = old_get(self, href) item, etag = await old_get(self, href)
yield href, item, etag yield href, item, etag
monkeypatch.setattr(MemoryStorage, "get", breakdown) monkeypatch.setattr(MemoryStorage, "get", breakdown)
@ -326,72 +344,77 @@ def test_uses_get_multi(monkeypatch):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item("UID:1") item = Item("UID:1")
expected_href, etag = a.upload(item) expected_href, etag = await a.upload(item)
sync(a, b, {}) await sync(a, b, {})
assert get_multi_calls == [[expected_href]] assert get_multi_calls == [[expected_href]]
def test_empty_storage_dataloss(): @pytest.mark.asyncio
async def test_empty_storage_dataloss():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.upload(Item("UID:1")) await a.upload(Item("UID:1"))
a.upload(Item("UID:2")) await a.upload(Item("UID:2"))
status = {} status = {}
sync(a, b, status) await sync(a, b, status)
with pytest.raises(StorageEmpty): with pytest.raises(StorageEmpty):
sync(MemoryStorage(), b, status) await sync(MemoryStorage(), b, status)
with pytest.raises(StorageEmpty): with pytest.raises(StorageEmpty):
sync(a, MemoryStorage(), status) await sync(a, MemoryStorage(), status)
def test_no_uids(): @pytest.mark.asyncio
async def test_no_uids():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.upload(Item("ASDF")) await a.upload(Item("ASDF"))
b.upload(Item("FOOBAR")) await b.upload(Item("FOOBAR"))
status = {} status = {}
sync(a, b, status) await sync(a, b, status)
assert items(a) == items(b) == {"ASDF", "FOOBAR"} assert items(a) == items(b) == {"ASDF", "FOOBAR"}
def test_changed_uids(): @pytest.mark.asyncio
async def test_changed_uids():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
href_a, etag_a = a.upload(Item("UID:A-ONE")) href_a, etag_a = await a.upload(Item("UID:A-ONE"))
href_b, etag_b = b.upload(Item("UID:B-ONE")) href_b, etag_b = await b.upload(Item("UID:B-ONE"))
status = {} status = {}
sync(a, b, status) await sync(a, b, status)
a.update(href_a, Item("UID:A-TWO"), etag_a) await a.update(href_a, Item("UID:A-TWO"), etag_a)
sync(a, b, status) await sync(a, b, status)
def test_both_readonly(): @pytest.mark.asyncio
async def test_both_readonly():
a = MemoryStorage(read_only=True) a = MemoryStorage(read_only=True)
b = MemoryStorage(read_only=True) b = MemoryStorage(read_only=True)
assert a.read_only assert a.read_only
assert b.read_only assert b.read_only
status = {} status = {}
with pytest.raises(BothReadOnly): with pytest.raises(BothReadOnly):
sync(a, b, status) await sync(a, b, status)
def test_partial_sync_revert(): @pytest.mark.asyncio
async def test_partial_sync_revert():
a = MemoryStorage(instance_name="a") a = MemoryStorage(instance_name="a")
b = MemoryStorage(instance_name="b") b = MemoryStorage(instance_name="b")
status = {} status = {}
a.upload(Item("UID:1")) await a.upload(Item("UID:1"))
b.upload(Item("UID:2")) await b.upload(Item("UID:2"))
b.read_only = True b.read_only = True
sync(a, b, status, partial_sync="revert") await sync(a, b, status, partial_sync="revert")
assert len(status) == 2 assert len(status) == 2
assert items(a) == {"UID:1", "UID:2"} assert items(a) == {"UID:1", "UID:2"}
assert items(b) == {"UID:2"} assert items(b) == {"UID:2"}
sync(a, b, status, partial_sync="revert") await sync(a, b, status, partial_sync="revert")
assert len(status) == 1 assert len(status) == 1
assert items(a) == {"UID:2"} assert items(a) == {"UID:2"}
assert items(b) == {"UID:2"} assert items(b) == {"UID:2"}
@ -399,37 +422,39 @@ def test_partial_sync_revert():
# Check that updates get reverted # Check that updates get reverted
a.items[next(iter(a.items))] = ("foo", Item("UID:2\nupdated")) a.items[next(iter(a.items))] = ("foo", Item("UID:2\nupdated"))
assert items(a) == {"UID:2\nupdated"} assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync="revert") await sync(a, b, status, partial_sync="revert")
assert len(status) == 1 assert len(status) == 1
assert items(a) == {"UID:2\nupdated"} assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync="revert") await sync(a, b, status, partial_sync="revert")
assert items(a) == {"UID:2"} assert items(a) == {"UID:2"}
# Check that deletions get reverted # Check that deletions get reverted
a.items.clear() a.items.clear()
sync(a, b, status, partial_sync="revert", force_delete=True) await sync(a, b, status, partial_sync="revert", force_delete=True)
sync(a, b, status, partial_sync="revert", force_delete=True) await sync(a, b, status, partial_sync="revert", force_delete=True)
assert items(a) == {"UID:2"} assert items(a) == {"UID:2"}
@pytest.mark.parametrize("sync_inbetween", (True, False)) @pytest.mark.parametrize("sync_inbetween", (True, False))
def test_ident_conflict(sync_inbetween): @pytest.mark.asyncio
async def test_ident_conflict(sync_inbetween):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href_a, etag_a = a.upload(Item("UID:aaa")) href_a, etag_a = await a.upload(Item("UID:aaa"))
href_b, etag_b = a.upload(Item("UID:bbb")) href_b, etag_b = await a.upload(Item("UID:bbb"))
if sync_inbetween: if sync_inbetween:
sync(a, b, status) await sync(a, b, status)
a.update(href_a, Item("UID:xxx"), etag_a) await a.update(href_a, Item("UID:xxx"), etag_a)
a.update(href_b, Item("UID:xxx"), etag_b) await a.update(href_b, Item("UID:xxx"), etag_b)
with pytest.raises(IdentConflict): with pytest.raises(IdentConflict):
sync(a, b, status) await sync(a, b, status)
def test_moved_href(): @pytest.mark.asyncio
async def test_moved_href():
""" """
Concrete application: ppl_ stores contact aliases in filenames, which means Concrete application: ppl_ stores contact aliases in filenames, which means
item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but
@ -440,8 +465,8 @@ def test_moved_href():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item("UID:haha")) href, etag = await a.upload(Item("UID:haha"))
sync(a, b, status) await sync(a, b, status)
b.items["lol"] = b.items.pop("haha") b.items["lol"] = b.items.pop("haha")
@ -451,7 +476,7 @@ def test_moved_href():
# No actual sync actions # No actual sync actions
a.delete = a.update = a.upload = b.delete = b.update = b.upload = blow_up a.delete = a.update = a.upload = b.delete = b.update = b.upload = blow_up
sync(a, b, status) await sync(a, b, status)
assert len(status) == 1 assert len(status) == 1
assert items(a) == items(b) == {"UID:haha"} assert items(a) == items(b) == {"UID:haha"}
assert status["haha"][1]["href"] == "lol" assert status["haha"][1]["href"] == "lol"
@ -460,12 +485,13 @@ def test_moved_href():
# Further sync should be a noop. Not even prefetching should occur. # Further sync should be a noop. Not even prefetching should occur.
b.get_multi = blow_up b.get_multi = blow_up
sync(a, b, status) await sync(a, b, status)
assert old_status == status assert old_status == status
assert items(a) == items(b) == {"UID:haha"} assert items(a) == items(b) == {"UID:haha"}
def test_bogus_etag_change(): @pytest.mark.asyncio
async def test_bogus_etag_change():
"""Assert that sync algorithm is resilient against etag changes if content """Assert that sync algorithm is resilient against etag changes if content
didn\'t change. didn\'t change.
@ -475,27 +501,33 @@ def test_bogus_etag_change():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href_a, etag_a = a.upload(Item("UID:ASDASD")) href_a, etag_a = await a.upload(Item("UID:ASDASD"))
sync(a, b, status) await sync(a, b, status)
assert len(status) == len(list(a.list())) == len(list(b.list())) == 1 assert (
len(status)
== len(await aiostream.stream.list(a.list()))
== len(await aiostream.stream.list(b.list()))
== 1
)
((href_b, etag_b),) = b.list() ((href_b, etag_b),) = await aiostream.stream.list(b.list())
a.update(href_a, Item("UID:ASDASD"), etag_a) await a.update(href_a, Item("UID:ASDASD"), etag_a)
b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b) await b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b)
b.delete = b.update = b.upload = blow_up b.delete = b.update = b.upload = blow_up
sync(a, b, status) await sync(a, b, status)
assert len(status) == 1 assert len(status) == 1
assert items(a) == items(b) == {"UID:ASDASD\nACTUALCHANGE:YES"} assert items(a) == items(b) == {"UID:ASDASD\nACTUALCHANGE:YES"}
def test_unicode_hrefs(): @pytest.mark.asyncio
async def test_unicode_hrefs():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item("UID:äää")) href, etag = await a.upload(Item("UID:äää"))
sync(a, b, status) await sync(a, b, status)
class ActionIntentionallyFailed(Exception): class ActionIntentionallyFailed(Exception):
@ -511,11 +543,12 @@ class SyncMachine(RuleBasedStateMachine):
Storage = Bundle("storage") Storage = Bundle("storage")
@rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans()) @rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans())
@pytest.mark.asyncio
def newstorage(self, flaky_etags, null_etag_on_upload): def newstorage(self, flaky_etags, null_etag_on_upload):
s = MemoryStorage() s = MemoryStorage()
if flaky_etags: if flaky_etags:
def get(href): async def get(href):
old_etag, item = s.items[href] old_etag, item = s.items[href]
etag = _random_string() etag = _random_string()
s.items[href] = etag, item s.items[href] = etag, item
@ -526,8 +559,15 @@ class SyncMachine(RuleBasedStateMachine):
if null_etag_on_upload: if null_etag_on_upload:
_old_upload = s.upload _old_upload = s.upload
_old_update = s.update _old_update = s.update
s.upload = lambda item: (_old_upload(item)[0], "NULL")
s.update = lambda h, i, e: _old_update(h, i, e) and "NULL" async def upload(item):
return ((await _old_upload(item)))[0], "NULL"
async def update(href, item, etag):
return await _old_update(href, item, etag) and "NULL"
s.upload = upload
s.update = update
return s return s
@ -547,11 +587,11 @@ class SyncMachine(RuleBasedStateMachine):
_old_upload = s.upload _old_upload = s.upload
_old_update = s.update _old_update = s.update
def upload(item): async def upload(item):
return _old_upload(item)[0], None return (await _old_upload(item))[0], None
def update(href, item, etag): async def update(href, item, etag):
_old_update(href, item, etag) return await _old_update(href, item, etag)
s.upload = upload s.upload = upload
s.update = update s.update = update
@ -590,66 +630,73 @@ class SyncMachine(RuleBasedStateMachine):
with_error_callback, with_error_callback,
partial_sync, partial_sync,
): ):
assume(a is not b) async def inner():
old_items_a = items(a) assume(a is not b)
old_items_b = items(b) old_items_a = items(a)
old_items_b = items(b)
a.instance_name = "a" a.instance_name = "a"
b.instance_name = "b" b.instance_name = "b"
errors = [] errors = []
if with_error_callback: if with_error_callback:
error_callback = errors.append error_callback = errors.append
else: else:
error_callback = None error_callback = None
try: try:
# If one storage is read-only, double-sync because changes don't # If one storage is read-only, double-sync because changes don't
# get reverted immediately. # get reverted immediately.
for _ in range(2 if a.read_only or b.read_only else 1): for _ in range(2 if a.read_only or b.read_only else 1):
sync( await sync(
a, a,
b, b,
status, status,
force_delete=force_delete, force_delete=force_delete,
conflict_resolution=conflict_resolution, conflict_resolution=conflict_resolution,
error_callback=error_callback, error_callback=error_callback,
partial_sync=partial_sync, partial_sync=partial_sync,
)
for e in errors:
raise e
except PartialSync:
assert partial_sync == "error"
except ActionIntentionallyFailed:
pass
except BothReadOnly:
assert a.read_only and b.read_only
assume(False)
except StorageEmpty:
if force_delete:
raise
else:
not_a = not await aiostream.stream.list(a.list())
not_b = not await aiostream.stream.list(b.list())
assert not_a or not_b
else:
items_a = items(a)
items_b = items(b)
assert items_a == items_b or partial_sync == "ignore"
assert items_a == old_items_a or not a.read_only
assert items_b == old_items_b or not b.read_only
assert (
set(a.items) | set(b.items) == set(status)
or partial_sync == "ignore"
) )
for e in errors: asyncio.run(inner())
raise e
except PartialSync:
assert partial_sync == "error"
except ActionIntentionallyFailed:
pass
except BothReadOnly:
assert a.read_only and b.read_only
assume(False)
except StorageEmpty:
if force_delete:
raise
else:
assert not list(a.list()) or not list(b.list())
else:
items_a = items(a)
items_b = items(b)
assert items_a == items_b or partial_sync == "ignore"
assert items_a == old_items_a or not a.read_only
assert items_b == old_items_b or not b.read_only
assert (
set(a.items) | set(b.items) == set(status) or partial_sync == "ignore"
)
TestSyncMachine = SyncMachine.TestCase TestSyncMachine = SyncMachine.TestCase
@pytest.mark.parametrize("error_callback", [True, False]) @pytest.mark.parametrize("error_callback", [True, False])
def test_rollback(error_callback): @pytest.mark.asyncio
async def test_rollback(error_callback):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
@ -662,7 +709,7 @@ def test_rollback(error_callback):
if error_callback: if error_callback:
errors = [] errors = []
sync( await sync(
a, a,
b, b,
status=status, status=status,
@ -677,16 +724,22 @@ def test_rollback(error_callback):
assert status["1"] assert status["1"]
else: else:
with pytest.raises(ActionIntentionallyFailed): with pytest.raises(ActionIntentionallyFailed):
sync(a, b, status=status, conflict_resolution="a wins") await sync(a, b, status=status, conflict_resolution="a wins")
def test_duplicate_hrefs(): @pytest.mark.asyncio
async def test_duplicate_hrefs():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.list = lambda: [("a", "a")] * 3
async def fake_list():
for item in [("a", "a")] * 3:
yield item
a.list = fake_list
a.items["a"] = ("a", Item("UID:a")) a.items["a"] = ("a", Item("UID:a"))
status = {} status = {}
sync(a, b, status) await sync(a, b, status)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
sync(a, b, status) await sync(a, b, status)

View file

@ -12,105 +12,122 @@ from vdirsyncer.storage.base import normalize_meta_value
from vdirsyncer.storage.memory import MemoryStorage from vdirsyncer.storage.memory import MemoryStorage
def test_irrelevant_status(): @pytest.mark.asyncio
async def test_irrelevant_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {"foo": "bar"} status = {"foo": "bar"}
metasync(a, b, status, keys=()) await metasync(a, b, status, keys=())
assert not status assert not status
def test_basic(monkeypatch): @pytest.mark.asyncio
async def test_basic(monkeypatch):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta("foo", "bar") await a.set_meta("foo", "bar")
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "bar" assert await a.get_meta("foo") == await b.get_meta("foo") == "bar"
a.set_meta("foo", "baz") await a.set_meta("foo", "baz")
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "baz" assert await a.get_meta("foo") == await b.get_meta("foo") == "baz"
monkeypatch.setattr(a, "set_meta", blow_up) monkeypatch.setattr(a, "set_meta", blow_up)
monkeypatch.setattr(b, "set_meta", blow_up) monkeypatch.setattr(b, "set_meta", blow_up)
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "baz" assert await a.get_meta("foo") == await b.get_meta("foo") == "baz"
monkeypatch.undo() monkeypatch.undo()
monkeypatch.undo() monkeypatch.undo()
b.set_meta("foo", None) await b.set_meta("foo", None)
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
assert not a.get_meta("foo") and not b.get_meta("foo") assert not await a.get_meta("foo") and not await b.get_meta("foo")
@pytest.fixture @pytest.fixture
def conflict_state(request): @pytest.mark.asyncio
async def conflict_state(request, event_loop):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta("foo", "bar") await a.set_meta("foo", "bar")
b.set_meta("foo", "baz") await b.set_meta("foo", "baz")
def cleanup(): def cleanup():
assert a.get_meta("foo") == "bar" async def do_cleanup():
assert b.get_meta("foo") == "baz" assert await a.get_meta("foo") == "bar"
assert not status assert await b.get_meta("foo") == "baz"
assert not status
event_loop.run_until_complete(do_cleanup())
request.addfinalizer(cleanup) request.addfinalizer(cleanup)
return a, b, status return a, b, status
def test_conflict(conflict_state): @pytest.mark.asyncio
async def test_conflict(conflict_state):
a, b, status = conflict_state a, b, status = conflict_state
with pytest.raises(MetaSyncConflict): with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
def test_invalid_conflict_resolution(conflict_state): @pytest.mark.asyncio
async def test_invalid_conflict_resolution(conflict_state):
a, b, status = conflict_state a, b, status = conflict_state
with pytest.raises(UserError) as excinfo: with pytest.raises(UserError) as excinfo:
metasync(a, b, status, keys=["foo"], conflict_resolution="foo") await metasync(a, b, status, keys=["foo"], conflict_resolution="foo")
assert "Invalid conflict resolution setting" in str(excinfo.value) assert "Invalid conflict resolution setting" in str(excinfo.value)
def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch): @pytest.mark.asyncio
async def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch):
a, b, status = conflict_state a, b, status = conflict_state
warnings = [] warnings = []
monkeypatch.setattr(logger, "warning", warnings.append) monkeypatch.setattr(logger, "warning", warnings.append)
with pytest.raises(MetaSyncConflict): with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=["foo"], conflict_resolution=lambda *a, **kw: None) await metasync(
a,
b,
status,
keys=["foo"],
conflict_resolution=lambda *a, **kw: None,
)
assert warnings == ["Custom commands don't work on metasync."] assert warnings == ["Custom commands don't work on metasync."]
def test_conflict_same_content(): @pytest.mark.asyncio
async def test_conflict_same_content():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta("foo", "bar") await a.set_meta("foo", "bar")
b.set_meta("foo", "bar") await b.set_meta("foo", "bar")
metasync(a, b, status, keys=["foo"]) await metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == status["foo"] == "bar" assert await a.get_meta("foo") == await b.get_meta("foo") == status["foo"] == "bar"
@pytest.mark.parametrize("wins", "ab") @pytest.mark.parametrize("wins", "ab")
def test_conflict_x_wins(wins): @pytest.mark.asyncio
async def test_conflict_x_wins(wins):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta("foo", "bar") await a.set_meta("foo", "bar")
b.set_meta("foo", "baz") await b.set_meta("foo", "baz")
metasync( await metasync(
a, a,
b, b,
status, status,
@ -119,8 +136,8 @@ def test_conflict_x_wins(wins):
) )
assert ( assert (
a.get_meta("foo") await a.get_meta("foo")
== b.get_meta("foo") == await b.get_meta("foo")
== status["foo"] == status["foo"]
== ("bar" if wins == "a" else "baz") == ("bar" if wins == "a" else "baz")
) )
@ -148,7 +165,8 @@ metadata = st.dictionaries(keys, values)
keys={"0"}, keys={"0"},
conflict_resolution="a wins", conflict_resolution="a wins",
) )
def test_fuzzing(a, b, status, keys, conflict_resolution): @pytest.mark.asyncio
async def test_fuzzing(a, b, status, keys, conflict_resolution):
def _get_storage(m, instance_name): def _get_storage(m, instance_name):
s = MemoryStorage(instance_name=instance_name) s = MemoryStorage(instance_name=instance_name)
s.metadata = m s.metadata = m
@ -159,13 +177,13 @@ def test_fuzzing(a, b, status, keys, conflict_resolution):
winning_storage = a if conflict_resolution == "a wins" else b winning_storage = a if conflict_resolution == "a wins" else b
expected_values = { expected_values = {
key: winning_storage.get_meta(key) for key in keys if key not in status key: await winning_storage.get_meta(key) for key in keys if key not in status
} }
metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution) await metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution)
for key in keys: for key in keys:
s = status.get(key, "") s = status.get(key, "")
assert a.get_meta(key) == b.get_meta(key) == s assert await a.get_meta(key) == await b.get_meta(key) == s
if expected_values.get(key, "") and s: if expected_values.get(key, "") and s:
assert s == expected_values[key] assert s == expected_values[key]

View file

@ -1,3 +1,4 @@
import aiostream
import pytest import pytest
from hypothesis import given from hypothesis import given
from hypothesis import HealthCheck from hypothesis import HealthCheck
@ -15,37 +16,42 @@ from vdirsyncer.vobject import Item
@given(uid=uid_strategy) @given(uid=uid_strategy)
# Using the random module for UIDs: # Using the random module for UIDs:
@settings(suppress_health_check=HealthCheck.all()) @settings(suppress_health_check=HealthCheck.all())
def test_repair_uids(uid): @pytest.mark.asyncio
async def test_repair_uids(uid):
s = MemoryStorage() s = MemoryStorage()
s.items = { s.items = {
"one": ("asdf", Item(f"BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD")), "one": ("asdf", Item(f"BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD")),
"two": ("asdf", Item(f"BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD")), "two": ("asdf", Item(f"BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD")),
} }
uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()] uid1, uid2 = [(await s.get(href))[0].uid async for href, etag in s.list()]
assert uid1 == uid2 assert uid1 == uid2
repair_storage(s, repair_unsafe_uid=False) await repair_storage(s, repair_unsafe_uid=False)
uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()] uid1, uid2 = [
(await s.get(href))[0].uid
for href, etag in await aiostream.stream.list(s.list())
]
assert uid1 != uid2 assert uid1 != uid2
@given(uid=uid_strategy.filter(lambda x: not href_safe(x))) @given(uid=uid_strategy.filter(lambda x: not href_safe(x)))
# Using the random module for UIDs: # Using the random module for UIDs:
@settings(suppress_health_check=HealthCheck.all()) @settings(suppress_health_check=HealthCheck.all())
def test_repair_unsafe_uids(uid): @pytest.mark.asyncio
async def test_repair_unsafe_uids(uid):
s = MemoryStorage() s = MemoryStorage()
item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD") item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD")
href, etag = s.upload(item) href, etag = await s.upload(item)
assert s.get(href)[0].uid == uid assert (await s.get(href))[0].uid == uid
assert not href_safe(uid) assert not href_safe(uid)
repair_storage(s, repair_unsafe_uid=True) await repair_storage(s, repair_unsafe_uid=True)
new_href = list(s.list())[0][0] new_href = (await aiostream.stream.list(s.list()))[0][0]
assert href_safe(new_href) assert href_safe(new_href)
newuid = s.get(new_href)[0].uid newuid = (await s.get(new_href))[0].uid
assert href_safe(newuid) assert href_safe(newuid)

View file

@ -1,8 +1,10 @@
import asyncio
import functools import functools
import json import json
import logging import logging
import sys import sys
import aiohttp
import click import click
import click_log import click_log
@ -124,17 +126,26 @@ def sync(ctx, collections, force_delete):
""" """
from .tasks import prepare_pair, sync_collection from .tasks import prepare_pair, sync_collection
for pair_name, collections in collections: async def main(collections):
for collection, config in prepare_pair( conn = aiohttp.TCPConnector(limit_per_host=16)
pair_name=pair_name,
collections=collections, for pair_name, collections in collections:
config=ctx.config, async for collection, config in prepare_pair(
): pair_name=pair_name,
sync_collection( collections=collections,
collection=collection, config=ctx.config,
general=config, connector=conn,
force_delete=force_delete, ):
) await sync_collection(
collection=collection,
general=config,
force_delete=force_delete,
connector=conn,
)
await conn.close()
asyncio.run(main(collections))
@app.command() @app.command()
@ -149,13 +160,31 @@ def metasync(ctx, collections):
""" """
from .tasks import prepare_pair, metasync_collection from .tasks import prepare_pair, metasync_collection
for pair_name, collections in collections: async def main(collections):
for collection, config in prepare_pair( conn = aiohttp.TCPConnector(limit_per_host=16)
pair_name=pair_name,
collections=collections, for pair_name, collections in collections:
config=ctx.config, collections = prepare_pair(
): pair_name=pair_name,
metasync_collection(collection=collection, general=config) collections=collections,
config=ctx.config,
connector=conn,
)
await asyncio.gather(
*[
metasync_collection(
collection=collection,
general=config,
connector=conn,
)
async for collection, config in collections
]
)
await conn.close()
asyncio.run(main(collections))
@app.command() @app.command()
@ -178,15 +207,23 @@ def discover(ctx, pairs, list):
config = ctx.config config = ctx.config
for pair_name in pairs or config.pairs: async def main():
pair = config.get_pair(pair_name) conn = aiohttp.TCPConnector(limit_per_host=16)
discover_collections( for pair_name in pairs or config.pairs:
status_path=config.general["status_path"], pair = config.get_pair(pair_name)
pair=pair,
from_cache=False, await discover_collections(
list_collections=list, status_path=config.general["status_path"],
) pair=pair,
from_cache=False,
list_collections=list,
connector=conn,
)
await conn.close()
asyncio.run(main())
@app.command() @app.command()
@ -225,7 +262,18 @@ def repair(ctx, collection, repair_unsafe_uid):
"turn off other client's synchronization features." "turn off other client's synchronization features."
) )
click.confirm("Do you want to continue?", abort=True) click.confirm("Do you want to continue?", abort=True)
repair_collection(ctx.config, collection, repair_unsafe_uid=repair_unsafe_uid)
async def main():
conn = aiohttp.TCPConnector(limit_per_host=16)
await repair_collection(
ctx.config,
collection,
repair_unsafe_uid=repair_unsafe_uid,
connector=conn,
)
await conn.close()
asyncio.run(main())
@app.command() @app.command()

View file

@ -1,10 +1,13 @@
import asyncio
import hashlib import hashlib
import json import json
import logging import logging
import sys import sys
import aiohttp
import aiostream
from .. import exceptions from .. import exceptions
from ..utils import cached_property
from .utils import handle_collection_not_found from .utils import handle_collection_not_found
from .utils import handle_storage_init_error from .utils import handle_storage_init_error
from .utils import load_status from .utils import load_status
@ -35,7 +38,14 @@ def _get_collections_cache_key(pair):
return m.hexdigest() return m.hexdigest()
def collections_for_pair(status_path, pair, from_cache=True, list_collections=False): async def collections_for_pair(
status_path,
pair,
from_cache=True,
list_collections=False,
*,
connector: aiohttp.TCPConnector,
):
"""Determine all configured collections for a given pair. Takes care of """Determine all configured collections for a given pair. Takes care of
shortcut expansion and result caching. shortcut expansion and result caching.
@ -67,16 +77,24 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
logger.info("Discovering collections for pair {}".format(pair.name)) logger.info("Discovering collections for pair {}".format(pair.name))
a_discovered = _DiscoverResult(pair.config_a) a_discovered = _DiscoverResult(pair.config_a, connector=connector)
b_discovered = _DiscoverResult(pair.config_b) b_discovered = _DiscoverResult(pair.config_b, connector=connector)
if list_collections: if list_collections:
_print_collections(pair.config_a["instance_name"], a_discovered.get_self) await _print_collections(
_print_collections(pair.config_b["instance_name"], b_discovered.get_self) pair.config_a["instance_name"],
a_discovered.get_self,
connector=connector,
)
await _print_collections(
pair.config_b["instance_name"],
b_discovered.get_self,
connector=connector,
)
# We have to use a list here because the special None/null value would get # We have to use a list here because the special None/null value would get
# mangled to string (because JSON objects always have string keys). # mangled to string (because JSON objects always have string keys).
rv = list( rv = await aiostream.stream.list(
expand_collections( expand_collections(
shortcuts=pair.collections, shortcuts=pair.collections,
config_a=pair.config_a, config_a=pair.config_a,
@ -87,7 +105,7 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
) )
) )
_sanity_check_collections(rv) await _sanity_check_collections(rv, connector=connector)
save_status( save_status(
status_path, status_path,
@ -103,10 +121,14 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
return rv return rv
def _sanity_check_collections(collections): async def _sanity_check_collections(collections, *, connector):
tasks = []
for _, (a_args, b_args) in collections: for _, (a_args, b_args) in collections:
storage_instance_from_config(a_args) tasks.append(storage_instance_from_config(a_args, connector=connector))
storage_instance_from_config(b_args) tasks.append(storage_instance_from_config(b_args, connector=connector))
await asyncio.gather(*tasks)
def _compress_collections_cache(collections, config_a, config_b): def _compress_collections_cache(collections, config_a, config_b):
@ -134,17 +156,28 @@ def _expand_collections_cache(collections, config_a, config_b):
class _DiscoverResult: class _DiscoverResult:
def __init__(self, config): def __init__(self, config, *, connector):
self._cls, _ = storage_class_from_config(config) self._cls, _ = storage_class_from_config(config)
self._config = config
def get_self(self): if self._cls.__name__ in [
"CardDAVStorage",
"CalDAVStorage",
"GoogleCalendarStorage",
]:
assert connector is not None
config["connector"] = connector
self._config = config
self._discovered = None
async def get_self(self):
if self._discovered is None:
self._discovered = await self._discover()
return self._discovered return self._discovered
@cached_property async def _discover(self):
def _discovered(self):
try: try:
discovered = list(self._cls.discover(**self._config)) discovered = await aiostream.stream.list(self._cls.discover(**self._config))
except NotImplementedError: except NotImplementedError:
return {} return {}
except Exception: except Exception:
@ -158,7 +191,7 @@ class _DiscoverResult:
return rv return rv
def expand_collections( async def expand_collections(
shortcuts, shortcuts,
config_a, config_a,
config_b, config_b,
@ -173,9 +206,9 @@ def expand_collections(
for shortcut in shortcuts: for shortcut in shortcuts:
if shortcut == "from a": if shortcut == "from a":
collections = get_a_discovered() collections = await get_a_discovered()
elif shortcut == "from b": elif shortcut == "from b":
collections = get_b_discovered() collections = await get_b_discovered()
else: else:
collections = [shortcut] collections = [shortcut]
@ -189,17 +222,23 @@ def expand_collections(
continue continue
handled_collections.add(collection) handled_collections.add(collection)
a_args = _collection_from_discovered( a_args = await _collection_from_discovered(
get_a_discovered, collection_a, config_a, _handle_collection_not_found get_a_discovered,
collection_a,
config_a,
_handle_collection_not_found,
) )
b_args = _collection_from_discovered( b_args = await _collection_from_discovered(
get_b_discovered, collection_b, config_b, _handle_collection_not_found get_b_discovered,
collection_b,
config_b,
_handle_collection_not_found,
) )
yield collection, (a_args, b_args) yield collection, (a_args, b_args)
def _collection_from_discovered( async def _collection_from_discovered(
get_discovered, collection, config, _handle_collection_not_found get_discovered, collection, config, _handle_collection_not_found
): ):
if collection is None: if collection is None:
@ -208,14 +247,19 @@ def _collection_from_discovered(
return args return args
try: try:
return get_discovered()[collection] return (await get_discovered())[collection]
except KeyError: except KeyError:
return _handle_collection_not_found(config, collection) return await _handle_collection_not_found(config, collection)
def _print_collections(instance_name, get_discovered): async def _print_collections(
instance_name: str,
get_discovered,
*,
connector: aiohttp.TCPConnector,
):
try: try:
discovered = get_discovered() discovered = await get_discovered()
except exceptions.UserError: except exceptions.UserError:
raise raise
except Exception: except Exception:
@ -238,8 +282,12 @@ def _print_collections(instance_name, get_discovered):
args["instance_name"] = instance_name args["instance_name"] = instance_name
try: try:
storage = storage_instance_from_config(args, create=False) storage = await storage_instance_from_config(
displayname = storage.get_meta("displayname") args,
create=False,
connector=connector,
)
displayname = await storage.get_meta("displayname")
except Exception: except Exception:
displayname = "" displayname = ""

View file

@ -1,5 +1,7 @@
import json import json
import aiohttp
from .. import exceptions from .. import exceptions
from .. import sync from .. import sync
from .config import CollectionConfig from .config import CollectionConfig
@ -15,11 +17,15 @@ from .utils import manage_sync_status
from .utils import save_status from .utils import save_status
def prepare_pair(pair_name, collections, config): async def prepare_pair(pair_name, collections, config, *, connector):
pair = config.get_pair(pair_name) pair = config.get_pair(pair_name)
all_collections = dict( all_collections = dict(
collections_for_pair(status_path=config.general["status_path"], pair=pair) await collections_for_pair(
status_path=config.general["status_path"],
pair=pair,
connector=connector,
)
) )
for collection_name in collections or all_collections: for collection_name in collections or all_collections:
@ -37,15 +43,21 @@ def prepare_pair(pair_name, collections, config):
yield collection, config.general yield collection, config.general
def sync_collection(collection, general, force_delete): async def sync_collection(
collection,
general,
force_delete,
*,
connector: aiohttp.TCPConnector,
):
pair = collection.pair pair = collection.pair
status_name = get_status_name(pair.name, collection.name) status_name = get_status_name(pair.name, collection.name)
try: try:
cli_logger.info(f"Syncing {status_name}") cli_logger.info(f"Syncing {status_name}")
a = storage_instance_from_config(collection.config_a) a = await storage_instance_from_config(collection.config_a, connector=connector)
b = storage_instance_from_config(collection.config_b) b = await storage_instance_from_config(collection.config_b, connector=connector)
sync_failed = False sync_failed = False
@ -57,7 +69,7 @@ def sync_collection(collection, general, force_delete):
with manage_sync_status( with manage_sync_status(
general["status_path"], pair.name, collection.name general["status_path"], pair.name, collection.name
) as status: ) as status:
sync.sync( await sync.sync(
a, a,
b, b,
status, status,
@ -76,9 +88,9 @@ def sync_collection(collection, general, force_delete):
raise JobFailed() raise JobFailed()
def discover_collections(pair, **kwargs): async def discover_collections(pair, **kwargs):
rv = collections_for_pair(pair=pair, **kwargs) rv = await collections_for_pair(pair=pair, **kwargs)
collections = list(c for c, (a, b) in rv) collections = [c for c, (a, b) in rv]
if collections == [None]: if collections == [None]:
collections = None collections = None
cli_logger.info( cli_logger.info(
@ -86,7 +98,13 @@ def discover_collections(pair, **kwargs):
) )
def repair_collection(config, collection, repair_unsafe_uid): async def repair_collection(
config,
collection,
repair_unsafe_uid,
*,
connector: aiohttp.TCPConnector,
):
from ..repair import repair_storage from ..repair import repair_storage
storage_name, collection = collection, None storage_name, collection = collection, None
@ -99,7 +117,7 @@ def repair_collection(config, collection, repair_unsafe_uid):
if collection is not None: if collection is not None:
cli_logger.info("Discovering collections (skipping cache).") cli_logger.info("Discovering collections (skipping cache).")
cls, config = storage_class_from_config(config) cls, config = storage_class_from_config(config)
for config in cls.discover(**config): async for config in cls.discover(**config):
if config["collection"] == collection: if config["collection"] == collection:
break break
else: else:
@ -110,14 +128,14 @@ def repair_collection(config, collection, repair_unsafe_uid):
) )
config["type"] = storage_type config["type"] = storage_type
storage = storage_instance_from_config(config) storage = await storage_instance_from_config(config, connector=connector)
cli_logger.info(f"Repairing {storage_name}/{collection}") cli_logger.info(f"Repairing {storage_name}/{collection}")
cli_logger.warning("Make sure no other program is talking to the server.") cli_logger.warning("Make sure no other program is talking to the server.")
repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid) await repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid)
def metasync_collection(collection, general): async def metasync_collection(collection, general, *, connector: aiohttp.TCPConnector):
from ..metasync import metasync from ..metasync import metasync
pair = collection.pair pair = collection.pair
@ -133,10 +151,10 @@ def metasync_collection(collection, general):
or {} or {}
) )
a = storage_instance_from_config(collection.config_a) a = await storage_instance_from_config(collection.config_a, connector=connector)
b = storage_instance_from_config(collection.config_b) b = await storage_instance_from_config(collection.config_b, connector=connector)
metasync( await metasync(
a, a,
b, b,
status, status,

View file

@ -5,6 +5,7 @@ import json
import os import os
import sys import sys
import aiohttp
import click import click
from atomicwrites import atomic_write from atomicwrites import atomic_write
@ -252,22 +253,37 @@ def storage_class_from_config(config):
return cls, config return cls, config
def storage_instance_from_config(config, create=True): async def storage_instance_from_config(
config,
create=True,
*,
connector: aiohttp.TCPConnector,
):
""" """
:param config: A configuration dictionary to pass as kwargs to the class :param config: A configuration dictionary to pass as kwargs to the class
corresponding to config['type'] corresponding to config['type']
""" """
from vdirsyncer.storage.dav import DAVStorage
from vdirsyncer.storage.http import HttpStorage
cls, new_config = storage_class_from_config(config) cls, new_config = storage_class_from_config(config)
if issubclass(cls, DAVStorage) or issubclass(cls, HttpStorage):
assert connector is not None # FIXME: hack?
new_config["connector"] = connector
try: try:
return cls(**new_config) return cls(**new_config)
except exceptions.CollectionNotFound as e: except exceptions.CollectionNotFound as e:
if create: if create:
config = handle_collection_not_found( config = await handle_collection_not_found(
config, config.get("collection", None), e=str(e) config, config.get("collection", None), e=str(e)
) )
return storage_instance_from_config(config, create=False) return await storage_instance_from_config(
config,
create=False,
connector=connector,
)
else: else:
raise raise
except Exception: except Exception:
@ -319,7 +335,7 @@ def assert_permissions(path, wanted):
os.chmod(path, wanted) os.chmod(path, wanted)
def handle_collection_not_found(config, collection, e=None): async def handle_collection_not_found(config, collection, e=None):
storage_name = config.get("instance_name", None) storage_name = config.get("instance_name", None)
cli_logger.warning( cli_logger.warning(
@ -333,7 +349,7 @@ def handle_collection_not_found(config, collection, e=None):
cls, config = storage_class_from_config(config) cls, config = storage_class_from_config(config)
config["collection"] = collection config["collection"] = collection
try: try:
args = cls.create_collection(**config) args = await cls.create_collection(**config)
args["type"] = storage_type args["type"] = storage_type
return args return args
except NotImplementedError as e: except NotImplementedError as e:

View file

@ -1,6 +1,6 @@
import logging import logging
import requests import aiohttp
from . import __version__ from . import __version__
from . import DOCS_HOME from . import DOCS_HOME
@ -99,23 +99,8 @@ def prepare_client_cert(cert):
return cert return cert
def _install_fingerprint_adapter(session, fingerprint): async def request(
prefix = "https://" method, url, session, latin1_fallback=True, verify_fingerprint=None, **kwargs
try:
from requests_toolbelt.adapters.fingerprint import FingerprintAdapter
except ImportError:
raise RuntimeError(
"`verify_fingerprint` can only be used with "
"requests-toolbelt versions >= 0.4.0"
)
if not isinstance(session.adapters[prefix], FingerprintAdapter):
fingerprint_adapter = FingerprintAdapter(fingerprint)
session.mount(prefix, fingerprint_adapter)
def request(
method, url, session=None, latin1_fallback=True, verify_fingerprint=None, **kwargs
): ):
""" """
Wrapper method for requests, to ease logging and mocking. Parameters should Wrapper method for requests, to ease logging and mocking. Parameters should
@ -132,16 +117,20 @@ def request(
https://github.com/kennethreitz/requests/issues/2042 https://github.com/kennethreitz/requests/issues/2042
""" """
if session is None:
session = requests.Session()
if verify_fingerprint is not None: if verify_fingerprint is not None:
_install_fingerprint_adapter(session, verify_fingerprint) ssl = aiohttp.Fingerprint(bytes.fromhex(verify_fingerprint.replace(":", "")))
kwargs.pop("verify", None)
elif kwargs.pop("verify", None) is False:
ssl = False
else:
ssl = None # TODO XXX: Check all possible values for this
session.hooks = {"response": _fix_redirects} session.hooks = {"response": _fix_redirects}
func = session.request func = session.request
# TODO: rewrite using
# https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing
logger.debug("=" * 20) logger.debug("=" * 20)
logger.debug(f"{method} {url}") logger.debug(f"{method} {url}")
logger.debug(kwargs.get("headers", {})) logger.debug(kwargs.get("headers", {}))
@ -150,7 +139,14 @@ def request(
assert isinstance(kwargs.get("data", b""), bytes) assert isinstance(kwargs.get("data", b""), bytes)
r = func(method, url, **kwargs) kwargs.pop("cert", None) # TODO XXX FIXME!
# Hacks to translate API
if auth := kwargs.pop("auth", None):
kwargs["auth"] = aiohttp.BasicAuth(*auth)
r = func(method, url, ssl=ssl, **kwargs)
r = await r
# See https://github.com/kennethreitz/requests/issues/2042 # See https://github.com/kennethreitz/requests/issues/2042
content_type = r.headers.get("Content-Type", "") content_type = r.headers.get("Content-Type", "")
@ -162,13 +158,13 @@ def request(
logger.debug("Removing latin1 fallback") logger.debug("Removing latin1 fallback")
r.encoding = None r.encoding = None
logger.debug(r.status_code) logger.debug(r.status)
logger.debug(r.headers) logger.debug(r.headers)
logger.debug(r.content) logger.debug(r.content)
if r.status_code == 412: if r.status == 412:
raise exceptions.PreconditionFailed(r.reason) raise exceptions.PreconditionFailed(r.reason)
if r.status_code in (404, 410): if r.status in (404, 410):
raise exceptions.NotFoundError(r.reason) raise exceptions.NotFoundError(r.reason)
r.raise_for_status() r.raise_for_status()

View file

@ -14,24 +14,24 @@ class MetaSyncConflict(MetaSyncError):
key = None key = None
def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): async def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
def _a_to_b(): async def _a_to_b():
logger.info(f"Copying {key} to {storage_b}") logger.info(f"Copying {key} to {storage_b}")
storage_b.set_meta(key, a) await storage_b.set_meta(key, a)
status[key] = a status[key] = a
def _b_to_a(): async def _b_to_a():
logger.info(f"Copying {key} to {storage_a}") logger.info(f"Copying {key} to {storage_a}")
storage_a.set_meta(key, b) await storage_a.set_meta(key, b)
status[key] = b status[key] = b
def _resolve_conflict(): async def _resolve_conflict():
if a == b: if a == b:
status[key] = a status[key] = a
elif conflict_resolution == "a wins": elif conflict_resolution == "a wins":
_a_to_b() await _a_to_b()
elif conflict_resolution == "b wins": elif conflict_resolution == "b wins":
_b_to_a() await _b_to_a()
else: else:
if callable(conflict_resolution): if callable(conflict_resolution):
logger.warning("Custom commands don't work on metasync.") logger.warning("Custom commands don't work on metasync.")
@ -40,8 +40,8 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
raise MetaSyncConflict(key) raise MetaSyncConflict(key)
for key in keys: for key in keys:
a = storage_a.get_meta(key) a = await storage_a.get_meta(key)
b = storage_b.get_meta(key) b = await storage_b.get_meta(key)
s = normalize_meta_value(status.get(key)) s = normalize_meta_value(status.get(key))
logger.debug(f"Key: {key}") logger.debug(f"Key: {key}")
logger.debug(f"A: {a}") logger.debug(f"A: {a}")
@ -49,11 +49,11 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
logger.debug(f"S: {s}") logger.debug(f"S: {s}")
if a != s and b != s: if a != s and b != s:
_resolve_conflict() await _resolve_conflict()
elif a != s and b == s: elif a != s and b == s:
_a_to_b() await _a_to_b()
elif a == s and b != s: elif a == s and b != s:
_b_to_a() await _b_to_a()
else: else:
assert a == b assert a == b

View file

@ -1,6 +1,8 @@
import logging import logging
from os.path import basename from os.path import basename
import aiostream
from .utils import generate_href from .utils import generate_href
from .utils import href_safe from .utils import href_safe
@ -11,11 +13,11 @@ class IrreparableItem(Exception):
pass pass
def repair_storage(storage, repair_unsafe_uid): async def repair_storage(storage, repair_unsafe_uid):
seen_uids = set() seen_uids = set()
all_hrefs = list(storage.list()) all_hrefs = await aiostream.stream.list(storage.list())
for i, (href, _) in enumerate(all_hrefs): for i, (href, _) in enumerate(all_hrefs):
item, etag = storage.get(href) item, etag = await storage.get(href)
logger.info("[{}/{}] Processing {}".format(i, len(all_hrefs), href)) logger.info("[{}/{}] Processing {}".format(i, len(all_hrefs), href))
try: try:
@ -32,10 +34,10 @@ def repair_storage(storage, repair_unsafe_uid):
seen_uids.add(new_item.uid) seen_uids.add(new_item.uid)
if new_item.raw != item.raw: if new_item.raw != item.raw:
if new_item.uid != item.uid: if new_item.uid != item.uid:
storage.upload(new_item) await storage.upload(new_item)
storage.delete(href, etag) await storage.delete(href, etag)
else: else:
storage.update(href, new_item, etag) await storage.update(href, new_item, etag)
def repair_item(href, item, seen_uids, repair_unsafe_uid): def repair_item(href, item, seen_uids, repair_unsafe_uid):

View file

@ -7,10 +7,10 @@ from ..utils import uniq
def mutating_storage_method(f): def mutating_storage_method(f):
@functools.wraps(f) @functools.wraps(f)
def inner(self, *args, **kwargs): async def inner(self, *args, **kwargs):
if self.read_only: if self.read_only:
raise exceptions.ReadOnlyError("This storage is read-only.") raise exceptions.ReadOnlyError("This storage is read-only.")
return f(self, *args, **kwargs) return await f(self, *args, **kwargs)
return inner return inner
@ -77,7 +77,7 @@ class Storage(metaclass=StorageMeta):
self.collection = collection self.collection = collection
@classmethod @classmethod
def discover(cls, **kwargs): async def discover(cls, **kwargs):
"""Discover collections given a basepath or -URL to many collections. """Discover collections given a basepath or -URL to many collections.
:param **kwargs: Keyword arguments to additionally pass to the storage :param **kwargs: Keyword arguments to additionally pass to the storage
@ -92,10 +92,12 @@ class Storage(metaclass=StorageMeta):
from the last segment of a URL or filesystem path. from the last segment of a URL or filesystem path.
""" """
if False:
yield # Needs to be an async generator
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): async def create_collection(cls, collection, **kwargs):
""" """
Create the specified collection and return the new arguments. Create the specified collection and return the new arguments.
@ -118,13 +120,13 @@ class Storage(metaclass=StorageMeta):
{x: getattr(self, x) for x in self._repr_attributes}, {x: getattr(self, x) for x in self._repr_attributes},
) )
def list(self): async def list(self):
""" """
:returns: list of (href, etag) :returns: list of (href, etag)
""" """
raise NotImplementedError() raise NotImplementedError()
def get(self, href): async def get(self, href):
"""Fetch a single item. """Fetch a single item.
:param href: href to fetch :param href: href to fetch
@ -134,7 +136,7 @@ class Storage(metaclass=StorageMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_multi(self, hrefs): async def get_multi(self, hrefs):
"""Fetch multiple items. Duplicate hrefs must be ignored. """Fetch multiple items. Duplicate hrefs must be ignored.
Functionally similar to :py:meth:`get`, but might bring performance Functionally similar to :py:meth:`get`, but might bring performance
@ -146,22 +148,22 @@ class Storage(metaclass=StorageMeta):
:returns: iterable of (href, item, etag) :returns: iterable of (href, item, etag)
""" """
for href in uniq(hrefs): for href in uniq(hrefs):
item, etag = self.get(href) item, etag = await self.get(href)
yield href, item, etag yield href, item, etag
def has(self, href): async def has(self, href):
"""Check if an item exists by its href. """Check if an item exists by its href.
:returns: True or False :returns: True or False
""" """
try: try:
self.get(href) await self.get(href)
except exceptions.PreconditionFailed: except exceptions.PreconditionFailed:
return False return False
else: else:
return True return True
def upload(self, item): async def upload(self, item):
"""Upload a new item. """Upload a new item.
In cases where the new etag cannot be atomically determined (i.e. in In cases where the new etag cannot be atomically determined (i.e. in
@ -176,7 +178,7 @@ class Storage(metaclass=StorageMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
def update(self, href, item, etag): async def update(self, href, item, etag):
"""Update an item. """Update an item.
The etag may be none in some cases, see `upload`. The etag may be none in some cases, see `upload`.
@ -189,7 +191,7 @@ class Storage(metaclass=StorageMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
def delete(self, href, etag): async def delete(self, href, etag):
"""Delete an item by href. """Delete an item by href.
:raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has :raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has
@ -197,8 +199,8 @@ class Storage(metaclass=StorageMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@contextlib.contextmanager @contextlib.asynccontextmanager
def at_once(self): async def at_once(self):
"""A contextmanager that buffers all writes. """A contextmanager that buffers all writes.
Essentially, this:: Essentially, this::
@ -217,7 +219,7 @@ class Storage(metaclass=StorageMeta):
""" """
yield yield
def get_meta(self, key): async def get_meta(self, key):
"""Get metadata value for collection/storage. """Get metadata value for collection/storage.
See the vdir specification for the keys that *have* to be accepted. See the vdir specification for the keys that *have* to be accepted.
@ -228,7 +230,7 @@ class Storage(metaclass=StorageMeta):
raise NotImplementedError("This storage does not support metadata.") raise NotImplementedError("This storage does not support metadata.")
def set_meta(self, key, value): async def set_meta(self, key, value):
"""Get metadata value for collection/storage. """Get metadata value for collection/storage.
:param key: The metadata key. :param key: The metadata key.

View file

@ -5,8 +5,8 @@ import xml.etree.ElementTree as etree
from inspect import getfullargspec from inspect import getfullargspec
from inspect import signature from inspect import signature
import requests import aiohttp
from requests.exceptions import HTTPError import aiostream
from .. import exceptions from .. import exceptions
from .. import http from .. import http
@ -18,6 +18,7 @@ from ..http import USERAGENT
from ..vobject import Item from ..vobject import Item
from .base import normalize_meta_value from .base import normalize_meta_value
from .base import Storage from .base import Storage
from vdirsyncer.exceptions import Error
dav_logger = logging.getLogger(__name__) dav_logger = logging.getLogger(__name__)
@ -44,10 +45,10 @@ def _contains_quoted_reserved_chars(x):
return False return False
def _assert_multistatus_success(r): async def _assert_multistatus_success(r):
# Xandikos returns a multistatus on PUT. # Xandikos returns a multistatus on PUT.
try: try:
root = _parse_xml(r.content) root = _parse_xml(await r.content.read())
except InvalidXMLResponse: except InvalidXMLResponse:
return return
for status in root.findall(".//{DAV:}status"): for status in root.findall(".//{DAV:}status"):
@ -57,7 +58,7 @@ def _assert_multistatus_success(r):
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
if st < 200 or st >= 400: if st < 200 or st >= 400:
raise HTTPError(f"Server error: {st}") raise Error(f"Server error: {st}")
def _normalize_href(base, href): def _normalize_href(base, href):
@ -169,14 +170,14 @@ class Discover:
_, collection = url.rstrip("/").rsplit("/", 1) _, collection = url.rstrip("/").rsplit("/", 1)
return urlparse.unquote(collection) return urlparse.unquote(collection)
def find_principal(self): async def find_principal(self):
try: try:
return self._find_principal_impl("") return await self._find_principal_impl("")
except (HTTPError, exceptions.Error): except (aiohttp.ClientResponseError, exceptions.Error):
dav_logger.debug("Trying out well-known URI") dav_logger.debug("Trying out well-known URI")
return self._find_principal_impl(self._well_known_uri) return await self._find_principal_impl(self._well_known_uri)
def _find_principal_impl(self, url): async def _find_principal_impl(self, url):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Depth"] = "0" headers["Depth"] = "0"
body = b""" body = b"""
@ -187,9 +188,14 @@ class Discover:
</propfind> </propfind>
""" """
response = self.session.request("PROPFIND", url, headers=headers, data=body) response = await self.session.request(
"PROPFIND",
url,
headers=headers,
data=body,
)
root = _parse_xml(response.content) root = _parse_xml(await response.content.read())
rv = root.find(".//{DAV:}current-user-principal/{DAV:}href") rv = root.find(".//{DAV:}current-user-principal/{DAV:}href")
if rv is None: if rv is None:
# This is for servers that don't support current-user-principal # This is for servers that don't support current-user-principal
@ -201,34 +207,37 @@ class Discover:
) )
) )
return response.url return response.url
return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/" return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/"
def find_home(self): async def find_home(self):
url = self.find_principal() url = await self.find_principal()
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Depth"] = "0" headers["Depth"] = "0"
response = self.session.request( response = await self.session.request(
"PROPFIND", url, headers=headers, data=self._homeset_xml "PROPFIND", url, headers=headers, data=self._homeset_xml
) )
root = etree.fromstring(response.content) root = etree.fromstring(await response.content.read())
# Better don't do string formatting here, because of XML namespaces # Better don't do string formatting here, because of XML namespaces
rv = root.find(".//" + self._homeset_tag + "/{DAV:}href") rv = root.find(".//" + self._homeset_tag + "/{DAV:}href")
if rv is None: if rv is None:
raise InvalidXMLResponse("Couldn't find home-set.") raise InvalidXMLResponse("Couldn't find home-set.")
return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/" return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/"
def find_collections(self): async def find_collections(self):
rv = None rv = None
try: try:
rv = list(self._find_collections_impl("")) rv = await aiostream.stream.list(self._find_collections_impl(""))
except (HTTPError, exceptions.Error): except (aiohttp.ClientResponseError, exceptions.Error):
pass pass
if rv: if rv:
return rv return rv
dav_logger.debug("Given URL is not a homeset URL") dav_logger.debug("Given URL is not a homeset URL")
return self._find_collections_impl(self.find_home()) return await aiostream.stream.list(
self._find_collections_impl(await self.find_home())
)
def _check_collection_resource_type(self, response): def _check_collection_resource_type(self, response):
if self._resourcetype is None: if self._resourcetype is None:
@ -245,13 +254,13 @@ class Discover:
return False return False
return True return True
def _find_collections_impl(self, url): async def _find_collections_impl(self, url):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Depth"] = "1" headers["Depth"] = "1"
r = self.session.request( r = await self.session.request(
"PROPFIND", url, headers=headers, data=self._collection_xml "PROPFIND", url, headers=headers, data=self._collection_xml
) )
root = _parse_xml(r.content) root = _parse_xml(await r.content.read())
done = set() done = set()
for response in root.findall("{DAV:}response"): for response in root.findall("{DAV:}response"):
if not self._check_collection_resource_type(response): if not self._check_collection_resource_type(response):
@ -260,33 +269,33 @@ class Discover:
href = response.find("{DAV:}href") href = response.find("{DAV:}href")
if href is None: if href is None:
raise InvalidXMLResponse("Missing href tag for collection " "props.") raise InvalidXMLResponse("Missing href tag for collection " "props.")
href = urlparse.urljoin(r.url, href.text) href = urlparse.urljoin(str(r.url), href.text)
if href not in done: if href not in done:
done.add(href) done.add(href)
yield {"href": href} yield {"href": href}
def discover(self): async def discover(self):
for c in self.find_collections(): for c in await self.find_collections():
url = c["href"] url = c["href"]
collection = self._get_collection_from_url(url) collection = self._get_collection_from_url(url)
storage_args = dict(self.kwargs) storage_args = dict(self.kwargs)
storage_args.update({"url": url, "collection": collection}) storage_args.update({"url": url, "collection": collection})
yield storage_args yield storage_args
def create(self, collection): async def create(self, collection):
if collection is None: if collection is None:
collection = self._get_collection_from_url(self.kwargs["url"]) collection = self._get_collection_from_url(self.kwargs["url"])
for c in self.discover(): async for c in self.discover():
if c["collection"] == collection: if c["collection"] == collection:
return c return c
home = self.find_home() home = await self.find_home()
url = urlparse.urljoin(home, urlparse.quote(collection, "/@")) url = urlparse.urljoin(home, urlparse.quote(collection, "/@"))
try: try:
url = self._create_collection_impl(url) url = await self._create_collection_impl(url)
except HTTPError as e: except (aiohttp.ClientResponseError, Error) as e:
raise NotImplementedError(e) raise NotImplementedError(e)
else: else:
rv = dict(self.kwargs) rv = dict(self.kwargs)
@ -294,7 +303,7 @@ class Discover:
rv["url"] = url rv["url"] = url
return rv return rv
def _create_collection_impl(self, url): async def _create_collection_impl(self, url):
data = """<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<mkcol xmlns="DAV:"> <mkcol xmlns="DAV:">
<set> <set>
@ -312,13 +321,13 @@ class Discover:
"utf-8" "utf-8"
) )
response = self.session.request( response = await self.session.request(
"MKCOL", "MKCOL",
url, url,
data=data, data=data,
headers=self.session.get_default_headers(), headers=self.session.get_default_headers(),
) )
return response.url return str(response.url)
class CalDiscover(Discover): class CalDiscover(Discover):
@ -350,14 +359,18 @@ class CardDiscover(Discover):
class DAVSession: class DAVSession:
""" """A helper class to connect to DAV servers."""
A helper class to connect to DAV servers.
""" connector: aiohttp.BaseConnector
@classmethod @classmethod
def init_and_remaining_args(cls, **kwargs): def init_and_remaining_args(cls, **kwargs):
def is_arg(k):
"""Return true if ``k`` is an argument of ``cls.__init__``."""
return k in argspec.args or k in argspec.kwonlyargs
argspec = getfullargspec(cls.__init__) argspec = getfullargspec(cls.__init__)
self_args, remainder = utils.split_dict(kwargs, argspec.args.__contains__) self_args, remainder = utils.split_dict(kwargs, is_arg)
return cls(**self_args), remainder return cls(**self_args), remainder
@ -371,6 +384,8 @@ class DAVSession:
useragent=USERAGENT, useragent=USERAGENT,
verify_fingerprint=None, verify_fingerprint=None,
auth_cert=None, auth_cert=None,
*,
connector: aiohttp.BaseConnector,
): ):
self._settings = { self._settings = {
"cert": prepare_client_cert(auth_cert), "cert": prepare_client_cert(auth_cert),
@ -380,21 +395,28 @@ class DAVSession:
self.useragent = useragent self.useragent = useragent
self.url = url.rstrip("/") + "/" self.url = url.rstrip("/") + "/"
self.connector = connector
self._session = requests.session()
@utils.cached_property @utils.cached_property
def parsed_url(self): def parsed_url(self):
return urlparse.urlparse(self.url) return urlparse.urlparse(self.url)
def request(self, method, path, **kwargs): async def request(self, method, path, **kwargs):
url = self.url url = self.url
if path: if path:
url = urlparse.urljoin(self.url, path) url = urlparse.urljoin(str(self.url), path)
more = dict(self._settings) more = dict(self._settings)
more.update(kwargs) more.update(kwargs)
return http.request(method, url, session=self._session, **more)
# XXX: This is a temporary hack to pin-point bad refactoring.
assert self.connector is not None
async with aiohttp.ClientSession(
connector=self.connector,
connector_owner=False,
# TODO use `raise_for_status=true`, though this needs traces first,
) as session:
return await http.request(method, url, session=session, **more)
def get_default_headers(self): def get_default_headers(self):
return { return {
@ -417,33 +439,41 @@ class DAVStorage(Storage):
# The DAVSession class to use # The DAVSession class to use
session_class = DAVSession session_class = DAVSession
connector: aiohttp.TCPConnector
_repr_attributes = ("username", "url") _repr_attributes = ("username", "url")
_property_table = { _property_table = {
"displayname": ("displayname", "DAV:"), "displayname": ("displayname", "DAV:"),
} }
def __init__(self, **kwargs): def __init__(self, *, connector, **kwargs):
# defined for _repr_attributes # defined for _repr_attributes
self.username = kwargs.get("username") self.username = kwargs.get("username")
self.url = kwargs.get("url") self.url = kwargs.get("url")
self.connector = connector
self.session, kwargs = self.session_class.init_and_remaining_args(**kwargs) self.session, kwargs = self.session_class.init_and_remaining_args(
connector=connector,
**kwargs,
)
super().__init__(**kwargs) super().__init__(**kwargs)
__init__.__signature__ = signature(session_class.__init__) __init__.__signature__ = signature(session_class.__init__)
@classmethod @classmethod
def discover(cls, **kwargs): async def discover(cls, **kwargs):
session, _ = cls.session_class.init_and_remaining_args(**kwargs) session, _ = cls.session_class.init_and_remaining_args(**kwargs)
d = cls.discovery_class(session, kwargs) d = cls.discovery_class(session, kwargs)
return d.discover()
async for collection in d.discover():
yield collection
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): async def create_collection(cls, collection, **kwargs):
session, _ = cls.session_class.init_and_remaining_args(**kwargs) session, _ = cls.session_class.init_and_remaining_args(**kwargs)
d = cls.discovery_class(session, kwargs) d = cls.discovery_class(session, kwargs)
return d.create(collection) return await d.create(collection)
def _normalize_href(self, *args, **kwargs): def _normalize_href(self, *args, **kwargs):
return _normalize_href(self.session.url, *args, **kwargs) return _normalize_href(self.session.url, *args, **kwargs)
@ -455,57 +485,65 @@ class DAVStorage(Storage):
def _is_item_mimetype(self, mimetype): def _is_item_mimetype(self, mimetype):
return _fuzzy_matches_mimetype(self.item_mimetype, mimetype) return _fuzzy_matches_mimetype(self.item_mimetype, mimetype)
def get(self, href): async def get(self, href):
((actual_href, item, etag),) = self.get_multi([href]) ((actual_href, item, etag),) = await aiostream.stream.list(
self.get_multi([href])
)
assert href == actual_href assert href == actual_href
return item, etag return item, etag
def get_multi(self, hrefs): async def get_multi(self, hrefs):
hrefs = set(hrefs) hrefs = set(hrefs)
href_xml = [] href_xml = []
for href in hrefs: for href in hrefs:
if href != self._normalize_href(href): if href != self._normalize_href(href):
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
href_xml.append(f"<href>{href}</href>") href_xml.append(f"<href>{href}</href>")
if not href_xml: if href_xml:
return () data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode(
"utf-8"
)
response = await self.session.request(
"REPORT", "", data=data, headers=self.session.get_default_headers()
)
root = _parse_xml(
await response.content.read()
) # etree only can handle bytes
rv = []
hrefs_left = set(hrefs)
for href, etag, prop in self._parse_prop_responses(root):
raw = prop.find(self.get_multi_data_query)
if raw is None:
dav_logger.warning(
"Skipping {}, the item content is missing.".format(href)
)
continue
data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode("utf-8") raw = raw.text or ""
response = self.session.request(
"REPORT", "", data=data, headers=self.session.get_default_headers()
)
root = _parse_xml(response.content) # etree only can handle bytes
rv = []
hrefs_left = set(hrefs)
for href, etag, prop in self._parse_prop_responses(root):
raw = prop.find(self.get_multi_data_query)
if raw is None:
dav_logger.warning(
"Skipping {}, the item content is missing.".format(href)
)
continue
raw = raw.text or "" if isinstance(raw, bytes):
raw = raw.decode(response.encoding)
if isinstance(etag, bytes):
etag = etag.decode(response.encoding)
if isinstance(raw, bytes): try:
raw = raw.decode(response.encoding) hrefs_left.remove(href)
if isinstance(etag, bytes): except KeyError:
etag = etag.decode(response.encoding) if href in hrefs:
dav_logger.warning("Server sent item twice: {}".format(href))
try: else:
hrefs_left.remove(href) dav_logger.warning(
except KeyError: "Server sent unsolicited item: {}".format(href)
if href in hrefs: )
dav_logger.warning("Server sent item twice: {}".format(href))
else: else:
dav_logger.warning("Server sent unsolicited item: {}".format(href)) rv.append((href, Item(raw), etag))
else: for href in hrefs_left:
rv.append((href, Item(raw), etag)) raise exceptions.NotFoundError(href)
for href in hrefs_left:
raise exceptions.NotFoundError(href)
return rv
def _put(self, href, item, etag): for href, item, etag in rv:
yield href, item, etag
async def _put(self, href, item, etag):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Content-Type"] = self.item_mimetype headers["Content-Type"] = self.item_mimetype
if etag is None: if etag is None:
@ -513,11 +551,11 @@ class DAVStorage(Storage):
else: else:
headers["If-Match"] = etag headers["If-Match"] = etag
response = self.session.request( response = await self.session.request(
"PUT", href, data=item.raw.encode("utf-8"), headers=headers "PUT", href, data=item.raw.encode("utf-8"), headers=headers
) )
_assert_multistatus_success(response) await _assert_multistatus_success(response)
# The server may not return an etag under certain conditions: # The server may not return an etag under certain conditions:
# #
@ -534,25 +572,28 @@ class DAVStorage(Storage):
# In such cases we return a constant etag. The next synchronization # In such cases we return a constant etag. The next synchronization
# will then detect an etag change and will download the new item. # will then detect an etag change and will download the new item.
etag = response.headers.get("etag", None) etag = response.headers.get("etag", None)
href = self._normalize_href(response.url) href = self._normalize_href(str(response.url))
return href, etag return href, etag
def update(self, href, item, etag): async def update(self, href, item, etag):
if etag is None: if etag is None:
raise ValueError("etag must be given and must not be None.") raise ValueError("etag must be given and must not be None.")
href, etag = self._put(self._normalize_href(href), item, etag) href, etag = await self._put(self._normalize_href(href), item, etag)
return etag return etag
def upload(self, item): async def upload(self, item):
href = self._get_href(item) href = self._get_href(item)
return self._put(href, item, None) rv = await self._put(href, item, None)
return rv
def delete(self, href, etag): async def delete(self, href, etag):
href = self._normalize_href(href) href = self._normalize_href(href)
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers.update({"If-Match": etag}) if etag: # baikal doesn't give us an etag.
dav_logger.warning("Deleting an item with no etag.")
headers.update({"If-Match": etag})
self.session.request("DELETE", href, headers=headers) await self.session.request("DELETE", href, headers=headers)
def _parse_prop_responses(self, root, handled_hrefs=None): def _parse_prop_responses(self, root, handled_hrefs=None):
if handled_hrefs is None: if handled_hrefs is None:
@ -604,7 +645,7 @@ class DAVStorage(Storage):
handled_hrefs.add(href) handled_hrefs.add(href)
yield href, etag, props yield href, etag, props
def list(self): async def list(self):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Depth"] = "1" headers["Depth"] = "1"
@ -620,14 +661,19 @@ class DAVStorage(Storage):
# We use a PROPFIND request instead of addressbook-query due to issues # We use a PROPFIND request instead of addressbook-query due to issues
# with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83 # with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83
response = self.session.request("PROPFIND", "", data=data, headers=headers) response = await self.session.request(
root = _parse_xml(response.content) "PROPFIND",
"",
data=data,
headers=headers,
)
root = _parse_xml(await response.content.read())
rv = self._parse_prop_responses(root) rv = self._parse_prop_responses(root)
for href, etag, _prop in rv: for href, etag, _prop in rv:
yield href, etag yield href, etag
def get_meta(self, key): async def get_meta(self, key):
try: try:
tagname, namespace = self._property_table[key] tagname, namespace = self._property_table[key]
except KeyError: except KeyError:
@ -649,9 +695,14 @@ class DAVStorage(Storage):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers["Depth"] = "0" headers["Depth"] = "0"
response = self.session.request("PROPFIND", "", data=data, headers=headers) response = await self.session.request(
"PROPFIND",
"",
data=data,
headers=headers,
)
root = _parse_xml(response.content) root = _parse_xml(await response.content.read())
for prop in root.findall(".//" + xpath): for prop in root.findall(".//" + xpath):
text = normalize_meta_value(getattr(prop, "text", None)) text = normalize_meta_value(getattr(prop, "text", None))
@ -659,7 +710,7 @@ class DAVStorage(Storage):
return text return text
return "" return ""
def set_meta(self, key, value): async def set_meta(self, key, value):
try: try:
tagname, namespace = self._property_table[key] tagname, namespace = self._property_table[key]
except KeyError: except KeyError:
@ -683,8 +734,11 @@ class DAVStorage(Storage):
"utf-8" "utf-8"
) )
self.session.request( await self.session.request(
"PROPPATCH", "", data=data, headers=self.session.get_default_headers() "PROPPATCH",
"",
data=data,
headers=self.session.get_default_headers(),
) )
# XXX: Response content is currently ignored. Though exceptions are # XXX: Response content is currently ignored. Though exceptions are
@ -776,7 +830,7 @@ class CalDAVStorage(DAVStorage):
("VTODO", "VEVENT"), start, end ("VTODO", "VEVENT"), start, end
) )
def list(self): async def list(self):
caldavfilters = list( caldavfilters = list(
self._get_list_filters(self.item_types, self.start_date, self.end_date) self._get_list_filters(self.item_types, self.start_date, self.end_date)
) )
@ -788,7 +842,8 @@ class CalDAVStorage(DAVStorage):
# instead? # instead?
# #
# See https://github.com/dmfs/tasks/issues/118 for backstory. # See https://github.com/dmfs/tasks/issues/118 for backstory.
yield from DAVStorage.list(self) async for href, etag in DAVStorage.list(self):
yield href, etag
data = """<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<C:calendar-query xmlns="DAV:" <C:calendar-query xmlns="DAV:"
@ -813,8 +868,13 @@ class CalDAVStorage(DAVStorage):
for caldavfilter in caldavfilters: for caldavfilter in caldavfilters:
xml = data.format(caldavfilter=caldavfilter).encode("utf-8") xml = data.format(caldavfilter=caldavfilter).encode("utf-8")
response = self.session.request("REPORT", "", data=xml, headers=headers) response = await self.session.request(
root = _parse_xml(response.content) "REPORT",
"",
data=xml,
headers=headers,
)
root = _parse_xml(await response.content.read())
rv = self._parse_prop_responses(root, handled_hrefs) rv = self._parse_prop_responses(root, handled_hrefs)
for href, etag, _prop in rv: for href, etag, _prop in rv:
yield href, etag yield href, etag

View file

@ -30,13 +30,13 @@ logger = logging.getLogger(__name__)
def _writing_op(f): def _writing_op(f):
@functools.wraps(f) @functools.wraps(f)
def inner(self, *args, **kwargs): async def inner(self, *args, **kwargs):
if not self._at_once: if not self._at_once:
self._sync_journal() self._sync_journal()
rv = f(self, *args, **kwargs) rv = f(self, *args, **kwargs)
if not self._at_once: if not self._at_once:
self._sync_journal() self._sync_journal()
return rv return await rv
return inner return inner
@ -120,7 +120,14 @@ class EtesyncStorage(Storage):
self._session.etesync.sync_journal(self.collection) self._session.etesync.sync_journal(self.collection)
@classmethod @classmethod
def discover(cls, email, secrets_dir, server_url=None, db_path=None, **kwargs): async def discover(
cls,
email,
secrets_dir,
server_url=None,
db_path=None,
**kwargs,
):
if kwargs.get("collection", None) is not None: if kwargs.get("collection", None) is not None:
raise TypeError("collection argument must not be given.") raise TypeError("collection argument must not be given.")
session = _Session(email, secrets_dir, server_url, db_path) session = _Session(email, secrets_dir, server_url, db_path)
@ -139,7 +146,7 @@ class EtesyncStorage(Storage):
logger.debug(f"Skipping collection: {entry!r}") logger.debug(f"Skipping collection: {entry!r}")
@classmethod @classmethod
def create_collection( async def create_collection(
cls, collection, email, secrets_dir, server_url=None, db_path=None, **kwargs cls, collection, email, secrets_dir, server_url=None, db_path=None, **kwargs
): ):
session = _Session(email, secrets_dir, server_url, db_path) session = _Session(email, secrets_dir, server_url, db_path)
@ -158,13 +165,13 @@ class EtesyncStorage(Storage):
**kwargs, **kwargs,
) )
def list(self): async def list(self):
self._sync_journal() self._sync_journal()
for entry in self._journal.collection.list(): for entry in self._journal.collection.list():
item = Item(entry.content) item = Item(entry.content)
yield str(entry.uid), item.hash yield str(entry.uid), item.hash
def get(self, href): async def get(self, href):
try: try:
item = Item(self._journal.collection.get(href).content) item = Item(self._journal.collection.get(href).content)
except etesync.exceptions.DoesNotExist as e: except etesync.exceptions.DoesNotExist as e:
@ -172,7 +179,7 @@ class EtesyncStorage(Storage):
return item, item.hash return item, item.hash
@_writing_op @_writing_op
def upload(self, item): async def upload(self, item):
try: try:
entry = self._item_type.create(self._journal.collection, item.raw) entry = self._item_type.create(self._journal.collection, item.raw)
entry.save() entry.save()
@ -183,7 +190,7 @@ class EtesyncStorage(Storage):
return item.uid, item.hash return item.uid, item.hash
@_writing_op @_writing_op
def update(self, href, item, etag): async def update(self, href, item, etag):
try: try:
entry = self._journal.collection.get(href) entry = self._journal.collection.get(href)
except etesync.exceptions.DoesNotExist as e: except etesync.exceptions.DoesNotExist as e:
@ -196,7 +203,7 @@ class EtesyncStorage(Storage):
return item.hash return item.hash
@_writing_op @_writing_op
def delete(self, href, etag): async def delete(self, href, etag):
try: try:
entry = self._journal.collection.get(href) entry = self._journal.collection.get(href)
old_item = Item(entry.content) old_item = Item(entry.content)
@ -206,8 +213,8 @@ class EtesyncStorage(Storage):
except etesync.exceptions.DoesNotExist as e: except etesync.exceptions.DoesNotExist as e:
raise exceptions.NotFoundError(e) raise exceptions.NotFoundError(e)
@contextlib.contextmanager @contextlib.asynccontextmanager
def at_once(self): async def at_once(self):
self._sync_journal() self._sync_journal()
self._at_once = True self._at_once = True
try: try:

View file

@ -41,7 +41,7 @@ class FilesystemStorage(Storage):
self.post_hook = post_hook self.post_hook = post_hook
@classmethod @classmethod
def discover(cls, path, **kwargs): async def discover(cls, path, **kwargs):
if kwargs.pop("collection", None) is not None: if kwargs.pop("collection", None) is not None:
raise TypeError("collection argument must not be given.") raise TypeError("collection argument must not be given.")
path = expand_path(path) path = expand_path(path)
@ -67,7 +67,7 @@ class FilesystemStorage(Storage):
return True return True
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): async def create_collection(cls, collection, **kwargs):
kwargs = dict(kwargs) kwargs = dict(kwargs)
path = kwargs["path"] path = kwargs["path"]
@ -86,7 +86,7 @@ class FilesystemStorage(Storage):
def _get_href(self, ident): def _get_href(self, ident):
return generate_href(ident) + self.fileext return generate_href(ident) + self.fileext
def list(self): async def list(self):
for fname in os.listdir(self.path): for fname in os.listdir(self.path):
fpath = os.path.join(self.path, fname) fpath = os.path.join(self.path, fname)
if ( if (
@ -96,7 +96,7 @@ class FilesystemStorage(Storage):
): ):
yield fname, get_etag_from_file(fpath) yield fname, get_etag_from_file(fpath)
def get(self, href): async def get(self, href):
fpath = self._get_filepath(href) fpath = self._get_filepath(href)
try: try:
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
@ -107,7 +107,7 @@ class FilesystemStorage(Storage):
else: else:
raise raise
def upload(self, item): async def upload(self, item):
if not isinstance(item.raw, str): if not isinstance(item.raw, str):
raise TypeError("item.raw must be a unicode string.") raise TypeError("item.raw must be a unicode string.")
@ -139,7 +139,7 @@ class FilesystemStorage(Storage):
else: else:
raise raise
def update(self, href, item, etag): async def update(self, href, item, etag):
fpath = self._get_filepath(href) fpath = self._get_filepath(href)
if not os.path.exists(fpath): if not os.path.exists(fpath):
raise exceptions.NotFoundError(item.uid) raise exceptions.NotFoundError(item.uid)
@ -158,7 +158,7 @@ class FilesystemStorage(Storage):
self._run_post_hook(fpath) self._run_post_hook(fpath)
return etag return etag
def delete(self, href, etag): async def delete(self, href, etag):
fpath = self._get_filepath(href) fpath = self._get_filepath(href)
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
@ -176,7 +176,7 @@ class FilesystemStorage(Storage):
except OSError as e: except OSError as e:
logger.warning("Error executing external hook: {}".format(str(e))) logger.warning("Error executing external hook: {}".format(str(e)))
def get_meta(self, key): async def get_meta(self, key):
fpath = os.path.join(self.path, key) fpath = os.path.join(self.path, key)
try: try:
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
@ -187,7 +187,7 @@ class FilesystemStorage(Storage):
else: else:
raise raise
def set_meta(self, key, value): async def set_meta(self, key, value):
value = normalize_meta_value(value) value = normalize_meta_value(value)
fpath = os.path.join(self.path, key) fpath = os.path.join(self.path, key)

View file

@ -3,6 +3,7 @@ import logging
import os import os
import urllib.parse as urlparse import urllib.parse as urlparse
import aiohttp
import click import click
from atomicwrites import atomic_write from atomicwrites import atomic_write
@ -28,13 +29,21 @@ except ImportError:
class GoogleSession(dav.DAVSession): class GoogleSession(dav.DAVSession):
def __init__(self, token_file, client_id, client_secret, url=None): def __init__(
self,
token_file,
client_id,
client_secret,
url=None,
connector: aiohttp.BaseConnector = None,
):
# Required for discovering collections # Required for discovering collections
if url is not None: if url is not None:
self.url = url self.url = url
self.useragent = client_id self.useragent = client_id
self._settings = {} self._settings = {}
self.connector = connector
if not have_oauth2: if not have_oauth2:
raise exceptions.UserError("requests-oauthlib not installed") raise exceptions.UserError("requests-oauthlib not installed")

View file

@ -1,5 +1,7 @@
import urllib.parse as urlparse import urllib.parse as urlparse
import aiohttp
from .. import exceptions from .. import exceptions
from ..http import prepare_auth from ..http import prepare_auth
from ..http import prepare_client_cert from ..http import prepare_client_cert
@ -30,6 +32,8 @@ class HttpStorage(Storage):
useragent=USERAGENT, useragent=USERAGENT,
verify_fingerprint=None, verify_fingerprint=None,
auth_cert=None, auth_cert=None,
*,
connector,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -43,6 +47,8 @@ class HttpStorage(Storage):
self.username, self.password = username, password self.username, self.password = username, password
self.useragent = useragent self.useragent = useragent
assert connector is not None
self.connector = connector
collection = kwargs.get("collection") collection = kwargs.get("collection")
if collection is not None: if collection is not None:
@ -53,22 +59,35 @@ class HttpStorage(Storage):
def _default_headers(self): def _default_headers(self):
return {"User-Agent": self.useragent} return {"User-Agent": self.useragent}
def list(self): async def list(self):
r = request("GET", self.url, headers=self._default_headers(), **self._settings) async with aiohttp.ClientSession(
connector=self.connector,
connector_owner=False,
# TODO use `raise_for_status=true`, though this needs traces first,
) as session:
r = await request(
"GET",
self.url,
headers=self._default_headers(),
session=session,
**self._settings,
)
self._items = {} self._items = {}
for item in split_collection(r.text): for item in split_collection((await r.read()).decode("utf-8")):
item = Item(item) item = Item(item)
if self._ignore_uids: if self._ignore_uids:
item = item.with_uid(item.hash) item = item.with_uid(item.hash)
self._items[item.ident] = item, item.hash self._items[item.ident] = item, item.hash
return ((href, etag) for href, (item, etag) in self._items.items()) for href, (_, etag) in self._items.items():
yield href, etag
def get(self, href): async def get(self, href):
if self._items is None: if self._items is None:
self.list() async for _ in self.list():
pass
try: try:
return self._items[href] return self._items[href]

View file

@ -28,18 +28,18 @@ class MemoryStorage(Storage):
def _get_href(self, item): def _get_href(self, item):
return item.ident + self.fileext return item.ident + self.fileext
def list(self): async def list(self):
for href, (etag, _item) in self.items.items(): for href, (etag, _item) in self.items.items():
yield href, etag yield href, etag
def get(self, href): async def get(self, href):
etag, item = self.items[href] etag, item = self.items[href]
return item, etag return item, etag
def has(self, href): async def has(self, href):
return href in self.items return href in self.items
def upload(self, item): async def upload(self, item):
href = self._get_href(item) href = self._get_href(item)
if href in self.items: if href in self.items:
raise exceptions.AlreadyExistingError(existing_href=href) raise exceptions.AlreadyExistingError(existing_href=href)
@ -47,7 +47,7 @@ class MemoryStorage(Storage):
self.items[href] = (etag, item) self.items[href] = (etag, item)
return href, etag return href, etag
def update(self, href, item, etag): async def update(self, href, item, etag):
if href not in self.items: if href not in self.items:
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
actual_etag, _ = self.items[href] actual_etag, _ = self.items[href]
@ -58,15 +58,15 @@ class MemoryStorage(Storage):
self.items[href] = (new_etag, item) self.items[href] = (new_etag, item)
return new_etag return new_etag
def delete(self, href, etag): async def delete(self, href, etag):
if not self.has(href): if not await self.has(href):
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
if etag != self.items[href][0]: if etag != self.items[href][0]:
raise exceptions.WrongEtagError(etag) raise exceptions.WrongEtagError(etag)
del self.items[href] del self.items[href]
def get_meta(self, key): async def get_meta(self, key):
return normalize_meta_value(self.metadata.get(key)) return normalize_meta_value(self.metadata.get(key))
def set_meta(self, key, value): async def set_meta(self, key, value):
self.metadata[key] = normalize_meta_value(value) self.metadata[key] = normalize_meta_value(value)

View file

@ -21,10 +21,12 @@ logger = logging.getLogger(__name__)
def _writing_op(f): def _writing_op(f):
@functools.wraps(f) @functools.wraps(f)
def inner(self, *args, **kwargs): async def inner(self, *args, **kwargs):
if self._items is None or not self._at_once: if self._items is None or not self._at_once:
self.list() async for _ in self.list():
rv = f(self, *args, **kwargs) pass
assert self._items is not None
rv = await f(self, *args, **kwargs)
if not self._at_once: if not self._at_once:
self._write() self._write()
return rv return rv
@ -53,7 +55,7 @@ class SingleFileStorage(Storage):
self._at_once = False self._at_once = False
@classmethod @classmethod
def discover(cls, path, **kwargs): async def discover(cls, path, **kwargs):
if kwargs.pop("collection", None) is not None: if kwargs.pop("collection", None) is not None:
raise TypeError("collection argument must not be given.") raise TypeError("collection argument must not be given.")
@ -81,7 +83,7 @@ class SingleFileStorage(Storage):
yield args yield args
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): async def create_collection(cls, collection, **kwargs):
path = os.path.abspath(expand_path(kwargs["path"])) path = os.path.abspath(expand_path(kwargs["path"]))
if collection is not None: if collection is not None:
@ -97,7 +99,7 @@ class SingleFileStorage(Storage):
kwargs["collection"] = collection kwargs["collection"] = collection
return kwargs return kwargs
def list(self): async def list(self):
self._items = collections.OrderedDict() self._items = collections.OrderedDict()
try: try:
@ -111,19 +113,19 @@ class SingleFileStorage(Storage):
raise OSError(e) raise OSError(e)
text = None text = None
if not text: if text:
return () for item in split_collection(text):
item = Item(item)
etag = item.hash
href = item.ident
self._items[href] = item, etag
for item in split_collection(text): yield href, etag
item = Item(item)
etag = item.hash
self._items[item.ident] = item, etag
return ((href, etag) for href, (item, etag) in self._items.items()) async def get(self, href):
def get(self, href):
if self._items is None or not self._at_once: if self._items is None or not self._at_once:
self.list() async for _ in self.list():
pass
try: try:
return self._items[href] return self._items[href]
@ -131,7 +133,7 @@ class SingleFileStorage(Storage):
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
@_writing_op @_writing_op
def upload(self, item): async def upload(self, item):
href = item.ident href = item.ident
if href in self._items: if href in self._items:
raise exceptions.AlreadyExistingError(existing_href=href) raise exceptions.AlreadyExistingError(existing_href=href)
@ -140,7 +142,7 @@ class SingleFileStorage(Storage):
return href, item.hash return href, item.hash
@_writing_op @_writing_op
def update(self, href, item, etag): async def update(self, href, item, etag):
if href not in self._items: if href not in self._items:
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
@ -152,7 +154,7 @@ class SingleFileStorage(Storage):
return item.hash return item.hash
@_writing_op @_writing_op
def delete(self, href, etag): async def delete(self, href, etag):
if href not in self._items: if href not in self._items:
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
@ -181,8 +183,8 @@ class SingleFileStorage(Storage):
self._items = None self._items = None
self._last_etag = None self._last_etag = None
@contextlib.contextmanager @contextlib.asynccontextmanager
def at_once(self): async def at_once(self):
self.list() self.list()
self._at_once = True self._at_once = True
try: try:

View file

@ -35,7 +35,7 @@ class _StorageInfo:
self.status = status self.status = status
self._item_cache = {} self._item_cache = {}
def prepare_new_status(self): async def prepare_new_status(self):
storage_nonempty = False storage_nonempty = False
prefetch = [] prefetch = []
@ -45,7 +45,7 @@ class _StorageInfo:
except IdentAlreadyExists as e: except IdentAlreadyExists as e:
raise e.to_ident_conflict(self.storage) raise e.to_ident_conflict(self.storage)
for href, etag in self.storage.list(): async for href, etag in self.storage.list():
storage_nonempty = True storage_nonempty = True
ident, meta = self.status.get_by_href(href) ident, meta = self.status.get_by_href(href)
@ -58,9 +58,13 @@ class _StorageInfo:
_store_props(ident, meta) _store_props(ident, meta)
# Prefetch items # Prefetch items
for href, item, etag in self.storage.get_multi(prefetch) if prefetch else (): if prefetch:
_store_props(item.ident, ItemMetadata(href=href, hash=item.hash, etag=etag)) async for href, item, etag in self.storage.get_multi(prefetch):
self.set_item_cache(item.ident, item) _store_props(
item.ident,
ItemMetadata(href=href, hash=item.hash, etag=etag),
)
self.set_item_cache(item.ident, item)
return storage_nonempty return storage_nonempty
@ -86,7 +90,7 @@ class _StorageInfo:
return self._item_cache[ident] return self._item_cache[ident]
def sync( async def sync(
storage_a, storage_a,
storage_b, storage_b,
status, status,
@ -137,8 +141,8 @@ def sync(
a_info = _StorageInfo(storage_a, SubStatus(status, "a")) a_info = _StorageInfo(storage_a, SubStatus(status, "a"))
b_info = _StorageInfo(storage_b, SubStatus(status, "b")) b_info = _StorageInfo(storage_b, SubStatus(status, "b"))
a_nonempty = a_info.prepare_new_status() a_nonempty = await a_info.prepare_new_status()
b_nonempty = b_info.prepare_new_status() b_nonempty = await b_info.prepare_new_status()
if status_nonempty and not force_delete: if status_nonempty and not force_delete:
if a_nonempty and not b_nonempty: if a_nonempty and not b_nonempty:
@ -148,10 +152,10 @@ def sync(
actions = list(_get_actions(a_info, b_info)) actions = list(_get_actions(a_info, b_info))
with storage_a.at_once(), storage_b.at_once(): async with storage_a.at_once(), storage_b.at_once():
for action in actions: for action in actions:
try: try:
action.run(a_info, b_info, conflict_resolution, partial_sync) await action.run(a_info, b_info, conflict_resolution, partial_sync)
except Exception as e: except Exception as e:
if error_callback: if error_callback:
error_callback(e) error_callback(e)
@ -160,10 +164,10 @@ def sync(
class Action: class Action:
def _run_impl(self, a, b): # pragma: no cover async def _run_impl(self, a, b): # pragma: no cover
raise NotImplementedError() raise NotImplementedError()
def run(self, a, b, conflict_resolution, partial_sync): async def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b): with self.auto_rollback(a, b):
if self.dest.storage.read_only: if self.dest.storage.read_only:
if partial_sync == "error": if partial_sync == "error":
@ -174,7 +178,7 @@ class Action:
else: else:
assert partial_sync == "revert" assert partial_sync == "revert"
self._run_impl(a, b) await self._run_impl(a, b)
@contextlib.contextmanager @contextlib.contextmanager
def auto_rollback(self, a, b): def auto_rollback(self, a, b):
@ -194,7 +198,7 @@ class Upload(Action):
self.ident = item.ident self.ident = item.ident
self.dest = dest self.dest = dest
def _run_impl(self, a, b): async def _run_impl(self, a, b):
if self.dest.storage.read_only: if self.dest.storage.read_only:
href = etag = None href = etag = None
@ -204,7 +208,7 @@ class Upload(Action):
self.ident, self.dest.storage self.ident, self.dest.storage
) )
) )
href, etag = self.dest.storage.upload(self.item) href, etag = await self.dest.storage.upload(self.item)
assert href is not None assert href is not None
self.dest.status.insert_ident( self.dest.status.insert_ident(
@ -218,7 +222,7 @@ class Update(Action):
self.ident = item.ident self.ident = item.ident
self.dest = dest self.dest = dest
def _run_impl(self, a, b): async def _run_impl(self, a, b):
if self.dest.storage.read_only: if self.dest.storage.read_only:
meta = ItemMetadata(hash=self.item.hash) meta = ItemMetadata(hash=self.item.hash)
else: else:
@ -226,7 +230,7 @@ class Update(Action):
"Copying (updating) item {} to {}".format(self.ident, self.dest.storage) "Copying (updating) item {} to {}".format(self.ident, self.dest.storage)
) )
meta = self.dest.status.get_new(self.ident) meta = self.dest.status.get_new(self.ident)
meta.etag = self.dest.storage.update(meta.href, self.item, meta.etag) meta.etag = await self.dest.storage.update(meta.href, self.item, meta.etag)
self.dest.status.update_ident(self.ident, meta) self.dest.status.update_ident(self.ident, meta)
@ -236,13 +240,13 @@ class Delete(Action):
self.ident = ident self.ident = ident
self.dest = dest self.dest = dest
def _run_impl(self, a, b): async def _run_impl(self, a, b):
meta = self.dest.status.get_new(self.ident) meta = self.dest.status.get_new(self.ident)
if not self.dest.storage.read_only: if not self.dest.storage.read_only:
sync_logger.info( sync_logger.info(
"Deleting item {} from {}".format(self.ident, self.dest.storage) "Deleting item {} from {}".format(self.ident, self.dest.storage)
) )
self.dest.storage.delete(meta.href, meta.etag) await self.dest.storage.delete(meta.href, meta.etag)
self.dest.status.remove_ident(self.ident) self.dest.status.remove_ident(self.ident)
@ -251,7 +255,7 @@ class ResolveConflict(Action):
def __init__(self, ident): def __init__(self, ident):
self.ident = ident self.ident = ident
def run(self, a, b, conflict_resolution, partial_sync): async def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b): with self.auto_rollback(a, b):
sync_logger.info( sync_logger.info(
"Doing conflict resolution for item {}...".format(self.ident) "Doing conflict resolution for item {}...".format(self.ident)
@ -271,9 +275,19 @@ class ResolveConflict(Action):
item_b = b.get_item_cache(self.ident) item_b = b.get_item_cache(self.ident)
new_item = conflict_resolution(item_a, item_b) new_item = conflict_resolution(item_a, item_b)
if new_item.hash != meta_a.hash: if new_item.hash != meta_a.hash:
Update(new_item, a).run(a, b, conflict_resolution, partial_sync) await Update(new_item, a).run(
a,
b,
conflict_resolution,
partial_sync,
)
if new_item.hash != meta_b.hash: if new_item.hash != meta_b.hash:
Update(new_item, b).run(a, b, conflict_resolution, partial_sync) await Update(new_item, b).run(
a,
b,
conflict_resolution,
partial_sync,
)
else: else:
raise UserError( raise UserError(
"Invalid conflict resolution mode: {!r}".format(conflict_resolution) "Invalid conflict resolution mode: {!r}".format(conflict_resolution)

View file

@ -26,22 +26,15 @@ def expand_path(p: str) -> str:
return p return p
def split_dict(d, f): def split_dict(d: dict, f: callable):
"""Puts key into first dict if f(key), otherwise in second dict""" """Puts key into first dict if f(key), otherwise in second dict"""
a, b = split_sequence(d.items(), lambda item: f(item[0])) a = {}
return dict(a), dict(b) b = {}
for k, v in d.items():
if f(k):
def split_sequence(s, f): a[k] = v
"""Puts item into first list if f(item), else in second list"""
a = []
b = []
for item in s:
if f(item):
a.append(item)
else: else:
b.append(item) b[k] = v
return a, b return a, b