Use black to auto-format the codebase

This commit is contained in:
Hugo Osvaldo Barrera 2021-05-06 19:28:54 +02:00
parent abf199f21e
commit d2d41e5df1
66 changed files with 2902 additions and 2497 deletions

View file

@ -13,6 +13,10 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
additional_dependencies: [flake8-import-order, flake8-bugbear] additional_dependencies: [flake8-import-order, flake8-bugbear]
- repo: https://github.com/psf/black
rev: "21.5b0"
hooks:
- id: black
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.5.0 rev: v2.5.0
hooks: hooks:

View file

@ -3,90 +3,104 @@ import os
from pkg_resources import get_distribution from pkg_resources import get_distribution
extensions = ['sphinx.ext.autodoc'] extensions = ["sphinx.ext.autodoc"]
templates_path = ['_templates'] templates_path = ["_templates"]
source_suffix = '.rst' source_suffix = ".rst"
master_doc = 'index' master_doc = "index"
project = 'vdirsyncer' project = "vdirsyncer"
copyright = ('2014-{}, Markus Unterwaditzer & contributors' copyright = "2014-{}, Markus Unterwaditzer & contributors".format(
.format(datetime.date.today().strftime('%Y'))) datetime.date.today().strftime("%Y")
)
release = get_distribution('vdirsyncer').version release = get_distribution("vdirsyncer").version
version = '.'.join(release.split('.')[:2]) # The short X.Y version. version = ".".join(release.split(".")[:2]) # The short X.Y version.
rst_epilog = '.. |vdirsyncer_version| replace:: %s' % release rst_epilog = ".. |vdirsyncer_version| replace:: %s" % release
exclude_patterns = ['_build'] exclude_patterns = ["_build"]
pygments_style = 'sphinx' pygments_style = "sphinx"
on_rtd = os.environ.get('READTHEDOCS', None) == 'True' on_rtd = os.environ.get("READTHEDOCS", None) == "True"
try: try:
import sphinx_rtd_theme import sphinx_rtd_theme
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
except ImportError: except ImportError:
html_theme = 'default' html_theme = "default"
if not on_rtd: if not on_rtd:
print('-' * 74) print("-" * 74)
print('Warning: sphinx-rtd-theme not installed, building with default ' print(
'theme.') "Warning: sphinx-rtd-theme not installed, building with default " "theme."
print('-' * 74) )
print("-" * 74)
html_static_path = ['_static'] html_static_path = ["_static"]
htmlhelp_basename = 'vdirsyncerdoc' htmlhelp_basename = "vdirsyncerdoc"
latex_elements = {} latex_elements = {}
latex_documents = [ latex_documents = [
('index', 'vdirsyncer.tex', 'vdirsyncer Documentation', (
'Markus Unterwaditzer', 'manual'), "index",
"vdirsyncer.tex",
"vdirsyncer Documentation",
"Markus Unterwaditzer",
"manual",
),
] ]
man_pages = [ man_pages = [
('index', 'vdirsyncer', 'vdirsyncer Documentation', ("index", "vdirsyncer", "vdirsyncer Documentation", ["Markus Unterwaditzer"], 1)
['Markus Unterwaditzer'], 1)
] ]
texinfo_documents = [ texinfo_documents = [
('index', 'vdirsyncer', 'vdirsyncer Documentation', (
'Markus Unterwaditzer', 'vdirsyncer', "index",
'Synchronize calendars and contacts.', 'Miscellaneous'), "vdirsyncer",
"vdirsyncer Documentation",
"Markus Unterwaditzer",
"vdirsyncer",
"Synchronize calendars and contacts.",
"Miscellaneous",
),
] ]
def github_issue_role(name, rawtext, text, lineno, inliner, def github_issue_role(name, rawtext, text, lineno, inliner, options=None, content=()):
options=None, content=()):
options = options or {} options = options or {}
try: try:
issue_num = int(text) issue_num = int(text)
if issue_num <= 0: if issue_num <= 0:
raise ValueError() raise ValueError()
except ValueError: except ValueError:
msg = inliner.reporter.error(f'Invalid GitHub issue: {text}', msg = inliner.reporter.error(f"Invalid GitHub issue: {text}", line=lineno)
line=lineno)
prb = inliner.problematic(rawtext, rawtext, msg) prb = inliner.problematic(rawtext, rawtext, msg)
return [prb], [msg] return [prb], [msg]
from docutils import nodes from docutils import nodes
PROJECT_HOME = 'https://github.com/pimutils/vdirsyncer' PROJECT_HOME = "https://github.com/pimutils/vdirsyncer"
link = '{}/{}/{}'.format(PROJECT_HOME, link = "{}/{}/{}".format(
'issues' if name == 'gh' else 'pull', PROJECT_HOME, "issues" if name == "gh" else "pull", issue_num
issue_num) )
linktext = ('issue #{}' if name == 'gh' linktext = ("issue #{}" if name == "gh" else "pull request #{}").format(issue_num)
else 'pull request #{}').format(issue_num) node = nodes.reference(rawtext, linktext, refuri=link, **options)
node = nodes.reference(rawtext, linktext, refuri=link,
**options)
return [node], [] return [node], []
def setup(app): def setup(app):
from sphinx.domains.python import PyObject from sphinx.domains.python import PyObject
app.add_object_type('storage', 'storage', 'pair: %s; storage',
doc_field_types=PyObject.doc_field_types) app.add_object_type(
app.add_role('gh', github_issue_role) "storage",
app.add_role('ghpr', github_issue_role) "storage",
"pair: %s; storage",
doc_field_types=PyObject.doc_field_types,
)
app.add_role("gh", github_issue_role)
app.add_role("ghpr", github_issue_role)

View file

@ -1,9 +1,9 @@
''' """
Vdirsyncer synchronizes calendars and contacts. Vdirsyncer synchronizes calendars and contacts.
Please refer to https://vdirsyncer.pimutils.org/en/stable/packaging.html for Please refer to https://vdirsyncer.pimutils.org/en/stable/packaging.html for
how to package vdirsyncer. how to package vdirsyncer.
''' """
from setuptools import Command from setuptools import Command
from setuptools import find_packages from setuptools import find_packages
from setuptools import setup from setuptools import setup
@ -11,25 +11,21 @@ from setuptools import setup
requirements = [ requirements = [
# https://github.com/mitsuhiko/click/issues/200 # https://github.com/mitsuhiko/click/issues/200
'click>=5.0', "click>=5.0",
'click-log>=0.3.0, <0.4.0', "click-log>=0.3.0, <0.4.0",
# https://github.com/pimutils/vdirsyncer/issues/478 # https://github.com/pimutils/vdirsyncer/issues/478
'click-threading>=0.2', "click-threading>=0.2",
"requests >=2.20.0",
'requests >=2.20.0',
# https://github.com/sigmavirus24/requests-toolbelt/pull/28 # https://github.com/sigmavirus24/requests-toolbelt/pull/28
# And https://github.com/sigmavirus24/requests-toolbelt/issues/54 # And https://github.com/sigmavirus24/requests-toolbelt/issues/54
'requests_toolbelt >=0.4.0', "requests_toolbelt >=0.4.0",
# https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa # https://github.com/untitaker/python-atomicwrites/commit/4d12f23227b6a944ab1d99c507a69fdbc7c9ed6d # noqa
'atomicwrites>=0.1.7' "atomicwrites>=0.1.7",
] ]
class PrintRequirements(Command): class PrintRequirements(Command):
description = 'Prints minimal requirements' description = "Prints minimal requirements"
user_options = [] user_options = []
def initialize_options(self): def initialize_options(self):
@ -43,54 +39,44 @@ class PrintRequirements(Command):
print(requirement.replace(">", "=").replace(" ", "")) print(requirement.replace(">", "=").replace(" ", ""))
with open('README.rst') as f: with open("README.rst") as f:
long_description = f.read() long_description = f.read()
setup( setup(
# General metadata # General metadata
name='vdirsyncer', name="vdirsyncer",
author='Markus Unterwaditzer', author="Markus Unterwaditzer",
author_email='markus@unterwaditzer.net', author_email="markus@unterwaditzer.net",
url='https://github.com/pimutils/vdirsyncer', url="https://github.com/pimutils/vdirsyncer",
description='Synchronize calendars and contacts', description="Synchronize calendars and contacts",
license='BSD', license="BSD",
long_description=long_description, long_description=long_description,
# Runtime dependencies # Runtime dependencies
install_requires=requirements, install_requires=requirements,
# Optional dependencies # Optional dependencies
extras_require={ extras_require={
'google': ['requests-oauthlib'], "google": ["requests-oauthlib"],
'etesync': ['etesync==0.5.2', 'django<2.0'] "etesync": ["etesync==0.5.2", "django<2.0"],
}, },
# Build dependencies # Build dependencies
setup_requires=['setuptools_scm != 1.12.0'], setup_requires=["setuptools_scm != 1.12.0"],
# Other # Other
packages=find_packages(exclude=['tests.*', 'tests']), packages=find_packages(exclude=["tests.*", "tests"]),
include_package_data=True, include_package_data=True,
cmdclass={ cmdclass={"minimal_requirements": PrintRequirements},
'minimal_requirements': PrintRequirements use_scm_version={"write_to": "vdirsyncer/version.py"},
}, entry_points={"console_scripts": ["vdirsyncer = vdirsyncer.cli:main"]},
use_scm_version={
'write_to': 'vdirsyncer/version.py'
},
entry_points={
'console_scripts': ['vdirsyncer = vdirsyncer.cli:main']
},
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'Environment :: Console', "Environment :: Console",
'License :: OSI Approved :: BSD License', "License :: OSI Approved :: BSD License",
'Operating System :: POSIX', "Operating System :: POSIX",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Programming Language :: Python :: 3.8', "Programming Language :: Python :: 3.8",
'Programming Language :: Python :: 3.9', "Programming Language :: Python :: 3.9",
'Topic :: Internet', "Topic :: Internet",
'Topic :: Utilities', "Topic :: Utilities",
], ],
) )

View file

@ -1,6 +1,6 @@
''' """
Test suite for vdirsyncer. Test suite for vdirsyncer.
''' """
import hypothesis.strategies as st import hypothesis.strategies as st
import urllib3.exceptions import urllib3.exceptions
@ -10,14 +10,14 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def blow_up(*a, **kw): def blow_up(*a, **kw):
raise AssertionError('Did not expect to be called.') raise AssertionError("Did not expect to be called.")
def assert_item_equals(a, b): def assert_item_equals(a, b):
assert normalize_item(a) == normalize_item(b) assert normalize_item(a) == normalize_item(b)
VCARD_TEMPLATE = '''BEGIN:VCARD VCARD_TEMPLATE = """BEGIN:VCARD
VERSION:3.0 VERSION:3.0
FN:Cyrus Daboo FN:Cyrus Daboo
N:Daboo;Cyrus;;; N:Daboo;Cyrus;;;
@ -31,9 +31,9 @@ TEL;TYPE=FAX:412 605 0705
URL;VALUE=URI:http://www.example.com URL;VALUE=URI:http://www.example.com
X-SOMETHING:{r} X-SOMETHING:{r}
UID:{uid} UID:{uid}
END:VCARD''' END:VCARD"""
TASK_TEMPLATE = '''BEGIN:VCALENDAR TASK_TEMPLATE = """BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:-//dmfs.org//mimedir.icalendar//EN PRODID:-//dmfs.org//mimedir.icalendar//EN
BEGIN:VTODO BEGIN:VTODO
@ -45,25 +45,30 @@ SUMMARY:Book: Kowlani - Tödlicher Staub
X-SOMETHING:{r} X-SOMETHING:{r}
UID:{uid} UID:{uid}
END:VTODO END:VTODO
END:VCALENDAR''' END:VCALENDAR"""
BARE_EVENT_TEMPLATE = '''BEGIN:VEVENT BARE_EVENT_TEMPLATE = """BEGIN:VEVENT
DTSTART:19970714T170000Z DTSTART:19970714T170000Z
DTEND:19970715T035959Z DTEND:19970715T035959Z
SUMMARY:Bastille Day Party SUMMARY:Bastille Day Party
X-SOMETHING:{r} X-SOMETHING:{r}
UID:{uid} UID:{uid}
END:VEVENT''' END:VEVENT"""
EVENT_TEMPLATE = '''BEGIN:VCALENDAR EVENT_TEMPLATE = (
"""BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:-//hacksw/handcal//NONSGML v1.0//EN PRODID:-//hacksw/handcal//NONSGML v1.0//EN
''' + BARE_EVENT_TEMPLATE + ''' """
END:VCALENDAR''' + BARE_EVENT_TEMPLATE
+ """
END:VCALENDAR"""
)
EVENT_WITH_TIMEZONE_TEMPLATE = '''BEGIN:VCALENDAR EVENT_WITH_TIMEZONE_TEMPLATE = (
"""BEGIN:VCALENDAR
BEGIN:VTIMEZONE BEGIN:VTIMEZONE
TZID:Europe/Rome TZID:Europe/Rome
X-LIC-LOCATION:Europe/Rome X-LIC-LOCATION:Europe/Rome
@ -82,26 +87,23 @@ DTSTART:19701025T030000
RRULE:FREQ=YEARLY;BYDAY=-1SU;BYMONTH=10 RRULE:FREQ=YEARLY;BYDAY=-1SU;BYMONTH=10
END:STANDARD END:STANDARD
END:VTIMEZONE END:VTIMEZONE
''' + BARE_EVENT_TEMPLATE + ''' """
END:VCALENDAR''' + BARE_EVENT_TEMPLATE
+ """
END:VCALENDAR"""
)
SIMPLE_TEMPLATE = '''BEGIN:FOO SIMPLE_TEMPLATE = """BEGIN:FOO
UID:{uid} UID:{uid}
X-SOMETHING:{r} X-SOMETHING:{r}
HAHA:YES HAHA:YES
END:FOO''' END:FOO"""
printable_characters_strategy = st.text( printable_characters_strategy = st.text(
st.characters(blacklist_categories=( st.characters(blacklist_categories=("Cc", "Cs"))
'Cc', 'Cs'
))
) )
uid_strategy = st.text( uid_strategy = st.text(
st.characters(blacklist_categories=( st.characters(blacklist_categories=("Zs", "Zl", "Zp", "Cc", "Cs")), min_size=1
'Zs', 'Zl', 'Zp',
'Cc', 'Cs'
)),
min_size=1
).filter(lambda x: x.strip() == x) ).filter(lambda x: x.strip() == x)

View file

@ -1,6 +1,6 @@
''' """
General-purpose fixtures for vdirsyncer's testsuite. General-purpose fixtures for vdirsyncer's testsuite.
''' """
import logging import logging
import os import os
@ -13,35 +13,42 @@ from hypothesis import Verbosity
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_logging(): def setup_logging():
click_log.basic_config('vdirsyncer').setLevel(logging.DEBUG) click_log.basic_config("vdirsyncer").setLevel(logging.DEBUG)
try: try:
import pytest_benchmark import pytest_benchmark
except ImportError: except ImportError:
@pytest.fixture @pytest.fixture
def benchmark(): def benchmark():
return lambda x: x() return lambda x: x()
else: else:
del pytest_benchmark del pytest_benchmark
settings.register_profile("ci", settings( settings.register_profile(
max_examples=1000, "ci",
verbosity=Verbosity.verbose, settings(
suppress_health_check=[HealthCheck.too_slow], max_examples=1000,
)) verbosity=Verbosity.verbose,
settings.register_profile("deterministic", settings( suppress_health_check=[HealthCheck.too_slow],
derandomize=True, ),
suppress_health_check=HealthCheck.all(), )
)) settings.register_profile(
settings.register_profile("dev", settings( "deterministic",
suppress_health_check=[HealthCheck.too_slow] settings(
)) derandomize=True,
suppress_health_check=HealthCheck.all(),
),
)
settings.register_profile("dev", settings(suppress_health_check=[HealthCheck.too_slow]))
if os.environ.get('DETERMINISTIC_TESTS', 'false').lower() == 'true': if os.environ.get("DETERMINISTIC_TESTS", "false").lower() == "true":
settings.load_profile("deterministic") settings.load_profile("deterministic")
elif os.environ.get('CI', 'false').lower() == 'true': elif os.environ.get("CI", "false").lower() == "true":
settings.load_profile("ci") settings.load_profile("ci")
else: else:
settings.load_profile("dev") settings.load_profile("dev")

View file

@ -20,7 +20,8 @@ from vdirsyncer.vobject import Item
def get_server_mixin(server_name): def get_server_mixin(server_name):
from . import __name__ as base from . import __name__ as base
x = __import__(f'{base}.servers.{server_name}', fromlist=[''])
x = __import__(f"{base}.servers.{server_name}", fromlist=[""])
return x.ServerMixin return x.ServerMixin
@ -35,18 +36,18 @@ class StorageTests:
supports_collections = True supports_collections = True
supports_metadata = True supports_metadata = True
@pytest.fixture(params=['VEVENT', 'VTODO', 'VCARD']) @pytest.fixture(params=["VEVENT", "VTODO", "VCARD"])
def item_type(self, request): def item_type(self, request):
'''Parametrize with all supported item types.''' """Parametrize with all supported item types."""
return request.param return request.param
@pytest.fixture @pytest.fixture
def get_storage_args(self): def get_storage_args(self):
''' """
Return a function with the following properties: Return a function with the following properties:
:param collection: The name of the collection to create and use. :param collection: The name of the collection to create and use.
''' """
raise NotImplementedError() raise NotImplementedError()
@pytest.fixture @pytest.fixture
@ -56,9 +57,9 @@ class StorageTests:
@pytest.fixture @pytest.fixture
def get_item(self, item_type): def get_item(self, item_type):
template = { template = {
'VEVENT': EVENT_TEMPLATE, "VEVENT": EVENT_TEMPLATE,
'VTODO': TASK_TEMPLATE, "VTODO": TASK_TEMPLATE,
'VCARD': VCARD_TEMPLATE, "VCARD": VCARD_TEMPLATE,
}[item_type] }[item_type]
return lambda **kw: format_item(template, **kw) return lambda **kw: format_item(template, **kw)
@ -66,12 +67,12 @@ class StorageTests:
@pytest.fixture @pytest.fixture
def requires_collections(self): def requires_collections(self):
if not self.supports_collections: if not self.supports_collections:
pytest.skip('This storage does not support collections.') pytest.skip("This storage does not support collections.")
@pytest.fixture @pytest.fixture
def requires_metadata(self): def requires_metadata(self):
if not self.supports_metadata: if not self.supports_metadata:
pytest.skip('This storage does not support metadata.') pytest.skip("This storage does not support metadata.")
def test_generic(self, s, get_item): def test_generic(self, s, get_item):
items = [get_item() for i in range(1, 10)] items = [get_item() for i in range(1, 10)]
@ -97,7 +98,7 @@ class StorageTests:
href, etag = s.upload(get_item()) href, etag = s.upload(get_item())
if etag is None: if etag is None:
_, etag = s.get(href) _, etag = s.get(href)
(href2, item, etag2), = s.get_multi([href] * 2) ((href2, item, etag2),) = s.get_multi([href] * 2)
assert href2 == href assert href2 == href
assert etag2 == etag assert etag2 == etag
@ -130,7 +131,7 @@ class StorageTests:
def test_update_nonexisting(self, s, get_item): def test_update_nonexisting(self, s, get_item):
item = get_item() item = get_item()
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.update('huehue', item, '"123"') s.update("huehue", item, '"123"')
def test_wrong_etag(self, s, get_item): def test_wrong_etag(self, s, get_item):
item = get_item() item = get_item()
@ -147,7 +148,7 @@ class StorageTests:
def test_delete_nonexisting(self, s, get_item): def test_delete_nonexisting(self, s, get_item):
with pytest.raises(exceptions.PreconditionFailed): with pytest.raises(exceptions.PreconditionFailed):
s.delete('1', '"123"') s.delete("1", '"123"')
def test_list(self, s, get_item): def test_list(self, s, get_item):
assert not list(s.list()) assert not list(s.list())
@ -157,10 +158,10 @@ class StorageTests:
assert list(s.list()) == [(href, etag)] assert list(s.list()) == [(href, etag)]
def test_has(self, s, get_item): def test_has(self, s, get_item):
assert not s.has('asd') assert not s.has("asd")
href, etag = s.upload(get_item()) href, etag = s.upload(get_item())
assert s.has(href) assert s.has(href)
assert not s.has('asd') assert not s.has("asd")
s.delete(href, etag) s.delete(href, etag)
assert not s.has(href) assert not s.has(href)
@ -173,8 +174,8 @@ class StorageTests:
info[href] = etag info[href] = etag
assert { assert {
href: etag for href, item, etag href: etag
in s.get_multi(href for href, etag in info.items()) for href, item, etag in s.get_multi(href for href, etag in info.items())
} == info } == info
def test_repr(self, s, get_storage_args): def test_repr(self, s, get_storage_args):
@ -184,61 +185,56 @@ class StorageTests:
def test_discover(self, requires_collections, get_storage_args, get_item): def test_discover(self, requires_collections, get_storage_args, get_item):
collections = set() collections = set()
for i in range(1, 5): for i in range(1, 5):
collection = f'test{i}' collection = f"test{i}"
s = self.storage_class(**get_storage_args(collection=collection)) s = self.storage_class(**get_storage_args(collection=collection))
assert not list(s.list()) assert not list(s.list())
s.upload(get_item()) s.upload(get_item())
collections.add(s.collection) collections.add(s.collection)
actual = { actual = {
c['collection'] for c in c["collection"]
self.storage_class.discover(**get_storage_args(collection=None)) for c in self.storage_class.discover(**get_storage_args(collection=None))
} }
assert actual >= collections assert actual >= collections
def test_create_collection(self, requires_collections, get_storage_args, def test_create_collection(self, requires_collections, get_storage_args, get_item):
get_item): if getattr(self, "dav_server", "") in ("icloud", "fastmail", "davical"):
if getattr(self, 'dav_server', '') in \ pytest.skip("Manual cleanup would be necessary.")
('icloud', 'fastmail', 'davical'): if getattr(self, "dav_server", "") == "radicale":
pytest.skip('Manual cleanup would be necessary.')
if getattr(self, 'dav_server', '') == "radicale":
pytest.skip("Radicale does not support collection creation") pytest.skip("Radicale does not support collection creation")
args = get_storage_args(collection=None) args = get_storage_args(collection=None)
args['collection'] = 'test' args["collection"] = "test"
s = self.storage_class( s = self.storage_class(**self.storage_class.create_collection(**args))
**self.storage_class.create_collection(**args)
)
href = s.upload(get_item())[0] href = s.upload(get_item())[0]
assert href in {href for href, etag in s.list()} assert href in {href for href, etag in s.list()}
def test_discover_collection_arg(self, requires_collections, def test_discover_collection_arg(self, requires_collections, get_storage_args):
get_storage_args): args = get_storage_args(collection="test2")
args = get_storage_args(collection='test2')
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
list(self.storage_class.discover(**args)) list(self.storage_class.discover(**args))
assert 'collection argument must not be given' in str(excinfo.value) assert "collection argument must not be given" in str(excinfo.value)
def test_collection_arg(self, get_storage_args): def test_collection_arg(self, get_storage_args):
if self.storage_class.storage_name.startswith('etesync'): if self.storage_class.storage_name.startswith("etesync"):
pytest.skip('etesync uses UUIDs.') pytest.skip("etesync uses UUIDs.")
if self.supports_collections: if self.supports_collections:
s = self.storage_class(**get_storage_args(collection='test2')) s = self.storage_class(**get_storage_args(collection="test2"))
# Can't do stronger assertion because of radicale, which needs a # Can't do stronger assertion because of radicale, which needs a
# fileextension to guess the collection type. # fileextension to guess the collection type.
assert 'test2' in s.collection assert "test2" in s.collection
else: else:
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.storage_class(collection='ayy', **get_storage_args()) self.storage_class(collection="ayy", **get_storage_args())
def test_case_sensitive_uids(self, s, get_item): def test_case_sensitive_uids(self, s, get_item):
if s.storage_name == 'filesystem': if s.storage_name == "filesystem":
pytest.skip('Behavior depends on the filesystem.') pytest.skip("Behavior depends on the filesystem.")
uid = str(uuid.uuid4()) uid = str(uuid.uuid4())
s.upload(get_item(uid=uid.upper())) s.upload(get_item(uid=uid.upper()))
@ -247,17 +243,18 @@ class StorageTests:
assert len(items) == 2 assert len(items) == 2
assert len(set(items)) == 2 assert len(set(items)) == 2
def test_specialchars(self, monkeypatch, requires_collections, def test_specialchars(
get_storage_args, get_item): self, monkeypatch, requires_collections, get_storage_args, get_item
if getattr(self, 'dav_server', '') == 'radicale': ):
pytest.skip('Radicale is fundamentally broken.') if getattr(self, "dav_server", "") == "radicale":
if getattr(self, 'dav_server', '') in ('icloud', 'fastmail'): pytest.skip("Radicale is fundamentally broken.")
pytest.skip('iCloud and FastMail reject this name.') if getattr(self, "dav_server", "") in ("icloud", "fastmail"):
pytest.skip("iCloud and FastMail reject this name.")
monkeypatch.setattr('vdirsyncer.utils.generate_href', lambda x: x) monkeypatch.setattr("vdirsyncer.utils.generate_href", lambda x: x)
uid = 'test @ foo ät bar град сатану' uid = "test @ foo ät bar град сатану"
collection = 'test @ foo ät bar' collection = "test @ foo ät bar"
s = self.storage_class(**get_storage_args(collection=collection)) s = self.storage_class(**get_storage_args(collection=collection))
item = get_item(uid=uid) item = get_item(uid=uid)
@ -268,33 +265,33 @@ class StorageTests:
assert etag2 == etag assert etag2 == etag
assert_item_equals(item2, item) assert_item_equals(item2, item)
(_, etag3), = s.list() ((_, etag3),) = s.list()
assert etag2 == etag3 assert etag2 == etag3
# etesync uses UUIDs for collection names # etesync uses UUIDs for collection names
if self.storage_class.storage_name.startswith('etesync'): if self.storage_class.storage_name.startswith("etesync"):
return return
assert collection in urlunquote(s.collection) assert collection in urlunquote(s.collection)
if self.storage_class.storage_name.endswith('dav'): if self.storage_class.storage_name.endswith("dav"):
assert urlquote(uid, '/@:') in href assert urlquote(uid, "/@:") in href
def test_metadata(self, requires_metadata, s): def test_metadata(self, requires_metadata, s):
if not getattr(self, 'dav_server', ''): if not getattr(self, "dav_server", ""):
assert not s.get_meta('color') assert not s.get_meta("color")
assert not s.get_meta('displayname') assert not s.get_meta("displayname")
try: try:
s.set_meta('color', None) s.set_meta("color", None)
assert not s.get_meta('color') assert not s.get_meta("color")
s.set_meta('color', '#ff0000') s.set_meta("color", "#ff0000")
assert s.get_meta('color') == '#ff0000' assert s.get_meta("color") == "#ff0000"
except exceptions.UnsupportedMetadataError: except exceptions.UnsupportedMetadataError:
pass pass
for x in ('hello world', 'hello wörld'): for x in ("hello world", "hello wörld"):
s.set_meta('displayname', x) s.set_meta("displayname", x)
rv = s.get_meta('displayname') rv = s.get_meta("displayname")
assert rv == x assert rv == x
assert isinstance(rv, str) assert isinstance(rv, str)
@ -307,20 +304,22 @@ class StorageTests:
], ],
) )
def test_metadata_normalization(self, requires_metadata, s, value): def test_metadata_normalization(self, requires_metadata, s, value):
x = s.get_meta('displayname') x = s.get_meta("displayname")
assert x == normalize_meta_value(x) assert x == normalize_meta_value(x)
if not getattr(self, 'dav_server', None): if not getattr(self, "dav_server", None):
# ownCloud replaces "" with "unnamed" # ownCloud replaces "" with "unnamed"
s.set_meta('displayname', value) s.set_meta("displayname", value)
assert s.get_meta('displayname') == normalize_meta_value(value) assert s.get_meta("displayname") == normalize_meta_value(value)
def test_recurring_events(self, s, item_type): def test_recurring_events(self, s, item_type):
if item_type != 'VEVENT': if item_type != "VEVENT":
pytest.skip('This storage instance doesn\'t support iCalendar.') pytest.skip("This storage instance doesn't support iCalendar.")
uid = str(uuid.uuid4()) uid = str(uuid.uuid4())
item = Item(textwrap.dedent(''' item = Item(
textwrap.dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
BEGIN:VEVENT BEGIN:VEVENT
@ -354,7 +353,11 @@ class StorageTests:
TRANSP:OPAQUE TRANSP:OPAQUE
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
'''.format(uid=uid)).strip()) """.format(
uid=uid
)
).strip()
)
href, etag = s.upload(item) href, etag = s.upload(item)

View file

@ -11,13 +11,13 @@ def slow_create_collection(request):
def delete_collections(): def delete_collections():
for s in to_delete: for s in to_delete:
s.session.request('DELETE', '') s.session.request("DELETE", "")
request.addfinalizer(delete_collections) request.addfinalizer(delete_collections)
def inner(cls, args, collection): def inner(cls, args, collection):
assert collection.startswith('test') assert collection.startswith("test")
collection += '-vdirsyncer-ci-' + str(uuid.uuid4()) collection += "-vdirsyncer-ci-" + str(uuid.uuid4())
args = cls.create_collection(collection, **args) args = cls.create_collection(collection, **args)
s = cls(**args) s = cls(**args)

View file

@ -11,26 +11,25 @@ from vdirsyncer import exceptions
from vdirsyncer.vobject import Item from vdirsyncer.vobject import Item
dav_server = os.environ.get('DAV_SERVER', 'skip') dav_server = os.environ.get("DAV_SERVER", "skip")
ServerMixin = get_server_mixin(dav_server) ServerMixin = get_server_mixin(dav_server)
class DAVStorageTests(ServerMixin, StorageTests): class DAVStorageTests(ServerMixin, StorageTests):
dav_server = dav_server dav_server = dav_server
@pytest.mark.skipif(dav_server == 'radicale', @pytest.mark.skipif(dav_server == "radicale", reason="Radicale is very tolerant.")
reason='Radicale is very tolerant.')
def test_dav_broken_item(self, s): def test_dav_broken_item(self, s):
item = Item('HAHA:YES') item = Item("HAHA:YES")
with pytest.raises((exceptions.Error, requests.exceptions.HTTPError)): with pytest.raises((exceptions.Error, requests.exceptions.HTTPError)):
s.upload(item) s.upload(item)
assert not list(s.list()) assert not list(s.list())
def test_dav_empty_get_multi_performance(self, s, monkeypatch): def test_dav_empty_get_multi_performance(self, s, monkeypatch):
def breakdown(*a, **kw): def breakdown(*a, **kw):
raise AssertionError('Expected not to be called.') raise AssertionError("Expected not to be called.")
monkeypatch.setattr('requests.sessions.Session.request', breakdown) monkeypatch.setattr("requests.sessions.Session.request", breakdown)
try: try:
assert list(s.get_multi([])) == [] assert list(s.get_multi([])) == []
@ -39,12 +38,11 @@ class DAVStorageTests(ServerMixin, StorageTests):
monkeypatch.undo() monkeypatch.undo()
def test_dav_unicode_href(self, s, get_item, monkeypatch): def test_dav_unicode_href(self, s, get_item, monkeypatch):
if self.dav_server == 'radicale': if self.dav_server == "radicale":
pytest.skip('Radicale is unable to deal with unicode hrefs') pytest.skip("Radicale is unable to deal with unicode hrefs")
monkeypatch.setattr(s, '_get_href', monkeypatch.setattr(s, "_get_href", lambda item: item.ident + s.fileext)
lambda item: item.ident + s.fileext) item = get_item(uid="град сатану" + str(uuid.uuid4()))
item = get_item(uid='град сатану' + str(uuid.uuid4()))
href, etag = s.upload(item) href, etag = s.upload(item)
item2, etag2 = s.get(href) item2, etag2 = s.get(href)
assert_item_equals(item, item2) assert_item_equals(item, item2)

View file

@ -17,7 +17,7 @@ from vdirsyncer.storage.dav import CalDAVStorage
class TestCalDAVStorage(DAVStorageTests): class TestCalDAVStorage(DAVStorageTests):
storage_class = CalDAVStorage storage_class = CalDAVStorage
@pytest.fixture(params=['VTODO', 'VEVENT']) @pytest.fixture(params=["VTODO", "VEVENT"])
def item_type(self, request): def item_type(self, request):
return request.param return request.param
@ -32,15 +32,19 @@ class TestCalDAVStorage(DAVStorageTests):
# The `arg` param is not named `item_types` because that would hit # The `arg` param is not named `item_types` because that would hit
# https://bitbucket.org/pytest-dev/pytest/issue/745/ # https://bitbucket.org/pytest-dev/pytest/issue/745/
@pytest.mark.parametrize('arg,calls_num', [ @pytest.mark.parametrize(
(('VTODO',), 1), "arg,calls_num",
(('VEVENT',), 1), [
(('VTODO', 'VEVENT'), 2), (("VTODO",), 1),
(('VTODO', 'VEVENT', 'VJOURNAL'), 3), (("VEVENT",), 1),
((), 1) (("VTODO", "VEVENT"), 2),
]) (("VTODO", "VEVENT", "VJOURNAL"), 3),
def test_item_types_performance(self, get_storage_args, arg, calls_num, ((), 1),
monkeypatch): ],
)
def test_item_types_performance(
self, get_storage_args, arg, calls_num, monkeypatch
):
s = self.storage_class(item_types=arg, **get_storage_args()) s = self.storage_class(item_types=arg, **get_storage_args())
old_parse = s._parse_prop_responses old_parse = s._parse_prop_responses
calls = [] calls = []
@ -49,19 +53,23 @@ class TestCalDAVStorage(DAVStorageTests):
calls.append(None) calls.append(None)
return old_parse(*a, **kw) return old_parse(*a, **kw)
monkeypatch.setattr(s, '_parse_prop_responses', new_parse) monkeypatch.setattr(s, "_parse_prop_responses", new_parse)
list(s.list()) list(s.list())
assert len(calls) == calls_num assert len(calls) == calls_num
@pytest.mark.xfail(dav_server == 'radicale', @pytest.mark.xfail(
reason='Radicale doesn\'t support timeranges.') dav_server == "radicale", reason="Radicale doesn't support timeranges."
)
def test_timerange_correctness(self, get_storage_args): def test_timerange_correctness(self, get_storage_args):
start_date = datetime.datetime(2013, 9, 10) start_date = datetime.datetime(2013, 9, 10)
end_date = datetime.datetime(2013, 9, 13) end_date = datetime.datetime(2013, 9, 13)
s = self.storage_class(start_date=start_date, end_date=end_date, s = self.storage_class(
**get_storage_args()) start_date=start_date, end_date=end_date, **get_storage_args()
)
too_old_item = format_item(dedent(''' too_old_item = format_item(
dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:-//hacksw/handcal//NONSGML v1.0//EN PRODID:-//hacksw/handcal//NONSGML v1.0//EN
@ -73,9 +81,13 @@ class TestCalDAVStorage(DAVStorageTests):
UID:{r} UID:{r}
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
''').strip()) """
).strip()
)
too_new_item = format_item(dedent(''' too_new_item = format_item(
dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:-//hacksw/handcal//NONSGML v1.0//EN PRODID:-//hacksw/handcal//NONSGML v1.0//EN
@ -87,9 +99,13 @@ class TestCalDAVStorage(DAVStorageTests):
UID:{r} UID:{r}
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
''').strip()) """
).strip()
)
good_item = format_item(dedent(''' good_item = format_item(
dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:-//hacksw/handcal//NONSGML v1.0//EN PRODID:-//hacksw/handcal//NONSGML v1.0//EN
@ -101,13 +117,15 @@ class TestCalDAVStorage(DAVStorageTests):
UID:{r} UID:{r}
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
''').strip()) """
).strip()
)
s.upload(too_old_item) s.upload(too_old_item)
s.upload(too_new_item) s.upload(too_new_item)
expected_href, _ = s.upload(good_item) expected_href, _ = s.upload(good_item)
(actual_href, _), = s.list() ((actual_href, _),) = s.list()
assert actual_href == expected_href assert actual_href == expected_href
def test_invalid_resource(self, monkeypatch, get_storage_args): def test_invalid_resource(self, monkeypatch, get_storage_args):
@ -115,37 +133,37 @@ class TestCalDAVStorage(DAVStorageTests):
args = get_storage_args(collection=None) args = get_storage_args(collection=None)
def request(session, method, url, **kwargs): def request(session, method, url, **kwargs):
assert url == args['url'] assert url == args["url"]
calls.append(None) calls.append(None)
r = requests.Response() r = requests.Response()
r.status_code = 200 r.status_code = 200
r._content = b'Hello World.' r._content = b"Hello World."
return r return r
monkeypatch.setattr('requests.sessions.Session.request', request) monkeypatch.setattr("requests.sessions.Session.request", request)
with pytest.raises(ValueError): with pytest.raises(ValueError):
s = self.storage_class(**args) s = self.storage_class(**args)
list(s.list()) list(s.list())
assert len(calls) == 1 assert len(calls) == 1
@pytest.mark.skipif(dav_server == 'icloud', @pytest.mark.skipif(dav_server == "icloud", reason="iCloud only accepts VEVENT")
reason='iCloud only accepts VEVENT') @pytest.mark.skipif(
@pytest.mark.skipif(dav_server == 'fastmail', dav_server == "fastmail", reason="Fastmail has non-standard hadling of VTODOs."
reason='Fastmail has non-standard hadling of VTODOs.') )
def test_item_types_general(self, s): def test_item_types_general(self, s):
event = s.upload(format_item(EVENT_TEMPLATE))[0] event = s.upload(format_item(EVENT_TEMPLATE))[0]
task = s.upload(format_item(TASK_TEMPLATE))[0] task = s.upload(format_item(TASK_TEMPLATE))[0]
s.item_types = ('VTODO', 'VEVENT') s.item_types = ("VTODO", "VEVENT")
def hrefs(): def hrefs():
return {href for href, etag in s.list()} return {href for href, etag in s.list()}
assert hrefs() == {event, task} assert hrefs() == {event, task}
s.item_types = ('VTODO',) s.item_types = ("VTODO",)
assert hrefs() == {task} assert hrefs() == {task}
s.item_types = ('VEVENT',) s.item_types = ("VEVENT",)
assert hrefs() == {event} assert hrefs() == {event}
s.item_types = () s.item_types = ()
assert hrefs() == {event, task} assert hrefs() == {event, task}

View file

@ -7,6 +7,6 @@ from vdirsyncer.storage.dav import CardDAVStorage
class TestCardDAVStorage(DAVStorageTests): class TestCardDAVStorage(DAVStorageTests):
storage_class = CardDAVStorage storage_class = CardDAVStorage
@pytest.fixture(params=['VCARD']) @pytest.fixture(params=["VCARD"])
def item_type(self, request): def item_type(self, request):
return request.param return request.param

View file

@ -6,7 +6,8 @@ from vdirsyncer.storage.dav import _parse_xml
def test_xml_utilities(): def test_xml_utilities():
x = _parse_xml(b'''<?xml version="1.0" encoding="UTF-8" ?> x = _parse_xml(
b"""<?xml version="1.0" encoding="UTF-8" ?>
<multistatus xmlns="DAV:"> <multistatus xmlns="DAV:">
<response> <response>
<propstat> <propstat>
@ -24,19 +25,22 @@ def test_xml_utilities():
</propstat> </propstat>
</response> </response>
</multistatus> </multistatus>
''') """
)
response = x.find('{DAV:}response') response = x.find("{DAV:}response")
props = _merge_xml(response.findall('{DAV:}propstat/{DAV:}prop')) props = _merge_xml(response.findall("{DAV:}propstat/{DAV:}prop"))
assert props.find('{DAV:}resourcetype/{DAV:}collection') is not None assert props.find("{DAV:}resourcetype/{DAV:}collection") is not None
assert props.find('{DAV:}getcontenttype') is not None assert props.find("{DAV:}getcontenttype") is not None
@pytest.mark.parametrize('char', range(32)) @pytest.mark.parametrize("char", range(32))
def test_xml_specialchars(char): def test_xml_specialchars(char):
x = _parse_xml('<?xml version="1.0" encoding="UTF-8" ?>' x = _parse_xml(
'<foo>ye{}s\r\n' '<?xml version="1.0" encoding="UTF-8" ?>'
'hello</foo>'.format(chr(char)).encode('ascii')) "<foo>ye{}s\r\n"
"hello</foo>".format(chr(char)).encode("ascii")
)
if char in _BAD_XML_CHARS: if char in _BAD_XML_CHARS:
assert x.text == 'yes\nhello' assert x.text == "yes\nhello"

View file

@ -19,7 +19,7 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/ # See https://docs.djangoproject.com/en/1.10/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret! # SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'd7r(p-9=$3a@bbt%*+$p@4)cej13nzd0gmnt8+m0bitb=-umj#' SECRET_KEY = "d7r(p-9=$3a@bbt%*+$p@4)cej13nzd0gmnt8+m0bitb=-umj#"
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True DEBUG = True
@ -30,56 +30,55 @@ ALLOWED_HOSTS = []
# Application definition # Application definition
INSTALLED_APPS = [ INSTALLED_APPS = [
'django.contrib.admin', "django.contrib.admin",
'django.contrib.auth', "django.contrib.auth",
'django.contrib.contenttypes', "django.contrib.contenttypes",
'django.contrib.sessions', "django.contrib.sessions",
'django.contrib.messages', "django.contrib.messages",
'django.contrib.staticfiles', "django.contrib.staticfiles",
'rest_framework', "rest_framework",
'rest_framework.authtoken', "rest_framework.authtoken",
'journal.apps.JournalConfig', "journal.apps.JournalConfig",
] ]
MIDDLEWARE = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', "django.middleware.security.SecurityMiddleware",
'django.contrib.sessions.middleware.SessionMiddleware', "django.contrib.sessions.middleware.SessionMiddleware",
'django.middleware.common.CommonMiddleware', "django.middleware.common.CommonMiddleware",
'django.middleware.csrf.CsrfViewMiddleware', "django.middleware.csrf.CsrfViewMiddleware",
'django.contrib.auth.middleware.AuthenticationMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware",
'django.contrib.messages.middleware.MessageMiddleware', "django.contrib.messages.middleware.MessageMiddleware",
'django.middleware.clickjacking.XFrameOptionsMiddleware', "django.middleware.clickjacking.XFrameOptionsMiddleware",
] ]
ROOT_URLCONF = 'etesync_server.urls' ROOT_URLCONF = "etesync_server.urls"
TEMPLATES = [ TEMPLATES = [
{ {
'BACKEND': 'django.template.backends.django.DjangoTemplates', "BACKEND": "django.template.backends.django.DjangoTemplates",
'DIRS': [], "DIRS": [],
'APP_DIRS': True, "APP_DIRS": True,
'OPTIONS': { "OPTIONS": {
'context_processors': [ "context_processors": [
'django.template.context_processors.debug', "django.template.context_processors.debug",
'django.template.context_processors.request', "django.template.context_processors.request",
'django.contrib.auth.context_processors.auth', "django.contrib.auth.context_processors.auth",
'django.contrib.messages.context_processors.messages', "django.contrib.messages.context_processors.messages",
], ],
}, },
}, },
] ]
WSGI_APPLICATION = 'etesync_server.wsgi.application' WSGI_APPLICATION = "etesync_server.wsgi.application"
# Database # Database
# https://docs.djangoproject.com/en/1.10/ref/settings/#databases # https://docs.djangoproject.com/en/1.10/ref/settings/#databases
DATABASES = { DATABASES = {
'default': { "default": {
'ENGINE': 'django.db.backends.sqlite3', "ENGINE": "django.db.backends.sqlite3",
'NAME': os.environ.get('ETESYNC_DB_PATH', "NAME": os.environ.get("ETESYNC_DB_PATH", os.path.join(BASE_DIR, "db.sqlite3")),
os.path.join(BASE_DIR, 'db.sqlite3')),
} }
} }
@ -89,16 +88,16 @@ DATABASES = {
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ {
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', # noqa "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", # noqa
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', # noqa "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", # noqa
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', # noqa "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", # noqa
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', # noqa "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", # noqa
}, },
] ]
@ -106,9 +105,9 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/1.10/topics/i18n/ # https://docs.djangoproject.com/en/1.10/topics/i18n/
LANGUAGE_CODE = 'en-us' LANGUAGE_CODE = "en-us"
TIME_ZONE = 'UTC' TIME_ZONE = "UTC"
USE_I18N = True USE_I18N = True
@ -120,4 +119,4 @@ USE_TZ = True
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.10/howto/static-files/ # https://docs.djangoproject.com/en/1.10/howto/static-files/
STATIC_URL = '/static/' STATIC_URL = "/static/"

View file

@ -19,22 +19,19 @@ from journal import views
from rest_framework_nested import routers from rest_framework_nested import routers
router = routers.DefaultRouter() router = routers.DefaultRouter()
router.register(r'journals', views.JournalViewSet) router.register(r"journals", views.JournalViewSet)
router.register(r'journal/(?P<journal_uid>[^/]+)', views.EntryViewSet) router.register(r"journal/(?P<journal_uid>[^/]+)", views.EntryViewSet)
router.register(r'user', views.UserInfoViewSet) router.register(r"user", views.UserInfoViewSet)
journals_router = routers.NestedSimpleRouter(router, r'journals', journals_router = routers.NestedSimpleRouter(router, r"journals", lookup="journal")
lookup='journal') journals_router.register(r"members", views.MembersViewSet, base_name="journal-members")
journals_router.register(r'members', views.MembersViewSet, journals_router.register(r"entries", views.EntryViewSet, base_name="journal-entries")
base_name='journal-members')
journals_router.register(r'entries', views.EntryViewSet,
base_name='journal-entries')
urlpatterns = [ urlpatterns = [
url(r'^api/v1/', include(router.urls)), url(r"^api/v1/", include(router.urls)),
url(r'^api/v1/', include(journals_router.urls)), url(r"^api/v1/", include(journals_router.urls)),
] ]
# Adding this just for testing, this shouldn't be here normally # Adding this just for testing, this shouldn't be here normally
urlpatterns += url(r'^reset/$', views.reset, name='reset_debug'), urlpatterns += (url(r"^reset/$", views.reset, name="reset_debug"),)

View file

@ -10,24 +10,23 @@ from vdirsyncer.storage.etesync import EtesyncCalendars
from vdirsyncer.storage.etesync import EtesyncContacts from vdirsyncer.storage.etesync import EtesyncContacts
pytestmark = pytest.mark.skipif(os.getenv('ETESYNC_TESTS', '') != 'true', pytestmark = pytest.mark.skipif(
reason='etesync tests disabled') os.getenv("ETESYNC_TESTS", "") != "true", reason="etesync tests disabled"
)
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def etesync_app(tmpdir_factory): def etesync_app(tmpdir_factory):
sys.path.insert(0, os.path.join(os.path.dirname(__file__), sys.path.insert(0, os.path.join(os.path.dirname(__file__), "etesync_server"))
'etesync_server'))
db = tmpdir_factory.mktemp('etesync').join('etesync.sqlite') db = tmpdir_factory.mktemp("etesync").join("etesync.sqlite")
shutil.copy( shutil.copy(
os.path.join(os.path.dirname(__file__), 'etesync_server', os.path.join(os.path.dirname(__file__), "etesync_server", "db.sqlite3"), str(db)
'db.sqlite3'),
str(db)
) )
os.environ['ETESYNC_DB_PATH'] = str(db) os.environ["ETESYNC_DB_PATH"] = str(db)
from etesync_server.wsgi import application from etesync_server.wsgi import application
return application return application
@ -39,44 +38,44 @@ class EtesyncTests(StorageTests):
def get_storage_args(self, request, get_item, tmpdir, etesync_app): def get_storage_args(self, request, get_item, tmpdir, etesync_app):
import wsgi_intercept import wsgi_intercept
import wsgi_intercept.requests_intercept import wsgi_intercept.requests_intercept
wsgi_intercept.requests_intercept.install() wsgi_intercept.requests_intercept.install()
wsgi_intercept.add_wsgi_intercept('127.0.0.1', 8000, wsgi_intercept.add_wsgi_intercept("127.0.0.1", 8000, lambda: etesync_app)
lambda: etesync_app)
def teardown(): def teardown():
wsgi_intercept.remove_wsgi_intercept('127.0.0.1', 8000) wsgi_intercept.remove_wsgi_intercept("127.0.0.1", 8000)
wsgi_intercept.requests_intercept.uninstall() wsgi_intercept.requests_intercept.uninstall()
request.addfinalizer(teardown) request.addfinalizer(teardown)
with open(os.path.join(os.path.dirname(__file__), with open(
'test@localhost/auth_token')) as f: os.path.join(os.path.dirname(__file__), "test@localhost/auth_token")
) as f:
token = f.read().strip() token = f.read().strip()
headers = {'Authorization': 'Token ' + token} headers = {"Authorization": "Token " + token}
r = requests.post('http://127.0.0.1:8000/reset/', headers=headers, r = requests.post(
allow_redirects=False) "http://127.0.0.1:8000/reset/", headers=headers, allow_redirects=False
)
assert r.status_code == 200 assert r.status_code == 200
def inner(collection='test'): def inner(collection="test"):
rv = { rv = {
'email': 'test@localhost', "email": "test@localhost",
'db_path': str(tmpdir.join('etesync.db')), "db_path": str(tmpdir.join("etesync.db")),
'secrets_dir': os.path.dirname(__file__), "secrets_dir": os.path.dirname(__file__),
'server_url': 'http://127.0.0.1:8000/' "server_url": "http://127.0.0.1:8000/",
} }
if collection is not None: if collection is not None:
rv = self.storage_class.create_collection( rv = self.storage_class.create_collection(collection=collection, **rv)
collection=collection,
**rv
)
return rv return rv
return inner return inner
class TestContacts(EtesyncTests): class TestContacts(EtesyncTests):
storage_class = EtesyncContacts storage_class = EtesyncContacts
@pytest.fixture(params=['VCARD']) @pytest.fixture(params=["VCARD"])
def item_type(self, request): def item_type(self, request):
return request.param return request.param
@ -84,6 +83,6 @@ class TestContacts(EtesyncTests):
class TestCalendars(EtesyncTests): class TestCalendars(EtesyncTests):
storage_class = EtesyncCalendars storage_class = EtesyncCalendars
@pytest.fixture(params=['VEVENT']) @pytest.fixture(params=["VEVENT"])
def item_type(self, request): def item_type(self, request):
return request.param return request.param

View file

@ -12,10 +12,10 @@ class ServerMixin:
"password": "baikal", "password": "baikal",
} }
if self.storage_class.fileext == '.vcf': if self.storage_class.fileext == ".vcf":
args['url'] = base_url + "card.php/" args["url"] = base_url + "card.php/"
else: else:
args['url'] = base_url + "cal.php/" args["url"] = base_url + "cal.php/"
if collection is not None: if collection is not None:
args = slow_create_collection(self.storage_class, args, collection) args = slow_create_collection(self.storage_class, args, collection)

View file

@ -6,43 +6,42 @@ import pytest
try: try:
caldav_args = { caldav_args = {
# Those credentials are configured through the Travis UI # Those credentials are configured through the Travis UI
'username': os.environ['DAVICAL_USERNAME'].strip(), "username": os.environ["DAVICAL_USERNAME"].strip(),
'password': os.environ['DAVICAL_PASSWORD'].strip(), "password": os.environ["DAVICAL_PASSWORD"].strip(),
'url': 'https://brutus.lostpackets.de/davical-test/caldav.php/', "url": "https://brutus.lostpackets.de/davical-test/caldav.php/",
} }
except KeyError as e: except KeyError as e:
pytestmark = pytest.mark.skip('Missing envkey: {}'.format(str(e))) pytestmark = pytest.mark.skip("Missing envkey: {}".format(str(e)))
@pytest.mark.flaky(reruns=5) @pytest.mark.flaky(reruns=5)
class ServerMixin: class ServerMixin:
@pytest.fixture @pytest.fixture
def davical_args(self): def davical_args(self):
if self.storage_class.fileext == '.ics': if self.storage_class.fileext == ".ics":
return dict(caldav_args) return dict(caldav_args)
elif self.storage_class.fileext == '.vcf': elif self.storage_class.fileext == ".vcf":
pytest.skip('No carddav') pytest.skip("No carddav")
else: else:
raise RuntimeError() raise RuntimeError()
@pytest.fixture @pytest.fixture
def get_storage_args(self, davical_args, request): def get_storage_args(self, davical_args, request):
def inner(collection='test'): def inner(collection="test"):
if collection is None: if collection is None:
return davical_args return davical_args
assert collection.startswith('test') assert collection.startswith("test")
for _ in range(4): for _ in range(4):
args = self.storage_class.create_collection( args = self.storage_class.create_collection(
collection + str(uuid.uuid4()), collection + str(uuid.uuid4()), **davical_args
**davical_args
) )
s = self.storage_class(**args) s = self.storage_class(**args)
if not list(s.list()): if not list(s.list()):
request.addfinalizer( request.addfinalizer(lambda: s.session.request("DELETE", ""))
lambda: s.session.request('DELETE', ''))
return args return args
raise RuntimeError('Failed to find free collection.') raise RuntimeError("Failed to find free collection.")
return inner return inner

View file

@ -4,29 +4,28 @@ import pytest
class ServerMixin: class ServerMixin:
@pytest.fixture @pytest.fixture
def get_storage_args(self, item_type, slow_create_collection): def get_storage_args(self, item_type, slow_create_collection):
if item_type != 'VEVENT': if item_type != "VEVENT":
# iCloud collections can either be calendars or task lists. # iCloud collections can either be calendars or task lists.
# See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa # See https://github.com/pimutils/vdirsyncer/pull/593#issuecomment-285941615 # noqa
pytest.skip('iCloud doesn\'t support anything else than VEVENT') pytest.skip("iCloud doesn't support anything else than VEVENT")
def inner(collection='test'): def inner(collection="test"):
args = { args = {
'username': os.environ['ICLOUD_USERNAME'], "username": os.environ["ICLOUD_USERNAME"],
'password': os.environ['ICLOUD_PASSWORD'] "password": os.environ["ICLOUD_PASSWORD"],
} }
if self.storage_class.fileext == '.ics': if self.storage_class.fileext == ".ics":
args['url'] = 'https://caldav.icloud.com/' args["url"] = "https://caldav.icloud.com/"
elif self.storage_class.fileext == '.vcf': elif self.storage_class.fileext == ".vcf":
args['url'] = 'https://contacts.icloud.com/' args["url"] = "https://contacts.icloud.com/"
else: else:
raise RuntimeError() raise RuntimeError()
if collection is not None: if collection is not None:
args = slow_create_collection(self.storage_class, args, args = slow_create_collection(self.storage_class, args, collection)
collection)
return args return args
return inner return inner

View file

@ -7,17 +7,17 @@ import pytest
import requests import requests
testserver_repo = os.path.dirname(__file__) testserver_repo = os.path.dirname(__file__)
make_sh = os.path.abspath(os.path.join(testserver_repo, 'make.sh')) make_sh = os.path.abspath(os.path.join(testserver_repo, "make.sh"))
def wait(): def wait():
for i in range(100): for i in range(100):
try: try:
requests.get('http://127.0.0.1:6767/', verify=False) requests.get("http://127.0.0.1:6767/", verify=False)
except Exception as e: except Exception as e:
# Don't know exact exception class, don't care. # Don't know exact exception class, don't care.
# Also, https://github.com/kennethreitz/requests/issues/2192 # Also, https://github.com/kennethreitz/requests/issues/2192
if 'connection refused' not in str(e).lower(): if "connection refused" not in str(e).lower():
raise raise
time.sleep(2 ** i) time.sleep(2 ** i)
else: else:
@ -26,47 +26,54 @@ def wait():
class ServerMixin: class ServerMixin:
@pytest.fixture(scope='session') @pytest.fixture(scope="session")
def setup_mysteryshack_server(self, xprocess): def setup_mysteryshack_server(self, xprocess):
def preparefunc(cwd): def preparefunc(cwd):
return wait, ['sh', make_sh, 'testserver'] return wait, ["sh", make_sh, "testserver"]
subprocess.check_call(['sh', make_sh, 'testserver-config']) subprocess.check_call(["sh", make_sh, "testserver-config"])
xprocess.ensure('mysteryshack_server', preparefunc) xprocess.ensure("mysteryshack_server", preparefunc)
return subprocess.check_output([ return (
os.path.join( subprocess.check_output(
testserver_repo, [
'mysteryshack/target/debug/mysteryshack' os.path.join(
), testserver_repo, "mysteryshack/target/debug/mysteryshack"
'-c', '/tmp/mysteryshack/config', ),
'user', "-c",
'authorize', "/tmp/mysteryshack/config",
'testuser', "user",
'https://example.com', "authorize",
self.storage_class.scope + ':rw' "testuser",
]).strip().decode() "https://example.com",
self.storage_class.scope + ":rw",
]
)
.strip()
.decode()
)
@pytest.fixture @pytest.fixture
def get_storage_args(self, monkeypatch, setup_mysteryshack_server): def get_storage_args(self, monkeypatch, setup_mysteryshack_server):
from requests import Session from requests import Session
monkeypatch.setitem(os.environ, 'OAUTHLIB_INSECURE_TRANSPORT', 'true') monkeypatch.setitem(os.environ, "OAUTHLIB_INSECURE_TRANSPORT", "true")
old_request = Session.request old_request = Session.request
def request(self, method, url, **kw): def request(self, method, url, **kw):
url = url.replace('https://', 'http://') url = url.replace("https://", "http://")
return old_request(self, method, url, **kw) return old_request(self, method, url, **kw)
monkeypatch.setattr(Session, 'request', request) monkeypatch.setattr(Session, "request", request)
shutil.rmtree('/tmp/mysteryshack/testuser/data', ignore_errors=True) shutil.rmtree("/tmp/mysteryshack/testuser/data", ignore_errors=True)
shutil.rmtree('/tmp/mysteryshack/testuser/meta', ignore_errors=True) shutil.rmtree("/tmp/mysteryshack/testuser/meta", ignore_errors=True)
def inner(**kw): def inner(**kw):
kw['account'] = 'testuser@127.0.0.1:6767' kw["account"] = "testuser@127.0.0.1:6767"
kw['access_token'] = setup_mysteryshack_server kw["access_token"] = setup_mysteryshack_server
if self.storage_class.fileext == '.ics': if self.storage_class.fileext == ".ics":
kw.setdefault('collection', 'test') kw.setdefault("collection", "test")
return kw return kw
return inner return inner

View file

@ -2,7 +2,6 @@ import pytest
class ServerMixin: class ServerMixin:
@pytest.fixture @pytest.fixture
def get_storage_args(self): def get_storage_args(self):
pytest.skip('DAV tests disabled.') pytest.skip("DAV tests disabled.")

View file

@ -12,72 +12,74 @@ class TestFilesystemStorage(StorageTests):
@pytest.fixture @pytest.fixture
def get_storage_args(self, tmpdir): def get_storage_args(self, tmpdir):
def inner(collection='test'): def inner(collection="test"):
rv = {'path': str(tmpdir), 'fileext': '.txt', 'collection': rv = {"path": str(tmpdir), "fileext": ".txt", "collection": collection}
collection}
if collection is not None: if collection is not None:
rv = self.storage_class.create_collection(**rv) rv = self.storage_class.create_collection(**rv)
return rv return rv
return inner return inner
def test_is_not_directory(self, tmpdir): def test_is_not_directory(self, tmpdir):
with pytest.raises(OSError): with pytest.raises(OSError):
f = tmpdir.join('hue') f = tmpdir.join("hue")
f.write('stub') f.write("stub")
self.storage_class(str(tmpdir) + '/hue', '.txt') self.storage_class(str(tmpdir) + "/hue", ".txt")
def test_broken_data(self, tmpdir): def test_broken_data(self, tmpdir):
s = self.storage_class(str(tmpdir), '.txt') s = self.storage_class(str(tmpdir), ".txt")
class BrokenItem: class BrokenItem:
raw = 'Ц, Ш, Л, ж, Д, З, Ю'.encode() raw = "Ц, Ш, Л, ж, Д, З, Ю".encode()
uid = 'jeezus' uid = "jeezus"
ident = uid ident = uid
with pytest.raises(TypeError): with pytest.raises(TypeError):
s.upload(BrokenItem) s.upload(BrokenItem)
assert not tmpdir.listdir() assert not tmpdir.listdir()
def test_ident_with_slash(self, tmpdir): def test_ident_with_slash(self, tmpdir):
s = self.storage_class(str(tmpdir), '.txt') s = self.storage_class(str(tmpdir), ".txt")
s.upload(Item('UID:a/b/c')) s.upload(Item("UID:a/b/c"))
item_file, = tmpdir.listdir() (item_file,) = tmpdir.listdir()
assert '/' not in item_file.basename and item_file.isfile() assert "/" not in item_file.basename and item_file.isfile()
def test_too_long_uid(self, tmpdir): def test_too_long_uid(self, tmpdir):
s = self.storage_class(str(tmpdir), '.txt') s = self.storage_class(str(tmpdir), ".txt")
item = Item('UID:' + 'hue' * 600) item = Item("UID:" + "hue" * 600)
href, etag = s.upload(item) href, etag = s.upload(item)
assert item.uid not in href assert item.uid not in href
def test_post_hook_inactive(self, tmpdir, monkeypatch): def test_post_hook_inactive(self, tmpdir, monkeypatch):
def check_call_mock(*args, **kwargs): def check_call_mock(*args, **kwargs):
raise AssertionError() raise AssertionError()
monkeypatch.setattr(subprocess, 'call', check_call_mock) monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), '.txt', post_hook=None) s = self.storage_class(str(tmpdir), ".txt", post_hook=None)
s.upload(Item('UID:a/b/c')) s.upload(Item("UID:a/b/c"))
def test_post_hook_active(self, tmpdir, monkeypatch): def test_post_hook_active(self, tmpdir, monkeypatch):
calls = [] calls = []
exe = 'foo' exe = "foo"
def check_call_mock(call, *args, **kwargs): def check_call_mock(call, *args, **kwargs):
calls.append(True) calls.append(True)
assert len(call) == 2 assert len(call) == 2
assert call[0] == exe assert call[0] == exe
monkeypatch.setattr(subprocess, 'call', check_call_mock) monkeypatch.setattr(subprocess, "call", check_call_mock)
s = self.storage_class(str(tmpdir), '.txt', post_hook=exe) s = self.storage_class(str(tmpdir), ".txt", post_hook=exe)
s.upload(Item('UID:a/b/c')) s.upload(Item("UID:a/b/c"))
assert calls assert calls
def test_ignore_git_dirs(self, tmpdir): def test_ignore_git_dirs(self, tmpdir):
tmpdir.mkdir('.git').mkdir('foo') tmpdir.mkdir(".git").mkdir("foo")
tmpdir.mkdir('a') tmpdir.mkdir("a")
tmpdir.mkdir('b') tmpdir.mkdir("b")
assert {c['collection'] for c assert {c["collection"] for c in self.storage_class.discover(str(tmpdir))} == {
in self.storage_class.discover(str(tmpdir))} == {'a', 'b'} "a",
"b",
}

View file

@ -8,42 +8,44 @@ from vdirsyncer.storage.http import prepare_auth
def test_list(monkeypatch): def test_list(monkeypatch):
collection_url = 'http://127.0.0.1/calendar/collection.ics' collection_url = "http://127.0.0.1/calendar/collection.ics"
items = [ items = [
('BEGIN:VEVENT\n' (
'SUMMARY:Eine Kurzinfo\n' "BEGIN:VEVENT\n"
'DESCRIPTION:Beschreibung des Termines\n' "SUMMARY:Eine Kurzinfo\n"
'END:VEVENT'), "DESCRIPTION:Beschreibung des Termines\n"
('BEGIN:VEVENT\n' "END:VEVENT"
'SUMMARY:Eine zweite Küèrzinfo\n' ),
'DESCRIPTION:Beschreibung des anderen Termines\n' (
'BEGIN:VALARM\n' "BEGIN:VEVENT\n"
'ACTION:AUDIO\n' "SUMMARY:Eine zweite Küèrzinfo\n"
'TRIGGER:19980403T120000\n' "DESCRIPTION:Beschreibung des anderen Termines\n"
'ATTACH;FMTTYPE=audio/basic:http://host.com/pub/ssbanner.aud\n' "BEGIN:VALARM\n"
'REPEAT:4\n' "ACTION:AUDIO\n"
'DURATION:PT1H\n' "TRIGGER:19980403T120000\n"
'END:VALARM\n' "ATTACH;FMTTYPE=audio/basic:http://host.com/pub/ssbanner.aud\n"
'END:VEVENT') "REPEAT:4\n"
"DURATION:PT1H\n"
"END:VALARM\n"
"END:VEVENT"
),
] ]
responses = [ responses = ["\n".join(["BEGIN:VCALENDAR"] + items + ["END:VCALENDAR"])] * 2
'\n'.join(['BEGIN:VCALENDAR'] + items + ['END:VCALENDAR'])
] * 2
def get(self, method, url, *a, **kw): def get(self, method, url, *a, **kw):
assert method == 'GET' assert method == "GET"
assert url == collection_url assert url == collection_url
r = Response() r = Response()
r.status_code = 200 r.status_code = 200
assert responses assert responses
r._content = responses.pop().encode('utf-8') r._content = responses.pop().encode("utf-8")
r.headers['Content-Type'] = 'text/calendar' r.headers["Content-Type"] = "text/calendar"
r.encoding = 'ISO-8859-1' r.encoding = "ISO-8859-1"
return r return r
monkeypatch.setattr('requests.sessions.Session.request', get) monkeypatch.setattr("requests.sessions.Session.request", get)
s = HttpStorage(url=collection_url) s = HttpStorage(url=collection_url)
@ -55,8 +57,9 @@ def test_list(monkeypatch):
assert etag2 == etag assert etag2 == etag
found_items[normalize_item(item)] = href found_items[normalize_item(item)] = href
expected = {normalize_item('BEGIN:VCALENDAR\n' + x + '\nEND:VCALENDAR') expected = {
for x in items} normalize_item("BEGIN:VCALENDAR\n" + x + "\nEND:VCALENDAR") for x in items
}
assert set(found_items) == expected assert set(found_items) == expected
@ -68,7 +71,7 @@ def test_list(monkeypatch):
def test_readonly_param(): def test_readonly_param():
url = 'http://example.com/' url = "http://example.com/"
with pytest.raises(ValueError): with pytest.raises(ValueError):
HttpStorage(url=url, read_only=False) HttpStorage(url=url, read_only=False)
@ -78,43 +81,43 @@ def test_readonly_param():
def test_prepare_auth(): def test_prepare_auth():
assert prepare_auth(None, '', '') is None assert prepare_auth(None, "", "") is None
assert prepare_auth(None, 'user', 'pwd') == ('user', 'pwd') assert prepare_auth(None, "user", "pwd") == ("user", "pwd")
assert prepare_auth('basic', 'user', 'pwd') == ('user', 'pwd') assert prepare_auth("basic", "user", "pwd") == ("user", "pwd")
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
assert prepare_auth('basic', '', 'pwd') assert prepare_auth("basic", "", "pwd")
assert 'you need to specify username and password' in \ assert "you need to specify username and password" in str(excinfo.value).lower()
str(excinfo.value).lower()
from requests.auth import HTTPDigestAuth from requests.auth import HTTPDigestAuth
assert isinstance(prepare_auth('digest', 'user', 'pwd'),
HTTPDigestAuth) assert isinstance(prepare_auth("digest", "user", "pwd"), HTTPDigestAuth)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
prepare_auth('ladida', 'user', 'pwd') prepare_auth("ladida", "user", "pwd")
assert 'unknown authentication method' in str(excinfo.value).lower() assert "unknown authentication method" in str(excinfo.value).lower()
def test_prepare_auth_guess(monkeypatch): def test_prepare_auth_guess(monkeypatch):
import requests_toolbelt.auth.guess import requests_toolbelt.auth.guess
assert isinstance(prepare_auth('guess', 'user', 'pwd'), assert isinstance(
requests_toolbelt.auth.guess.GuessAuth) prepare_auth("guess", "user", "pwd"), requests_toolbelt.auth.guess.GuessAuth
)
monkeypatch.delattr(requests_toolbelt.auth.guess, 'GuessAuth') monkeypatch.delattr(requests_toolbelt.auth.guess, "GuessAuth")
with pytest.raises(UserError) as excinfo: with pytest.raises(UserError) as excinfo:
prepare_auth('guess', 'user', 'pwd') prepare_auth("guess", "user", "pwd")
assert 'requests_toolbelt is too old' in str(excinfo.value).lower() assert "requests_toolbelt is too old" in str(excinfo.value).lower()
def test_verify_false_disallowed(): def test_verify_false_disallowed():
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
HttpStorage(url='http://example.com', verify=False) HttpStorage(url="http://example.com", verify=False)
assert 'forbidden' in str(excinfo.value).lower() assert "forbidden" in str(excinfo.value).lower()
assert 'consider setting verify_fingerprint' in str(excinfo.value).lower() assert "consider setting verify_fingerprint" in str(excinfo.value).lower()

View file

@ -8,13 +8,14 @@ from vdirsyncer.storage.singlefile import SingleFileStorage
class CombinedStorage(Storage): class CombinedStorage(Storage):
'''A subclass of HttpStorage to make testing easier. It supports writes via """A subclass of HttpStorage to make testing easier. It supports writes via
SingleFileStorage.''' SingleFileStorage."""
_repr_attributes = ('url', 'path')
storage_name = 'http_and_singlefile' _repr_attributes = ("url", "path")
storage_name = "http_and_singlefile"
def __init__(self, url, path, **kwargs): def __init__(self, url, path, **kwargs):
if kwargs.get('collection', None) is not None: if kwargs.get("collection", None) is not None:
raise ValueError() raise ValueError()
super().__init__(**kwargs) super().__init__(**kwargs)
@ -48,30 +49,30 @@ class TestHttpStorage(StorageTests):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_tmpdir(self, tmpdir, monkeypatch): def setup_tmpdir(self, tmpdir, monkeypatch):
self.tmpfile = str(tmpdir.ensure('collection.txt')) self.tmpfile = str(tmpdir.ensure("collection.txt"))
def _request(method, url, *args, **kwargs): def _request(method, url, *args, **kwargs):
assert method == 'GET' assert method == "GET"
assert url == 'http://localhost:123/collection.txt' assert url == "http://localhost:123/collection.txt"
assert 'vdirsyncer' in kwargs['headers']['User-Agent'] assert "vdirsyncer" in kwargs["headers"]["User-Agent"]
r = Response() r = Response()
r.status_code = 200 r.status_code = 200
try: try:
with open(self.tmpfile, 'rb') as f: with open(self.tmpfile, "rb") as f:
r._content = f.read() r._content = f.read()
except OSError: except OSError:
r._content = b'' r._content = b""
r.headers['Content-Type'] = 'text/calendar' r.headers["Content-Type"] = "text/calendar"
r.encoding = 'utf-8' r.encoding = "utf-8"
return r return r
monkeypatch.setattr(vdirsyncer.storage.http, 'request', _request) monkeypatch.setattr(vdirsyncer.storage.http, "request", _request)
@pytest.fixture @pytest.fixture
def get_storage_args(self): def get_storage_args(self):
def inner(collection=None): def inner(collection=None):
assert collection is None assert collection is None
return {'url': 'http://localhost:123/collection.txt', return {"url": "http://localhost:123/collection.txt", "path": self.tmpfile}
'path': self.tmpfile}
return inner return inner

View file

@ -11,10 +11,10 @@ class TestSingleFileStorage(StorageTests):
@pytest.fixture @pytest.fixture
def get_storage_args(self, tmpdir): def get_storage_args(self, tmpdir):
def inner(collection='test'): def inner(collection="test"):
rv = {'path': str(tmpdir.join('%s.txt')), rv = {"path": str(tmpdir.join("%s.txt")), "collection": collection}
'collection': collection}
if collection is not None: if collection is not None:
rv = self.storage_class.create_collection(**rv) rv = self.storage_class.create_collection(**rv)
return rv return rv
return inner return inner

View file

@ -9,20 +9,24 @@ import vdirsyncer.cli as cli
class _CustomRunner: class _CustomRunner:
def __init__(self, tmpdir): def __init__(self, tmpdir):
self.tmpdir = tmpdir self.tmpdir = tmpdir
self.cfg = tmpdir.join('config') self.cfg = tmpdir.join("config")
self.runner = CliRunner() self.runner = CliRunner()
def invoke(self, args, env=None, **kwargs): def invoke(self, args, env=None, **kwargs):
env = env or {} env = env or {}
env.setdefault('VDIRSYNCER_CONFIG', str(self.cfg)) env.setdefault("VDIRSYNCER_CONFIG", str(self.cfg))
return self.runner.invoke(cli.app, args, env=env, **kwargs) return self.runner.invoke(cli.app, args, env=env, **kwargs)
def write_with_general(self, data): def write_with_general(self, data):
self.cfg.write(dedent(''' self.cfg.write(
dedent(
"""
[general] [general]
status_path = "{}/status/" status_path = "{}/status/"
''').format(str(self.tmpdir))) """
self.cfg.write(data, mode='a') ).format(str(self.tmpdir))
)
self.cfg.write(data, mode="a")
@pytest.fixture @pytest.fixture

View file

@ -15,16 +15,18 @@ invalid = object()
def read_config(tmpdir, monkeypatch): def read_config(tmpdir, monkeypatch):
def inner(cfg): def inner(cfg):
errors = [] errors = []
monkeypatch.setattr('vdirsyncer.cli.cli_logger.error', errors.append) monkeypatch.setattr("vdirsyncer.cli.cli_logger.error", errors.append)
f = io.StringIO(dedent(cfg.format(base=str(tmpdir)))) f = io.StringIO(dedent(cfg.format(base=str(tmpdir))))
rv = Config.from_fileobject(f) rv = Config.from_fileobject(f)
monkeypatch.undo() monkeypatch.undo()
return errors, rv return errors, rv
return inner return inner
def test_read_config(read_config): def test_read_config(read_config):
errors, c = read_config(''' errors, c = read_config(
"""
[general] [general]
status_path = "/tmp/status/" status_path = "/tmp/status/"
@ -42,25 +44,32 @@ def test_read_config(read_config):
[storage bob_b] [storage bob_b]
type = "carddav" type = "carddav"
''') """
)
assert c.general == {'status_path': '/tmp/status/'} assert c.general == {"status_path": "/tmp/status/"}
assert set(c.pairs) == {'bob'} assert set(c.pairs) == {"bob"}
bob = c.pairs['bob'] bob = c.pairs["bob"]
assert bob.collections is None assert bob.collections is None
assert c.storages == { assert c.storages == {
'bob_a': {'type': 'filesystem', 'path': '/tmp/contacts/', 'fileext': "bob_a": {
'.vcf', 'yesno': False, 'number': 42, "type": "filesystem",
'instance_name': 'bob_a'}, "path": "/tmp/contacts/",
'bob_b': {'type': 'carddav', 'instance_name': 'bob_b'} "fileext": ".vcf",
"yesno": False,
"number": 42,
"instance_name": "bob_a",
},
"bob_b": {"type": "carddav", "instance_name": "bob_b"},
} }
def test_missing_collections_param(read_config): def test_missing_collections_param(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
status_path = "/tmp/status/" status_path = "/tmp/status/"
@ -73,27 +82,31 @@ def test_missing_collections_param(read_config):
[storage bob_b] [storage bob_b]
type = "lmao" type = "lmao"
''') """
)
assert 'collections parameter missing' in str(excinfo.value) assert "collections parameter missing" in str(excinfo.value)
def test_invalid_section_type(read_config): def test_invalid_section_type(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
status_path = "/tmp/status/" status_path = "/tmp/status/"
[bogus] [bogus]
''') """
)
assert 'Unknown section' in str(excinfo.value) assert "Unknown section" in str(excinfo.value)
assert 'bogus' in str(excinfo.value) assert "bogus" in str(excinfo.value)
def test_missing_general_section(read_config): def test_missing_general_section(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[pair my_pair] [pair my_pair]
a = "my_a" a = "my_a"
b = "my_b" b = "my_b"
@ -108,40 +121,46 @@ def test_missing_general_section(read_config):
type = "filesystem" type = "filesystem"
path = "{base}/path_b/" path = "{base}/path_b/"
fileext = ".txt" fileext = ".txt"
''') """
)
assert 'Invalid general section.' in str(excinfo.value) assert "Invalid general section." in str(excinfo.value)
def test_wrong_general_section(read_config): def test_wrong_general_section(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
wrong = true wrong = true
''') """
)
assert 'Invalid general section.' in str(excinfo.value) assert "Invalid general section." in str(excinfo.value)
assert excinfo.value.problems == [ assert excinfo.value.problems == [
'general section doesn\'t take the parameters: wrong', "general section doesn't take the parameters: wrong",
'general section is missing the parameters: status_path' "general section is missing the parameters: status_path",
] ]
def test_invalid_storage_name(read_config): def test_invalid_storage_name(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
status_path = "{base}/status/" status_path = "{base}/status/"
[storage foo.bar] [storage foo.bar]
''') """
)
assert 'invalid characters' in str(excinfo.value).lower() assert "invalid characters" in str(excinfo.value).lower()
def test_invalid_collections_arg(read_config): def test_invalid_collections_arg(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
status_path = "/tmp/status/" status_path = "/tmp/status/"
@ -159,14 +178,16 @@ def test_invalid_collections_arg(read_config):
type = "filesystem" type = "filesystem"
path = "/tmp/bar/" path = "/tmp/bar/"
fileext = ".txt" fileext = ".txt"
''') """
)
assert 'Expected string' in str(excinfo.value) assert "Expected string" in str(excinfo.value)
def test_duplicate_sections(read_config): def test_duplicate_sections(read_config):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
read_config(''' read_config(
"""
[general] [general]
status_path = "/tmp/status/" status_path = "/tmp/status/"
@ -184,7 +205,8 @@ def test_duplicate_sections(read_config):
type = "filesystem" type = "filesystem"
path = "/tmp/bar/" path = "/tmp/bar/"
fileext = ".txt" fileext = ".txt"
''') """
)
assert 'Name "foobar" already used' in str(excinfo.value) assert 'Name "foobar" already used' in str(excinfo.value)

View file

@ -8,7 +8,9 @@ from vdirsyncer.storage.base import Storage
def test_discover_command(tmpdir, runner): def test_discover_command(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
path = "{0}/foo/" path = "{0}/foo/"
@ -23,50 +25,51 @@ def test_discover_command(tmpdir, runner):
a = "foo" a = "foo"
b = "bar" b = "bar"
collections = ["from a"] collections = ["from a"]
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir("foo")
bar = tmpdir.mkdir('bar') bar = tmpdir.mkdir("bar")
for x in 'abc': for x in "abc":
foo.mkdir(x) foo.mkdir(x)
bar.mkdir(x) bar.mkdir(x)
bar.mkdir('d') bar.mkdir("d")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
foo.mkdir('d') foo.mkdir("d")
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
lines = result.output.splitlines() lines = result.output.splitlines()
assert 'Syncing foobar/a' in lines assert "Syncing foobar/a" in lines
assert 'Syncing foobar/b' in lines assert "Syncing foobar/b" in lines
assert 'Syncing foobar/c' in lines assert "Syncing foobar/c" in lines
assert 'Syncing foobar/d' not in result.output assert "Syncing foobar/d" not in result.output
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert 'Syncing foobar/a' in lines assert "Syncing foobar/a" in lines
assert 'Syncing foobar/b' in lines assert "Syncing foobar/b" in lines
assert 'Syncing foobar/c' in lines assert "Syncing foobar/c" in lines
assert 'Syncing foobar/d' in result.output assert "Syncing foobar/d" in result.output
# Check for redundant data that is already in the config. This avoids # Check for redundant data that is already in the config. This avoids
# copying passwords from the config too. # copying passwords from the config too.
assert 'fileext' not in tmpdir \ assert "fileext" not in tmpdir.join("status").join("foobar.collections").read()
.join('status') \
.join('foobar.collections') \
.read()
def test_discover_different_collection_names(tmpdir, runner): def test_discover_different_collection_names(tmpdir, runner):
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir("foo")
bar = tmpdir.mkdir('bar') bar = tmpdir.mkdir("bar")
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
fileext = ".txt" fileext = ".txt"
@ -84,35 +87,39 @@ def test_discover_different_collection_names(tmpdir, runner):
["coll1", "coll_a1", "coll_b1"], ["coll1", "coll_a1", "coll_b1"],
"coll2" "coll2"
] ]
''').format(foo=str(foo), bar=str(bar))) """
).format(foo=str(foo), bar=str(bar))
)
result = runner.invoke(['discover'], input='y\n' * 6) result = runner.invoke(["discover"], input="y\n" * 6)
assert not result.exception assert not result.exception
coll_a1 = foo.join('coll_a1') coll_a1 = foo.join("coll_a1")
coll_b1 = bar.join('coll_b1') coll_b1 = bar.join("coll_b1")
assert coll_a1.exists() assert coll_a1.exists()
assert coll_b1.exists() assert coll_b1.exists()
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
foo_txt = coll_a1.join('foo.txt') foo_txt = coll_a1.join("foo.txt")
foo_txt.write('BEGIN:VCALENDAR\nUID:foo\nEND:VCALENDAR') foo_txt.write("BEGIN:VCALENDAR\nUID:foo\nEND:VCALENDAR")
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert foo_txt.exists() assert foo_txt.exists()
assert coll_b1.join('foo.txt').exists() assert coll_b1.join("foo.txt").exists()
def test_discover_direct_path(tmpdir, runner): def test_discover_direct_path(tmpdir, runner):
foo = tmpdir.join('foo') foo = tmpdir.join("foo")
bar = tmpdir.join('bar') bar = tmpdir.join("bar")
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
fileext = ".txt" fileext = ".txt"
@ -127,12 +134,14 @@ def test_discover_direct_path(tmpdir, runner):
a = "foo" a = "foo"
b = "bar" b = "bar"
collections = null collections = null
''').format(foo=str(foo), bar=str(bar))) """
).format(foo=str(foo), bar=str(bar))
)
result = runner.invoke(['discover'], input='y\n' * 2) result = runner.invoke(["discover"], input="y\n" * 2)
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert foo.exists() assert foo.exists()
@ -140,7 +149,9 @@ def test_discover_direct_path(tmpdir, runner):
def test_null_collection_with_named_collection(tmpdir, runner): def test_null_collection_with_named_collection(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -154,25 +165,29 @@ def test_null_collection_with_named_collection(tmpdir, runner):
[storage bar] [storage bar]
type = "singlefile" type = "singlefile"
path = "{base}/bar.txt" path = "{base}/bar.txt"
'''.format(base=str(tmpdir)))) """.format(
base=str(tmpdir)
)
)
)
result = runner.invoke(['discover'], input='y\n' * 2) result = runner.invoke(["discover"], input="y\n" * 2)
assert not result.exception assert not result.exception
foo = tmpdir.join('foo') foo = tmpdir.join("foo")
foobaz = foo.join('baz') foobaz = foo.join("baz")
assert foo.exists() assert foo.exists()
assert foobaz.exists() assert foobaz.exists()
bar = tmpdir.join('bar.txt') bar = tmpdir.join("bar.txt")
assert bar.exists() assert bar.exists()
foobaz.join('lol.txt').write('BEGIN:VCARD\nUID:HAHA\nEND:VCARD') foobaz.join("lol.txt").write("BEGIN:VCARD\nUID:HAHA\nEND:VCARD")
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert 'HAHA' in bar.read() assert "HAHA" in bar.read()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -182,23 +197,24 @@ def test_null_collection_with_named_collection(tmpdir, runner):
(True, False), (True, False),
(False, True), (False, True),
(False, False), (False, False),
] ],
) )
def test_collection_required(a_requires, b_requires, tmpdir, runner, def test_collection_required(a_requires, b_requires, tmpdir, runner, monkeypatch):
monkeypatch):
class TestStorage(Storage): class TestStorage(Storage):
storage_name = 'test' storage_name = "test"
def __init__(self, require_collection, **kw): def __init__(self, require_collection, **kw):
if require_collection: if require_collection:
assert not kw.get('collection') assert not kw.get("collection")
raise exceptions.CollectionRequired() raise exceptions.CollectionRequired()
from vdirsyncer.cli.utils import storage_names from vdirsyncer.cli.utils import storage_names
monkeypatch.setitem(storage_names._storages, 'test', TestStorage)
runner.write_with_general(dedent(''' monkeypatch.setitem(storage_names._storages, "test", TestStorage)
runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -211,11 +227,15 @@ def test_collection_required(a_requires, b_requires, tmpdir, runner,
[storage bar] [storage bar]
type = "test" type = "test"
require_collection = {b} require_collection = {b}
'''.format(a=json.dumps(a_requires), b=json.dumps(b_requires)))) """.format(
a=json.dumps(a_requires), b=json.dumps(b_requires)
)
)
)
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
if a_requires or b_requires: if a_requires or b_requires:
assert result.exception assert result.exception
assert \ assert (
'One or more storages don\'t support `collections = null`.' in \ "One or more storages don't support `collections = null`." in result.output
result.output )

View file

@ -2,7 +2,9 @@ from textwrap import dedent
def test_get_password_from_command(tmpdir, runner): def test_get_password_from_command(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -17,26 +19,30 @@ def test_get_password_from_command(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{base}/bar/" path = "{base}/bar/"
fileext.fetch = ["prompt", "Fileext for bar"] fileext.fetch = ["prompt", "Fileext for bar"]
'''.format(base=str(tmpdir)))) """.format(
base=str(tmpdir)
)
)
)
foo = tmpdir.ensure('foo', dir=True) foo = tmpdir.ensure("foo", dir=True)
foo.ensure('a', dir=True) foo.ensure("a", dir=True)
foo.ensure('b', dir=True) foo.ensure("b", dir=True)
foo.ensure('c', dir=True) foo.ensure("c", dir=True)
bar = tmpdir.ensure('bar', dir=True) bar = tmpdir.ensure("bar", dir=True)
bar.ensure('a', dir=True) bar.ensure("a", dir=True)
bar.ensure('b', dir=True) bar.ensure("b", dir=True)
bar.ensure('c', dir=True) bar.ensure("c", dir=True)
result = runner.invoke(['discover'], input='.asdf\n') result = runner.invoke(["discover"], input=".asdf\n")
assert not result.exception assert not result.exception
status = tmpdir.join('status').join('foobar.collections').read() status = tmpdir.join("status").join("foobar.collections").read()
assert 'foo' in status assert "foo" in status
assert 'bar' in status assert "bar" in status
assert 'asdf' not in status assert "asdf" not in status
assert 'txt' not in status assert "txt" not in status
foo.join('a').join('foo.txt').write('BEGIN:VCARD\nUID:foo\nEND:VCARD') foo.join("a").join("foo.txt").write("BEGIN:VCARD\nUID:foo\nEND:VCARD")
result = runner.invoke(['sync'], input='.asdf\n') result = runner.invoke(["sync"], input=".asdf\n")
assert not result.exception assert not result.exception
assert [x.basename for x in bar.join('a').listdir()] == ['foo.asdf'] assert [x.basename for x in bar.join("a").listdir()] == ["foo.asdf"]

View file

@ -5,67 +5,72 @@ import pytest
@pytest.fixture @pytest.fixture
def storage(tmpdir, runner): def storage(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
path = "{base}/foo/" path = "{base}/foo/"
fileext = ".txt" fileext = ".txt"
''').format(base=str(tmpdir))) """
).format(base=str(tmpdir))
)
return tmpdir.mkdir('foo') return tmpdir.mkdir("foo")
@pytest.mark.parametrize('collection', [None, "foocoll"]) @pytest.mark.parametrize("collection", [None, "foocoll"])
def test_basic(storage, runner, collection): def test_basic(storage, runner, collection):
if collection is not None: if collection is not None:
storage = storage.mkdir(collection) storage = storage.mkdir(collection)
collection_arg = f'foo/{collection}' collection_arg = f"foo/{collection}"
else: else:
collection_arg = 'foo' collection_arg = "foo"
argv = ['repair', collection_arg] argv = ["repair", collection_arg]
result = runner.invoke(argv, input='y') result = runner.invoke(argv, input="y")
assert not result.exception assert not result.exception
storage.join('item.txt').write('BEGIN:VCARD\nEND:VCARD') storage.join("item.txt").write("BEGIN:VCARD\nEND:VCARD")
storage.join('toobroken.txt').write('') storage.join("toobroken.txt").write("")
result = runner.invoke(argv, input='y') result = runner.invoke(argv, input="y")
assert not result.exception assert not result.exception
assert 'No UID' in result.output assert "No UID" in result.output
assert '\'toobroken.txt\' is malformed beyond repair' \ assert "'toobroken.txt' is malformed beyond repair" in result.output
in result.output (new_fname,) = [x for x in storage.listdir() if "toobroken" not in str(x)]
new_fname, = [x for x in storage.listdir() if 'toobroken' not in str(x)] assert "UID:" in new_fname.read()
assert 'UID:' in new_fname.read()
@pytest.mark.parametrize('repair_uids', [None, True, False]) @pytest.mark.parametrize("repair_uids", [None, True, False])
def test_repair_uids(storage, runner, repair_uids): def test_repair_uids(storage, runner, repair_uids):
f = storage.join('baduid.txt') f = storage.join("baduid.txt")
orig_f = 'BEGIN:VCARD\nUID:!!!!!\nEND:VCARD' orig_f = "BEGIN:VCARD\nUID:!!!!!\nEND:VCARD"
f.write(orig_f) f.write(orig_f)
if repair_uids is None: if repair_uids is None:
opt = [] opt = []
elif repair_uids: elif repair_uids:
opt = ['--repair-unsafe-uid'] opt = ["--repair-unsafe-uid"]
else: else:
opt = ['--no-repair-unsafe-uid'] opt = ["--no-repair-unsafe-uid"]
result = runner.invoke(['repair'] + opt + ['foo'], input='y') result = runner.invoke(["repair"] + opt + ["foo"], input="y")
assert not result.exception assert not result.exception
if repair_uids: if repair_uids:
assert 'UID or href is unsafe, assigning random UID' in result.output assert "UID or href is unsafe, assigning random UID" in result.output
assert not f.exists() assert not f.exists()
new_f, = storage.listdir() (new_f,) = storage.listdir()
s = new_f.read() s = new_f.read()
assert s.startswith('BEGIN:VCARD') assert s.startswith("BEGIN:VCARD")
assert s.endswith('END:VCARD') assert s.endswith("END:VCARD")
assert s != orig_f assert s != orig_f
else: else:
assert 'UID may cause problems, add --repair-unsafe-uid to repair.' \ assert (
"UID may cause problems, add --repair-unsafe-uid to repair."
in result.output in result.output
)
assert f.read() == orig_f assert f.read() == orig_f

View file

@ -8,7 +8,9 @@ from hypothesis import settings
def test_simple_run(tmpdir, runner): def test_simple_run(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair my_pair] [pair my_pair]
a = "my_a" a = "my_a"
b = "my_b" b = "my_b"
@ -23,33 +25,37 @@ def test_simple_run(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{0}/path_b/" path = "{0}/path_b/"
fileext = ".txt" fileext = ".txt"
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
tmpdir.mkdir('path_a') tmpdir.mkdir("path_a")
tmpdir.mkdir('path_b') tmpdir.mkdir("path_b")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
tmpdir.join('path_a/haha.txt').write('UID:haha') tmpdir.join("path_a/haha.txt").write("UID:haha")
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert 'Copying (uploading) item haha to my_b' in result.output assert "Copying (uploading) item haha to my_b" in result.output
assert tmpdir.join('path_b/haha.txt').read() == 'UID:haha' assert tmpdir.join("path_b/haha.txt").read() == "UID:haha"
def test_sync_inexistant_pair(tmpdir, runner): def test_sync_inexistant_pair(tmpdir, runner):
runner.write_with_general("") runner.write_with_general("")
result = runner.invoke(['sync', 'foo']) result = runner.invoke(["sync", "foo"])
assert result.exception assert result.exception
assert 'pair foo does not exist.' in result.output.lower() assert "pair foo does not exist." in result.output.lower()
def test_debug_connections(tmpdir, runner): def test_debug_connections(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair my_pair] [pair my_pair]
a = "my_a" a = "my_a"
b = "my_b" b = "my_b"
@ -64,23 +70,27 @@ def test_debug_connections(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{0}/path_b/" path = "{0}/path_b/"
fileext = ".txt" fileext = ".txt"
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
tmpdir.mkdir('path_a') tmpdir.mkdir("path_a")
tmpdir.mkdir('path_b') tmpdir.mkdir("path_b")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['-vdebug', 'sync', '--max-workers=3']) result = runner.invoke(["-vdebug", "sync", "--max-workers=3"])
assert 'using 3 maximal workers' in result.output.lower() assert "using 3 maximal workers" in result.output.lower()
result = runner.invoke(['-vdebug', 'sync']) result = runner.invoke(["-vdebug", "sync"])
assert 'using 1 maximal workers' in result.output.lower() assert "using 1 maximal workers" in result.output.lower()
def test_empty_storage(tmpdir, runner): def test_empty_storage(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair my_pair] [pair my_pair]
a = "my_a" a = "my_a"
b = "my_b" b = "my_b"
@ -95,32 +105,35 @@ def test_empty_storage(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{0}/path_b/" path = "{0}/path_b/"
fileext = ".txt" fileext = ".txt"
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
tmpdir.mkdir('path_a') tmpdir.mkdir("path_a")
tmpdir.mkdir('path_b') tmpdir.mkdir("path_b")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
tmpdir.join('path_a/haha.txt').write('UID:haha') tmpdir.join("path_a/haha.txt").write("UID:haha")
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
tmpdir.join('path_b/haha.txt').remove() tmpdir.join("path_b/haha.txt").remove()
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
lines = result.output.splitlines() lines = result.output.splitlines()
assert lines[0] == 'Syncing my_pair' assert lines[0] == "Syncing my_pair"
assert lines[1].startswith('error: my_pair: ' assert lines[1].startswith(
'Storage "my_b" was completely emptied.') "error: my_pair: " 'Storage "my_b" was completely emptied.'
)
assert result.exception assert result.exception
def test_verbosity(tmpdir, runner): def test_verbosity(tmpdir, runner):
runner.write_with_general('') runner.write_with_general("")
result = runner.invoke(['--verbosity=HAHA', 'sync']) result = runner.invoke(["--verbosity=HAHA", "sync"])
assert result.exception assert result.exception
assert ( assert (
'invalid value for "--verbosity"' in result.output.lower() 'invalid value for "--verbosity"' in result.output.lower()
@ -129,13 +142,15 @@ def test_verbosity(tmpdir, runner):
def test_collections_cache_invalidation(tmpdir, runner): def test_collections_cache_invalidation(tmpdir, runner):
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir("foo")
bar = tmpdir.mkdir('bar') bar = tmpdir.mkdir("bar")
for x in 'abc': for x in "abc":
foo.mkdir(x) foo.mkdir(x)
bar.mkdir(x) bar.mkdir(x)
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
path = "{0}/foo/" path = "{0}/foo/"
@ -150,22 +165,26 @@ def test_collections_cache_invalidation(tmpdir, runner):
a = "foo" a = "foo"
b = "bar" b = "bar"
collections = ["a", "b", "c"] collections = ["a", "b", "c"]
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
foo.join('a/itemone.txt').write('UID:itemone') foo.join("a/itemone.txt").write("UID:itemone")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert 'detected change in config file' not in result.output.lower() assert "detected change in config file" not in result.output.lower()
rv = bar.join('a').listdir() rv = bar.join("a").listdir()
assert len(rv) == 1 assert len(rv) == 1
assert rv[0].basename == 'itemone.txt' assert rv[0].basename == "itemone.txt"
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
path = "{0}/foo/" path = "{0}/foo/"
@ -180,32 +199,36 @@ def test_collections_cache_invalidation(tmpdir, runner):
a = "foo" a = "foo"
b = "bar" b = "bar"
collections = ["a", "b", "c"] collections = ["a", "b", "c"]
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
for entry in tmpdir.join('status').listdir(): for entry in tmpdir.join("status").listdir():
if not str(entry).endswith('.collections'): if not str(entry).endswith(".collections"):
entry.remove() entry.remove()
bar2 = tmpdir.mkdir('bar2') bar2 = tmpdir.mkdir("bar2")
for x in 'abc': for x in "abc":
bar2.mkdir(x) bar2.mkdir(x)
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert 'detected change in config file' in result.output.lower() assert "detected change in config file" in result.output.lower()
assert result.exception assert result.exception
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
rv = bar.join('a').listdir() rv = bar.join("a").listdir()
rv2 = bar2.join('a').listdir() rv2 = bar2.join("a").listdir()
assert len(rv) == len(rv2) == 1 assert len(rv) == len(rv2) == 1
assert rv[0].basename == rv2[0].basename == 'itemone.txt' assert rv[0].basename == rv2[0].basename == "itemone.txt"
def test_invalid_pairs_as_cli_arg(tmpdir, runner): def test_invalid_pairs_as_cli_arg(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[storage foo] [storage foo]
type = "filesystem" type = "filesystem"
path = "{0}/foo/" path = "{0}/foo/"
@ -220,85 +243,92 @@ def test_invalid_pairs_as_cli_arg(tmpdir, runner):
a = "foo" a = "foo"
b = "bar" b = "bar"
collections = ["a", "b", "c"] collections = ["a", "b", "c"]
''').format(str(tmpdir))) """
).format(str(tmpdir))
)
for base in ('foo', 'bar'): for base in ("foo", "bar"):
base = tmpdir.mkdir(base) base = tmpdir.mkdir(base)
for c in 'abc': for c in "abc":
base.mkdir(c) base.mkdir(c)
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync', 'foobar/d']) result = runner.invoke(["sync", "foobar/d"])
assert result.exception assert result.exception
assert 'pair foobar: collection "d" not found' in result.output.lower() assert 'pair foobar: collection "d" not found' in result.output.lower()
def test_multiple_pairs(tmpdir, runner): def test_multiple_pairs(tmpdir, runner):
def get_cfg(): def get_cfg():
for name_a, name_b in ('foo', 'bar'), ('bam', 'baz'): for name_a, name_b in ("foo", "bar"), ("bam", "baz"):
yield dedent(''' yield dedent(
"""
[pair {a}{b}] [pair {a}{b}]
a = "{a}" a = "{a}"
b = "{b}" b = "{b}"
collections = null collections = null
''').format(a=name_a, b=name_b) """
).format(a=name_a, b=name_b)
for name in name_a, name_b: for name in name_a, name_b:
yield dedent(''' yield dedent(
"""
[storage {name}] [storage {name}]
type = "filesystem" type = "filesystem"
path = "{path}" path = "{path}"
fileext = ".txt" fileext = ".txt"
''').format(name=name, path=str(tmpdir.mkdir(name))) """
).format(name=name, path=str(tmpdir.mkdir(name)))
runner.write_with_general(''.join(get_cfg())) runner.write_with_general("".join(get_cfg()))
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
assert set(result.output.splitlines()) > { assert set(result.output.splitlines()) > {
'Discovering collections for pair bambaz', "Discovering collections for pair bambaz",
'Discovering collections for pair foobar' "Discovering collections for pair foobar",
} }
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert not result.exception assert not result.exception
assert set(result.output.splitlines()) == { assert set(result.output.splitlines()) == {
'Syncing bambaz', "Syncing bambaz",
'Syncing foobar', "Syncing foobar",
} }
collections_strategy = st.sets( collections_strategy = st.sets(
st.text( st.text(
st.characters( st.characters(
blacklist_characters=set( blacklist_characters=set("./\x00"), # Invalid chars on POSIX filesystems
'./\x00' # Invalid chars on POSIX filesystems
),
# Surrogates can't be encoded to utf-8 in Python # Surrogates can't be encoded to utf-8 in Python
blacklist_categories={'Cs'} blacklist_categories={"Cs"},
), ),
min_size=1, min_size=1,
max_size=50 max_size=50,
), ),
min_size=1 min_size=1,
) )
# XXX: https://github.com/pimutils/vdirsyncer/issues/617 # XXX: https://github.com/pimutils/vdirsyncer/issues/617
@pytest.mark.skipif(sys.platform == 'darwin', @pytest.mark.skipif(sys.platform == "darwin", reason="This test inexplicably fails")
reason='This test inexplicably fails')
@pytest.mark.parametrize( @pytest.mark.parametrize(
"collections", "collections",
[ [
('persönlich',), ("persönlich",),
('a', 'A',), (
('\ufffe',), "a",
] + [ "A",
),
("\ufffe",),
]
+ [
collections_strategy.example() collections_strategy.example()
for _ in range(settings.get_profile(settings._current_profile).max_examples) for _ in range(settings.get_profile(settings._current_profile).max_examples)
] ],
) )
def test_create_collections(collections, tmpdir, runner): def test_create_collections(collections, tmpdir, runner):
# Hypothesis calls this tests in a way that fixtures are not reset, to tmpdir is the # Hypothesis calls this tests in a way that fixtures are not reset, to tmpdir is the
@ -306,7 +336,9 @@ def test_create_collections(collections, tmpdir, runner):
# This horrible hack creates a new subdirectory on each run, effectively giving us a # This horrible hack creates a new subdirectory on each run, effectively giving us a
# new tmpdir each run. # new tmpdir each run.
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -321,25 +353,27 @@ def test_create_collections(collections, tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{base}/bar/" path = "{base}/bar/"
fileext = ".txt" fileext = ".txt"
'''.format(base=str(tmpdir), colls=json.dumps(list(collections))))) """.format(
base=str(tmpdir), colls=json.dumps(list(collections))
result = runner.invoke( )
['discover'], )
input='y\n' * 2 * (len(collections) + 1)
) )
result = runner.invoke(["discover"], input="y\n" * 2 * (len(collections) + 1))
assert not result.exception, result.output assert not result.exception, result.output
result = runner.invoke( result = runner.invoke(["sync"] + ["foobar/" + x for x in collections])
['sync'] + ['foobar/' + x for x in collections]
)
assert not result.exception, result.output assert not result.exception, result.output
assert {x.basename for x in tmpdir.join('foo').listdir()} == \ assert {x.basename for x in tmpdir.join("foo").listdir()} == {
{x.basename for x in tmpdir.join('bar').listdir()} x.basename for x in tmpdir.join("bar").listdir()
}
def test_ident_conflict(tmpdir, runner): def test_ident_conflict(tmpdir, runner):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -354,35 +388,51 @@ def test_ident_conflict(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{base}/bar/" path = "{base}/bar/"
fileext = ".txt" fileext = ".txt"
'''.format(base=str(tmpdir)))) """.format(
base=str(tmpdir)
)
)
)
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir("foo")
tmpdir.mkdir('bar') tmpdir.mkdir("bar")
foo.join('one.txt').write('UID:1') foo.join("one.txt").write("UID:1")
foo.join('two.txt').write('UID:1') foo.join("two.txt").write("UID:1")
foo.join('three.txt').write('UID:1') foo.join("three.txt").write("UID:1")
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert not result.exception assert not result.exception
result = runner.invoke(['sync']) result = runner.invoke(["sync"])
assert result.exception assert result.exception
assert ('error: foobar: Storage "foo" contains multiple items with the ' assert (
'same UID or even content') in result.output 'error: foobar: Storage "foo" contains multiple items with the '
assert sorted([ "same UID or even content"
'one.txt' in result.output, ) in result.output
'two.txt' in result.output, assert (
'three.txt' in result.output, sorted(
]) == [False, True, True] [
"one.txt" in result.output,
"two.txt" in result.output,
"three.txt" in result.output,
]
)
== [False, True, True]
)
@pytest.mark.parametrize('existing,missing', [ @pytest.mark.parametrize(
('foo', 'bar'), "existing,missing",
('bar', 'foo'), [
]) ("foo", "bar"),
("bar", "foo"),
],
)
def test_unknown_storage(tmpdir, runner, existing, missing): def test_unknown_storage(tmpdir, runner, existing, missing):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -392,35 +442,42 @@ def test_unknown_storage(tmpdir, runner, existing, missing):
type = "filesystem" type = "filesystem"
path = "{base}/{existing}/" path = "{base}/{existing}/"
fileext = ".txt" fileext = ".txt"
'''.format(base=str(tmpdir), existing=existing))) """.format(
base=str(tmpdir), existing=existing
)
)
)
tmpdir.mkdir(existing) tmpdir.mkdir(existing)
result = runner.invoke(['discover']) result = runner.invoke(["discover"])
assert result.exception assert result.exception
assert ( assert (
"Storage '{missing}' not found. " "Storage '{missing}' not found. "
"These are the configured storages: ['{existing}']" "These are the configured storages: ['{existing}']".format(
.format(missing=missing, existing=existing) missing=missing, existing=existing
)
) in result.output ) in result.output
@pytest.mark.parametrize('cmd', ['sync', 'metasync']) @pytest.mark.parametrize("cmd", ["sync", "metasync"])
def test_no_configured_pairs(tmpdir, runner, cmd): def test_no_configured_pairs(tmpdir, runner, cmd):
runner.write_with_general('') runner.write_with_general("")
result = runner.invoke([cmd]) result = runner.invoke([cmd])
assert result.output == 'critical: Nothing to do.\n' assert result.output == "critical: Nothing to do.\n"
assert result.exception.code == 5 assert result.exception.code == 5
@pytest.mark.parametrize('resolution,expect_foo,expect_bar', [ @pytest.mark.parametrize(
(['command', 'cp'], 'UID:lol\nfööcontent', 'UID:lol\nfööcontent') "resolution,expect_foo,expect_bar",
]) [(["command", "cp"], "UID:lol\nfööcontent", "UID:lol\nfööcontent")],
def test_conflict_resolution(tmpdir, runner, resolution, expect_foo, )
expect_bar): def test_conflict_resolution(tmpdir, runner, resolution, expect_foo, expect_bar):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -436,28 +493,34 @@ def test_conflict_resolution(tmpdir, runner, resolution, expect_foo,
type = "filesystem" type = "filesystem"
fileext = ".txt" fileext = ".txt"
path = "{base}/bar" path = "{base}/bar"
'''.format(base=str(tmpdir), val=json.dumps(resolution)))) """.format(
base=str(tmpdir), val=json.dumps(resolution)
)
)
)
foo = tmpdir.join('foo') foo = tmpdir.join("foo")
bar = tmpdir.join('bar') bar = tmpdir.join("bar")
fooitem = foo.join('lol.txt').ensure() fooitem = foo.join("lol.txt").ensure()
fooitem.write('UID:lol\nfööcontent') fooitem.write("UID:lol\nfööcontent")
baritem = bar.join('lol.txt').ensure() baritem = bar.join("lol.txt").ensure()
baritem.write('UID:lol\nbööcontent') baritem.write("UID:lol\nbööcontent")
r = runner.invoke(['discover']) r = runner.invoke(["discover"])
assert not r.exception assert not r.exception
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
assert not r.exception assert not r.exception
assert fooitem.read() == expect_foo assert fooitem.read() == expect_foo
assert baritem.read() == expect_bar assert baritem.read() == expect_bar
@pytest.mark.parametrize('partial_sync', ['error', 'ignore', 'revert', None]) @pytest.mark.parametrize("partial_sync", ["error", "ignore", "revert", None])
def test_partial_sync(tmpdir, runner, partial_sync): def test_partial_sync(tmpdir, runner, partial_sync):
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -474,58 +537,69 @@ def test_partial_sync(tmpdir, runner, partial_sync):
read_only = true read_only = true
fileext = ".txt" fileext = ".txt"
path = "{base}/bar" path = "{base}/bar"
'''.format( """.format(
partial_sync=(f'partial_sync = "{partial_sync}"\n' partial_sync=(
if partial_sync else ''), f'partial_sync = "{partial_sync}"\n' if partial_sync else ""
base=str(tmpdir) ),
))) base=str(tmpdir),
)
)
)
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir("foo")
bar = tmpdir.mkdir('bar') bar = tmpdir.mkdir("bar")
foo.join('other.txt').write('UID:other') foo.join("other.txt").write("UID:other")
bar.join('other.txt').write('UID:other') bar.join("other.txt").write("UID:other")
baritem = bar.join('lol.txt') baritem = bar.join("lol.txt")
baritem.write('UID:lol') baritem.write("UID:lol")
r = runner.invoke(['discover']) r = runner.invoke(["discover"])
assert not r.exception assert not r.exception
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
assert not r.exception assert not r.exception
fooitem = foo.join('lol.txt') fooitem = foo.join("lol.txt")
fooitem.remove() fooitem.remove()
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
if partial_sync == 'error': if partial_sync == "error":
assert r.exception assert r.exception
assert 'Attempted change' in r.output assert "Attempted change" in r.output
elif partial_sync == 'ignore': elif partial_sync == "ignore":
assert baritem.exists() assert baritem.exists()
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
assert not r.exception assert not r.exception
assert baritem.exists() assert baritem.exists()
else: else:
assert baritem.exists() assert baritem.exists()
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
assert not r.exception assert not r.exception
assert baritem.exists() assert baritem.exists()
assert fooitem.exists() assert fooitem.exists()
def test_fetch_only_necessary_params(tmpdir, runner): def test_fetch_only_necessary_params(tmpdir, runner):
fetched_file = tmpdir.join('fetched_flag') fetched_file = tmpdir.join("fetched_flag")
fetch_script = tmpdir.join('fetch_script') fetch_script = tmpdir.join("fetch_script")
fetch_script.write(dedent(''' fetch_script.write(
dedent(
"""
set -e set -e
touch "{}" touch "{}"
echo ".txt" echo ".txt"
'''.format(str(fetched_file)))) """.format(
str(fetched_file)
)
)
)
runner.write_with_general(dedent(''' runner.write_with_general(
dedent(
"""
[pair foobar] [pair foobar]
a = "foo" a = "foo"
b = "bar" b = "bar"
@ -550,7 +624,11 @@ def test_fetch_only_necessary_params(tmpdir, runner):
type = "filesystem" type = "filesystem"
path = "{path}" path = "{path}"
fileext.fetch = ["command", "sh", "{script}"] fileext.fetch = ["command", "sh", "{script}"]
'''.format(path=str(tmpdir.mkdir('bogus')), script=str(fetch_script)))) """.format(
path=str(tmpdir.mkdir("bogus")), script=str(fetch_script)
)
)
)
def fetched(): def fetched():
try: try:
@ -559,18 +637,18 @@ def test_fetch_only_necessary_params(tmpdir, runner):
except Exception: except Exception:
return False return False
r = runner.invoke(['discover']) r = runner.invoke(["discover"])
assert not r.exception assert not r.exception
assert fetched() assert fetched()
r = runner.invoke(['sync', 'foobar']) r = runner.invoke(["sync", "foobar"])
assert not r.exception assert not r.exception
assert not fetched() assert not fetched()
r = runner.invoke(['sync']) r = runner.invoke(["sync"])
assert not r.exception assert not r.exception
assert fetched() assert fetched()
r = runner.invoke(['sync', 'bambar']) r = runner.invoke(["sync", "bambar"])
assert not r.exception assert not r.exception
assert fetched() assert fetched()

View file

@ -6,20 +6,20 @@ from vdirsyncer.cli.utils import storage_names
def test_handle_cli_error(capsys): def test_handle_cli_error(capsys):
try: try:
raise exceptions.InvalidResponse('ayy lmao') raise exceptions.InvalidResponse("ayy lmao")
except BaseException: except BaseException:
handle_cli_error() handle_cli_error()
out, err = capsys.readouterr() out, err = capsys.readouterr()
assert 'returned something vdirsyncer doesn\'t understand' in err assert "returned something vdirsyncer doesn't understand" in err
assert 'ayy lmao' in err assert "ayy lmao" in err
def test_storage_instance_from_config(monkeypatch): def test_storage_instance_from_config(monkeypatch):
def lol(**kw): def lol(**kw):
assert kw == {'foo': 'bar', 'baz': 1} assert kw == {"foo": "bar", "baz": 1}
return 'OK' return "OK"
monkeypatch.setitem(storage_names._storages, 'lol', lol) monkeypatch.setitem(storage_names._storages, "lol", lol)
config = {'type': 'lol', 'foo': 'bar', 'baz': 1} config = {"type": "lol", "foo": "bar", "baz": 1}
assert storage_instance_from_config(config) == 'OK' assert storage_instance_from_config(config) == "OK"

View file

@ -11,7 +11,7 @@ from vdirsyncer import utils
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def no_debug_output(request): def no_debug_output(request):
logger = click_log.basic_config('vdirsyncer') logger = click_log.basic_config("vdirsyncer")
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
@ -19,47 +19,55 @@ def test_get_storage_init_args():
from vdirsyncer.storage.memory import MemoryStorage from vdirsyncer.storage.memory import MemoryStorage
all, required = utils.get_storage_init_args(MemoryStorage) all, required = utils.get_storage_init_args(MemoryStorage)
assert all == {'fileext', 'collection', 'read_only', 'instance_name'} assert all == {"fileext", "collection", "read_only", "instance_name"}
assert not required assert not required
def test_request_ssl(): def test_request_ssl():
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request('GET', "https://self-signed.badssl.com/") http.request("GET", "https://self-signed.badssl.com/")
assert 'certificate verify failed' in str(excinfo.value) assert "certificate verify failed" in str(excinfo.value)
http.request('GET', "https://self-signed.badssl.com/", verify=False) http.request("GET", "https://self-signed.badssl.com/", verify=False)
def _fingerprints_broken(): def _fingerprints_broken():
from pkg_resources import parse_version as ver from pkg_resources import parse_version as ver
broken_urllib3 = ver(requests.__version__) <= ver('2.5.1')
broken_urllib3 = ver(requests.__version__) <= ver("2.5.1")
return broken_urllib3 return broken_urllib3
@pytest.mark.skipif(_fingerprints_broken(), @pytest.mark.skipif(
reason='https://github.com/shazow/urllib3/issues/529') _fingerprints_broken(), reason="https://github.com/shazow/urllib3/issues/529"
@pytest.mark.parametrize('fingerprint', [ )
'94:FD:7A:CB:50:75:A4:69:82:0A:F8:23:DF:07:FC:69:3E:CD:90:CA', @pytest.mark.parametrize(
'19:90:F7:23:94:F2:EF:AB:2B:64:2D:57:3D:25:95:2D' "fingerprint",
]) [
"94:FD:7A:CB:50:75:A4:69:82:0A:F8:23:DF:07:FC:69:3E:CD:90:CA",
"19:90:F7:23:94:F2:EF:AB:2B:64:2D:57:3D:25:95:2D",
],
)
def test_request_ssl_fingerprints(httpsserver, fingerprint): def test_request_ssl_fingerprints(httpsserver, fingerprint):
httpsserver.serve_content('') # we need to serve something httpsserver.serve_content("") # we need to serve something
http.request('GET', httpsserver.url, verify=False, http.request("GET", httpsserver.url, verify=False, verify_fingerprint=fingerprint)
verify_fingerprint=fingerprint)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request('GET', httpsserver.url, http.request("GET", httpsserver.url, verify_fingerprint=fingerprint)
verify_fingerprint=fingerprint)
with pytest.raises(requests.exceptions.ConnectionError) as excinfo: with pytest.raises(requests.exceptions.ConnectionError) as excinfo:
http.request('GET', httpsserver.url, verify=False, http.request(
verify_fingerprint=''.join(reversed(fingerprint))) "GET",
assert 'Fingerprints did not match' in str(excinfo.value) httpsserver.url,
verify=False,
verify_fingerprint="".join(reversed(fingerprint)),
)
assert "Fingerprints did not match" in str(excinfo.value)
def test_open_graphical_browser(monkeypatch): def test_open_graphical_browser(monkeypatch):
import webbrowser import webbrowser
# Just assert that this internal attribute still exists and behaves the way # Just assert that this internal attribute still exists and behaves the way
# expected # expected
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
@ -67,9 +75,9 @@ def test_open_graphical_browser(monkeypatch):
else: else:
assert webbrowser._tryorder is None assert webbrowser._tryorder is None
monkeypatch.setattr('webbrowser._tryorder', []) monkeypatch.setattr("webbrowser._tryorder", [])
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
utils.open_graphical_browser('http://example.com') utils.open_graphical_browser("http://example.com")
assert 'No graphical browser found' in str(excinfo.value) assert "No graphical browser found" in str(excinfo.value)

View file

@ -7,18 +7,20 @@ from vdirsyncer.vobject import Item
def test_conflict_resolution_command(): def test_conflict_resolution_command():
def check_call(command): def check_call(command):
command, a_tmp, b_tmp = command command, a_tmp, b_tmp = command
assert command == os.path.expanduser('~/command') assert command == os.path.expanduser("~/command")
with open(a_tmp) as f: with open(a_tmp) as f:
assert f.read() == a.raw assert f.read() == a.raw
with open(b_tmp) as f: with open(b_tmp) as f:
assert f.read() == b.raw assert f.read() == b.raw
with open(b_tmp, 'w') as f: with open(b_tmp, "w") as f:
f.write(a.raw) f.write(a.raw)
a = Item('UID:AAAAAAA') a = Item("UID:AAAAAAA")
b = Item('UID:BBBBBBB') b = Item("UID:BBBBBBB")
assert _resolve_conflict_via_command( assert (
a, b, ['~/command'], 'a', 'b', _resolve_conflict_via_command(
_check_call=check_call a, b, ["~/command"], "a", "b", _check_call=check_call
).raw == a.raw ).raw
== a.raw
)

View file

@ -6,74 +6,161 @@ from vdirsyncer.cli.discover import expand_collections
missing = object() missing = object()
@pytest.mark.parametrize('shortcuts,expected', [ @pytest.mark.parametrize(
(['from a'], [ "shortcuts,expected",
('c1', ({'type': 'fooboo', 'custom_arg': 'a1', 'collection': 'c1'}, [
{'type': 'fooboo', 'custom_arg': 'b1', 'collection': 'c1'})), (
('c2', ({'type': 'fooboo', 'custom_arg': 'a2', 'collection': 'c2'}, ["from a"],
{'type': 'fooboo', 'custom_arg': 'b2', 'collection': 'c2'})), [
('a3', ({'type': 'fooboo', 'custom_arg': 'a3', 'collection': 'a3'}, (
missing)) "c1",
]), (
(['from b'], [ {"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
('c1', ({'type': 'fooboo', 'custom_arg': 'a1', 'collection': 'c1'}, {"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
{'type': 'fooboo', 'custom_arg': 'b1', 'collection': 'c1'})), ),
('c2', ({'type': 'fooboo', 'custom_arg': 'a2', 'collection': 'c2'}, ),
{'type': 'fooboo', 'custom_arg': 'b2', 'collection': 'c2'})), (
('b3', (missing, "c2",
{'type': 'fooboo', 'custom_arg': 'b3', 'collection': 'b3'})) (
]), {"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
(['from a', 'from b'], [ {"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
('c1', ({'type': 'fooboo', 'custom_arg': 'a1', 'collection': 'c1'}, ),
{'type': 'fooboo', 'custom_arg': 'b1', 'collection': 'c1'})), ),
('c2', ({'type': 'fooboo', 'custom_arg': 'a2', 'collection': 'c2'}, (
{'type': 'fooboo', 'custom_arg': 'b2', 'collection': 'c2'})), "a3",
('a3', ({'type': 'fooboo', 'custom_arg': 'a3', 'collection': 'a3'}, (
missing)), {"type": "fooboo", "custom_arg": "a3", "collection": "a3"},
('b3', (missing, missing,
{'type': 'fooboo', 'custom_arg': 'b3', 'collection': 'b3'})) ),
]), ),
([['c12', 'c1', 'c2']], [ ],
('c12', ({'type': 'fooboo', 'custom_arg': 'a1', 'collection': 'c1'}, ),
{'type': 'fooboo', 'custom_arg': 'b2', 'collection': 'c2'})), (
]), ["from b"],
(None, [ [
(None, ({'type': 'fooboo', 'storage_side': 'a', 'collection': None}, (
{'type': 'fooboo', 'storage_side': 'b', 'collection': None})) "c1",
]), (
([None], [ {"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
(None, ({'type': 'fooboo', 'storage_side': 'a', 'collection': None}, {"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
{'type': 'fooboo', 'storage_side': 'b', 'collection': None})) ),
]), ),
]) (
"c2",
(
{"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
{"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
),
),
(
"b3",
(
missing,
{"type": "fooboo", "custom_arg": "b3", "collection": "b3"},
),
),
],
),
(
["from a", "from b"],
[
(
"c1",
(
{"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
{"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
),
),
(
"c2",
(
{"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
{"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
),
),
(
"a3",
(
{"type": "fooboo", "custom_arg": "a3", "collection": "a3"},
missing,
),
),
(
"b3",
(
missing,
{"type": "fooboo", "custom_arg": "b3", "collection": "b3"},
),
),
],
),
(
[["c12", "c1", "c2"]],
[
(
"c12",
(
{"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
{"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
),
),
],
),
(
None,
[
(
None,
(
{"type": "fooboo", "storage_side": "a", "collection": None},
{"type": "fooboo", "storage_side": "b", "collection": None},
),
)
],
),
(
[None],
[
(
None,
(
{"type": "fooboo", "storage_side": "a", "collection": None},
{"type": "fooboo", "storage_side": "b", "collection": None},
),
)
],
),
],
)
def test_expand_collections(shortcuts, expected): def test_expand_collections(shortcuts, expected):
config_a = { config_a = {"type": "fooboo", "storage_side": "a"}
'type': 'fooboo',
'storage_side': 'a'
}
config_b = { config_b = {"type": "fooboo", "storage_side": "b"}
'type': 'fooboo',
'storage_side': 'b'
}
def get_discovered_a(): def get_discovered_a():
return { return {
'c1': {'type': 'fooboo', 'custom_arg': 'a1', 'collection': 'c1'}, "c1": {"type": "fooboo", "custom_arg": "a1", "collection": "c1"},
'c2': {'type': 'fooboo', 'custom_arg': 'a2', 'collection': 'c2'}, "c2": {"type": "fooboo", "custom_arg": "a2", "collection": "c2"},
'a3': {'type': 'fooboo', 'custom_arg': 'a3', 'collection': 'a3'} "a3": {"type": "fooboo", "custom_arg": "a3", "collection": "a3"},
} }
def get_discovered_b(): def get_discovered_b():
return { return {
'c1': {'type': 'fooboo', 'custom_arg': 'b1', 'collection': 'c1'}, "c1": {"type": "fooboo", "custom_arg": "b1", "collection": "c1"},
'c2': {'type': 'fooboo', 'custom_arg': 'b2', 'collection': 'c2'}, "c2": {"type": "fooboo", "custom_arg": "b2", "collection": "c2"},
'b3': {'type': 'fooboo', 'custom_arg': 'b3', 'collection': 'b3'} "b3": {"type": "fooboo", "custom_arg": "b3", "collection": "b3"},
} }
assert sorted(expand_collections( assert (
shortcuts, sorted(
config_a, config_b, expand_collections(
get_discovered_a, get_discovered_b, shortcuts,
lambda config, collection: missing config_a,
)) == sorted(expected) config_b,
get_discovered_a,
get_discovered_b,
lambda config, collection: missing,
)
)
== sorted(expected)
)

View file

@ -15,8 +15,9 @@ def mystrategy(monkeypatch):
def strategy(x): def strategy(x):
calls.append(x) calls.append(x)
return x return x
calls = [] calls = []
monkeypatch.setitem(STRATEGIES, 'mystrategy', strategy) monkeypatch.setitem(STRATEGIES, "mystrategy", strategy)
return calls return calls
@ -44,18 +45,15 @@ def value_cache(monkeypatch):
def get_context(*a, **kw): def get_context(*a, **kw):
return FakeContext() return FakeContext()
monkeypatch.setattr('click.get_current_context', get_context) monkeypatch.setattr("click.get_current_context", get_context)
return _cache return _cache
def test_key_conflict(monkeypatch, mystrategy): def test_key_conflict(monkeypatch, mystrategy):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
expand_fetch_params({ expand_fetch_params({"foo": "bar", "foo.fetch": ["mystrategy", "baz"]})
'foo': 'bar',
'foo.fetch': ['mystrategy', 'baz']
})
assert 'Can\'t set foo.fetch and foo.' in str(excinfo.value) assert "Can't set foo.fetch and foo." in str(excinfo.value)
@given(s=st.text(), t=st.text(min_size=1)) @given(s=st.text(), t=st.text(min_size=1))
@ -66,47 +64,40 @@ def test_fuzzing(s, t):
assert config[s] == t assert config[s] == t
@pytest.mark.parametrize('value', [ @pytest.mark.parametrize("value", [[], "lol", 42])
[],
'lol',
42
])
def test_invalid_fetch_value(mystrategy, value): def test_invalid_fetch_value(mystrategy, value):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
expand_fetch_params({ expand_fetch_params({"foo.fetch": value})
'foo.fetch': value
})
assert 'Expected a list' in str(excinfo.value) or \ assert "Expected a list" in str(
'Expected list of length > 0' in str(excinfo.value) excinfo.value
) or "Expected list of length > 0" in str(excinfo.value)
def test_unknown_strategy(): def test_unknown_strategy():
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
expand_fetch_params({ expand_fetch_params({"foo.fetch": ["unreal", "asdf"]})
'foo.fetch': ['unreal', 'asdf']
})
assert 'Unknown strategy' in str(excinfo.value) assert "Unknown strategy" in str(excinfo.value)
def test_caching(monkeypatch, mystrategy, value_cache): def test_caching(monkeypatch, mystrategy, value_cache):
orig_cfg = {'foo.fetch': ['mystrategy', 'asdf']} orig_cfg = {"foo.fetch": ["mystrategy", "asdf"]}
rv = expand_fetch_params(orig_cfg) rv = expand_fetch_params(orig_cfg)
assert rv['foo'] == 'asdf' assert rv["foo"] == "asdf"
assert mystrategy == ['asdf'] assert mystrategy == ["asdf"]
assert len(value_cache) == 1 assert len(value_cache) == 1
rv = expand_fetch_params(orig_cfg) rv = expand_fetch_params(orig_cfg)
assert rv['foo'] == 'asdf' assert rv["foo"] == "asdf"
assert mystrategy == ['asdf'] assert mystrategy == ["asdf"]
assert len(value_cache) == 1 assert len(value_cache) == 1
value_cache.clear() value_cache.clear()
rv = expand_fetch_params(orig_cfg) rv = expand_fetch_params(orig_cfg)
assert rv['foo'] == 'asdf' assert rv["foo"] == "asdf"
assert mystrategy == ['asdf'] * 2 assert mystrategy == ["asdf"] * 2
assert len(value_cache) == 1 assert len(value_cache) == 1
@ -117,9 +108,9 @@ def test_failed_strategy(monkeypatch, value_cache):
calls.append(x) calls.append(x)
raise KeyboardInterrupt() raise KeyboardInterrupt()
monkeypatch.setitem(STRATEGIES, 'mystrategy', strategy) monkeypatch.setitem(STRATEGIES, "mystrategy", strategy)
orig_cfg = {'foo.fetch': ['mystrategy', 'asdf']} orig_cfg = {"foo.fetch": ["mystrategy", "asdf"]}
for _ in range(2): for _ in range(2):
with pytest.raises(KeyboardInterrupt): with pytest.raises(KeyboardInterrupt):
@ -131,9 +122,8 @@ def test_failed_strategy(monkeypatch, value_cache):
def test_empty_value(monkeypatch, mystrategy): def test_empty_value(monkeypatch, mystrategy):
with pytest.raises(exceptions.UserError) as excinfo: with pytest.raises(exceptions.UserError) as excinfo:
expand_fetch_params({ expand_fetch_params({"foo.fetch": ["mystrategy", ""]})
'foo.fetch': ['mystrategy', '']
})
assert 'Empty value for foo.fetch, this most likely indicates an error' \ assert "Empty value for foo.fetch, this most likely indicates an error" in str(
in str(excinfo.value) excinfo.value
)

View file

@ -7,28 +7,29 @@ from vdirsyncer.sync.status import SqliteStatus
status_dict_strategy = st.dictionaries( status_dict_strategy = st.dictionaries(
st.text(), st.text(),
st.tuples(*( st.tuples(
st.fixed_dictionaries({ *(
'href': st.text(), st.fixed_dictionaries(
'hash': st.text(), {"href": st.text(), "hash": st.text(), "etag": st.text()}
'etag': st.text() )
}) for _ in range(2) for _ in range(2)
)) )
),
) )
@given(status_dict=status_dict_strategy) @given(status_dict=status_dict_strategy)
def test_legacy_status(status_dict): def test_legacy_status(status_dict):
hrefs_a = {meta_a['href'] for meta_a, meta_b in status_dict.values()} hrefs_a = {meta_a["href"] for meta_a, meta_b in status_dict.values()}
hrefs_b = {meta_b['href'] for meta_a, meta_b in status_dict.values()} hrefs_b = {meta_b["href"] for meta_a, meta_b in status_dict.values()}
assume(len(hrefs_a) == len(status_dict) == len(hrefs_b)) assume(len(hrefs_a) == len(status_dict) == len(hrefs_b))
status = SqliteStatus() status = SqliteStatus()
status.load_legacy_status(status_dict) status.load_legacy_status(status_dict)
assert dict(status.to_legacy_status()) == status_dict assert dict(status.to_legacy_status()) == status_dict
for ident, (meta_a, meta_b) in status_dict.items(): for ident, (meta_a, meta_b) in status_dict.items():
ident_a, meta2_a = status.get_by_href_a(meta_a['href']) ident_a, meta2_a = status.get_by_href_a(meta_a["href"])
ident_b, meta2_b = status.get_by_href_b(meta_b['href']) ident_b, meta2_b = status.get_by_href_b(meta_b["href"])
assert meta2_a.to_status() == meta_a assert meta2_a.to_status() == meta_a
assert meta2_b.to_status() == meta_b assert meta2_b.to_status() == meta_b
assert ident_a == ident_b == ident assert ident_a == ident_b == ident

View file

@ -22,7 +22,7 @@ from vdirsyncer.vobject import Item
def sync(a, b, status, *args, **kwargs): def sync(a, b, status, *args, **kwargs):
new_status = SqliteStatus(':memory:') new_status = SqliteStatus(":memory:")
new_status.load_legacy_status(status) new_status.load_legacy_status(status)
rv = _sync(a, b, new_status, *args, **kwargs) rv = _sync(a, b, new_status, *args, **kwargs)
status.clear() status.clear()
@ -41,7 +41,7 @@ def items(s):
def test_irrelevant_status(): def test_irrelevant_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {'1': ('1', 1234, '1.ics', 2345)} status = {"1": ("1", 1234, "1.ics", 2345)}
sync(a, b, status) sync(a, b, status)
assert not status assert not status
assert not items(a) assert not items(a)
@ -52,7 +52,7 @@ def test_missing_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item = Item('asdf') item = Item("asdf")
a.upload(item) a.upload(item)
b.upload(item) b.upload(item)
sync(a, b, status) sync(a, b, status)
@ -65,14 +65,14 @@ def test_missing_status_and_different_items():
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item1 = Item('UID:1\nhaha') item1 = Item("UID:1\nhaha")
item2 = Item('UID:1\nhoho') item2 = Item("UID:1\nhoho")
a.upload(item1) a.upload(item1)
b.upload(item2) b.upload(item2)
with pytest.raises(SyncConflict): with pytest.raises(SyncConflict):
sync(a, b, status) sync(a, b, status)
assert not status assert not status
sync(a, b, status, conflict_resolution='a wins') sync(a, b, status, conflict_resolution="a wins")
assert items(a) == items(b) == {item1.raw} assert items(a) == items(b) == {item1.raw}
@ -82,8 +82,8 @@ def test_read_only_and_prefetch():
b.read_only = True b.read_only = True
status = {} status = {}
item1 = Item('UID:1\nhaha') item1 = Item("UID:1\nhaha")
item2 = Item('UID:2\nhoho') item2 = Item("UID:2\nhoho")
a.upload(item1) a.upload(item1)
a.upload(item2) a.upload(item2)
@ -98,11 +98,11 @@ def test_partial_sync_error():
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.upload(Item('UID:0')) a.upload(Item("UID:0"))
b.read_only = True b.read_only = True
with pytest.raises(PartialSync): with pytest.raises(PartialSync):
sync(a, b, status, partial_sync='error') sync(a, b, status, partial_sync="error")
def test_partial_sync_ignore(): def test_partial_sync_ignore():
@ -110,17 +110,17 @@ def test_partial_sync_ignore():
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item0 = Item('UID:0\nhehe') item0 = Item("UID:0\nhehe")
a.upload(item0) a.upload(item0)
b.upload(item0) b.upload(item0)
b.read_only = True b.read_only = True
item1 = Item('UID:1\nhaha') item1 = Item("UID:1\nhaha")
a.upload(item1) a.upload(item1)
sync(a, b, status, partial_sync='ignore') sync(a, b, status, partial_sync="ignore")
sync(a, b, status, partial_sync='ignore') sync(a, b, status, partial_sync="ignore")
assert items(a) == {item0.raw, item1.raw} assert items(a) == {item0.raw, item1.raw}
assert items(b) == {item0.raw} assert items(b) == {item0.raw}
@ -131,69 +131,69 @@ def test_partial_sync_ignore2():
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item('UID:0')) href, etag = a.upload(Item("UID:0"))
a.read_only = True a.read_only = True
sync(a, b, status, partial_sync='ignore', force_delete=True) sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {'UID:0'} assert items(b) == items(a) == {"UID:0"}
b.items.clear() b.items.clear()
sync(a, b, status, partial_sync='ignore', force_delete=True) sync(a, b, status, partial_sync="ignore", force_delete=True)
sync(a, b, status, partial_sync='ignore', force_delete=True) sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(a) == {'UID:0'} assert items(a) == {"UID:0"}
assert not b.items assert not b.items
a.read_only = False a.read_only = False
a.update(href, Item('UID:0\nupdated'), etag) a.update(href, Item("UID:0\nupdated"), etag)
a.read_only = True a.read_only = True
sync(a, b, status, partial_sync='ignore', force_delete=True) sync(a, b, status, partial_sync="ignore", force_delete=True)
assert items(b) == items(a) == {'UID:0\nupdated'} assert items(b) == items(a) == {"UID:0\nupdated"}
def test_upload_and_update(): def test_upload_and_update():
a = MemoryStorage(fileext='.a') a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext='.b') b = MemoryStorage(fileext=".b")
status = {} status = {}
item = Item('UID:1') # new item 1 in a item = Item("UID:1") # new item 1 in a
a.upload(item) a.upload(item)
sync(a, b, status) sync(a, b, status)
assert items(b) == items(a) == {item.raw} assert items(b) == items(a) == {item.raw}
item = Item('UID:1\nASDF:YES') # update of item 1 in b item = Item("UID:1\nASDF:YES") # update of item 1 in b
b.update('1.b', item, b.get('1.b')[1]) b.update("1.b", item, b.get("1.b")[1])
sync(a, b, status) sync(a, b, status)
assert items(b) == items(a) == {item.raw} assert items(b) == items(a) == {item.raw}
item2 = Item('UID:2') # new item 2 in b item2 = Item("UID:2") # new item 2 in b
b.upload(item2) b.upload(item2)
sync(a, b, status) sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw} assert items(b) == items(a) == {item.raw, item2.raw}
item2 = Item('UID:2\nASDF:YES') # update of item 2 in a item2 = Item("UID:2\nASDF:YES") # update of item 2 in a
a.update('2.a', item2, a.get('2.a')[1]) a.update("2.a", item2, a.get("2.a")[1])
sync(a, b, status) sync(a, b, status)
assert items(b) == items(a) == {item.raw, item2.raw} assert items(b) == items(a) == {item.raw, item2.raw}
def test_deletion(): def test_deletion():
a = MemoryStorage(fileext='.a') a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext='.b') b = MemoryStorage(fileext=".b")
status = {} status = {}
item = Item('UID:1') item = Item("UID:1")
a.upload(item) a.upload(item)
item2 = Item('UID:2') item2 = Item("UID:2")
a.upload(item2) a.upload(item2)
sync(a, b, status) sync(a, b, status)
b.delete('1.b', b.get('1.b')[1]) b.delete("1.b", b.get("1.b")[1])
sync(a, b, status) sync(a, b, status)
assert items(a) == items(b) == {item2.raw} assert items(a) == items(b) == {item2.raw}
a.upload(item) a.upload(item)
sync(a, b, status) sync(a, b, status)
assert items(a) == items(b) == {item.raw, item2.raw} assert items(a) == items(b) == {item.raw, item2.raw}
a.delete('1.a', a.get('1.a')[1]) a.delete("1.a", a.get("1.a")[1])
sync(a, b, status) sync(a, b, status)
assert items(a) == items(b) == {item2.raw} assert items(a) == items(b) == {item2.raw}
@ -203,38 +203,34 @@ def test_insert_hash():
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
item = Item('UID:1') item = Item("UID:1")
href, etag = a.upload(item) href, etag = a.upload(item)
sync(a, b, status) sync(a, b, status)
for d in status['1']: for d in status["1"]:
del d['hash'] del d["hash"]
a.update(href, Item('UID:1\nHAHA:YES'), etag) a.update(href, Item("UID:1\nHAHA:YES"), etag)
sync(a, b, status) sync(a, b, status)
assert 'hash' in status['1'][0] and 'hash' in status['1'][1] assert "hash" in status["1"][0] and "hash" in status["1"][1]
def test_already_synced(): def test_already_synced():
a = MemoryStorage(fileext='.a') a = MemoryStorage(fileext=".a")
b = MemoryStorage(fileext='.b') b = MemoryStorage(fileext=".b")
item = Item('UID:1') item = Item("UID:1")
a.upload(item) a.upload(item)
b.upload(item) b.upload(item)
status = { status = {
'1': ({ "1": (
'href': '1.a', {"href": "1.a", "hash": item.hash, "etag": a.get("1.a")[1]},
'hash': item.hash, {"href": "1.b", "hash": item.hash, "etag": b.get("1.b")[1]},
'etag': a.get('1.a')[1] )
}, {
'href': '1.b',
'hash': item.hash,
'etag': b.get('1.b')[1]
})
} }
old_status = deepcopy(status) old_status = deepcopy(status)
a.update = b.update = a.upload = b.upload = \ a.update = b.update = a.upload = b.upload = lambda *a, **kw: pytest.fail(
lambda *a, **kw: pytest.fail('Method shouldn\'t have been called.') "Method shouldn't have been called."
)
for _ in (1, 2): for _ in (1, 2):
sync(a, b, status) sync(a, b, status)
@ -242,38 +238,38 @@ def test_already_synced():
assert items(a) == items(b) == {item.raw} assert items(a) == items(b) == {item.raw}
@pytest.mark.parametrize('winning_storage', 'ab') @pytest.mark.parametrize("winning_storage", "ab")
def test_conflict_resolution_both_etags_new(winning_storage): def test_conflict_resolution_both_etags_new(winning_storage):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item('UID:1') item = Item("UID:1")
href_a, etag_a = a.upload(item) href_a, etag_a = a.upload(item)
href_b, etag_b = b.upload(item) href_b, etag_b = b.upload(item)
status = {} status = {}
sync(a, b, status) sync(a, b, status)
assert status assert status
item_a = Item('UID:1\nitem a') item_a = Item("UID:1\nitem a")
item_b = Item('UID:1\nitem b') item_b = Item("UID:1\nitem b")
a.update(href_a, item_a, etag_a) a.update(href_a, item_a, etag_a)
b.update(href_b, item_b, etag_b) b.update(href_b, item_b, etag_b)
with pytest.raises(SyncConflict): with pytest.raises(SyncConflict):
sync(a, b, status) sync(a, b, status)
sync(a, b, status, conflict_resolution=f'{winning_storage} wins') sync(a, b, status, conflict_resolution=f"{winning_storage} wins")
assert items(a) == items(b) == { assert (
item_a.raw if winning_storage == 'a' else item_b.raw items(a) == items(b) == {item_a.raw if winning_storage == "a" else item_b.raw}
} )
def test_updated_and_deleted(): def test_updated_and_deleted():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
href_a, etag_a = a.upload(Item('UID:1')) href_a, etag_a = a.upload(Item("UID:1"))
status = {} status = {}
sync(a, b, status, force_delete=True) sync(a, b, status, force_delete=True)
(href_b, etag_b), = b.list() ((href_b, etag_b),) = b.list()
b.delete(href_b, etag_b) b.delete(href_b, etag_b)
updated = Item('UID:1\nupdated') updated = Item("UID:1\nupdated")
a.update(href_a, updated, etag_a) a.update(href_a, updated, etag_a)
sync(a, b, status, force_delete=True) sync(a, b, status, force_delete=True)
@ -283,35 +279,35 @@ def test_updated_and_deleted():
def test_conflict_resolution_invalid_mode(): def test_conflict_resolution_invalid_mode():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item_a = Item('UID:1\nitem a') item_a = Item("UID:1\nitem a")
item_b = Item('UID:1\nitem b') item_b = Item("UID:1\nitem b")
a.upload(item_a) a.upload(item_a)
b.upload(item_b) b.upload(item_b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
sync(a, b, {}, conflict_resolution='yolo') sync(a, b, {}, conflict_resolution="yolo")
def test_conflict_resolution_new_etags_without_changes(): def test_conflict_resolution_new_etags_without_changes():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item('UID:1') item = Item("UID:1")
href_a, etag_a = a.upload(item) href_a, etag_a = a.upload(item)
href_b, etag_b = b.upload(item) href_b, etag_b = b.upload(item)
status = {'1': (href_a, 'BOGUS_a', href_b, 'BOGUS_b')} status = {"1": (href_a, "BOGUS_a", href_b, "BOGUS_b")}
sync(a, b, status) sync(a, b, status)
(ident, (status_a, status_b)), = status.items() ((ident, (status_a, status_b)),) = status.items()
assert ident == '1' assert ident == "1"
assert status_a['href'] == href_a assert status_a["href"] == href_a
assert status_a['etag'] == etag_a assert status_a["etag"] == etag_a
assert status_b['href'] == href_b assert status_b["href"] == href_b
assert status_b['etag'] == etag_b assert status_b["etag"] == etag_b
def test_uses_get_multi(monkeypatch): def test_uses_get_multi(monkeypatch):
def breakdown(*a, **kw): def breakdown(*a, **kw):
raise AssertionError('Expected use of get_multi') raise AssertionError("Expected use of get_multi")
get_multi_calls = [] get_multi_calls = []
@ -324,12 +320,12 @@ def test_uses_get_multi(monkeypatch):
item, etag = old_get(self, href) item, etag = old_get(self, href)
yield href, item, etag yield href, item, etag
monkeypatch.setattr(MemoryStorage, 'get', breakdown) monkeypatch.setattr(MemoryStorage, "get", breakdown)
monkeypatch.setattr(MemoryStorage, 'get_multi', get_multi) monkeypatch.setattr(MemoryStorage, "get_multi", get_multi)
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
item = Item('UID:1') item = Item("UID:1")
expected_href, etag = a.upload(item) expected_href, etag = a.upload(item)
sync(a, b, {}) sync(a, b, {})
@ -339,8 +335,8 @@ def test_uses_get_multi(monkeypatch):
def test_empty_storage_dataloss(): def test_empty_storage_dataloss():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.upload(Item('UID:1')) a.upload(Item("UID:1"))
a.upload(Item('UID:2')) a.upload(Item("UID:2"))
status = {} status = {}
sync(a, b, status) sync(a, b, status)
with pytest.raises(StorageEmpty): with pytest.raises(StorageEmpty):
@ -353,22 +349,22 @@ def test_empty_storage_dataloss():
def test_no_uids(): def test_no_uids():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.upload(Item('ASDF')) a.upload(Item("ASDF"))
b.upload(Item('FOOBAR')) b.upload(Item("FOOBAR"))
status = {} status = {}
sync(a, b, status) sync(a, b, status)
assert items(a) == items(b) == {'ASDF', 'FOOBAR'} assert items(a) == items(b) == {"ASDF", "FOOBAR"}
def test_changed_uids(): def test_changed_uids():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
href_a, etag_a = a.upload(Item('UID:A-ONE')) href_a, etag_a = a.upload(Item("UID:A-ONE"))
href_b, etag_b = b.upload(Item('UID:B-ONE')) href_b, etag_b = b.upload(Item("UID:B-ONE"))
status = {} status = {}
sync(a, b, status) sync(a, b, status)
a.update(href_a, Item('UID:A-TWO'), etag_a) a.update(href_a, Item("UID:A-TWO"), etag_a)
sync(a, b, status) sync(a, b, status)
@ -383,71 +379,71 @@ def test_both_readonly():
def test_partial_sync_revert(): def test_partial_sync_revert():
a = MemoryStorage(instance_name='a') a = MemoryStorage(instance_name="a")
b = MemoryStorage(instance_name='b') b = MemoryStorage(instance_name="b")
status = {} status = {}
a.upload(Item('UID:1')) a.upload(Item("UID:1"))
b.upload(Item('UID:2')) b.upload(Item("UID:2"))
b.read_only = True b.read_only = True
sync(a, b, status, partial_sync='revert') sync(a, b, status, partial_sync="revert")
assert len(status) == 2 assert len(status) == 2
assert items(a) == {'UID:1', 'UID:2'} assert items(a) == {"UID:1", "UID:2"}
assert items(b) == {'UID:2'} assert items(b) == {"UID:2"}
sync(a, b, status, partial_sync='revert') sync(a, b, status, partial_sync="revert")
assert len(status) == 1 assert len(status) == 1
assert items(a) == {'UID:2'} assert items(a) == {"UID:2"}
assert items(b) == {'UID:2'} assert items(b) == {"UID:2"}
# Check that updates get reverted # Check that updates get reverted
a.items[next(iter(a.items))] = ('foo', Item('UID:2\nupdated')) a.items[next(iter(a.items))] = ("foo", Item("UID:2\nupdated"))
assert items(a) == {'UID:2\nupdated'} assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync='revert') sync(a, b, status, partial_sync="revert")
assert len(status) == 1 assert len(status) == 1
assert items(a) == {'UID:2\nupdated'} assert items(a) == {"UID:2\nupdated"}
sync(a, b, status, partial_sync='revert') sync(a, b, status, partial_sync="revert")
assert items(a) == {'UID:2'} assert items(a) == {"UID:2"}
# Check that deletions get reverted # Check that deletions get reverted
a.items.clear() a.items.clear()
sync(a, b, status, partial_sync='revert', force_delete=True) sync(a, b, status, partial_sync="revert", force_delete=True)
sync(a, b, status, partial_sync='revert', force_delete=True) sync(a, b, status, partial_sync="revert", force_delete=True)
assert items(a) == {'UID:2'} assert items(a) == {"UID:2"}
@pytest.mark.parametrize('sync_inbetween', (True, False)) @pytest.mark.parametrize("sync_inbetween", (True, False))
def test_ident_conflict(sync_inbetween): def test_ident_conflict(sync_inbetween):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href_a, etag_a = a.upload(Item('UID:aaa')) href_a, etag_a = a.upload(Item("UID:aaa"))
href_b, etag_b = a.upload(Item('UID:bbb')) href_b, etag_b = a.upload(Item("UID:bbb"))
if sync_inbetween: if sync_inbetween:
sync(a, b, status) sync(a, b, status)
a.update(href_a, Item('UID:xxx'), etag_a) a.update(href_a, Item("UID:xxx"), etag_a)
a.update(href_b, Item('UID:xxx'), etag_b) a.update(href_b, Item("UID:xxx"), etag_b)
with pytest.raises(IdentConflict): with pytest.raises(IdentConflict):
sync(a, b, status) sync(a, b, status)
def test_moved_href(): def test_moved_href():
''' """
Concrete application: ppl_ stores contact aliases in filenames, which means Concrete application: ppl_ stores contact aliases in filenames, which means
item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but item's hrefs get changed. Vdirsyncer doesn't synchronize this data, but
also shouldn't do things like deleting and re-uploading to the server. also shouldn't do things like deleting and re-uploading to the server.
.. _ppl: http://ppladdressbook.org/ .. _ppl: http://ppladdressbook.org/
''' """
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item('UID:haha')) href, etag = a.upload(Item("UID:haha"))
sync(a, b, status) sync(a, b, status)
b.items['lol'] = b.items.pop('haha') b.items["lol"] = b.items.pop("haha")
# The sync algorithm should prefetch `lol`, see that it's the same ident # The sync algorithm should prefetch `lol`, see that it's the same ident
# and not do anything else. # and not do anything else.
@ -457,8 +453,8 @@ def test_moved_href():
sync(a, b, status) sync(a, b, status)
assert len(status) == 1 assert len(status) == 1
assert items(a) == items(b) == {'UID:haha'} assert items(a) == items(b) == {"UID:haha"}
assert status['haha'][1]['href'] == 'lol' assert status["haha"][1]["href"] == "lol"
old_status = deepcopy(status) old_status = deepcopy(status)
# Further sync should be a noop. Not even prefetching should occur. # Further sync should be a noop. Not even prefetching should occur.
@ -466,39 +462,39 @@ def test_moved_href():
sync(a, b, status) sync(a, b, status)
assert old_status == status assert old_status == status
assert items(a) == items(b) == {'UID:haha'} assert items(a) == items(b) == {"UID:haha"}
def test_bogus_etag_change(): def test_bogus_etag_change():
'''Assert that sync algorithm is resilient against etag changes if content """Assert that sync algorithm is resilient against etag changes if content
didn\'t change. didn\'t change.
In this particular case we test a scenario where both etags have been In this particular case we test a scenario where both etags have been
updated, but only one side actually changed its item content. updated, but only one side actually changed its item content.
''' """
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href_a, etag_a = a.upload(Item('UID:ASDASD')) href_a, etag_a = a.upload(Item("UID:ASDASD"))
sync(a, b, status) sync(a, b, status)
assert len(status) == len(list(a.list())) == len(list(b.list())) == 1 assert len(status) == len(list(a.list())) == len(list(b.list())) == 1
(href_b, etag_b), = b.list() ((href_b, etag_b),) = b.list()
a.update(href_a, Item('UID:ASDASD'), etag_a) a.update(href_a, Item("UID:ASDASD"), etag_a)
b.update(href_b, Item('UID:ASDASD\nACTUALCHANGE:YES'), etag_b) b.update(href_b, Item("UID:ASDASD\nACTUALCHANGE:YES"), etag_b)
b.delete = b.update = b.upload = blow_up b.delete = b.update = b.upload = blow_up
sync(a, b, status) sync(a, b, status)
assert len(status) == 1 assert len(status) == 1
assert items(a) == items(b) == {'UID:ASDASD\nACTUALCHANGE:YES'} assert items(a) == items(b) == {"UID:ASDASD\nACTUALCHANGE:YES"}
def test_unicode_hrefs(): def test_unicode_hrefs():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
href, etag = a.upload(Item('UID:äää')) href, etag = a.upload(Item("UID:äää"))
sync(a, b, status) sync(a, b, status)
@ -511,27 +507,27 @@ def action_failure(*a, **kw):
class SyncMachine(RuleBasedStateMachine): class SyncMachine(RuleBasedStateMachine):
Status = Bundle('status') Status = Bundle("status")
Storage = Bundle('storage') Storage = Bundle("storage")
@rule(target=Storage, @rule(target=Storage, flaky_etags=st.booleans(), null_etag_on_upload=st.booleans())
flaky_etags=st.booleans(),
null_etag_on_upload=st.booleans())
def newstorage(self, flaky_etags, null_etag_on_upload): def newstorage(self, flaky_etags, null_etag_on_upload):
s = MemoryStorage() s = MemoryStorage()
if flaky_etags: if flaky_etags:
def get(href): def get(href):
old_etag, item = s.items[href] old_etag, item = s.items[href]
etag = _random_string() etag = _random_string()
s.items[href] = etag, item s.items[href] = etag, item
return item, etag return item, etag
s.get = get s.get = get
if null_etag_on_upload: if null_etag_on_upload:
_old_upload = s.upload _old_upload = s.upload
_old_update = s.update _old_update = s.update
s.upload = lambda item: (_old_upload(item)[0], 'NULL') s.upload = lambda item: (_old_upload(item)[0], "NULL")
s.update = lambda h, i, e: _old_update(h, i, e) and 'NULL' s.update = lambda h, i, e: _old_update(h, i, e) and "NULL"
return s return s
@ -564,11 +560,9 @@ class SyncMachine(RuleBasedStateMachine):
def newstatus(self): def newstatus(self):
return {} return {}
@rule(storage=Storage, @rule(storage=Storage, uid=uid_strategy, etag=st.text())
uid=uid_strategy,
etag=st.text())
def upload(self, storage, uid, etag): def upload(self, storage, uid, etag):
item = Item(f'UID:{uid}') item = Item(f"UID:{uid}")
storage.items[uid] = (etag, item) storage.items[uid] = (etag, item)
@rule(storage=Storage, href=st.text()) @rule(storage=Storage, href=st.text())
@ -577,22 +571,31 @@ class SyncMachine(RuleBasedStateMachine):
@rule( @rule(
status=Status, status=Status,
a=Storage, b=Storage, a=Storage,
b=Storage,
force_delete=st.booleans(), force_delete=st.booleans(),
conflict_resolution=st.one_of((st.just('a wins'), st.just('b wins'))), conflict_resolution=st.one_of((st.just("a wins"), st.just("b wins"))),
with_error_callback=st.booleans(), with_error_callback=st.booleans(),
partial_sync=st.one_of(( partial_sync=st.one_of(
st.just('ignore'), st.just('revert'), st.just('error') (st.just("ignore"), st.just("revert"), st.just("error"))
)) ),
) )
def sync(self, status, a, b, force_delete, conflict_resolution, def sync(
with_error_callback, partial_sync): self,
status,
a,
b,
force_delete,
conflict_resolution,
with_error_callback,
partial_sync,
):
assume(a is not b) assume(a is not b)
old_items_a = items(a) old_items_a = items(a)
old_items_b = items(b) old_items_b = items(b)
a.instance_name = 'a' a.instance_name = "a"
b.instance_name = 'b' b.instance_name = "b"
errors = [] errors = []
@ -605,16 +608,20 @@ class SyncMachine(RuleBasedStateMachine):
# If one storage is read-only, double-sync because changes don't # If one storage is read-only, double-sync because changes don't
# get reverted immediately. # get reverted immediately.
for _ in range(2 if a.read_only or b.read_only else 1): for _ in range(2 if a.read_only or b.read_only else 1):
sync(a, b, status, sync(
force_delete=force_delete, a,
conflict_resolution=conflict_resolution, b,
error_callback=error_callback, status,
partial_sync=partial_sync) force_delete=force_delete,
conflict_resolution=conflict_resolution,
error_callback=error_callback,
partial_sync=partial_sync,
)
for e in errors: for e in errors:
raise e raise e
except PartialSync: except PartialSync:
assert partial_sync == 'error' assert partial_sync == "error"
except ActionIntentionallyFailed: except ActionIntentionallyFailed:
pass pass
except BothReadOnly: except BothReadOnly:
@ -629,49 +636,55 @@ class SyncMachine(RuleBasedStateMachine):
items_a = items(a) items_a = items(a)
items_b = items(b) items_b = items(b)
assert items_a == items_b or partial_sync == 'ignore' assert items_a == items_b or partial_sync == "ignore"
assert items_a == old_items_a or not a.read_only assert items_a == old_items_a or not a.read_only
assert items_b == old_items_b or not b.read_only assert items_b == old_items_b or not b.read_only
assert set(a.items) | set(b.items) == set(status) or \ assert (
partial_sync == 'ignore' set(a.items) | set(b.items) == set(status) or partial_sync == "ignore"
)
TestSyncMachine = SyncMachine.TestCase TestSyncMachine = SyncMachine.TestCase
@pytest.mark.parametrize('error_callback', [True, False]) @pytest.mark.parametrize("error_callback", [True, False])
def test_rollback(error_callback): def test_rollback(error_callback):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.items['0'] = ('', Item('UID:0')) a.items["0"] = ("", Item("UID:0"))
b.items['1'] = ('', Item('UID:1')) b.items["1"] = ("", Item("UID:1"))
b.upload = b.update = b.delete = action_failure b.upload = b.update = b.delete = action_failure
if error_callback: if error_callback:
errors = [] errors = []
sync(a, b, status=status, conflict_resolution='a wins', sync(
error_callback=errors.append) a,
b,
status=status,
conflict_resolution="a wins",
error_callback=errors.append,
)
assert len(errors) == 1 assert len(errors) == 1
assert isinstance(errors[0], ActionIntentionallyFailed) assert isinstance(errors[0], ActionIntentionallyFailed)
assert len(status) == 1 assert len(status) == 1
assert status['1'] assert status["1"]
else: else:
with pytest.raises(ActionIntentionallyFailed): with pytest.raises(ActionIntentionallyFailed):
sync(a, b, status=status, conflict_resolution='a wins') sync(a, b, status=status, conflict_resolution="a wins")
def test_duplicate_hrefs(): def test_duplicate_hrefs():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
a.list = lambda: [('a', 'a')] * 3 a.list = lambda: [("a", "a")] * 3
a.items['a'] = ('a', Item('UID:a')) a.items["a"] = ("a", Item("UID:a"))
status = {} status = {}
sync(a, b, status) sync(a, b, status)

View file

@ -2,13 +2,12 @@ from vdirsyncer import exceptions
def test_user_error_problems(): def test_user_error_problems():
e = exceptions.UserError('A few problems occurred', problems=[ e = exceptions.UserError(
'Problem one', "A few problems occurred",
'Problem two', problems=["Problem one", "Problem two", "Problem three"],
'Problem three' )
])
assert 'one' in str(e) assert "one" in str(e)
assert 'two' in str(e) assert "two" in str(e)
assert 'three' in str(e) assert "three" in str(e)
assert 'problems occurred' in str(e) assert "problems occurred" in str(e)

View file

@ -15,7 +15,7 @@ from vdirsyncer.storage.memory import MemoryStorage
def test_irrelevant_status(): def test_irrelevant_status():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {'foo': 'bar'} status = {"foo": "bar"}
metasync(a, b, status, keys=()) metasync(a, b, status, keys=())
assert not status assert not status
@ -26,24 +26,24 @@ def test_basic(monkeypatch):
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta('foo', 'bar') a.set_meta("foo", "bar")
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
assert a.get_meta('foo') == b.get_meta('foo') == 'bar' assert a.get_meta("foo") == b.get_meta("foo") == "bar"
a.set_meta('foo', 'baz') a.set_meta("foo", "baz")
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
assert a.get_meta('foo') == b.get_meta('foo') == 'baz' assert a.get_meta("foo") == b.get_meta("foo") == "baz"
monkeypatch.setattr(a, 'set_meta', blow_up) monkeypatch.setattr(a, "set_meta", blow_up)
monkeypatch.setattr(b, 'set_meta', blow_up) monkeypatch.setattr(b, "set_meta", blow_up)
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
assert a.get_meta('foo') == b.get_meta('foo') == 'baz' assert a.get_meta("foo") == b.get_meta("foo") == "baz"
monkeypatch.undo() monkeypatch.undo()
monkeypatch.undo() monkeypatch.undo()
b.set_meta('foo', None) b.set_meta("foo", None)
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
assert not a.get_meta('foo') and not b.get_meta('foo') assert not a.get_meta("foo") and not b.get_meta("foo")
@pytest.fixture @pytest.fixture
@ -51,12 +51,12 @@ def conflict_state(request):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta('foo', 'bar') a.set_meta("foo", "bar")
b.set_meta('foo', 'baz') b.set_meta("foo", "baz")
def cleanup(): def cleanup():
assert a.get_meta('foo') == 'bar' assert a.get_meta("foo") == "bar"
assert b.get_meta('foo') == 'baz' assert b.get_meta("foo") == "baz"
assert not status assert not status
request.addfinalizer(cleanup) request.addfinalizer(cleanup)
@ -68,54 +68,61 @@ def test_conflict(conflict_state):
a, b, status = conflict_state a, b, status = conflict_state
with pytest.raises(MetaSyncConflict): with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
def test_invalid_conflict_resolution(conflict_state): def test_invalid_conflict_resolution(conflict_state):
a, b, status = conflict_state a, b, status = conflict_state
with pytest.raises(UserError) as excinfo: with pytest.raises(UserError) as excinfo:
metasync(a, b, status, keys=['foo'], conflict_resolution='foo') metasync(a, b, status, keys=["foo"], conflict_resolution="foo")
assert 'Invalid conflict resolution setting' in str(excinfo.value) assert "Invalid conflict resolution setting" in str(excinfo.value)
def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch): def test_warning_on_custom_conflict_commands(conflict_state, monkeypatch):
a, b, status = conflict_state a, b, status = conflict_state
warnings = [] warnings = []
monkeypatch.setattr(logger, 'warning', warnings.append) monkeypatch.setattr(logger, "warning", warnings.append)
with pytest.raises(MetaSyncConflict): with pytest.raises(MetaSyncConflict):
metasync(a, b, status, keys=['foo'], metasync(a, b, status, keys=["foo"], conflict_resolution=lambda *a, **kw: None)
conflict_resolution=lambda *a, **kw: None)
assert warnings == ['Custom commands don\'t work on metasync.'] assert warnings == ["Custom commands don't work on metasync."]
def test_conflict_same_content(): def test_conflict_same_content():
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta('foo', 'bar') a.set_meta("foo", "bar")
b.set_meta('foo', 'bar') b.set_meta("foo", "bar")
metasync(a, b, status, keys=['foo']) metasync(a, b, status, keys=["foo"])
assert a.get_meta('foo') == b.get_meta('foo') == status['foo'] == 'bar' assert a.get_meta("foo") == b.get_meta("foo") == status["foo"] == "bar"
@pytest.mark.parametrize('wins', 'ab') @pytest.mark.parametrize("wins", "ab")
def test_conflict_x_wins(wins): def test_conflict_x_wins(wins):
a = MemoryStorage() a = MemoryStorage()
b = MemoryStorage() b = MemoryStorage()
status = {} status = {}
a.set_meta('foo', 'bar') a.set_meta("foo", "bar")
b.set_meta('foo', 'baz') b.set_meta("foo", "baz")
metasync(a, b, status, keys=['foo'], metasync(
conflict_resolution='a wins' if wins == 'a' else 'b wins') a,
b,
status,
keys=["foo"],
conflict_resolution="a wins" if wins == "a" else "b wins",
)
assert a.get_meta('foo') == b.get_meta('foo') == status['foo'] == ( assert (
'bar' if wins == 'a' else 'baz' a.get_meta("foo")
== b.get_meta("foo")
== status["foo"]
== ("bar" if wins == "a" else "baz")
) )
@ -125,33 +132,40 @@ metadata = st.dictionaries(keys, values)
@given( @given(
a=metadata, b=metadata, a=metadata,
status=metadata, keys=st.sets(keys), b=metadata,
conflict_resolution=st.just('a wins') | st.just('b wins') status=metadata,
keys=st.sets(keys),
conflict_resolution=st.just("a wins") | st.just("b wins"),
)
@example(
a={"0": "0"}, b={}, status={"0": "0"}, keys={"0"}, conflict_resolution="a wins"
)
@example(
a={"0": "0"},
b={"0": "1"},
status={"0": "0"},
keys={"0"},
conflict_resolution="a wins",
) )
@example(a={'0': '0'}, b={}, status={'0': '0'}, keys={'0'},
conflict_resolution='a wins')
@example(a={'0': '0'}, b={'0': '1'}, status={'0': '0'}, keys={'0'},
conflict_resolution='a wins')
def test_fuzzing(a, b, status, keys, conflict_resolution): def test_fuzzing(a, b, status, keys, conflict_resolution):
def _get_storage(m, instance_name): def _get_storage(m, instance_name):
s = MemoryStorage(instance_name=instance_name) s = MemoryStorage(instance_name=instance_name)
s.metadata = m s.metadata = m
return s return s
a = _get_storage(a, 'A') a = _get_storage(a, "A")
b = _get_storage(b, 'B') b = _get_storage(b, "B")
winning_storage = (a if conflict_resolution == 'a wins' else b) winning_storage = a if conflict_resolution == "a wins" else b
expected_values = {key: winning_storage.get_meta(key) expected_values = {
for key in keys key: winning_storage.get_meta(key) for key in keys if key not in status
if key not in status} }
metasync(a, b, status, metasync(a, b, status, keys=keys, conflict_resolution=conflict_resolution)
keys=keys, conflict_resolution=conflict_resolution)
for key in keys: for key in keys:
s = status.get(key, '') s = status.get(key, "")
assert a.get_meta(key) == b.get_meta(key) == s assert a.get_meta(key) == b.get_meta(key) == s
if expected_values.get(key, '') and s: if expected_values.get(key, "") and s:
assert s == expected_values[key] assert s == expected_values[key]

View file

@ -18,14 +18,8 @@ from vdirsyncer.vobject import Item
def test_repair_uids(uid): def test_repair_uids(uid):
s = MemoryStorage() s = MemoryStorage()
s.items = { s.items = {
'one': ( "one": ("asdf", Item(f"BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD")),
'asdf', "two": ("asdf", Item(f"BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD")),
Item(f'BEGIN:VCARD\nFN:Hans\nUID:{uid}\nEND:VCARD')
),
'two': (
'asdf',
Item(f'BEGIN:VCARD\nFN:Peppi\nUID:{uid}\nEND:VCARD')
)
} }
uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()] uid1, uid2 = [s.get(href)[0].uid for href, etag in s.list()]
@ -42,7 +36,7 @@ def test_repair_uids(uid):
@settings(suppress_health_check=HealthCheck.all()) @settings(suppress_health_check=HealthCheck.all())
def test_repair_unsafe_uids(uid): def test_repair_unsafe_uids(uid):
s = MemoryStorage() s = MemoryStorage()
item = Item(f'BEGIN:VCARD\nUID:{uid}\nEND:VCARD') item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD")
href, etag = s.upload(item) href, etag = s.upload(item)
assert s.get(href)[0].uid == uid assert s.get(href)[0].uid == uid
assert not href_safe(uid) assert not href_safe(uid)
@ -55,12 +49,11 @@ def test_repair_unsafe_uids(uid):
assert href_safe(newuid) assert href_safe(newuid)
@pytest.mark.parametrize('uid,href', [ @pytest.mark.parametrize(
('b@dh0mbr3', 'perfectly-fine'), "uid,href", [("b@dh0mbr3", "perfectly-fine"), ("perfectly-fine", "b@dh0mbr3")]
('perfectly-fine', 'b@dh0mbr3') )
])
def test_repair_unsafe_href(uid, href): def test_repair_unsafe_href(uid, href):
item = Item(f'BEGIN:VCARD\nUID:{uid}\nEND:VCARD') item = Item(f"BEGIN:VCARD\nUID:{uid}\nEND:VCARD")
new_item = repair_item(href, item, set(), True) new_item = repair_item(href, item, set(), True)
assert new_item.raw != item.raw assert new_item.raw != item.raw
assert new_item.uid != item.uid assert new_item.uid != item.uid
@ -68,18 +61,14 @@ def test_repair_unsafe_href(uid, href):
def test_repair_do_nothing(): def test_repair_do_nothing():
item = Item('BEGIN:VCARD\nUID:justfine\nEND:VCARD') item = Item("BEGIN:VCARD\nUID:justfine\nEND:VCARD")
assert repair_item('fine', item, set(), True) is item assert repair_item("fine", item, set(), True) is item
assert repair_item('@@@@/fine', item, set(), True) is item assert repair_item("@@@@/fine", item, set(), True) is item
@pytest.mark.parametrize('raw', [ @pytest.mark.parametrize(
'AYYY', "raw", ["AYYY", "", "@@@@", "BEGIN:VCARD", "BEGIN:FOO\nEND:FOO"]
'', )
'@@@@',
'BEGIN:VCARD',
'BEGIN:FOO\nEND:FOO'
])
def test_repair_irreparable(raw): def test_repair_irreparable(raw):
with pytest.raises(IrreparableItem): with pytest.raises(IrreparableItem):
repair_item('fine', Item(raw), set(), True) repair_item("fine", Item(raw), set(), True)

View file

@ -20,40 +20,35 @@ from tests import VCARD_TEMPLATE
_simple_split = [ _simple_split = [
VCARD_TEMPLATE.format(r=123, uid=123), VCARD_TEMPLATE.format(r=123, uid=123),
VCARD_TEMPLATE.format(r=345, uid=345), VCARD_TEMPLATE.format(r=345, uid=345),
VCARD_TEMPLATE.format(r=678, uid=678) VCARD_TEMPLATE.format(r=678, uid=678),
] ]
_simple_joined = '\r\n'.join( _simple_joined = "\r\n".join(
['BEGIN:VADDRESSBOOK'] ["BEGIN:VADDRESSBOOK"] + _simple_split + ["END:VADDRESSBOOK\r\n"]
+ _simple_split
+ ['END:VADDRESSBOOK\r\n']
) )
def test_split_collection_simple(benchmark): def test_split_collection_simple(benchmark):
given = benchmark(lambda: list(vobject.split_collection(_simple_joined))) given = benchmark(lambda: list(vobject.split_collection(_simple_joined)))
assert [normalize_item(item) for item in given] == \ assert [normalize_item(item) for item in given] == [
[normalize_item(item) for item in _simple_split] normalize_item(item) for item in _simple_split
]
assert [x.splitlines() for x in given] == \ assert [x.splitlines() for x in given] == [x.splitlines() for x in _simple_split]
[x.splitlines() for x in _simple_split]
def test_split_collection_multiple_wrappers(benchmark): def test_split_collection_multiple_wrappers(benchmark):
joined = '\r\n'.join( joined = "\r\n".join(
'BEGIN:VADDRESSBOOK\r\n' "BEGIN:VADDRESSBOOK\r\n" + x + "\r\nEND:VADDRESSBOOK\r\n" for x in _simple_split
+ x
+ '\r\nEND:VADDRESSBOOK\r\n'
for x in _simple_split
) )
given = benchmark(lambda: list(vobject.split_collection(joined))) given = benchmark(lambda: list(vobject.split_collection(joined)))
assert [normalize_item(item) for item in given] == \ assert [normalize_item(item) for item in given] == [
[normalize_item(item) for item in _simple_split] normalize_item(item) for item in _simple_split
]
assert [x.splitlines() for x in given] == \ assert [x.splitlines() for x in given] == [x.splitlines() for x in _simple_split]
[x.splitlines() for x in _simple_split]
def test_join_collection_simple(benchmark): def test_join_collection_simple(benchmark):
@ -63,8 +58,11 @@ def test_join_collection_simple(benchmark):
def test_join_collection_vevents(benchmark): def test_join_collection_vevents(benchmark):
actual = benchmark(lambda: vobject.join_collection([ actual = benchmark(
dedent(""" lambda: vobject.join_collection(
[
dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:HUEHUE PRODID:HUEHUE
@ -75,10 +73,15 @@ def test_join_collection_vevents(benchmark):
VALUE:Event {} VALUE:Event {}
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
""").format(i) for i in range(3) """
])) ).format(i)
for i in range(3)
]
)
)
expected = dedent(""" expected = dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
VERSION:2.0 VERSION:2.0
PRODID:HUEHUE PRODID:HUEHUE
@ -95,7 +98,8 @@ def test_join_collection_vevents(benchmark):
VALUE:Event 2 VALUE:Event 2
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
""").lstrip() """
).lstrip()
assert actual.splitlines() == expected.splitlines() assert actual.splitlines() == expected.splitlines()
@ -103,34 +107,29 @@ def test_join_collection_vevents(benchmark):
def test_split_collection_timezones(): def test_split_collection_timezones():
items = [ items = [
BARE_EVENT_TEMPLATE.format(r=123, uid=123), BARE_EVENT_TEMPLATE.format(r=123, uid=123),
BARE_EVENT_TEMPLATE.format(r=345, uid=345) BARE_EVENT_TEMPLATE.format(r=345, uid=345),
] ]
timezone = ( timezone = (
'BEGIN:VTIMEZONE\r\n' "BEGIN:VTIMEZONE\r\n"
'TZID:/mozilla.org/20070129_1/Asia/Tokyo\r\n' "TZID:/mozilla.org/20070129_1/Asia/Tokyo\r\n"
'X-LIC-LOCATION:Asia/Tokyo\r\n' "X-LIC-LOCATION:Asia/Tokyo\r\n"
'BEGIN:STANDARD\r\n' "BEGIN:STANDARD\r\n"
'TZOFFSETFROM:+0900\r\n' "TZOFFSETFROM:+0900\r\n"
'TZOFFSETTO:+0900\r\n' "TZOFFSETTO:+0900\r\n"
'TZNAME:JST\r\n' "TZNAME:JST\r\n"
'DTSTART:19700101T000000\r\n' "DTSTART:19700101T000000\r\n"
'END:STANDARD\r\n' "END:STANDARD\r\n"
'END:VTIMEZONE' "END:VTIMEZONE"
) )
full = '\r\n'.join( full = "\r\n".join(["BEGIN:VCALENDAR"] + items + [timezone, "END:VCALENDAR"])
['BEGIN:VCALENDAR']
+ items
+ [timezone, 'END:VCALENDAR']
)
given = {normalize_item(item) given = {normalize_item(item) for item in vobject.split_collection(full)}
for item in vobject.split_collection(full)}
expected = { expected = {
normalize_item('\r\n'.join(( normalize_item(
'BEGIN:VCALENDAR', item, timezone, 'END:VCALENDAR' "\r\n".join(("BEGIN:VCALENDAR", item, timezone, "END:VCALENDAR"))
))) )
for item in items for item in items
} }
@ -138,32 +137,28 @@ def test_split_collection_timezones():
def test_split_contacts(): def test_split_contacts():
bare = '\r\n'.join([VCARD_TEMPLATE.format(r=x, uid=x) for x in range(4)]) bare = "\r\n".join([VCARD_TEMPLATE.format(r=x, uid=x) for x in range(4)])
with_wrapper = 'BEGIN:VADDRESSBOOK\r\n' + bare + '\nEND:VADDRESSBOOK\r\n' with_wrapper = "BEGIN:VADDRESSBOOK\r\n" + bare + "\nEND:VADDRESSBOOK\r\n"
for _ in (bare, with_wrapper): for _ in (bare, with_wrapper):
split = list(vobject.split_collection(bare)) split = list(vobject.split_collection(bare))
assert len(split) == 4 assert len(split) == 4
assert vobject.join_collection(split).splitlines() == \ assert vobject.join_collection(split).splitlines() == with_wrapper.splitlines()
with_wrapper.splitlines()
def test_hash_item(): def test_hash_item():
a = EVENT_TEMPLATE.format(r=1, uid=1) a = EVENT_TEMPLATE.format(r=1, uid=1)
b = '\n'.join(line for line in a.splitlines() b = "\n".join(line for line in a.splitlines() if "PRODID" not in line)
if 'PRODID' not in line)
assert vobject.hash_item(a) == vobject.hash_item(b) assert vobject.hash_item(a) == vobject.hash_item(b)
def test_multiline_uid(benchmark): def test_multiline_uid(benchmark):
a = ('BEGIN:FOO\r\n' a = "BEGIN:FOO\r\n" "UID:123456789abcd\r\n" " efgh\r\n" "END:FOO\r\n"
'UID:123456789abcd\r\n' assert benchmark(lambda: vobject.Item(a).uid) == "123456789abcdefgh"
' efgh\r\n'
'END:FOO\r\n')
assert benchmark(lambda: vobject.Item(a).uid) == '123456789abcdefgh'
complex_uid_item = dedent(''' complex_uid_item = dedent(
"""
BEGIN:VCALENDAR BEGIN:VCALENDAR
BEGIN:VTIMEZONE BEGIN:VTIMEZONE
TZID:Europe/Rome TZID:Europe/Rome
@ -199,99 +194,102 @@ complex_uid_item = dedent('''
TRANSP:OPAQUE TRANSP:OPAQUE
END:VEVENT END:VEVENT
END:VCALENDAR END:VCALENDAR
''').strip() """
).strip()
def test_multiline_uid_complex(benchmark): def test_multiline_uid_complex(benchmark):
assert benchmark(lambda: vobject.Item(complex_uid_item).uid) == ( assert benchmark(lambda: vobject.Item(complex_uid_item).uid) == (
'040000008200E00074C5B7101A82E008000000005' "040000008200E00074C5B7101A82E008000000005"
'0AAABEEF50DCF001000000062548482FA830A46B9' "0AAABEEF50DCF001000000062548482FA830A46B9"
'EA62114AC9F0EF' "EA62114AC9F0EF"
) )
def test_replace_multiline_uid(benchmark): def test_replace_multiline_uid(benchmark):
def inner(): def inner():
return vobject.Item(complex_uid_item).with_uid('a').uid return vobject.Item(complex_uid_item).with_uid("a").uid
assert benchmark(inner) == 'a' assert benchmark(inner) == "a"
@pytest.mark.parametrize('template', [EVENT_TEMPLATE, @pytest.mark.parametrize(
EVENT_WITH_TIMEZONE_TEMPLATE, "template", [EVENT_TEMPLATE, EVENT_WITH_TIMEZONE_TEMPLATE, VCARD_TEMPLATE]
VCARD_TEMPLATE]) )
@given(uid=st.one_of(st.none(), uid_strategy)) @given(uid=st.one_of(st.none(), uid_strategy))
def test_replace_uid(template, uid): def test_replace_uid(template, uid):
item = vobject.Item(template.format(r=123, uid=123)).with_uid(uid) item = vobject.Item(template.format(r=123, uid=123)).with_uid(uid)
assert item.uid == uid assert item.uid == uid
if uid: if uid:
assert item.raw.count(f'\nUID:{uid}') == 1 assert item.raw.count(f"\nUID:{uid}") == 1
else: else:
assert '\nUID:' not in item.raw assert "\nUID:" not in item.raw
def test_broken_item(): def test_broken_item():
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
vobject._Component.parse('END:FOO') vobject._Component.parse("END:FOO")
assert 'Parsing error at line 1' in str(excinfo.value) assert "Parsing error at line 1" in str(excinfo.value)
item = vobject.Item('END:FOO') item = vobject.Item("END:FOO")
assert item.parsed is None assert item.parsed is None
def test_multiple_items(): def test_multiple_items():
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
vobject._Component.parse([ vobject._Component.parse(
'BEGIN:FOO', [
'END:FOO', "BEGIN:FOO",
'BEGIN:FOO', "END:FOO",
'END:FOO', "BEGIN:FOO",
]) "END:FOO",
]
)
assert 'Found 2 components, expected one' in str(excinfo.value) assert "Found 2 components, expected one" in str(excinfo.value)
c1, c2 = vobject._Component.parse([ c1, c2 = vobject._Component.parse(
'BEGIN:FOO', [
'END:FOO', "BEGIN:FOO",
'BEGIN:FOO', "END:FOO",
'END:FOO', "BEGIN:FOO",
], multiple=True) "END:FOO",
assert c1.name == c2.name == 'FOO' ],
multiple=True,
)
assert c1.name == c2.name == "FOO"
def test_input_types(): def test_input_types():
lines = ['BEGIN:FOO', 'FOO:BAR', 'END:FOO'] lines = ["BEGIN:FOO", "FOO:BAR", "END:FOO"]
for x in (lines, '\r\n'.join(lines), '\r\n'.join(lines).encode('ascii')): for x in (lines, "\r\n".join(lines), "\r\n".join(lines).encode("ascii")):
c = vobject._Component.parse(x) c = vobject._Component.parse(x)
assert c.name == 'FOO' assert c.name == "FOO"
assert c.props == ['FOO:BAR'] assert c.props == ["FOO:BAR"]
assert not c.subcomponents assert not c.subcomponents
value_strategy = st.text( value_strategy = st.text(
st.characters(blacklist_categories=( st.characters(
'Zs', 'Zl', 'Zp', blacklist_categories=("Zs", "Zl", "Zp", "Cc", "Cs"), blacklist_characters=":="
'Cc', 'Cs' ),
), blacklist_characters=':='), min_size=1,
min_size=1
).filter(lambda x: x.strip() == x) ).filter(lambda x: x.strip() == x)
class VobjectMachine(RuleBasedStateMachine): class VobjectMachine(RuleBasedStateMachine):
Unparsed = Bundle('unparsed') Unparsed = Bundle("unparsed")
Parsed = Bundle('parsed') Parsed = Bundle("parsed")
@rule(target=Unparsed, @rule(target=Unparsed, joined=st.booleans(), encoded=st.booleans())
joined=st.booleans(),
encoded=st.booleans())
def get_unparsed_lines(self, joined, encoded): def get_unparsed_lines(self, joined, encoded):
rv = ['BEGIN:FOO', 'FOO:YES', 'END:FOO'] rv = ["BEGIN:FOO", "FOO:YES", "END:FOO"]
if joined: if joined:
rv = '\r\n'.join(rv) rv = "\r\n".join(rv)
if encoded: if encoded:
rv = rv.encode('utf-8') rv = rv.encode("utf-8")
elif encoded: elif encoded:
assume(False) assume(False)
return rv return rv
@ -304,24 +302,24 @@ class VobjectMachine(RuleBasedStateMachine):
def serialize(self, parsed): def serialize(self, parsed):
return list(parsed.dump_lines()) return list(parsed.dump_lines())
@rule(c=Parsed, @rule(c=Parsed, key=uid_strategy, value=uid_strategy)
key=uid_strategy,
value=uid_strategy)
def add_prop(self, c, key, value): def add_prop(self, c, key, value):
c[key] = value c[key] = value
assert c[key] == value assert c[key] == value
assert key in c assert key in c
assert c.get(key) == value assert c.get(key) == value
dump = '\r\n'.join(c.dump_lines()) dump = "\r\n".join(c.dump_lines())
assert key in dump and value in dump assert key in dump and value in dump
@rule(c=Parsed, @rule(
key=uid_strategy, c=Parsed,
value=uid_strategy, key=uid_strategy,
params=st.lists(st.tuples(value_strategy, value_strategy))) value=uid_strategy,
params=st.lists(st.tuples(value_strategy, value_strategy)),
)
def add_prop_raw(self, c, key, value, params): def add_prop_raw(self, c, key, value, params):
params_str = ','.join(k + '=' + v for k, v in params) params_str = ",".join(k + "=" + v for k, v in params)
c.props.insert(0, f'{key};{params_str}:{value}') c.props.insert(0, f"{key};{params_str}:{value}")
assert c[key] == value assert c[key] == value
assert key in c assert key in c
assert c.get(key) == value assert c.get(key) == value
@ -330,7 +328,7 @@ class VobjectMachine(RuleBasedStateMachine):
def add_component(self, c, sub_c): def add_component(self, c, sub_c):
assume(sub_c is not c and sub_c not in c) assume(sub_c is not c and sub_c not in c)
c.subcomponents.append(sub_c) c.subcomponents.append(sub_c)
assert '\r\n'.join(sub_c.dump_lines()) in '\r\n'.join(c.dump_lines()) assert "\r\n".join(sub_c.dump_lines()) in "\r\n".join(c.dump_lines())
@rule(c=Parsed) @rule(c=Parsed)
def sanity_check(self, c): def sanity_check(self, c):
@ -342,14 +340,10 @@ TestVobjectMachine = VobjectMachine.TestCase
def test_component_contains(): def test_component_contains():
item = vobject._Component.parse([ item = vobject._Component.parse(["BEGIN:FOO", "FOO:YES", "END:FOO"])
'BEGIN:FOO',
'FOO:YES',
'END:FOO'
])
assert 'FOO' in item assert "FOO" in item
assert 'BAZ' not in item assert "BAZ" not in item
with pytest.raises(ValueError): with pytest.raises(ValueError):
42 in item # noqa: B015 42 in item # noqa: B015

View file

@ -1,26 +1,27 @@
''' """
Vdirsyncer synchronizes calendars and contacts. Vdirsyncer synchronizes calendars and contacts.
''' """
PROJECT_HOME = 'https://github.com/pimutils/vdirsyncer' PROJECT_HOME = "https://github.com/pimutils/vdirsyncer"
BUGTRACKER_HOME = PROJECT_HOME + '/issues' BUGTRACKER_HOME = PROJECT_HOME + "/issues"
DOCS_HOME = 'https://vdirsyncer.pimutils.org/en/stable' DOCS_HOME = "https://vdirsyncer.pimutils.org/en/stable"
try: try:
from .version import version as __version__ # noqa from .version import version as __version__ # noqa
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
raise ImportError( raise ImportError(
'Failed to find (autogenerated) version.py. ' "Failed to find (autogenerated) version.py. "
'This might be because you are installing from GitHub\'s tarballs, ' "This might be because you are installing from GitHub's tarballs, "
'use the PyPI ones.' "use the PyPI ones."
) )
def _check_python_version(): # pragma: no cover def _check_python_version(): # pragma: no cover
import sys import sys
if sys.version_info < (3, 7, 0): if sys.version_info < (3, 7, 0):
print('vdirsyncer requires at least Python 3.7.') print("vdirsyncer requires at least Python 3.7.")
sys.exit(1) sys.exit(1)

View file

@ -1,3 +1,4 @@
if __name__ == '__main__': if __name__ == "__main__":
from vdirsyncer.cli import app from vdirsyncer.cli import app
app() app()

View file

@ -10,7 +10,7 @@ from .. import BUGTRACKER_HOME
cli_logger = logging.getLogger(__name__) cli_logger = logging.getLogger(__name__)
click_log.basic_config('vdirsyncer') click_log.basic_config("vdirsyncer")
class AppContext: class AppContext:
@ -30,6 +30,7 @@ def catch_errors(f):
f(*a, **kw) f(*a, **kw)
except BaseException: except BaseException:
from .utils import handle_cli_error from .utils import handle_cli_error
handle_cli_error() handle_cli_error()
sys.exit(1) sys.exit(1)
@ -37,24 +38,26 @@ def catch_errors(f):
@click.group() @click.group()
@click_log.simple_verbosity_option('vdirsyncer') @click_log.simple_verbosity_option("vdirsyncer")
@click.version_option(version=__version__) @click.version_option(version=__version__)
@click.option('--config', '-c', metavar='FILE', help='Config file to use.') @click.option("--config", "-c", metavar="FILE", help="Config file to use.")
@pass_context @pass_context
@catch_errors @catch_errors
def app(ctx, config): def app(ctx, config):
''' """
Synchronize calendars and contacts Synchronize calendars and contacts
''' """
if sys.platform == 'win32': if sys.platform == "win32":
cli_logger.warning('Vdirsyncer currently does not support Windows. ' cli_logger.warning(
'You will likely encounter bugs. ' "Vdirsyncer currently does not support Windows. "
'See {}/535 for more information.' "You will likely encounter bugs. "
.format(BUGTRACKER_HOME)) "See {}/535 for more information.".format(BUGTRACKER_HOME)
)
if not ctx.config: if not ctx.config:
from .config import load_config from .config import load_config
ctx.config = load_config(config) ctx.config = load_config(config)
@ -62,40 +65,44 @@ main = app
def max_workers_callback(ctx, param, value): def max_workers_callback(ctx, param, value):
if value == 0 and logging.getLogger('vdirsyncer').level == logging.DEBUG: if value == 0 and logging.getLogger("vdirsyncer").level == logging.DEBUG:
value = 1 value = 1
cli_logger.debug(f'Using {value} maximal workers.') cli_logger.debug(f"Using {value} maximal workers.")
return value return value
def max_workers_option(default=0): def max_workers_option(default=0):
help = 'Use at most this many connections. ' help = "Use at most this many connections. "
if default == 0: if default == 0:
help += 'The default is 0, which means "as many as necessary". ' \ help += (
'With -vdebug enabled, the default is 1.' 'The default is 0, which means "as many as necessary". '
"With -vdebug enabled, the default is 1."
)
else: else:
help += f'The default is {default}.' help += f"The default is {default}."
return click.option( return click.option(
'--max-workers', default=default, type=click.IntRange(min=0, max=None), "--max-workers",
default=default,
type=click.IntRange(min=0, max=None),
callback=max_workers_callback, callback=max_workers_callback,
help=help help=help,
) )
def collections_arg_callback(ctx, param, value): def collections_arg_callback(ctx, param, value):
''' """
Expand the various CLI shortforms ("pair, pair/collection") to an iterable Expand the various CLI shortforms ("pair, pair/collection") to an iterable
of (pair, collections). of (pair, collections).
''' """
# XXX: Ugly! pass_context should work everywhere. # XXX: Ugly! pass_context should work everywhere.
config = ctx.find_object(AppContext).config config = ctx.find_object(AppContext).config
rv = {} rv = {}
for pair_and_collection in (value or config.pairs): for pair_and_collection in value or config.pairs:
pair, collection = pair_and_collection, None pair, collection = pair_and_collection, None
if '/' in pair: if "/" in pair:
pair, collection = pair.split('/') pair, collection = pair.split("/")
collections = rv.setdefault(pair, set()) collections = rv.setdefault(pair, set())
if collection: if collection:
@ -104,20 +111,25 @@ def collections_arg_callback(ctx, param, value):
return rv.items() return rv.items()
collections_arg = click.argument('collections', nargs=-1, collections_arg = click.argument(
callback=collections_arg_callback) "collections", nargs=-1, callback=collections_arg_callback
)
@app.command() @app.command()
@collections_arg @collections_arg
@click.option('--force-delete/--no-force-delete', @click.option(
help=('Do/Don\'t abort synchronization when all items are about ' "--force-delete/--no-force-delete",
'to be deleted from both sides.')) help=(
"Do/Don't abort synchronization when all items are about "
"to be deleted from both sides."
),
)
@max_workers_option() @max_workers_option()
@pass_context @pass_context
@catch_errors @catch_errors
def sync(ctx, collections, force_delete, max_workers): def sync(ctx, collections, force_delete, max_workers):
''' """
Synchronize the given collections or pairs. If no arguments are given, all Synchronize the given collections or pairs. If no arguments are given, all
will be synchronized. will be synchronized.
@ -136,7 +148,7 @@ def sync(ctx, collections, force_delete, max_workers):
\b \b
# Sync only "first_collection" from the pair "bob" # Sync only "first_collection" from the pair "bob"
vdirsyncer sync bob/first_collection vdirsyncer sync bob/first_collection
''' """
from .tasks import prepare_pair, sync_collection from .tasks import prepare_pair, sync_collection
from .utils import WorkerQueue from .utils import WorkerQueue
@ -144,11 +156,16 @@ def sync(ctx, collections, force_delete, max_workers):
with wq.join(): with wq.join():
for pair_name, collections in collections: for pair_name, collections in collections:
wq.put(functools.partial(prepare_pair, pair_name=pair_name, wq.put(
collections=collections, functools.partial(
config=ctx.config, prepare_pair,
force_delete=force_delete, pair_name=pair_name,
callback=sync_collection)) collections=collections,
config=ctx.config,
force_delete=force_delete,
callback=sync_collection,
)
)
wq.spawn_worker() wq.spawn_worker()
@ -158,11 +175,11 @@ def sync(ctx, collections, force_delete, max_workers):
@pass_context @pass_context
@catch_errors @catch_errors
def metasync(ctx, collections, max_workers): def metasync(ctx, collections, max_workers):
''' """
Synchronize metadata of the given collections or pairs. Synchronize metadata of the given collections or pairs.
See the `sync` command for usage. See the `sync` command for usage.
''' """
from .tasks import prepare_pair, metasync_collection from .tasks import prepare_pair, metasync_collection
from .utils import WorkerQueue from .utils import WorkerQueue
@ -170,59 +187,73 @@ def metasync(ctx, collections, max_workers):
with wq.join(): with wq.join():
for pair_name, collections in collections: for pair_name, collections in collections:
wq.put(functools.partial(prepare_pair, pair_name=pair_name, wq.put(
collections=collections, functools.partial(
config=ctx.config, prepare_pair,
callback=metasync_collection)) pair_name=pair_name,
collections=collections,
config=ctx.config,
callback=metasync_collection,
)
)
wq.spawn_worker() wq.spawn_worker()
@app.command() @app.command()
@click.argument('pairs', nargs=-1) @click.argument("pairs", nargs=-1)
@click.option( @click.option(
'--list/--no-list', default=True, "--list/--no-list",
default=True,
help=( help=(
'Whether to list all collections from both sides during discovery, ' "Whether to list all collections from both sides during discovery, "
'for debugging. This is slow and may crash for broken servers.' "for debugging. This is slow and may crash for broken servers."
) ),
) )
@max_workers_option(default=1) @max_workers_option(default=1)
@pass_context @pass_context
@catch_errors @catch_errors
def discover(ctx, pairs, max_workers, list): def discover(ctx, pairs, max_workers, list):
''' """
Refresh collection cache for the given pairs. Refresh collection cache for the given pairs.
''' """
from .tasks import discover_collections from .tasks import discover_collections
from .utils import WorkerQueue from .utils import WorkerQueue
config = ctx.config config = ctx.config
wq = WorkerQueue(max_workers) wq = WorkerQueue(max_workers)
with wq.join(): with wq.join():
for pair_name in (pairs or config.pairs): for pair_name in pairs or config.pairs:
pair = config.get_pair(pair_name) pair = config.get_pair(pair_name)
wq.put(functools.partial( wq.put(
discover_collections, functools.partial(
status_path=config.general['status_path'], discover_collections,
pair=pair, status_path=config.general["status_path"],
from_cache=False, pair=pair,
list_collections=list, from_cache=False,
)) list_collections=list,
)
)
wq.spawn_worker() wq.spawn_worker()
@app.command() @app.command()
@click.argument('collection') @click.argument("collection")
@click.option('--repair-unsafe-uid/--no-repair-unsafe-uid', default=False, @click.option(
help=('Some characters in item UIDs and URLs may cause problems ' "--repair-unsafe-uid/--no-repair-unsafe-uid",
'with buggy software. Adding this option will reassign ' default=False,
'new UIDs to those items. This is disabled by default, ' help=(
'which is equivalent to `--no-repair-unsafe-uid`.')) "Some characters in item UIDs and URLs may cause problems "
"with buggy software. Adding this option will reassign "
"new UIDs to those items. This is disabled by default, "
"which is equivalent to `--no-repair-unsafe-uid`."
),
)
@pass_context @pass_context
@catch_errors @catch_errors
def repair(ctx, collection, repair_unsafe_uid): def repair(ctx, collection, repair_unsafe_uid):
''' """
Repair a given collection. Repair a given collection.
Runs a few checks on the collection and applies some fixes to individual Runs a few checks on the collection and applies some fixes to individual
@ -234,12 +265,13 @@ def repair(ctx, collection, repair_unsafe_uid):
\b\bExamples: \b\bExamples:
# Repair the `foo` collection of the `calendars_local` storage # Repair the `foo` collection of the `calendars_local` storage
vdirsyncer repair calendars_local/foo vdirsyncer repair calendars_local/foo
''' """
from .tasks import repair_collection from .tasks import repair_collection
cli_logger.warning('This operation will take a very long time.') cli_logger.warning("This operation will take a very long time.")
cli_logger.warning('It\'s recommended to make a backup and ' cli_logger.warning(
'turn off other client\'s synchronization features.') "It's recommended to make a backup and "
click.confirm('Do you want to continue?', abort=True) "turn off other client's synchronization features."
repair_collection(ctx.config, collection, )
repair_unsafe_uid=repair_unsafe_uid) click.confirm("Do you want to continue?", abort=True)
repair_collection(ctx.config, collection, repair_unsafe_uid=repair_unsafe_uid)

View file

@ -14,19 +14,20 @@ from .fetchparams import expand_fetch_params
from .utils import storage_class_from_config from .utils import storage_class_from_config
GENERAL_ALL = frozenset(['status_path']) GENERAL_ALL = frozenset(["status_path"])
GENERAL_REQUIRED = frozenset(['status_path']) GENERAL_REQUIRED = frozenset(["status_path"])
SECTION_NAME_CHARS = frozenset(chain(string.ascii_letters, string.digits, '_')) SECTION_NAME_CHARS = frozenset(chain(string.ascii_letters, string.digits, "_"))
def validate_section_name(name, section_type): def validate_section_name(name, section_type):
invalid = set(name) - SECTION_NAME_CHARS invalid = set(name) - SECTION_NAME_CHARS
if invalid: if invalid:
chars_display = ''.join(sorted(SECTION_NAME_CHARS)) chars_display = "".join(sorted(SECTION_NAME_CHARS))
raise exceptions.UserError( raise exceptions.UserError(
'The {}-section "{}" contains invalid characters. Only ' 'The {}-section "{}" contains invalid characters. Only '
'the following characters are allowed for storage and ' "the following characters are allowed for storage and "
'pair names:\n{}'.format(section_type, name, chars_display)) "pair names:\n{}".format(section_type, name, chars_display)
)
def _validate_general_section(general_config): def _validate_general_section(general_config):
@ -35,18 +36,21 @@ def _validate_general_section(general_config):
problems = [] problems = []
if invalid: if invalid:
problems.append('general section doesn\'t take the parameters: {}' problems.append(
.format(', '.join(invalid))) "general section doesn't take the parameters: {}".format(", ".join(invalid))
)
if missing: if missing:
problems.append('general section is missing the parameters: {}' problems.append(
.format(', '.join(missing))) "general section is missing the parameters: {}".format(", ".join(missing))
)
if problems: if problems:
raise exceptions.UserError( raise exceptions.UserError(
'Invalid general section. Copy the example ' "Invalid general section. Copy the example "
'config from the repository and edit it: {}' "config from the repository and edit it: {}".format(PROJECT_HOME),
.format(PROJECT_HOME), problems=problems) problems=problems,
)
def _validate_collections_param(collections): def _validate_collections_param(collections):
@ -54,7 +58,7 @@ def _validate_collections_param(collections):
return return
if not isinstance(collections, list): if not isinstance(collections, list):
raise ValueError('`collections` parameter must be a list or `null`.') raise ValueError("`collections` parameter must be a list or `null`.")
collection_names = set() collection_names = set()
@ -64,7 +68,7 @@ def _validate_collections_param(collections):
collection_name = collection collection_name = collection
elif isinstance(collection, list): elif isinstance(collection, list):
e = ValueError( e = ValueError(
'Expected list of format ' "Expected list of format "
'["config_name", "storage_a_name", "storage_b_name"]' '["config_name", "storage_a_name", "storage_b_name"]'
) )
if len(collection) != 3: if len(collection) != 3:
@ -79,14 +83,15 @@ def _validate_collections_param(collections):
collection_name = collection[0] collection_name = collection[0]
else: else:
raise ValueError('Expected string or list of three strings.') raise ValueError("Expected string or list of three strings.")
if collection_name in collection_names: if collection_name in collection_names:
raise ValueError('Duplicate value.') raise ValueError("Duplicate value.")
collection_names.add(collection_name) collection_names.add(collection_name)
except ValueError as e: except ValueError as e:
raise ValueError('`collections` parameter, position {i}: {e}' raise ValueError(
.format(i=i, e=str(e))) "`collections` parameter, position {i}: {e}".format(i=i, e=str(e))
)
class _ConfigReader: class _ConfigReader:
@ -106,39 +111,38 @@ class _ConfigReader:
raise ValueError(f'Name "{name}" already used.') raise ValueError(f'Name "{name}" already used.')
self._seen_names.add(name) self._seen_names.add(name)
if section_type == 'general': if section_type == "general":
if self._general: if self._general:
raise ValueError('More than one general section.') raise ValueError("More than one general section.")
self._general = options self._general = options
elif section_type == 'storage': elif section_type == "storage":
self._storages[name] = options self._storages[name] = options
elif section_type == 'pair': elif section_type == "pair":
self._pairs[name] = options self._pairs[name] = options
else: else:
raise ValueError('Unknown section type.') raise ValueError("Unknown section type.")
def parse(self): def parse(self):
for section in self._parser.sections(): for section in self._parser.sections():
if ' ' in section: if " " in section:
section_type, name = section.split(' ', 1) section_type, name = section.split(" ", 1)
else: else:
section_type = name = section section_type = name = section
try: try:
self._parse_section( self._parse_section(
section_type, name, section_type,
dict(_parse_options(self._parser.items(section), name,
section=section)) dict(_parse_options(self._parser.items(section), section=section)),
) )
except ValueError as e: except ValueError as e:
raise exceptions.UserError( raise exceptions.UserError('Section "{}": {}'.format(section, str(e)))
'Section "{}": {}'.format(section, str(e)))
_validate_general_section(self._general) _validate_general_section(self._general)
if getattr(self._file, 'name', None): if getattr(self._file, "name", None):
self._general['status_path'] = os.path.join( self._general["status_path"] = os.path.join(
os.path.dirname(self._file.name), os.path.dirname(self._file.name),
expand_path(self._general['status_path']) expand_path(self._general["status_path"]),
) )
return self._general, self._pairs, self._storages return self._general, self._pairs, self._storages
@ -149,8 +153,7 @@ def _parse_options(items, section=None):
try: try:
yield key, json.loads(value) yield key, json.loads(value)
except ValueError as e: except ValueError as e:
raise ValueError('Section "{}", option "{}": {}' raise ValueError('Section "{}", option "{}": {}'.format(section, key, e))
.format(section, key, e))
class Config: class Config:
@ -158,14 +161,14 @@ class Config:
self.general = general self.general = general
self.storages = storages self.storages = storages
for name, options in storages.items(): for name, options in storages.items():
options['instance_name'] = name options["instance_name"] = name
self.pairs = {} self.pairs = {}
for name, options in pairs.items(): for name, options in pairs.items():
try: try:
self.pairs[name] = PairConfig(self, name, options) self.pairs[name] = PairConfig(self, name, options)
except ValueError as e: except ValueError as e:
raise exceptions.UserError(f'Pair {name}: {e}') raise exceptions.UserError(f"Pair {name}: {e}")
@classmethod @classmethod
def from_fileobject(cls, f): def from_fileobject(cls, f):
@ -175,21 +178,21 @@ class Config:
@classmethod @classmethod
def from_filename_or_environment(cls, fname=None): def from_filename_or_environment(cls, fname=None):
if fname is None: if fname is None:
fname = os.environ.get('VDIRSYNCER_CONFIG', None) fname = os.environ.get("VDIRSYNCER_CONFIG", None)
if fname is None: if fname is None:
fname = expand_path('~/.vdirsyncer/config') fname = expand_path("~/.vdirsyncer/config")
if not os.path.exists(fname): if not os.path.exists(fname):
xdg_config_dir = os.environ.get('XDG_CONFIG_HOME', xdg_config_dir = os.environ.get(
expand_path('~/.config/')) "XDG_CONFIG_HOME", expand_path("~/.config/")
fname = os.path.join(xdg_config_dir, 'vdirsyncer/config') )
fname = os.path.join(xdg_config_dir, "vdirsyncer/config")
try: try:
with open(fname) as f: with open(fname) as f:
return cls.from_fileobject(f) return cls.from_fileobject(f)
except Exception as e: except Exception as e:
raise exceptions.UserError( raise exceptions.UserError(
'Error during reading config {}: {}' "Error during reading config {}: {}".format(fname, e)
.format(fname, e)
) )
def get_storage_args(self, storage_name): def get_storage_args(self, storage_name):
@ -197,9 +200,10 @@ class Config:
args = self.storages[storage_name] args = self.storages[storage_name]
except KeyError: except KeyError:
raise exceptions.UserError( raise exceptions.UserError(
'Storage {!r} not found. ' "Storage {!r} not found. "
'These are the configured storages: {}' "These are the configured storages: {}".format(
.format(storage_name, list(self.storages)) storage_name, list(self.storages)
)
) )
else: else:
return expand_fetch_params(args) return expand_fetch_params(args)
@ -215,50 +219,53 @@ class PairConfig:
def __init__(self, full_config, name, options): def __init__(self, full_config, name, options):
self._config = full_config self._config = full_config
self.name = name self.name = name
self.name_a = options.pop('a') self.name_a = options.pop("a")
self.name_b = options.pop('b') self.name_b = options.pop("b")
self._partial_sync = options.pop('partial_sync', None) self._partial_sync = options.pop("partial_sync", None)
self.metadata = options.pop('metadata', None) or () self.metadata = options.pop("metadata", None) or ()
self.conflict_resolution = \ self.conflict_resolution = self._process_conflict_resolution_param(
self._process_conflict_resolution_param( options.pop("conflict_resolution", None)
options.pop('conflict_resolution', None)) )
try: try:
self.collections = options.pop('collections') self.collections = options.pop("collections")
except KeyError: except KeyError:
raise ValueError( raise ValueError(
'collections parameter missing.\n\n' "collections parameter missing.\n\n"
'As of 0.9.0 this parameter has no default anymore. ' "As of 0.9.0 this parameter has no default anymore. "
'Set `collections = null` explicitly in your pair config.' "Set `collections = null` explicitly in your pair config."
) )
else: else:
_validate_collections_param(self.collections) _validate_collections_param(self.collections)
if options: if options:
raise ValueError('Unknown options: {}'.format(', '.join(options))) raise ValueError("Unknown options: {}".format(", ".join(options)))
def _process_conflict_resolution_param(self, conflict_resolution): def _process_conflict_resolution_param(self, conflict_resolution):
if conflict_resolution in (None, 'a wins', 'b wins'): if conflict_resolution in (None, "a wins", "b wins"):
return conflict_resolution return conflict_resolution
elif isinstance(conflict_resolution, list) and \ elif (
len(conflict_resolution) > 1 and \ isinstance(conflict_resolution, list)
conflict_resolution[0] == 'command': and len(conflict_resolution) > 1
and conflict_resolution[0] == "command"
):
def resolve(a, b): def resolve(a, b):
a_name = self.config_a['instance_name'] a_name = self.config_a["instance_name"]
b_name = self.config_b['instance_name'] b_name = self.config_b["instance_name"]
command = conflict_resolution[1:] command = conflict_resolution[1:]
def inner(): def inner():
return _resolve_conflict_via_command(a, b, command, a_name, return _resolve_conflict_via_command(a, b, command, a_name, b_name)
b_name)
ui_worker = get_ui_worker() ui_worker = get_ui_worker()
return ui_worker.put(inner) return ui_worker.put(inner)
return resolve return resolve
else: else:
raise ValueError('Invalid value for `conflict_resolution`.') raise ValueError("Invalid value for `conflict_resolution`.")
# The following parameters are lazily evaluated because evaluating # The following parameters are lazily evaluated because evaluating
# self.config_a would expand all `x.fetch` parameters. This is costly and # self.config_a would expand all `x.fetch` parameters. This is costly and
@ -282,21 +289,23 @@ class PairConfig:
cls_a, _ = storage_class_from_config(self.config_a) cls_a, _ = storage_class_from_config(self.config_a)
cls_b, _ = storage_class_from_config(self.config_b) cls_b, _ = storage_class_from_config(self.config_b)
if not cls_a.read_only and \ if (
not self.config_a.get('read_only', False) and \ not cls_a.read_only
not cls_b.read_only and \ and not self.config_a.get("read_only", False)
not self.config_b.get('read_only', False): and not cls_b.read_only
and not self.config_b.get("read_only", False)
):
raise exceptions.UserError( raise exceptions.UserError(
'`partial_sync` is only effective if one storage is ' "`partial_sync` is only effective if one storage is "
'read-only. Use `read_only = true` in exactly one storage ' "read-only. Use `read_only = true` in exactly one storage "
'section.' "section."
) )
if partial_sync is None: if partial_sync is None:
partial_sync = 'revert' partial_sync = "revert"
if partial_sync not in ('ignore', 'revert', 'error'): if partial_sync not in ("ignore", "revert", "error"):
raise exceptions.UserError('Invalid value for `partial_sync`.') raise exceptions.UserError("Invalid value for `partial_sync`.")
return partial_sync return partial_sync
@ -314,8 +323,7 @@ class CollectionConfig:
load_config = Config.from_filename_or_environment load_config = Config.from_filename_or_environment
def _resolve_conflict_via_command(a, b, command, a_name, b_name, def _resolve_conflict_via_command(a, b, command, a_name, b_name, _check_call=None):
_check_call=None):
import tempfile import tempfile
import shutil import shutil
@ -324,14 +332,14 @@ def _resolve_conflict_via_command(a, b, command, a_name, b_name,
from ..vobject import Item from ..vobject import Item
dir = tempfile.mkdtemp(prefix='vdirsyncer-conflict.') dir = tempfile.mkdtemp(prefix="vdirsyncer-conflict.")
try: try:
a_tmp = os.path.join(dir, a_name) a_tmp = os.path.join(dir, a_name)
b_tmp = os.path.join(dir, b_name) b_tmp = os.path.join(dir, b_name)
with open(a_tmp, 'w') as f: with open(a_tmp, "w") as f:
f.write(a.raw) f.write(a.raw)
with open(b_tmp, 'w') as f: with open(b_tmp, "w") as f:
f.write(b.raw) f.write(b.raw)
command[0] = expand_path(command[0]) command[0] = expand_path(command[0])
@ -343,8 +351,7 @@ def _resolve_conflict_via_command(a, b, command, a_name, b_name,
new_b = f.read() new_b = f.read()
if new_a != new_b: if new_a != new_b:
raise exceptions.UserError('The two files are not completely ' raise exceptions.UserError("The two files are not completely " "equal.")
'equal.')
return Item(new_a) return Item(new_a)
finally: finally:
shutil.rmtree(dir) shutil.rmtree(dir)

View file

@ -22,19 +22,21 @@ logger = logging.getLogger(__name__)
def _get_collections_cache_key(pair): def _get_collections_cache_key(pair):
m = hashlib.sha256() m = hashlib.sha256()
j = json.dumps([ j = json.dumps(
DISCOVERY_CACHE_VERSION, [
pair.collections, DISCOVERY_CACHE_VERSION,
pair.config_a, pair.collections,
pair.config_b, pair.config_a,
], sort_keys=True) pair.config_b,
m.update(j.encode('utf-8')) ],
sort_keys=True,
)
m.update(j.encode("utf-8"))
return m.hexdigest() return m.hexdigest()
def collections_for_pair(status_path, pair, from_cache=True, def collections_for_pair(status_path, pair, from_cache=True, list_collections=False):
list_collections=False): """Determine all configured collections for a given pair. Takes care of
'''Determine all configured collections for a given pair. Takes care of
shortcut expansion and result caching. shortcut expansion and result caching.
:param status_path: The path to the status directory. :param status_path: The path to the status directory.
@ -42,55 +44,62 @@ def collections_for_pair(status_path, pair, from_cache=True,
discover and save to cache. discover and save to cache.
:returns: iterable of (collection, (a_args, b_args)) :returns: iterable of (collection, (a_args, b_args))
''' """
cache_key = _get_collections_cache_key(pair) cache_key = _get_collections_cache_key(pair)
if from_cache: if from_cache:
rv = load_status(status_path, pair.name, data_type='collections') rv = load_status(status_path, pair.name, data_type="collections")
if rv and rv.get('cache_key', None) == cache_key: if rv and rv.get("cache_key", None) == cache_key:
return list(_expand_collections_cache( return list(
rv['collections'], pair.config_a, pair.config_b _expand_collections_cache(
)) rv["collections"], pair.config_a, pair.config_b
)
)
elif rv: elif rv:
raise exceptions.UserError('Detected change in config file, ' raise exceptions.UserError(
'please run `vdirsyncer discover {}`.' "Detected change in config file, "
.format(pair.name)) "please run `vdirsyncer discover {}`.".format(pair.name)
)
else: else:
raise exceptions.UserError('Please run `vdirsyncer discover {}` ' raise exceptions.UserError(
' before synchronization.' "Please run `vdirsyncer discover {}` "
.format(pair.name)) " before synchronization.".format(pair.name)
)
logger.info('Discovering collections for pair {}' .format(pair.name)) logger.info("Discovering collections for pair {}".format(pair.name))
a_discovered = _DiscoverResult(pair.config_a) a_discovered = _DiscoverResult(pair.config_a)
b_discovered = _DiscoverResult(pair.config_b) b_discovered = _DiscoverResult(pair.config_b)
if list_collections: if list_collections:
_print_collections(pair.config_a['instance_name'], _print_collections(pair.config_a["instance_name"], a_discovered.get_self)
a_discovered.get_self) _print_collections(pair.config_b["instance_name"], b_discovered.get_self)
_print_collections(pair.config_b['instance_name'],
b_discovered.get_self)
# We have to use a list here because the special None/null value would get # We have to use a list here because the special None/null value would get
# mangled to string (because JSON objects always have string keys). # mangled to string (because JSON objects always have string keys).
rv = list(expand_collections( rv = list(
shortcuts=pair.collections, expand_collections(
config_a=pair.config_a, shortcuts=pair.collections,
config_b=pair.config_b, config_a=pair.config_a,
get_a_discovered=a_discovered.get_self, config_b=pair.config_b,
get_b_discovered=b_discovered.get_self, get_a_discovered=a_discovered.get_self,
_handle_collection_not_found=handle_collection_not_found get_b_discovered=b_discovered.get_self,
)) _handle_collection_not_found=handle_collection_not_found,
)
)
_sanity_check_collections(rv) _sanity_check_collections(rv)
save_status(status_path, pair.name, data_type='collections', save_status(
data={ status_path,
'collections': list( pair.name,
_compress_collections_cache(rv, pair.config_a, data_type="collections",
pair.config_b) data={
), "collections": list(
'cache_key': cache_key _compress_collections_cache(rv, pair.config_a, pair.config_b)
}) ),
"cache_key": cache_key,
},
)
return rv return rv
@ -141,25 +150,31 @@ class _DiscoverResult:
except Exception: except Exception:
return handle_storage_init_error(self._cls, self._config) return handle_storage_init_error(self._cls, self._config)
else: else:
storage_type = self._config['type'] storage_type = self._config["type"]
rv = {} rv = {}
for args in discovered: for args in discovered:
args['type'] = storage_type args["type"] = storage_type
rv[args['collection']] = args rv[args["collection"]] = args
return rv return rv
def expand_collections(shortcuts, config_a, config_b, get_a_discovered, def expand_collections(
get_b_discovered, _handle_collection_not_found): shortcuts,
config_a,
config_b,
get_a_discovered,
get_b_discovered,
_handle_collection_not_found,
):
handled_collections = set() handled_collections = set()
if shortcuts is None: if shortcuts is None:
shortcuts = [None] shortcuts = [None]
for shortcut in shortcuts: for shortcut in shortcuts:
if shortcut == 'from a': if shortcut == "from a":
collections = get_a_discovered() collections = get_a_discovered()
elif shortcut == 'from b': elif shortcut == "from b":
collections = get_b_discovered() collections = get_b_discovered()
else: else:
collections = [shortcut] collections = [shortcut]
@ -175,22 +190,21 @@ def expand_collections(shortcuts, config_a, config_b, get_a_discovered,
handled_collections.add(collection) handled_collections.add(collection)
a_args = _collection_from_discovered( a_args = _collection_from_discovered(
get_a_discovered, collection_a, config_a, get_a_discovered, collection_a, config_a, _handle_collection_not_found
_handle_collection_not_found
) )
b_args = _collection_from_discovered( b_args = _collection_from_discovered(
get_b_discovered, collection_b, config_b, get_b_discovered, collection_b, config_b, _handle_collection_not_found
_handle_collection_not_found
) )
yield collection, (a_args, b_args) yield collection, (a_args, b_args)
def _collection_from_discovered(get_discovered, collection, config, def _collection_from_discovered(
_handle_collection_not_found): get_discovered, collection, config, _handle_collection_not_found
):
if collection is None: if collection is None:
args = dict(config) args = dict(config)
args['collection'] = None args["collection"] = None
return args return args
try: try:
@ -209,26 +223,31 @@ def _print_collections(instance_name, get_discovered):
# UserError), we don't even know if the storage supports discovery # UserError), we don't even know if the storage supports discovery
# properly. So we can't abort. # properly. So we can't abort.
import traceback import traceback
logger.debug(''.join(traceback.format_tb(sys.exc_info()[2])))
logger.warning('Failed to discover collections for {}, use `-vdebug` ' logger.debug("".join(traceback.format_tb(sys.exc_info()[2])))
'to see the full traceback.'.format(instance_name)) logger.warning(
"Failed to discover collections for {}, use `-vdebug` "
"to see the full traceback.".format(instance_name)
)
return return
logger.info(f'{instance_name}:') logger.info(f"{instance_name}:")
for args in discovered.values(): for args in discovered.values():
collection = args['collection'] collection = args["collection"]
if collection is None: if collection is None:
continue continue
args['instance_name'] = instance_name args["instance_name"] = instance_name
try: try:
storage = storage_instance_from_config(args, create=False) storage = storage_instance_from_config(args, create=False)
displayname = storage.get_meta('displayname') displayname = storage.get_meta("displayname")
except Exception: except Exception:
displayname = '' displayname = ""
logger.info(' - {}{}'.format( logger.info(
json.dumps(collection), " - {}{}".format(
f' ("{displayname}")' json.dumps(collection),
if displayname and displayname != collection f' ("{displayname}")'
else '' if displayname and displayname != collection
)) else "",
)
)

View file

@ -7,7 +7,7 @@ from .. import exceptions
from ..utils import expand_path from ..utils import expand_path
from ..utils import synchronized from ..utils import synchronized
SUFFIX = '.fetch' SUFFIX = ".fetch"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,9 +18,9 @@ def expand_fetch_params(config):
if not key.endswith(SUFFIX): if not key.endswith(SUFFIX):
continue continue
newkey = key[:-len(SUFFIX)] newkey = key[: -len(SUFFIX)]
if newkey in config: if newkey in config:
raise ValueError(f'Can\'t set {key} and {newkey}.') raise ValueError(f"Can't set {key} and {newkey}.")
config[newkey] = _fetch_value(config[key], key) config[newkey] = _fetch_value(config[key], key)
del config[key] del config[key]
@ -30,10 +30,11 @@ def expand_fetch_params(config):
@synchronized() @synchronized()
def _fetch_value(opts, key): def _fetch_value(opts, key):
if not isinstance(opts, list): if not isinstance(opts, list):
raise ValueError('Invalid value for {}: Expected a list, found {!r}.' raise ValueError(
.format(key, opts)) "Invalid value for {}: Expected a list, found {!r}.".format(key, opts)
)
if not opts: if not opts:
raise ValueError('Expected list of length > 0.') raise ValueError("Expected list of length > 0.")
try: try:
ctx = click.get_current_context().find_object(AppContext) ctx = click.get_current_context().find_object(AppContext)
@ -46,7 +47,7 @@ def _fetch_value(opts, key):
cache_key = tuple(opts) cache_key = tuple(opts)
if cache_key in password_cache: if cache_key in password_cache:
rv = password_cache[cache_key] rv = password_cache[cache_key]
logger.debug(f'Found cached value for {opts!r}.') logger.debug(f"Found cached value for {opts!r}.")
if isinstance(rv, BaseException): if isinstance(rv, BaseException):
raise rv raise rv
return rv return rv
@ -55,10 +56,9 @@ def _fetch_value(opts, key):
try: try:
strategy_fn = STRATEGIES[strategy] strategy_fn = STRATEGIES[strategy]
except KeyError: except KeyError:
raise exceptions.UserError(f'Unknown strategy: {strategy}') raise exceptions.UserError(f"Unknown strategy: {strategy}")
logger.debug('Fetching value for {} with {} strategy.' logger.debug("Fetching value for {} with {} strategy.".format(key, strategy))
.format(key, strategy))
try: try:
rv = strategy_fn(*opts[1:]) rv = strategy_fn(*opts[1:])
except (click.Abort, KeyboardInterrupt) as e: except (click.Abort, KeyboardInterrupt) as e:
@ -66,22 +66,25 @@ def _fetch_value(opts, key):
raise raise
else: else:
if not rv: if not rv:
raise exceptions.UserError('Empty value for {}, this most likely ' raise exceptions.UserError(
'indicates an error.' "Empty value for {}, this most likely "
.format(key)) "indicates an error.".format(key)
)
password_cache[cache_key] = rv password_cache[cache_key] = rv
return rv return rv
def _strategy_command(*command): def _strategy_command(*command):
import subprocess import subprocess
command = (expand_path(command[0]),) + command[1:] command = (expand_path(command[0]),) + command[1:]
try: try:
stdout = subprocess.check_output(command, universal_newlines=True) stdout = subprocess.check_output(command, universal_newlines=True)
return stdout.strip('\n') return stdout.strip("\n")
except OSError as e: except OSError as e:
raise exceptions.UserError('Failed to execute command: {}\n{}' raise exceptions.UserError(
.format(' '.join(command), str(e))) "Failed to execute command: {}\n{}".format(" ".join(command), str(e))
)
def _strategy_prompt(text): def _strategy_prompt(text):
@ -89,6 +92,6 @@ def _strategy_prompt(text):
STRATEGIES = { STRATEGIES = {
'command': _strategy_command, "command": _strategy_command,
'prompt': _strategy_prompt, "prompt": _strategy_prompt,
} }

View file

@ -19,28 +19,30 @@ from .utils import save_status
def prepare_pair(wq, pair_name, collections, config, callback, **kwargs): def prepare_pair(wq, pair_name, collections, config, callback, **kwargs):
pair = config.get_pair(pair_name) pair = config.get_pair(pair_name)
all_collections = dict(collections_for_pair( all_collections = dict(
status_path=config.general['status_path'], pair=pair collections_for_pair(status_path=config.general["status_path"], pair=pair)
)) )
# spawn one worker less because we can reuse the current one # spawn one worker less because we can reuse the current one
new_workers = -1 new_workers = -1
for collection_name in (collections or all_collections): for collection_name in collections or all_collections:
try: try:
config_a, config_b = all_collections[collection_name] config_a, config_b = all_collections[collection_name]
except KeyError: except KeyError:
raise exceptions.UserError( raise exceptions.UserError(
'Pair {}: Collection {} not found. These are the ' "Pair {}: Collection {} not found. These are the "
'configured collections:\n{}' "configured collections:\n{}".format(
.format(pair_name, pair_name, json.dumps(collection_name), list(all_collections)
json.dumps(collection_name), )
list(all_collections))) )
new_workers += 1 new_workers += 1
collection = CollectionConfig(pair, collection_name, config_a, collection = CollectionConfig(pair, collection_name, config_a, config_b)
config_b) wq.put(
wq.put(functools.partial(callback, collection=collection, functools.partial(
general=config.general, **kwargs)) callback, collection=collection, general=config.general, **kwargs
)
)
for _ in range(new_workers): for _ in range(new_workers):
wq.spawn_worker() wq.spawn_worker()
@ -51,7 +53,7 @@ def sync_collection(wq, collection, general, force_delete):
status_name = get_status_name(pair.name, collection.name) status_name = get_status_name(pair.name, collection.name)
try: try:
cli_logger.info(f'Syncing {status_name}') cli_logger.info(f"Syncing {status_name}")
a = storage_instance_from_config(collection.config_a) a = storage_instance_from_config(collection.config_a)
b = storage_instance_from_config(collection.config_b) b = storage_instance_from_config(collection.config_b)
@ -63,14 +65,17 @@ def sync_collection(wq, collection, general, force_delete):
sync_failed = True sync_failed = True
handle_cli_error(status_name, e) handle_cli_error(status_name, e)
with manage_sync_status(general['status_path'], pair.name, with manage_sync_status(
collection.name) as status: general["status_path"], pair.name, collection.name
) as status:
sync.sync( sync.sync(
a, b, status, a,
b,
status,
conflict_resolution=pair.conflict_resolution, conflict_resolution=pair.conflict_resolution,
force_delete=force_delete, force_delete=force_delete,
error_callback=error_callback, error_callback=error_callback,
partial_sync=pair.partial_sync partial_sync=pair.partial_sync,
) )
if sync_failed: if sync_failed:
@ -87,62 +92,76 @@ def discover_collections(wq, pair, **kwargs):
collections = list(c for c, (a, b) in rv) collections = list(c for c, (a, b) in rv)
if collections == [None]: if collections == [None]:
collections = None collections = None
cli_logger.info('Saved for {}: collections = {}' cli_logger.info(
.format(pair.name, json.dumps(collections))) "Saved for {}: collections = {}".format(pair.name, json.dumps(collections))
)
def repair_collection(config, collection, repair_unsafe_uid): def repair_collection(config, collection, repair_unsafe_uid):
from ..repair import repair_storage from ..repair import repair_storage
storage_name, collection = collection, None storage_name, collection = collection, None
if '/' in storage_name: if "/" in storage_name:
storage_name, collection = storage_name.split('/') storage_name, collection = storage_name.split("/")
config = config.get_storage_args(storage_name) config = config.get_storage_args(storage_name)
storage_type = config['type'] storage_type = config["type"]
if collection is not None: if collection is not None:
cli_logger.info('Discovering collections (skipping cache).') cli_logger.info("Discovering collections (skipping cache).")
cls, config = storage_class_from_config(config) cls, config = storage_class_from_config(config)
for config in cls.discover(**config): for config in cls.discover(**config):
if config['collection'] == collection: if config["collection"] == collection:
break break
else: else:
raise exceptions.UserError( raise exceptions.UserError(
'Couldn\'t find collection {} for storage {}.' "Couldn't find collection {} for storage {}.".format(
.format(collection, storage_name) collection, storage_name
)
) )
config['type'] = storage_type config["type"] = storage_type
storage = storage_instance_from_config(config) storage = storage_instance_from_config(config)
cli_logger.info(f'Repairing {storage_name}/{collection}') cli_logger.info(f"Repairing {storage_name}/{collection}")
cli_logger.warning('Make sure no other program is talking to the server.') cli_logger.warning("Make sure no other program is talking to the server.")
repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid) repair_storage(storage, repair_unsafe_uid=repair_unsafe_uid)
def metasync_collection(wq, collection, general): def metasync_collection(wq, collection, general):
from ..metasync import metasync from ..metasync import metasync
pair = collection.pair pair = collection.pair
status_name = get_status_name(pair.name, collection.name) status_name = get_status_name(pair.name, collection.name)
try: try:
cli_logger.info(f'Metasyncing {status_name}') cli_logger.info(f"Metasyncing {status_name}")
status = load_status(general['status_path'], pair.name, status = (
collection.name, data_type='metadata') or {} load_status(
general["status_path"], pair.name, collection.name, data_type="metadata"
)
or {}
)
a = storage_instance_from_config(collection.config_a) a = storage_instance_from_config(collection.config_a)
b = storage_instance_from_config(collection.config_b) b = storage_instance_from_config(collection.config_b)
metasync( metasync(
a, b, status, a,
b,
status,
conflict_resolution=pair.conflict_resolution, conflict_resolution=pair.conflict_resolution,
keys=pair.metadata keys=pair.metadata,
) )
except BaseException: except BaseException:
handle_cli_error(status_name) handle_cli_error(status_name)
raise JobFailed() raise JobFailed()
save_status(general['status_path'], pair.name, collection.name, save_status(
data_type='metadata', data=status) general["status_path"],
pair.name,
collection.name,
data_type="metadata",
data=status,
)

View file

@ -31,15 +31,15 @@ STATUS_DIR_PERMISSIONS = 0o700
class _StorageIndex: class _StorageIndex:
def __init__(self): def __init__(self):
self._storages = dict( self._storages = dict(
caldav='vdirsyncer.storage.dav.CalDAVStorage', caldav="vdirsyncer.storage.dav.CalDAVStorage",
carddav='vdirsyncer.storage.dav.CardDAVStorage', carddav="vdirsyncer.storage.dav.CardDAVStorage",
filesystem='vdirsyncer.storage.filesystem.FilesystemStorage', filesystem="vdirsyncer.storage.filesystem.FilesystemStorage",
http='vdirsyncer.storage.http.HttpStorage', http="vdirsyncer.storage.http.HttpStorage",
singlefile='vdirsyncer.storage.singlefile.SingleFileStorage', singlefile="vdirsyncer.storage.singlefile.SingleFileStorage",
google_calendar='vdirsyncer.storage.google.GoogleCalendarStorage', google_calendar="vdirsyncer.storage.google.GoogleCalendarStorage",
google_contacts='vdirsyncer.storage.google.GoogleContactsStorage', google_contacts="vdirsyncer.storage.google.GoogleContactsStorage",
etesync_calendars='vdirsyncer.storage.etesync.EtesyncCalendars', etesync_calendars="vdirsyncer.storage.etesync.EtesyncCalendars",
etesync_contacts='vdirsyncer.storage.etesync.EtesyncContacts' etesync_contacts="vdirsyncer.storage.etesync.EtesyncContacts",
) )
def __getitem__(self, name): def __getitem__(self, name):
@ -47,7 +47,7 @@ class _StorageIndex:
if not isinstance(item, str): if not isinstance(item, str):
return item return item
modname, clsname = item.rsplit('.', 1) modname, clsname = item.rsplit(".", 1)
mod = importlib.import_module(modname) mod = importlib.import_module(modname)
self._storages[name] = rv = getattr(mod, clsname) self._storages[name] = rv = getattr(mod, clsname)
assert rv.storage_name == name assert rv.storage_name == name
@ -63,12 +63,12 @@ class JobFailed(RuntimeError):
def handle_cli_error(status_name=None, e=None): def handle_cli_error(status_name=None, e=None):
''' """
Print a useful error message for the current exception. Print a useful error message for the current exception.
This is supposed to catch all exceptions, and should never raise any This is supposed to catch all exceptions, and should never raise any
exceptions itself. exceptions itself.
''' """
try: try:
if e is not None: if e is not None:
@ -80,101 +80,104 @@ def handle_cli_error(status_name=None, e=None):
except StorageEmpty as e: except StorageEmpty as e:
cli_logger.error( cli_logger.error(
'{status_name}: Storage "{name}" was completely emptied. If you ' '{status_name}: Storage "{name}" was completely emptied. If you '
'want to delete ALL entries on BOTH sides, then use ' "want to delete ALL entries on BOTH sides, then use "
'`vdirsyncer sync --force-delete {status_name}`. ' "`vdirsyncer sync --force-delete {status_name}`. "
'Otherwise delete the files for {status_name} in your status ' "Otherwise delete the files for {status_name} in your status "
'directory.'.format( "directory.".format(
name=e.empty_storage.instance_name, name=e.empty_storage.instance_name, status_name=status_name
status_name=status_name
) )
) )
except PartialSync as e: except PartialSync as e:
cli_logger.error( cli_logger.error(
'{status_name}: Attempted change on {storage}, which is read-only' "{status_name}: Attempted change on {storage}, which is read-only"
'. Set `partial_sync` in your pair section to `ignore` to ignore ' ". Set `partial_sync` in your pair section to `ignore` to ignore "
'those changes, or `revert` to revert them on the other side.' "those changes, or `revert` to revert them on the other side.".format(
.format(status_name=status_name, storage=e.storage) status_name=status_name, storage=e.storage
)
) )
except SyncConflict as e: except SyncConflict as e:
cli_logger.error( cli_logger.error(
'{status_name}: One item changed on both sides. Resolve this ' "{status_name}: One item changed on both sides. Resolve this "
'conflict manually, or by setting the `conflict_resolution` ' "conflict manually, or by setting the `conflict_resolution` "
'parameter in your config file.\n' "parameter in your config file.\n"
'See also {docs}/config.html#pair-section\n' "See also {docs}/config.html#pair-section\n"
'Item ID: {e.ident}\n' "Item ID: {e.ident}\n"
'Item href on side A: {e.href_a}\n' "Item href on side A: {e.href_a}\n"
'Item href on side B: {e.href_b}\n' "Item href on side B: {e.href_b}\n".format(
.format(status_name=status_name, e=e, docs=DOCS_HOME) status_name=status_name, e=e, docs=DOCS_HOME
)
) )
except IdentConflict as e: except IdentConflict as e:
cli_logger.error( cli_logger.error(
'{status_name}: Storage "{storage.instance_name}" contains ' '{status_name}: Storage "{storage.instance_name}" contains '
'multiple items with the same UID or even content. Vdirsyncer ' "multiple items with the same UID or even content. Vdirsyncer "
'will now abort the synchronization of this collection, because ' "will now abort the synchronization of this collection, because "
'the fix for this is not clear; It could be the result of a badly ' "the fix for this is not clear; It could be the result of a badly "
'behaving server. You can try running:\n\n' "behaving server. You can try running:\n\n"
' vdirsyncer repair {storage.instance_name}\n\n' " vdirsyncer repair {storage.instance_name}\n\n"
'But make sure to have a backup of your data in some form. The ' "But make sure to have a backup of your data in some form. The "
'offending hrefs are:\n\n{href_list}\n' "offending hrefs are:\n\n{href_list}\n".format(
.format(status_name=status_name, status_name=status_name,
storage=e.storage, storage=e.storage,
href_list='\n'.join(map(repr, e.hrefs))) href_list="\n".join(map(repr, e.hrefs)),
)
) )
except (click.Abort, KeyboardInterrupt, JobFailed): except (click.Abort, KeyboardInterrupt, JobFailed):
pass pass
except exceptions.PairNotFound as e: except exceptions.PairNotFound as e:
cli_logger.error( cli_logger.error(
'Pair {pair_name} does not exist. Please check your ' "Pair {pair_name} does not exist. Please check your "
'configuration file and make sure you\'ve typed the pair name ' "configuration file and make sure you've typed the pair name "
'correctly'.format(pair_name=e.pair_name) "correctly".format(pair_name=e.pair_name)
) )
except exceptions.InvalidResponse as e: except exceptions.InvalidResponse as e:
cli_logger.error( cli_logger.error(
'The server returned something vdirsyncer doesn\'t understand. ' "The server returned something vdirsyncer doesn't understand. "
'Error message: {!r}\n' "Error message: {!r}\n"
'While this is most likely a serverside problem, the vdirsyncer ' "While this is most likely a serverside problem, the vdirsyncer "
'devs are generally interested in such bugs. Please report it in ' "devs are generally interested in such bugs. Please report it in "
'the issue tracker at {}' "the issue tracker at {}".format(e, BUGTRACKER_HOME)
.format(e, BUGTRACKER_HOME)
) )
except exceptions.CollectionRequired: except exceptions.CollectionRequired:
cli_logger.error( cli_logger.error(
'One or more storages don\'t support `collections = null`. ' "One or more storages don't support `collections = null`. "
'You probably want to set `collections = ["from a", "from b"]`.' 'You probably want to set `collections = ["from a", "from b"]`.'
) )
except Exception as e: except Exception as e:
tb = sys.exc_info()[2] tb = sys.exc_info()[2]
import traceback import traceback
tb = traceback.format_tb(tb) tb = traceback.format_tb(tb)
if status_name: if status_name:
msg = f'Unknown error occurred for {status_name}' msg = f"Unknown error occurred for {status_name}"
else: else:
msg = 'Unknown error occurred' msg = "Unknown error occurred"
msg += f': {e}\nUse `-vdebug` to see the full traceback.' msg += f": {e}\nUse `-vdebug` to see the full traceback."
cli_logger.error(msg) cli_logger.error(msg)
cli_logger.debug(''.join(tb)) cli_logger.debug("".join(tb))
def get_status_name(pair, collection): def get_status_name(pair, collection):
if collection is None: if collection is None:
return pair return pair
return pair + '/' + collection return pair + "/" + collection
def get_status_path(base_path, pair, collection=None, data_type=None): def get_status_path(base_path, pair, collection=None, data_type=None):
assert data_type is not None assert data_type is not None
status_name = get_status_name(pair, collection) status_name = get_status_name(pair, collection)
path = expand_path(os.path.join(base_path, status_name)) path = expand_path(os.path.join(base_path, status_name))
if os.path.isfile(path) and data_type == 'items': if os.path.isfile(path) and data_type == "items":
new_path = path + '.items' new_path = path + ".items"
# XXX: Legacy migration # XXX: Legacy migration
cli_logger.warning('Migrating statuses: Renaming {} to {}' cli_logger.warning(
.format(path, new_path)) "Migrating statuses: Renaming {} to {}".format(path, new_path)
)
os.rename(path, new_path) os.rename(path, new_path)
path += '.' + data_type path += "." + data_type
return path return path
@ -205,20 +208,20 @@ def prepare_status_path(path):
@contextlib.contextmanager @contextlib.contextmanager
def manage_sync_status(base_path, pair_name, collection_name): def manage_sync_status(base_path, pair_name, collection_name):
path = get_status_path(base_path, pair_name, collection_name, 'items') path = get_status_path(base_path, pair_name, collection_name, "items")
status = None status = None
legacy_status = None legacy_status = None
try: try:
# XXX: Legacy migration # XXX: Legacy migration
with open(path, 'rb') as f: with open(path, "rb") as f:
if f.read(1) == b'{': if f.read(1) == b"{":
f.seek(0) f.seek(0)
legacy_status = dict(json.load(f)) legacy_status = dict(json.load(f))
except (OSError, ValueError): except (OSError, ValueError):
pass pass
if legacy_status is not None: if legacy_status is not None:
cli_logger.warning('Migrating legacy status to sqlite') cli_logger.warning("Migrating legacy status to sqlite")
os.remove(path) os.remove(path)
status = SqliteStatus(path) status = SqliteStatus(path)
status.load_legacy_status(legacy_status) status.load_legacy_status(legacy_status)
@ -233,10 +236,10 @@ def save_status(base_path, pair, collection=None, data_type=None, data=None):
assert data_type is not None assert data_type is not None
assert data is not None assert data is not None
status_name = get_status_name(pair, collection) status_name = get_status_name(pair, collection)
path = expand_path(os.path.join(base_path, status_name)) + '.' + data_type path = expand_path(os.path.join(base_path, status_name)) + "." + data_type
prepare_status_path(path) prepare_status_path(path)
with atomic_write(path, mode='w', overwrite=True) as f: with atomic_write(path, mode="w", overwrite=True) as f:
json.dump(data, f) json.dump(data, f)
os.chmod(path, STATUS_PERMISSIONS) os.chmod(path, STATUS_PERMISSIONS)
@ -244,20 +247,19 @@ def save_status(base_path, pair, collection=None, data_type=None, data=None):
def storage_class_from_config(config): def storage_class_from_config(config):
config = dict(config) config = dict(config)
storage_name = config.pop('type') storage_name = config.pop("type")
try: try:
cls = storage_names[storage_name] cls = storage_names[storage_name]
except KeyError: except KeyError:
raise exceptions.UserError( raise exceptions.UserError(f"Unknown storage type: {storage_name}")
f'Unknown storage type: {storage_name}')
return cls, config return cls, config
def storage_instance_from_config(config, create=True): def storage_instance_from_config(config, create=True):
''' """
:param config: A configuration dictionary to pass as kwargs to the class :param config: A configuration dictionary to pass as kwargs to the class
corresponding to config['type'] corresponding to config['type']
''' """
cls, new_config = storage_class_from_config(config) cls, new_config = storage_class_from_config(config)
@ -266,7 +268,8 @@ def storage_instance_from_config(config, create=True):
except exceptions.CollectionNotFound as e: except exceptions.CollectionNotFound as e:
if create: if create:
config = handle_collection_not_found( config = handle_collection_not_found(
config, config.get('collection', None), e=str(e)) config, config.get("collection", None), e=str(e)
)
return storage_instance_from_config(config, create=False) return storage_instance_from_config(config, create=False)
else: else:
raise raise
@ -276,7 +279,7 @@ def storage_instance_from_config(config, create=True):
def handle_storage_init_error(cls, config): def handle_storage_init_error(cls, config):
e = sys.exc_info()[1] e = sys.exc_info()[1]
if not isinstance(e, TypeError) or '__init__' not in repr(e): if not isinstance(e, TypeError) or "__init__" not in repr(e):
raise raise
all, required = get_storage_init_args(cls) all, required = get_storage_init_args(cls)
@ -288,30 +291,34 @@ def handle_storage_init_error(cls, config):
if missing: if missing:
problems.append( problems.append(
'{} storage requires the parameters: {}' "{} storage requires the parameters: {}".format(
.format(cls.storage_name, ', '.join(missing))) cls.storage_name, ", ".join(missing)
)
)
if invalid: if invalid:
problems.append( problems.append(
'{} storage doesn\'t take the parameters: {}' "{} storage doesn't take the parameters: {}".format(
.format(cls.storage_name, ', '.join(invalid))) cls.storage_name, ", ".join(invalid)
)
)
if not problems: if not problems:
raise e raise e
raise exceptions.UserError( raise exceptions.UserError(
'Failed to initialize {}'.format(config['instance_name']), "Failed to initialize {}".format(config["instance_name"]), problems=problems
problems=problems
) )
class WorkerQueue: class WorkerQueue:
''' """
A simple worker-queue setup. A simple worker-queue setup.
Note that workers quit if queue is empty. That means you have to first put Note that workers quit if queue is empty. That means you have to first put
things into the queue before spawning the worker! things into the queue before spawning the worker!
''' """
def __init__(self, max_workers): def __init__(self, max_workers):
self._queue = queue.Queue() self._queue = queue.Queue()
self._workers = [] self._workers = []
@ -369,7 +376,7 @@ class WorkerQueue:
if not self._workers: if not self._workers:
# Ugly hack, needed because ui_worker is not running. # Ugly hack, needed because ui_worker is not running.
click.echo = _echo click.echo = _echo
cli_logger.critical('Nothing to do.') cli_logger.critical("Nothing to do.")
sys.exit(5) sys.exit(5)
ui_worker.run() ui_worker.run()
@ -381,8 +388,9 @@ class WorkerQueue:
tasks_done = next(self.num_done_tasks) tasks_done = next(self.num_done_tasks)
if tasks_failed > 0: if tasks_failed > 0:
cli_logger.error('{} out of {} tasks failed.' cli_logger.error(
.format(tasks_failed, tasks_done)) "{} out of {} tasks failed.".format(tasks_failed, tasks_done)
)
sys.exit(1) sys.exit(1)
def put(self, f): def put(self, f):
@ -392,25 +400,30 @@ class WorkerQueue:
def assert_permissions(path, wanted): def assert_permissions(path, wanted):
permissions = os.stat(path).st_mode & 0o777 permissions = os.stat(path).st_mode & 0o777
if permissions > wanted: if permissions > wanted:
cli_logger.warning('Correcting permissions of {} from {:o} to {:o}' cli_logger.warning(
.format(path, permissions, wanted)) "Correcting permissions of {} from {:o} to {:o}".format(
path, permissions, wanted
)
)
os.chmod(path, wanted) os.chmod(path, wanted)
def handle_collection_not_found(config, collection, e=None): def handle_collection_not_found(config, collection, e=None):
storage_name = config.get('instance_name', None) storage_name = config.get("instance_name", None)
cli_logger.warning('{}No collection {} found for storage {}.' cli_logger.warning(
.format(f'{e}\n' if e else '', "{}No collection {} found for storage {}.".format(
json.dumps(collection), storage_name)) f"{e}\n" if e else "", json.dumps(collection), storage_name
)
)
if click.confirm('Should vdirsyncer attempt to create it?'): if click.confirm("Should vdirsyncer attempt to create it?"):
storage_type = config['type'] storage_type = config["type"]
cls, config = storage_class_from_config(config) cls, config = storage_class_from_config(config)
config['collection'] = collection config["collection"] = collection
try: try:
args = cls.create_collection(**config) args = cls.create_collection(**config)
args['type'] = storage_type args["type"] = storage_type
return args return args
except NotImplementedError as e: except NotImplementedError as e:
cli_logger.error(e) cli_logger.error(e)
@ -418,5 +431,5 @@ def handle_collection_not_found(config, collection, e=None):
raise exceptions.UserError( raise exceptions.UserError(
'Unable to find or create collection "{collection}" for ' 'Unable to find or create collection "{collection}" for '
'storage "{storage}". Please create the collection ' 'storage "{storage}". Please create the collection '
'yourself.'.format(collection=collection, "yourself.".format(collection=collection, storage=storage_name)
storage=storage_name)) )

View file

@ -1,80 +1,81 @@
''' """
Contains exception classes used by vdirsyncer. Not all exceptions are here, Contains exception classes used by vdirsyncer. Not all exceptions are here,
only the most commonly used ones. only the most commonly used ones.
''' """
class Error(Exception): class Error(Exception):
'''Baseclass for all errors.''' """Baseclass for all errors."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if getattr(self, key, object()) is not None: # pragma: no cover if getattr(self, key, object()) is not None: # pragma: no cover
raise TypeError(f'Invalid argument: {key}') raise TypeError(f"Invalid argument: {key}")
setattr(self, key, value) setattr(self, key, value)
super().__init__(*args) super().__init__(*args)
class UserError(Error, ValueError): class UserError(Error, ValueError):
'''Wrapper exception to be used to signify the traceback should not be """Wrapper exception to be used to signify the traceback should not be
shown to the user.''' shown to the user."""
problems = None problems = None
def __str__(self): def __str__(self):
msg = Error.__str__(self) msg = Error.__str__(self)
for problem in self.problems or (): for problem in self.problems or ():
msg += f'\n - {problem}' msg += f"\n - {problem}"
return msg return msg
class CollectionNotFound(Error): class CollectionNotFound(Error):
'''Collection not found''' """Collection not found"""
class PairNotFound(Error): class PairNotFound(Error):
'''Pair not found''' """Pair not found"""
pair_name = None pair_name = None
class PreconditionFailed(Error): class PreconditionFailed(Error):
''' """
- The item doesn't exist although it should - The item doesn't exist although it should
- The item exists although it shouldn't - The item exists although it shouldn't
- The etags don't match. - The etags don't match.
Due to CalDAV we can't actually say which error it is. Due to CalDAV we can't actually say which error it is.
This error may indicate race conditions. This error may indicate race conditions.
''' """
class NotFoundError(PreconditionFailed): class NotFoundError(PreconditionFailed):
'''Item not found''' """Item not found"""
class AlreadyExistingError(PreconditionFailed): class AlreadyExistingError(PreconditionFailed):
'''Item already exists.''' """Item already exists."""
existing_href = None existing_href = None
class WrongEtagError(PreconditionFailed): class WrongEtagError(PreconditionFailed):
'''Wrong etag''' """Wrong etag"""
class ReadOnlyError(Error): class ReadOnlyError(Error):
'''Storage is read-only.''' """Storage is read-only."""
class InvalidResponse(Error, ValueError): class InvalidResponse(Error, ValueError):
'''The backend returned an invalid result.''' """The backend returned an invalid result."""
class UnsupportedMetadataError(Error, NotImplementedError): class UnsupportedMetadataError(Error, NotImplementedError):
'''The storage doesn't support this type of metadata.''' """The storage doesn't support this type of metadata."""
class CollectionRequired(Error): class CollectionRequired(Error):
'''`collection = null` is not allowed.''' """`collection = null` is not allowed."""

View file

@ -9,22 +9,23 @@ from .utils import expand_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
USERAGENT = f'vdirsyncer/{__version__}' USERAGENT = f"vdirsyncer/{__version__}"
def _detect_faulty_requests(): # pragma: no cover def _detect_faulty_requests(): # pragma: no cover
text = ( text = (
'Error during import: {e}\n\n' "Error during import: {e}\n\n"
'If you have installed vdirsyncer from a distro package, please file ' "If you have installed vdirsyncer from a distro package, please file "
'a bug against that package, not vdirsyncer.\n\n' "a bug against that package, not vdirsyncer.\n\n"
'Consult {d}/problems.html#requests-related-importerrors' "Consult {d}/problems.html#requests-related-importerrors"
'-based-distributions on how to work around this.' "-based-distributions on how to work around this."
) )
try: try:
from requests_toolbelt.auth.guess import GuessAuth # noqa from requests_toolbelt.auth.guess import GuessAuth # noqa
except ImportError as e: except ImportError as e:
import sys import sys
print(text.format(e=str(e), d=DOCS_HOME), file=sys.stderr) print(text.format(e=str(e), d=DOCS_HOME), file=sys.stderr)
sys.exit(1) sys.exit(1)
@ -35,28 +36,30 @@ del _detect_faulty_requests
def prepare_auth(auth, username, password): def prepare_auth(auth, username, password):
if username and password: if username and password:
if auth == 'basic' or auth is None: if auth == "basic" or auth is None:
return (username, password) return (username, password)
elif auth == 'digest': elif auth == "digest":
from requests.auth import HTTPDigestAuth from requests.auth import HTTPDigestAuth
return HTTPDigestAuth(username, password) return HTTPDigestAuth(username, password)
elif auth == 'guess': elif auth == "guess":
try: try:
from requests_toolbelt.auth.guess import GuessAuth from requests_toolbelt.auth.guess import GuessAuth
except ImportError: except ImportError:
raise exceptions.UserError( raise exceptions.UserError(
'Your version of requests_toolbelt is too ' "Your version of requests_toolbelt is too "
'old for `guess` authentication. At least ' "old for `guess` authentication. At least "
'version 0.4.0 is required.' "version 0.4.0 is required."
) )
else: else:
return GuessAuth(username, password) return GuessAuth(username, password)
else: else:
raise exceptions.UserError('Unknown authentication method: {}' raise exceptions.UserError("Unknown authentication method: {}".format(auth))
.format(auth))
elif auth: elif auth:
raise exceptions.UserError('You need to specify username and password ' raise exceptions.UserError(
'for {} authentication.'.format(auth)) "You need to specify username and password "
"for {} authentication.".format(auth)
)
else: else:
return None return None
@ -65,24 +68,26 @@ def prepare_verify(verify, verify_fingerprint):
if isinstance(verify, (str, bytes)): if isinstance(verify, (str, bytes)):
verify = expand_path(verify) verify = expand_path(verify)
elif not isinstance(verify, bool): elif not isinstance(verify, bool):
raise exceptions.UserError('Invalid value for verify ({}), ' raise exceptions.UserError(
'must be a path to a PEM-file or boolean.' "Invalid value for verify ({}), "
.format(verify)) "must be a path to a PEM-file or boolean.".format(verify)
)
if verify_fingerprint is not None: if verify_fingerprint is not None:
if not isinstance(verify_fingerprint, (bytes, str)): if not isinstance(verify_fingerprint, (bytes, str)):
raise exceptions.UserError('Invalid value for verify_fingerprint ' raise exceptions.UserError(
'({}), must be a string or null.' "Invalid value for verify_fingerprint "
.format(verify_fingerprint)) "({}), must be a string or null.".format(verify_fingerprint)
)
elif not verify: elif not verify:
raise exceptions.UserError( raise exceptions.UserError(
'Disabling all SSL validation is forbidden. Consider setting ' "Disabling all SSL validation is forbidden. Consider setting "
'verify_fingerprint if you have a broken or self-signed cert.' "verify_fingerprint if you have a broken or self-signed cert."
) )
return { return {
'verify': verify, "verify": verify,
'verify_fingerprint': verify_fingerprint, "verify_fingerprint": verify_fingerprint,
} }
@ -95,22 +100,24 @@ def prepare_client_cert(cert):
def _install_fingerprint_adapter(session, fingerprint): def _install_fingerprint_adapter(session, fingerprint):
prefix = 'https://' prefix = "https://"
try: try:
from requests_toolbelt.adapters.fingerprint import \ from requests_toolbelt.adapters.fingerprint import FingerprintAdapter
FingerprintAdapter
except ImportError: except ImportError:
raise RuntimeError('`verify_fingerprint` can only be used with ' raise RuntimeError(
'requests-toolbelt versions >= 0.4.0') "`verify_fingerprint` can only be used with "
"requests-toolbelt versions >= 0.4.0"
)
if not isinstance(session.adapters[prefix], FingerprintAdapter): if not isinstance(session.adapters[prefix], FingerprintAdapter):
fingerprint_adapter = FingerprintAdapter(fingerprint) fingerprint_adapter = FingerprintAdapter(fingerprint)
session.mount(prefix, fingerprint_adapter) session.mount(prefix, fingerprint_adapter)
def request(method, url, session=None, latin1_fallback=True, def request(
verify_fingerprint=None, **kwargs): method, url, session=None, latin1_fallback=True, verify_fingerprint=None, **kwargs
''' ):
"""
Wrapper method for requests, to ease logging and mocking. Parameters should Wrapper method for requests, to ease logging and mocking. Parameters should
be the same as for ``requests.request``, except: be the same as for ``requests.request``, except:
@ -123,7 +130,7 @@ def request(method, url, session=None, latin1_fallback=True,
autodetection (usually ending up with utf8) instead of plainly falling autodetection (usually ending up with utf8) instead of plainly falling
back to this silly default. See back to this silly default. See
https://github.com/kennethreitz/requests/issues/2042 https://github.com/kennethreitz/requests/issues/2042
''' """
if session is None: if session is None:
session = requests.Session() session = requests.Session()
@ -135,21 +142,23 @@ def request(method, url, session=None, latin1_fallback=True,
func = session.request func = session.request
logger.debug(f'{method} {url}') logger.debug(f"{method} {url}")
logger.debug(kwargs.get('headers', {})) logger.debug(kwargs.get("headers", {}))
logger.debug(kwargs.get('data', None)) logger.debug(kwargs.get("data", None))
logger.debug('Sending request...') logger.debug("Sending request...")
assert isinstance(kwargs.get('data', b''), bytes) assert isinstance(kwargs.get("data", b""), bytes)
r = func(method, url, **kwargs) r = func(method, url, **kwargs)
# See https://github.com/kennethreitz/requests/issues/2042 # See https://github.com/kennethreitz/requests/issues/2042
content_type = r.headers.get('Content-Type', '') content_type = r.headers.get("Content-Type", "")
if not latin1_fallback and \ if (
'charset' not in content_type and \ not latin1_fallback
content_type.startswith('text/'): and "charset" not in content_type
logger.debug('Removing latin1 fallback') and content_type.startswith("text/")
):
logger.debug("Removing latin1 fallback")
r.encoding = None r.encoding = None
logger.debug(r.status_code) logger.debug(r.status_code)
@ -166,7 +175,7 @@ def request(method, url, session=None, latin1_fallback=True,
def _fix_redirects(r, *args, **kwargs): def _fix_redirects(r, *args, **kwargs):
''' """
Requests discards of the body content when it is following a redirect that Requests discards of the body content when it is following a redirect that
is not a 307 or 308. We never want that to happen. is not a 307 or 308. We never want that to happen.
@ -177,7 +186,7 @@ def _fix_redirects(r, *args, **kwargs):
FIXME: This solution isn't very nice. A new hook in requests would be FIXME: This solution isn't very nice. A new hook in requests would be
better. better.
''' """
if r.is_redirect: if r.is_redirect:
logger.debug('Rewriting status code from %s to 307', r.status_code) logger.debug("Rewriting status code from %s to 307", r.status_code)
r.status_code = 307 r.status_code = 307

View file

@ -16,39 +16,37 @@ class MetaSyncConflict(MetaSyncError):
def metasync(storage_a, storage_b, status, keys, conflict_resolution=None): def metasync(storage_a, storage_b, status, keys, conflict_resolution=None):
def _a_to_b(): def _a_to_b():
logger.info(f'Copying {key} to {storage_b}') logger.info(f"Copying {key} to {storage_b}")
storage_b.set_meta(key, a) storage_b.set_meta(key, a)
status[key] = a status[key] = a
def _b_to_a(): def _b_to_a():
logger.info(f'Copying {key} to {storage_a}') logger.info(f"Copying {key} to {storage_a}")
storage_a.set_meta(key, b) storage_a.set_meta(key, b)
status[key] = b status[key] = b
def _resolve_conflict(): def _resolve_conflict():
if a == b: if a == b:
status[key] = a status[key] = a
elif conflict_resolution == 'a wins': elif conflict_resolution == "a wins":
_a_to_b() _a_to_b()
elif conflict_resolution == 'b wins': elif conflict_resolution == "b wins":
_b_to_a() _b_to_a()
else: else:
if callable(conflict_resolution): if callable(conflict_resolution):
logger.warning('Custom commands don\'t work on metasync.') logger.warning("Custom commands don't work on metasync.")
elif conflict_resolution is not None: elif conflict_resolution is not None:
raise exceptions.UserError( raise exceptions.UserError("Invalid conflict resolution setting.")
'Invalid conflict resolution setting.'
)
raise MetaSyncConflict(key) raise MetaSyncConflict(key)
for key in keys: for key in keys:
a = storage_a.get_meta(key) a = storage_a.get_meta(key)
b = storage_b.get_meta(key) b = storage_b.get_meta(key)
s = normalize_meta_value(status.get(key)) s = normalize_meta_value(status.get(key))
logger.debug(f'Key: {key}') logger.debug(f"Key: {key}")
logger.debug(f'A: {a}') logger.debug(f"A: {a}")
logger.debug(f'B: {b}') logger.debug(f"B: {b}")
logger.debug(f'S: {s}') logger.debug(f"S: {s}")
if a != s and b != s: if a != s and b != s:
_resolve_conflict() _resolve_conflict()

View file

@ -16,17 +16,17 @@ def repair_storage(storage, repair_unsafe_uid):
all_hrefs = list(storage.list()) all_hrefs = list(storage.list())
for i, (href, _) in enumerate(all_hrefs): for i, (href, _) in enumerate(all_hrefs):
item, etag = storage.get(href) item, etag = storage.get(href)
logger.info('[{}/{}] Processing {}' logger.info("[{}/{}] Processing {}".format(i, len(all_hrefs), href))
.format(i, len(all_hrefs), href))
try: try:
new_item = repair_item(href, item, seen_uids, repair_unsafe_uid) new_item = repair_item(href, item, seen_uids, repair_unsafe_uid)
except IrreparableItem: except IrreparableItem:
logger.error('Item {!r} is malformed beyond repair. ' logger.error(
'The PRODID property may indicate which software ' "Item {!r} is malformed beyond repair. "
'created this item.' "The PRODID property may indicate which software "
.format(href)) "created this item.".format(href)
logger.error(f'Item content: {item.raw!r}') )
logger.error(f"Item content: {item.raw!r}")
continue continue
seen_uids.add(new_item.uid) seen_uids.add(new_item.uid)
@ -45,17 +45,18 @@ def repair_item(href, item, seen_uids, repair_unsafe_uid):
new_item = item new_item = item
if not item.uid: if not item.uid:
logger.warning('No UID, assigning random UID.') logger.warning("No UID, assigning random UID.")
new_item = item.with_uid(generate_href()) new_item = item.with_uid(generate_href())
elif item.uid in seen_uids: elif item.uid in seen_uids:
logger.warning('Duplicate UID, assigning random UID.') logger.warning("Duplicate UID, assigning random UID.")
new_item = item.with_uid(generate_href()) new_item = item.with_uid(generate_href())
elif not href_safe(item.uid) or not href_safe(basename(href)): elif not href_safe(item.uid) or not href_safe(basename(href)):
if not repair_unsafe_uid: if not repair_unsafe_uid:
logger.warning('UID may cause problems, add ' logger.warning(
'--repair-unsafe-uid to repair.') "UID may cause problems, add " "--repair-unsafe-uid to repair."
)
else: else:
logger.warning('UID or href is unsafe, assigning random UID.') logger.warning("UID or href is unsafe, assigning random UID.")
new_item = item.with_uid(generate_href()) new_item = item.with_uid(generate_href())
if not new_item.uid: if not new_item.uid:

View file

@ -1,6 +1,6 @@
''' """
There are storage classes which control the access to one vdir-collection and There are storage classes which control the access to one vdir-collection and
offer basic CRUD-ish methods for modifying those collections. The exact offer basic CRUD-ish methods for modifying those collections. The exact
interface is described in `vdirsyncer.storage.base`, the `Storage` class should interface is described in `vdirsyncer.storage.base`, the `Storage` class should
be a superclass of all storage classes. be a superclass of all storage classes.
''' """

View file

@ -9,21 +9,22 @@ def mutating_storage_method(f):
@functools.wraps(f) @functools.wraps(f)
def inner(self, *args, **kwargs): def inner(self, *args, **kwargs):
if self.read_only: if self.read_only:
raise exceptions.ReadOnlyError('This storage is read-only.') raise exceptions.ReadOnlyError("This storage is read-only.")
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
return inner return inner
class StorageMeta(type): class StorageMeta(type):
def __init__(cls, name, bases, d): def __init__(cls, name, bases, d):
for method in ('update', 'upload', 'delete'): for method in ("update", "upload", "delete"):
setattr(cls, method, mutating_storage_method(getattr(cls, method))) setattr(cls, method, mutating_storage_method(getattr(cls, method)))
return super().__init__(name, bases, d) return super().__init__(name, bases, d)
class Storage(metaclass=StorageMeta): class Storage(metaclass=StorageMeta):
'''Superclass of all storages, interface that all storages have to """Superclass of all storages, interface that all storages have to
implement. implement.
Terminology: Terminology:
@ -40,9 +41,9 @@ class Storage(metaclass=StorageMeta):
:param read_only: Whether the synchronization algorithm should avoid writes :param read_only: Whether the synchronization algorithm should avoid writes
to this storage. Some storages accept no value other than ``True``. to this storage. Some storages accept no value other than ``True``.
''' """
fileext = '.txt' fileext = ".txt"
# The string used in the config to denote the type of storage. Should be # The string used in the config to denote the type of storage. Should be
# overridden by subclasses. # overridden by subclasses.
@ -67,17 +68,17 @@ class Storage(metaclass=StorageMeta):
if read_only is None: if read_only is None:
read_only = self.read_only read_only = self.read_only
if self.read_only and not read_only: if self.read_only and not read_only:
raise exceptions.UserError('This storage can only be read-only.') raise exceptions.UserError("This storage can only be read-only.")
self.read_only = bool(read_only) self.read_only = bool(read_only)
if collection and instance_name: if collection and instance_name:
instance_name = f'{instance_name}/{collection}' instance_name = f"{instance_name}/{collection}"
self.instance_name = instance_name self.instance_name = instance_name
self.collection = collection self.collection = collection
@classmethod @classmethod
def discover(cls, **kwargs): def discover(cls, **kwargs):
'''Discover collections given a basepath or -URL to many collections. """Discover collections given a basepath or -URL to many collections.
:param **kwargs: Keyword arguments to additionally pass to the storage :param **kwargs: Keyword arguments to additionally pass to the storage
instances returned. You shouldn't pass `collection` here, otherwise instances returned. You shouldn't pass `collection` here, otherwise
@ -90,19 +91,19 @@ class Storage(metaclass=StorageMeta):
machine-readable identifier for the collection, usually obtained machine-readable identifier for the collection, usually obtained
from the last segment of a URL or filesystem path. from the last segment of a URL or filesystem path.
''' """
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): def create_collection(cls, collection, **kwargs):
''' """
Create the specified collection and return the new arguments. Create the specified collection and return the new arguments.
``collection=None`` means the arguments are already pointing to a ``collection=None`` means the arguments are already pointing to a
possible collection location. possible collection location.
The returned args should contain the collection name, for UI purposes. The returned args should contain the collection name, for UI purposes.
''' """
raise NotImplementedError() raise NotImplementedError()
def __repr__(self): def __repr__(self):
@ -112,29 +113,29 @@ class Storage(metaclass=StorageMeta):
except ValueError: except ValueError:
pass pass
return '<{}(**{})>'.format( return "<{}(**{})>".format(
self.__class__.__name__, self.__class__.__name__,
{x: getattr(self, x) for x in self._repr_attributes} {x: getattr(self, x) for x in self._repr_attributes},
) )
def list(self): def list(self):
''' """
:returns: list of (href, etag) :returns: list of (href, etag)
''' """
raise NotImplementedError() raise NotImplementedError()
def get(self, href): def get(self, href):
'''Fetch a single item. """Fetch a single item.
:param href: href to fetch :param href: href to fetch
:returns: (item, etag) :returns: (item, etag)
:raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` if item can't :raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` if item can't
be found. be found.
''' """
raise NotImplementedError() raise NotImplementedError()
def get_multi(self, hrefs): def get_multi(self, hrefs):
'''Fetch multiple items. Duplicate hrefs must be ignored. """Fetch multiple items. Duplicate hrefs must be ignored.
Functionally similar to :py:meth:`get`, but might bring performance Functionally similar to :py:meth:`get`, but might bring performance
benefits on some storages when used cleverly. benefits on some storages when used cleverly.
@ -143,16 +144,16 @@ class Storage(metaclass=StorageMeta):
:raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` if one of the :raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` if one of the
items couldn't be found. items couldn't be found.
:returns: iterable of (href, item, etag) :returns: iterable of (href, item, etag)
''' """
for href in uniq(hrefs): for href in uniq(hrefs):
item, etag = self.get(href) item, etag = self.get(href)
yield href, item, etag yield href, item, etag
def has(self, href): def has(self, href):
'''Check if an item exists by its href. """Check if an item exists by its href.
:returns: True or False :returns: True or False
''' """
try: try:
self.get(href) self.get(href)
except exceptions.PreconditionFailed: except exceptions.PreconditionFailed:
@ -161,7 +162,7 @@ class Storage(metaclass=StorageMeta):
return True return True
def upload(self, item): def upload(self, item):
'''Upload a new item. """Upload a new item.
In cases where the new etag cannot be atomically determined (i.e. in In cases where the new etag cannot be atomically determined (i.e. in
the same "transaction" as the upload itself), this method may return the same "transaction" as the upload itself), this method may return
@ -172,11 +173,11 @@ class Storage(metaclass=StorageMeta):
already an item with that href. already an item with that href.
:returns: (href, etag) :returns: (href, etag)
''' """
raise NotImplementedError() raise NotImplementedError()
def update(self, href, item, etag): def update(self, href, item, etag):
'''Update an item. """Update an item.
The etag may be none in some cases, see `upload`. The etag may be none in some cases, see `upload`.
@ -185,20 +186,20 @@ class Storage(metaclass=StorageMeta):
exist. exist.
:returns: etag :returns: etag
''' """
raise NotImplementedError() raise NotImplementedError()
def delete(self, href, etag): def delete(self, href, etag):
'''Delete an item by href. """Delete an item by href.
:raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has :raises: :exc:`vdirsyncer.exceptions.PreconditionFailed` when item has
a different etag or doesn't exist. a different etag or doesn't exist.
''' """
raise NotImplementedError() raise NotImplementedError()
@contextlib.contextmanager @contextlib.contextmanager
def at_once(self): def at_once(self):
'''A contextmanager that buffers all writes. """A contextmanager that buffers all writes.
Essentially, this:: Essentially, this::
@ -213,34 +214,34 @@ class Storage(metaclass=StorageMeta):
Note that this removes guarantees about which exceptions are returned Note that this removes guarantees about which exceptions are returned
when. when.
''' """
yield yield
def get_meta(self, key): def get_meta(self, key):
'''Get metadata value for collection/storage. """Get metadata value for collection/storage.
See the vdir specification for the keys that *have* to be accepted. See the vdir specification for the keys that *have* to be accepted.
:param key: The metadata key. :param key: The metadata key.
:type key: unicode :type key: unicode
''' """
raise NotImplementedError('This storage does not support metadata.') raise NotImplementedError("This storage does not support metadata.")
def set_meta(self, key, value): def set_meta(self, key, value):
'''Get metadata value for collection/storage. """Get metadata value for collection/storage.
:param key: The metadata key. :param key: The metadata key.
:type key: unicode :type key: unicode
:param value: The value. :param value: The value.
:type value: unicode :type value: unicode
''' """
raise NotImplementedError('This storage does not support metadata.') raise NotImplementedError("This storage does not support metadata.")
def normalize_meta_value(value): def normalize_meta_value(value):
# `None` is returned by iCloud for empty properties. # `None` is returned by iCloud for empty properties.
if not value or value == 'None': if not value or value == "None":
value = '' value = ""
return value.strip() return value.strip()

View file

@ -22,12 +22,12 @@ from .base import Storage
dav_logger = logging.getLogger(__name__) dav_logger = logging.getLogger(__name__)
CALDAV_DT_FORMAT = '%Y%m%dT%H%M%SZ' CALDAV_DT_FORMAT = "%Y%m%dT%H%M%SZ"
def _generate_path_reserved_chars(): def _generate_path_reserved_chars():
for x in "/?#[]!$&'()*+,;": for x in "/?#[]!$&'()*+,;":
x = urlparse.quote(x, '') x = urlparse.quote(x, "")
yield x.upper() yield x.upper()
yield x.lower() yield x.lower()
@ -39,7 +39,7 @@ del _generate_path_reserved_chars
def _contains_quoted_reserved_chars(x): def _contains_quoted_reserved_chars(x):
for y in _path_reserved_chars: for y in _path_reserved_chars:
if y in x: if y in x:
dav_logger.debug(f'Unsafe character: {y!r}') dav_logger.debug(f"Unsafe character: {y!r}")
return True return True
return False return False
@ -50,19 +50,19 @@ def _assert_multistatus_success(r):
root = _parse_xml(r.content) root = _parse_xml(r.content)
except InvalidXMLResponse: except InvalidXMLResponse:
return return
for status in root.findall('.//{DAV:}status'): for status in root.findall(".//{DAV:}status"):
parts = status.text.strip().split() parts = status.text.strip().split()
try: try:
st = int(parts[1]) st = int(parts[1])
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
if st < 200 or st >= 400: if st < 200 or st >= 400:
raise HTTPError(f'Server error: {st}') raise HTTPError(f"Server error: {st}")
def _normalize_href(base, href): def _normalize_href(base, href):
'''Normalize the href to be a path only relative to hostname and """Normalize the href to be a path only relative to hostname and
schema.''' schema."""
orig_href = href orig_href = href
if not href: if not href:
raise ValueError(href) raise ValueError(href)
@ -80,13 +80,12 @@ def _normalize_href(base, href):
old_x = x old_x = x
x = urlparse.unquote(x) x = urlparse.unquote(x)
x = urlparse.quote(x, '/@%:') x = urlparse.quote(x, "/@%:")
if orig_href == x: if orig_href == x:
dav_logger.debug(f'Already normalized: {x!r}') dav_logger.debug(f"Already normalized: {x!r}")
else: else:
dav_logger.debug('Normalized URL from {!r} to {!r}' dav_logger.debug("Normalized URL from {!r} to {!r}".format(orig_href, x))
.format(orig_href, x))
return x return x
@ -96,8 +95,8 @@ class InvalidXMLResponse(exceptions.InvalidResponse):
_BAD_XML_CHARS = ( _BAD_XML_CHARS = (
b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f' b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f"
b'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f' b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
) )
@ -105,8 +104,8 @@ def _clean_body(content, bad_chars=_BAD_XML_CHARS):
new_content = content.translate(None, bad_chars) new_content = content.translate(None, bad_chars)
if new_content != content: if new_content != content:
dav_logger.warning( dav_logger.warning(
'Your server incorrectly returned ASCII control characters in its ' "Your server incorrectly returned ASCII control characters in its "
'XML. Vdirsyncer ignores those, but this is a bug in your server.' "XML. Vdirsyncer ignores those, but this is a bug in your server."
) )
return new_content return new_content
@ -115,9 +114,10 @@ def _parse_xml(content):
try: try:
return etree.XML(_clean_body(content)) return etree.XML(_clean_body(content))
except etree.ParseError as e: except etree.ParseError as e:
raise InvalidXMLResponse('Invalid XML encountered: {}\n' raise InvalidXMLResponse(
'Double-check the URLs in your config.' "Invalid XML encountered: {}\n"
.format(e)) "Double-check the URLs in your config.".format(e)
)
def _merge_xml(items): def _merge_xml(items):
@ -137,7 +137,7 @@ def _fuzzy_matches_mimetype(strict, weak):
if strict is None or weak is None: if strict is None or weak is None:
return True return True
mediatype, subtype = strict.split('/') mediatype, subtype = strict.split("/")
if subtype in weak: if subtype in weak:
return True return True
return False return False
@ -158,27 +158,27 @@ class Discover:
""" """
def __init__(self, session, kwargs): def __init__(self, session, kwargs):
if kwargs.pop('collection', None) is not None: if kwargs.pop("collection", None) is not None:
raise TypeError('collection argument must not be given.') raise TypeError("collection argument must not be given.")
self.session = session self.session = session
self.kwargs = kwargs self.kwargs = kwargs
@staticmethod @staticmethod
def _get_collection_from_url(url): def _get_collection_from_url(url):
_, collection = url.rstrip('/').rsplit('/', 1) _, collection = url.rstrip("/").rsplit("/", 1)
return urlparse.unquote(collection) return urlparse.unquote(collection)
def find_principal(self): def find_principal(self):
try: try:
return self._find_principal_impl('') return self._find_principal_impl("")
except (HTTPError, exceptions.Error): except (HTTPError, exceptions.Error):
dav_logger.debug('Trying out well-known URI') dav_logger.debug("Trying out well-known URI")
return self._find_principal_impl(self._well_known_uri) return self._find_principal_impl(self._well_known_uri)
def _find_principal_impl(self, url): def _find_principal_impl(self, url):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Depth'] = '0' headers["Depth"] = "0"
body = b""" body = b"""
<propfind xmlns="DAV:"> <propfind xmlns="DAV:">
<prop> <prop>
@ -187,106 +187,102 @@ class Discover:
</propfind> </propfind>
""" """
response = self.session.request('PROPFIND', url, headers=headers, response = self.session.request("PROPFIND", url, headers=headers, data=body)
data=body)
root = _parse_xml(response.content) root = _parse_xml(response.content)
rv = root.find('.//{DAV:}current-user-principal/{DAV:}href') rv = root.find(".//{DAV:}current-user-principal/{DAV:}href")
if rv is None: if rv is None:
# This is for servers that don't support current-user-principal # This is for servers that don't support current-user-principal
# E.g. Synology NAS # E.g. Synology NAS
# See https://github.com/pimutils/vdirsyncer/issues/498 # See https://github.com/pimutils/vdirsyncer/issues/498
dav_logger.debug( dav_logger.debug(
'No current-user-principal returned, re-using URL {}' "No current-user-principal returned, re-using URL {}".format(
.format(response.url)) response.url
)
)
return response.url return response.url
return urlparse.urljoin(response.url, rv.text).rstrip('/') + '/' return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/"
def find_home(self): def find_home(self):
url = self.find_principal() url = self.find_principal()
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Depth'] = '0' headers["Depth"] = "0"
response = self.session.request('PROPFIND', url, response = self.session.request(
headers=headers, "PROPFIND", url, headers=headers, data=self._homeset_xml
data=self._homeset_xml) )
root = etree.fromstring(response.content) root = etree.fromstring(response.content)
# Better don't do string formatting here, because of XML namespaces # Better don't do string formatting here, because of XML namespaces
rv = root.find('.//' + self._homeset_tag + '/{DAV:}href') rv = root.find(".//" + self._homeset_tag + "/{DAV:}href")
if rv is None: if rv is None:
raise InvalidXMLResponse('Couldn\'t find home-set.') raise InvalidXMLResponse("Couldn't find home-set.")
return urlparse.urljoin(response.url, rv.text).rstrip('/') + '/' return urlparse.urljoin(response.url, rv.text).rstrip("/") + "/"
def find_collections(self): def find_collections(self):
rv = None rv = None
try: try:
rv = list(self._find_collections_impl('')) rv = list(self._find_collections_impl(""))
except (HTTPError, exceptions.Error): except (HTTPError, exceptions.Error):
pass pass
if rv: if rv:
return rv return rv
dav_logger.debug('Given URL is not a homeset URL') dav_logger.debug("Given URL is not a homeset URL")
return self._find_collections_impl(self.find_home()) return self._find_collections_impl(self.find_home())
def _check_collection_resource_type(self, response): def _check_collection_resource_type(self, response):
if self._resourcetype is None: if self._resourcetype is None:
return True return True
props = _merge_xml(response.findall( props = _merge_xml(response.findall("{DAV:}propstat/{DAV:}prop"))
'{DAV:}propstat/{DAV:}prop'
))
if props is None or not len(props): if props is None or not len(props):
dav_logger.debug('Skipping, missing <prop>: %s', response) dav_logger.debug("Skipping, missing <prop>: %s", response)
return False return False
if props.find('{DAV:}resourcetype/' + self._resourcetype) \ if props.find("{DAV:}resourcetype/" + self._resourcetype) is None:
is None: dav_logger.debug(
dav_logger.debug('Skipping, not of resource type %s: %s', "Skipping, not of resource type %s: %s", self._resourcetype, response
self._resourcetype, response) )
return False return False
return True return True
def _find_collections_impl(self, url): def _find_collections_impl(self, url):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Depth'] = '1' headers["Depth"] = "1"
r = self.session.request('PROPFIND', url, headers=headers, r = self.session.request(
data=self._collection_xml) "PROPFIND", url, headers=headers, data=self._collection_xml
)
root = _parse_xml(r.content) root = _parse_xml(r.content)
done = set() done = set()
for response in root.findall('{DAV:}response'): for response in root.findall("{DAV:}response"):
if not self._check_collection_resource_type(response): if not self._check_collection_resource_type(response):
continue continue
href = response.find('{DAV:}href') href = response.find("{DAV:}href")
if href is None: if href is None:
raise InvalidXMLResponse('Missing href tag for collection ' raise InvalidXMLResponse("Missing href tag for collection " "props.")
'props.')
href = urlparse.urljoin(r.url, href.text) href = urlparse.urljoin(r.url, href.text)
if href not in done: if href not in done:
done.add(href) done.add(href)
yield {'href': href} yield {"href": href}
def discover(self): def discover(self):
for c in self.find_collections(): for c in self.find_collections():
url = c['href'] url = c["href"]
collection = self._get_collection_from_url(url) collection = self._get_collection_from_url(url)
storage_args = dict(self.kwargs) storage_args = dict(self.kwargs)
storage_args.update({'url': url, 'collection': collection}) storage_args.update({"url": url, "collection": collection})
yield storage_args yield storage_args
def create(self, collection): def create(self, collection):
if collection is None: if collection is None:
collection = self._get_collection_from_url(self.kwargs['url']) collection = self._get_collection_from_url(self.kwargs["url"])
for c in self.discover(): for c in self.discover():
if c['collection'] == collection: if c["collection"] == collection:
return c return c
home = self.find_home() home = self.find_home()
url = urlparse.urljoin( url = urlparse.urljoin(home, urlparse.quote(collection, "/@"))
home,
urlparse.quote(collection, '/@')
)
try: try:
url = self._create_collection_impl(url) url = self._create_collection_impl(url)
@ -294,12 +290,12 @@ class Discover:
raise NotImplementedError(e) raise NotImplementedError(e)
else: else:
rv = dict(self.kwargs) rv = dict(self.kwargs)
rv['collection'] = collection rv["collection"] = collection
rv['url'] = url rv["url"] = url
return rv return rv
def _create_collection_impl(self, url): def _create_collection_impl(self, url):
data = '''<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<mkcol xmlns="DAV:"> <mkcol xmlns="DAV:">
<set> <set>
<prop> <prop>
@ -310,13 +306,14 @@ class Discover:
</prop> </prop>
</set> </set>
</mkcol> </mkcol>
'''.format( """.format(
etree.tostring(etree.Element(self._resourcetype), etree.tostring(etree.Element(self._resourcetype), encoding="unicode")
encoding='unicode') ).encode(
).encode('utf-8') "utf-8"
)
response = self.session.request( response = self.session.request(
'MKCOL', "MKCOL",
url, url,
data=data, data=data,
headers=self.session.get_default_headers(), headers=self.session.get_default_headers(),
@ -325,8 +322,8 @@ class Discover:
class CalDiscover(Discover): class CalDiscover(Discover):
_namespace = 'urn:ietf:params:xml:ns:caldav' _namespace = "urn:ietf:params:xml:ns:caldav"
_resourcetype = '{%s}calendar' % _namespace _resourcetype = "{%s}calendar" % _namespace
_homeset_xml = b""" _homeset_xml = b"""
<propfind xmlns="DAV:" xmlns:c="urn:ietf:params:xml:ns:caldav"> <propfind xmlns="DAV:" xmlns:c="urn:ietf:params:xml:ns:caldav">
<prop> <prop>
@ -334,13 +331,13 @@ class CalDiscover(Discover):
</prop> </prop>
</propfind> </propfind>
""" """
_homeset_tag = '{%s}calendar-home-set' % _namespace _homeset_tag = "{%s}calendar-home-set" % _namespace
_well_known_uri = '/.well-known/caldav' _well_known_uri = "/.well-known/caldav"
class CardDiscover(Discover): class CardDiscover(Discover):
_namespace = 'urn:ietf:params:xml:ns:carddav' _namespace = "urn:ietf:params:xml:ns:carddav"
_resourcetype = '{%s}addressbook' % _namespace _resourcetype = "{%s}addressbook" % _namespace
_homeset_xml = b""" _homeset_xml = b"""
<propfind xmlns="DAV:" xmlns:c="urn:ietf:params:xml:ns:carddav"> <propfind xmlns="DAV:" xmlns:c="urn:ietf:params:xml:ns:carddav">
<prop> <prop>
@ -348,34 +345,41 @@ class CardDiscover(Discover):
</prop> </prop>
</propfind> </propfind>
""" """
_homeset_tag = '{%s}addressbook-home-set' % _namespace _homeset_tag = "{%s}addressbook-home-set" % _namespace
_well_known_uri = '/.well-known/carddav' _well_known_uri = "/.well-known/carddav"
class DAVSession: class DAVSession:
''' """
A helper class to connect to DAV servers. A helper class to connect to DAV servers.
''' """
@classmethod @classmethod
def init_and_remaining_args(cls, **kwargs): def init_and_remaining_args(cls, **kwargs):
argspec = getfullargspec(cls.__init__) argspec = getfullargspec(cls.__init__)
self_args, remainder = \ self_args, remainder = utils.split_dict(kwargs, argspec.args.__contains__)
utils.split_dict(kwargs, argspec.args.__contains__)
return cls(**self_args), remainder return cls(**self_args), remainder
def __init__(self, url, username='', password='', verify=True, auth=None, def __init__(
useragent=USERAGENT, verify_fingerprint=None, self,
auth_cert=None): url,
username="",
password="",
verify=True,
auth=None,
useragent=USERAGENT,
verify_fingerprint=None,
auth_cert=None,
):
self._settings = { self._settings = {
'cert': prepare_client_cert(auth_cert), "cert": prepare_client_cert(auth_cert),
'auth': prepare_auth(auth, username, password) "auth": prepare_auth(auth, username, password),
} }
self._settings.update(prepare_verify(verify, verify_fingerprint)) self._settings.update(prepare_verify(verify, verify_fingerprint))
self.useragent = useragent self.useragent = useragent
self.url = url.rstrip('/') + '/' self.url = url.rstrip("/") + "/"
self._session = requests.session() self._session = requests.session()
@ -394,8 +398,8 @@ class DAVSession:
def get_default_headers(self): def get_default_headers(self):
return { return {
'User-Agent': self.useragent, "User-Agent": self.useragent,
'Content-Type': 'application/xml; charset=UTF-8' "Content-Type": "application/xml; charset=UTF-8",
} }
@ -413,19 +417,18 @@ class DAVStorage(Storage):
# The DAVSession class to use # The DAVSession class to use
session_class = DAVSession session_class = DAVSession
_repr_attributes = ('username', 'url') _repr_attributes = ("username", "url")
_property_table = { _property_table = {
'displayname': ('displayname', 'DAV:'), "displayname": ("displayname", "DAV:"),
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
# defined for _repr_attributes # defined for _repr_attributes
self.username = kwargs.get('username') self.username = kwargs.get("username")
self.url = kwargs.get('url') self.url = kwargs.get("url")
self.session, kwargs = \ self.session, kwargs = self.session_class.init_and_remaining_args(**kwargs)
self.session_class.init_and_remaining_args(**kwargs)
super().__init__(**kwargs) super().__init__(**kwargs)
__init__.__signature__ = signature(session_class.__init__) __init__.__signature__ = signature(session_class.__init__)
@ -463,17 +466,13 @@ class DAVStorage(Storage):
for href in hrefs: for href in hrefs:
if href != self._normalize_href(href): if href != self._normalize_href(href):
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
href_xml.append(f'<href>{href}</href>') href_xml.append(f"<href>{href}</href>")
if not href_xml: if not href_xml:
return () return ()
data = self.get_multi_template \ data = self.get_multi_template.format(hrefs="\n".join(href_xml)).encode("utf-8")
.format(hrefs='\n'.join(href_xml)).encode('utf-8')
response = self.session.request( response = self.session.request(
'REPORT', "REPORT", "", data=data, headers=self.session.get_default_headers()
'',
data=data,
headers=self.session.get_default_headers()
) )
root = _parse_xml(response.content) # etree only can handle bytes root = _parse_xml(response.content) # etree only can handle bytes
rv = [] rv = []
@ -481,11 +480,12 @@ class DAVStorage(Storage):
for href, etag, prop in self._parse_prop_responses(root): for href, etag, prop in self._parse_prop_responses(root):
raw = prop.find(self.get_multi_data_query) raw = prop.find(self.get_multi_data_query)
if raw is None: if raw is None:
dav_logger.warning('Skipping {}, the item content is missing.' dav_logger.warning(
.format(href)) "Skipping {}, the item content is missing.".format(href)
)
continue continue
raw = raw.text or '' raw = raw.text or ""
if isinstance(raw, bytes): if isinstance(raw, bytes):
raw = raw.decode(response.encoding) raw = raw.decode(response.encoding)
@ -496,11 +496,9 @@ class DAVStorage(Storage):
hrefs_left.remove(href) hrefs_left.remove(href)
except KeyError: except KeyError:
if href in hrefs: if href in hrefs:
dav_logger.warning('Server sent item twice: {}' dav_logger.warning("Server sent item twice: {}".format(href))
.format(href))
else: else:
dav_logger.warning('Server sent unsolicited item: {}' dav_logger.warning("Server sent unsolicited item: {}".format(href))
.format(href))
else: else:
rv.append((href, Item(raw), etag)) rv.append((href, Item(raw), etag))
for href in hrefs_left: for href in hrefs_left:
@ -509,17 +507,14 @@ class DAVStorage(Storage):
def _put(self, href, item, etag): def _put(self, href, item, etag):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Content-Type'] = self.item_mimetype headers["Content-Type"] = self.item_mimetype
if etag is None: if etag is None:
headers['If-None-Match'] = '*' headers["If-None-Match"] = "*"
else: else:
headers['If-Match'] = etag headers["If-Match"] = etag
response = self.session.request( response = self.session.request(
'PUT', "PUT", href, data=item.raw.encode("utf-8"), headers=headers
href,
data=item.raw.encode('utf-8'),
headers=headers
) )
_assert_multistatus_success(response) _assert_multistatus_success(response)
@ -538,13 +533,13 @@ class DAVStorage(Storage):
# #
# In such cases we return a constant etag. The next synchronization # In such cases we return a constant etag. The next synchronization
# will then detect an etag change and will download the new item. # will then detect an etag change and will download the new item.
etag = response.headers.get('etag', None) etag = response.headers.get("etag", None)
href = self._normalize_href(response.url) href = self._normalize_href(response.url)
return href, etag return href, etag
def update(self, href, item, etag): def update(self, href, item, etag):
if etag is None: if etag is None:
raise ValueError('etag must be given and must not be None.') raise ValueError("etag must be given and must not be None.")
href, etag = self._put(self._normalize_href(href), item, etag) href, etag = self._put(self._normalize_href(href), item, etag)
return etag return etag
@ -555,23 +550,17 @@ class DAVStorage(Storage):
def delete(self, href, etag): def delete(self, href, etag):
href = self._normalize_href(href) href = self._normalize_href(href)
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers.update({ headers.update({"If-Match": etag})
'If-Match': etag
})
self.session.request( self.session.request("DELETE", href, headers=headers)
'DELETE',
href,
headers=headers
)
def _parse_prop_responses(self, root, handled_hrefs=None): def _parse_prop_responses(self, root, handled_hrefs=None):
if handled_hrefs is None: if handled_hrefs is None:
handled_hrefs = set() handled_hrefs = set()
for response in root.iter('{DAV:}response'): for response in root.iter("{DAV:}response"):
href = response.find('{DAV:}href') href = response.find("{DAV:}href")
if href is None: if href is None:
dav_logger.error('Skipping response, href is missing.') dav_logger.error("Skipping response, href is missing.")
continue continue
href = self._normalize_href(href.text) href = self._normalize_href(href.text)
@ -582,34 +571,34 @@ class DAVStorage(Storage):
# https://github.com/pimutils/vdirsyncer/issues/88 # https://github.com/pimutils/vdirsyncer/issues/88
# - Davmail # - Davmail
# https://github.com/pimutils/vdirsyncer/issues/144 # https://github.com/pimutils/vdirsyncer/issues/144
dav_logger.warning('Skipping identical href: {!r}' dav_logger.warning("Skipping identical href: {!r}".format(href))
.format(href))
continue continue
props = response.findall('{DAV:}propstat/{DAV:}prop') props = response.findall("{DAV:}propstat/{DAV:}prop")
if props is None or not len(props): if props is None or not len(props):
dav_logger.debug('Skipping {!r}, properties are missing.' dav_logger.debug("Skipping {!r}, properties are missing.".format(href))
.format(href))
continue continue
else: else:
props = _merge_xml(props) props = _merge_xml(props)
if props.find('{DAV:}resourcetype/{DAV:}collection') is not None: if props.find("{DAV:}resourcetype/{DAV:}collection") is not None:
dav_logger.debug(f'Skipping {href!r}, is collection.') dav_logger.debug(f"Skipping {href!r}, is collection.")
continue continue
etag = getattr(props.find('{DAV:}getetag'), 'text', '') etag = getattr(props.find("{DAV:}getetag"), "text", "")
if not etag: if not etag:
dav_logger.debug('Skipping {!r}, etag property is missing.' dav_logger.debug(
.format(href)) "Skipping {!r}, etag property is missing.".format(href)
)
continue continue
contenttype = getattr(props.find('{DAV:}getcontenttype'), contenttype = getattr(props.find("{DAV:}getcontenttype"), "text", None)
'text', None)
if not self._is_item_mimetype(contenttype): if not self._is_item_mimetype(contenttype):
dav_logger.debug('Skipping {!r}, {!r} != {!r}.' dav_logger.debug(
.format(href, contenttype, "Skipping {!r}, {!r} != {!r}.".format(
self.item_mimetype)) href, contenttype, self.item_mimetype
)
)
continue continue
handled_hrefs.add(href) handled_hrefs.add(href)
@ -617,9 +606,9 @@ class DAVStorage(Storage):
def list(self): def list(self):
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Depth'] = '1' headers["Depth"] = "1"
data = b'''<?xml version="1.0" encoding="utf-8" ?> data = b"""<?xml version="1.0" encoding="utf-8" ?>
<propfind xmlns="DAV:"> <propfind xmlns="DAV:">
<prop> <prop>
<resourcetype/> <resourcetype/>
@ -627,12 +616,11 @@ class DAVStorage(Storage):
<getetag/> <getetag/>
</prop> </prop>
</propfind> </propfind>
''' """
# We use a PROPFIND request instead of addressbook-query due to issues # We use a PROPFIND request instead of addressbook-query due to issues
# with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83 # with Zimbra. See https://github.com/pimutils/vdirsyncer/issues/83
response = self.session.request('PROPFIND', '', data=data, response = self.session.request("PROPFIND", "", data=data, headers=headers)
headers=headers)
root = _parse_xml(response.content) root = _parse_xml(response.content)
rv = self._parse_prop_responses(root) rv = self._parse_prop_responses(root)
@ -645,32 +633,31 @@ class DAVStorage(Storage):
except KeyError: except KeyError:
raise exceptions.UnsupportedMetadataError() raise exceptions.UnsupportedMetadataError()
xpath = f'{{{namespace}}}{tagname}' xpath = f"{{{namespace}}}{tagname}"
data = '''<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<propfind xmlns="DAV:"> <propfind xmlns="DAV:">
<prop> <prop>
{} {}
</prop> </prop>
</propfind> </propfind>
'''.format( """.format(
etree.tostring(etree.Element(xpath), encoding='unicode') etree.tostring(etree.Element(xpath), encoding="unicode")
).encode('utf-8') ).encode(
"utf-8"
)
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
headers['Depth'] = '0' headers["Depth"] = "0"
response = self.session.request( response = self.session.request("PROPFIND", "", data=data, headers=headers)
'PROPFIND', '',
data=data, headers=headers
)
root = _parse_xml(response.content) root = _parse_xml(response.content)
for prop in root.findall('.//' + xpath): for prop in root.findall(".//" + xpath):
text = normalize_meta_value(getattr(prop, 'text', None)) text = normalize_meta_value(getattr(prop, "text", None))
if text: if text:
return text return text
return '' return ""
def set_meta(self, key, value): def set_meta(self, key, value):
try: try:
@ -678,11 +665,11 @@ class DAVStorage(Storage):
except KeyError: except KeyError:
raise exceptions.UnsupportedMetadataError() raise exceptions.UnsupportedMetadataError()
lxml_selector = f'{{{namespace}}}{tagname}' lxml_selector = f"{{{namespace}}}{tagname}"
element = etree.Element(lxml_selector) element = etree.Element(lxml_selector)
element.text = normalize_meta_value(value) element.text = normalize_meta_value(value)
data = '''<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<propertyupdate xmlns="DAV:"> <propertyupdate xmlns="DAV:">
<set> <set>
<prop> <prop>
@ -690,11 +677,14 @@ class DAVStorage(Storage):
</prop> </prop>
</set> </set>
</propertyupdate> </propertyupdate>
'''.format(etree.tostring(element, encoding='unicode')).encode('utf-8') """.format(
etree.tostring(element, encoding="unicode")
).encode(
"utf-8"
)
self.session.request( self.session.request(
'PROPPATCH', '', "PROPPATCH", "", data=data, headers=self.session.get_default_headers()
data=data, headers=self.session.get_default_headers()
) )
# XXX: Response content is currently ignored. Though exceptions are # XXX: Response content is currently ignored. Though exceptions are
@ -705,15 +695,15 @@ class DAVStorage(Storage):
class CalDAVStorage(DAVStorage): class CalDAVStorage(DAVStorage):
storage_name = 'caldav' storage_name = "caldav"
fileext = '.ics' fileext = ".ics"
item_mimetype = 'text/calendar' item_mimetype = "text/calendar"
discovery_class = CalDiscover discovery_class = CalDiscover
start_date = None start_date = None
end_date = None end_date = None
get_multi_template = '''<?xml version="1.0" encoding="utf-8" ?> get_multi_template = """<?xml version="1.0" encoding="utf-8" ?>
<C:calendar-multiget xmlns="DAV:" <C:calendar-multiget xmlns="DAV:"
xmlns:C="urn:ietf:params:xml:ns:caldav"> xmlns:C="urn:ietf:params:xml:ns:caldav">
<prop> <prop>
@ -721,70 +711,73 @@ class CalDAVStorage(DAVStorage):
<C:calendar-data/> <C:calendar-data/>
</prop> </prop>
{hrefs} {hrefs}
</C:calendar-multiget>''' </C:calendar-multiget>"""
get_multi_data_query = '{urn:ietf:params:xml:ns:caldav}calendar-data' get_multi_data_query = "{urn:ietf:params:xml:ns:caldav}calendar-data"
_property_table = dict(DAVStorage._property_table) _property_table = dict(DAVStorage._property_table)
_property_table.update({ _property_table.update(
'color': ('calendar-color', 'http://apple.com/ns/ical/'), {
}) "color": ("calendar-color", "http://apple.com/ns/ical/"),
}
)
def __init__(self, start_date=None, end_date=None, def __init__(self, start_date=None, end_date=None, item_types=(), **kwargs):
item_types=(), **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if not isinstance(item_types, (list, tuple)): if not isinstance(item_types, (list, tuple)):
raise exceptions.UserError('item_types must be a list.') raise exceptions.UserError("item_types must be a list.")
self.item_types = tuple(item_types) self.item_types = tuple(item_types)
if (start_date is None) != (end_date is None): if (start_date is None) != (end_date is None):
raise exceptions.UserError('If start_date is given, ' raise exceptions.UserError(
'end_date has to be given too.') "If start_date is given, " "end_date has to be given too."
)
elif start_date is not None and end_date is not None: elif start_date is not None and end_date is not None:
namespace = dict(datetime.__dict__) namespace = dict(datetime.__dict__)
namespace['start_date'] = self.start_date = \ namespace["start_date"] = self.start_date = (
(eval(start_date, namespace) eval(start_date, namespace)
if isinstance(start_date, (bytes, str)) if isinstance(start_date, (bytes, str))
else start_date) else start_date
self.end_date = \ )
(eval(end_date, namespace) self.end_date = (
if isinstance(end_date, (bytes, str)) eval(end_date, namespace)
else end_date) if isinstance(end_date, (bytes, str))
else end_date
)
@staticmethod @staticmethod
def _get_list_filters(components, start, end): def _get_list_filters(components, start, end):
if components: if components:
caldavfilter = ''' caldavfilter = """
<C:comp-filter name="VCALENDAR"> <C:comp-filter name="VCALENDAR">
<C:comp-filter name="{component}"> <C:comp-filter name="{component}">
{timefilter} {timefilter}
</C:comp-filter> </C:comp-filter>
</C:comp-filter> </C:comp-filter>
''' """
if start is not None and end is not None: if start is not None and end is not None:
start = start.strftime(CALDAV_DT_FORMAT) start = start.strftime(CALDAV_DT_FORMAT)
end = end.strftime(CALDAV_DT_FORMAT) end = end.strftime(CALDAV_DT_FORMAT)
timefilter = ('<C:time-range start="{start}" end="{end}"/>' timefilter = '<C:time-range start="{start}" end="{end}"/>'.format(
.format(start=start, end=end)) start=start, end=end
)
else: else:
timefilter = '' timefilter = ""
for component in components: for component in components:
yield caldavfilter.format(component=component, yield caldavfilter.format(component=component, timefilter=timefilter)
timefilter=timefilter)
else: else:
if start is not None and end is not None: if start is not None and end is not None:
yield from CalDAVStorage._get_list_filters(('VTODO', 'VEVENT'), yield from CalDAVStorage._get_list_filters(
start, end) ("VTODO", "VEVENT"), start, end
)
def list(self): def list(self):
caldavfilters = list(self._get_list_filters( caldavfilters = list(
self.item_types, self._get_list_filters(self.item_types, self.start_date, self.end_date)
self.start_date, )
self.end_date
))
if not caldavfilters: if not caldavfilters:
# If we don't have any filters (which is the default), taking the # If we don't have any filters (which is the default), taking the
# risk of sending a calendar-query is not necessary. There doesn't # risk of sending a calendar-query is not necessary. There doesn't
@ -795,7 +788,7 @@ class CalDAVStorage(DAVStorage):
# See https://github.com/dmfs/tasks/issues/118 for backstory. # See https://github.com/dmfs/tasks/issues/118 for backstory.
yield from DAVStorage.list(self) yield from DAVStorage.list(self)
data = '''<?xml version="1.0" encoding="utf-8" ?> data = """<?xml version="1.0" encoding="utf-8" ?>
<C:calendar-query xmlns="DAV:" <C:calendar-query xmlns="DAV:"
xmlns:C="urn:ietf:params:xml:ns:caldav"> xmlns:C="urn:ietf:params:xml:ns:caldav">
<prop> <prop>
@ -805,21 +798,20 @@ class CalDAVStorage(DAVStorage):
<C:filter> <C:filter>
{caldavfilter} {caldavfilter}
</C:filter> </C:filter>
</C:calendar-query>''' </C:calendar-query>"""
headers = self.session.get_default_headers() headers = self.session.get_default_headers()
# https://github.com/pimutils/vdirsyncer/issues/166 # https://github.com/pimutils/vdirsyncer/issues/166
# The default in CalDAV's calendar-queries is 0, but the examples use # The default in CalDAV's calendar-queries is 0, but the examples use
# an explicit value of 1 for querying items. it is extremely unclear in # an explicit value of 1 for querying items. it is extremely unclear in
# the spec which values from WebDAV are actually allowed. # the spec which values from WebDAV are actually allowed.
headers['Depth'] = '1' headers["Depth"] = "1"
handled_hrefs = set() handled_hrefs = set()
for caldavfilter in caldavfilters: for caldavfilter in caldavfilters:
xml = data.format(caldavfilter=caldavfilter).encode('utf-8') xml = data.format(caldavfilter=caldavfilter).encode("utf-8")
response = self.session.request('REPORT', '', data=xml, response = self.session.request("REPORT", "", data=xml, headers=headers)
headers=headers)
root = _parse_xml(response.content) root = _parse_xml(response.content)
rv = self._parse_prop_responses(root, handled_hrefs) rv = self._parse_prop_responses(root, handled_hrefs)
for href, etag, _prop in rv: for href, etag, _prop in rv:
@ -827,12 +819,12 @@ class CalDAVStorage(DAVStorage):
class CardDAVStorage(DAVStorage): class CardDAVStorage(DAVStorage):
storage_name = 'carddav' storage_name = "carddav"
fileext = '.vcf' fileext = ".vcf"
item_mimetype = 'text/vcard' item_mimetype = "text/vcard"
discovery_class = CardDiscover discovery_class = CardDiscover
get_multi_template = '''<?xml version="1.0" encoding="utf-8" ?> get_multi_template = """<?xml version="1.0" encoding="utf-8" ?>
<C:addressbook-multiget xmlns="DAV:" <C:addressbook-multiget xmlns="DAV:"
xmlns:C="urn:ietf:params:xml:ns:carddav"> xmlns:C="urn:ietf:params:xml:ns:carddav">
<prop> <prop>
@ -840,6 +832,6 @@ class CardDAVStorage(DAVStorage):
<C:address-data/> <C:address-data/>
</prop> </prop>
{hrefs} {hrefs}
</C:addressbook-multiget>''' </C:addressbook-multiget>"""
get_multi_data_query = '{urn:ietf:params:xml:ns:carddav}address-data' get_multi_data_query = "{urn:ietf:params:xml:ns:carddav}address-data"

View file

@ -11,6 +11,7 @@ try:
import etesync import etesync
import etesync.exceptions import etesync.exceptions
from etesync import AddressBook, Contact, Calendar, Event from etesync import AddressBook, Contact, Calendar, Event
has_etesync = True has_etesync = True
except ImportError: except ImportError:
has_etesync = False has_etesync = False
@ -36,37 +37,40 @@ def _writing_op(f):
if not self._at_once: if not self._at_once:
self._sync_journal() self._sync_journal()
return rv return rv
return inner return inner
class _Session: class _Session:
def __init__(self, email, secrets_dir, server_url=None, db_path=None): def __init__(self, email, secrets_dir, server_url=None, db_path=None):
if not has_etesync: if not has_etesync:
raise exceptions.UserError('Dependencies for etesync are not ' raise exceptions.UserError("Dependencies for etesync are not " "installed.")
'installed.')
server_url = server_url or etesync.API_URL server_url = server_url or etesync.API_URL
self.email = email self.email = email
self.secrets_dir = os.path.join(secrets_dir, email + '/') self.secrets_dir = os.path.join(secrets_dir, email + "/")
self._auth_token_path = os.path.join(self.secrets_dir, 'auth_token') self._auth_token_path = os.path.join(self.secrets_dir, "auth_token")
self._key_path = os.path.join(self.secrets_dir, 'key') self._key_path = os.path.join(self.secrets_dir, "key")
auth_token = self._get_auth_token() auth_token = self._get_auth_token()
if not auth_token: if not auth_token:
password = click.prompt('Enter service password for {}' password = click.prompt(
.format(self.email), hide_input=True) "Enter service password for {}".format(self.email), hide_input=True
auth_token = etesync.Authenticator(server_url) \ )
.get_auth_token(self.email, password) auth_token = etesync.Authenticator(server_url).get_auth_token(
self.email, password
)
self._set_auth_token(auth_token) self._set_auth_token(auth_token)
self._db_path = db_path or os.path.join(self.secrets_dir, 'db.sqlite') self._db_path = db_path or os.path.join(self.secrets_dir, "db.sqlite")
self.etesync = etesync.EteSync(email, auth_token, remote=server_url, self.etesync = etesync.EteSync(
db_path=self._db_path) email, auth_token, remote=server_url, db_path=self._db_path
)
key = self._get_key() key = self._get_key()
if not key: if not key:
password = click.prompt('Enter key password', hide_input=True) password = click.prompt("Enter key password", hide_input=True)
click.echo(f'Deriving key for {self.email}') click.echo(f"Deriving key for {self.email}")
self.etesync.derive_key(password) self.etesync.derive_key(password)
self._set_key(self.etesync.cipher_key) self._set_key(self.etesync.cipher_key)
else: else:
@ -87,14 +91,14 @@ class _Session:
def _get_key(self): def _get_key(self):
try: try:
with open(self._key_path, 'rb') as f: with open(self._key_path, "rb") as f:
return f.read() return f.read()
except OSError: except OSError:
pass pass
def _set_key(self, content): def _set_key(self, content):
checkdir(os.path.dirname(self._key_path), create=True) checkdir(os.path.dirname(self._key_path), create=True)
with atomicwrites.atomic_write(self._key_path, mode='wb') as f: with atomicwrites.atomic_write(self._key_path, mode="wb") as f:
f.write(content) f.write(content)
assert_permissions(self._key_path, 0o600) assert_permissions(self._key_path, 0o600)
@ -104,10 +108,9 @@ class EtesyncStorage(Storage):
_item_type = None _item_type = None
_at_once = False _at_once = False
def __init__(self, email, secrets_dir, server_url=None, db_path=None, def __init__(self, email, secrets_dir, server_url=None, db_path=None, **kwargs):
**kwargs): if kwargs.get("collection", None) is None:
if kwargs.get('collection', None) is None: raise ValueError("Collection argument required")
raise ValueError('Collection argument required')
self._session = _Session(email, secrets_dir, server_url, db_path) self._session = _Session(email, secrets_dir, server_url, db_path)
super().__init__(**kwargs) super().__init__(**kwargs)
@ -117,10 +120,9 @@ class EtesyncStorage(Storage):
self._session.etesync.sync_journal(self.collection) self._session.etesync.sync_journal(self.collection)
@classmethod @classmethod
def discover(cls, email, secrets_dir, server_url=None, db_path=None, def discover(cls, email, secrets_dir, server_url=None, db_path=None, **kwargs):
**kwargs): if kwargs.get("collection", None) is not None:
if kwargs.get('collection', None) is not None: raise TypeError("collection argument must not be given.")
raise TypeError('collection argument must not be given.')
session = _Session(email, secrets_dir, server_url, db_path) session = _Session(email, secrets_dir, server_url, db_path)
assert cls._collection_type assert cls._collection_type
session.etesync.sync_journal_list() session.etesync.sync_journal_list()
@ -131,20 +133,19 @@ class EtesyncStorage(Storage):
secrets_dir=secrets_dir, secrets_dir=secrets_dir,
db_path=db_path, db_path=db_path,
collection=entry.uid, collection=entry.uid,
**kwargs **kwargs,
) )
else: else:
logger.debug(f'Skipping collection: {entry!r}') logger.debug(f"Skipping collection: {entry!r}")
@classmethod @classmethod
def create_collection(cls, collection, email, secrets_dir, server_url=None, def create_collection(
db_path=None, **kwargs): cls, collection, email, secrets_dir, server_url=None, db_path=None, **kwargs
):
session = _Session(email, secrets_dir, server_url, db_path) session = _Session(email, secrets_dir, server_url, db_path)
content = {'displayName': collection} content = {"displayName": collection}
c = cls._collection_type.create( c = cls._collection_type.create(
session.etesync, session.etesync, binascii.hexlify(os.urandom(32)).decode(), content
binascii.hexlify(os.urandom(32)).decode(),
content
) )
c.save() c.save()
session.etesync.sync_journal_list() session.etesync.sync_journal_list()
@ -154,7 +155,7 @@ class EtesyncStorage(Storage):
secrets_dir=secrets_dir, secrets_dir=secrets_dir,
db_path=db_path, db_path=db_path,
server_url=server_url, server_url=server_url,
**kwargs **kwargs,
) )
def list(self): def list(self):
@ -219,10 +220,10 @@ class EtesyncStorage(Storage):
class EtesyncContacts(EtesyncStorage): class EtesyncContacts(EtesyncStorage):
_collection_type = AddressBook _collection_type = AddressBook
_item_type = Contact _item_type = Contact
storage_name = 'etesync_contacts' storage_name = "etesync_contacts"
class EtesyncCalendars(EtesyncStorage): class EtesyncCalendars(EtesyncStorage):
_collection_type = Calendar _collection_type = Calendar
_item_type = Event _item_type = Event
storage_name = 'etesync_calendars' storage_name = "etesync_calendars"

View file

@ -19,11 +19,10 @@ logger = logging.getLogger(__name__)
class FilesystemStorage(Storage): class FilesystemStorage(Storage):
storage_name = 'filesystem' storage_name = "filesystem"
_repr_attributes = ('path',) _repr_attributes = ("path",)
def __init__(self, path, fileext, encoding='utf-8', post_hook=None, def __init__(self, path, fileext, encoding="utf-8", post_hook=None, **kwargs):
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
path = expand_path(path) path = expand_path(path)
checkdir(path, create=False) checkdir(path, create=False)
@ -34,8 +33,8 @@ class FilesystemStorage(Storage):
@classmethod @classmethod
def discover(cls, path, **kwargs): def discover(cls, path, **kwargs):
if kwargs.pop('collection', None) is not None: if kwargs.pop("collection", None) is not None:
raise TypeError('collection argument must not be given.') raise TypeError("collection argument must not be given.")
path = expand_path(path) path = expand_path(path)
try: try:
collections = os.listdir(path) collections = os.listdir(path)
@ -47,30 +46,29 @@ class FilesystemStorage(Storage):
collection_path = os.path.join(path, collection) collection_path = os.path.join(path, collection)
if not cls._validate_collection(collection_path): if not cls._validate_collection(collection_path):
continue continue
args = dict(collection=collection, path=collection_path, args = dict(collection=collection, path=collection_path, **kwargs)
**kwargs)
yield args yield args
@classmethod @classmethod
def _validate_collection(cls, path): def _validate_collection(cls, path):
if not os.path.isdir(path): if not os.path.isdir(path):
return False return False
if os.path.basename(path).startswith('.'): if os.path.basename(path).startswith("."):
return False return False
return True return True
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): def create_collection(cls, collection, **kwargs):
kwargs = dict(kwargs) kwargs = dict(kwargs)
path = kwargs['path'] path = kwargs["path"]
if collection is not None: if collection is not None:
path = os.path.join(path, collection) path = os.path.join(path, collection)
checkdir(expand_path(path), create=True) checkdir(expand_path(path), create=True)
kwargs['path'] = path kwargs["path"] = path
kwargs['collection'] = collection kwargs["collection"] = collection
return kwargs return kwargs
def _get_filepath(self, href): def _get_filepath(self, href):
@ -88,9 +86,8 @@ class FilesystemStorage(Storage):
def get(self, href): def get(self, href):
fpath = self._get_filepath(href) fpath = self._get_filepath(href)
try: try:
with open(fpath, 'rb') as f: with open(fpath, "rb") as f:
return (Item(f.read().decode(self.encoding)), return (Item(f.read().decode(self.encoding)), get_etag_from_file(fpath))
get_etag_from_file(fpath))
except OSError as e: except OSError as e:
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
raise exceptions.NotFoundError(href) raise exceptions.NotFoundError(href)
@ -99,18 +96,14 @@ class FilesystemStorage(Storage):
def upload(self, item): def upload(self, item):
if not isinstance(item.raw, str): if not isinstance(item.raw, str):
raise TypeError('item.raw must be a unicode string.') raise TypeError("item.raw must be a unicode string.")
try: try:
href = self._get_href(item.ident) href = self._get_href(item.ident)
fpath, etag = self._upload_impl(item, href) fpath, etag = self._upload_impl(item, href)
except OSError as e: except OSError as e:
if e.errno in ( if e.errno in (errno.ENAMETOOLONG, errno.ENOENT): # Unix # Windows
errno.ENAMETOOLONG, # Unix logger.debug("UID as filename rejected, trying with random " "one.")
errno.ENOENT # Windows
):
logger.debug('UID as filename rejected, trying with random '
'one.')
# random href instead of UID-based # random href instead of UID-based
href = self._get_href(None) href = self._get_href(None)
fpath, etag = self._upload_impl(item, href) fpath, etag = self._upload_impl(item, href)
@ -124,7 +117,7 @@ class FilesystemStorage(Storage):
def _upload_impl(self, item, href): def _upload_impl(self, item, href):
fpath = self._get_filepath(href) fpath = self._get_filepath(href)
try: try:
with atomic_write(fpath, mode='wb', overwrite=False) as f: with atomic_write(fpath, mode="wb", overwrite=False) as f:
f.write(item.raw.encode(self.encoding)) f.write(item.raw.encode(self.encoding))
return fpath, get_etag_from_file(f) return fpath, get_etag_from_file(f)
except OSError as e: except OSError as e:
@ -142,9 +135,9 @@ class FilesystemStorage(Storage):
raise exceptions.WrongEtagError(etag, actual_etag) raise exceptions.WrongEtagError(etag, actual_etag)
if not isinstance(item.raw, str): if not isinstance(item.raw, str):
raise TypeError('item.raw must be a unicode string.') raise TypeError("item.raw must be a unicode string.")
with atomic_write(fpath, mode='wb', overwrite=True) as f: with atomic_write(fpath, mode="wb", overwrite=True) as f:
f.write(item.raw.encode(self.encoding)) f.write(item.raw.encode(self.encoding))
etag = get_etag_from_file(f) etag = get_etag_from_file(f)
@ -162,21 +155,22 @@ class FilesystemStorage(Storage):
os.remove(fpath) os.remove(fpath)
def _run_post_hook(self, fpath): def _run_post_hook(self, fpath):
logger.info('Calling post_hook={} with argument={}'.format( logger.info(
self.post_hook, fpath)) "Calling post_hook={} with argument={}".format(self.post_hook, fpath)
)
try: try:
subprocess.call([self.post_hook, fpath]) subprocess.call([self.post_hook, fpath])
except OSError as e: except OSError as e:
logger.warning('Error executing external hook: {}'.format(str(e))) logger.warning("Error executing external hook: {}".format(str(e)))
def get_meta(self, key): def get_meta(self, key):
fpath = os.path.join(self.path, key) fpath = os.path.join(self.path, key)
try: try:
with open(fpath, 'rb') as f: with open(fpath, "rb") as f:
return normalize_meta_value(f.read().decode(self.encoding)) return normalize_meta_value(f.read().decode(self.encoding))
except OSError as e: except OSError as e:
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
return '' return ""
else: else:
raise raise
@ -184,5 +178,5 @@ class FilesystemStorage(Storage):
value = normalize_meta_value(value) value = normalize_meta_value(value)
fpath = os.path.join(self.path, key) fpath = os.path.join(self.path, key)
with atomic_write(fpath, mode='wb', overwrite=True) as f: with atomic_write(fpath, mode="wb", overwrite=True) as f:
f.write(value.encode(self.encoding)) f.write(value.encode(self.encoding))

View file

@ -17,11 +17,12 @@ from ..utils import open_graphical_browser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKEN_URL = 'https://accounts.google.com/o/oauth2/v2/auth' TOKEN_URL = "https://accounts.google.com/o/oauth2/v2/auth"
REFRESH_URL = 'https://www.googleapis.com/oauth2/v4/token' REFRESH_URL = "https://www.googleapis.com/oauth2/v4/token"
try: try:
from requests_oauthlib import OAuth2Session from requests_oauthlib import OAuth2Session
have_oauth2 = True have_oauth2 = True
except ImportError: except ImportError:
have_oauth2 = False have_oauth2 = False
@ -37,7 +38,7 @@ class GoogleSession(dav.DAVSession):
self._settings = {} self._settings = {}
if not have_oauth2: if not have_oauth2:
raise exceptions.UserError('requests-oauthlib not installed') raise exceptions.UserError("requests-oauthlib not installed")
token_file = expand_path(token_file) token_file = expand_path(token_file)
ui_worker = get_ui_worker() ui_worker = get_ui_worker()
@ -53,26 +54,26 @@ class GoogleSession(dav.DAVSession):
pass pass
except ValueError as e: except ValueError as e:
raise exceptions.UserError( raise exceptions.UserError(
'Failed to load token file {}, try deleting it. ' "Failed to load token file {}, try deleting it. "
'Original error: {}'.format(token_file, e) "Original error: {}".format(token_file, e)
) )
def _save_token(token): def _save_token(token):
checkdir(expand_path(os.path.dirname(token_file)), create=True) checkdir(expand_path(os.path.dirname(token_file)), create=True)
with atomic_write(token_file, mode='w', overwrite=True) as f: with atomic_write(token_file, mode="w", overwrite=True) as f:
json.dump(token, f) json.dump(token, f)
self._session = OAuth2Session( self._session = OAuth2Session(
client_id=client_id, client_id=client_id,
token=token, token=token,
redirect_uri='urn:ietf:wg:oauth:2.0:oob', redirect_uri="urn:ietf:wg:oauth:2.0:oob",
scope=self.scope, scope=self.scope,
auto_refresh_url=REFRESH_URL, auto_refresh_url=REFRESH_URL,
auto_refresh_kwargs={ auto_refresh_kwargs={
'client_id': client_id, "client_id": client_id,
'client_secret': client_secret, "client_secret": client_secret,
}, },
token_updater=_save_token token_updater=_save_token,
) )
if not token: if not token:
@ -80,8 +81,10 @@ class GoogleSession(dav.DAVSession):
TOKEN_URL, TOKEN_URL,
# access_type and approval_prompt are Google specific # access_type and approval_prompt are Google specific
# extra parameters. # extra parameters.
access_type='offline', approval_prompt='force') access_type="offline",
click.echo(f'Opening {authorization_url} ...') approval_prompt="force",
)
click.echo(f"Opening {authorization_url} ...")
try: try:
open_graphical_browser(authorization_url) open_graphical_browser(authorization_url)
except Exception as e: except Exception as e:
@ -102,31 +105,42 @@ class GoogleSession(dav.DAVSession):
class GoogleCalendarStorage(dav.CalDAVStorage): class GoogleCalendarStorage(dav.CalDAVStorage):
class session_class(GoogleSession): class session_class(GoogleSession):
url = 'https://apidata.googleusercontent.com/caldav/v2/' url = "https://apidata.googleusercontent.com/caldav/v2/"
scope = ['https://www.googleapis.com/auth/calendar'] scope = ["https://www.googleapis.com/auth/calendar"]
class discovery_class(dav.CalDiscover): class discovery_class(dav.CalDiscover):
@staticmethod @staticmethod
def _get_collection_from_url(url): def _get_collection_from_url(url):
# Google CalDAV has collection URLs like: # Google CalDAV has collection URLs like:
# /user/foouser/calendars/foocalendar/events/ # /user/foouser/calendars/foocalendar/events/
parts = url.rstrip('/').split('/') parts = url.rstrip("/").split("/")
parts.pop() parts.pop()
collection = parts.pop() collection = parts.pop()
return urlparse.unquote(collection) return urlparse.unquote(collection)
storage_name = 'google_calendar' storage_name = "google_calendar"
def __init__(self, token_file, client_id, client_secret, start_date=None, def __init__(
end_date=None, item_types=(), **kwargs): self,
if not kwargs.get('collection'): token_file,
client_id,
client_secret,
start_date=None,
end_date=None,
item_types=(),
**kwargs,
):
if not kwargs.get("collection"):
raise exceptions.CollectionRequired() raise exceptions.CollectionRequired()
super().__init__( super().__init__(
token_file=token_file, client_id=client_id, token_file=token_file,
client_secret=client_secret, start_date=start_date, client_id=client_id,
end_date=end_date, item_types=item_types, client_secret=client_secret,
**kwargs start_date=start_date,
end_date=end_date,
item_types=item_types,
**kwargs,
) )
# This is ugly: We define/override the entire signature computed for the # This is ugly: We define/override the entire signature computed for the
@ -144,23 +158,24 @@ class GoogleContactsStorage(dav.CardDAVStorage):
# #
# So we configure the well-known URI here again, such that discovery # So we configure the well-known URI here again, such that discovery
# tries collection enumeration on it directly. That appears to work. # tries collection enumeration on it directly. That appears to work.
url = 'https://www.googleapis.com/.well-known/carddav' url = "https://www.googleapis.com/.well-known/carddav"
scope = ['https://www.googleapis.com/auth/carddav'] scope = ["https://www.googleapis.com/auth/carddav"]
class discovery_class(dav.CardDAVStorage.discovery_class): class discovery_class(dav.CardDAVStorage.discovery_class):
# Google CardDAV doesn't return any resourcetype prop. # Google CardDAV doesn't return any resourcetype prop.
_resourcetype = None _resourcetype = None
storage_name = 'google_contacts' storage_name = "google_contacts"
def __init__(self, token_file, client_id, client_secret, **kwargs): def __init__(self, token_file, client_id, client_secret, **kwargs):
if not kwargs.get('collection'): if not kwargs.get("collection"):
raise exceptions.CollectionRequired() raise exceptions.CollectionRequired()
super().__init__( super().__init__(
token_file=token_file, client_id=client_id, token_file=token_file,
client_id=client_id,
client_secret=client_secret, client_secret=client_secret,
**kwargs **kwargs,
) )
# This is ugly: We define/override the entire signature computed for the # This is ugly: We define/override the entire signature computed for the

View file

@ -12,41 +12,49 @@ from .base import Storage
class HttpStorage(Storage): class HttpStorage(Storage):
storage_name = 'http' storage_name = "http"
read_only = True read_only = True
_repr_attributes = ('username', 'url') _repr_attributes = ("username", "url")
_items = None _items = None
# Required for tests. # Required for tests.
_ignore_uids = True _ignore_uids = True
def __init__(self, url, username='', password='', verify=True, auth=None, def __init__(
useragent=USERAGENT, verify_fingerprint=None, auth_cert=None, self,
**kwargs): url,
username="",
password="",
verify=True,
auth=None,
useragent=USERAGENT,
verify_fingerprint=None,
auth_cert=None,
**kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self._settings = { self._settings = {
'auth': prepare_auth(auth, username, password), "auth": prepare_auth(auth, username, password),
'cert': prepare_client_cert(auth_cert), "cert": prepare_client_cert(auth_cert),
'latin1_fallback': False, "latin1_fallback": False,
} }
self._settings.update(prepare_verify(verify, verify_fingerprint)) self._settings.update(prepare_verify(verify, verify_fingerprint))
self.username, self.password = username, password self.username, self.password = username, password
self.useragent = useragent self.useragent = useragent
collection = kwargs.get('collection') collection = kwargs.get("collection")
if collection is not None: if collection is not None:
url = urlparse.urljoin(url, collection) url = urlparse.urljoin(url, collection)
self.url = url self.url = url
self.parsed_url = urlparse.urlparse(self.url) self.parsed_url = urlparse.urlparse(self.url)
def _default_headers(self): def _default_headers(self):
return {'User-Agent': self.useragent} return {"User-Agent": self.useragent}
def list(self): def list(self):
r = request('GET', self.url, headers=self._default_headers(), r = request("GET", self.url, headers=self._default_headers(), **self._settings)
**self._settings)
self._items = {} self._items = {}
for item in split_collection(r.text): for item in split_collection(r.text):

View file

@ -6,21 +6,20 @@ from .base import Storage
def _random_string(): def _random_string():
return f'{random.random():.9f}' return f"{random.random():.9f}"
class MemoryStorage(Storage): class MemoryStorage(Storage):
storage_name = 'memory' storage_name = "memory"
''' """
Saves data in RAM, only useful for testing. Saves data in RAM, only useful for testing.
''' """
def __init__(self, fileext='', **kwargs): def __init__(self, fileext="", **kwargs):
if kwargs.get('collection') is not None: if kwargs.get("collection") is not None:
raise exceptions.UserError('MemoryStorage does not support ' raise exceptions.UserError("MemoryStorage does not support " "collections.")
'collections.')
self.items = {} # href => (etag, item) self.items = {} # href => (etag, item)
self.metadata = {} self.metadata = {}
self.fileext = fileext self.fileext = fileext

View file

@ -28,21 +28,22 @@ def _writing_op(f):
if not self._at_once: if not self._at_once:
self._write() self._write()
return rv return rv
return inner return inner
class SingleFileStorage(Storage): class SingleFileStorage(Storage):
storage_name = 'singlefile' storage_name = "singlefile"
_repr_attributes = ('path',) _repr_attributes = ("path",)
_write_mode = 'wb' _write_mode = "wb"
_append_mode = 'ab' _append_mode = "ab"
_read_mode = 'rb' _read_mode = "rb"
_items = None _items = None
_last_etag = None _last_etag = None
def __init__(self, path, encoding='utf-8', **kwargs): def __init__(self, path, encoding="utf-8", **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
path = os.path.abspath(expand_path(path)) path = os.path.abspath(expand_path(path))
checkfile(path, create=False) checkfile(path, create=False)
@ -53,49 +54,47 @@ class SingleFileStorage(Storage):
@classmethod @classmethod
def discover(cls, path, **kwargs): def discover(cls, path, **kwargs):
if kwargs.pop('collection', None) is not None: if kwargs.pop("collection", None) is not None:
raise TypeError('collection argument must not be given.') raise TypeError("collection argument must not be given.")
path = os.path.abspath(expand_path(path)) path = os.path.abspath(expand_path(path))
try: try:
path_glob = path % '*' path_glob = path % "*"
except TypeError: except TypeError:
# If not exactly one '%s' is present, we cannot discover # If not exactly one '%s' is present, we cannot discover
# collections because we wouldn't know which name to assign. # collections because we wouldn't know which name to assign.
raise NotImplementedError() raise NotImplementedError()
placeholder_pos = path.index('%s') placeholder_pos = path.index("%s")
for subpath in glob.iglob(path_glob): for subpath in glob.iglob(path_glob):
if os.path.isfile(subpath): if os.path.isfile(subpath):
args = dict(kwargs) args = dict(kwargs)
args['path'] = subpath args["path"] = subpath
collection_end = ( collection_end = (
placeholder_pos placeholder_pos + 2 + len(subpath) - len(path) # length of '%s'
+ 2 # length of '%s'
+ len(subpath)
- len(path)
) )
collection = subpath[placeholder_pos:collection_end] collection = subpath[placeholder_pos:collection_end]
args['collection'] = collection args["collection"] = collection
yield args yield args
@classmethod @classmethod
def create_collection(cls, collection, **kwargs): def create_collection(cls, collection, **kwargs):
path = os.path.abspath(expand_path(kwargs['path'])) path = os.path.abspath(expand_path(kwargs["path"]))
if collection is not None: if collection is not None:
try: try:
path = path % (collection,) path = path % (collection,)
except TypeError: except TypeError:
raise ValueError('Exactly one %s required in path ' raise ValueError(
'if collection is not null.') "Exactly one %s required in path " "if collection is not null."
)
checkfile(path, create=True) checkfile(path, create=True)
kwargs['path'] = path kwargs["path"] = path
kwargs['collection'] = collection kwargs["collection"] = collection
return kwargs return kwargs
def list(self): def list(self):
@ -107,6 +106,7 @@ class SingleFileStorage(Storage):
text = f.read().decode(self.encoding) text = f.read().decode(self.encoding)
except OSError as e: except OSError as e:
import errno import errno
if e.errno != errno.ENOENT: # file not found if e.errno != errno.ENOENT: # file not found
raise OSError(e) raise OSError(e)
text = None text = None
@ -163,18 +163,19 @@ class SingleFileStorage(Storage):
del self._items[href] del self._items[href]
def _write(self): def _write(self):
if self._last_etag is not None and \ if self._last_etag is not None and self._last_etag != get_etag_from_file(
self._last_etag != get_etag_from_file(self.path): self.path
raise exceptions.PreconditionFailed(( ):
'Some other program modified the file {!r}. Re-run the ' raise exceptions.PreconditionFailed(
'synchronization and make sure absolutely no other program is ' (
'writing into the same file.' "Some other program modified the file {!r}. Re-run the "
).format(self.path)) "synchronization and make sure absolutely no other program is "
text = join_collection( "writing into the same file."
item.raw for item, etag in self._items.values() ).format(self.path)
) )
text = join_collection(item.raw for item, etag in self._items.values())
try: try:
with atomic_write(self.path, mode='wb', overwrite=True) as f: with atomic_write(self.path, mode="wb", overwrite=True) as f:
f.write(text.encode(self.encoding)) f.write(text.encode(self.encoding))
finally: finally:
self._items = None self._items = None

View file

@ -1,4 +1,4 @@
''' """
The `sync` function in `vdirsyncer.sync` can be called on two instances of The `sync` function in `vdirsyncer.sync` can be called on two instances of
`Storage` to synchronize them. Apart from the defined errors, this is the only `Storage` to synchronize them. Apart from the defined errors, this is the only
public API of this module. public API of this module.
@ -8,7 +8,7 @@ Yang: http://blog.ezyang.com/2012/08/how-offlineimap-works/
Some modifications to it are explained in Some modifications to it are explained in
https://unterwaditzer.net/2016/sync-algorithm.html https://unterwaditzer.net/2016/sync-algorithm.html
''' """
import contextlib import contextlib
import itertools import itertools
import logging import logging
@ -27,8 +27,9 @@ sync_logger = logging.getLogger(__name__)
class _StorageInfo: class _StorageInfo:
'''A wrapper class that holds prefetched items, the status and other """A wrapper class that holds prefetched items, the status and other
things.''' things."""
def __init__(self, storage, status): def __init__(self, storage, status):
self.storage = storage self.storage = storage
self.status = status self.status = status
@ -57,13 +58,8 @@ class _StorageInfo:
_store_props(ident, meta) _store_props(ident, meta)
# Prefetch items # Prefetch items
for href, item, etag in (self.storage.get_multi(prefetch) for href, item, etag in self.storage.get_multi(prefetch) if prefetch else ():
if prefetch else ()): _store_props(item.ident, ItemMetadata(href=href, hash=item.hash, etag=etag))
_store_props(item.ident, ItemMetadata(
href=href,
hash=item.hash,
etag=etag
))
self.set_item_cache(item.ident, item) self.set_item_cache(item.ident, item)
return storage_nonempty return storage_nonempty
@ -90,9 +86,16 @@ class _StorageInfo:
return self._item_cache[ident] return self._item_cache[ident]
def sync(storage_a, storage_b, status, conflict_resolution=None, def sync(
force_delete=False, error_callback=None, partial_sync='revert'): storage_a,
'''Synchronizes two storages. storage_b,
status,
conflict_resolution=None,
force_delete=False,
error_callback=None,
partial_sync="revert",
):
"""Synchronizes two storages.
:param storage_a: The first storage :param storage_a: The first storage
:type storage_a: :class:`vdirsyncer.storage.base.Storage` :type storage_a: :class:`vdirsyncer.storage.base.Storage`
@ -119,20 +122,20 @@ def sync(storage_a, storage_b, status, conflict_resolution=None,
- ``error``: Raise an error. - ``error``: Raise an error.
- ``ignore``: Those actions are simply skipped. - ``ignore``: Those actions are simply skipped.
- ``revert`` (default): Revert changes on other side. - ``revert`` (default): Revert changes on other side.
''' """
if storage_a.read_only and storage_b.read_only: if storage_a.read_only and storage_b.read_only:
raise BothReadOnly() raise BothReadOnly()
if conflict_resolution == 'a wins': if conflict_resolution == "a wins":
conflict_resolution = lambda a, b: a conflict_resolution = lambda a, b: a
elif conflict_resolution == 'b wins': elif conflict_resolution == "b wins":
conflict_resolution = lambda a, b: b conflict_resolution = lambda a, b: b
status_nonempty = bool(next(status.iter_old(), None)) status_nonempty = bool(next(status.iter_old(), None))
with status.transaction(): with status.transaction():
a_info = _StorageInfo(storage_a, SubStatus(status, 'a')) a_info = _StorageInfo(storage_a, SubStatus(status, "a"))
b_info = _StorageInfo(storage_b, SubStatus(status, 'b')) b_info = _StorageInfo(storage_b, SubStatus(status, "b"))
a_nonempty = a_info.prepare_new_status() a_nonempty = a_info.prepare_new_status()
b_nonempty = b_info.prepare_new_status() b_nonempty = b_info.prepare_new_status()
@ -148,12 +151,7 @@ def sync(storage_a, storage_b, status, conflict_resolution=None,
with storage_a.at_once(), storage_b.at_once(): with storage_a.at_once(), storage_b.at_once():
for action in actions: for action in actions:
try: try:
action.run( action.run(a_info, b_info, conflict_resolution, partial_sync)
a_info,
b_info,
conflict_resolution,
partial_sync
)
except Exception as e: except Exception as e:
if error_callback: if error_callback:
error_callback(e) error_callback(e)
@ -168,13 +166,13 @@ class Action:
def run(self, a, b, conflict_resolution, partial_sync): def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b): with self.auto_rollback(a, b):
if self.dest.storage.read_only: if self.dest.storage.read_only:
if partial_sync == 'error': if partial_sync == "error":
raise PartialSync(self.dest.storage) raise PartialSync(self.dest.storage)
elif partial_sync == 'ignore': elif partial_sync == "ignore":
self.rollback(a, b) self.rollback(a, b)
return return
else: else:
assert partial_sync == 'revert' assert partial_sync == "revert"
self._run_impl(a, b) self._run_impl(a, b)
@ -201,16 +199,17 @@ class Upload(Action):
if self.dest.storage.read_only: if self.dest.storage.read_only:
href = etag = None href = etag = None
else: else:
sync_logger.info('Copying (uploading) item {} to {}' sync_logger.info(
.format(self.ident, self.dest.storage)) "Copying (uploading) item {} to {}".format(
self.ident, self.dest.storage
)
)
href, etag = self.dest.storage.upload(self.item) href, etag = self.dest.storage.upload(self.item)
assert href is not None assert href is not None
self.dest.status.insert_ident(self.ident, ItemMetadata( self.dest.status.insert_ident(
href=href, self.ident, ItemMetadata(href=href, hash=self.item.hash, etag=etag)
hash=self.item.hash, )
etag=etag
))
class Update(Action): class Update(Action):
@ -223,11 +222,11 @@ class Update(Action):
if self.dest.storage.read_only: if self.dest.storage.read_only:
meta = ItemMetadata(hash=self.item.hash) meta = ItemMetadata(hash=self.item.hash)
else: else:
sync_logger.info('Copying (updating) item {} to {}' sync_logger.info(
.format(self.ident, self.dest.storage)) "Copying (updating) item {} to {}".format(self.ident, self.dest.storage)
)
meta = self.dest.status.get_new(self.ident) meta = self.dest.status.get_new(self.ident)
meta.etag = \ meta.etag = self.dest.storage.update(meta.href, self.item, meta.etag)
self.dest.storage.update(meta.href, self.item, meta.etag)
self.dest.status.update_ident(self.ident, meta) self.dest.status.update_ident(self.ident, meta)
@ -240,8 +239,9 @@ class Delete(Action):
def _run_impl(self, a, b): def _run_impl(self, a, b):
meta = self.dest.status.get_new(self.ident) meta = self.dest.status.get_new(self.ident)
if not self.dest.storage.read_only: if not self.dest.storage.read_only:
sync_logger.info('Deleting item {} from {}' sync_logger.info(
.format(self.ident, self.dest.storage)) "Deleting item {} from {}".format(self.ident, self.dest.storage)
)
self.dest.storage.delete(meta.href, meta.etag) self.dest.storage.delete(meta.href, meta.etag)
self.dest.status.remove_ident(self.ident) self.dest.status.remove_ident(self.ident)
@ -253,35 +253,39 @@ class ResolveConflict(Action):
def run(self, a, b, conflict_resolution, partial_sync): def run(self, a, b, conflict_resolution, partial_sync):
with self.auto_rollback(a, b): with self.auto_rollback(a, b):
sync_logger.info('Doing conflict resolution for item {}...' sync_logger.info(
.format(self.ident)) "Doing conflict resolution for item {}...".format(self.ident)
)
meta_a = a.status.get_new(self.ident) meta_a = a.status.get_new(self.ident)
meta_b = b.status.get_new(self.ident) meta_b = b.status.get_new(self.ident)
if meta_a.hash == meta_b.hash: if meta_a.hash == meta_b.hash:
sync_logger.info('...same content on both sides.') sync_logger.info("...same content on both sides.")
elif conflict_resolution is None: elif conflict_resolution is None:
raise SyncConflict(ident=self.ident, href_a=meta_a.href, raise SyncConflict(
href_b=meta_b.href) ident=self.ident, href_a=meta_a.href, href_b=meta_b.href
)
elif callable(conflict_resolution): elif callable(conflict_resolution):
item_a = a.get_item_cache(self.ident) item_a = a.get_item_cache(self.ident)
item_b = b.get_item_cache(self.ident) item_b = b.get_item_cache(self.ident)
new_item = conflict_resolution(item_a, item_b) new_item = conflict_resolution(item_a, item_b)
if new_item.hash != meta_a.hash: if new_item.hash != meta_a.hash:
Update(new_item, a).run(a, b, conflict_resolution, Update(new_item, a).run(a, b, conflict_resolution, partial_sync)
partial_sync)
if new_item.hash != meta_b.hash: if new_item.hash != meta_b.hash:
Update(new_item, b).run(a, b, conflict_resolution, Update(new_item, b).run(a, b, conflict_resolution, partial_sync)
partial_sync)
else: else:
raise UserError('Invalid conflict resolution mode: {!r}' raise UserError(
.format(conflict_resolution)) "Invalid conflict resolution mode: {!r}".format(conflict_resolution)
)
def _get_actions(a_info, b_info): def _get_actions(a_info, b_info):
for ident in uniq(itertools.chain(a_info.status.parent.iter_new(), for ident in uniq(
a_info.status.parent.iter_old())): itertools.chain(
a_info.status.parent.iter_new(), a_info.status.parent.iter_old()
)
):
a = a_info.status.get_new(ident) a = a_info.status.get_new(ident)
b = b_info.status.get_new(ident) b = b_info.status.get_new(ident)

View file

@ -2,18 +2,18 @@ from .. import exceptions
class SyncError(exceptions.Error): class SyncError(exceptions.Error):
'''Errors related to synchronization.''' """Errors related to synchronization."""
class SyncConflict(SyncError): class SyncConflict(SyncError):
''' """
Two items changed since the last sync, they now have different contents and Two items changed since the last sync, they now have different contents and
no conflict resolution method was given. no conflict resolution method was given.
:param ident: The ident of the item. :param ident: The ident of the item.
:param href_a: The item's href on side A. :param href_a: The item's href on side A.
:param href_b: The item's href on side B. :param href_b: The item's href on side B.
''' """
ident = None ident = None
href_a = None href_a = None
@ -21,12 +21,13 @@ class SyncConflict(SyncError):
class IdentConflict(SyncError): class IdentConflict(SyncError):
''' """
Multiple items on the same storage have the same UID. Multiple items on the same storage have the same UID.
:param storage: The affected storage. :param storage: The affected storage.
:param hrefs: List of affected hrefs on `storage`. :param hrefs: List of affected hrefs on `storage`.
''' """
storage = None storage = None
_hrefs = None _hrefs = None
@ -42,37 +43,38 @@ class IdentConflict(SyncError):
class StorageEmpty(SyncError): class StorageEmpty(SyncError):
''' """
One storage unexpectedly got completely empty between two synchronizations. One storage unexpectedly got completely empty between two synchronizations.
The first argument is the empty storage. The first argument is the empty storage.
:param empty_storage: The empty :param empty_storage: The empty
:py:class:`vdirsyncer.storage.base.Storage`. :py:class:`vdirsyncer.storage.base.Storage`.
''' """
empty_storage = None empty_storage = None
class BothReadOnly(SyncError): class BothReadOnly(SyncError):
''' """
Both storages are marked as read-only. Synchronization is therefore not Both storages are marked as read-only. Synchronization is therefore not
possible. possible.
''' """
class PartialSync(SyncError): class PartialSync(SyncError):
''' """
Attempted change on read-only storage. Attempted change on read-only storage.
''' """
storage = None storage = None
class IdentAlreadyExists(SyncError): class IdentAlreadyExists(SyncError):
'''Like IdentConflict, but for internal state. If this bubbles up, we don't """Like IdentConflict, but for internal state. If this bubbles up, we don't
have a data race, but a bug.''' have a data race, but a bug."""
old_href = None old_href = None
new_href = None new_href = None
def to_ident_conflict(self, storage): def to_ident_conflict(self, storage):
return IdentConflict(storage=storage, return IdentConflict(storage=storage, hrefs=[self.old_href, self.new_href])
hrefs=[self.old_href, self.new_href])

View file

@ -10,14 +10,14 @@ from .exceptions import IdentAlreadyExists
def _exclusive_transaction(conn): def _exclusive_transaction(conn):
c = None c = None
try: try:
c = conn.execute('BEGIN EXCLUSIVE TRANSACTION') c = conn.execute("BEGIN EXCLUSIVE TRANSACTION")
yield c yield c
c.execute('COMMIT') c.execute("COMMIT")
except BaseException: except BaseException:
if c is None: if c is None:
raise raise
_, e, tb = sys.exc_info() _, e, tb = sys.exc_info()
c.execute('ROLLBACK') c.execute("ROLLBACK")
raise e.with_traceback(tb) raise e.with_traceback(tb)
@ -27,14 +27,12 @@ class _StatusBase(metaclass=abc.ABCMeta):
for ident, metadata in status.items(): for ident, metadata in status.items():
if len(metadata) == 4: if len(metadata) == 4:
href_a, etag_a, href_b, etag_b = metadata href_a, etag_a, href_b, etag_b = metadata
props_a = ItemMetadata(href=href_a, hash='UNDEFINED', props_a = ItemMetadata(href=href_a, hash="UNDEFINED", etag=etag_a)
etag=etag_a) props_b = ItemMetadata(href=href_b, hash="UNDEFINED", etag=etag_b)
props_b = ItemMetadata(href=href_b, hash='UNDEFINED',
etag=etag_b)
else: else:
a, b = metadata a, b = metadata
a.setdefault('hash', 'UNDEFINED') a.setdefault("hash", "UNDEFINED")
b.setdefault('hash', 'UNDEFINED') b.setdefault("hash", "UNDEFINED")
props_a = ItemMetadata(**a) props_a = ItemMetadata(**a)
props_b = ItemMetadata(**b) props_b = ItemMetadata(**b)
@ -111,7 +109,7 @@ class _StatusBase(metaclass=abc.ABCMeta):
class SqliteStatus(_StatusBase): class SqliteStatus(_StatusBase):
SCHEMA_VERSION = 1 SCHEMA_VERSION = 1
def __init__(self, path=':memory:'): def __init__(self, path=":memory:"):
self._path = path self._path = path
self._c = sqlite3.connect(path) self._c = sqlite3.connect(path)
self._c.isolation_level = None # turn off idiocy of DB-API self._c.isolation_level = None # turn off idiocy of DB-API
@ -126,12 +124,12 @@ class SqliteStatus(_StatusBase):
# data. # data.
with _exclusive_transaction(self._c) as c: with _exclusive_transaction(self._c) as c:
c.execute('CREATE TABLE meta ( "version" INTEGER PRIMARY KEY )') c.execute('CREATE TABLE meta ( "version" INTEGER PRIMARY KEY )')
c.execute('INSERT INTO meta (version) VALUES (?)', c.execute("INSERT INTO meta (version) VALUES (?)", (self.SCHEMA_VERSION,))
(self.SCHEMA_VERSION,))
# I know that this is a bad schema, but right there is just too # I know that this is a bad schema, but right there is just too
# little gain in deduplicating the .._a and .._b columns. # little gain in deduplicating the .._a and .._b columns.
c.execute('''CREATE TABLE status ( c.execute(
"""CREATE TABLE status (
"ident" TEXT PRIMARY KEY NOT NULL, "ident" TEXT PRIMARY KEY NOT NULL,
"href_a" TEXT, "href_a" TEXT,
"href_b" TEXT, "href_b" TEXT,
@ -139,9 +137,10 @@ class SqliteStatus(_StatusBase):
"hash_b" TEXT NOT NULL, "hash_b" TEXT NOT NULL,
"etag_a" TEXT, "etag_a" TEXT,
"etag_b" TEXT "etag_b" TEXT
); ''') ); """
c.execute('CREATE UNIQUE INDEX by_href_a ON status(href_a)') )
c.execute('CREATE UNIQUE INDEX by_href_b ON status(href_b)') c.execute("CREATE UNIQUE INDEX by_href_a ON status(href_a)")
c.execute("CREATE UNIQUE INDEX by_href_b ON status(href_b)")
# We cannot add NOT NULL here because data is first fetched for the # We cannot add NOT NULL here because data is first fetched for the
# storage a, then storage b. Inbetween the `_b`-columns are filled # storage a, then storage b. Inbetween the `_b`-columns are filled
@ -156,7 +155,8 @@ class SqliteStatus(_StatusBase):
# transaction and reenable on end), it's a separate table now that # transaction and reenable on end), it's a separate table now that
# just gets copied over before we commit. That's a lot of copying, # just gets copied over before we commit. That's a lot of copying,
# sadly. # sadly.
c.execute('''CREATE TABLE new_status ( c.execute(
"""CREATE TABLE new_status (
"ident" TEXT PRIMARY KEY NOT NULL, "ident" TEXT PRIMARY KEY NOT NULL,
"href_a" TEXT, "href_a" TEXT,
"href_b" TEXT, "href_b" TEXT,
@ -164,14 +164,16 @@ class SqliteStatus(_StatusBase):
"hash_b" TEXT, "hash_b" TEXT,
"etag_a" TEXT, "etag_a" TEXT,
"etag_b" TEXT "etag_b" TEXT
); ''') ); """
)
def _is_latest_version(self): def _is_latest_version(self):
try: try:
return bool(self._c.execute( return bool(
'SELECT version FROM meta WHERE version = ?', self._c.execute(
(self.SCHEMA_VERSION,) "SELECT version FROM meta WHERE version = ?", (self.SCHEMA_VERSION,)
).fetchone()) ).fetchone()
)
except sqlite3.OperationalError: except sqlite3.OperationalError:
return False return False
@ -182,10 +184,9 @@ class SqliteStatus(_StatusBase):
with _exclusive_transaction(self._c) as new_c: with _exclusive_transaction(self._c) as new_c:
self._c = new_c self._c = new_c
yield yield
self._c.execute('DELETE FROM status') self._c.execute("DELETE FROM status")
self._c.execute('INSERT INTO status ' self._c.execute("INSERT INTO status " "SELECT * FROM new_status")
'SELECT * FROM new_status') self._c.execute("DELETE FROM new_status")
self._c.execute('DELETE FROM new_status')
finally: finally:
self._c = old_c self._c = old_c
@ -193,88 +194,99 @@ class SqliteStatus(_StatusBase):
# FIXME: Super inefficient # FIXME: Super inefficient
old_props = self.get_new_a(ident) old_props = self.get_new_a(ident)
if old_props is not None: if old_props is not None:
raise IdentAlreadyExists(old_href=old_props.href, raise IdentAlreadyExists(old_href=old_props.href, new_href=a_props.href)
new_href=a_props.href)
b_props = self.get_new_b(ident) or ItemMetadata() b_props = self.get_new_b(ident) or ItemMetadata()
self._c.execute( self._c.execute(
'INSERT OR REPLACE INTO new_status ' "INSERT OR REPLACE INTO new_status " "VALUES(?, ?, ?, ?, ?, ?, ?)",
'VALUES(?, ?, ?, ?, ?, ?, ?)', (
(ident, a_props.href, b_props.href, a_props.hash, b_props.hash, ident,
a_props.etag, b_props.etag) a_props.href,
b_props.href,
a_props.hash,
b_props.hash,
a_props.etag,
b_props.etag,
),
) )
def insert_ident_b(self, ident, b_props): def insert_ident_b(self, ident, b_props):
# FIXME: Super inefficient # FIXME: Super inefficient
old_props = self.get_new_b(ident) old_props = self.get_new_b(ident)
if old_props is not None: if old_props is not None:
raise IdentAlreadyExists(old_href=old_props.href, raise IdentAlreadyExists(old_href=old_props.href, new_href=b_props.href)
new_href=b_props.href)
a_props = self.get_new_a(ident) or ItemMetadata() a_props = self.get_new_a(ident) or ItemMetadata()
self._c.execute( self._c.execute(
'INSERT OR REPLACE INTO new_status ' "INSERT OR REPLACE INTO new_status " "VALUES(?, ?, ?, ?, ?, ?, ?)",
'VALUES(?, ?, ?, ?, ?, ?, ?)', (
(ident, a_props.href, b_props.href, a_props.hash, b_props.hash, ident,
a_props.etag, b_props.etag) a_props.href,
b_props.href,
a_props.hash,
b_props.hash,
a_props.etag,
b_props.etag,
),
) )
def update_ident_a(self, ident, props): def update_ident_a(self, ident, props):
self._c.execute( self._c.execute(
'UPDATE new_status' "UPDATE new_status" " SET href_a=?, hash_a=?, etag_a=?" " WHERE ident=?",
' SET href_a=?, hash_a=?, etag_a=?' (props.href, props.hash, props.etag, ident),
' WHERE ident=?',
(props.href, props.hash, props.etag, ident)
) )
assert self._c.rowcount > 0 assert self._c.rowcount > 0
def update_ident_b(self, ident, props): def update_ident_b(self, ident, props):
self._c.execute( self._c.execute(
'UPDATE new_status' "UPDATE new_status" " SET href_b=?, hash_b=?, etag_b=?" " WHERE ident=?",
' SET href_b=?, hash_b=?, etag_b=?' (props.href, props.hash, props.etag, ident),
' WHERE ident=?',
(props.href, props.hash, props.etag, ident)
) )
assert self._c.rowcount > 0 assert self._c.rowcount > 0
def remove_ident(self, ident): def remove_ident(self, ident):
self._c.execute('DELETE FROM new_status WHERE ident=?', (ident,)) self._c.execute("DELETE FROM new_status WHERE ident=?", (ident,))
def _get_impl(self, ident, side, table): def _get_impl(self, ident, side, table):
res = self._c.execute('SELECT href_{side} AS href,' res = self._c.execute(
' hash_{side} AS hash,' "SELECT href_{side} AS href,"
' etag_{side} AS etag ' " hash_{side} AS hash,"
'FROM {table} WHERE ident=?' " etag_{side} AS etag "
.format(side=side, table=table), "FROM {table} WHERE ident=?".format(side=side, table=table),
(ident,)).fetchone() (ident,),
).fetchone()
if res is None: if res is None:
return None return None
if res['hash'] is None: # FIXME: Implement as constraint in db if res["hash"] is None: # FIXME: Implement as constraint in db
assert res['href'] is None assert res["href"] is None
assert res['etag'] is None assert res["etag"] is None
return None return None
res = dict(res) res = dict(res)
return ItemMetadata(**res) return ItemMetadata(**res)
def get_a(self, ident): def get_a(self, ident):
return self._get_impl(ident, side='a', table='status') return self._get_impl(ident, side="a", table="status")
def get_b(self, ident): def get_b(self, ident):
return self._get_impl(ident, side='b', table='status') return self._get_impl(ident, side="b", table="status")
def get_new_a(self, ident): def get_new_a(self, ident):
return self._get_impl(ident, side='a', table='new_status') return self._get_impl(ident, side="a", table="new_status")
def get_new_b(self, ident): def get_new_b(self, ident):
return self._get_impl(ident, side='b', table='new_status') return self._get_impl(ident, side="b", table="new_status")
def iter_old(self): def iter_old(self):
return iter(res['ident'] for res in return iter(
self._c.execute('SELECT ident FROM status').fetchall()) res["ident"]
for res in self._c.execute("SELECT ident FROM status").fetchall()
)
def iter_new(self): def iter_new(self):
return iter(res['ident'] for res in return iter(
self._c.execute('SELECT ident FROM new_status').fetchall()) res["ident"]
for res in self._c.execute("SELECT ident FROM new_status").fetchall()
)
def rollback(self, ident): def rollback(self, ident):
a = self.get_a(ident) a = self.get_a(ident)
@ -286,41 +298,41 @@ class SqliteStatus(_StatusBase):
return return
self._c.execute( self._c.execute(
'INSERT OR REPLACE INTO new_status' "INSERT OR REPLACE INTO new_status" " VALUES (?, ?, ?, ?, ?, ?, ?)",
' VALUES (?, ?, ?, ?, ?, ?, ?)', (ident, a.href, b.href, a.hash, b.hash, a.etag, b.etag),
(ident, a.href, b.href, a.hash, b.hash, a.etag, b.etag)
) )
def _get_by_href_impl(self, href, default=(None, None), side=None): def _get_by_href_impl(self, href, default=(None, None), side=None):
res = self._c.execute( res = self._c.execute(
'SELECT ident, hash_{side} AS hash, etag_{side} AS etag ' "SELECT ident, hash_{side} AS hash, etag_{side} AS etag "
'FROM status WHERE href_{side}=?'.format(side=side), "FROM status WHERE href_{side}=?".format(side=side),
(href,)).fetchone() (href,),
).fetchone()
if not res: if not res:
return default return default
return res['ident'], ItemMetadata( return res["ident"], ItemMetadata(
href=href, href=href,
hash=res['hash'], hash=res["hash"],
etag=res['etag'], etag=res["etag"],
) )
def get_by_href_a(self, *a, **kw): def get_by_href_a(self, *a, **kw):
kw['side'] = 'a' kw["side"] = "a"
return self._get_by_href_impl(*a, **kw) return self._get_by_href_impl(*a, **kw)
def get_by_href_b(self, *a, **kw): def get_by_href_b(self, *a, **kw):
kw['side'] = 'b' kw["side"] = "b"
return self._get_by_href_impl(*a, **kw) return self._get_by_href_impl(*a, **kw)
class SubStatus: class SubStatus:
def __init__(self, parent, side): def __init__(self, parent, side):
self.parent = parent self.parent = parent
assert side in 'ab' assert side in "ab"
self.remove_ident = parent.remove_ident self.remove_ident = parent.remove_ident
if side == 'a': if side == "a":
self.insert_ident = parent.insert_ident_a self.insert_ident = parent.insert_ident_a
self.update_ident = parent.update_ident_a self.update_ident = parent.update_ident_a
self.get = parent.get_a self.get = parent.get_a
@ -345,8 +357,4 @@ class ItemMetadata:
setattr(self, k, v) setattr(self, k, v)
def to_status(self): def to_status(self):
return { return {"href": self.href, "etag": self.etag, "hash": self.hash}
'href': self.href,
'etag': self.etag,
'hash': self.hash
}

View file

@ -11,9 +11,9 @@ from . import exceptions
# not included, because there are some servers that (incorrectly) encode it to # not included, because there are some servers that (incorrectly) encode it to
# `%40` when it's part of a URL path, and reject or "repair" URLs that contain # `%40` when it's part of a URL path, and reject or "repair" URLs that contain
# `@` in the path. So it's better to just avoid it. # `@` in the path. So it's better to just avoid it.
SAFE_UID_CHARS = ('abcdefghijklmnopqrstuvwxyz' SAFE_UID_CHARS = (
'ABCDEFGHIJKLMNOPQRSTUVWXYZ' "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789_.-+"
'0123456789_.-+') )
_missing = object() _missing = object()
@ -26,13 +26,13 @@ def expand_path(p):
def split_dict(d, f): def split_dict(d, f):
'''Puts key into first dict if f(key), otherwise in second dict''' """Puts key into first dict if f(key), otherwise in second dict"""
a, b = split_sequence(d.items(), lambda item: f(item[0])) a, b = split_sequence(d.items(), lambda item: f(item[0]))
return dict(a), dict(b) return dict(a), dict(b)
def split_sequence(s, f): def split_sequence(s, f):
'''Puts item into first list if f(item), else in second list''' """Puts item into first list if f(item), else in second list"""
a = [] a = []
b = [] b = []
for item in s: for item in s:
@ -45,9 +45,9 @@ def split_sequence(s, f):
def uniq(s): def uniq(s):
'''Filter duplicates while preserving order. ``set`` can almost always be """Filter duplicates while preserving order. ``set`` can almost always be
used instead of this, but preserving order might prove useful for used instead of this, but preserving order might prove useful for
debugging.''' debugging."""
d = set() d = set()
for x in s: for x in s:
if x not in d: if x not in d:
@ -56,23 +56,23 @@ def uniq(s):
def get_etag_from_file(f): def get_etag_from_file(f):
'''Get etag from a filepath or file-like object. """Get etag from a filepath or file-like object.
This function will flush/sync the file as much as necessary to obtain a This function will flush/sync the file as much as necessary to obtain a
correct value. correct value.
''' """
if hasattr(f, 'read'): if hasattr(f, "read"):
f.flush() # Only this is necessary on Linux f.flush() # Only this is necessary on Linux
if sys.platform == 'win32': if sys.platform == "win32":
os.fsync(f.fileno()) # Apparently necessary on Windows os.fsync(f.fileno()) # Apparently necessary on Windows
stat = os.fstat(f.fileno()) stat = os.fstat(f.fileno())
else: else:
stat = os.stat(f) stat = os.stat(f)
mtime = getattr(stat, 'st_mtime_ns', None) mtime = getattr(stat, "st_mtime_ns", None)
if mtime is None: if mtime is None:
mtime = stat.st_mtime mtime = stat.st_mtime
return f'{mtime:.9f};{stat.st_ino}' return f"{mtime:.9f};{stat.st_ino}"
def get_storage_init_specs(cls, stop_at=object): def get_storage_init_specs(cls, stop_at=object):
@ -80,11 +80,12 @@ def get_storage_init_specs(cls, stop_at=object):
return () return ()
spec = getfullargspec(cls.__init__) spec = getfullargspec(cls.__init__)
traverse_superclass = getattr(cls.__init__, '_traverse_superclass', True) traverse_superclass = getattr(cls.__init__, "_traverse_superclass", True)
if traverse_superclass: if traverse_superclass:
if traverse_superclass is True: # noqa if traverse_superclass is True: # noqa
supercls = next(getattr(x.__init__, '__objclass__', x) supercls = next(
for x in cls.__mro__[1:]) getattr(x.__init__, "__objclass__", x) for x in cls.__mro__[1:]
)
else: else:
supercls = traverse_superclass supercls = traverse_superclass
superspecs = get_storage_init_specs(supercls, stop_at=stop_at) superspecs = get_storage_init_specs(supercls, stop_at=stop_at)
@ -95,7 +96,7 @@ def get_storage_init_specs(cls, stop_at=object):
def get_storage_init_args(cls, stop_at=object): def get_storage_init_args(cls, stop_at=object):
''' """
Get args which are taken during class initialization. Assumes that all Get args which are taken during class initialization. Assumes that all
classes' __init__ calls super().__init__ with the rest of the arguments. classes' __init__ calls super().__init__ with the rest of the arguments.
@ -103,7 +104,7 @@ def get_storage_init_args(cls, stop_at=object):
:returns: (all, required), where ``all`` is a set of all arguments the :returns: (all, required), where ``all`` is a set of all arguments the
class can take, and ``required`` is the subset of arguments the class class can take, and ``required`` is the subset of arguments the class
requires. requires.
''' """
all, required = set(), set() all, required = set(), set()
for spec in get_storage_init_specs(cls, stop_at=stop_at): for spec in get_storage_init_specs(cls, stop_at=stop_at):
all.update(spec.args[1:]) all.update(spec.args[1:])
@ -114,47 +115,48 @@ def get_storage_init_args(cls, stop_at=object):
def checkdir(path, create=False, mode=0o750): def checkdir(path, create=False, mode=0o750):
''' """
Check whether ``path`` is a directory. Check whether ``path`` is a directory.
:param create: Whether to create the directory (and all parent directories) :param create: Whether to create the directory (and all parent directories)
if it does not exist. if it does not exist.
:param mode: Mode to create missing directories with. :param mode: Mode to create missing directories with.
''' """
if not os.path.isdir(path): if not os.path.isdir(path):
if os.path.exists(path): if os.path.exists(path):
raise OSError(f'{path} is not a directory.') raise OSError(f"{path} is not a directory.")
if create: if create:
os.makedirs(path, mode) os.makedirs(path, mode)
else: else:
raise exceptions.CollectionNotFound('Directory {} does not exist.' raise exceptions.CollectionNotFound(
.format(path)) "Directory {} does not exist.".format(path)
)
def checkfile(path, create=False): def checkfile(path, create=False):
''' """
Check whether ``path`` is a file. Check whether ``path`` is a file.
:param create: Whether to create the file's parent directories if they do :param create: Whether to create the file's parent directories if they do
not exist. not exist.
''' """
checkdir(os.path.dirname(path), create=create) checkdir(os.path.dirname(path), create=create)
if not os.path.isfile(path): if not os.path.isfile(path):
if os.path.exists(path): if os.path.exists(path):
raise OSError(f'{path} is not a file.') raise OSError(f"{path} is not a file.")
if create: if create:
with open(path, 'wb'): with open(path, "wb"):
pass pass
else: else:
raise exceptions.CollectionNotFound('File {} does not exist.' raise exceptions.CollectionNotFound("File {} does not exist.".format(path))
.format(path))
class cached_property: class cached_property:
'''A read-only @property that is only evaluated once. Only usable on class """A read-only @property that is only evaluated once. Only usable on class
instances' methods. instances' methods.
''' """
def __init__(self, fget, doc=None): def __init__(self, fget, doc=None):
self.__name__ = fget.__name__ self.__name__ = fget.__name__
self.__module__ = fget.__module__ self.__module__ = fget.__module__
@ -173,12 +175,12 @@ def href_safe(ident, safe=SAFE_UID_CHARS):
def generate_href(ident=None, safe=SAFE_UID_CHARS): def generate_href(ident=None, safe=SAFE_UID_CHARS):
''' """
Generate a safe identifier, suitable for URLs, storage hrefs or UIDs. Generate a safe identifier, suitable for URLs, storage hrefs or UIDs.
If the given ident string is safe, it will be returned, otherwise a random If the given ident string is safe, it will be returned, otherwise a random
UUID. UUID.
''' """
if not ident or not href_safe(ident, safe): if not ident or not href_safe(ident, safe):
return str(uuid.uuid4()) return str(uuid.uuid4())
else: else:
@ -188,6 +190,7 @@ def generate_href(ident=None, safe=SAFE_UID_CHARS):
def synchronized(lock=None): def synchronized(lock=None):
if lock is None: if lock is None:
from threading import Lock from threading import Lock
lock = Lock() lock = Lock()
def inner(f): def inner(f):
@ -195,21 +198,24 @@ def synchronized(lock=None):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
with lock: with lock:
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
return inner return inner
def open_graphical_browser(url, new=0, autoraise=True): def open_graphical_browser(url, new=0, autoraise=True):
'''Open a graphical web browser. """Open a graphical web browser.
This is basically like `webbrowser.open`, but without trying to launch CLI This is basically like `webbrowser.open`, but without trying to launch CLI
browsers at all. We're excluding those since it's undesirable to launch browsers at all. We're excluding those since it's undesirable to launch
those when you're using vdirsyncer on a server. Rather copypaste the URL those when you're using vdirsyncer on a server. Rather copypaste the URL
into the local browser, or use the URL-yanking features of your terminal into the local browser, or use the URL-yanking features of your terminal
emulator. emulator.
''' """
import webbrowser import webbrowser
cli_names = {'www-browser', 'links', 'links2', 'elinks', 'lynx', 'w3m'}
cli_names = {"www-browser", "links", "links2", "elinks", "lynx", "w3m"}
if webbrowser._tryorder is None: # Python 3.7 if webbrowser._tryorder is None: # Python 3.7
webbrowser.register_standard_browsers() webbrowser.register_standard_browsers()
@ -222,5 +228,4 @@ def open_graphical_browser(url, new=0, autoraise=True):
if browser.open(url, new, autoraise): if browser.open(url, new, autoraise):
return return
raise RuntimeError('No graphical browser found. Please open the URL ' raise RuntimeError("No graphical browser found. Please open the URL " "manually.")
'manually.')

View file

@ -8,36 +8,36 @@ from .utils import uniq
IGNORE_PROPS = ( IGNORE_PROPS = (
# PRODID is changed by radicale for some reason after upload # PRODID is changed by radicale for some reason after upload
'PRODID', "PRODID",
# Sometimes METHOD:PUBLISH is added by WebCAL providers, for us it doesn't # Sometimes METHOD:PUBLISH is added by WebCAL providers, for us it doesn't
# make a difference # make a difference
'METHOD', "METHOD",
# X-RADICALE-NAME is used by radicale, because hrefs don't really exist in # X-RADICALE-NAME is used by radicale, because hrefs don't really exist in
# their filesystem backend # their filesystem backend
'X-RADICALE-NAME', "X-RADICALE-NAME",
# Apparently this is set by Horde? # Apparently this is set by Horde?
# https://github.com/pimutils/vdirsyncer/issues/318 # https://github.com/pimutils/vdirsyncer/issues/318
'X-WR-CALNAME', "X-WR-CALNAME",
# Those are from the VCARD specification and is supposed to change when the # Those are from the VCARD specification and is supposed to change when the
# item does -- however, we can determine that ourselves # item does -- however, we can determine that ourselves
'REV', "REV",
'LAST-MODIFIED', "LAST-MODIFIED",
'CREATED', "CREATED",
# Some iCalendar HTTP calendars generate the DTSTAMP at request time, so # Some iCalendar HTTP calendars generate the DTSTAMP at request time, so
# this property always changes when the rest of the item didn't. Some do # this property always changes when the rest of the item didn't. Some do
# the same with the UID. # the same with the UID.
# #
# - Google's read-only calendar links # - Google's read-only calendar links
# - http://www.feiertage-oesterreich.at/ # - http://www.feiertage-oesterreich.at/
'DTSTAMP', "DTSTAMP",
'UID', "UID",
) )
class Item: class Item:
'''Immutable wrapper class for VCALENDAR (VEVENT, VTODO) and """Immutable wrapper class for VCALENDAR (VEVENT, VTODO) and
VCARD''' VCARD"""
def __init__(self, raw): def __init__(self, raw):
assert isinstance(raw, str), type(raw) assert isinstance(raw, str), type(raw)
@ -50,43 +50,43 @@ class Item:
component = stack.pop() component = stack.pop()
stack.extend(component.subcomponents) stack.extend(component.subcomponents)
if component.name in ('VEVENT', 'VTODO', 'VJOURNAL', 'VCARD'): if component.name in ("VEVENT", "VTODO", "VJOURNAL", "VCARD"):
del component['UID'] del component["UID"]
if new_uid: if new_uid:
component['UID'] = new_uid component["UID"] = new_uid
return Item('\r\n'.join(parsed.dump_lines())) return Item("\r\n".join(parsed.dump_lines()))
@cached_property @cached_property
def raw(self): def raw(self):
'''Raw content of the item, as unicode string. """Raw content of the item, as unicode string.
Vdirsyncer doesn't validate the content in any way. Vdirsyncer doesn't validate the content in any way.
''' """
return self._raw return self._raw
@cached_property @cached_property
def uid(self): def uid(self):
'''Global identifier of the item, across storages, doesn't change after """Global identifier of the item, across storages, doesn't change after
a modification of the item.''' a modification of the item."""
# Don't actually parse component, but treat all lines as single # Don't actually parse component, but treat all lines as single
# component, avoiding traversal through all subcomponents. # component, avoiding traversal through all subcomponents.
x = _Component('TEMP', self.raw.splitlines(), []) x = _Component("TEMP", self.raw.splitlines(), [])
try: try:
return x['UID'].strip() or None return x["UID"].strip() or None
except KeyError: except KeyError:
return None return None
@cached_property @cached_property
def hash(self): def hash(self):
'''Hash of self.raw, used for etags.''' """Hash of self.raw, used for etags."""
return hash_item(self.raw) return hash_item(self.raw)
@cached_property @cached_property
def ident(self): def ident(self):
'''Used for generating hrefs and matching up items during """Used for generating hrefs and matching up items during
synchronization. This is either the UID or the hash of the item's synchronization. This is either the UID or the hash of the item's
content.''' content."""
# We hash the item instead of directly using its raw content, because # We hash the item instead of directly using its raw content, because
# #
@ -98,7 +98,7 @@ class Item:
@property @property
def parsed(self): def parsed(self):
'''Don't cache because the rv is mutable.''' """Don't cache because the rv is mutable."""
try: try:
return _Component.parse(self.raw) return _Component.parse(self.raw)
except Exception: except Exception:
@ -106,33 +106,32 @@ class Item:
def normalize_item(item, ignore_props=IGNORE_PROPS): def normalize_item(item, ignore_props=IGNORE_PROPS):
'''Create syntactically invalid mess that is equal for similar items.''' """Create syntactically invalid mess that is equal for similar items."""
if not isinstance(item, Item): if not isinstance(item, Item):
item = Item(item) item = Item(item)
item = _strip_timezones(item) item = _strip_timezones(item)
x = _Component('TEMP', item.raw.splitlines(), []) x = _Component("TEMP", item.raw.splitlines(), [])
for prop in IGNORE_PROPS: for prop in IGNORE_PROPS:
del x[prop] del x[prop]
x.props.sort() x.props.sort()
return '\r\n'.join(filter(bool, (line.strip() for line in x.props))) return "\r\n".join(filter(bool, (line.strip() for line in x.props)))
def _strip_timezones(item): def _strip_timezones(item):
parsed = item.parsed parsed = item.parsed
if not parsed or parsed.name != 'VCALENDAR': if not parsed or parsed.name != "VCALENDAR":
return item return item
parsed.subcomponents = [c for c in parsed.subcomponents parsed.subcomponents = [c for c in parsed.subcomponents if c.name != "VTIMEZONE"]
if c.name != 'VTIMEZONE']
return Item('\r\n'.join(parsed.dump_lines())) return Item("\r\n".join(parsed.dump_lines()))
def hash_item(text): def hash_item(text):
return hashlib.sha256(normalize_item(text).encode('utf-8')).hexdigest() return hashlib.sha256(normalize_item(text).encode("utf-8")).hexdigest()
def split_collection(text): def split_collection(text):
@ -146,16 +145,16 @@ def split_collection(text):
for item in chain(items.values(), ungrouped_items): for item in chain(items.values(), ungrouped_items):
item.subcomponents.extend(inline) item.subcomponents.extend(inline)
yield '\r\n'.join(item.dump_lines()) yield "\r\n".join(item.dump_lines())
def _split_collection_impl(item, main, inline, items, ungrouped_items): def _split_collection_impl(item, main, inline, items, ungrouped_items):
if item.name == 'VTIMEZONE': if item.name == "VTIMEZONE":
inline.append(item) inline.append(item)
elif item.name == 'VCARD': elif item.name == "VCARD":
ungrouped_items.append(item) ungrouped_items.append(item)
elif item.name in ('VTODO', 'VEVENT', 'VJOURNAL'): elif item.name in ("VTODO", "VEVENT", "VJOURNAL"):
uid = item.get('UID', '') uid = item.get("UID", "")
wrapper = _Component(main.name, main.props[:], []) wrapper = _Component(main.name, main.props[:], [])
if uid.strip(): if uid.strip():
@ -164,34 +163,31 @@ def _split_collection_impl(item, main, inline, items, ungrouped_items):
ungrouped_items.append(wrapper) ungrouped_items.append(wrapper)
wrapper.subcomponents.append(item) wrapper.subcomponents.append(item)
elif item.name in ('VCALENDAR', 'VADDRESSBOOK'): elif item.name in ("VCALENDAR", "VADDRESSBOOK"):
if item.name == 'VCALENDAR': if item.name == "VCALENDAR":
del item['METHOD'] del item["METHOD"]
for subitem in item.subcomponents: for subitem in item.subcomponents:
_split_collection_impl(subitem, item, inline, items, _split_collection_impl(subitem, item, inline, items, ungrouped_items)
ungrouped_items)
else: else:
raise ValueError('Unknown component: {}' raise ValueError("Unknown component: {}".format(item.name))
.format(item.name))
_default_join_wrappers = { _default_join_wrappers = {
'VCALENDAR': 'VCALENDAR', "VCALENDAR": "VCALENDAR",
'VEVENT': 'VCALENDAR', "VEVENT": "VCALENDAR",
'VTODO': 'VCALENDAR', "VTODO": "VCALENDAR",
'VCARD': 'VADDRESSBOOK' "VCARD": "VADDRESSBOOK",
} }
def join_collection(items, wrappers=_default_join_wrappers): def join_collection(items, wrappers=_default_join_wrappers):
''' """
:param wrappers: { :param wrappers: {
item_type: wrapper_type item_type: wrapper_type
} }
''' """
items1, items2 = tee((_Component.parse(x) items1, items2 = tee((_Component.parse(x) for x in items), 2)
for x in items), 2)
item_type, wrapper_type = _get_item_type(items1, wrappers) item_type, wrapper_type = _get_item_type(items1, wrappers)
wrapper_props = [] wrapper_props = []
@ -206,17 +202,19 @@ def join_collection(items, wrappers=_default_join_wrappers):
lines = chain(*uniq(tuple(x.dump_lines()) for x in components)) lines = chain(*uniq(tuple(x.dump_lines()) for x in components))
if wrapper_type is not None: if wrapper_type is not None:
lines = chain(*( lines = chain(
[f'BEGIN:{wrapper_type}'], *(
# XXX: wrapper_props is a list of lines (with line-wrapping), so [f"BEGIN:{wrapper_type}"],
# filtering out duplicate lines will almost certainly break # XXX: wrapper_props is a list of lines (with line-wrapping), so
# multiline-values. Since the only props we usually need to # filtering out duplicate lines will almost certainly break
# support are PRODID and VERSION, I don't care. # multiline-values. Since the only props we usually need to
uniq(wrapper_props), # support are PRODID and VERSION, I don't care.
lines, uniq(wrapper_props),
[f'END:{wrapper_type}'] lines,
)) [f"END:{wrapper_type}"],
return ''.join(line + '\r\n' for line in lines) )
)
return "".join(line + "\r\n" for line in lines)
def _get_item_type(components, wrappers): def _get_item_type(components, wrappers):
@ -234,11 +232,11 @@ def _get_item_type(components, wrappers):
if not i: if not i:
return None, None return None, None
else: else:
raise ValueError('Not sure how to join components.') raise ValueError("Not sure how to join components.")
class _Component: class _Component:
''' """
Raw outline of the components. Raw outline of the components.
Vdirsyncer's operations on iCalendar and VCard objects are limited to Vdirsyncer's operations on iCalendar and VCard objects are limited to
@ -253,15 +251,15 @@ class _Component:
Original version from https://github.com/collective/icalendar/, but apart Original version from https://github.com/collective/icalendar/, but apart
from the similar API, very few parts have been reused. from the similar API, very few parts have been reused.
''' """
def __init__(self, name, lines, subcomponents): def __init__(self, name, lines, subcomponents):
''' """
:param name: The component name. :param name: The component name.
:param lines: The component's own properties, as list of lines :param lines: The component's own properties, as list of lines
(strings). (strings).
:param subcomponents: List of components. :param subcomponents: List of components.
''' """
self.name = name self.name = name
self.props = lines self.props = lines
self.subcomponents = subcomponents self.subcomponents = subcomponents
@ -269,7 +267,7 @@ class _Component:
@classmethod @classmethod
def parse(cls, lines, multiple=False): def parse(cls, lines, multiple=False):
if isinstance(lines, bytes): if isinstance(lines, bytes):
lines = lines.decode('utf-8') lines = lines.decode("utf-8")
if isinstance(lines, str): if isinstance(lines, str):
lines = lines.splitlines() lines = lines.splitlines()
@ -277,10 +275,10 @@ class _Component:
rv = [] rv = []
try: try:
for _i, line in enumerate(lines): for _i, line in enumerate(lines):
if line.startswith('BEGIN:'): if line.startswith("BEGIN:"):
c_name = line[len('BEGIN:'):].strip().upper() c_name = line[len("BEGIN:") :].strip().upper()
stack.append(cls(c_name, [], [])) stack.append(cls(c_name, [], []))
elif line.startswith('END:'): elif line.startswith("END:"):
component = stack.pop() component = stack.pop()
if stack: if stack:
stack[-1].subcomponents.append(component) stack[-1].subcomponents.append(component)
@ -290,25 +288,24 @@ class _Component:
if line.strip(): if line.strip():
stack[-1].props.append(line) stack[-1].props.append(line)
except IndexError: except IndexError:
raise ValueError('Parsing error at line {}'.format(_i + 1)) raise ValueError("Parsing error at line {}".format(_i + 1))
if multiple: if multiple:
return rv return rv
elif len(rv) != 1: elif len(rv) != 1:
raise ValueError('Found {} components, expected one.' raise ValueError("Found {} components, expected one.".format(len(rv)))
.format(len(rv)))
else: else:
return rv[0] return rv[0]
def dump_lines(self): def dump_lines(self):
yield f'BEGIN:{self.name}' yield f"BEGIN:{self.name}"
yield from self.props yield from self.props
for c in self.subcomponents: for c in self.subcomponents:
yield from c.dump_lines() yield from c.dump_lines()
yield f'END:{self.name}' yield f"END:{self.name}"
def __delitem__(self, key): def __delitem__(self, key):
prefix = (f'{key}:', f'{key};') prefix = (f"{key}:", f"{key};")
new_lines = [] new_lines = []
lineiter = iter(self.props) lineiter = iter(self.props)
while True: while True:
@ -321,7 +318,7 @@ class _Component:
break break
for line in lineiter: for line in lineiter:
if not line.startswith((' ', '\t')): if not line.startswith((" ", "\t")):
new_lines.append(line) new_lines.append(line)
break break
@ -329,36 +326,37 @@ class _Component:
def __setitem__(self, key, val): def __setitem__(self, key, val):
assert isinstance(val, str) assert isinstance(val, str)
assert '\n' not in val assert "\n" not in val
del self[key] del self[key]
line = f'{key}:{val}' line = f"{key}:{val}"
self.props.append(line) self.props.append(line)
def __contains__(self, obj): def __contains__(self, obj):
if isinstance(obj, type(self)): if isinstance(obj, type(self)):
return obj not in self.subcomponents and \ return obj not in self.subcomponents and not any(
not any(obj in x for x in self.subcomponents) obj in x for x in self.subcomponents
)
elif isinstance(obj, str): elif isinstance(obj, str):
return self.get(obj, None) is not None return self.get(obj, None) is not None
else: else:
raise ValueError(obj) raise ValueError(obj)
def __getitem__(self, key): def __getitem__(self, key):
prefix_without_params = f'{key}:' prefix_without_params = f"{key}:"
prefix_with_params = f'{key};' prefix_with_params = f"{key};"
iterlines = iter(self.props) iterlines = iter(self.props)
for line in iterlines: for line in iterlines:
if line.startswith(prefix_without_params): if line.startswith(prefix_without_params):
rv = line[len(prefix_without_params):] rv = line[len(prefix_without_params) :]
break break
elif line.startswith(prefix_with_params): elif line.startswith(prefix_with_params):
rv = line[len(prefix_with_params):].split(':', 1)[-1] rv = line[len(prefix_with_params) :].split(":", 1)[-1]
break break
else: else:
raise KeyError() raise KeyError()
for line in iterlines: for line in iterlines:
if line.startswith((' ', '\t')): if line.startswith((" ", "\t")):
rv += line[1:] rv += line[1:]
else: else:
break break