diff options
Diffstat (limited to 'src/home/mqtt/payload/base_payload.py')
-rw-r--r-- | src/home/mqtt/payload/base_payload.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/payload/base_payload.py new file mode 100644 index 0000000..c9ec907 --- /dev/null +++ b/src/home/mqtt/payload/base_payload.py @@ -0,0 +1,129 @@ +import abc +import struct + +from typing import Generic, TypeVar + + +class MQTTPayload(abc.ABC): + FORMAT = '' + PACKER = {} + UNPACKER = {} + + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + def pack(self): + args = [] + bf_number = -1 + bf_arg = 0 + bf_progress = 0 + + for field, field_type in self.__class__.__annotations__.items(): + field_type_origin = None + if hasattr(field_type, '__extra__') or hasattr(field_type, '__origin__'): + try: + field_type_origin = field_type.__extra__ + except AttributeError: + field_type_origin = field_type.__origin__ + + if field_type_origin is not None and issubclass(field_type_origin, MQTTPayloadBitField): + n, s, b = field_type.__args__ + if n != bf_number: + if bf_number != -1: + args.append(bf_arg) + bf_number = n + bf_progress = 0 + bf_arg = 0 + bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress + bf_progress += b + + else: + if bf_number != -1: + args.append(bf_arg) + bf_number = -1 + bf_progress = 0 + bf_arg = 0 + + args.append(self._pack_field(field)) + + if bf_number != -1: + args.append(bf_arg) + + return struct.pack(self.FORMAT, *args) + + @classmethod + def unpack(cls, buf: bytes): + data = struct.unpack(cls.FORMAT, buf) + kwargs = {} + i = 0 + bf_number = -1 + bf_progress = 0 + + for field, field_type in cls.__annotations__.items(): + field_type_origin = None + if hasattr(field_type, '__extra__') or hasattr(field_type, '__origin__'): + try: + field_type_origin = field_type.__extra__ + except AttributeError: + field_type_origin = field_type.__origin__ + + if field_type_origin is not None and issubclass(field_type_origin, MQTTPayloadBitField): + n, s, b = field_type.__args__ + if n != bf_number: + bf_number = n + bf_progress = 0 + kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) + bf_progress += b + continue # don't increment i + + if bf_number != -1: + bf_number = -1 + i += 1 + + if issubclass(field_type, MQTTPayloadCustomField): + kwargs[field] = field_type.unpack(data[i]) + else: + kwargs[field] = cls._unpack_field(field, data[i]) + i += 1 + return cls(**kwargs) + + def _pack_field(self, name): + val = getattr(self, name) + if self.PACKER and name in self.PACKER: + return self.PACKER[name](val) + else: + return val + + @classmethod + def _unpack_field(cls, name, val): + if isinstance(val, MQTTPayloadCustomField): + return + if cls.UNPACKER and name in cls.UNPACKER: + return cls.UNPACKER[name](val) + else: + return val + + +class MQTTPayloadCustomField(abc.ABC): + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + @abc.abstractmethod + def __index__(self): + pass + + @classmethod + @abc.abstractmethod + def unpack(cls, *args, **kwargs): + pass + + +NT = TypeVar('NT') # number of bit field +ST = TypeVar('ST') # size in bytes +BT = TypeVar('BT') # size in bits of particular value + + +class MQTTPayloadBitField(int, Generic[NT, ST, BT]): + pass |