From eece22723e4d62a80c73ebb3e85251dae30d7844 Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Mon, 19 May 2014 18:32:31 +0200 Subject: [PATCH] Move class inspection code to utils --- tests/utils/test_main.py | 15 +++++++++++++++ vdirsyncer/cli.py | 18 ++---------------- vdirsyncer/utils/__init__.py | 25 +++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/utils/test_main.py b/tests/utils/test_main.py index c9ed34a..cef05b0 100644 --- a/tests/utils/test_main.py +++ b/tests/utils/test_main.py @@ -108,4 +108,19 @@ def test_get_password_from_system_keyring(monkeypatch, resources_to_test): assert netrc_calls == [hostname] +def test_get_class_init_args(): + class Foobar(object): + def __init__(self, foo, bar, baz=None): + pass + all, required = utils.get_class_init_args(Foobar) + assert all == {'foo', 'bar', 'baz'} + assert required == {'foo', 'bar'} + + +def test_get_class_init_args_on_storage(): + from vdirsyncer.storage.memory import MemoryStorage + + all, required = utils.get_class_init_args(MemoryStorage) + assert not all + assert not required diff --git a/vdirsyncer/cli.py b/vdirsyncer/cli.py index abe5cb2..241a946 100644 --- a/vdirsyncer/cli.py +++ b/vdirsyncer/cli.py @@ -15,7 +15,7 @@ import argvard from .storage import storage_names from .sync import sync, StorageEmpty -from .utils import expand_path, parse_options, split_dict +from .utils import expand_path, parse_options, split_dict, get_class_init_args import vdirsyncer.log as log @@ -123,7 +123,7 @@ def storage_instance_from_config(config, description=None): try: return cls(**config) except Exception: - all, required = get_init_args(cls) + all, required = get_class_init_args(cls) given = set(config) missing = required - given invalid = given - all @@ -147,20 +147,6 @@ def storage_instance_from_config(config, description=None): sys.exit(1) -def get_init_args(cls): - from vdirsyncer.storage.base import Storage - import inspect - - if cls is Storage: - return set(), set() - - spec = inspect.getargspec(cls.__init__) - all = set(spec.args[1:]) - required = set(spec.args[1:-len(spec.defaults)]) - supercls = next(x for x in cls.__mro__[1:] if hasattr(x, '__init__')) - s_all, s_required = get_init_args(supercls) - - return all | s_all, required | s_required def main(): diff --git a/vdirsyncer/utils/__init__.py b/vdirsyncer/utils/__init__.py index 70a71aa..4fa5034 100644 --- a/vdirsyncer/utils/__init__.py +++ b/vdirsyncer/utils/__init__.py @@ -251,3 +251,28 @@ class safe_write(object): def get_etag_from_file(fpath): return '{:.9f}'.format(os.path.getmtime(fpath)) + + +def get_class_init_args(cls): + ''' + Get args which are taken during class initialization. Assumes that all + classes' __init__ calls super().__init__ with the rest of the arguments. + + :param cls: The class to inspect. + :returns: (all, required), where ``all`` is a set of all arguments the + class can take, and ``required`` is the subset of arguments the class + requires. + ''' + import inspect + + if cls is object: + return set(), set() + + spec = inspect.getargspec(cls.__init__) + all = set(spec.args[1:]) + required = set(spec.args[1:-len(spec.defaults or ())]) + supercls = next(getattr(x.__init__, '__objclass__', x) + for x in cls.__mro__[1:]) + s_all, s_required = get_class_init_args(supercls) + + return all | s_all, required | s_required