Add type hint to vdirsyncer/cli/utils.py

This commit is contained in:
Justin ! 2023-08-22 23:52:03 -04:00 committed by Hugo
parent adc974bdd1
commit 3611e7d62f

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib import contextlib
import errno import errno
import importlib import importlib
@ -12,6 +14,7 @@ from atomicwrites import atomic_write
from .. import BUGTRACKER_HOME from .. import BUGTRACKER_HOME
from .. import DOCS_HOME from .. import DOCS_HOME
from .. import exceptions from .. import exceptions
from ..storage.base import Storage
from ..sync.exceptions import IdentConflict from ..sync.exceptions import IdentConflict
from ..sync.exceptions import PartialSync from ..sync.exceptions import PartialSync
from ..sync.exceptions import StorageEmpty from ..sync.exceptions import StorageEmpty
@ -27,7 +30,7 @@ STATUS_DIR_PERMISSIONS = 0o700
class _StorageIndex: class _StorageIndex:
def __init__(self): def __init__(self):
self._storages = { self._storages: dict[str, str] = {
"caldav": "vdirsyncer.storage.dav.CalDAVStorage", "caldav": "vdirsyncer.storage.dav.CalDAVStorage",
"carddav": "vdirsyncer.storage.dav.CardDAVStorage", "carddav": "vdirsyncer.storage.dav.CardDAVStorage",
"filesystem": "vdirsyncer.storage.filesystem.FilesystemStorage", "filesystem": "vdirsyncer.storage.filesystem.FilesystemStorage",
@ -37,7 +40,7 @@ class _StorageIndex:
"google_contacts": "vdirsyncer.storage.google.GoogleContactsStorage", "google_contacts": "vdirsyncer.storage.google.GoogleContactsStorage",
} }
def __getitem__(self, name): def __getitem__(self, name: str) -> Storage:
item = self._storages[name] item = self._storages[name]
if not isinstance(item, str): if not isinstance(item, str):
return item return item
@ -154,13 +157,18 @@ def handle_cli_error(status_name=None, e=None):
cli_logger.debug("".join(tb)) cli_logger.debug("".join(tb))
def get_status_name(pair, collection): def get_status_name(pair: str, collection: str | None) -> str:
if collection is None: if collection is None:
return pair return pair
return pair + "/" + collection return pair + "/" + collection
def get_status_path(base_path, pair, collection=None, data_type=None): def get_status_path(
base_path: str,
pair: str,
collection: str | None = None,
data_type: str | None = None,
) -> str:
assert data_type is not None assert data_type is not None
status_name = get_status_name(pair, collection) status_name = get_status_name(pair, collection)
path = expand_path(os.path.join(base_path, status_name)) path = expand_path(os.path.join(base_path, status_name))
@ -174,7 +182,12 @@ def get_status_path(base_path, pair, collection=None, data_type=None):
return path return path
def load_status(base_path, pair, collection=None, data_type=None): def load_status(
base_path: str,
pair: str,
collection: str | None = None,
data_type: str | None = None
) -> dict | None:
path = get_status_path(base_path, pair, collection, data_type) path = get_status_path(base_path, pair, collection, data_type)
if not os.path.exists(path): if not os.path.exists(path):
return None return None
@ -189,7 +202,7 @@ def load_status(base_path, pair, collection=None, data_type=None):
return {} return {}
def prepare_status_path(path): def prepare_status_path(path: str) -> None:
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
try: try:
@ -200,7 +213,7 @@ def prepare_status_path(path):
@contextlib.contextmanager @contextlib.contextmanager
def manage_sync_status(base_path, pair_name, collection_name): def manage_sync_status(base_path: str, pair_name: str, collection_name: str):
path = get_status_path(base_path, pair_name, collection_name, "items") path = get_status_path(base_path, pair_name, collection_name, "items")
status = None status = None
legacy_status = None legacy_status = None
@ -225,7 +238,13 @@ def manage_sync_status(base_path, pair_name, collection_name):
yield status yield status
def save_status(base_path, pair, collection=None, data_type=None, data=None): def save_status(
base_path: str,
pair: str,
collection: str | None = None,
data_type: str | None = None,
data: dict | None = None,
) -> None:
assert data_type is not None assert data_type is not None
assert data is not None assert data is not None
status_name = get_status_name(pair, collection) status_name = get_status_name(pair, collection)
@ -319,7 +338,7 @@ def handle_storage_init_error(cls, config):
) )
def assert_permissions(path, wanted): def assert_permissions(path: str, wanted: int) -> None:
permissions = os.stat(path).st_mode & 0o777 permissions = os.stat(path).st_mode & 0o777
if permissions > wanted: if permissions > wanted:
cli_logger.warning( cli_logger.warning(