summaryrefslogtreecommitdiff
path: root/src/home/config
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/config')
-rw-r--r--src/home/config/__init__.py1
-rw-r--r--src/home/config/config.py143
2 files changed, 86 insertions, 58 deletions
diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py
index 03b9f25..1321047 100644
--- a/src/home/config/__init__.py
+++ b/src/home/config/__init__.py
@@ -2,6 +2,7 @@ from .config import (
Config,
ConfigUnit,
AppConfigUnit,
+ TranslationsUnit,
config,
is_development_mode,
setup_logging,
diff --git a/src/home/config/config.py b/src/home/config/config.py
index e1a089d..37dd5e8 100644
--- a/src/home/config/config.py
+++ b/src/home/config/config.py
@@ -1,9 +1,9 @@
-import toml
import yaml
import logging
import os
import pprint
+from abc import ABC
from cerberus import Validator, DocumentError
from typing import Optional, Any, MutableMapping, Union
from argparse import ArgumentParser
@@ -12,43 +12,65 @@ from os.path import join, isdir, isfile
from ..util import parse_addr
+SUPPORTED_LANGUAGES = ('en', 'ru')
+CONFIG_DIRECTORIES = (
+ join(os.environ['HOME'], '.config', 'homekit'),
+ '/etc/homekit'
+)
+
class RootSchemaType(Enum):
DEFAULT = auto()
DICT = auto()
LIST = auto()
-class ConfigUnit:
- NAME = 'dumb'
-
+class BaseConfigUnit(ABC):
_data: MutableMapping[str, Any]
+ _logger: logging.Logger
- @classmethod
- def get_config_path(cls, name=None) -> str:
- if name is None:
- name = cls.NAME
- if name is None:
- raise ValueError('get_config_path: name is none')
+ def __init__(self, name=None):
+ self._data = {}
+ self._logger = logging.getLogger(self.__class__.__name__)
- dirnames = (
- join(os.environ['HOME'], '.config', 'homekit'),
- '/etc/homekit'
- )
+ def __getitem__(self, key):
+ return self._data[key]
- for dirname in dirnames:
- if isdir(dirname):
- for fmt in ('toml', 'yaml'):
- filename = join(dirname, f'{name}.{fmt}')
- if isfile(filename):
- return filename
+ def __setitem__(self, key, value):
+ raise NotImplementedError('overwriting config values is prohibited')
- raise IOError(f'config file for \'{name}\' not found')
+ def __contains__(self, key):
+ return key in self._data
- @staticmethod
- def schema() -> Optional[dict]:
- return None
+ def load_from(self, path: str):
+ with open(path, 'r') as fd:
+ self._data = yaml.safe_load(fd)
+
+ def get(self,
+ key: Optional[str] = None,
+ default=None):
+ if key is None:
+ return self._data
+
+ cur = self._data
+ pts = key.split('.')
+ for i in range(len(pts)):
+ k = pts[i]
+ if i < len(pts)-1:
+ if k not in cur:
+ raise KeyError(f'key {k} not found')
+ else:
+ return cur[k] if k in cur else default
+ cur = self._data[k]
+
+ raise KeyError(f'option {key} not found')
+
+
+class ConfigUnit(BaseConfigUnit):
+ NAME = 'dumb'
def __init__(self, name=None):
+ super().__init__()
+
self._data = {}
self._logger = logging.getLogger(self.__class__.__name__)
@@ -59,12 +81,24 @@ class ConfigUnit:
elif name is not None:
self.NAME = name
- def load_from(self, path: str):
- if path.endswith('.toml'):
- self._data = toml.load(path)
- elif path.endswith('.yaml'):
- with open(path, 'r') as fd:
- self._data = yaml.safe_load(fd)
+ @classmethod
+ def get_config_path(cls, name=None) -> str:
+ if name is None:
+ name = cls.NAME
+ if name is None:
+ raise ValueError('get_config_path: name is none')
+
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'{name}.yaml')
+ if isfile(filename):
+ return filename
+
+ raise IOError(f'\'{name}\'.yaml not found')
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return None
def validate(self):
schema = self.schema()
@@ -115,34 +149,6 @@ class ConfigUnit:
def custom_validator(data):
pass
- def __getitem__(self, key):
- return self._data[key]
-
- def __setitem__(self, key, value):
- raise NotImplementedError('overwriting config values is prohibited')
-
- def __contains__(self, key):
- return key in self._data
-
- def get(self,
- key: Optional[str] = None,
- default=None):
- if key is None:
- return self._data
-
- cur = self._data
- pts = key.split('.')
- for i in range(len(pts)):
- k = pts[i]
- if i < len(pts)-1:
- if k not in cur:
- raise KeyError(f'key {k} not found')
- else:
- return cur[k] if k in cur else default
- cur = self._data[k]
-
- raise KeyError(f'option {key} not found')
-
def get_addr(self, key: str):
return parse_addr(self.get(key))
@@ -189,6 +195,27 @@ class AppConfigUnit(ConfigUnit):
return self._logging_verbose
+class TranslationsUnit(BaseConfigUnit):
+ _lang: str
+ _name: str
+
+ def __init__(self,
+ lang: str,
+ name: str):
+ super().__init__()
+ self._lang = lang
+ self._name = name
+
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'i18n-{lang}', f'{name}.yaml')
+ if isfile(filename):
+ self.load_from(filename)
+ break
+
+ raise IOError(f'i18n-{lang}/{name}.yaml not found')
+
+
class Config:
app_name: Optional[str]
app_config: AppConfigUnit