diff --git a/tests/utils/test_main.py b/tests/utils/test_main.py index ad3d984..e775dee 100644 --- a/tests/utils/test_main.py +++ b/tests/utils/test_main.py @@ -227,3 +227,17 @@ def test_request_ssl(httpsserver): utils.request('GET', httpsserver.url, verify_fingerprint=''.join(reversed(sha1))) assert 'Fingerprints did not match' in str(excinfo.value) + + +def test_atomic_write(tmpdir): + x = utils.atomic_write + fname = tmpdir.join('ha') + for i in range(2): + with x(str(fname), binary=False, overwrite=True) as f: + f.write('hoho') + + with pytest.raises(OSError): + with x(str(fname), binary=False, overwrite=False) as f: + f.write('haha') + + assert fname.read() == 'hoho' diff --git a/vdirsyncer/cli/utils.py b/vdirsyncer/cli/utils.py index d8a38d7..a57e0f0 100644 --- a/vdirsyncer/cli/utils.py +++ b/vdirsyncer/cli/utils.py @@ -20,7 +20,7 @@ from .. import DOCS_HOME, PROJECT_HOME, exceptions, log from ..doubleclick import click from ..storage import storage_names from ..sync import StorageEmpty, SyncConflict -from ..utils import expand_path, get_class_init_args, safe_write +from ..utils import atomic_write, expand_path, get_class_init_args from ..utils.compat import text_type @@ -345,7 +345,7 @@ def save_status(base_path, pair, collection=None, data_type=None, data=None): if e.errno != errno.EEXIST: raise - with safe_write(path, 'w+') as f: + with atomic_write(path, binary=True, overwrite=True) as f: json.dump(data, f) diff --git a/vdirsyncer/storage/filesystem.py b/vdirsyncer/storage/filesystem.py index 670d3e0..4f5ca01 100644 --- a/vdirsyncer/storage/filesystem.py +++ b/vdirsyncer/storage/filesystem.py @@ -12,7 +12,7 @@ import os from .base import Item, Storage from .. import exceptions, log -from ..utils import checkdir, expand_path, get_etag_from_file, safe_write +from ..utils import atomic_write, checkdir, expand_path, get_etag_from_file from ..utils.compat import text_type logger = log.get(__name__) @@ -101,15 +101,21 @@ class FilesystemStorage(Storage): def upload(self, item): href = self._get_href(item) fpath = self._get_filepath(href) - if os.path.exists(fpath): - raise exceptions.AlreadyExistingError(item) if not isinstance(item.raw, text_type): raise TypeError('item.raw must be a unicode string.') - with safe_write(fpath, 'wb+') as f: - f.write(item.raw.encode(self.encoding)) - return href, f.get_etag() + + try: + with atomic_write(fpath, binary=True, overwrite=False) as f: + f.write(item.raw.encode(self.encoding)) + return href, f.get_etag() + except OSError as e: + import errno + if e.errno == errno.EEXIST: + raise exceptions.AlreadyExistingError(item) + else: + raise def update(self, href, item, etag): fpath = self._get_filepath(href) @@ -125,7 +131,7 @@ class FilesystemStorage(Storage): if not isinstance(item.raw, text_type): raise TypeError('item.raw must be a unicode string.') - with safe_write(fpath, 'wb') as f: + with atomic_write(fpath, binary=True, overwrite=True) as f: f.write(item.raw.encode(self.encoding)) return f.get_etag() diff --git a/vdirsyncer/storage/singlefile.py b/vdirsyncer/storage/singlefile.py index 4d50516..b5525f2 100644 --- a/vdirsyncer/storage/singlefile.py +++ b/vdirsyncer/storage/singlefile.py @@ -12,7 +12,7 @@ import os from .base import Item, Storage from .. import exceptions, log -from ..utils import checkfile, expand_path, safe_write +from ..utils import atomic_write, checkfile, expand_path from ..utils.compat import iteritems, itervalues from ..utils.vobject import join_collection, split_collection @@ -166,7 +166,7 @@ class SingleFileStorage(Storage): (item.raw for item, etag in itervalues(self._items)), ) try: - with safe_write(self.path, self._write_mode) as f: + with atomic_write(self.path, binary=True, overwrite=True) as f: f.write(text.encode(self.encoding)) finally: self._items = None diff --git a/vdirsyncer/utils/__init__.py b/vdirsyncer/utils/__init__.py index 18c1410..d1d9100 100644 --- a/vdirsyncer/utils/__init__.py +++ b/vdirsyncer/utils/__init__.py @@ -9,6 +9,7 @@ import os import threading +import uuid import requests from requests.packages.urllib3.poolmanager import PoolManager @@ -241,15 +242,21 @@ def request(method, url, session=None, latin1_fallback=True, return r -class safe_write(object): - '''A helper class for performing atomic writes. Writes to a tempfile in - the same directory and then renames. The tempfile location can be - overridden, but must reside on the same filesystem to be atomic. +class atomic_write(object): + ''' + A helper class for performing atomic writes. Usage:: - with safe_write(fpath, 'w+') as f: + with safe_write(fpath, binary=False, overwrite=False) as f: f.write('hohoho') + + :param fpath: The destination filepath. May or may not exist. + :param binary: Whether binary write mode should be used. + :param overwrite: If set to false, an error is raised if ``fpath`` exists. + This should still be atomic. + :param tmppath: An alternative tmpfile location. Must reside on the same + filesystem to be atomic. ''' f = None @@ -257,23 +264,35 @@ class safe_write(object): fpath = None mode = None - def __init__(self, fpath, mode, tmppath=None): - self.tmppath = tmppath or fpath + '.tmp' + def __init__(self, fpath, binary, overwrite, tmppath=None): + if not tmppath: + base = os.path.dirname(fpath) + tmppath = os.path.join(base, str(uuid.uuid4()) + '.tmp') + self.fpath = fpath - self.mode = mode + self.binary = binary + self.overwrite = overwrite + self.tmppath = tmppath def __enter__(self): - self.f = f = open(self.tmppath, self.mode) + self.f = f = open(self.tmppath, 'wb' if self.binary else 'w') self.write = f.write return self def __exit__(self, cls, value, tb): self.f.close() if cls is None: - os.rename(self.tmppath, self.fpath) + self._commit() else: os.remove(self.tmppath) + def _commit(self): + if self.overwrite: + os.rename(self.tmppath, self.fpath) # atomic + else: + os.link(self.tmppath, self.fpath) # atomic, fails if file exists + os.unlink(self.tmppath) # doesn't matter if atomic + def get_etag(self): self.f.flush() return get_etag_from_file(self.tmppath)