summaryrefslogtreecommitdiff
path: root/include/py/homekit/config/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'include/py/homekit/config/config.py')
-rw-r--r--include/py/homekit/config/config.py387
1 files changed, 387 insertions, 0 deletions
diff --git a/include/py/homekit/config/config.py b/include/py/homekit/config/config.py
new file mode 100644
index 0000000..7344386
--- /dev/null
+++ b/include/py/homekit/config/config.py
@@ -0,0 +1,387 @@
+import yaml
+import logging
+import os
+import cerberus
+import cerberus.errors
+
+from abc import ABC
+from typing import Optional, Any, MutableMapping, Union
+from argparse import ArgumentParser
+from enum import Enum, auto
+from os.path import join, isdir, isfile
+from ..util import Addr
+
+
+class MyValidator(cerberus.Validator):
+ def _normalize_coerce_addr(self, value):
+ return Addr.fromstring(value)
+
+
+MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ())
+
+
+CONFIG_DIRECTORIES = (
+ join(os.environ['HOME'], '.config', 'homekit'),
+ '/etc/homekit'
+)
+
+
+class RootSchemaType(Enum):
+ DEFAULT = auto()
+ DICT = auto()
+ LIST = auto()
+
+
+class BaseConfigUnit(ABC):
+ _data: MutableMapping[str, Any]
+ _logger: logging.Logger
+
+ def __init__(self):
+ self._data = {}
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError('overwriting config values is prohibited')
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def load_from(self, path: str):
+ with open(path, 'r') as fd:
+ self._data = yaml.safe_load(fd)
+
+ def get(self,
+ key: Optional[str] = None,
+ default=None):
+ if key is None:
+ return self._data
+
+ cur = self._data
+ pts = key.split('.')
+ for i in range(len(pts)):
+ k = pts[i]
+ if i < len(pts)-1:
+ if k not in cur:
+ raise KeyError(f'key {k} not found')
+ else:
+ return cur[k] if k in cur else default
+ cur = self._data[k]
+
+ raise KeyError(f'option {key} not found')
+
+
+class ConfigUnit(BaseConfigUnit):
+ NAME = 'dumb'
+
+ def __init__(self, name=None, load=True):
+ super().__init__()
+
+ self._data = {}
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ if self.NAME != 'dumb' and load:
+ self.load_from(self.get_config_path())
+ self.validate()
+
+ elif name is not None:
+ self.NAME = name
+
+ @classmethod
+ def get_config_path(cls, name=None) -> str:
+ if name is None:
+ name = cls.NAME
+ if name is None:
+ raise ValueError('get_config_path: name is none')
+
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'{name}.yaml')
+ if isfile(filename):
+ return filename
+
+ raise IOError(f'\'{name}.yaml\' not found')
+
+ @classmethod
+ def schema(cls) -> Optional[dict]:
+ return None
+
+ @classmethod
+ def _addr_schema(cls, required=False, **kwargs):
+ return {
+ 'type': 'addr',
+ 'coerce': Addr.fromstring,
+ 'required': required,
+ **kwargs
+ }
+
+ def validate(self):
+ schema = self.schema()
+ if not schema:
+ self._logger.warning('validate: no schema')
+ return
+
+ if isinstance(self, AppConfigUnit):
+ schema['logging'] = {
+ 'type': 'dict',
+ 'schema': {
+ 'logging': {'type': 'boolean'}
+ }
+ }
+
+ rst = RootSchemaType.DEFAULT
+ try:
+ if schema['type'] == 'dict':
+ rst = RootSchemaType.DICT
+ elif schema['type'] == 'list':
+ rst = RootSchemaType.LIST
+ elif schema['roottype'] == 'dict':
+ del schema['roottype']
+ rst = RootSchemaType.DICT
+ except KeyError:
+ pass
+
+ v = MyValidator()
+
+ if rst == RootSchemaType.DICT:
+ normalized = v.validated({'document': self._data},
+ {'document': {
+ 'type': 'dict',
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': schema
+ }})['document']
+ elif rst == RootSchemaType.LIST:
+ v = MyValidator()
+ normalized = v.validated({'document': self._data}, {'document': schema})['document']
+ else:
+ normalized = v.validated(self._data, schema)
+
+ self._data = normalized
+
+ try:
+ self.custom_validator(self._data)
+ except Exception as e:
+ raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}')
+
+ @staticmethod
+ def custom_validator(data):
+ pass
+
+ def get_addr(self, key: str):
+ return Addr.fromstring(self.get(key))
+
+
+class AppConfigUnit(ConfigUnit):
+ _logging_verbose: bool
+ _logging_fmt: Optional[str]
+ _logging_file: Optional[str]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(load=False, *args, **kwargs)
+ self._logging_verbose = False
+ self._logging_fmt = None
+ self._logging_file = None
+
+ def logging_set_fmt(self, fmt: str) -> None:
+ self._logging_fmt = fmt
+
+ def logging_get_fmt(self) -> Optional[str]:
+ try:
+ return self['logging']['default_fmt']
+ except KeyError:
+ return self._logging_fmt
+
+ def logging_set_file(self, file: str) -> None:
+ self._logging_file = file
+
+ def logging_get_file(self) -> Optional[str]:
+ try:
+ return self['logging']['file']
+ except KeyError:
+ return self._logging_file
+
+ def logging_set_verbose(self):
+ self._logging_verbose = True
+
+ def logging_is_verbose(self) -> bool:
+ try:
+ return bool(self['logging']['verbose'])
+ except KeyError:
+ return self._logging_verbose
+
+
+class TranslationUnit(BaseConfigUnit):
+ pass
+
+
+class Translation:
+ LANGUAGES = ('en', 'ru')
+ _langs: dict[str, TranslationUnit]
+
+ def __init__(self, name: str):
+ super().__init__()
+ self._langs = {}
+ for lang in self.LANGUAGES:
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'i18n-{lang}', f'{name}.yaml')
+ if lang in self._langs:
+ raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded')
+ self._langs[lang] = TranslationUnit()
+ self._langs[lang].load_from(filename)
+ diff = set()
+ for data in self._langs.values():
+ diff ^= data.get().keys()
+ if len(diff) > 0:
+ raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff))
+
+ def get(self, lang: str) -> TranslationUnit:
+ return self._langs[lang]
+
+
+class Config:
+ app_name: Optional[str]
+ app_config: AppConfigUnit
+
+ def __init__(self):
+ self.app_name = None
+ self.app_config = AppConfigUnit()
+
+ def load_app(self,
+ name: Optional[Union[str, AppConfigUnit, bool]] = None,
+ use_cli=True,
+ parser: ArgumentParser = None,
+ no_config=False):
+ global app_config
+
+ if not no_config \
+ and not isinstance(name, str) \
+ and not isinstance(name, bool) \
+ and issubclass(name, AppConfigUnit) or name == AppConfigUnit:
+ self.app_name = name.NAME
+ self.app_config = name()
+ app_config = self.app_config
+ else:
+ self.app_name = name if isinstance(name, str) else None
+
+ if self.app_name is None and not use_cli:
+ raise RuntimeError('either config name must be none or use_cli must be True')
+
+ no_config = name is False or no_config
+ path = None
+
+ if use_cli:
+ if parser is None:
+ parser = ArgumentParser()
+ if not no_config:
+ parser.add_argument('-c', '--config', type=str, required=name is None,
+ help='Path to the config in TOML or YAML format')
+ parser.add_argument('-V', '--verbose', action='store_true')
+ parser.add_argument('--log-file', type=str)
+ parser.add_argument('--log-default-fmt', action='store_true')
+ args = parser.parse_args()
+
+ if not no_config and args.config:
+ path = args.config
+
+ if args.verbose:
+ self.app_config.logging_set_verbose()
+ if args.log_file:
+ self.app_config.logging_set_file(args.log_file)
+ if args.log_default_fmt:
+ self.app_config.logging_set_fmt(args.log_default_fmt)
+
+ if not isinstance(name, ConfigUnit):
+ if not no_config and path is None:
+ path = ConfigUnit.get_config_path(name=self.app_name)
+
+ if not no_config:
+ self.app_config.load_from(path)
+ self.app_config.validate()
+
+ setup_logging(self.app_config.logging_is_verbose(),
+ self.app_config.logging_get_file(),
+ self.app_config.logging_get_fmt())
+
+ if use_cli:
+ return args
+
+
+config = Config()
+
+
+def is_development_mode() -> bool:
+ if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev':
+ return True
+
+ return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True)
+
+
+def setup_logging(verbose=False, log_file=None, default_fmt=None):
+ logging_level = logging.INFO
+ if is_development_mode() or verbose:
+ logging_level = logging.DEBUG
+ _add_logging_level('TRACE', logging.DEBUG-5)
+
+ log_config = {'level': logging_level}
+ if not default_fmt:
+ log_config['format'] = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+
+ if log_file is not None:
+ log_config['filename'] = log_file
+ log_config['encoding'] = 'utf-8'
+
+ logging.basicConfig(**log_config)
+
+
+# https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945
+def _add_logging_level(levelName, levelNum, methodName=None):
+ """
+ Comprehensively adds a new logging level to the `logging` module and the
+ currently configured logging class.
+
+ `levelName` becomes an attribute of the `logging` module with the value
+ `levelNum`. `methodName` becomes a convenience method for both `logging`
+ itself and the class returned by `logging.getLoggerClass()` (usually just
+ `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
+ used.
+
+ To avoid accidental clobberings of existing attributes, this method will
+ raise an `AttributeError` if the level name is already an attribute of the
+ `logging` module or if the method name is already present
+
+ Example
+ -------
+ >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
+ >>> logging.getLogger(__name__).setLevel("TRACE")
+ >>> logging.getLogger(__name__).trace('that worked')
+ >>> logging.trace('so did this')
+ >>> logging.TRACE
+ 5
+
+ """
+ if not methodName:
+ methodName = levelName.lower()
+
+ if hasattr(logging, levelName):
+ raise AttributeError('{} already defined in logging module'.format(levelName))
+ if hasattr(logging, methodName):
+ raise AttributeError('{} already defined in logging module'.format(methodName))
+ if hasattr(logging.getLoggerClass(), methodName):
+ raise AttributeError('{} already defined in logger class'.format(methodName))
+
+ # This method was inspired by the answers to Stack Overflow post
+ # http://stackoverflow.com/q/2183233/2988730, especially
+ # http://stackoverflow.com/a/13638084/2988730
+ def logForLevel(self, message, *args, **kwargs):
+ if self.isEnabledFor(levelNum):
+ self._log(levelNum, message, args, **kwargs)
+ def logToRoot(message, *args, **kwargs):
+ logging.log(levelNum, message, *args, **kwargs)
+
+ logging.addLevelName(levelNum, levelName)
+ setattr(logging, levelName, levelNum)
+ setattr(logging.getLoggerClass(), methodName, logForLevel)
+ setattr(logging, methodName, logToRoot) \ No newline at end of file