Browse Source

use atomic writes for config file writing as well

Nick Sweeting 5 years ago
parent
commit
79b19ddf35

+ 36 - 37
archivebox/config/__init__.py

@@ -320,65 +320,64 @@ def load_config_file(out_dir: str=None) -> Optional[Dict[str, str]]:
         return config_file_vars
         return config_file_vars
     return None
     return None
 
 
+
 def write_config_file(config: Dict[str, str], out_dir: str=None) -> ConfigDict:
 def write_config_file(config: Dict[str, str], out_dir: str=None) -> ConfigDict:
     """load the ini-formatted config file from OUTPUT_DIR/Archivebox.conf"""
     """load the ini-formatted config file from OUTPUT_DIR/Archivebox.conf"""
 
 
     out_dir = out_dir or os.path.abspath(os.getenv('OUTPUT_DIR', '.'))
     out_dir = out_dir or os.path.abspath(os.getenv('OUTPUT_DIR', '.'))
     config_path = os.path.join(out_dir, CONFIG_FILENAME)
     config_path = os.path.join(out_dir, CONFIG_FILENAME)
+    
     if not os.path.exists(config_path):
     if not os.path.exists(config_path):
-        with open(config_path, 'w+') as f:
-            f.write(CONFIG_HEADER)
+        atomic_write(config_path, CONFIG_HEADER)
 
 
     config_file = ConfigParser()
     config_file = ConfigParser()
     config_file.optionxform = str
     config_file.optionxform = str
     config_file.read(config_path)
     config_file.read(config_path)
 
 
+    with open(config_path, 'r') as old:
+        atomic_write(f'{config_path}.bak', old.read())
+
     find_section = lambda key: [name for name, opts in CONFIG_DEFAULTS.items() if key in opts][0]
     find_section = lambda key: [name for name, opts in CONFIG_DEFAULTS.items() if key in opts][0]
 
 
-    with open(f'{config_path}.old', 'w+') as old:
-        with open(config_path, 'r') as new:
-            old.write(new.read())
-
-    with open(config_path, 'w+') as f:
-        for key, val in config.items():
-            section = find_section(key)
-            if section in config_file:
-                existing_config = dict(config_file[section])
-            else:
-                existing_config = {}
-
-            config_file[section] = {**existing_config, key: val}
-
-        # always make sure there's a SECRET_KEY defined for Django
-        existing_secret_key = None
-        if 'SERVER_CONFIG' in config_file and 'SECRET_KEY' in config_file['SERVER_CONFIG']:
-            existing_secret_key = config_file['SERVER_CONFIG']['SECRET_KEY']
-
-        if (not existing_secret_key) or ('not a valid secret' in existing_secret_key):
-            from django.utils.crypto import get_random_string
-            chars = 'abcdefghijklmnopqrstuvwxyz0123456789-_+!.'
-            random_secret_key = get_random_string(50, chars)
-            if 'SERVER_CONFIG' in config_file:
-                config_file['SERVER_CONFIG']['SECRET_KEY'] = random_secret_key
-            else:
-                config_file['SERVER_CONFIG'] = {'SECRET_KEY': random_secret_key}
-
-        f.write(CONFIG_HEADER)
-        config_file.write(f)
+    # Set up sections in empty config file
+    for key, val in config.items():
+        section = find_section(key)
+        if section in config_file:
+            existing_config = dict(config_file[section])
+        else:
+            existing_config = {}
+        config_file[section] = {**existing_config, key: val}
+
+    # always make sure there's a SECRET_KEY defined for Django
+    existing_secret_key = None
+    if 'SERVER_CONFIG' in config_file and 'SECRET_KEY' in config_file['SERVER_CONFIG']:
+        existing_secret_key = config_file['SERVER_CONFIG']['SECRET_KEY']
+
+    if (not existing_secret_key) or ('not a valid secret' in existing_secret_key):
+        from django.utils.crypto import get_random_string
+        chars = 'abcdefghijklmnopqrstuvwxyz0123456789-_+!.'
+        random_secret_key = get_random_string(50, chars)
+        if 'SERVER_CONFIG' in config_file:
+            config_file['SERVER_CONFIG']['SECRET_KEY'] = random_secret_key
+        else:
+            config_file['SERVER_CONFIG'] = {'SECRET_KEY': random_secret_key}
+
 
 
+    atomic_write(config_path, '\n'.join((CONFIG_HEADER, config_file)))
     try:
     try:
+        # validate the config by attempting to re-parse it
         CONFIG = load_all_config()
         CONFIG = load_all_config()
         return {
         return {
             key.upper(): CONFIG.get(key.upper())
             key.upper(): CONFIG.get(key.upper())
             for key in config.keys()
             for key in config.keys()
         }
         }
     except:
     except:
-        with open(f'{config_path}.old', 'r') as old:
-            with open(config_path, 'w+') as new:
-                new.write(old.read())
+        # something went horribly wrong, rever to the previous version
+        with open(f'{config_path}.bak', 'r') as old:
+            atomic_write(config_path, old.read())
 
 
-    if os.path.exists(f'{config_path}.old'):
-        os.remove(f'{config_path}.old')
+    if os.path.exists(f'{config_path}.bak'):
+        os.remove(f'{config_path}.bak')
 
 
     return {}
     return {}
 
 

+ 3 - 3
archivebox/extractors/dom.py

@@ -5,7 +5,7 @@ import os
 from typing import Optional
 from typing import Optional
 
 
 from ..index.schema import Link, ArchiveResult, ArchiveOutput, ArchiveError
 from ..index.schema import Link, ArchiveResult, ArchiveOutput, ArchiveError
-from ..system import run, chmod_file
+from ..system import run, chmod_file, atomic_write
 from ..util import (
 from ..util import (
     enforce_types,
     enforce_types,
     is_static_file,
     is_static_file,
@@ -46,8 +46,8 @@ def save_dom(link: Link, out_dir: Optional[str]=None, timeout: int=TIMEOUT) -> A
     status = 'succeeded'
     status = 'succeeded'
     timer = TimedProgress(timeout, prefix='      ')
     timer = TimedProgress(timeout, prefix='      ')
     try:
     try:
-        with open(output_path, 'w+') as f:
-            result = run(cmd, stdout=f, cwd=out_dir, timeout=timeout)
+        result = run(cmd, cwd=out_dir, timeout=timeout)
+        atomic_write(output_path, result.stdout)
 
 
         if result.returncode:
         if result.returncode:
             hints = result.stderr.decode()
             hints = result.stderr.decode()

+ 0 - 1
archivebox/extractors/git.py

@@ -65,7 +65,6 @@ def save_git(link: Link, out_dir: Optional[str]=None, timeout: int=TIMEOUT) -> A
     timer = TimedProgress(timeout, prefix='      ')
     timer = TimedProgress(timeout, prefix='      ')
     try:
     try:
         result = run(cmd, cwd=output_path, timeout=timeout + 1)
         result = run(cmd, cwd=output_path, timeout=timeout + 1)
-
         if result.returncode == 128:
         if result.returncode == 128:
             # ignore failed re-download when the folder already exists
             # ignore failed re-download when the folder already exists
             pass
             pass

+ 1 - 0
archivebox/extractors/pdf.py

@@ -58,6 +58,7 @@ def save_pdf(link: Link, out_dir: Optional[str]=None, timeout: int=TIMEOUT) -> A
     finally:
     finally:
         timer.end()
         timer.end()
 
 
+
     return ArchiveResult(
     return ArchiveResult(
         cmd=cmd,
         cmd=cmd,
         pwd=out_dir,
         pwd=out_dir,