diff options
Diffstat (limited to 'src/home/config/config.py')
-rw-r--r-- | src/home/config/config.py | 170 |
1 files changed, 104 insertions, 66 deletions
diff --git a/src/home/config/config.py b/src/home/config/config.py index 4681685..2d49524 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -3,45 +3,116 @@ import yaml import logging import os +from . import validators from os.path import join, isdir, isfile from typing import Optional, Any, MutableMapping from argparse import ArgumentParser from ..util import parse_addr -def _get_config_path(name: str) -> str: - formats = ['toml', 'yaml'] +_my_validators = {} - dirname = join(os.environ['HOME'], '.config', name) - if isdir(dirname): - for fmt in formats: - filename = join(dirname, f'config.{fmt}') - if isfile(filename): - return filename +def _get_validator(name: str) -> Optional[callable]: + if hasattr(validators, f'{name}_validator'): + return getattr(validators, f'{name}_validator') + if name in _my_validators: + return _my_validators[name] + return None - raise IOError(f'config not found in {dirname}') - else: - filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats] - for file in filenames: - if isfile(file): - return file +def add_validator(name: str, f: callable): + _my_validators[name] = f - raise IOError(f'config not found') +class ConfigUnit: + NAME = 'dumb' -class ConfigStore: data: MutableMapping[str, Any] + + @classmethod + def get_config_path(cls, name=None) -> str: + if name is None: + name = cls.NAME + if name is None: + raise ValueError('get_config_path: name is none') + + dirnames = ( + join(os.environ['HOME'], '.config', 'homekit'), + '/etc/homekit' + ) + + for dirname in dirnames: + if isdir(dirname): + for fmt in ('toml', 'yaml'): + filename = join(dirname, f'{name}.{fmt}') + if isfile(filename): + return filename + + raise IOError(f'config \'{name}\' not found') + + def __init__(self, name=None): + self.data = {} + + if self.NAME != 'dumb': + self.load_from(self.get_config_path()) + self.validate() + + elif name is not None: + self.NAME = name + + def load_from(self, path: str): + if path.endswith('.toml'): + self.data = toml.load(path) + elif path.endswith('.yaml'): + with open(path, 'r') as fd: + self.data = yaml.safe_load(fd) + + def validate(self): + v = _get_validator(self.NAME) + v(self.data) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + raise NotImplementedError('overwriting config values is prohibited') + + def __contains__(self, key): + return key in self.data + + def get(self, key: str, default=None): + cur = self.data + pts = key.split('.') + for i in range(len(pts)): + k = pts[i] + if i < len(pts)-1: + if k not in cur: + raise KeyError(f'key {k} not found') + else: + return cur[k] if k in cur else default + cur = self.data[k] + raise KeyError(f'option {key} not found') + + def get_addr(self, key: str): + return parse_addr(self.get(key)) + + def items(self): + return self.data.items() + + +class Config: app_name: Optional[str] + app_config: ConfigUnit def __init__(self): - self.data = {} self.app_name = None + self.app_config = ConfigUnit() - def load(self, name: Optional[str] = None, - use_cli=True, - parser: ArgumentParser = None): + def load_app(self, + name: Optional[str] = None, + use_cli=True, + parser: ArgumentParser = None): self.app_name = name if (name is None) and (not use_cli): @@ -75,65 +146,32 @@ class ConfigStore: log_default_fmt = args.log_default_fmt if not no_config and path is None: - path = _get_config_path(name) - - if no_config: - self.data = {} - else: - if path.endswith('.toml'): - self.data = toml.load(path) - elif path.endswith('.yaml'): - with open(path, 'r') as fd: - self.data = yaml.safe_load(fd) - - if 'logging' in self: - if not log_file and 'file' in self['logging']: - log_file = self['logging']['file'] - if log_default_fmt and 'default_fmt' in self['logging']: - log_default_fmt = self['logging']['default_fmt'] + path = ConfigUnit.get_config_path(name=name) + + if not no_config: + self.app_config.load_from(path) + + if 'logging' in self.app_config: + if not log_file and 'file' in self.app_config['logging']: + log_file = self.app_config['logging']['file'] + if log_default_fmt and 'default_fmt' in self.app_config['logging']: + log_default_fmt = self.app_config['logging']['default_fmt'] setup_logging(log_verbose, log_file, log_default_fmt) if use_cli: return args - def __getitem__(self, key): - return self.data[key] - - def __setitem__(self, key, value): - raise NotImplementedError('overwriting config values is prohibited') - - def __contains__(self, key): - return key in self.data - - def get(self, key: str, default=None): - cur = self.data - pts = key.split('.') - for i in range(len(pts)): - k = pts[i] - if i < len(pts)-1: - if k not in cur: - raise KeyError(f'key {k} not found') - else: - return cur[k] if k in cur else default - cur = self.data[k] - raise KeyError(f'option {key} not found') - - def get_addr(self, key: str): - return parse_addr(self.get(key)) - - def items(self): - return self.data.items() - -config = ConfigStore() +config = Config() +app_config = config.app_config def is_development_mode() -> bool: if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev': return True - return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True) + return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True) def setup_logging(verbose=False, log_file=None, default_fmt=False): |