Stricter config validation

This commit is contained in:
Markus Unterwaditzer 2014-12-26 00:50:15 +01:00
parent 4757fac383
commit 6ef330aac5
2 changed files with 63 additions and 25 deletions

View file

@ -187,7 +187,7 @@ def test_missing_general_section(tmpdir, runner):
result = runner.invoke(['sync']) result = runner.invoke(['sync'])
assert result.exception assert result.exception
assert result.output.startswith('critical:') assert result.output.startswith('critical:')
assert 'unable to find general section' in result.output.lower() assert 'invalid general section' in result.output.lower()
def test_wrong_general_section(tmpdir, runner): def test_wrong_general_section(tmpdir, runner):
@ -199,12 +199,11 @@ def test_wrong_general_section(tmpdir, runner):
assert result.exception assert result.exception
lines = result.output.splitlines() lines = result.output.splitlines()
assert lines[:-1] == [ assert lines[:-2] == [
'critical: general section doesn\'t take the parameters: wrong', 'critical: general section doesn\'t take the parameters: wrong',
'critical: general section is missing the parameters: status_path' 'critical: general section is missing the parameters: status_path'
] ]
assert lines[-1].startswith('critical:') assert 'Invalid general section.' in lines[-2]
assert lines[-1].endswith('Invalid general section.')
def test_verbosity(tmpdir): def test_verbosity(tmpdir):
@ -270,7 +269,7 @@ def test_collections_cache_invalidation(tmpdir, runner):
[pair foobar] [pair foobar]
a = foo a = foo
b = bar b = bar
collections = a, b, c collections = ["a", "b", "c"]
''').format(str(tmpdir))) ''').format(str(tmpdir)))
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir('foo')
@ -298,7 +297,7 @@ def test_collections_cache_invalidation(tmpdir, runner):
[pair foobar] [pair foobar]
a = foo a = foo
b = bar b = bar
collections = a, b, c collections = ["a", "b", "c"]
''').format(str(tmpdir))) ''').format(str(tmpdir)))
tmpdir.join('status').remove() tmpdir.join('status').remove()
@ -327,7 +326,7 @@ def test_invalid_pairs_as_cli_arg(tmpdir, runner):
[pair foobar] [pair foobar]
a = foo a = foo
b = bar b = bar
collections = a, b, c collections = ["a", "b", "c"]
''').format(str(tmpdir))) ''').format(str(tmpdir)))
tmpdir.mkdir('foo') tmpdir.mkdir('foo')
@ -353,7 +352,7 @@ def test_discover_command(tmpdir, runner):
[pair foobar] [pair foobar]
a = foo a = foo
b = bar b = bar
collections = from a collections = ["from a"]
''').format(str(tmpdir))) ''').format(str(tmpdir)))
foo = tmpdir.mkdir('foo') foo = tmpdir.mkdir('foo')
@ -410,3 +409,29 @@ def test_multiple_pairs(tmpdir, runner):
'Syncing bambaz', 'Syncing bambaz',
'Syncing foobar', 'Syncing foobar',
] ]
def test_invalid_collections_arg(tmpdir, runner):
runner.write_with_general(dedent('''
[pair foobar]
a = foo
b = bar
collections = [null]
[storage foo]
type = filesystem
path = {base}/foo/
fileext = .txt
[storage bar]
type = filesystem
path = {base}/bar/
fileext = .txt
'''.format(base=str(tmpdir))))
result = runner.invoke(['sync'])
assert result.exception
assert result.output.strip().endswith(
'Section `pair foobar`: `collections` parameter must be a list of '
'collection names (strings!) or `null`.'
)

View file

@ -22,6 +22,7 @@ from ..storage import storage_names
from ..sync import StorageEmpty, SyncConflict from ..sync import StorageEmpty, SyncConflict
from ..utils import expand_path, get_class_init_args, parse_options, \ from ..utils import expand_path, get_class_init_args, parse_options, \
safe_write safe_write
from ..utils.compat import text_type
try: try:
@ -222,12 +223,6 @@ def _collections_for_pair_impl(status_path, name_a, name_b, pair_name,
def _validate_general_section(general_config): def _validate_general_section(general_config):
if general_config is None:
raise CliError(
'Unable to find general section. You should copy the example '
'config from the repository and edit it.\n{}'.format(PROJECT_HOME)
)
if 'passwordeval' in general_config: if 'passwordeval' in general_config:
# XXX: Deprecation # XXX: Deprecation
cli_logger.warning('The `passwordeval` parameter has been renamed to ' cli_logger.warning('The `passwordeval` parameter has been renamed to '
@ -245,7 +240,20 @@ def _validate_general_section(general_config):
.format(u', '.join(missing))) .format(u', '.join(missing)))
if invalid or missing: if invalid or missing:
raise CliError('Invalid general section.') raise ValueError('Invalid general section. You should copy the '
'example config from the repository and edit it.\n{}'
.format(PROJECT_HOME))
def _validate_pair_section(pair_config):
collections = pair_config.get('collections', None)
if collections is None:
return
e = ValueError('`collections` parameter must be a list of collection '
'names (strings!) or `null`.')
if not isinstance(collections, list) or \
any(not isinstance(x, (text_type, bytes)) for x in collections):
raise e
def load_config(f): def load_config(f):
@ -254,24 +262,29 @@ def load_config(f):
get_options = lambda s: dict(parse_options(c.items(s), section=s)) get_options = lambda s: dict(parse_options(c.items(s), section=s))
general = None general = {}
pairs = {} pairs = {}
storages = {} storages = {}
def handle_storage(storage_name, options): def handle_storage(storage_name, options):
validate_section_name(storage_name, 'storage')
storages.setdefault(storage_name, {}).update(options) storages.setdefault(storage_name, {}).update(options)
storages[storage_name]['instance_name'] = storage_name storages[storage_name]['instance_name'] = storage_name
def handle_pair(pair_name, options): def handle_pair(pair_name, options):
validate_section_name(pair_name, 'pair') _validate_pair_section(options)
a, b = options.pop('a'), options.pop('b') a, b = options.pop('a'), options.pop('b')
pairs[pair_name] = a, b, options pairs[pair_name] = a, b, options
def handle_general(_, options):
if general:
raise CliError('More than one general section in config file.')
general.update(options)
def bad_section(name, options): def bad_section(name, options):
cli_logger.error('Unknown section: {}'.format(name)) cli_logger.error('Unknown section: {}'.format(name))
handlers = {'storage': handle_storage, 'pair': handle_pair} handlers = {'storage': handle_storage, 'pair': handle_pair, 'general':
handle_general}
for section in c.sections(): for section in c.sections():
if ' ' in section: if ' ' in section:
@ -279,12 +292,12 @@ def load_config(f):
else: else:
section_type = name = section section_type = name = section
if section_type == 'general': try:
if general is not None: validate_section_name(name, section_type)
raise CliError('More than one general section in config file.') f = handlers.get(section_type, bad_section)
general = get_options(section_type) f(name, get_options(section))
else: except ValueError as e:
handlers.get(section_type, bad_section)(name, get_options(section)) raise CliError('Section `{}`: {}'.format(section, str(e)))
_validate_general_section(general) _validate_general_section(general)
return general, pairs, storages return general, pairs, storages