Initial async support

Add asyncio to the storage backends and most of the codebase. A lot of
it merely uses asyncio APIs, but still doesn't actually run several
things concurrently internally. Further improvements will be added on
top of these changes

Thanks to  Thomas Grainger (@graingert) for a few useful pointers
related to asyncio.
This commit is contained in:
Hugo Osvaldo Barrera 2021-06-17 23:38:18 +02:00
parent 7c9170c677
commit 1a1f6f0788
44 changed files with 1383 additions and 935 deletions

View file

@ -19,6 +19,10 @@ packages:
- python-pytest-cov
- python-pytest-httpserver
- python-trustme
- python-pytest-asyncio
- python-aiohttp
- python-aiostream
- python-aioresponses
sources:
- https://github.com/pimutils/vdirsyncer
environment:

View file

@ -8,6 +8,7 @@ addopts =
--cov=vdirsyncer
--cov-report=term-missing
--no-cov-on-fail
# filterwarnings=error
[flake8]
application-import-names = tests,vdirsyncer

View file

@ -13,14 +13,14 @@ requirements = [
# https://github.com/mitsuhiko/click/issues/200
"click>=5.0,<9.0",
"click-log>=0.3.0, <0.4.0",
# https://github.com/pimutils/vdirsyncer/issues/478
"click-threading>=0.5",
"requests >=2.20.0",
# https://github.com/sigmavirus24/requests-toolbelt/pull/28
# And https://github.com/sigmavirus24/requests-toolbelt/issues/54
"requests_toolbelt >=0.4.0",
# https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa
"atomicwrites>=0.1.7",
"aiohttp>=3.7.1,<4.0.0",
"aiostream>=0.4.3,<0.5.0",
]

View file

@ -3,3 +3,5 @@ pytest
pytest-cov
pytest-httpserver
trustme
pytest-asyncio
aioresponses

View file

@ -4,6 +4,7 @@ General-purpose fixtures for vdirsyncer's testsuite.
import logging
import os
import aiohttp
import click_log
import pytest
from hypothesis import HealthCheck
@ -52,3 +53,18 @@ elif os.environ.get("CI", "false").lower() == "true":
settings.load_profile("ci")
else:
settings.load_profile("dev")
@pytest.fixture
async def aio_session(event_loop):
async with aiohttp.ClientSession() as session:
yield session
@pytest.fixture
async def aio_connector(event_loop):
conn = aiohttp.TCPConnector(limit_per_host=16)
try:
yield conn
finally:
await conn.close()

View file

@ -4,6 +4,7 @@ import uuid
from urllib.parse import quote as urlquote
from urllib.parse import unquote as urlunquote
import aiostream
import pytest
from .. import assert_item_equals
@ -49,8 +50,9 @@ class StorageTests:
raise NotImplementedError()
@pytest.fixture
def s(self, get_storage_args):
return self.storage_class(**get_storage_args())
async def s(self, get_storage_args):
rv = self.storage_class(**await get_storage_args())
return rv
@pytest.fixture
def get_item(self, item_type):
@ -72,176 +74,211 @@ class StorageTests:
if not self.supports_metadata:
pytest.skip("This storage does not support metadata.")
def test_generic(self, s, get_item):
@pytest.mark.asyncio
async def test_generic(self, s, get_item):
items = [get_item() for i in range(1, 10)]
hrefs = []
for item in items:
href, etag = s.upload(item)
href, etag = await s.upload(item)
if etag is None:
_, etag = s.get(href)
_, etag = await s.get(href)
hrefs.append((href, etag))
hrefs.sort()
assert hrefs == sorted(s.list())
assert hrefs == sorted(await aiostream.stream.list(s.list()))
for href, etag in hrefs:
assert isinstance(href, (str, bytes))
assert isinstance(etag, (str, bytes))
assert s.has(href)
item, etag2 = s.get(href)
assert await s.has(href)
item, etag2 = await s.get(href)
assert etag == etag2
def test_empty_get_multi(self, s):
assert list(s.get_multi([])) == []
@pytest.mark.asyncio
async def test_empty_get_multi(self, s):
assert await aiostream.stream.list(s.get_multi([])) == []
def test_get_multi_duplicates(self, s, get_item):
href, etag = s.upload(get_item())
@pytest.mark.asyncio
async def test_get_multi_duplicates(self, s, get_item):
href, etag = await s.upload(get_item())
if etag is None:
_, etag = s.get(href)
((href2, item, etag2),) = s.get_multi([href] * 2)
_, etag = await s.get(href)
((href2, item, etag2),) = await aiostream.stream.list(s.get_multi([href] * 2))
assert href2 == href
assert etag2 == etag
def test_upload_already_existing(self, s, get_item):
@pytest.mark.asyncio
async def test_upload_already_existing(self, s, get_item):
item = get_item()
s.upload(item)
await s.upload(item)
with pytest.raises(exceptions.PreconditionFailed):
s.upload(item)
await s.upload(item)
def test_upload(self, s, get_item):
@pytest.mark.asyncio
async def test_upload(self, s, get_item):
item = get_item()
href, etag = s.upload(item)
assert_item_equals(s.get(href)[0], item)
href, etag = await s.upload(item)
assert_item_equals((await s.get(href))[0], item)
def test_update(self, s, get_item):
@pytest.mark.asyncio
async def test_update(self, s, get_item):
item = get_item()
href, etag = s.upload(item)
href, etag = await s.upload(item)
if etag is None:
_, etag = s.get(href)
assert_item_equals(s.get(href)[0], item)
_, etag = await s.get(href)
assert_item_equals((await s.get(href))[0], item)
new_item = get_item(uid=item.uid)
new_etag = s.update(href, new_item, etag)
new_etag = await s.update(href, new_item, etag)
if new_etag is None:
_, new_etag = s.get(href)
_, new_etag = await s.get(href)
# See https://github.com/pimutils/vdirsyncer/issues/48
assert isinstance(new_etag, (bytes, str))
assert_item_equals(s.get(href)[0], new_item)
assert_item_equals((await s.get(href))[0], new_item)
def test_update_nonexisting(self, s, get_item):
@pytest.mark.asyncio
async def test_update_nonexisting(self, s, get_item):
item = get_item()
with pytest.raises(exceptions.PreconditionFailed):
s.update("huehue", item, '"123"')
await s.update("huehue", item, '"123"')
def test_wrong_etag(self, s, get_item):
@pytest.mark.asyncio
async def test_wrong_etag(self, s, get_item):
item = get_item()
href, etag = s.upload(item)
href, etag = await s.upload(item)
with pytest.raises(exceptions.PreconditionFailed):
s.update(href, item, '"lolnope"')
await s.update(href, item, '"lolnope"')
with pytest.raises(exceptions.PreconditionFailed):
s.delete(href, '"lolnope"')
await s.delete(href, '"lolnope"')
def test_delete(self, s, get_item):
href, etag = s.upload(get_item())
s.delete(href, etag)
assert not list(s.list())
@pytest.mark.asyncio
async def test_delete(self, s, get_item):
href, etag = await s.upload(get_item())
await s.delete(href, etag)
assert not await aiostream.stream.list(s.list())
def test_delete_nonexisting(self, s, get_item):
@pytest.mark.asyncio
async def test_delete_nonexisting(self, s, get_item):
with pytest.raises(exceptions.PreconditionFailed):
s.delete("1", '"123"')
await s.delete("1", '"123"')
def test_list(self, s, get_item):
assert not list(s.list())
href, etag = s.upload(get_item())
@pytest.mark.asyncio
async def test_list(self, s, get_item):
assert not await aiostream.stream.list(s.list())
href, etag = await s.upload(get_item())
if etag is None:
_, etag = s.get(href)
assert list(s.list()) == [(href, etag)]
_, etag = await s.get(href)
assert await aiostream.stream.list(s.list()) == [(href, etag)]
def test_has(self, s, get_item):
assert not s.has("asd")
href, etag = s.upload(get_item())
assert s.has(href)
assert not s.has("asd")
s.delete(href, etag)
assert not s.has(href)
@pytest.mark.asyncio
async def test_has(self, s, get_item):
assert not await s.has("asd")
href, etag = await s.upload(get_item())
assert await s.has(href)
assert not await s.has("asd")
await s.delete(href, etag)
assert not await s.has(href)
def test_update_others_stay_the_same(self, s, get_item):
@pytest.mark.asyncio
async def test_update_others_stay_the_same(self, s, get_item):
info = {}
for _ in range(4):
href, etag = s.upload(get_item())
href, etag = await s.upload(get_item())
if etag is None:
_, etag = s.get(href)
_, etag = await s.get(href)
info[href] = etag
assert {
href: etag
for href, item, etag in s.get_multi(href for href, etag in info.items())
} == info
items = await aiostream.stream.list(
s.get_multi(href for href, etag in info.items())
)
assert {href: etag for href, item, etag in items} == info
def test_repr(self, s, get_storage_args):
@pytest.mark.asyncio
def test_repr(self, s, get_storage_args): # XXX: unused param
assert self.storage_class.__name__ in repr(s)
assert s.instance_name is None
def test_discover(self, requires_collections, get_storage_args, get_item):
@pytest.mark.asyncio
async def test_discover(
self,
requires_collections,
get_storage_args,
get_item,
aio_connector,
):
collections = set()
for i in range(1, 5):
collection = f"test{i}"
s = self.storage_class(**get_storage_args(collection=collection))
assert not list(s.list())
s.upload(get_item())
s = self.storage_class(**await get_storage_args(collection=collection))
assert not await aiostream.stream.list(s.list())
await s.upload(get_item())
collections.add(s.collection)
actual = {
c["collection"]
for c in self.storage_class.discover(**get_storage_args(collection=None))
}
discovered = await aiostream.stream.list(
self.storage_class.discover(**await get_storage_args(collection=None))
)
actual = {c["collection"] for c in discovered}
assert actual >= collections
def test_create_collection(self, requires_collections, get_storage_args, get_item):
@pytest.mark.asyncio
async def test_create_collection(
self,
requires_collections,
get_storage_args,
get_item,
):
if getattr(self, "dav_server", "") in ("icloud", "fastmail", "davical"):
pytest.skip("Manual cleanup would be necessary.")
if getattr(self, "dav_server", "") == "radicale":
pytest.skip("Radicale does not support collection creation")
args = get_storage_args(collection=None)
args = await get_storage_args(collection=None)
args["collection"] = "test"
s = self.storage_class(**self.storage_class.create_collection(**args))
s = self.storage_class(**await self.storage_class.create_collection(**args))
href = s.upload(get_item())[0]
assert href in (href for href, etag in s.list())
href = (await s.upload(get_item()))[0]
assert href in await aiostream.stream.list(
(href async for href, etag in s.list())
)
def test_discover_collection_arg(self, requires_collections, get_storage_args):
args = get_storage_args(collection="test2")
@pytest.mark.asyncio
async def test_discover_collection_arg(
self, requires_collections, get_storage_args
):
args = await get_storage_args(collection="test2")
with pytest.raises(TypeError) as excinfo:
list(self.storage_class.discover(**args))
await aiostream.stream.list(self.storage_class.discover(**args))
assert "collection argument must not be given" in str(excinfo.value)
def test_collection_arg(self, get_storage_args):
@pytest.mark.asyncio
async def test_collection_arg(self, get_storage_args):
if self.storage_class.storage_name.startswith("etesync"):
pytest.skip("etesync uses UUIDs.")
if self.supports_collections:
s = self.storage_class(**get_storage_args(collection="test2"))
s = self.storage_class(**await get_storage_args(collection="test2"))
# Can't do stronger assertion because of radicale, which needs a
# fileextension to guess the collection type.
assert "test2" in s.collection
else:
with pytest.raises(ValueError):
self.storage_class(collection="ayy", **get_storage_args())
self.storage_class(collection="ayy", **await get_storage_args())
def test_case_sensitive_uids(self, s, get_item):
@pytest.mark.asyncio
async def test_case_sensitive_uids(self, s, get_item):
if s.storage_name == "filesystem":
pytest.skip("Behavior depends on the filesystem.")
uid = str(uuid.uuid4())
s.upload(get_item(uid=uid.upper()))
s.upload(get_item(uid=uid.lower()))
items = [href for href, etag in s.list()]
await s.upload(get_item(uid=uid.upper()))
await s.upload(get_item(uid=uid.lower()))
items = [href async for href, etag in s.list()]
assert len(items) == 2
assert len(set(items)) == 2
def test_specialchars(
@pytest.mark.asyncio
async def test_specialchars(
self, monkeypatch, requires_collections, get_storage_args, get_item
):
if getattr(self, "dav_server", "") == "radicale":
@ -254,16 +291,16 @@ class StorageTests:
uid = "test @ foo ät bar град сатану"
collection = "test @ foo ät bar"
s = self.storage_class(**get_storage_args(collection=collection))
s = self.storage_class(**await get_storage_args(collection=collection))
item = get_item(uid=uid)
href, etag = s.upload(item)
item2, etag2 = s.get(href)
href, etag = await s.upload(item)
item2, etag2 = await s.get(href)
if etag is not None:
assert etag2 == etag
assert_item_equals(item2, item)
((_, etag3),) = s.list()
((_, etag3),) = await aiostream.stream.list(s.list())
assert etag2 == etag3
# etesync uses UUIDs for collection names
@ -274,22 +311,23 @@ class StorageTests:
if self.storage_class.storage_name.endswith("dav"):
assert urlquote(uid, "/@:") in href
def test_metadata(self, requires_metadata, s):
@pytest.mark.asyncio
async def test_metadata(self, requires_metadata, s):
if not getattr(self, "dav_server", ""):
assert not s.get_meta("color")
assert not s.get_meta("displayname")
assert not await s.get_meta("color")
assert not await s.get_meta("displayname")
try:
s.set_meta("color", None)
assert not s.get_meta("color")
s.set_meta("color", "#ff0000")
assert s.get_meta("color") == "#ff0000"
await s.set_meta("color", None)
assert not await s.get_meta("color")
await s.set_meta("color", "#ff0000")
assert await s.get_meta("color") == "#ff0000"
except exceptions.UnsupportedMetadataError:
pass
for x in ("hello world", "hello wörld"):
s.set_meta("displayname", x)
rv = s.get_meta("displayname")
await s.set_meta("displayname", x)
rv = await s.get_meta("displayname")
assert rv == x
assert isinstance(rv, str)
@ -306,16 +344,18 @@ class StorageTests:
"فلسطين",
],
)
def test_metadata_normalization(self, requires_metadata, s, value):
x = s.get_meta("displayname")
@pytest.mark.asyncio
async def test_metadata_normalization(self, requires_metadata, s, value):
x = await s.get_meta("displayname")
assert x == normalize_meta_value(x)
if not getattr(self, "dav_server", None):
# ownCloud replaces "" with "unnamed"
s.set_meta("displayname", value)
assert s.get_meta("displayname") == normalize_meta_value(value)
await s.set_meta("displayname", value)
assert await s.get_meta("displayname") == normalize_meta_value(value)
def test_recurring_events(self, s, item_type):
@pytest.mark.asyncio
async def test_recurring_events(self, s, item_type):
if item_type != "VEVENT":
pytest.skip("This storage instance doesn't support iCalendar.")
@ -362,7 +402,7 @@ class StorageTests:
).strip()
)
href, etag = s.upload(item)
href, etag = await s.upload(item)
item2, etag2 = s.get(href)
item2, etag2 = await s.get(href)
assert normalize_item(item) == normalize_item(item2)

View file

@ -3,6 +3,7 @@ import subprocess
import time
import uuid
import aiostream
import pytest
import requests
@ -80,31 +81,31 @@ def xandikos_server():
@pytest.fixture
def slow_create_collection(request):
async def slow_create_collection(request, aio_connector):
# We need to properly clean up because otherwise we might run into
# storage limits.
to_delete = []
def delete_collections():
async def delete_collections():
for s in to_delete:
s.session.request("DELETE", "")
await s.session.request("DELETE", "")
request.addfinalizer(delete_collections)
def inner(cls, args, collection):
async def inner(cls, args, collection):
assert collection.startswith("test")
collection += "-vdirsyncer-ci-" + str(uuid.uuid4())
args = cls.create_collection(collection, **args)
args = await cls.create_collection(collection, **args)
s = cls(**args)
_clear_collection(s)
assert not list(s.list())
await _clear_collection(s)
assert not await aiostream.stream.list(s.list())
to_delete.append(s)
return args
return inner
yield inner
await delete_collections()
def _clear_collection(s):
for href, etag in s.list():
async def _clear_collection(s):
async for href, etag in s.list():
s.delete(href, etag)

View file

@ -1,8 +1,9 @@
import os
import uuid
import aiohttp
import aiostream
import pytest
import requests.exceptions
from .. import get_server_mixin
from .. import StorageTests
@ -19,30 +20,33 @@ class DAVStorageTests(ServerMixin, StorageTests):
dav_server = dav_server
@pytest.mark.skipif(dav_server == "radicale", reason="Radicale is very tolerant.")
def test_dav_broken_item(self, s):
@pytest.mark.asyncio
async def test_dav_broken_item(self, s):
item = Item("HAHA:YES")
with pytest.raises((exceptions.Error, requests.exceptions.HTTPError)):
s.upload(item)
assert not list(s.list())
with pytest.raises((exceptions.Error, aiohttp.ClientResponseError)):
await s.upload(item)
assert not await aiostream.stream.list(s.list())
def test_dav_empty_get_multi_performance(self, s, monkeypatch):
@pytest.mark.asyncio
async def test_dav_empty_get_multi_performance(self, s, monkeypatch):
def breakdown(*a, **kw):
raise AssertionError("Expected not to be called.")
monkeypatch.setattr("requests.sessions.Session.request", breakdown)
try:
assert list(s.get_multi([])) == []
assert list(await aiostream.stream.list(s.get_multi([]))) == []
finally:
# Make sure monkeypatch doesn't interfere with DAV server teardown
monkeypatch.undo()
def test_dav_unicode_href(self, s, get_item, monkeypatch):
@pytest.mark.asyncio
async def test_dav_unicode_href(self, s, get_item, monkeypatch):
if self.dav_server == "radicale":
pytest.skip("Radicale is unable to deal with unicode hrefs")
monkeypatch.setattr(s, "_get_href", lambda item: item.ident + s.fileext)
item = get_item(uid="град сатану" + str(uuid.uuid4()))
href, etag = s.upload(item)
item2, etag2 = s.get(href)
href, etag = await s.upload(item)
item2, etag2 = await s.get(href)
assert_item_equals(item, item2)

View file

@ -1,8 +1,10 @@
import datetime
from textwrap import dedent
import aiohttp
import aiostream
import pytest
import requests.exceptions
from aioresponses import aioresponses
from . import dav_server
from . import DAVStorageTests
@ -21,15 +23,17 @@ class TestCalDAVStorage(DAVStorageTests):
def item_type(self, request):
return request.param
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.")
def test_doesnt_accept_vcard(self, item_type, get_storage_args):
s = self.storage_class(item_types=(item_type,), **get_storage_args())
@pytest.mark.asyncio
async def test_doesnt_accept_vcard(self, item_type, get_storage_args):
s = self.storage_class(item_types=(item_type,), **await get_storage_args())
try:
s.upload(format_item(VCARD_TEMPLATE))
except (exceptions.Error, requests.exceptions.HTTPError):
await s.upload(format_item(VCARD_TEMPLATE))
except (exceptions.Error, aiohttp.ClientResponseError):
# Most storages hard-fail, but xandikos doesn't.
pass
assert not list(s.list())
assert not await aiostream.stream.list(s.list())
# The `arg` param is not named `item_types` because that would hit
# https://bitbucket.org/pytest-dev/pytest/issue/745/
@ -44,10 +48,11 @@ class TestCalDAVStorage(DAVStorageTests):
],
)
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.")
def test_item_types_performance(
@pytest.mark.asyncio
async def test_item_types_performance(
self, get_storage_args, arg, calls_num, monkeypatch
):
s = self.storage_class(item_types=arg, **get_storage_args())
s = self.storage_class(item_types=arg, **await get_storage_args())
old_parse = s._parse_prop_responses
calls = []
@ -56,17 +61,18 @@ class TestCalDAVStorage(DAVStorageTests):
return old_parse(*a, **kw)
monkeypatch.setattr(s, "_parse_prop_responses", new_parse)
list(s.list())
await aiostream.stream.list(s.list())
assert len(calls) == calls_num
@pytest.mark.xfail(
dav_server == "radicale", reason="Radicale doesn't support timeranges."
)
def test_timerange_correctness(self, get_storage_args):
@pytest.mark.asyncio
async def test_timerange_correctness(self, get_storage_args):
start_date = datetime.datetime(2013, 9, 10)
end_date = datetime.datetime(2013, 9, 13)
s = self.storage_class(
start_date=start_date, end_date=end_date, **get_storage_args()
start_date=start_date, end_date=end_date, **await get_storage_args()
)
too_old_item = format_item(
@ -123,50 +129,44 @@ class TestCalDAVStorage(DAVStorageTests):
).strip()
)
s.upload(too_old_item)
s.upload(too_new_item)
expected_href, _ = s.upload(good_item)
await s.upload(too_old_item)
await s.upload(too_new_item)
expected_href, _ = await s.upload(good_item)
((actual_href, _),) = s.list()
((actual_href, _),) = await aiostream.stream.list(s.list())
assert actual_href == expected_href
def test_invalid_resource(self, monkeypatch, get_storage_args):
calls = []
args = get_storage_args(collection=None)
@pytest.mark.asyncio
async def test_invalid_resource(self, monkeypatch, get_storage_args):
args = await get_storage_args(collection=None)
def request(session, method, url, **kwargs):
assert url == args["url"]
calls.append(None)
with aioresponses() as m:
m.add(args["url"], method="PROPFIND", status=200, body="Hello world")
r = requests.Response()
r.status_code = 200
r._content = b"Hello World."
return r
with pytest.raises(ValueError):
s = self.storage_class(**args)
await aiostream.stream.list(s.list())
monkeypatch.setattr("requests.sessions.Session.request", request)
with pytest.raises(ValueError):
s = self.storage_class(**args)
list(s.list())
assert len(calls) == 1
assert len(m.requests) == 1
@pytest.mark.skipif(dav_server == "icloud", reason="iCloud only accepts VEVENT")
@pytest.mark.skipif(
dav_server == "fastmail", reason="Fastmail has non-standard hadling of VTODOs."
)
@pytest.mark.xfail(dav_server == "baikal", reason="Baikal returns 500.")
def test_item_types_general(self, s):
event = s.upload(format_item(EVENT_TEMPLATE))[0]
task = s.upload(format_item(TASK_TEMPLATE))[0]
@pytest.mark.asyncio
async def test_item_types_general(self, s):
event = (await s.upload(format_item(EVENT_TEMPLATE)))[0]
task = (await s.upload(format_item(TASK_TEMPLATE)))[0]
s.item_types = ("VTODO", "VEVENT")
def hrefs():
return {href for href, etag in s.list()}
async def hrefs():
return {href async for href, etag in s.list()}
assert hrefs() == {event, task}
assert await hrefs() == {event, task}
s.item_types = ("VTODO",)
assert hrefs() == {task}
assert await hrefs() == {task}
s.item_types = ("VEVENT",)
assert hrefs() == {event}
assert await hrefs() == {event}
s.item_types = ()
assert hrefs() == {event, task}
assert await hrefs() == {event, task}

View file

@ -58,7 +58,7 @@ class EtesyncTests(StorageTests):
)
assert r.status_code == 200
def inner(collection="test"):
async def inner(collection="test"):
rv = {
"email": "test@localhost",
"db_path": str(tmpdir.join("etesync.db")),

View file

@ -3,13 +3,21 @@ import pytest
class ServerMixin:
@pytest.fixture
def get_storage_args(self, request, tmpdir, slow_create_collection, baikal_server):
def inner(collection="test"):
def get_storage_args(
self,
request,
tmpdir,
slow_create_collection,
baikal_server,
aio_connector,
):
async def inner(collection="test"):
base_url = "http://127.0.0.1:8002/"
args = {
"url": base_url,
"username": "baikal",
"password": "baikal",
"connector": aio_connector,
}
if self.storage_class.fileext == ".vcf":
@ -18,7 +26,11 @@ class ServerMixin:
args["url"] = base_url + "cal.php/"
if collection is not None:
args = slow_create_collection(self.storage_class, args, collection)
args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args
return inner

View file

@ -27,7 +27,7 @@ class ServerMixin:
@pytest.fixture
def get_storage_args(self, davical_args, request):
def inner(collection="test"):
async def inner(collection="test"):
if collection is None:
return davical_args

View file

@ -11,7 +11,7 @@ class ServerMixin:
# See https://github.com/pimutils/vdirsyncer/issues/824
pytest.skip("Fastmail has non-standard VTODO support.")
def inner(collection="test"):
async def inner(collection="test"):
args = {
"username": os.environ["FASTMAIL_USERNAME"],
"password": os.environ["FASTMAIL_PASSWORD"],

View file

@ -11,7 +11,7 @@ class ServerMixin:
# See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa
pytest.skip("iCloud doesn't support anything else than VEVENT")
def inner(collection="test"):
async def inner(collection="test"):
args = {
"username": os.environ["ICLOUD_USERNAME"],
"password": os.environ["ICLOUD_PASSWORD"],

View file

@ -9,17 +9,23 @@ class ServerMixin:
tmpdir,
slow_create_collection,
radicale_server,
aio_connector,
):
def inner(collection="test"):
async def inner(collection="test"):
url = "http://127.0.0.1:8001/"
args = {
"url": url,
"username": "radicale",
"password": "radicale",
"connector": aio_connector,
}
if collection is not None:
args = slow_create_collection(self.storage_class, args, collection)
args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args
return inner

View file

@ -9,13 +9,19 @@ class ServerMixin:
tmpdir,
slow_create_collection,
xandikos_server,
aio_connector,
):
def inner(collection="test"):
async def inner(collection="test"):
url = "http://127.0.0.1:8000/"
args = {"url": url}
args = {"url": url, "connector": aio_connector}
if collection is not None:
args = slow_create_collection(self.storage_class, args, collection)
args = await slow_create_collection(
self.storage_class,
args,
collection,
)
return args
return inner

View file

@ -1,5 +1,6 @@
import subprocess
import aiostream
import pytest
from . import StorageTests
@ -12,10 +13,10 @@ class TestFilesystemStorage(StorageTests):
@pytest.fixture
def get_storage_args(self, tmpdir):
def inner(collection="test"):
async def inner(collection="test"):
rv = {"path": str(tmpdir), "fileext": ".txt", "collection": collection}
if collection is not None:
rv = self.storage_class.create_collection(**rv)
rv = await self.storage_class.create_collection(**rv)
return rv
return inner
@ -26,7 +27,8 @@ class TestFilesystemStorage(StorageTests):
f.write("stub")
self.storage_class(str(tmpdir) + "/hue", ".txt")
def test_broken_data(self, tmpdir):
@pytest.mark.asyncio
async def test_broken_data(self, tmpdir):
s = self.storage_class(str(tmpdir), ".txt")
class BrokenItem:
@ -35,64 +37,70 @@ class TestFilesystemStorage(StorageTests):
ident = uid
with pytest.raises(TypeError):
s.upload(BrokenItem)
await s.upload(BrokenItem)
assert not tmpdir.listdir()
def test_ident_with_slash(self, tmpdir):
@pytest.mark.asyncio
async def test_ident_with_slash(self, tmpdir):
s = self.storage_class(str(tmpdir), ".txt")
s.upload(Item("UID:a/b/c"))
await s.upload(Item("UID:a/b/c"))
(item_file,) = tmpdir.listdir()
assert "/" not in item_file.basename and item_file.isfile()
def test_ignore_tmp_files(self, tmpdir):
@pytest.mark.asyncio
async def test_ignore_tmp_files(self, tmpdir):
"""Test that files with .tmp suffix beside .ics files are ignored."""
s = self.storage_class(str(tmpdir), ".ics")
s.upload(Item("UID:xyzxyz"))
await s.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(ext="tmp"))
assert len(tmpdir.listdir()) == 2
assert len(list(s.list())) == 1
assert len(await aiostream.stream.list(s.list())) == 1
def test_ignore_tmp_files_empty_fileext(self, tmpdir):
@pytest.mark.asyncio
async def test_ignore_tmp_files_empty_fileext(self, tmpdir):
"""Test that files with .tmp suffix are ignored with empty fileext."""
s = self.storage_class(str(tmpdir), "")
s.upload(Item("UID:xyzxyz"))
await s.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(ext="tmp"))
assert len(tmpdir.listdir()) == 2
# assert False, tmpdir.listdir() # enable to see the created filename
assert len(list(s.list())) == 1
assert len(await aiostream.stream.list(s.list())) == 1
def test_ignore_files_typical_backup(self, tmpdir):
@pytest.mark.asyncio
async def test_ignore_files_typical_backup(self, tmpdir):
"""Test file-name ignorance with typical backup ending ~."""
ignorext = "~" # without dot
storage = self.storage_class(str(tmpdir), "", fileignoreext=ignorext)
storage.upload(Item("UID:xyzxyz"))
await storage.upload(Item("UID:xyzxyz"))
(item_file,) = tmpdir.listdir()
item_file.copy(item_file.new(basename=item_file.basename + ignorext))
assert len(tmpdir.listdir()) == 2
assert len(list(storage.list())) == 1
assert len(await aiostream.stream.list(storage.list())) == 1
def test_too_long_uid(self, tmpdir):
@pytest.mark.asyncio
async def test_too_long_uid(self, tmpdir):
storage = self.storage_class(str(tmpdir), ".txt")
item = Item("UID:" + "hue" * 600)
href, etag = storage.upload(item)
href, etag = await storage.upload(item)
assert item.uid not in href
def test_post_hook_inactive(self, tmpdir, monkeypatch):
@pytest.mark.asyncio
async def test_post_hook_inactive(self, tmpdir, monkeypatch):
def check_call_mock(*args, **kwargs):
raise AssertionError()
monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), ".txt", post_hook=None)
s.upload(Item("UID:a/b/c"))
def test_post_hook_active(self, tmpdir, monkeypatch):
await s.upload(Item("UID:a/b/c"))
@pytest.mark.asyncio
async def test_post_hook_active(self, tmpdir, monkeypatch):
calls = []
exe = "foo"
@ -104,14 +112,17 @@ class TestFilesystemStorage(StorageTests):
monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), ".txt", post_hook=exe)
s.upload(Item("UID:a/b/c"))
await s.upload(Item("UID:a/b/c"))
assert calls
def test_ignore_git_dirs(self, tmpdir):
@pytest.mark.asyncio
async def test_ignore_git_dirs(self, tmpdir):
tmpdir.mkdir(".git").mkdir("foo")
tmpdir.mkdir("a")
tmpdir.mkdir("b")
assert {c["collection"] for c in self.storage_class.discover(str(tmpdir))} == {
"a",
"b",
expected = {"a", "b"}
actual = {
c["collection"] async for c in self.storage_class.discover(str(tmpdir))
}
assert actual == expected

View file

@ -1,5 +1,6 @@
import pytest
from requests import Response
from aioresponses import aioresponses
from aioresponses import CallbackResult
from tests import normalize_item
from vdirsyncer.exceptions import UserError
@ -7,7 +8,8 @@ from vdirsyncer.storage.http import HttpStorage
from vdirsyncer.storage.http import prepare_auth
def test_list(monkeypatch):
@pytest.mark.asyncio
async def test_list(aio_connector):
collection_url = "http://127.0.0.1/calendar/collection.ics"
items = [
@ -34,50 +36,53 @@ def test_list(monkeypatch):
responses = ["\n".join(["BEGIN:VCALENDAR"] + items + ["END:VCALENDAR"])] * 2
def get(self, method, url, *a, **kw):
assert method == "GET"
assert url == collection_url
r = Response()
r.status_code = 200
def callback(url, headers, **kwargs):
assert headers["User-Agent"].startswith("vdirsyncer/")
assert responses
r._content = responses.pop().encode("utf-8")
r.headers["Content-Type"] = "text/calendar"
r.encoding = "ISO-8859-1"
return r
monkeypatch.setattr("requests.sessions.Session.request", get)
return CallbackResult(
status=200,
body=responses.pop().encode("utf-8"),
headers={"Content-Type": "text/calendar; charset=iso-8859-1"},
)
s = HttpStorage(url=collection_url)
with aioresponses() as m:
m.get(collection_url, callback=callback, repeat=True)
found_items = {}
s = HttpStorage(url=collection_url, connector=aio_connector)
for href, etag in s.list():
item, etag2 = s.get(href)
assert item.uid is not None
assert etag2 == etag
found_items[normalize_item(item)] = href
found_items = {}
expected = {
normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items
}
async for href, etag in s.list():
item, etag2 = await s.get(href)
assert item.uid is not None
assert etag2 == etag
found_items[normalize_item(item)] = href
assert set(found_items) == expected
expected = {
normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items
}
for href, etag in s.list():
item, etag2 = s.get(href)
assert item.uid is not None
assert etag2 == etag
assert found_items[normalize_item(item)] == href
assert set(found_items) == expected
async for href, etag in s.list():
item, etag2 = await s.get(href)
assert item.uid is not None
assert etag2 == etag
assert found_items[normalize_item(item)] == href
def test_readonly_param():
def test_readonly_param(aio_connector):
"""The ``readonly`` param cannot be ``False``."""
url = "http://example.com/"
with pytest.raises(ValueError):
HttpStorage(url=url, read_only=False)
HttpStorage(url=url, read_only=False, connector=aio_connector)
a = HttpStorage(url=url, read_only=True).read_only
b = HttpStorage(url=url, read_only=None).read_only
assert a is b is True
a = HttpStorage(url=url, read_only=True, connector=aio_connector)
b = HttpStorage(url=url, read_only=None, connector=aio_connector)
assert a.read_only is b.read_only is True
def test_prepare_auth():
@ -115,9 +120,9 @@ def test_prepare_auth_guess(monkeypatch):
assert "requests_toolbelt is too old" in str(excinfo.value).lower()
def test_verify_false_disallowed():
def test_verify_false_disallowed(aio_connector):
with pytest.raises(ValueError) as excinfo:
HttpStorage(url="http://example.com", verify=False)
HttpStorage(url="http://example.com", verify=False, connector=aio_connector)
assert "forbidden" in str(excinfo.value).lower()
assert "consider setting verify_fingerprint" in str(excinfo.value).lower()

View file

@ -1,5 +1,7 @@
import aiostream
import pytest
from requests import Response
from aioresponses import aioresponses
from aioresponses import CallbackResult
import vdirsyncer.storage.http
from . import StorageTests
@ -14,32 +16,33 @@ class CombinedStorage(Storage):
_repr_attributes = ("url", "path")
storage_name = "http_and_singlefile"
def __init__(self, url, path, **kwargs):
def __init__(self, url, path, *, connector, **kwargs):
if kwargs.get("collection", None) is not None:
raise ValueError()
super().__init__(**kwargs)
self.url = url
self.path = path
self._reader = vdirsyncer.storage.http.HttpStorage(url=url)
self._reader = vdirsyncer.storage.http.HttpStorage(url=url, connector=connector)
self._reader._ignore_uids = False
self._writer = SingleFileStorage(path=path)
def list(self, *a, **kw):
return self._reader.list(*a, **kw)
async def list(self, *a, **kw):
async for item in self._reader.list(*a, **kw):
yield item
def get(self, *a, **kw):
self.list()
return self._reader.get(*a, **kw)
async def get(self, *a, **kw):
await aiostream.stream.list(self.list())
return await self._reader.get(*a, **kw)
def upload(self, *a, **kw):
return self._writer.upload(*a, **kw)
async def upload(self, *a, **kw):
return await self._writer.upload(*a, **kw)
def update(self, *a, **kw):
return self._writer.update(*a, **kw)
async def update(self, *a, **kw):
return await self._writer.update(*a, **kw)
def delete(self, *a, **kw):
return self._writer.delete(*a, **kw)
async def delete(self, *a, **kw):
return await self._writer.delete(*a, **kw)
class TestHttpStorage(StorageTests):
@ -51,28 +54,37 @@ class TestHttpStorage(StorageTests):
def setup_tmpdir(self, tmpdir, monkeypatch):
self.tmpfile = str(tmpdir.ensure("collection.txt"))
def _request(method, url, *args, **kwargs):
assert method == "GET"
assert url == "http://localhost:123/collection.txt"
assert "vdirsyncer" in kwargs["headers"]["User-Agent"]
r = Response()
r.status_code = 200
try:
with open(self.tmpfile, "rb") as f:
r._content = f.read()
except OSError:
r._content = b""
def callback(url, headers, **kwargs):
"""Read our tmpfile at request time.
r.headers["Content-Type"] = "text/calendar"
r.encoding = "utf-8"
return r
We can't just read this during test setup since the file get written to
during test execution.
monkeypatch.setattr(vdirsyncer.storage.http, "request", _request)
It might make sense to actually run a server serving the local file.
"""
assert headers["User-Agent"].startswith("vdirsyncer/")
with open(self.tmpfile, "r") as f:
body = f.read()
return CallbackResult(
status=200,
body=body,
headers={"Content-Type": "text/calendar; charset=utf-8"},
)
with aioresponses() as m:
m.get("http://localhost:123/collection.txt", callback=callback, repeat=True)
yield
@pytest.fixture
def get_storage_args(self):
def inner(collection=None):
def get_storage_args(self, aio_connector):
async def inner(collection=None):
assert collection is None
return {"url": "http://localhost:123/collection.txt", "path": self.tmpfile}
return {
"url": "http://localhost:123/collection.txt",
"path": self.tmpfile,
"connector": aio_connector,
}
return inner

View file

@ -11,4 +11,7 @@ class TestMemoryStorage(StorageTests):
@pytest.fixture
def get_storage_args(self):
return lambda **kw: kw
async def inner(**args):
return args
return inner

View file

@ -11,10 +11,10 @@ class TestSingleFileStorage(StorageTests):
@pytest.fixture
def get_storage_args(self, tmpdir):
def inner(collection="test"):
async def inner(collection="test"):
rv = {"path": str(tmpdir.join("%s.txt")), "collection": collection}
if collection is not None:
rv = self.storage_class.create_collection(**rv)
rv = await self.storage_class.create_collection(**rv)
return rv
return inner

View file

@ -1,3 +1,5 @@
import pytest
from vdirsyncer import exceptions
from vdirsyncer.cli.utils import handle_cli_error
from vdirsyncer.cli.utils import storage_instance_from_config
@ -15,11 +17,13 @@ def test_handle_cli_error(capsys):
assert "ayy lmao" in err
def test_storage_instance_from_config(monkeypatch):
def lol(**kw):
assert kw == {"foo": "bar", "baz": 1}
return "OK"
@pytest.mark.asyncio
async def test_storage_instance_from_config(monkeypatch, aio_connector):
class Dummy:
def __init__(self, **kw):
assert kw == {"foo": "bar", "baz": 1}
monkeypatch.setitem(storage_names._storages, "lol", lol)
monkeypatch.setitem(storage_names._storages, "lol", Dummy)
config = {"type": "lol", "foo": "bar", "baz": 1}
assert storage_instance_from_config(config) == "OK"
storage = await storage_instance_from_config(config, connector=aio_connector)
assert isinstance(storage, Dummy)

View file

@ -1,9 +1,9 @@
import logging
import sys
import aiohttp
import click_log
import pytest
import requests
from cryptography import x509
from cryptography.hazmat.primitives import hashes
@ -25,74 +25,77 @@ def test_get_storage_init_args():
assert not required
def test_request_ssl():
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request("GET", "https://self-signed.badssl.com/")
assert "certificate verify failed" in str(excinfo.value)
@pytest.mark.asyncio
async def test_request_ssl():
async with aiohttp.ClientSession() as session:
with pytest.raises(aiohttp.ClientConnectorCertificateError) as excinfo:
await http.request(
"GET",
"https://self-signed.badssl.com/",
session=session,
)
assert "certificate verify failed" in str(excinfo.value)
http.request("GET", "https://self-signed.badssl.com/", verify=False)
await http.request(
"GET",
"https://self-signed.badssl.com/",
verify=False,
session=session,
)
def _fingerprints_broken():
from pkg_resources import parse_version as ver
broken_urllib3 = ver(requests.__version__) <= ver("2.5.1")
return broken_urllib3
def fingerprint_of_cert(cert, hash=hashes.SHA256):
def fingerprint_of_cert(cert, hash=hashes.SHA256) -> str:
return x509.load_pem_x509_certificate(cert.bytes()).fingerprint(hash()).hex()
@pytest.mark.skipif(
_fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529"
)
@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256])
def test_request_ssl_leaf_fingerprint(httpserver, localhost_cert, hash_algorithm):
@pytest.mark.parametrize("hash_algorithm", [hashes.SHA256])
@pytest.mark.asyncio
async def test_request_ssl_leaf_fingerprint(
httpserver,
localhost_cert,
hash_algorithm,
aio_session,
):
fingerprint = fingerprint_of_cert(localhost_cert.cert_chain_pems[0], hash_algorithm)
bogus = "".join(reversed(fingerprint))
# We have to serve something:
httpserver.expect_request("/").respond_with_data("OK")
url = f"https://{httpserver.host}:{httpserver.port}/"
url = f"https://127.0.0.1:{httpserver.port}/"
http.request("GET", url, verify=False, verify_fingerprint=fingerprint)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request("GET", url, verify_fingerprint=fingerprint)
await http.request("GET", url, verify_fingerprint=fingerprint, session=aio_session)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request(
"GET",
url,
verify=False,
verify_fingerprint="".join(reversed(fingerprint)),
)
assert "Fingerprints did not match" in str(excinfo.value)
with pytest.raises(aiohttp.ServerFingerprintMismatch):
await http.request("GET", url, verify_fingerprint=bogus, session=aio_session)
@pytest.mark.skipif(
_fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529"
)
@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.parametrize("hash_algorithm", [hashes.MD5, hashes.SHA256])
def test_request_ssl_ca_fingerprint(httpserver, ca, hash_algorithm):
@pytest.mark.parametrize("hash_algorithm", [hashes.SHA256])
@pytest.mark.asyncio
async def test_request_ssl_ca_fingerprints(httpserver, ca, hash_algorithm, aio_session):
fingerprint = fingerprint_of_cert(ca.cert_pem)
bogus = "".join(reversed(fingerprint))
# We have to serve something:
httpserver.expect_request("/").respond_with_data("OK")
url = f"https://{httpserver.host}:{httpserver.port}/"
url = f"https://127.0.0.1:{httpserver.port}/"
http.request("GET", url, verify=False, verify_fingerprint=fingerprint)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request("GET", url, verify_fingerprint=fingerprint)
await http.request(
"GET",
url,
verify=False,
verify_fingerprint=fingerprint,
session=aio_session,
)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
with pytest.raises(aiohttp.ServerFingerprintMismatch):
http.request(
"GET",
url,
verify=False,
verify_fingerprint="".join(reversed(fingerprint)),
verify_fingerprint=bogus,
session=aio_session,
)
assert "Fingerprints did not match" in str(excinfo.value)
def test_open_graphical_browser(monkeypatch):

View file

@ -1,3 +1,4 @@
import aiostream
import pytest
from vdirsyncer.cli.discover import expand_collections
@ -132,34 +133,40 @@ missing = object()
),
],
)
def test_expand_collections(shortcuts, expected):
@pytest.mark.asyncio
async def test_expand_collections(shortcuts, expected):
config_a = {"type": "fooboo", "storage_side": "a"}
config_b = {"type": "fooboo", "storage_side": "b"}
def get_discovered_a():
async def get_discovered_a():
return {
"c1": {"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
"c2": {"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
"a3": {"type": "fooboo", "custom_arg": "a3", "collection": "a3"},
}
def get_discovered_b():
async def get_discovered_b():
return {
"c1": {"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
"c2": {"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
"b3": {"type": "fooboo", "custom_arg": "b3", "collection": "b3"},
}
async def handle_not_found(config, collection):
return missing
assert (
sorted(
expand_collections(
shortcuts,
config_a,
config_b,
get_discovered_a,
get_discovered_b,
lambda config, collection: missing,
await aiostream.stream.list(
expand_collections(
shortcuts,
config_a,
config_b,
get_discovered_a,
get_discovered_b,
handle_not_found,
)
)
)
== sorted(expected)

View file

@ -1,5 +1,7 @@
import asyncio
from copy import deepcopy
import aiostream
import hypothesis.strategies as st
import pytest
from hypothesis import assume
@ -21,10 +23,10 @@ from vdirsyncer.sync.status import SqliteStatus
from vdirsyncer.vobject import Item
def sync(a, b, status, *args, **kwargs):
async def sync(a, b, status, *args, **kwargs):
new_status = SqliteStatus(":memory:")
new_status.load_legacy_status(status)
rv = _sync(a, b, new_status, *args, **kwargs)
rv = await _sync(a, b, new_status, *args, **kwargs)
status.clear()
status.update(new_status.to_legacy_status())
return rv
@ -38,45 +40,49 @@ def items(s):
return {x[1].raw for x in s.items.values()}
def test_irrelevant_status():
@pytest.mark.asyncio
async def test_irrelevant_status():
a = MemoryStorage()
b = MemoryStorage()
status = {"1": ("1", 1234, "1.ics", 2345)}
sync(a, b, status)
await sync(a, b, status)
assert not status
assert not items(a)
assert not items(b)
def test_missing_status():
@pytest.mark.asyncio
async def test_missing_status():
a = MemoryStorage()
b = MemoryStorage()
status = {}
item = Item("asdf")
a.upload(item)
b.upload(item)
sync(a, b, status)
await a.upload(item)
await b.upload(item)
await sync(a, b, status)
assert len(status) == 1
assert items(a) == items(b) == {item.raw}
def test_missing_status_and_different_items():
@pytest.mark.asyncio
async def test_missing_status_and_different_items():
a = MemoryStorage()
b = MemoryStorage()
status = {}
item1 = Item("UID:1\nhaha")
item2 = Item("UID:1\nhoho")
a.upload(item1)
b.upload(item2)
await a.upload(item1)
await b.upload(item2)
with pytest.raises(SyncConflict):
sync(a, b, status)
await sync(a, b, status)
assert not status
sync(a, b, status, conflict_resolution="a wins")
await sync(a, b, status, conflict_resolution="a wins")
assert items(a) == items(b) == {item1.raw}
def test_read_only_and_prefetch():
@pytest.mark.asyncio
async def test_read_only_and_prefetch():
a = MemoryStorage()
b = MemoryStorage()
b.read_only = True
@ -84,147 +90,154 @@ def test_read_only_and_prefetch():
status = {}
item1 = Item("UID:1\nhaha")
item2 = Item("UID:2\nhoho")
a.upload(item1)
a.upload(item2)
await a.upload(item1)
await a.upload(item2)
sync(a, b, status, force_delete=True)
sync(a, b, status, force_delete=True)
await sync(a, b, status, force_delete=True)
await sync(a, b, status, force_delete=True)
assert not items(a) and not items(b)
def test_partial_sync_error():
@pytest.mark.asyncio
async def test_partial_sync_error():
a = MemoryStorage()
b = MemoryStorage()
status = {}
a.upload(Item("UID:0"))
await a.upload(Item("UID:0"))
b.read_only = True
with pytest.raises(PartialSync):
sync(a, b, status, partial_sync="error")
await sync(a, b, status, partial_sync="error")
def test_partial_sync_ignore():
@pytest.mark.asyncio
async def test_partial_sync_ignore():
a = MemoryStorage()
b = MemoryStorage()
status = {}
item0 = Item("UID:0\nhehe")
a.upload(item0)
b.upload(item0)
await a.upload(item0)
await b.upload(item0)
b.read_only = True
item1 = Item("UID:1\nhaha")
a.upload(item1)
await a.upload(item1)
sync(a, b, status, partial_sync="ignore")
sync(a, b, status, partial_sync="ignore")
await sync(a, b, status, partial_sync="ignore")
await sync(a, b, status, partial_sync="ignore")
assert items(a) == {item0.raw, item1.raw}
assert items(b) == {item0.raw}
def test_partial_sync_ignore2():
@pytest.mark.asyncio
async def test_partial_sync_ignore2():
a = MemoryStorage()
b = MemoryStorage()
status = {}
href, etag = a.upload(Item("UID:0"))
href, etag = await a.upload(Item("UID:0"))
a.read_only = True
sync(a, b, status, partial_sync="ignore", force_delete=True)
await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {"UID:0"}
b.items.clear()
sync(a, b, status, partial_sync="ignore", force_delete=True)
sync(a, b, status, partial_sync="ignore", force_delete=True)
await sync(a, b, status, partial_sync="ignore", force_delete=True)
await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(a) == {"UID:0"}
assert not b.items
a.read_only = False
a.update(href, Item("UID:0\nupdated"), etag)
await a.update(href, Item("UID:0\nupdated"), etag)
a.read_only = True
sync(a, b, status, partial_sync="ignore", force_delete=True)
await sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {"UID:0\nupdated"}
def test_upload_and_update():
@pytest.mark.asyncio
async def test_upload_and_update():
a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b")
status = {}
item = Item("UID:1") # new item 1 in a
a.upload(item)
sync(a, b, status)
await a.upload(item)
await sync(a, b, status)
assert items(b) == items(a) == {item.raw}
item = Item("UID:1\nASDF:YES") # update of item 1 in b
b.update("1.b", item, b.get("1.b")[1])
sync(a, b, status)
await b.update("1.b", item, (await b.get("1.b"))[1])
await sync(a, b, status)
assert items(b) == items(a) == {item.raw}
item2 = Item("UID:2") # new item 2 in b
b.upload(item2)
sync(a, b, status)
await b.upload(item2)
await sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw}
item2 = Item("UID:2\nASDF:YES") # update of item 2 in a
a.update("2.a", item2, a.get("2.a")[1])
sync(a, b, status)
await a.update("2.a", item2, (await a.get("2.a"))[1])
await sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw}
def test_deletion():
@pytest.mark.asyncio
async def test_deletion():
a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b")
status = {}
item = Item("UID:1")
a.upload(item)
await a.upload(item)
item2 = Item("UID:2")
a.upload(item2)
sync(a, b, status)
b.delete("1.b", b.get("1.b")[1])
sync(a, b, status)
await a.upload(item2)
await sync(a, b, status)
await b.delete("1.b", (await b.get("1.b"))[1])
await sync(a, b, status)
assert items(a) == items(b) == {item2.raw}
a.upload(item)
sync(a, b, status)
await a.upload(item)
await sync(a, b, status)
assert items(a) == items(b) == {item.raw, item2.raw}
a.delete("1.a", a.get("1.a")[1])
sync(a, b, status)
await a.delete("1.a", (await a.get("1.a"))[1])
await sync(a, b, status)
assert items(a) == items(b) == {item2.raw}
def test_insert_hash():
@pytest.mark.asyncio
async def test_insert_hash():
a = MemoryStorage()
b = MemoryStorage()
status = {}
item = Item("UID:1")
href, etag = a.upload(item)
sync(a, b, status)
href, etag = await a.upload(item)
await sync(a, b, status)
for d in status["1"]:
del d["hash"]
a.update(href, Item("UID:1\nHAHA:YES"), etag)
sync(a, b, status)
await a.update(href, Item("UID:1\nHAHA:YES"), etag)
await sync(a, b, status)
assert "hash" in status["1"][0] and "hash" in status["1"][1]
def test_already_synced():
@pytest.mark.asyncio
async def test_already_synced():
a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext=".b")
item = Item("UID:1")
a.upload(item)
b.upload(item)
await a.upload(item)
await b.upload(item)
status = {
"1": (
{"href": "1.a", "hash": item.hash, "etag": a.get("1.a")[1]},
{"href": "1.b", "hash": item.hash, "etag": b.get("1.b")[1]},
{"href": "1.a", "hash": item.hash, "etag": (await a.get("1.a"))[1]},
{"href": "1.b", "hash": item.hash, "etag": (await b.get("1.b"))[1]},
)
}
old_status = deepcopy(status)
@ -233,69 +246,73 @@ def test_already_synced():
)
for _ in (1, 2):
sync(a, b, status)
await sync(a, b, status)
assert status == old_status
assert items(a) == items(b) == {item.raw}
@pytest.mark.parametrize("winning_storage", "ab")
def test_conflict_resolution_both_etags_new(winning_storage):
@pytest.mark.asyncio
async def test_conflict_resolution_both_etags_new(winning_storage):
a = MemoryStorage()
b = MemoryStorage()
item = Item("UID:1")
href_a, etag_a = a.upload(item)
href_b, etag_b = b.upload(item)
href_a, etag_a = await a.upload(item)
href_b, etag_b = await b.upload(item)
status = {}
sync(a, b, status)
await sync(a, b, status)
assert status
item_a = Item("UID:1\nitem a")
item_b = Item("UID:1\nitem b")
a.update(href_a, item_a, etag_a)
b.update(href_b, item_b, etag_b)
await a.update(href_a, item_a, etag_a)
await b.update(href_b, item_b, etag_b)
with pytest.raises(SyncConflict):
sync(a, b, status)
sync(a, b, status, conflict_resolution=f"{winning_storage} wins")
await sync(a, b, status)
await sync(a, b, status, conflict_resolution=f"{winning_storage} wins")
assert (
items(a) == items(b) == {item_a.raw if winning_storage == "a" else item_b.raw}
)
def test_updated_and_deleted():
@pytest.mark.asyncio
async def test_updated_and_deleted():
a = MemoryStorage()
b = MemoryStorage()
href_a, etag_a = a.upload(Item("UID:1"))
href_a, etag_a = await a.upload(Item("UID:1"))
status = {}
sync(a, b, status, force_delete=True)
await sync(a, b, status, force_delete=True)
((href_b, etag_b),) = b.list()
b.delete(href_b, etag_b)
((href_b, etag_b),) = await aiostream.stream.list(b.list())
await b.delete(href_b, etag_b)
updated = Item("UID:1\nupdated")
a.update(href_a, updated, etag_a)
sync(a, b, status, force_delete=True)
await a.update(href_a, updated, etag_a)
await sync(a, b, status, force_delete=True)
assert items(a) == items(b) == {updated.raw}
def test_conflict_resolution_invalid_mode():
@pytest.mark.asyncio
async def test_conflict_resolution_invalid_mode():
a = MemoryStorage()
b = MemoryStorage()
item_a = Item("UID:1\nitem a")
item_b = Item("UID:1\nitem b")
a.upload(item_a)
b.upload(item_b)
await a.upload(item_a)
await b.upload(item_b)
with pytest.raises(ValueError):
sync(a, b, {}, conflict_resolution="yolo")
await sync(a, b, {}, conflict_resolution="yolo")
def test_conflict_resolution_new_etags_without_changes():
@pytest.mark.asyncio
async def test_conflict_resolution_new_etags_without_changes():
a = MemoryStorage()
b = MemoryStorage()
item = Item("UID:1")
href_a, etag_a = a.upload(item)
href_b, etag_b = b.upload(item)
href_a, etag_a = await a.upload(item)
href_b, etag_b = await b.upload(item)
status = {"1": (href_a, "BOGUS_a", href_b, "BOGUS_b")}
sync(a, b, status)
await sync(a, b, status)
((ident, (status_a, status_b)),) = status.items()
assert ident == "1"
@ -305,7 +322,8 @@ def test_conflict_resolution_new_etags_without_changes():
assert status_b["etag"] == etag_b
def test_uses_get_multi(monkeypatch):
@pytest.mark.asyncio
async def test_uses_get_multi(monkeypatch):
def breakdown(*a, **kw):
raise AssertionError("Expected use of get_multi")
@ -313,11 +331,11 @@ def test_uses_get_multi(monkeypatch):
old_get = MemoryStorage.get
def get_multi(self, hrefs):
async def get_multi(self, hrefs):
hrefs = list(hrefs)
get_multi_calls.append(hrefs)
for href in hrefs:
item, etag = old_get(self, href)
item, etag = await old_get(self, href)
yield href, item, etag
monkeypatch.setattr(MemoryStorage, "get", breakdown)
@ -326,72 +344,77 @@ def test_uses_get_multi(monkeypatch):
a = MemoryStorage()
b = MemoryStorage()
item = Item("UID:1")
expected_href, etag = a.upload(item)
expected_href, etag = await a.upload(item)
sync(a, b, {})
await sync(a, b, {})
assert get_multi_calls == [[expected_href]]
def test_empty_storage_dataloss():
@pytest.mark.asyncio
async def test_empty_storage_dataloss():
a = MemoryStorage()
b = MemoryStorage()
a.upload(Item("UID:1"))
a.upload(Item("UID:2"))
await a.upload(Item("UID:1"))
await a.upload(Item("UID:2"))
status = {}
sync(a, b, status)
await sync(a, b, status)
with pytest.raises(StorageEmpty):
sync(MemoryStorage(), b, status)
await sync(MemoryStorage(), b, status)
with pytest.raises(StorageEmpty):
sync(a, MemoryStorage(), status)
await sync(a, MemoryStorage(), status)
def test_no_uids():
@pytest.mark.asyncio
async def test_no_uids():
a = MemoryStorage()
b = MemoryStorage()
a.upload(Item("ASDF"))
b.upload(Item("FOOBAR"))
await a.upload(Item("ASDF"))
await b.upload(Item("FOOBAR"))
status = {}
sync(a, b, status)
await sync(a, b, status)
assert items(a) == items(b) == {"ASDF", "FOOBAR"}
def test_changed_uids():
@pytest.mark.asyncio
async def test_changed_uids():
a = MemoryStorage()
b = MemoryStorage()
href_a, etag_a = a.upload(Item("UID:A-ONE"))
href_b, etag_b = b.upload(Item("UID:B-ONE"))
href_a, etag_a = await a.upload(Item("UID:A-ONE"))
href_b, etag_b = await b.upload(Item("UID:B-ONE"))
status = {}
sync(a, b, status)
await sync(a, b, status)
a.update(href_a, Item("UID:A-TWO"), etag_a)
sync(a, b, status)
await a.update(href_a, Item("UID:A-TWO"), etag_a)
await sync(a, b, status)
def test_both_readonly():
@pytest.mark.asyncio
async def test_both_readonly():
a = MemoryStorage(read_only=True)
b = MemoryStorage(read_only=True)
assert a.read_only
assert b.read_only
status = {}
with pytest.raises(BothReadOnly):
sync(a, b, status)
await sync(a, b, status)
def test_partial_sync_revert():
@pytest.mark.asyncio
async def test_partial_sync_revert():
a = MemoryStorage(instance_name="a")
b = MemoryStorage(instance_name="b")
status = {}
a.upload(Item("UID:1"))
b.upload(Item("UID:2"))
await a.upload(Item("UID:1"))
await b.upload(Item("UID:2"))
b.read_only = True
sync(a, b, status, partial_sync="revert")
await sync(a, b, status, partial_sync="revert")
assert len(status) == 2
assert items(a) == {"UID:1", "UID:2"}
assert items(b) == {"UID:2"}
sync(a, b, status, partial_sync="revert")
await sync(a, b, status, partial_sync="revert")
assert len(status) == 1
assert items(a) == {"UID:2"}
assert items(b) == {"UID:2"}
@ -399,37 +422,39 @@ def test_partial_sync_revert():
# Check that updates get reverted
a.items[next(iter(a.items))] = ("foo", Item("UID:2\nupdated"))
assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync="revert")
await sync(a, b, status, partial_sync="revert")
assert len(status) == 1
assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync="revert")
await sync(a, b, status, partial_sync="revert")
assert items(a) == {"UID:2"}
# Check that deletions get reverted
a.items.clear()
sync(a, b, status, partial_sync="revert", force_delete=True)
sync(a, b, status, partial_sync="revert", force_delete=True)
await sync(a, b, status, partial_sync="revert", force_delete=True)
await sync(a, b, status, partial_sync="revert", force_delete=True)
assert items(a) == {"UID:2"}
@pytest.mark.parametrize("sync_inbetween", (True, False))
def test_ident_conflict(sync_inbetween):
@pytest.mark.asyncio
async def test_ident_conflict(sync_inbetween):
a = MemoryStorage()
b = MemoryStorage()
status = {}
href_a, etag_a = a.upload(Item("UID:aaa"))
href_b, etag_b = a.upload(Item("UID:bbb"))
href_a, etag_a = await a.upload(Item("UID:aaa"))
href_b, etag_b = await a.upload(Item("UID:bbb"))
if sync_inbetween:
sync(a, b, status)
await sync(a, b, status)
a.update(href_a, Item("UID:xxx"), etag_a)
a.update(href_b, Item("UID:xxx"), etag_b)
await a.update(href_a, Item("UID:xxx"), etag_a)
await a.update(href_b, Item("UID:xxx"), etag_b)
with pytest.raises(IdentConflict):
sync(a, b, status)
await sync(a, b, status)
def test_moved_href():
@pytest.mark.asyncio
async def test_moved_href():
"""
Concrete application: ppl_ stores contact aliases in filenames, which means
item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but
@ -440,8 +465,8 @@ def test_moved_href():
a = MemoryStorage()
b = MemoryStorage()
status = {}
href, etag = a.upload(Item("UID:haha"))
sync(a, b, status)
href, etag = await a.upload(Item("UID:haha"))
await sync(a, b, status)
b.items["lol"] = b.items.pop("haha")
@ -451,7 +476,7 @@ def test_moved_href():
# No actual sync actions
a.delete = a.update = a.upload = b.delete = b.update = b.upload = blow_up
sync(a, b, status)
await sync(a, b, status)
assert len(status) == 1
assert items(a) == items(b) == {"UID:haha"}
assert status["haha"][1]["href"] == "lol"
@ -460,12 +485,13 @@ def test_moved_href():
# Further sync should be a noop. Not even prefetching should occur.
b.get_multi = blow_up
sync(a, b, status)
await sync(a, b, status)
assert old_status == status
assert items(a) == items(b) == {"UID:haha"}
def test_bogus_etag_change():
@pytest.mark.asyncio
async def test_bogus_etag_change():
"""Assert that sync algorithm is resilient against etag changes if content
didn\'t change.
@ -475,27 +501,33 @@ def test_bogus_etag_change():
a = MemoryStorage()
b = MemoryStorage()
status = {}
href_a, etag_a = a.upload(Item("UID:ASDASD"))
sync(a, b, status)
assert len(status) == len(list(a.list())) == len(list(b.list())) == 1
href_a, etag_a = await a.upload(Item("UID:ASDASD"))
await sync(a, b, status)
assert (
len(status)
== len(await aiostream.stream.list(a.list()))
== len(await aiostream.stream.list(b.list()))
== 1
)
((href_b, etag_b),) = b.list()
a.update(href_a, Item("UID:ASDASD"), etag_a)
b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b)
((href_b, etag_b),) = await aiostream.stream.list(b.list())
await a.update(href_a, Item("UID:ASDASD"), etag_a)
await b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b)
b.delete = b.update = b.upload = blow_up
sync(a, b, status)
await sync(a, b, status)
assert len(status) == 1
assert items(a) == items(b) == {"UID:ASDASD\nACTUALCHANGE:YES"}
def test_unicode_hrefs():
@pytest.mark.asyncio
async def test_unicode_hrefs():
a = MemoryStorage()
b = MemoryStorage()
status = {}
href, etag = a.upload(Item("UID:äää"))
sync(a, b, status)
href, etag = await a.upload(Item("UID:äää"))
await sync(a, b, status)
class ActionIntentionallyFailed(Exception):
@ -511,11 +543,12 @@ class SyncMachine(RuleBasedStateMachine):
Storage = Bundle("storage")
@rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans())
@pytest.mark.asyncio
def newstorage(self, flaky_etags, null_etag_on_upload):
s = MemoryStorage()
if flaky_etags:
def get(href):
async def get(href):
old_etag, item = s.items[href]
etag = _random_string()
s.items[href] = etag, item
@ -526,8 +559,15 @@ class SyncMachine(RuleBasedStateMachine):
if null_etag_on_upload:
_old_upload = s.upload
_old_update = s.update
s.upload = lambda item: (_old_upload(item)[0], "NULL")
s.update = lambda h, i, e: _old_update(h, i, e) and "NULL"
async def upload(item):
return ((await _old_upload(item)))[0], "NULL"
async def update(href, item, etag):
return await _old_update(href, item, etag) and "NULL"
s.upload = upload
s.update = update
return s
@ -547,11 +587,11 @@ class SyncMachine(RuleBasedStateMachine):
_old_upload = s.upload
_old_update = s.update
def upload(item):
return _old_upload(item)[0], None
async def upload(item):
return (await _old_upload(item))[0], None
def update(href, item, etag):
_old_update(href, item, etag)
async def update(href, item, etag):
return await _old_update(href, item, etag)
s.upload = upload
s.update = update
@ -590,66 +630,73 @@ class SyncMachine(RuleBasedStateMachine):
with_error_callback,
partial_sync,
):
assume(a is not b)
old_items_a = items(a)
old_items_b = items(b)
async def inner():
assume(a is not b)
old_items_a = items(a)
old_items_b = items(b)
a.instance_name = "a"
b.instance_name = "b"
a.instance_name = "a"
b.instance_name = "b"
errors = []
errors = []
if with_error_callback:
error_callback = errors.append
else:
error_callback = None
if with_error_callback:
error_callback = errors.append
else:
error_callback = None
try:
# If one storage is read-only, double-sync because changes don't
# get reverted immediately.
for _ in range(2 if a.read_only or b.read_only else 1):
sync(
a,
b,
status,
force_delete=force_delete,
conflict_resolution=conflict_resolution,
error_callback=error_callback,
partial_sync=partial_sync,
try:
# If one storage is read-only, double-sync because changes don't
# get reverted immediately.
for _ in range(2 if a.read_only or b.read_only else 1):
await sync(
a,
b,
status,
force_delete=force_delete,
conflict_resolution=conflict_resolution,
error_callback=error_callback,
partial_sync=partial_sync,
)
for e in errors:
raise e
except PartialSync:
assert partial_sync == "error"
except ActionIntentionallyFailed:
pass
except BothReadOnly:
assert a.read_only and b.read_only
assume(False)
except StorageEmpty:
if force_delete:
raise
else:
not_a = not await aiostream.stream.list(a.list())
not_b = not await aiostream.stream.list(b.list())
assert not_a or not_b
else:
items_a = items(a)
items_b = items(b)
assert items_a == items_b or partial_sync == "ignore"
assert items_a == old_items_a or not a.read_only
assert items_b == old_items_b or not b.read_only
assert (
set(a.items) | set(b.items) == set(status)
or partial_sync == "ignore"
)
for e in errors:
raise e
except PartialSync:
assert partial_sync == "error"
except ActionIntentionallyFailed:
pass
except BothReadOnly:
assert a.read_only and b.read_only
assume(False)
except StorageEmpty:
if force_delete:
raise
else:
assert not list(a.list()) or not list(b.list())
else:
items_a = items(a)
items_b = items(b)
assert items_a == items_b or partial_sync == "ignore"
assert items_a == old_items_a or not a.read_only
assert items_b == old_items_b or not b.read_only
assert (
set(a.items) | set(b.items) == set(status) or partial_sync == "ignore"
)
asyncio.run(inner())
TestSyncMachine = SyncMachine.TestCase
@pytest.mark.parametrize("error_callback", [True, False])
def test_rollback(error_callback):
@pytest.mark.asyncio
async def test_rollback(error_callback):
a = MemoryStorage()
b = MemoryStorage()
status = {}
@ -662,7 +709,7 @@ def test_rollback(error_callback):
if error_callback:
errors = []
sync(
await sync(
a,
b,
status=status,
@ -677,16 +724,22 @@ def test_rollback(error_callback):
assert status["1"]
else:
with pytest.raises(ActionIntentionallyFailed):
sync(a, b, status=status, conflict_resolution="a wins")
await sync(a, b, status=status, conflict_resolution="a wins")
def test_duplicate_hrefs():
@pytest.mark.asyncio
async def test_duplicate_hrefs():
a = MemoryStorage()
b = MemoryStorage()
a.list = lambda: [("a", "a")] * 3
async def fake_list():
for item in [("a", "a")] * 3:
yield item
a.list = fake_list
a.items["a"] = ("a", Item("UID:a"))
status = {}
sync(a, b, status)
await sync(a, b, status)
with pytest.raises(AssertionError):
sync(a, b, status)
await sync(a, b, status)

View file

@ -12,105 +12,122 @@ from vdirsyncer.storage.base import normalize_meta_value
from vdirsyncer.storage.memory import MemoryStorage
def test_irrelevant_status():
@pytest.mark.asyncio
async def test_irrelevant_status():
a = MemoryStorage()
b = MemoryStorage()
status = {"foo": "bar"}
metasync(a, b, status, keys=())
await metasync(a, b, status, keys=())
assert not status
def test_basic(monkeypatch):
@pytest.mark.asyncio
async def test_basic(monkeypatch):
a = MemoryStorage()
b = MemoryStorage()
status = {}
a.set_meta("foo", "bar")
metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "bar"
await a.set_meta("foo", "bar")
await metasync(a, b, status, keys=["foo"])
assert await a.get_meta("foo") == await b.get_meta("foo") == "bar"
a.set_meta("foo", "baz")
metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "baz"
await a.set_meta("foo", "baz")
await metasync(a, b, status, keys=["foo"])
assert await a.get_meta("foo") == await b.get_meta("foo") == "baz"
monkeypatch.setattr(a, "set_meta", blow_up)
monkeypatch.setattr(b, "set_meta", blow_up)
metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == "baz"
await metasync(a, b, status, keys=["foo"])
assert await a.get_meta("foo") == await b.get_meta("foo") == "baz"
monkeypatch.undo()
monkeypatch.undo()
b.set_meta("foo", None)
metasync(a, b, status, keys=["foo"])
assert not a.get_meta("foo") and not b.get_meta("foo")
await b.set_meta("foo", None)
await metasync(a, b, status, keys=["foo"])
assert not await a.get_meta("foo") and not await b.get_meta("foo")
@pytest.fixture
def conflict_state(request):
@pytest.mark.asyncio
async def conflict_state(request, event_loop):
a = MemoryStorage()
b = MemoryStorage()
status = {}
a.set_meta("foo", "bar")
b.set_meta("foo", "baz")
await a.set_meta("foo", "bar")
await b.set_meta("foo", "baz")
def cleanup():
assert a.get_meta("foo") == "bar"
assert b.get_meta("foo") == "baz"
assert not status
async def do_cleanup():
assert await a.get_meta("foo") == "bar"
assert await b.get_meta("foo") == "baz"
assert not status
event_loop.run_until_complete(do_cleanup())
request.addfinalizer(cleanup)
return a, b, status
def test_conflict(conflict_state):
@pytest.mark.asyncio
async def test_conflict(conflict_state):
a, b, status = conflict_state
with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=["foo"])
await metasync(a, b, status, keys=["foo"])
def test_invalid_conflict_resolution(conflict_state):
@pytest.mark.asyncio
async def test_invalid_conflict_resolution(conflict_state):
a, b, status = conflict_state
with pytest.raises(UserError) as excinfo:
metasync(a, b, status, keys=["foo"], conflict_resolution="foo")
await metasync(a, b, status, keys=["foo"], conflict_resolution="foo")
assert "Invalid conflict resolution setting" in str(excinfo.value)
def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch):
@pytest.mark.asyncio
async def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch):
a, b, status = conflict_state
warnings = []
monkeypatch.setattr(logger, "warning", warnings.append)
with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=["foo"], conflict_resolution=lambda *a, **kw: None)
await metasync(
a,
b,
status,
keys=["foo"],
conflict_resolution=lambda *a, **kw: None,
)
assert warnings == ["Custom commands don't work on metasync."]
def test_conflict_same_content():
@pytest.mark.asyncio
async def test_conflict_same_content():
a = MemoryStorage()
b = MemoryStorage()
status = {}
a.set_meta("foo", "bar")
b.set_meta("foo", "bar")
await a.set_meta("foo", "bar")
await b.set_meta("foo", "bar")
metasync(a, b, status, keys=["foo"])
assert a.get_meta("foo") == b.get_meta("foo") == status["foo"] == "bar"
await metasync(a, b, status, keys=["foo"])
assert await a.get_meta("foo") == await b.get_meta("foo") == status["foo"] == "bar"
@pytest.mark.parametrize("wins", "ab")
def test_conflict_x_wins(wins):
@pytest.mark.asyncio
async def test_conflict_x_wins(wins):
a = MemoryStorage()
b = MemoryStorage()
status = {}
a.set_meta("foo", "bar")
b.set_meta("foo", "baz")
await a.set_meta("foo", "bar")
await b.set_meta("foo", "baz")
metasync(
await metasync(
a,
b,
status,
@ -119,8 +136,8 @@ def test_conflict_x_wins(wins):
)
assert (
a.get_meta("foo")
== b.get_meta("foo")
await a.get_meta("foo")
== await b.get_meta("foo")
== status["foo"]
== ("bar" if wins == "a" else "baz")
)
@ -148,7 +165,8 @@ metadata = st.dictionaries(keys, values)
keys={"0"},
conflict_resolution="a wins",
)
def test_fuzzing(a, b, status, keys, conflict_resolution):
@pytest.mark.asyncio
async def test_fuzzing(a, b, status, keys, conflict_resolution):
def _get_storage(m, instance_name):
s = MemoryStorage(instance_name=instance_name)
s.metadata = m
@ -159,13 +177,13 @@ def test_fuzzing(a, b, status, keys, conflict_resolution):
winning_storage = a if conflict_resolution == "a wins" else b
expected_values = {
key: winning_storage.get_meta(key) for key in keys if key not in status
key: await winning_storage.get_meta(key) for key in keys if key not in status
}
metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution)
await metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution)
for key in keys:
s = status.get(key, "")
assert a.get_meta(key) == b.get_meta(key) == s
assert await a.get_meta(key) == await b.get_meta(key) == s
if expected_values.get(key, "") and s:
assert s == expected_values[key]

View file

@ -1,3 +1,4 @@
import aiostream
import pytest
from hypothesis import given
from hypothesis import HealthCheck
@ -15,37 +16,42 @@ from vdirsyncer.vobject import Item
@given(uid=uid_strategy)
# Using the random module for UIDs:
@settings(suppress_health_check=HealthCheck.all())
def test_repair_uids(uid):
@pytest.mark.asyncio
async def test_repair_uids(uid):
s = MemoryStorage()
s.items = {
"one": ("asdf", Item(f"BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD")),
"two": ("asdf", Item(f"BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD")),
}
uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()]
uid1, uid2 = [(await s.get(href))[0].uid async for href, etag in s.list()]
assert uid1 == uid2
repair_storage(s, repair_unsafe_uid=False)
await repair_storage(s, repair_unsafe_uid=False)
uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()]
uid1, uid2 = [
(await s.get(href))[0].uid
for href, etag in await aiostream.stream.list(s.list())
]
assert uid1 != uid2
@given(uid=uid_strategy.filter(lambda x: not href_safe(x)))
# Using the random module for UIDs:
@settings(suppress_health_check=HealthCheck.all())
def test_repair_unsafe_uids(uid):
@pytest.mark.asyncio
async def test_repair_unsafe_uids(uid):
s = MemoryStorage()
item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD")
href, etag = s.upload(item)
assert s.get(href)[0].uid == uid
href, etag = await s.upload(item)
assert (await s.get(href))[0].uid == uid
assert not href_safe(uid)
repair_storage(s, repair_unsafe_uid=True)
await repair_storage(s, repair_unsafe_uid=True)
new_href = list(s.list())[0][0]
new_href = (await aiostream.stream.list(s.list()))[0][0]
assert href_safe(new_href)
newuid = s.get(new_href)[0].uid
newuid = (await s.get(new_href))[0].uid
assert href_safe(newuid)

View file

@ -1,8 +1,10 @@
import asyncio
import functools
import json
import logging
import sys
import aiohttp
import click
import click_log
@ -124,17 +126,26 @@ def sync(ctx, collections, force_delete):
"""
from .tasks import prepare_pair, sync_collection
for pair_name, collections in collections:
for collection, config in prepare_pair(
pair_name=pair_name,
collections=collections,
config=ctx.config,
):
sync_collection(
collection=collection,
general=config,
force_delete=force_delete,
)
async def main(collections):
conn = aiohttp.TCPConnector(limit_per_host=16)
for pair_name, collections in collections:
async for collection, config in prepare_pair(
pair_name=pair_name,
collections=collections,
config=ctx.config,
connector=conn,
):
await sync_collection(
collection=collection,
general=config,
force_delete=force_delete,
connector=conn,
)
await conn.close()
asyncio.run(main(collections))
@app.command()
@ -149,13 +160,31 @@ def metasync(ctx, collections):
"""
from .tasks import prepare_pair, metasync_collection
for pair_name, collections in collections:
for collection, config in prepare_pair(
pair_name=pair_name,
collections=collections,
config=ctx.config,
):
metasync_collection(collection=collection, general=config)
async def main(collections):
conn = aiohttp.TCPConnector(limit_per_host=16)
for pair_name, collections in collections:
collections = prepare_pair(
pair_name=pair_name,
collections=collections,
config=ctx.config,
connector=conn,
)
await asyncio.gather(
*[
metasync_collection(
collection=collection,
general=config,
connector=conn,
)
async for collection, config in collections
]
)
await conn.close()
asyncio.run(main(collections))
@app.command()
@ -178,15 +207,23 @@ def discover(ctx, pairs, list):
config = ctx.config
for pair_name in pairs or config.pairs:
pair = config.get_pair(pair_name)
async def main():
conn = aiohttp.TCPConnector(limit_per_host=16)
discover_collections(
status_path=config.general["status_path"],
pair=pair,
from_cache=False,
list_collections=list,
)
for pair_name in pairs or config.pairs:
pair = config.get_pair(pair_name)
await discover_collections(
status_path=config.general["status_path"],
pair=pair,
from_cache=False,
list_collections=list,
connector=conn,
)
await conn.close()
asyncio.run(main())
@app.command()
@ -225,7 +262,18 @@ def repair(ctx, collection, repair_unsafe_uid):
"turn off other client's synchronization features."
)
click.confirm("Do you want to continue?", abort=True)
repair_collection(ctx.config, collection, repair_unsafe_uid=repair_unsafe_uid)
async def main():
conn = aiohttp.TCPConnector(limit_per_host=16)
await repair_collection(
ctx.config,
collection,
repair_unsafe_uid=repair_unsafe_uid,
connector=conn,
)
await conn.close()
asyncio.run(main())
@app.command()

View file

@ -1,10 +1,13 @@
import asyncio
import hashlib
import json
import logging
import sys
import aiohttp
import aiostream
from .. import exceptions
from ..utils import cached_property
from .utils import handle_collection_not_found
from .utils import handle_storage_init_error
from .utils import load_status
@ -35,7 +38,14 @@ def _get_collections_cache_key(pair):
return m.hexdigest()
def collections_for_pair(status_path, pair, from_cache=True, list_collections=False):
async def collections_for_pair(
status_path,
pair,
from_cache=True,
list_collections=False,
*,
connector: aiohttp.TCPConnector,
):
"""Determine all configured collections for a given pair. Takes care of
shortcut expansion and result caching.
@ -67,16 +77,24 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
logger.info("Discovering collections for pair {}".format(pair.name))
a_discovered = _DiscoverResult(pair.config_a)
b_discovered = _DiscoverResult(pair.config_b)
a_discovered = _DiscoverResult(pair.config_a, connector=connector)
b_discovered = _DiscoverResult(pair.config_b, connector=connector)
if list_collections:
_print_collections(pair.config_a["instance_name"], a_discovered.get_self)
_print_collections(pair.config_b["instance_name"], b_discovered.get_self)
await _print_collections(
pair.config_a["instance_name"],
a_discovered.get_self,
connector=connector,
)
await _print_collections(
pair.config_b["instance_name"],
b_discovered.get_self,
connector=connector,
)
# We have to use a list here because the special None/null value would get
# mangled to string (because JSON objects always have string keys).
rv = list(
rv = await aiostream.stream.list(
expand_collections(
shortcuts=pair.collections,
config_a=pair.config_a,
@ -87,7 +105,7 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
)
)
_sanity_check_collections(rv)
await _sanity_check_collections(rv, connector=connector)
save_status(
status_path,
@ -103,10 +121,14 @@ def collections_for_pair(status_path, pair, from_cache=True, list_collections=Fa
return rv
def _sanity_check_collections(collections):
async def _sanity_check_collections(collections, *, connector):
tasks = []
for _, (a_args, b_args) in collections:
storage_instance_from_config(a_args)
storage_instance_from_config(b_args)
tasks.append(storage_instance_from_config(a_args, connector=connector))
tasks.append(storage_instance_from_config(b_args, connector=connector))
await asyncio.gather(*tasks)
def _compress_collections_cache(collections, config_a, config_b):
@ -134,17 +156,28 @@ def _expand_collections_cache(collections, config_a, config_b):
class _DiscoverResult:
def __init__(self, config):
def __init__(self, config, *, connector):
self._cls, _ = storage_class_from_config(config)
self._config = config
def get_self(self):
if self._cls.__name__ in [
"CardDAVStorage",
"CalDAVStorage",
"GoogleCalendarStorage",
]:
assert connector is not None
config["connector"] = connector
self._config = config
self._discovered = None
async def get_self(self):
if self._discovered is None:
self._discovered = await self._discover()
return self._discovered
@cached_property
def _discovered(self):
async def _discover(self):
try:
discovered = list(self._cls.discover(**self._config))
discovered = await aiostream.stream.list(self._cls.discover(**self._config))
except NotImplementedError:
return {}
except Exception:
@ -158,7 +191,7 @@ class _DiscoverResult:
return rv
def expand_collections(
async def expand_collections(
shortcuts,
config_a,
config_b,
@ -173,9 +206,9 @@ def expand_collections(
for shortcut in shortcuts:
if shortcut == "from a":
collections = get_a_discovered()
collections = await get_a_discovered()
elif shortcut == "from b":
collections = get_b_discovered()
collections = await get_b_discovered()
else:
collections = [shortcut]
@ -189,17 +222,23 @@ def expand_collections(
continue
handled_collections.add(collection)
a_args = _collection_from_discovered(
get_a_discovered, collection_a, config_a, _handle_collection_not_found
a_args = await _collection_from_discovered(
get_a_discovered,
collection_a,
config_a,
_handle_collection_not_found,
)
b_args = _collection_from_discovered(
get_b_discovered, collection_b, config_b, _handle_collection_not_found
b_args = await _collection_from_discovered(
get_b_discovered,
collection_b,
config_b,
_handle_collection_not_found,
)
yield collection, (a_args, b_args)
def _collection_from_discovered(
async def _collection_from_discovered(
get_discovered, collection, config, _handle_collection_not_found
):
if collection is None:
@ -208,14 +247,19 @@ def _collection_from_discovered(
return args
try:
return get_discovered()[collection]
return (await get_discovered())[collection]
except KeyError:
return _handle_collection_not_found(config, collection)
return await _handle_collection_not_found(config, collection)
def _print_collections(instance_name, get_discovered):
async def _print_collections(
instance_name: str,
get_discovered,
*,
connector: aiohttp.TCPConnector,
):
try:
discovered = get_discovered()
discovered = await get_discovered()
except exceptions.UserError:
raise
except Exception:
@ -238,8 +282,12 @@ def _print_collections(instance_name, get_discovered):
args["instance_name"] = instance_name
try:
storage = storage_instance_from_config(args, create=False)
displayname = storage.get_meta("displayname")
storage = await storage_instance_from_config(
args,
create=False,
connector=connector,
)
displayname = await storage.get_meta("displayname")
except Exception:
displayname = ""

View file

@ -1,5 +1,7 @@
import json
import aiohttp
from .. import exceptions
from .. import sync
from .config import CollectionConfig
@ -15,11 +17,15 @@ from .utils import manage_sync_status
from .utils import save_status
def prepare_pair(pair_name, collections, config):
async def prepare_pair(pair_name, collections, config, *, connector):
pair = config.get_pair(pair_name)
all_collections = dict(
collections_for_pair(status_path=config.general["status_path"], pair=pair)
await collections_for_pair(
status_path=config.general["status_path"],
pair=pair,
connector=connector,
)
)
for collection_name in collections or all_collections:
@ -37,15 +43,21 @@ def prepare_pair(pair_name, collections, config):
yield collection, config.general
def sync_collection(collection, general, force_delete):
async def sync_collection(
collection,
general,
force_delete,
*,
connector: aiohttp.TCPConnector,
):
pair = collection.pair
status_name = get_status_name(pair.name, collection.name)
try:
cli_logger.info(f"Syncing {status_name}")
a = storage_instance_from_config(collection.config_a)
b = storage_instance_from_config(collection.config_b)
a = await storage_instance_from_config(collection.config_a, connector=connector)
b = await storage_instance_from_config(collection.config_b, connector=connector)
sync_failed = False
@ -57,7 +69,7 @@ def sync_collection(collection, general, force_delete):
with manage_sync_status(
general["status_path"], pair.name, collection.name
) as status:
sync.sync(
await sync.sync(
a,
b,
status,
@ -76,9 +88,9 @@ def sync_collection(collection, general, force_delete):
raise JobFailed()
def discover_collections(pair, **kwargs):
rv = collections_for_pair(pair=pair, **kwargs)
collections = list(c for c, (a, b) in rv)
async def discover_collections(pair, **kwargs):
rv = await collections_for_pair(pair=pair, **kwargs)
collections = [c for c, (a, b) in rv]
if collections == [None]:
collections = None
cli_logger.info(
@ -86,7 +98,13 @@ def discover_collections(pair, **kwargs):
)
def repair_collection(config, collection, repair_unsafe_uid):
async def repair_collection(
config,
collection,
repair_unsafe_uid,
*,
connector: aiohttp.TCPConnector,
):
from ..repair import repair_storage
storage_name, collection = collection, None
@ -99,7 +117,7 @@ def repair_collection(config, collection, repair_unsafe_uid):
if collection is not None:
cli_logger.info("Discovering collections (skipping cache).")
cls, config = storage_class_from_config(config)
for config in cls.discover(**config):
async for config in cls.discover(**config):
if config["collection"] == collection:
break
else:
@ -110,14 +128,14 @@ def repair_collection(config, collection, repair_unsafe_uid):
)
config["type"] = storage_type
storage = storage_instance_from_config(config)
storage = await storage_instance_from_config(config, connector=connector)
cli_logger.info(f"Repairing {storage_name}/{collection}")
cli_logger.warning("Make sure no other program is talking to the server.")
repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid)
await repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid)
def metasync_collection(collection, general):
async def metasync_collection(collection, general, *, connector: aiohttp.TCPConnector):
from ..metasync import metasync
pair = collection.pair
@ -133,10 +151,10 @@ def metasync_collection(collection, general):
or {}
)
a = storage_instance_from_config(collection.config_a)
b = storage_instance_from_config(collection.config_b)
a = await storage_instance_from_config(collection.config_a, connector=connector)
b = await storage_instance_from_config(collection.config_b, connector=connector)
metasync(
await metasync(
a,
b,
status,

View file

@ -5,6 +5,7 @@ import json
import os
import sys
import aiohttp
import click
from atomicwrites import atomic_write
@ -252,22 +253,37 @@ def storage_class_from_config(config):
return cls, config
def storage_instance_from_config(config, create=True):
async def storage_instance_from_config(
config,
create=True,
*,
connector: aiohttp.TCPConnector,
):
"""
:param config: A configuration dictionary to pass as kwargs to the class
corresponding to config['type']
"""
from vdirsyncer.storage.dav import DAVStorage
from vdirsyncer.storage.http import HttpStorage
cls, new_config = storage_class_from_config(config)
if issubclass(cls, DAVStorage) or issubclass(cls, HttpStorage):
assert connector is not None # FIXME: hack?
new_config["connector"] = connector
try:
return cls(**new_config)
except exceptions.CollectionNotFound as e:
if create:
config = handle_collection_not_found(
config = await handle_collection_not_found(
config, config.get("collection", None), e=str(e)
)
return storage_instance_from_config(config, create=False)
return await storage_instance_from_config(
config,
create=False,
connector=connector,
)
else:
raise
except Exception:
@ -319,7 +335,7 @@ def assert_permissions(path, wanted):
os.chmod(path, wanted)
def handle_collection_not_found(config, collection, e=None):
async def handle_collection_not_found(config, collection, e=None):
storage_name = config.get("instance_name", None)
cli_logger.warning(
@ -333,7 +349,7 @@ def handle_collection_not_found(config, collection, e=None):
cls, config = storage_class_from_config(config)
config["collection"] = collection
try:
args = cls.create_collection(**config)
args = await cls.create_collection(**config)
args["type"] = storage_type
return args
except NotImplementedError as e:

View file

@ -1,6 +1,6 @@
import logging
import requests
import aiohttp
from . import __version__
from . import DOCS_HOME
@ -99,23 +99,8 @@ def prepare_client_cert(cert):
return cert
def _install_fingerprint_adapter(session, fingerprint):
prefix = "https://"
try:
from requests_toolbelt.adapters.fingerprint import FingerprintAdapter
except ImportError:
raise RuntimeError(
"`verify_fingerprint` can only be used with "
"requests-toolbelt versions >= 0.4.0"
)
if not isinstance(session.adapters[prefix], FingerprintAdapter):
fingerprint_adapter = FingerprintAdapter(fingerprint)
session.mount(prefix, fingerprint_adapter)
def request(
method, url, session=None, latin1_fallback=True, verify_fingerprint=None, **kwargs
async def request(
method, url, session, latin1_fallback=True, verify_fingerprint=None, **kwargs
):
"""
Wrapper method for requests, to ease logging and mocking. Parameters should
@ -132,16 +117,20 @@ def request(
https://github.com/kennethreitz/requests/issues/2042
"""
if session is None:
session = requests.Session()
if verify_fingerprint is not None:
_install_fingerprint_adapter(session, verify_fingerprint)
ssl = aiohttp.Fingerprint(bytes.fromhex(verify_fingerprint.replace(":", "")))
kwargs.pop("verify", None)
elif kwargs.pop("verify", None) is False:
ssl = False
else:
ssl = None # TODO XXX: Check all possible values for this
session.hooks = {"response": _fix_redirects}
func = session.request
# TODO: rewrite using
# https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing
logger.debug("=" * 20)
logger.debug(f"{method} {url}")
logger.debug(kwargs.get("headers", {}))
@ -150,7 +139,14 @@ def request(
assert isinstance(kwargs.get("data", b""), bytes)
r = func(method, url, **kwargs)
kwargs.pop("cert", None) # TODO XXX FIXME!
# Hacks to translate API
if auth := kwargs.pop("auth", None):
kwargs["auth"] = aiohttp.BasicAuth(*auth)
r = func(method, url, ssl=ssl, **kwargs)
r = await r
# See https://github.com/kennethreitz/requests/issues/2042
content_type = r.headers.get("Content-Type", "")
@ -162,13 +158,13 @@ def request(
logger.debug("Removing latin1 fallback")
r.encoding = None
logger.debug(r.status_code)
logger.debug(r.status)
logger.debug(r.headers)
logger.debug(r.content)
if r.status_code == 412:
if r.status == 412:
raise exceptions.PreconditionFailed(r.reason)
if r.status_code in (404, 410):
if r.status in (404, 410):
raise exceptions.NotFoundError(r.reason)
r.raise_for_status()

View file

@ -14,24 +14,24 @@ class MetaSyncConflict(MetaSyncError):
key = None
def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
def _a_to_b():
async def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
async def _a_to_b():
logger.info(f"Copying {key} to {storage_b}")
storage_b.set_meta(key, a)
await storage_b.set_meta(key, a)
status[key] = a
def _b_to_a():
async def _b_to_a():
logger.info(f"Copying {key} to {storage_a}")
storage_a.set_meta(key, b)
await storage_a.set_meta(key, b)
status[key] = b
def _resolve_conflict():
async def _resolve_conflict():
if a == b:
status[key] = a
elif conflict_resolution == "a wins":
_a_to_b()
await _a_to_b()
elif conflict_resolution == "b wins":
_b_to_a()
await _b_to_a()
else:
if callable(conflict_resolution):
logger.warning("Custom commands don't work on metasync.")
@ -40,8 +40,8 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
raise MetaSyncConflict(key)
for key in keys:
a = storage_a.get_meta(key)
b = storage_b.get_meta(key)
a = await storage_a.get_meta(key)
b = await storage_b.get_meta(key)
s = normalize_meta_value(status.get(key))
logger.debug(f"Key: {key}")
logger.debug(f"A: {a}")
@ -49,11 +49,11 @@ def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
logger.debug(f"S: {s}")
if a != s and b != s:
_resolve_conflict()
await _resolve_conflict()
elif a != s and b == s:
_a_to_b()
await _a_to_b()
elif a == s and b != s:
_b_to_a()
await _b_to_a()
else:
assert a == b

View file

@ -1,6 +1,8 @@
import logging
from os.path import basename
import aiostream
from .utils import generate_href
from .utils import href_safe
@ -11,11 +13,11 @@ class IrreparableItem(Exception):
pass
def repair_storage(storage, repair_unsafe_uid):
async def repair_storage(storage, repair_unsafe_uid):
seen_uids = set()
all_hrefs = list(storage.list())
all_hrefs = await aiostream.stream.list(storage.list())
for i, (href, _) in enumerate(all_hrefs):
item, etag = storage.get(href)
item, etag = await storage.get(href)
logger.info("[{}/{}] Processing {}".format(i, len(all_hrefs), href))
try:
@ -32,10 +34,10 @@ def repair_storage(storage, repair_unsafe_uid):
seen_uids.add(new_item.uid)
if new_item.raw != item.raw:
if new_item.uid != item.uid:
storage.upload(new_item)
storage.delete(href, etag)
await storage.upload(new_item)
await storage.delete(href, etag)
else:
storage.update(href, new_item, etag)
await storage.update(href, new_item, etag)
def repair_item(href, item, seen_uids, repair_unsafe_uid):

View file

@ -7,10 +7,10 @@ from ..utils import uniq
def mutating_storage_method(f):
@functools.wraps(f)
def inner(self, *args, **kwargs):
async def inner(self, *args, **kwargs):
if self.read_only:
raise exceptions.ReadOnlyError("This storage is read-only.")
return f(self, *args, **kwargs)
return await f(self, *args, **kwargs)
return inner
@ -77,7 +77,7 @@ class Storage(metaclass=StorageMeta):
self.collection = collection
@classmethod
def discover(cls, **kwargs):
async def discover(cls, **kwargs):
"""Discover collections given a basepath or -URL to many collections.
:param **kwargs: Keyword arguments to additionally pass to the storage
@ -92,10 +92,12 @@ class Storage(metaclass=StorageMeta):
from the last segment of a URL or filesystem path.
"""
if False:
yield # Needs to be an async generator
raise NotImplementedError()
@classmethod
def create_collection(cls, collection, **kwargs):
async def create_collection(cls, collection, **kwargs):
"""
Create the specified collection and return the new arguments.
@ -118,13 +120,13 @@ class Storage(metaclass=StorageMeta):
{x: getattr(self, x) for x in self._repr_attributes},
)
def list(self):
async def list(self):
"""
:returns: list of (href, etag)
"""
raise NotImplementedError()
def get(self, href):
async def get(self, href):
"""Fetch a single item.
:param href: href to fetch
@ -134,7 +136,7 @@ class Storage(metaclass=StorageMeta):
"""
raise NotImplementedError()
def get_multi(self, hrefs):
async def get_multi(self, hrefs):
"""Fetch multiple items. Duplicate hrefs must be ignored.
Functionally similar to :py:meth:`get`, but might bring performance
@ -146,22 +148,22 @@ class Storage(metaclass=StorageMeta):
:returns: iterable of (href, item, etag)
"""
for href in uniq(hrefs):
item, etag = self.get(href)
item, etag = await self.get(href)
yield href, item, etag
def has(self, href):
async def has(self, href):
"""Check if an item exists by its href.
:returns: True or False
"""
try:
self.get(href)
await self.get(href)
except exceptions.PreconditionFailed:
return False
else:
return True
def upload(self, item):
async def upload(self, item):
"""Upload a new item.
In cases where the new etag cannot be atomically determined (i.e. in
@ -176,7 +178,7 @@ class Storage(metaclass=StorageMeta):
"""
raise NotImplementedError()
def update(self, href, item, etag):
async def update(self, href, item, etag):
"""Update an item.
The etag may be none in some cases, see `upload`.
@ -189,7 +191,7 @@ class Storage(metaclass=StorageMeta):
"""
raise NotImplementedError()
def delete(self, href, etag):
async def delete(self, href, etag):
"""Delete an item by href.
:raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has
@ -197,8 +199,8 @@ class Storage(metaclass=StorageMeta):
"""
raise NotImplementedError()
@contextlib.contextmanager
def at_once(self):
@contextlib.asynccontextmanager
async def at_once(self):
"""A contextmanager that buffers all writes.
Essentially, this::
@ -217,7 +219,7 @@ class Storage(metaclass=StorageMeta):
"""
yield
def get_meta(self, key):
async def get_meta(self, key):
"""Get metadata value for collection/storage.
See the vdir specification for the keys that *have* to be accepted.
@ -228,7 +230,7 @@ class Storage(metaclass=StorageMeta):
raise NotImplementedError("This storage does not support metadata.")
def set_meta(self, key, value):
async def set_meta(self, key, value):
"""Get metadata value for collection/storage.
:param key: The metadata key.

View file

@ -5,8 +5,8 @@ import xml.etree.ElementTree as etree
from inspect import getfullargspec
from inspect import signature
import requests
from requests.exceptions import HTTPError
import aiohttp
import aiostream
from .. import exceptions
from .. import http
@ -18,6 +18,7 @@ from ..http import USERAGENT
from ..vobject import Item
from .base import normalize_meta_value
from .base import Storage
from vdirsyncer.exceptions import Error
dav_logger = logging.getLogger(__name__)
@ -44,10 +45,10 @@ def _contains_quoted_reserved_chars(x):
return False
def _assert_multistatus_success(r):
async def _assert_multistatus_success(r):
# Xandikos returns a multistatus on PUT.
try:
root = _parse_xml(r.content)
root = _parse_xml(await r.content.read())
except InvalidXMLResponse:
return
for status in root.findall(".//{DAV:}status"):
@ -57,7 +58,7 @@ def _assert_multistatus_success(r):
except (ValueError, IndexError):
continue
if st < 200 or st >= 400:
raise HTTPError(f"Server error: {st}")
raise Error(f"Server error: {st}")
def _normalize_href(base, href):
@ -169,14 +170,14 @@ class Discover:
_, collection = url.rstrip("/").rsplit("/", 1)
return urlparse.unquote(collection)
def find_principal(self):
async def find_principal(self):
try:
return self._find_principal_impl("")
except (HTTPError, exceptions.Error):
return await self._find_principal_impl("")
except (aiohttp.ClientResponseError, exceptions.Error):
dav_logger.debug("Trying out well-known URI")
return self._find_principal_impl(self._well_known_uri)
return await self._find_principal_impl(self._well_known_uri)
def _find_principal_impl(self, url):
async def _find_principal_impl(self, url):
headers = self.session.get_default_headers()
headers["Depth"] = "0"
body = b"""
@ -187,9 +188,14 @@ class Discover:
</propfind>
"""
response = self.session.request("PROPFIND", url, headers=headers, data=body)
response = await self.session.request(
"PROPFIND",
url,
headers=headers,
data=body,
)
root = _parse_xml(response.content)
root = _parse_xml(await response.content.read())
rv = root.find(".//{DAV:}current-user-principal/{DAV:}href")
if rv is None:
# This is for servers that don't support current-user-principal
@ -201,34 +207,37 @@ class Discover:
)
)
return response.url
return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/"
return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/"
def find_home(self):
url = self.find_principal()
async def find_home(self):
url = await self.find_principal()
headers = self.session.get_default_headers()
headers["Depth"] = "0"
response = self.session.request(
response = await self.session.request(
"PROPFIND", url, headers=headers, data=self._homeset_xml
)
root = etree.fromstring(response.content)
root = etree.fromstring(await response.content.read())
# Better don't do string formatting here, because of XML namespaces
rv = root.find(".//" + self._homeset_tag + "/{DAV:}href")
if rv is None:
raise InvalidXMLResponse("Couldn't find home-set.")
return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/"
return urlparse.urljoin(str(response.url), rv.text).rstrip("/") + "/"
def find_collections(self):
async def find_collections(self):
rv = None
try:
rv = list(self._find_collections_impl(""))
except (HTTPError, exceptions.Error):
rv = await aiostream.stream.list(self._find_collections_impl(""))
except (aiohttp.ClientResponseError, exceptions.Error):
pass
if rv:
return rv
dav_logger.debug("Given URL is not a homeset URL")
return self._find_collections_impl(self.find_home())
return await aiostream.stream.list(
self._find_collections_impl(await self.find_home())
)
def _check_collection_resource_type(self, response):
if self._resourcetype is None:
@ -245,13 +254,13 @@ class Discover:
return False
return True
def _find_collections_impl(self, url):
async def _find_collections_impl(self, url):
headers = self.session.get_default_headers()
headers["Depth"] = "1"
r = self.session.request(
r = await self.session.request(
"PROPFIND", url, headers=headers, data=self._collection_xml
)
root = _parse_xml(r.content)
root = _parse_xml(await r.content.read())
done = set()
for response in root.findall("{DAV:}response"):
if not self._check_collection_resource_type(response):
@ -260,33 +269,33 @@ class Discover:
href = response.find("{DAV:}href")
if href is None:
raise InvalidXMLResponse("Missing href tag for collection " "props.")
href = urlparse.urljoin(r.url, href.text)
href = urlparse.urljoin(str(r.url), href.text)
if href not in done:
done.add(href)
yield {"href": href}
def discover(self):
for c in self.find_collections():
async def discover(self):
for c in await self.find_collections():
url = c["href"]
collection = self._get_collection_from_url(url)
storage_args = dict(self.kwargs)
storage_args.update({"url": url, "collection": collection})
yield storage_args
def create(self, collection):
async def create(self, collection):
if collection is None:
collection = self._get_collection_from_url(self.kwargs["url"])
for c in self.discover():
async for c in self.discover():
if c["collection"] == collection:
return c
home = self.find_home()
home = await self.find_home()
url = urlparse.urljoin(home, urlparse.quote(collection, "/@"))
try:
url = self._create_collection_impl(url)
except HTTPError as e:
url = await self._create_collection_impl(url)
except (aiohttp.ClientResponseError, Error) as e:
raise NotImplementedError(e)
else:
rv = dict(self.kwargs)
@ -294,7 +303,7 @@ class Discover:
rv["url"] = url
return rv
def _create_collection_impl(self, url):
async def _create_collection_impl(self, url):
data = """<?xml version="1.0" encoding="utf-8" ?>
<mkcol xmlns="DAV:">
<set>
@ -312,13 +321,13 @@ class Discover:
"utf-8"
)
response = self.session.request(
response = await self.session.request(
"MKCOL",
url,
data=data,
headers=self.session.get_default_headers(),
)
return response.url
return str(response.url)
class CalDiscover(Discover):
@ -350,14 +359,18 @@ class CardDiscover(Discover):
class DAVSession:
"""
A helper class to connect to DAV servers.
"""
"""A helper class to connect to DAV servers."""
connector: aiohttp.BaseConnector
@classmethod
def init_and_remaining_args(cls, **kwargs):
def is_arg(k):
"""Return true if ``k`` is an argument of ``cls.__init__``."""
return k in argspec.args or k in argspec.kwonlyargs
argspec = getfullargspec(cls.__init__)
self_args, remainder = utils.split_dict(kwargs, argspec.args.__contains__)
self_args, remainder = utils.split_dict(kwargs, is_arg)
return cls(**self_args), remainder
@ -371,6 +384,8 @@ class DAVSession:
useragent=USERAGENT,
verify_fingerprint=None,
auth_cert=None,
*,
connector: aiohttp.BaseConnector,
):
self._settings = {
"cert": prepare_client_cert(auth_cert),
@ -380,21 +395,28 @@ class DAVSession:
self.useragent = useragent
self.url = url.rstrip("/") + "/"
self._session = requests.session()
self.connector = connector
@utils.cached_property
def parsed_url(self):
return urlparse.urlparse(self.url)
def request(self, method, path, **kwargs):
async def request(self, method, path, **kwargs):
url = self.url
if path:
url = urlparse.urljoin(self.url, path)
url = urlparse.urljoin(str(self.url), path)
more = dict(self._settings)
more.update(kwargs)
return http.request(method, url, session=self._session, **more)
# XXX: This is a temporary hack to pin-point bad refactoring.
assert self.connector is not None
async with aiohttp.ClientSession(
connector=self.connector,
connector_owner=False,
# TODO use `raise_for_status=true`, though this needs traces first,
) as session:
return await http.request(method, url, session=session, **more)
def get_default_headers(self):
return {
@ -417,33 +439,41 @@ class DAVStorage(Storage):
# The DAVSession class to use
session_class = DAVSession
connector: aiohttp.TCPConnector
_repr_attributes = ("username", "url")
_property_table = {
"displayname": ("displayname", "DAV:"),
}
def __init__(self, **kwargs):
def __init__(self, *, connector, **kwargs):
# defined for _repr_attributes
self.username = kwargs.get("username")
self.url = kwargs.get("url")
self.connector = connector
self.session, kwargs = self.session_class.init_and_remaining_args(**kwargs)
self.session, kwargs = self.session_class.init_and_remaining_args(
connector=connector,
**kwargs,
)
super().__init__(**kwargs)
__init__.__signature__ = signature(session_class.__init__)
@classmethod
def discover(cls, **kwargs):
async def discover(cls, **kwargs):
session, _ = cls.session_class.init_and_remaining_args(**kwargs)
d = cls.discovery_class(session, kwargs)
return d.discover()
async for collection in d.discover():
yield collection
@classmethod
def create_collection(cls, collection, **kwargs):
async def create_collection(cls, collection, **kwargs):
session, _ = cls.session_class.init_and_remaining_args(**kwargs)
d = cls.discovery_class(session, kwargs)
return d.create(collection)
return await d.create(collection)
def _normalize_href(self, *args, **kwargs):
return _normalize_href(self.session.url, *args, **kwargs)
@ -455,57 +485,65 @@ class DAVStorage(Storage):
def _is_item_mimetype(self, mimetype):
return _fuzzy_matches_mimetype(self.item_mimetype, mimetype)
def get(self, href):
((actual_href, item, etag),) = self.get_multi([href])
async def get(self, href):
((actual_href, item, etag),) = await aiostream.stream.list(
self.get_multi([href])
)
assert href == actual_href
return item, etag
def get_multi(self, hrefs):
async def get_multi(self, hrefs):
hrefs = set(hrefs)
href_xml = []
for href in hrefs:
if href != self._normalize_href(href):
raise exceptions.NotFoundError(href)
href_xml.append(f"<href>{href}</href>")
if not href_xml:
return ()
if href_xml:
data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode(
"utf-8"
)
response = await self.session.request(
"REPORT", "", data=data, headers=self.session.get_default_headers()
)
root = _parse_xml(
await response.content.read()
) # etree only can handle bytes
rv = []
hrefs_left = set(hrefs)
for href, etag, prop in self._parse_prop_responses(root):
raw = prop.find(self.get_multi_data_query)
if raw is None:
dav_logger.warning(
"Skipping {}, the item content is missing.".format(href)
)
continue
data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode("utf-8")
response = self.session.request(
"REPORT", "", data=data, headers=self.session.get_default_headers()
)
root = _parse_xml(response.content) # etree only can handle bytes
rv = []
hrefs_left = set(hrefs)
for href, etag, prop in self._parse_prop_responses(root):
raw = prop.find(self.get_multi_data_query)
if raw is None:
dav_logger.warning(
"Skipping {}, the item content is missing.".format(href)
)
continue
raw = raw.text or ""
raw = raw.text or ""
if isinstance(raw, bytes):
raw = raw.decode(response.encoding)
if isinstance(etag, bytes):
etag = etag.decode(response.encoding)
if isinstance(raw, bytes):
raw = raw.decode(response.encoding)
if isinstance(etag, bytes):
etag = etag.decode(response.encoding)
try:
hrefs_left.remove(href)
except KeyError:
if href in hrefs:
dav_logger.warning("Server sent item twice: {}".format(href))
try:
hrefs_left.remove(href)
except KeyError:
if href in hrefs:
dav_logger.warning("Server sent item twice: {}".format(href))
else:
dav_logger.warning(
"Server sent unsolicited item: {}".format(href)
)
else:
dav_logger.warning("Server sent unsolicited item: {}".format(href))
else:
rv.append((href, Item(raw), etag))
for href in hrefs_left:
raise exceptions.NotFoundError(href)
return rv
rv.append((href, Item(raw), etag))
for href in hrefs_left:
raise exceptions.NotFoundError(href)
def _put(self, href, item, etag):
for href, item, etag in rv:
yield href, item, etag
async def _put(self, href, item, etag):
headers = self.session.get_default_headers()
headers["Content-Type"] = self.item_mimetype
if etag is None:
@ -513,11 +551,11 @@ class DAVStorage(Storage):
else:
headers["If-Match"] = etag
response = self.session.request(
response = await self.session.request(
"PUT", href, data=item.raw.encode("utf-8"), headers=headers
)
_assert_multistatus_success(response)
await _assert_multistatus_success(response)
# The server may not return an etag under certain conditions:
#
@ -534,25 +572,28 @@ class DAVStorage(Storage):
# In such cases we return a constant etag. The next synchronization
# will then detect an etag change and will download the new item.
etag = response.headers.get("etag", None)
href = self._normalize_href(response.url)
href = self._normalize_href(str(response.url))
return href, etag
def update(self, href, item, etag):
async def update(self, href, item, etag):
if etag is None:
raise ValueError("etag must be given and must not be None.")
href, etag = self._put(self._normalize_href(href), item, etag)
href, etag = await self._put(self._normalize_href(href), item, etag)
return etag
def upload(self, item):
async def upload(self, item):
href = self._get_href(item)
return self._put(href, item, None)
rv = await self._put(href, item, None)
return rv
def delete(self, href, etag):
async def delete(self, href, etag):
href = self._normalize_href(href)
headers = self.session.get_default_headers()
headers.update({"If-Match": etag})
if etag: # baikal doesn't give us an etag.
dav_logger.warning("Deleting an item with no etag.")
headers.update({"If-Match": etag})
self.session.request("DELETE", href, headers=headers)
await self.session.request("DELETE", href, headers=headers)
def _parse_prop_responses(self, root, handled_hrefs=None):
if handled_hrefs is None:
@ -604,7 +645,7 @@ class DAVStorage(Storage):
handled_hrefs.add(href)
yield href, etag, props
def list(self):
async def list(self):
headers = self.session.get_default_headers()
headers["Depth"] = "1"
@ -620,14 +661,19 @@ class DAVStorage(Storage):
# We use a PROPFIND request instead of addressbook-query due to issues
# with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83
response = self.session.request("PROPFIND", "", data=data, headers=headers)
root = _parse_xml(response.content)
response = await self.session.request(
"PROPFIND",
"",
data=data,
headers=headers,
)
root = _parse_xml(await response.content.read())
rv = self._parse_prop_responses(root)
for href, etag, _prop in rv:
yield href, etag
def get_meta(self, key):
async def get_meta(self, key):
try:
tagname, namespace = self._property_table[key]
except KeyError:
@ -649,9 +695,14 @@ class DAVStorage(Storage):
headers = self.session.get_default_headers()
headers["Depth"] = "0"
response = self.session.request("PROPFIND", "", data=data, headers=headers)
response = await self.session.request(
"PROPFIND",
"",
data=data,
headers=headers,
)
root = _parse_xml(response.content)
root = _parse_xml(await response.content.read())
for prop in root.findall(".//" + xpath):
text = normalize_meta_value(getattr(prop, "text", None))
@ -659,7 +710,7 @@ class DAVStorage(Storage):
return text
return ""
def set_meta(self, key, value):
async def set_meta(self, key, value):
try:
tagname, namespace = self._property_table[key]
except KeyError:
@ -683,8 +734,11 @@ class DAVStorage(Storage):
"utf-8"
)
self.session.request(
"PROPPATCH", "", data=data, headers=self.session.get_default_headers()
await self.session.request(
"PROPPATCH",
"",
data=data,
headers=self.session.get_default_headers(),
)
# XXX: Response content is currently ignored. Though exceptions are
@ -776,7 +830,7 @@ class CalDAVStorage(DAVStorage):
("VTODO", "VEVENT"), start, end
)
def list(self):
async def list(self):
caldavfilters = list(
self._get_list_filters(self.item_types, self.start_date, self.end_date)
)
@ -788,7 +842,8 @@ class CalDAVStorage(DAVStorage):
# instead?
#
# See https://github.com/dmfs/tasks/issues/118 for backstory.
yield from DAVStorage.list(self)
async for href, etag in DAVStorage.list(self):
yield href, etag
data = """<?xml version="1.0" encoding="utf-8" ?>
<C:calendar-query xmlns="DAV:"
@ -813,8 +868,13 @@ class CalDAVStorage(DAVStorage):
for caldavfilter in caldavfilters:
xml = data.format(caldavfilter=caldavfilter).encode("utf-8")
response = self.session.request("REPORT", "", data=xml, headers=headers)
root = _parse_xml(response.content)
response = await self.session.request(
"REPORT",
"",
data=xml,
headers=headers,
)
root = _parse_xml(await response.content.read())
rv = self._parse_prop_responses(root, handled_hrefs)
for href, etag, _prop in rv:
yield href, etag

View file

@ -30,13 +30,13 @@ logger = logging.getLogger(__name__)
def _writing_op(f):
@functools.wraps(f)
def inner(self, *args, **kwargs):
async def inner(self, *args, **kwargs):
if not self._at_once:
self._sync_journal()
rv = f(self, *args, **kwargs)
if not self._at_once:
self._sync_journal()
return rv
return await rv
return inner
@ -120,7 +120,14 @@ class EtesyncStorage(Storage):
self._session.etesync.sync_journal(self.collection)
@classmethod
def discover(cls, email, secrets_dir, server_url=None, db_path=None, **kwargs):
async def discover(
cls,
email,
secrets_dir,
server_url=None,
db_path=None,
**kwargs,
):
if kwargs.get("collection", None) is not None:
raise TypeError("collection argument must not be given.")
session = _Session(email, secrets_dir, server_url, db_path)
@ -139,7 +146,7 @@ class EtesyncStorage(Storage):
logger.debug(f"Skipping collection: {entry!r}")
@classmethod
def create_collection(
async def create_collection(
cls, collection, email, secrets_dir, server_url=None, db_path=None, **kwargs
):
session = _Session(email, secrets_dir, server_url, db_path)
@ -158,13 +165,13 @@ class EtesyncStorage(Storage):
**kwargs,
)
def list(self):
async def list(self):
self._sync_journal()
for entry in self._journal.collection.list():
item = Item(entry.content)
yield str(entry.uid), item.hash
def get(self, href):
async def get(self, href):
try:
item = Item(self._journal.collection.get(href).content)
except etesync.exceptions.DoesNotExist as e:
@ -172,7 +179,7 @@ class EtesyncStorage(Storage):
return item, item.hash
@_writing_op
def upload(self, item):
async def upload(self, item):
try:
entry = self._item_type.create(self._journal.collection, item.raw)
entry.save()
@ -183,7 +190,7 @@ class EtesyncStorage(Storage):
return item.uid, item.hash
@_writing_op
def update(self, href, item, etag):
async def update(self, href, item, etag):
try:
entry = self._journal.collection.get(href)
except etesync.exceptions.DoesNotExist as e:
@ -196,7 +203,7 @@ class EtesyncStorage(Storage):
return item.hash
@_writing_op
def delete(self, href, etag):
async def delete(self, href, etag):
try:
entry = self._journal.collection.get(href)
old_item = Item(entry.content)
@ -206,8 +213,8 @@ class EtesyncStorage(Storage):
except etesync.exceptions.DoesNotExist as e:
raise exceptions.NotFoundError(e)
@contextlib.contextmanager
def at_once(self):
@contextlib.asynccontextmanager
async def at_once(self):
self._sync_journal()
self._at_once = True
try:

View file

@ -41,7 +41,7 @@ class FilesystemStorage(Storage):
self.post_hook = post_hook
@classmethod
def discover(cls, path, **kwargs):
async def discover(cls, path, **kwargs):
if kwargs.pop("collection", None) is not None:
raise TypeError("collection argument must not be given.")
path = expand_path(path)
@ -67,7 +67,7 @@ class FilesystemStorage(Storage):
return True
@classmethod
def create_collection(cls, collection, **kwargs):
async def create_collection(cls, collection, **kwargs):
kwargs = dict(kwargs)
path = kwargs["path"]
@ -86,7 +86,7 @@ class FilesystemStorage(Storage):
def _get_href(self, ident):
return generate_href(ident) + self.fileext
def list(self):
async def list(self):
for fname in os.listdir(self.path):
fpath = os.path.join(self.path, fname)
if (
@ -96,7 +96,7 @@ class FilesystemStorage(Storage):
):
yield fname, get_etag_from_file(fpath)
def get(self, href):
async def get(self, href):
fpath = self._get_filepath(href)
try:
with open(fpath, "rb") as f:
@ -107,7 +107,7 @@ class FilesystemStorage(Storage):
else:
raise
def upload(self, item):
async def upload(self, item):
if not isinstance(item.raw, str):
raise TypeError("item.raw must be a unicode string.")
@ -139,7 +139,7 @@ class FilesystemStorage(Storage):
else:
raise
def update(self, href, item, etag):
async def update(self, href, item, etag):
fpath = self._get_filepath(href)
if not os.path.exists(fpath):
raise exceptions.NotFoundError(item.uid)
@ -158,7 +158,7 @@ class FilesystemStorage(Storage):
self._run_post_hook(fpath)
return etag
def delete(self, href, etag):
async def delete(self, href, etag):
fpath = self._get_filepath(href)
if not os.path.isfile(fpath):
raise exceptions.NotFoundError(href)
@ -176,7 +176,7 @@ class FilesystemStorage(Storage):
except OSError as e:
logger.warning("Error executing external hook: {}".format(str(e)))
def get_meta(self, key):
async def get_meta(self, key):
fpath = os.path.join(self.path, key)
try:
with open(fpath, "rb") as f:
@ -187,7 +187,7 @@ class FilesystemStorage(Storage):
else:
raise
def set_meta(self, key, value):
async def set_meta(self, key, value):
value = normalize_meta_value(value)
fpath = os.path.join(self.path, key)

View file

@ -3,6 +3,7 @@ import logging
import os
import urllib.parse as urlparse
import aiohttp
import click
from atomicwrites import atomic_write
@ -28,13 +29,21 @@ except ImportError:
class GoogleSession(dav.DAVSession):
def __init__(self, token_file, client_id, client_secret, url=None):
def __init__(
self,
token_file,
client_id,
client_secret,
url=None,
connector: aiohttp.BaseConnector = None,
):
# Required for discovering collections
if url is not None:
self.url = url
self.useragent = client_id
self._settings = {}
self.connector = connector
if not have_oauth2:
raise exceptions.UserError("requests-oauthlib not installed")

View file

@ -1,5 +1,7 @@
import urllib.parse as urlparse
import aiohttp
from .. import exceptions
from ..http import prepare_auth
from ..http import prepare_client_cert
@ -30,6 +32,8 @@ class HttpStorage(Storage):
useragent=USERAGENT,
verify_fingerprint=None,
auth_cert=None,
*,
connector,
**kwargs
):
super().__init__(**kwargs)
@ -43,6 +47,8 @@ class HttpStorage(Storage):
self.username, self.password = username, password
self.useragent = useragent
assert connector is not None
self.connector = connector
collection = kwargs.get("collection")
if collection is not None:
@ -53,22 +59,35 @@ class HttpStorage(Storage):
def _default_headers(self):
return {"User-Agent": self.useragent}
def list(self):
r = request("GET", self.url, headers=self._default_headers(), **self._settings)
async def list(self):
async with aiohttp.ClientSession(
connector=self.connector,
connector_owner=False,
# TODO use `raise_for_status=true`, though this needs traces first,
) as session:
r = await request(
"GET",
self.url,
headers=self._default_headers(),
session=session,
**self._settings,
)
self._items = {}
for item in split_collection(r.text):
for item in split_collection((await r.read()).decode("utf-8")):
item = Item(item)
if self._ignore_uids:
item = item.with_uid(item.hash)
self._items[item.ident] = item, item.hash
return ((href, etag) for href, (item, etag) in self._items.items())
for href, (_, etag) in self._items.items():
yield href, etag
def get(self, href):
async def get(self, href):
if self._items is None:
self.list()
async for _ in self.list():
pass
try:
return self._items[href]

View file

@ -28,18 +28,18 @@ class MemoryStorage(Storage):
def _get_href(self, item):
return item.ident + self.fileext
def list(self):
async def list(self):
for href, (etag, _item) in self.items.items():
yield href, etag
def get(self, href):
async def get(self, href):
etag, item = self.items[href]
return item, etag
def has(self, href):
async def has(self, href):
return href in self.items
def upload(self, item):
async def upload(self, item):
href = self._get_href(item)
if href in self.items:
raise exceptions.AlreadyExistingError(existing_href=href)
@ -47,7 +47,7 @@ class MemoryStorage(Storage):
self.items[href] = (etag, item)
return href, etag
def update(self, href, item, etag):
async def update(self, href, item, etag):
if href not in self.items:
raise exceptions.NotFoundError(href)
actual_etag, _ = self.items[href]
@ -58,15 +58,15 @@ class MemoryStorage(Storage):
self.items[href] = (new_etag, item)
return new_etag
def delete(self, href, etag):
if not self.has(href):
async def delete(self, href, etag):
if not await self.has(href):
raise exceptions.NotFoundError(href)
if etag != self.items[href][0]:
raise exceptions.WrongEtagError(etag)
del self.items[href]
def get_meta(self, key):
async def get_meta(self, key):
return normalize_meta_value(self.metadata.get(key))
def set_meta(self, key, value):
async def set_meta(self, key, value):
self.metadata[key] = normalize_meta_value(value)

View file

@ -21,10 +21,12 @@ logger = logging.getLogger(__name__)
def _writing_op(f):
@functools.wraps(f)
def inner(self, *args, **kwargs):
async def inner(self, *args, **kwargs):
if self._items is None or not self._at_once:
self.list()
rv = f(self, *args, **kwargs)
async for _ in self.list():
pass
assert self._items is not None
rv = await f(self, *args, **kwargs)
if not self._at_once:
self._write()
return rv
@ -53,7 +55,7 @@ class SingleFileStorage(Storage):
self._at_once = False
@classmethod
def discover(cls, path, **kwargs):
async def discover(cls, path, **kwargs):
if kwargs.pop("collection", None) is not None:
raise TypeError("collection argument must not be given.")
@ -81,7 +83,7 @@ class SingleFileStorage(Storage):
yield args
@classmethod
def create_collection(cls, collection, **kwargs):
async def create_collection(cls, collection, **kwargs):
path = os.path.abspath(expand_path(kwargs["path"]))
if collection is not None:
@ -97,7 +99,7 @@ class SingleFileStorage(Storage):
kwargs["collection"] = collection
return kwargs
def list(self):
async def list(self):
self._items = collections.OrderedDict()
try:
@ -111,19 +113,19 @@ class SingleFileStorage(Storage):
raise OSError(e)
text = None
if not text:
return ()
if text:
for item in split_collection(text):
item = Item(item)
etag = item.hash
href = item.ident
self._items[href] = item, etag
for item in split_collection(text):
item = Item(item)
etag = item.hash
self._items[item.ident] = item, etag
yield href, etag
return ((href, etag) for href, (item, etag) in self._items.items())
def get(self, href):
async def get(self, href):
if self._items is None or not self._at_once:
self.list()
async for _ in self.list():
pass
try:
return self._items[href]
@ -131,7 +133,7 @@ class SingleFileStorage(Storage):
raise exceptions.NotFoundError(href)
@_writing_op
def upload(self, item):
async def upload(self, item):
href = item.ident
if href in self._items:
raise exceptions.AlreadyExistingError(existing_href=href)
@ -140,7 +142,7 @@ class SingleFileStorage(Storage):
return href, item.hash
@_writing_op
def update(self, href, item, etag):
async def update(self, href, item, etag):
if href not in self._items:
raise exceptions.NotFoundError(href)
@ -152,7 +154,7 @@ class SingleFileStorage(Storage):
return item.hash
@_writing_op
def delete(self, href, etag):
async def delete(self, href, etag):
if href not in self._items:
raise exceptions.NotFoundError(href)
@ -181,8 +183,8 @@ class SingleFileStorage(Storage):
self._items = None
self._last_etag = None
@contextlib.contextmanager
def at_once(self):
@contextlib.asynccontextmanager
async def at_once(self):
self.list()
self._at_once = True
try:

View file

@ -35,7 +35,7 @@ class _StorageInfo:
self.status = status
self._item_cache = {}
def prepare_new_status(self):
async def prepare_new_status(self):
storage_nonempty = False
prefetch = []
@ -45,7 +45,7 @@ class _StorageInfo:
except IdentAlreadyExists as e:
raise e.to_ident_conflict(self.storage)
for href, etag in self.storage.list():
async for href, etag in self.storage.list():
storage_nonempty = True
ident, meta = self.status.get_by_href(href)
@ -58,9 +58,13 @@ class _StorageInfo:
_store_props(ident, meta)
# Prefetch items
for href, item, etag in self.storage.get_multi(prefetch) if prefetch else ():
_store_props(item.ident, ItemMetadata(href=href, hash=item.hash, etag=etag))
self.set_item_cache(item.ident, item)
if prefetch:
async for href, item, etag in self.storage.get_multi(prefetch):
_store_props(
item.ident,
ItemMetadata(href=href, hash=item.hash, etag=etag),
)
self.set_item_cache(item.ident, item)
return storage_nonempty
@ -86,7 +90,7 @@ class _StorageInfo:
return self._item_cache[ident]
def sync(
async def sync(
storage_a,
storage_b,
status,
@ -137,8 +141,8 @@ def sync(
a_info = _StorageInfo(storage_a, SubStatus(status, "a"))
b_info = _StorageInfo(storage_b, SubStatus(status, "b"))
a_nonempty = a_info.prepare_new_status()
b_nonempty = b_info.prepare_new_status()
a_nonempty = await a_info.prepare_new_status()
b_nonempty = await b_info.prepare_new_status()
if status_nonempty and not force_delete:
if a_nonempty and not b_nonempty:
@ -148,10 +152,10 @@ def sync(
actions = list(_get_actions(a_info, b_info))
with storage_a.at_once(), storage_b.at_once():
async with storage_a.at_once(), storage_b.at_once():
for action in actions:
try:
action.run(a_info, b_info, conflict_resolution, partial_sync)
await action.run(a_info, b_info, conflict_resolution, partial_sync)
except Exception as e:
if error_callback:
error_callback(e)
@ -160,10 +164,10 @@ def sync(
class Action:
def _run_impl(self, a, b): # pragma: no cover
async def _run_impl(self, a, b): # pragma: no cover
raise NotImplementedError()
def run(self, a, b, conflict_resolution, partial_sync):
async def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b):
if self.dest.storage.read_only:
if partial_sync == "error":
@ -174,7 +178,7 @@ class Action:
else:
assert partial_sync == "revert"
self._run_impl(a, b)
await self._run_impl(a, b)
@contextlib.contextmanager
def auto_rollback(self, a, b):
@ -194,7 +198,7 @@ class Upload(Action):
self.ident = item.ident
self.dest = dest
def _run_impl(self, a, b):
async def _run_impl(self, a, b):
if self.dest.storage.read_only:
href = etag = None
@ -204,7 +208,7 @@ class Upload(Action):
self.ident, self.dest.storage
)
)
href, etag = self.dest.storage.upload(self.item)
href, etag = await self.dest.storage.upload(self.item)
assert href is not None
self.dest.status.insert_ident(
@ -218,7 +222,7 @@ class Update(Action):
self.ident = item.ident
self.dest = dest
def _run_impl(self, a, b):
async def _run_impl(self, a, b):
if self.dest.storage.read_only:
meta = ItemMetadata(hash=self.item.hash)
else:
@ -226,7 +230,7 @@ class Update(Action):
"Copying (updating) item {} to {}".format(self.ident, self.dest.storage)
)
meta = self.dest.status.get_new(self.ident)
meta.etag = self.dest.storage.update(meta.href, self.item, meta.etag)
meta.etag = await self.dest.storage.update(meta.href, self.item, meta.etag)
self.dest.status.update_ident(self.ident, meta)
@ -236,13 +240,13 @@ class Delete(Action):
self.ident = ident
self.dest = dest
def _run_impl(self, a, b):
async def _run_impl(self, a, b):
meta = self.dest.status.get_new(self.ident)
if not self.dest.storage.read_only:
sync_logger.info(
"Deleting item {} from {}".format(self.ident, self.dest.storage)
)
self.dest.storage.delete(meta.href, meta.etag)
await self.dest.storage.delete(meta.href, meta.etag)
self.dest.status.remove_ident(self.ident)
@ -251,7 +255,7 @@ class ResolveConflict(Action):
def __init__(self, ident):
self.ident = ident
def run(self, a, b, conflict_resolution, partial_sync):
async def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b):
sync_logger.info(
"Doing conflict resolution for item {}...".format(self.ident)
@ -271,9 +275,19 @@ class ResolveConflict(Action):
item_b = b.get_item_cache(self.ident)
new_item = conflict_resolution(item_a, item_b)
if new_item.hash != meta_a.hash:
Update(new_item, a).run(a, b, conflict_resolution, partial_sync)
await Update(new_item, a).run(
a,
b,
conflict_resolution,
partial_sync,
)
if new_item.hash != meta_b.hash:
Update(new_item, b).run(a, b, conflict_resolution, partial_sync)
await Update(new_item, b).run(
a,
b,
conflict_resolution,
partial_sync,
)
else:
raise UserError(
"Invalid conflict resolution mode: {!r}".format(conflict_resolution)

View file

@ -26,22 +26,15 @@ def expand_path(p: str) -> str:
return p
def split_dict(d, f):
def split_dict(d: dict, f: callable):
"""Puts key into first dict if f(key), otherwise in second dict"""
a, b = split_sequence(d.items(), lambda item: f(item[0]))
return dict(a), dict(b)
def split_sequence(s, f):
"""Puts item into first list if f(item), else in second list"""
a = []
b = []
for item in s:
if f(item):
a.append(item)
a = {}
b = {}
for k, v in d.items():
if f(k):
a[k] = v
else:
b.append(item)
b[k] = v
return a, b