summaryrefslogtreecommitdiff
path: root/src/home/config/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/config/config.py')
-rw-r--r--src/home/config/config.py170
1 files changed, 104 insertions, 66 deletions
diff --git a/src/home/config/config.py b/src/home/config/config.py
index 4681685..2d49524 100644
--- a/src/home/config/config.py
+++ b/src/home/config/config.py
@@ -3,45 +3,116 @@ import yaml
import logging
import os
+from . import validators
from os.path import join, isdir, isfile
from typing import Optional, Any, MutableMapping
from argparse import ArgumentParser
from ..util import parse_addr
-def _get_config_path(name: str) -> str:
- formats = ['toml', 'yaml']
+_my_validators = {}
- dirname = join(os.environ['HOME'], '.config', name)
- if isdir(dirname):
- for fmt in formats:
- filename = join(dirname, f'config.{fmt}')
- if isfile(filename):
- return filename
+def _get_validator(name: str) -> Optional[callable]:
+ if hasattr(validators, f'{name}_validator'):
+ return getattr(validators, f'{name}_validator')
+ if name in _my_validators:
+ return _my_validators[name]
+ return None
- raise IOError(f'config not found in {dirname}')
- else:
- filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats]
- for file in filenames:
- if isfile(file):
- return file
+def add_validator(name: str, f: callable):
+ _my_validators[name] = f
- raise IOError(f'config not found')
+class ConfigUnit:
+ NAME = 'dumb'
-class ConfigStore:
data: MutableMapping[str, Any]
+
+ @classmethod
+ def get_config_path(cls, name=None) -> str:
+ if name is None:
+ name = cls.NAME
+ if name is None:
+ raise ValueError('get_config_path: name is none')
+
+ dirnames = (
+ join(os.environ['HOME'], '.config', 'homekit'),
+ '/etc/homekit'
+ )
+
+ for dirname in dirnames:
+ if isdir(dirname):
+ for fmt in ('toml', 'yaml'):
+ filename = join(dirname, f'{name}.{fmt}')
+ if isfile(filename):
+ return filename
+
+ raise IOError(f'config \'{name}\' not found')
+
+ def __init__(self, name=None):
+ self.data = {}
+
+ if self.NAME != 'dumb':
+ self.load_from(self.get_config_path())
+ self.validate()
+
+ elif name is not None:
+ self.NAME = name
+
+ def load_from(self, path: str):
+ if path.endswith('.toml'):
+ self.data = toml.load(path)
+ elif path.endswith('.yaml'):
+ with open(path, 'r') as fd:
+ self.data = yaml.safe_load(fd)
+
+ def validate(self):
+ v = _get_validator(self.NAME)
+ v(self.data)
+
+ def __getitem__(self, key):
+ return self.data[key]
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError('overwriting config values is prohibited')
+
+ def __contains__(self, key):
+ return key in self.data
+
+ def get(self, key: str, default=None):
+ cur = self.data
+ pts = key.split('.')
+ for i in range(len(pts)):
+ k = pts[i]
+ if i < len(pts)-1:
+ if k not in cur:
+ raise KeyError(f'key {k} not found')
+ else:
+ return cur[k] if k in cur else default
+ cur = self.data[k]
+ raise KeyError(f'option {key} not found')
+
+ def get_addr(self, key: str):
+ return parse_addr(self.get(key))
+
+ def items(self):
+ return self.data.items()
+
+
+class Config:
app_name: Optional[str]
+ app_config: ConfigUnit
def __init__(self):
- self.data = {}
self.app_name = None
+ self.app_config = ConfigUnit()
- def load(self, name: Optional[str] = None,
- use_cli=True,
- parser: ArgumentParser = None):
+ def load_app(self,
+ name: Optional[str] = None,
+ use_cli=True,
+ parser: ArgumentParser = None):
self.app_name = name
if (name is None) and (not use_cli):
@@ -75,65 +146,32 @@ class ConfigStore:
log_default_fmt = args.log_default_fmt
if not no_config and path is None:
- path = _get_config_path(name)
-
- if no_config:
- self.data = {}
- else:
- if path.endswith('.toml'):
- self.data = toml.load(path)
- elif path.endswith('.yaml'):
- with open(path, 'r') as fd:
- self.data = yaml.safe_load(fd)
-
- if 'logging' in self:
- if not log_file and 'file' in self['logging']:
- log_file = self['logging']['file']
- if log_default_fmt and 'default_fmt' in self['logging']:
- log_default_fmt = self['logging']['default_fmt']
+ path = ConfigUnit.get_config_path(name=name)
+
+ if not no_config:
+ self.app_config.load_from(path)
+
+ if 'logging' in self.app_config:
+ if not log_file and 'file' in self.app_config['logging']:
+ log_file = self.app_config['logging']['file']
+ if log_default_fmt and 'default_fmt' in self.app_config['logging']:
+ log_default_fmt = self.app_config['logging']['default_fmt']
setup_logging(log_verbose, log_file, log_default_fmt)
if use_cli:
return args
- def __getitem__(self, key):
- return self.data[key]
-
- def __setitem__(self, key, value):
- raise NotImplementedError('overwriting config values is prohibited')
-
- def __contains__(self, key):
- return key in self.data
-
- def get(self, key: str, default=None):
- cur = self.data
- pts = key.split('.')
- for i in range(len(pts)):
- k = pts[i]
- if i < len(pts)-1:
- if k not in cur:
- raise KeyError(f'key {k} not found')
- else:
- return cur[k] if k in cur else default
- cur = self.data[k]
- raise KeyError(f'option {key} not found')
-
- def get_addr(self, key: str):
- return parse_addr(self.get(key))
-
- def items(self):
- return self.data.items()
-
-config = ConfigStore()
+config = Config()
+app_config = config.app_config
def is_development_mode() -> bool:
if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev':
return True
- return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True)
+ return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True)
def setup_logging(verbose=False, log_file=None, default_fmt=False):